CLOVER🍀

That was when it all began.

Nettyで簡易HTTPサーバを書く

久しぶりにNettyネタです。まあ、今回は写経ものですが。

ところで、今年の1月頃にNettyを触っていた時はバージョン3.3.0.Finalだったのですが、今はすでにバージョン3.4.2.Finalとなっています。上がるの早いなぁ…。

んで、今回はHTTPサーバを書いてみたいと思います。なんてことはなくて、「org.jboss.netty.example.http.file」パッケージにあるサンプルをScalaで書き直してみるものです。

まずbuild.sbt。

name := "netty-http-file"

version := "0.0.1"

scalaVersion := "2.9.2"

organization := "littlewings"

libraryDependencies += "io.netty" % "netty" % "3.4.2.Final"

そういえば、sbtでScala 2.9.2を使うの、これが初めてのような…。

あとは、写経したファイルを載せていきます。
HttpStaticFileServer.scala

import java.net.InetSocketAddress
import java.util.concurrent.Executors

import org.jboss.netty.bootstrap.ServerBootstrap
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory

object HttpStaticFileServer {
  def main(args: Array[String]): Unit = {
    val (port, documentRoot) =
      args toList match {
        case portAsString :: documentRoot :: Nil => (portAsString.toInt, documentRoot)
        case documentRoot :: Nil => (8080, documentRoot)
        case _ => sys.exit(1)
      }

    new HttpStaticFileServer(port, documentRoot).run()
  }
}

class HttpStaticFileServer(port: Int, documentRoot: String) {
  def run(): Unit = {
    val bootstrap = new ServerBootstrap(
                      new NioServerSocketChannelFactory(
                        Executors.newCachedThreadPool,
                        Executors.newCachedThreadPool))

    bootstrap.setPipelineFactory(new HttpStaticFileServerPipelineFactory(documentRoot))

    println("Boot HttpStaticFileServer[%d] root[%s] %s".format(port, documentRoot, new java.util.Date))

    bootstrap.bind(new InetSocketAddress(port))
  }
}

HttpStaticFileServerPipelineFactory.scala

import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels}
import org.jboss.netty.handler.codec.http.{HttpChunkAggregator, HttpRequestDecoder, HttpResponseEncoder}
import org.jboss.netty.handler.stream.ChunkedWriteHandler

class HttpStaticFileServerPipelineFactory(private val documentRoot: String) extends ChannelPipelineFactory {
  @throws(classOf[Exception])
  def getPipeline: ChannelPipeline = {
    val pipeline = Channels.pipeline

    pipeline.addLast("decoder", new HttpRequestDecoder)
    pipeline.addLast("aggregator", new HttpChunkAggregator(65536))
    pipeline.addLast("encoder", new HttpResponseEncoder)
    pipeline.addLast("chunckedWriter", new ChunkedWriteHandler)
    pipeline.addLast("handler", new HttpStaticFileServerHandler(documentRoot))

    pipeline
  }
}

HttpStaticFileServerHandler.scala

import java.io.{File, FileNotFoundException, RandomAccessFile, UnsupportedEncodingException}
import java.net.URLDecoder
import java.text.{DateFormat, SimpleDateFormat}
import java.util.{Calendar, Date, GregorianCalendar, Locale, TimeZone}

import javax.activation.MimetypesFileTypeMap

import org.jboss.netty.buffer.ChannelBuffers
import org.jboss.netty.channel.Channel
import org.jboss.netty.channel.ChannelFuture
import org.jboss.netty.channel.ChannelFutureListener
import org.jboss.netty.channel.ChannelFutureProgressListener
import org.jboss.netty.channel.ChannelHandlerContext
import org.jboss.netty.channel.DefaultFileRegion
import org.jboss.netty.channel.ExceptionEvent
import org.jboss.netty.channel.MessageEvent
import org.jboss.netty.channel.SimpleChannelUpstreamHandler
import org.jboss.netty.handler.codec.frame.TooLongFrameException
import org.jboss.netty.handler.codec.http.DefaultHttpResponse
import org.jboss.netty.handler.codec.http.HttpHeaders
import org.jboss.netty.handler.codec.http.HttpMethod
import org.jboss.netty.handler.codec.http.HttpRequest
import org.jboss.netty.handler.codec.http.HttpResponse
import org.jboss.netty.handler.codec.http.HttpResponseStatus
import org.jboss.netty.handler.codec.http.HttpVersion
import org.jboss.netty.handler.ssl.SslHandler
import org.jboss.netty.handler.stream.ChunkedFile
import org.jboss.netty.util.CharsetUtil

object HttpStaticFileServerHandler {
  val HTTP_DATE_FORMAT: String = "EEE, dd MMM yyyy HH:mm:ss zzz"
  val HTTP_DATE_GMT_TIMEZONE: String = "GMT"
  val HTTP_CACHE_SECONDS: Int = 60

  val KNOWN_CONTENT_TYPES = Map(".scala" -> "text/plain; charset=UTF-8",
                                ".sbt" -> "text/plain; charset=UTF-8")

  def newDateFormatter: DateFormat = new SimpleDateFormat(HTTP_DATE_FORMAT, Locale.US)
}

class HttpStaticFileServerHandler(private val documentRoot: String) extends SimpleChannelUpstreamHandler {
  @throws(classOf[Exception])
  override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent): Unit = e.getMessage match {
    case request: HttpRequest =>
      for {
        _ <- booleanToOption(request.getMethod == HttpMethod.GET)
              .orElse(sendError(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED))
        path <- sanitizeUri(request.getUri)
              .orElse(sendError(ctx, HttpResponseStatus.FORBIDDEN))
        file <- Some(new File(path))
        _ <- booleanToOption(file.exists && !file.isHidden)
              .orElse(sendError(ctx, HttpResponseStatus.NOT_FOUND))
        _ <- booleanToOption(file.isFile)
              .orElse(sendError(ctx, HttpResponseStatus.FORBIDDEN))
        _ <- isModified(request, file)
              .orElse(sendNotModified(ctx))
        raf <- randomAccessFile(file)
                .orElse(sendError(ctx, HttpResponseStatus.NOT_FOUND))
        response <- httpResponse(file, raf)
        ch <- Some(e.getChannel)
        _ <- Some(ch.write(response))
        writeFuture <- writeContents(ch, path, raf)
        _ <- booleanToOption(!HttpHeaders.isKeepAlive(request))
        _ <- Some(writeFuture.addListener(ChannelFutureListener.CLOSE))
      } yield Unit
  }

  private def writeContents(ch: Channel, path: String, raf: RandomAccessFile): Option[ChannelFuture] =
    ch.getPipeline.get(classOf[SslHandler]) match {
      case null =>
        val region = new DefaultFileRegion(raf.getChannel, 0, raf.length)
        val writeFuture = ch.write(region)
        writeFuture.addListener(new ChannelFutureProgressListener {
          def operationComplete(future: ChannelFuture): Unit = region.releaseExternalResources()
          def operationProgressed(future: ChannelFuture, amount: Long, current: Long, total: Long): Unit =
            printf("%s: %d / %d (+%d)%n", path, current, total, amount)
        })
        Some(writeFuture)
      case _ => Some(ch.write(new ChunkedFile(raf, 0, raf.length, 8192)))
    }

  private def httpResponse(file: File, raf: RandomAccessFile): Option[HttpResponse] = {
    val response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)
    HttpHeaders.setContentLength(response, raf.length)
    setContentTypeHeader(response, file)
    setDateAndCacheHeaders(response, file)
    Some(response)
  }

  private def isModified(request: HttpRequest, file: File): Option[Boolean] = {
    stringToOption(request.getHeader(HttpHeaders.Names.IF_MODIFIED_SINCE)).map {
      ifModifiedSince =>
        val dateFormatter = HttpStaticFileServerHandler.newDateFormatter
        val ifModifiedSinceDate = dateFormatter.parse(ifModifiedSince)

        val ifModifiedSinceDateSeconds = ifModifiedSinceDate.getTime / 1000
        val fileLastModifiedSeconds = file.lastModified / 1000

        ifModifiedSinceDateSeconds != fileLastModifiedSeconds
    }.flatMap(booleanToOption)
  }

  private def randomAccessFile(file: File): Option[RandomAccessFile] =
    try {
      Some(new RandomAccessFile(file, "r"))
    } catch {
      case e: FileNotFoundException => None
    }

  private def booleanToOption(b: Boolean): Option[Boolean] =
    b match {
      case true => Some(b)
      case false => None
    }

  private def stringToOption(s: String): Option[String] =
    s match {
      case null => None
      case "" => None
      case _ => Some(s)
    }

  @throws(classOf[Exception])
  override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = {
    val ch = ctx.getChannel
    e.getCause match {
      case cause: TooLongFrameException => sendError(ctx, HttpResponseStatus.BAD_REQUEST)
      case cause =>
        cause.printStackTrace
        if (ch.isConnected) sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR)
    }
  }

  private def sanitizeUri(uri: String): Option[String] = {
    val decodedUri =
      try {
        URLDecoder.decode(uri, "UTF-8")
      } catch {
        case e: UnsupportedEncodingException =>
          URLDecoder.decode(uri, "ISO-8859-1")
      }

    Option(decodedUri)
      .map(_.replace('/', File.separatorChar))
      .filterNot(_.contains(File.separator + "."))
      .filterNot(_.contains("." + File.separator))
      .filterNot(_.startsWith("."))
      .filterNot(_.endsWith("."))
      .map(u => documentRoot + File.separator + u)
  }

  private def sendError[A](ctx: ChannelHandlerContext, status: HttpResponseStatus): Option[A] = {
    val response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, status)
    response.setHeader(HttpHeaders.Names.CONTENT_TYPE, "text/plain; charset=UTF-8")
    response.setContent(ChannelBuffers.copiedBuffer("Failure: " + status + "\r\n", CharsetUtil.UTF_8))

    ctx.getChannel.write(response).addListener(ChannelFutureListener.CLOSE)
    None
  }

  private def sendNotModified[A](ctx: ChannelHandlerContext): Option[A] = {
    val response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_MODIFIED)
    setDateHeader(response)

    ctx.getChannel.write(response).addListener(ChannelFutureListener.CLOSE)
    None
  }

  private def setDateHeader(response: HttpResponse): Unit = bindDateHeader(response)((r, f, c) => Unit)
  private def setDateAndCacheHeaders(response: HttpResponse, fileToCache: File): Unit =
    bindDateHeader(response) { (response, dateFormatter, time) =>
      time.add(Calendar.SECOND, HttpStaticFileServerHandler.HTTP_CACHE_SECONDS)
      response.setHeader(HttpHeaders.Names.EXPIRES, dateFormatter.format(time.getTime))
      response.setHeader(HttpHeaders.Names.CACHE_CONTROL, "private, max-age= " + HttpStaticFileServerHandler.HTTP_CACHE_SECONDS)
      response.setHeader(HttpHeaders.Names.LAST_MODIFIED, dateFormatter.format(new Date(fileToCache.lastModified)))
    }

  private def bindDateHeader[A](response: HttpResponse)(body: (HttpResponse, DateFormat, Calendar) => A): A = {
    val dateFormatter = HttpStaticFileServerHandler.newDateFormatter
    dateFormatter.setTimeZone(TimeZone.getTimeZone(HttpStaticFileServerHandler.HTTP_DATE_GMT_TIMEZONE))

    val time = new GregorianCalendar()
    response.setHeader(HttpHeaders.Names.DATE, dateFormatter.format(time.getTime))
    body(response, dateFormatter, time)
  }

  private def setContentTypeHeader(response: HttpResponse, file: File): Unit = {
    val filePath = file.getPath
    val ext = filePath.substring(filePath.lastIndexOf('.'), filePath.length)

    val contentType =
      HttpStaticFileServerHandler.KNOWN_CONTENT_TYPES.get(ext).getOrElse {
        val mimeTypesMap = new MimetypesFileTypeMap
        mimeTypesMap.getContentType(filePath)
      }
    response.setHeader(HttpHeaders.Names.CONTENT_TYPE, contentType)
  }
}

オリジナルからの大きな変更点は

  • HttpStaticFileServerHandlerクラスのコンストラクタ引数に、DocumentRootを追加
  • HttpStaticFileServerHandler#messageReceivedを、for式で表記する

です。その他は、あんまり大したことありません。あと、.scalaと.sbtはブラウザで見れるようにちょっと特別扱いしています。

せっかくScalaで書くのだから、returnを書きたくないというところからfor式で書くに至りました。もちろん、こんなことをしたがために、ムダにハマりましたが。

この部分ですね。

  @throws(classOf[Exception])
  override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent): Unit = e.getMessage match {
    case request: HttpRequest =>
      for {
        _ <- booleanToOption(request.getMethod == HttpMethod.GET)
              .orElse(sendError(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED))
        path <- sanitizeUri(request.getUri)
              .orElse(sendError(ctx, HttpResponseStatus.FORBIDDEN))
        file <- Some(new File(path))
        _ <- booleanToOption(file.exists && !file.isHidden)
              .orElse(sendError(ctx, HttpResponseStatus.NOT_FOUND))
        _ <- booleanToOption(file.isFile)
              .orElse(sendError(ctx, HttpResponseStatus.FORBIDDEN))
        _ <- isModified(request, file)
              .orElse(sendNotModified(ctx))
        raf <- randomAccessFile(file)
                .orElse(sendError(ctx, HttpResponseStatus.NOT_FOUND))
        response <- httpResponse(file, raf)
        ch <- Some(e.getChannel)
        _ <- Some(ch.write(response))
        writeFuture <- writeContents(ch, path, raf)
        _ <- booleanToOption(!HttpHeaders.isKeepAlive(request))
        _ <- Some(writeFuture.addListener(ChannelFutureListener.CLOSE))
      } yield Unit
  }

手っ取り早くfor式に乗せるため、Optionを使いました。

動作自体は、オリジナルと変わりません。

起動引数は、ポート番号とドキュメントルートの2つか、ドキュメントルートのみです。

$ sbt "run ."

で、このプロジェクトをドキュメントルートとしてポート8080で起動します。

ところで、NettyのHTTPのデコーダってリクエストにHTTPのバージョンを付けないと死んでしまう?ようです…。ブラウザからアクセスすると問題なく動作しますが、telnetで簡易的にアクセスすると

$ telnet localhost 8080
Trying 127.0.0.1...
Connected to localhost.
Escape character is '^]'.
GET /build.sbt
HTTP/1.1 500 Internal Server Error
Content-Type: text/plain; charset=UTF-8

Failure: 500 Internal Server Error
Connection closed by foreign host.

こうなります。

裏では、こんなスタックトレースが。

java.lang.IllegalArgumentException: empty text
	at org.jboss.netty.handler.codec.http.HttpVersion.<init>(HttpVersion.java:97)
	at org.jboss.netty.handler.codec.http.HttpVersion.valueOf(HttpVersion.java:62)
	at org.jboss.netty.handler.codec.http.HttpRequestDecoder.createMessage(HttpRequestDecoder.java:76)
	at org.jboss.netty.handler.codec.http.HttpMessageDecoder.decode(HttpMessageDecoder.java:187)
	at org.jboss.netty.handler.codec.http.HttpMessageDecoder.decode(HttpMessageDecoder.java:101)
	at org.jboss.netty.handler.codec.replay.ReplayingDecoder.callDecode(ReplayingDecoder.java:548)
	at org.jboss.netty.handler.codec.replay.ReplayingDecoder.messageReceived(ReplayingDecoder.java:445)
	at org.jboss.netty.channel.Channels.fireMessageReceived(Channels.java:268)
	at org.jboss.netty.channel.Channels.fireMessageReceived(Channels.java:255)
	at org.jboss.netty.channel.socket.nio.NioWorker.read(NioWorker.java:94)
	at org.jboss.netty.channel.socket.nio.AbstractNioWorker.processSelectedKeys(AbstractNioWorker.java:372)
	at org.jboss.netty.channel.socket.nio.AbstractNioWorker.run(AbstractNioWorker.java:246)
	at org.jboss.netty.channel.socket.nio.NioWorker.run(NioWorker.java:38)
	at java.util.concurrent.ThreadPoolExecutor$Worker.runTask(ThreadPoolExecutor.java:886)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:908)
	at java.lang.Thread.run(Thread.java:662)

これ、HTTPのバージョンをクラスのインスタンスで表しているため、何も指定しないと特定のバージョンにマップできないからなんでしょうねぇ。厳格のような気もしますが、HTTP/1.0くらいをデフォルトにして動作してくれてもいいんじゃ?って気もします。