CLOVER🍀

That was when it all began.

Clojureで学ぶ、UDPネットワークプログラミング - Non Blocking IO

前回はReliable UDPを書くのにClojureをすっ飛ばしてしまいましたが、今回はClojureから始めます。

お題は、NIOを使ったUDP Echoクライアント/サーバ。外見上の動作は、前と変わらないので省略…。

最初はClojure版から入ったわけですが、まーてこずりました。

udp_nio_server.clj

(import '(java.net DatagramPacket InetSocketAddress)
        '(java.nio ByteBuffer)
        '(java.nio.channels DatagramChannel SelectionKey Selector)
        '(java.util Date))
(require '[clojure.string :as str])

(def timeout 50000)
(def buffer (ByteBuffer/allocate 8192))
(def output-queue (ref []))
(def address (InetSocketAddress. 50000))

(defn log [word & more-words]
  (println (str \[ (Date.) \] \ (str/join \  (conj more-words word)))))

(defn handle-readable [key]
  (let [channel (.channel key)]
    (.clear buffer)
    (let [address (.receive channel buffer)]
      (when (not (nil? address))
        (.flip buffer)
        (let [bytes (byte-array (.limit buffer))]
          (.get buffer bytes)
          (let [reply-word (str "Non Blocking Server Reply => " (String. bytes "UTF-8"))
                reply-word-binary (.getBytes reply-word "UTF-8")]
            (.clear buffer)
            (.put buffer reply-word-binary)
            (.flip buffer)))
        (.send channel buffer address)

        (when-let [cnt (> (.remaining buffer) 0)]
          (let [bytes (byte-array cnt)]
            (.get buffer bytes)
            (dosync
             (alter output-queue
                    conj
                    (DatagramPacket. bytes cnt address))))
          (.interestOps key (bit-or SelectionKey/OP_READ SelectionKey/OP_WRITE)))))))

(defn handle-writable [key]
  (let [channel (.channel key)]
    (loop []
      (when-not (empty? output-queue)
        (let [packet (first output-queue)]
          (.clear buffer)
          (.put buffer (.getData packet))
          (.flip buffer)
          (.send channel buffer (.getSocketAddress channel))

          (if (.hasRemaining buffer)
            (recur)
            (do
              (dosync
               (alter output-queue
                      rest))
              (recur))))))))

(with-open [channel (DatagramChannel/open)]
  (.configureBlocking channel false)
  (.bind (.socket channel) address)

  (let [selector (Selector/open)]
    (.register channel selector SelectionKey/OP_READ)

    (log "Non Blocking Clojure UDP Server" address "Startup.")

    (doseq [_ (take-while true? (repeatedly #(not (empty? (.keys selector)))))]
      (.select selector timeout)
      (let [selected-keys (.selectedKeys selector)]
        (locking selected-keys
          (do
            (doseq [key selected-keys]
              (when (.isValid key)
                (when (.isReadable key)
                  (handle-readable key))
                (when (.isWritable key)
                  (handle-writable key))))
            (.clear selected-keys)))))))

TCPの時と違って、ServerSocketChannel/SocketChannelみたいなものは、DatagramChannelに一本化されるので、見方によっては簡単になるんでしょうか…。

ここでのポイントは、DatagramChannelをopenしたら、Non Blockingモードに設定すること、

(with-open [channel (DatagramChannel/open)]
  (.configureBlocking channel false)
  (.bind (.socket channel) address)

Selector#selectをするところとかは、TCPの時とまあ同じ。

      (.select selector timeout)
      (let [selected-keys (.selectedKeys selector)]

ちょっと違うのは、最初のDatagramChannel#register時に、SelectionKey.OP_ACCEPTを指定する必要がありません。

  (let [selector (Selector/open)]
    (.register channel selector SelectionKey/OP_READ)

その他、loop/recurや条件分岐がゴチャゴチャしているのは、ちょっとお愛嬌…。

次に、クライアント側。
udp_nio_client.clj

(import '(java.net InetSocketAddress)
        '(java.nio ByteBuffer)
        '(java.nio.channels DatagramChannel)
        '(java.util Date))
(require '[clojure.string :as str])

(def timeout 50000)
(def buffer (ByteBuffer/allocate 8192))
(def address (InetSocketAddress. 50000))

(defn log [word & more-words]
  (println (str \[ (Date.) \] \ (str/join \  (conj more-words word)))))

(with-open [channel (DatagramChannel/open)]
  (.configureBlocking channel false)
  (.connect channel address)

  (doseq [_ (take-while #(not (.isConnected %))
                        (repeat channel))])

  (log "Non Blocking Clojure UDP Client" "Startup.")

  (letfn [(send-receive [word]
            (.clear buffer)
            (.put buffer (.getBytes word "UTF-8"))
            (.flip buffer)

            (loop [n (.send channel buffer address)]
              (when (= n 0)
                (recur (.send channel buffer address))))

            (.clear buffer)

            (loop [addr (.receive channel buffer)]
              (when (nil? addr)
                (recur (.receive channel buffer))))

            (.flip buffer)
            (println (String. (.array buffer)
                              (.position buffer)
                              (.limit buffer))))]
    (doseq [word (take-while #(not (or (nil? %) (empty? %) (= % "exit")))
                             (repeatedly read-line))]
    (send-receive word))))

Non Blockingモードにしているがゆえに、DatagramChannel#isConnectがtrueになるのを待ち続けるこの始末。

  (doseq [_ (take-while #(not (.isConnected %))
                        (repeat channel))])

ここは…本筋とあんまり関係ないのですが、DatagramChannelのsend/receiveで妙にハマりました。ここですね。

            (loop [n (.send channel buffer address)]
              (when (= n 0)
                (recur (.send channel buffer address))))

            (.clear buffer)

            (loop [addr (.receive channel buffer)]
              (when (nil? addr)
                (recur (.receive channel buffer))))

ホントはrepeatとtake-whileでなんとかしたかったのですが、イマイチうまくいかなかったので諦めて明示的な自己再起にしてしまいました。

動かすと、こんな感じになります。サーバを起動。

$ clj udp_nio_server.clj 
[Mon Sep 23 22:31:12 JST 2013] Non Blocking Clojure UDP Server 0.0.0.0/0.0.0.0:50000 Startup.

クライアント側。

$ clj udp_nio_client.clj 
[Mon Sep 23 22:31:47 JST 2013] Non Blocking Clojure UDP Client Startup.
Hello World
Non Blocking Server Reply => Hello World
こんにちは、世界
Non Blocking Server Reply => こんにちは、世界
Hello Clojure!!
Non Blocking Server Reply => Hello Clojure!!
exit

UDPを使ったNIOの例、特にクライアント側は全然見つかる様子がなかったので(需要ないんでしょうなー)、TCPの例を見つつ適当に実装してみました。

まあ、動いてるからとりあえずいっかぁ…。

最後は、Scala版です。ちょっと長いですが、こちらは普通に書けました。ということは、Clojure力不足が目立っているわけですね…。
UdpNioClientServer.scala

import scala.collection.JavaConverters._

import java.net.{DatagramPacket, InetSocketAddress}
import java.nio.ByteBuffer
import java.nio.channels.{DatagramChannel, SelectionKey, Selector}
import java.nio.charset.StandardCharsets
import java.util.Date

import UdpNioHelper._

object UdpNioHelper {
  val TIMEOUT: Long = 5000L
  val PORT: Int = 50000

  def log(msg: Any, moreMsgs: Any*): Unit =
    println(s"[${new Date}] ${(msg :: moreMsgs.toList).mkString(" ")}")

  implicit class AutoCloseableWrapper[A <: AutoCloseable](val underlying: A) extends AnyVal {
    def foreach(fun: A => Unit): Unit =
      try {
        fun(underlying)
      } finally {
        underlying.close()
      }
  }
}

object UdpNioServer {
  val buffer = ByteBuffer.allocate(8192)
  var outputQueue = Vector[DatagramPacket]()

  def main(args: Array[String]): Unit = {
    for (channel <- DatagramChannel.open()) {
      val address = new InetSocketAddress(PORT)
      channel.configureBlocking(false)
      channel.socket.bind(address)

      log("Non Blocking Scala UDP Server", address, "Startup.")

      val selector = Selector.open()
      channel.register(selector, SelectionKey.OP_READ)

      Iterator
        .continually(selector.keys)
        .takeWhile(!_.isEmpty)
        .foreach { _ =>
          selector.select(TIMEOUT)

          val selectedKeys = selector.selectedKeys
          selectedKeys synchronized {
            for (key <- selectedKeys.asScala) {
              if (key.isValid) {
                if (key.isReadable) {
                  handleReadable(key)
                }

                if (key.isWritable) {
                  handleWritable(key)
                }
              }

              selectedKeys.clear()
            }
          }
        }
    }
  }

  private def handleReadable(key: SelectionKey): Unit = {
    val channel = key.channel.asInstanceOf[DatagramChannel]

    buffer.clear()
    channel.receive(buffer) match {
      case null =>
      case address =>
        buffer.flip()

        val bytes = Array.ofDim[Byte](buffer.limit)
        buffer.get(bytes)
        val replyWord = "Non Blocking Server Reply => " + new String(bytes, StandardCharsets.UTF_8)
        val replyWordBinary = replyWord.getBytes(StandardCharsets.UTF_8)

        buffer.clear()
        buffer.put(replyWordBinary)
        buffer.flip()

        channel.send(buffer, address)

        val count = buffer.remaining
        if (count > 0) {
          val remainBytes = Array.ofDim[Byte](count)
          buffer.get(remainBytes)
          outputQueue = outputQueue :+ new DatagramPacket(remainBytes, count, address)

          key.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE)
        }
    }
  }

  private def handleWritable(key: SelectionKey): Unit = {
    val channel = key.channel.asInstanceOf[DatagramChannel]

    Iterator
      .continually(outputQueue)
      .takeWhile(!_.isEmpty)
      .foreach { _ =>
        val packet = outputQueue.head

        buffer.clear()
        buffer.put(packet.getData)
        buffer.flip()

        channel.send(buffer, packet.getSocketAddress)

        if (!buffer.hasRemaining) {
          outputQueue = outputQueue.tail
        }
      }
  }
}

object UdpNioClient {
  def main(args: Array[String]): Unit = {
    val address = new InetSocketAddress(PORT)
    val buffer = ByteBuffer.allocate(8192)

    for (channel <- DatagramChannel.open()) {
      channel.configureBlocking(false)
      channel.connect(address)

      Iterator
        .continually(channel)
        .takeWhile(!_.isConnected)
        .foreach { _ => }

      log("Non Blocking Scala UDP Client", "Startup.")

      Iterator
        .continually(readLine())
        .takeWhile(word => word != null && !word.isEmpty && word != "exit")
        .foreach { word =>
          buffer.clear()
          buffer.put(word.getBytes(StandardCharsets.UTF_8))
          buffer.flip()

          Iterator
            .continually(channel.send(buffer, address))
            .takeWhile(_ == 0)
            .foreach { _ => }

          buffer.clear()

          Iterator
            .continually(channel.receive(buffer))
            .takeWhile(_ == null)
            .foreach { _ =>  }
        
          buffer.flip()
          println(new String(buffer.array,
                             buffer.position,
                             buffer.limit,
                             StandardCharsets.UTF_8))
      }
    }
  }
}