CLOVER🍀

That was when it all began.

JBoss NettyでWebSocket

今度は、チュートリアルのWebSocketをScalaで写経。いかにチュートリアルの写経とはいえ、WebSocketを使ったプログラムを書くのは今回が初めてです♪

写経元のチュートリアル
http://docs.jboss.org/netty/3.2/xref/org/jboss/netty/example/http/websocket/package-summary.html

WebSocketServer.scala

import java.net.InetSocketAddress
import java.util.concurrent.Executors

import org.jboss.netty.bootstrap.ServerBootstrap
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory

object WebSocketServer {
  def main(args: Array[String]): Unit = {
    val bootstrap = new ServerBootstrap(new NioServerSocketChannelFactory(
      Executors.newCachedThreadPool, Executors.newCachedThreadPool
    ))

    bootstrap.setPipelineFactory(new WebSocketServerPipelineFactory)
    bootstrap.bind(new InetSocketAddress(8080))
  }
}

WebSocketServerHandler.scala

import java.security.MessageDigest

import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers}
import org.jboss.netty.channel.{ChannelFuture, ChannelFutureListener, ChannelHandlerContext, ChannelPipeline, ExceptionEvent, MessageEvent, SimpleChannelUpstreamHandler}
import org.jboss.netty.handler.codec.http.{DefaultHttpResponse, HttpHeaders, HttpMethod, HttpRequest, HttpResponse, HttpResponseStatus, HttpVersion}
import org.jboss.netty.handler.codec.http.HttpHeaders.{Names, Values}
import org.jboss.netty.handler.codec.http.websocket.{DefaultWebSocketFrame, WebSocketFrame, WebSocketFrameDecoder, WebSocketFrameEncoder}
import org.jboss.netty.util.CharsetUtil

object WebSocketServerHandler {
  val WEBSOCKET_PATH: String = "/websocket"
}

class WebSocketServerHandler extends SimpleChannelUpstreamHandler {
  import WebSocketServerHandler._

  @throws(classOf[Exception])
  override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent): Unit = e.getMessage match {
    case req: HttpRequest => handleHttpRequest(ctx, req)
    case frame: WebSocketFrame => handleWebSocketFrame(ctx, frame)
  }

  @throws(classOf[Exception])
  private def handleHttpRequest(ctx: ChannelHandlerContext, req: HttpRequest): Unit = req.getMethod match {
    case HttpMethod.GET => handleGetMethod(ctx, req)
    case _ => sendHttpResponse(ctx, req, new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN))
  }

  private def handleGetMethod(ctx: ChannelHandlerContext, req: HttpRequest): Unit = req.getUri match {
    case "/" =>
      val res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)
      val content = WebSocketServerIndexPage.getContent(getWebSocketLocation(req))

      res.setHeader(Names.CONTENT_TYPE, "text/html; charset=UTF-8")
      HttpHeaders.setContentLength(res, content.readableBytes)

      res.setContent(content)
      sendHttpResponse(ctx, req, res)
    case WEBSOCKET_PATH if Values.UPGRADE.equalsIgnoreCase(req.getHeader(Names.CONNECTION)) && Values.WEBSOCKET.equalsIgnoreCase(req.getHeader(Names.UPGRADE)) =>
      val res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, new HttpResponseStatus(101, "Web Socket Protocol Handshake"))
      res.addHeader(Names.UPGRADE, Values.WEBSOCKET)
      res.addHeader(Names.CONNECTION, Values.UPGRADE)

      if (req.containsHeader(Names.SEC_WEBSOCKET_KEY1) && req.containsHeader(Names.SEC_WEBSOCKET_KEY2)) {
	res.addHeader(Names.SEC_WEBSOCKET_ORIGIN, req.getHeader(Names.ORIGIN))
	res.addHeader(Names.SEC_WEBSOCKET_LOCATION, getWebSocketLocation(req))

	req.getHeader(Names.SEC_WEBSOCKET_PROTOCOL) match {
	  case null =>
          case protocol => res.addHeader(Names.SEC_WEBSOCKET_PROTOCOL, protocol)
	}

	val (key1, key2) = (req.getHeader(Names.SEC_WEBSOCKET_KEY1), req.getHeader(Names.SEC_WEBSOCKET_KEY2))
        val (a, b, c) = ((key1.replaceAll("[^0-9]", "").toLong / key1.replaceAll("[^ ]", "").length).asInstanceOf[Int],
			 (key2.replaceAll("[^0-9]", "").toLong / key2.replaceAll("[^ ]", "").length).asInstanceOf[Int],
			 req.getContent.readLong)

	val input = ChannelBuffers.buffer(16)
	input.writeInt(a)
	input.writeInt(b)
	input.writeLong(c)

	val output = ChannelBuffers.wrappedBuffer(MessageDigest.getInstance("MD5").digest(input.array))
	res.setContent(output)
      } else {
	res.addHeader(Names.WEBSOCKET_ORIGIN, req.getHeader(Names.ORIGIN))
	res.addHeader(Names.WEBSOCKET_LOCATION, getWebSocketLocation(req))

	req.getHeader(Names.WEBSOCKET_PROTOCOL) match {
	  case null =>
	  case protocol => res.addHeader(Names.WEBSOCKET_PROTOCOL, protocol)
	}
      }

      val p = ctx.getChannel.getPipeline
      p.remove("aggregator")
      p.replace("decoder", "wsdecoder", new WebSocketFrameDecoder)

      ctx.getChannel.write(res)

      p.replace("encoder", "wsencoder", new WebSocketFrameEncoder)
    case _ => sendHttpResponse(ctx, req, new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN))
  }

  private def handleWebSocketFrame(ctx: ChannelHandlerContext, frame: WebSocketFrame): Unit =
    ctx.getChannel.write(new DefaultWebSocketFrame(frame.getTextData.toUpperCase))

  private def sendHttpResponse(ctx: ChannelHandlerContext, req: HttpRequest, res: HttpResponse): Unit = {
    res.getStatus.getCode match {
      case 200 =>
      case _ =>
	res.setContent(ChannelBuffers.copiedBuffer(res.getStatus.toString, CharsetUtil.UTF_8))
	HttpHeaders.setContentLength(res, res.getContent.readableBytes)
    }

    val f = ctx.getChannel.write(res)
    HttpHeaders.isKeepAlive(req) match {
      case false => f.addListener(ChannelFutureListener.CLOSE)
      case _ if res.getStatus.getCode != 200 => f.addListener(ChannelFutureListener.CLOSE)
      case _ =>
    }
  }

  override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent): Unit = {
    e.getCause.printStackTrace()
    e.getChannel.close()
  }

  private def getWebSocketLocation(req: HttpRequest): String = "ws://" + req.getHeader(Names.HOST) + WEBSOCKET_PATH
}

WebSocketServerIndexPage.scala

import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers}
import org.jboss.netty.util.CharsetUtil

object WebSocketServerIndexPage {
  def getContent(webSocketLocation: String): ChannelBuffer = ChannelBuffers.copiedBuffer(
    """|<html><head><title>Web Socket Test</title></head>
      |<body>
      |<script type="text/javascript">
      |var socket;
      |if (window.WebSocket) {
      |  socket = new WebSocket("%s");
      |  socket.onmessage = function(event) { writeConsole(event.data); };
      |  socket.onopen = function(event) { alert("Web Socket opened!"); };
      |  socket.onclose = function(event) { alert("Web Socket closed."); };
      |} else {
      |  alert("Your browser does not support Web Socket.");
      |}
      |
      |function send(message) {
      |  if (!window.WebSocket) { return; }
      |  if (socket.readyState == WebSocket.OPEN) {
      |    socket.send(message);
      |  } else {
      |    alert("The socket is not open.");
      |  }
      |}
      |
      |function writeConsole(message) {
      |  var p = document.createElement("p");
      |  p.innerHTML = message;
      |  document.getElementById("console").appendChild(p);
      |}
      |</script>
      |<form onsubmit="return false;">
      |<input type="text" name="message" value="Hello, World!"/>
      |<input type="button" value="Send Web Socket Data" onclick="send(this.form.message.value)" />
      |</form>
      |<div id="console"></div>
      |</body>
      |</html>
      |""".stripMargin.format(webSocketLocation),
    CharsetUtil.US_ASCII
  )    
}

WebSocketServerPipelineFactory.scala

import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels}
import org.jboss.netty.handler.codec.http.{HttpChunkAggregator, HttpRequestDecoder, HttpResponseEncoder}

class WebSocketServerPipelineFactory extends ChannelPipelineFactory {
  @throws(classOf[Exception])
  def getPipeline: ChannelPipeline = {
    val pipeline = Channels.pipeline
    pipeline.addLast("decoder", new HttpRequestDecoder)
    pipeline.addLast("aggregator", new HttpChunkAggregator(65536))
    pipeline.addLast("encoder", new HttpResponseEncoder)
    pipeline.addLast("handler", new WebSocketServerHandler)
    pipeline
  }
}

一部、チュートリアル通りではありません。JavaScriptがalertではなく、formより下のdivに書き込むようにしていますし、Scalaの機能を使った端折り方をしていたりします。あと、オリジナルのWebSocketServerHandler.javaでは、static importを使っている割には呼び出し方を省略したり省略していなかったりしていたので、勉強を兼ねてstatic import(Scalaではコンパニオンオブジェクトメンバーのimport)は今回は外しました。

動作そのものはチュートリアルと同じなので、Google Chrome
http://localhost:8080/
にアクセスしてテキストフィールドに文字を入力の上、ボタンを押下すると入力した文字列が大文字に変換してformの下に表示されます。残念ながら、メインで利用しているブラウザFirefoxではWebSocketが未サポートなので動きません…。

WebSocketを扱う場合は、NettyではなくJettyを使えばもっと楽に書けるらしいですね。その辺りは、また今度。

あと、この手のDaemonプログラムをsbtで直に実行してしまうと、無理に殺すことになってsbtごと死んでしまうので、コンパイルと実行は別々のsbtプロセスでやりました。

片方のsbtコンソールでコンパイル

> compile
[info] Compiling 1 Scala source to /xxxxx/netty-websocket/target/scala-2.9.0.final/classes...
[success] Total time: 1 s, completed Jul 18, 2011 9:53:39 PM

もう片方では直接実行。

$ sbt run
[info] Set current project to default (in build file:/xxxxx/netty-websocket/)
[info] Running WebSocketServer 

先にコンパイルしておかないと、sbt runした時にコンパイルが走るため、この場合はコンパイルに時間がかかり悲しい思いをすることになります…。