CLOVER🍀

That was when it all began.

Gauche本の継続を、Scalaで書く

久々の更新です。先週末、先々週末は共に出勤だったため、さすがにブログ更新どころじゃなかったですからね…。今週末は、両方ともお休みです。

まあ、次の週末が休みとなる可能性は、けっこう微妙なのですが…。

さて、Gauche本を片手にSchemeGaucheの勉強をしていたのですが、やっと19章の「継続」に辿り着きました。Gaucheで書いていて、へ〜って思うところもあるのですが、やっぱりまだ理解しきれていないように思うので、思考することを兼ねてScalaで書き直すことにしました。

Gauche本の19.2「Schemeによる継続渡しの表現」〜19.3「さらに継続を渡して」を対象に、Scalaで書き換えます。

まずは基礎となる関数、findFoldとprocess。

import scala.annotation.tailrec

@tailrec
def findFold[T](pred: T => Boolean,
                proc: (T, List[T]) => List[T],
                seed: List[T],
                xs: List[T]): List[T] = {
  xs match {
    case Nil => seed
    case y :: ys if pred(y) =>
      val seed2 = proc(y, seed)
      findFold(pred, proc, seed2, ys)
    case _ => findFold(pred, proc, seed, xs.tail)
  }
}

def process[T](elt: T, seed: List[T]): List[T] = {
  println("found: " + elt)
  elt :: seed
}

関数findFoldは、初期値seedを取り、引数xsを分解しつつ各要素に対して述語predを適用してtrueが返れば、分解した要素とseedをprocに適用し、これを初期値にfindFoldを再帰呼び出しする関数です。

このサンプルで何度も登場する関数とListは、あらかじめ宣言しておきました。

val tenList = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
def odd(x: Int) = x % 2 != 0

これを、REPLで読み込みます。

scala> :load cont.scala
Loading cont.scala...
import scala.annotation.tailrec
tenList: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
odd: (x: Int)Boolean
findFold: [T](pred: T => Boolean, proc: (T, List[T]) => List[T], seed: List[T], xs: List[T])List[T]
process: [T](elt: T, seed: List[T])List[T]

Gauche本通りに実行。

scala> findFold[Int](odd, process, Nil, tenList)
found: 1
found: 3
found: 5
found: 7
found: 9
res28: List[Int] = List(9, 7, 5, 3, 1)

findFoldを型指定して呼び出さなければならないところが、少し悲しい…。各要素が述語oddがtrueを返した時に表示され、processによって繋げられたListが最後に返ってきています。

で、ここからはGauche本に沿ってこの関数を継続スタイルに変えていきます。findFoldのこの部分を変えるんだそうな。

      val seed2 = proc(y, seed)
      findFold(pred, proc, seed2, ys)

そして、引数procを引数に継続を取るものに変更したのが以下。

def findFold2[T](pred: T => Boolean,
                 procWithCont: (T, List[T], (List[T] => List[T])) => List[T],
                 seed: List[T],
                 xs: List[T]): List[T] = {
  xs match {
    case Nil => seed
    case y :: ys if pred(y) =>
      procWithCont(y, seed, seed2 => findFold2(pred, procWithCont, seed2, ys))
    case _ => findFold2(pred, procWithCont, seed, xs.tail)
  }
}

全部型を書かなくてはいけないので、Gaucheと比べるとちょっと面倒…。procWithContが引数に取る継続は、とりあえず簡単のためListを取りListを返す関数、ということにしておきました。

次に、関数processを継続を受けとるように変更したのが以下。

def processWithCont[T](elt: T, seed: List[T], cont: List[T] => List[T]): List[T] = {
  println("found: " + elt)
  cont(elt :: seed)
}

REPLで読み込んで実行すると、こうなります。

scala> findFold2[Int](odd, processWithCont, Nil, tenList)
found: 1
found: 3
found: 5
found: 7
found: 9
res30: List[Int] = List(9, 7, 5, 3, 1)

見た目上の結果は、何も変わっていません…。が、findFold2では自己再帰呼び出しではなくて、あくまで引数に取った関数procWithContの結果が返っているため、動き的にはちょっと違います。findFold2が繰り返し呼び出される間に、procWithCont関数が挟まっているわけですね。

これを明らかにするため、Gauche本が提示していた以下の関数…をScalaで書き直しました。

var next: (() => List[_]) = () => Nil
def break[T](value: List[T]) = value

def processWithBreak[T](elt: T, seed: List[T], cont: List[T] => List[T]): List[T] = {
  next = () => {
    println("found: " + elt)
    cont(elt :: seed)
  }
  break(Nil)
}

nextは途中の処理内容を束縛するための変数で、breakは与えられた引数をそのまま返す関数です。

では、REPLで読み込んで実行。

scala> findFold2[Int](odd, processWithBreak, Nil, tenList)
res31: List[Int] = List()

空のListが返ってきました。結果は?ってことで、変数nextを評価してみます。

scala> next()
found: 1
res32: List[Any] = List()

なんか出力されましたね。戻り値は、相変わらず空のListですが。そのまま、続けて実行。

scala> next()
found: 3
res33: List[Any] = List()

scala> next()
found: 5
res34: List[Any] = List()

scala> next()
found: 7
res35: List[Any] = List()

scala> next()
found: 9
res36: List[Any] = List(9, 7, 5, 3, 1)

これは、さっきまで一括で出力されていたものが、変数nextの評価時に1つずつ進んで行っている感じですね。5回目を呼び出したところで、戻り値としてこれまで返ってきていた、5つの要素のListが返却されています。

ポイントは、関数findFold2の以下の部分、

      procWithCont(y, seed, seed2 => findFold2(pred, procWithCont, seed2, ys))

そして関数processWithBreakですね。

def processWithBreak[T](elt: T, seed: List[T], cont: List[T] => List[T]): List[T] = {
  next = () => {
    println("found: " + elt)
    cont(elt :: seed)
  }
  break(Nil)
}

findFold2の引数に渡したListが、これ以上分解不可能になるまではnextを呼び出すことで、次のfindFold2関数の呼び出しが実行されます。この時、findFold2の戻り値はbreak関数の結果になります。なので、Listが尽きるまではnextの戻り値がNilだったというわけですね。

Listを分解しきったところでは、findFold2関数の結果は以下のcase式にマッチします。よって、それ以降は変数nextを何回評価してもseedの値が返ります。

    case Nil => seed

この時のseedの値は、nextに再束縛されて評価され続けていた

    cont(elt :: seed)

の累積となるので、

res37: List[Any] = List(9, 7, 5, 3, 1)

が返るというわけですね。

最後、本に習ってfindFold関数自体を継続を取るように変更。引数多いなぁ。

def findFoldWithCont[T](pred: T => Boolean,
                        procWithCont: (T, List[T], (List[T] => List[T])) => List[T],
                        seed: List[T],
                        xs: List[T],
                        cont: List[T] => List[T]): List[T] = {
  xs match {
    case Nil => cont(seed)
    case y :: ys if pred(y) =>
      procWithCont(y, seed, seed2 => findFoldWithCont(pred, procWithCont, seed2, ys, cont))
    case _ =>
      findFoldWithCont(pred, procWithCont, seed, xs.tail, cont)
  }
}

実行。

scala> findFoldWithCont[Int](odd, processWithCont, Nil, tenList, xs => { println(xs); xs })
found: 1
found: 3
found: 5
found: 7
found: 9
List(9, 7, 5, 3, 1)
res40: List[Int] = List(9, 7, 5, 3, 1)

res40の手前にあるListの出力が、第5引数の無名関数でprintlnしている部分ですね。関数の結果自体も、この無名関数の引数そのものになっています。

なかなか長かったですが、けっこう面白かったです。まあ、継続の勉強としてはまだまだ初歩なんですけどね…。

一応、全コードを貼っておきます。
cont.scala

import scala.annotation.tailrec

val tenList = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
def odd(x: Int) = x % 2 != 0

@tailrec
def findFold[T](pred: T => Boolean,
                proc: (T, List[T]) => List[T],
                seed: List[T],
                xs: List[T]): List[T] = {
  xs match {
    case Nil => seed
    case y :: ys if pred(y) =>
      val seed2 = proc(y, seed)
      findFold(pred, proc, seed2, ys)
    case _ => findFold(pred, proc, seed, xs.tail)
  }
}

def process[T](elt: T, seed: List[T]): List[T] = {
  println("found: " + elt)
  elt :: seed
}

def findFold2[T](pred: T => Boolean,
                 procWithCont: (T, List[T], (List[T] => List[T])) => List[T],
                 seed: List[T],
                 xs: List[T]): List[T] = {
  xs match {
    case Nil => seed
    case y :: ys if pred(y) =>
      procWithCont(y, seed, seed2 => findFold2(pred, procWithCont, seed2, ys))
    case _ => findFold2(pred, procWithCont, seed, xs.tail)
  }
}

def processWithCont[T](elt: T, seed: List[T], cont: List[T] => List[T]): List[T] = {
  println("found: " + elt)
  cont(elt :: seed)
}

var next: (() => List[_]) = () => Nil
def break[T](value: List[T]) = value

def processWithBreak[T](elt: T, seed: List[T], cont: List[T] => List[T]): List[T] = {
  next = () => {
    println("found: " + elt)
    cont(elt :: seed)
  }
  break(Nil)
}

def findFoldWithCont[T](pred: T => Boolean,
                        procWithCont: (T, List[T], (List[T] => List[T])) => List[T],
                        seed: List[T],
                        xs: List[T],
                        cont: List[T] => List[T]): List[T] = {
  xs match {
    case Nil => cont(seed)
    case y :: ys if pred(y) =>
      procWithCont(y, seed, seed2 => findFoldWithCont(pred, procWithCont, seed2, ys, cont))
    case _ =>
      findFoldWithCont(pred, procWithCont, seed, xs.tail, cont)
  }
}