不動点演算子と再帰呼び出しのアスペクト

動機

再帰呼び出しにメモ化がかかってほしい。

概要

Scalaでは、高階関数と遅延評価が使えるので
不動点演算子を実装することができる。
自分で再帰を書く代わりに不動点演算子を使う書き方を試み
再帰呼び出しのメモ化ができることを確認する。

方針

不動点演算子とは、関数を別の関数に移すファンクターをとり、
その最小不動点を与える関数である。
参考: どう書く.org
http://ja.doukaku.org/150/nested/

再帰関数を、ファンクターの最小不動点として抽象化する。
不動点演算子をmixinで拡張することで、再帰呼び出しをインクリメンタルに拡張できるようにする。

再帰で実装

例題はフィボナッチ数列のn番目を求める関数。
不動点演算子を考える前に、まずは再帰で実装。
その後継承でメモ化してみる。

参考: (Javaの)継承で再帰呼び出しを書き換える過去ログ
http://d.hatena.ne.jp/kya-zinc/20090506/1241606381

package fix
import scala.collection.mutable

object FibMain {
  def main(args : Array[String]) : Unit = {
    println(new Fib()(6))
    println(new MemoFib()(60))
  }
}

class Fib extends (Int => BigInt){
  override def apply(n : Int) : BigInt = {
    if(n <= 1) 1 else apply(n-1) + apply(n-2)
  }
}

class MemoFib extends Fib {
  val map : mutable.Map[Int,BigInt] = mutable.Map()
  override def apply(n : Int) : BigInt = {
    if(!map.contains(n)) map += (n -> super.apply(n))
    map(n)
  }
}

メモ化で(ハッシュ関数が理想的なら)線形時間で計算できるようになる。
ただ、継承ではメモ化の実装を他の関数と共有できないのが問題。

不動点演算子で実装

fixを、FixpointFunctionクラスのパラメータ
fmapの不動点として実装する。
不動点が関数なので(A => B)をオーバーライドする。

その後、フィボナッチ関数をつくるファンクターfmapを与える。

package fix

object FixMain {
  import FixpointFunction.fix
  def main (args : Array[String]) : Unit = {
    val fmap : (Int => BigInt) => Int => BigInt = {
      f => x =>
        if (x <= 1) 1 else f(x-1) + f(x-2)
    }

    //Simple fixpoint
    val fib = fix(fmap)
    
    for(i <- 0 to 8){
      println(fib(i))
    }
  }
}

class FixpointFunction[A,B](fmap : (A => B) => A => B) extends (A => B){
  override def apply(x : A) = fix(fmap)(x)
  
  def fix(fmap : (A => B) => A => B)(x : A) : B = 
    fmap(fix(fmap))(x)
  
  override def toString = "<fixpoint>"
}

object FixpointFunction {
  def fix[A,B](fmap : (A => B) => A => B) = new FixpointFunction(fmap)
}

実行結果

1
1
2
3
5
8
13
21
34

fmapの不動点としてフィボナッチ数列を計算する関数が実装された。
これで、自分で再帰を書かなくても
FixpointFunctionが再帰呼び出しをしてくれるよ!

拡張

Fixpoint.fixをオーバーライドするtraitをmixinすることで
挙動を変えることができる。
再帰呼びだしの前後にprintするTraceFixpointFunctionと
メモ化の実装MemoFixpointFunctionを定義し、
フィボナッチ数列を計算する関数にmixinする。

package fix
import collection.mutable

object FixMain2 {
  def main(args : Array[String]) : Unit = {
    val fmap : (Int => BigInt) => Int => BigInt = {
      f => x =>
        if (x <= 1) 1 else f(x-1) + f(x-2)
    }

    //Mixin trace!
    class Fib2 extends FixpointFunction[Int,BigInt](fmap) with
      TraceFixpointFunction[Int,BigInt]{
      override def toString = "Fib2"
    }
    val fib2 = new Fib2()

    //Mixin memoize!
    class Fib3 extends Fib2 with MemoFixpointFunction[Int,BigInt]{
      override def toString = "Fib3"
    }
    val fib3 = new Fib3()

    println("====fib2====")
    println(fib2(5))
    println("====fib3====")
    println(fib3(60))
  }
}

trait TraceFixpointFunction[A,B] extends FixpointFunction[A,B] {
  override def fix(fmap : (A => B) => A => B)(x : A) : B = {
    println(toString + " called with " + x)
    val result = super.fix(fmap)(x)
    println(toString + " return " + result)
    result
  }
}

trait MemoFixpointFunction[A,B] extends FixpointFunction[A,B]{
  val history = mutable.Map[A,B]()
  
  override def fix(fmap : (A => B) => A => B)(x : A) : B = {
    if(!history.contains(x)) history += (x -> super.fix(fmap)(x))
    history(x)
  }
}

実行結果

====fib2====
Fib2 called with 5
Fib2 called with 4
Fib2 called with 3
Fib2 called with 2
Fib2 called with 1
Fib2 return 1
Fib2 called with 0
Fib2 return 1
Fib2 return 2
Fib2 called with 1
(中略)
Fib2 return 8
8
====fib3====
Fib3 called with 60
Fib3 called with 59
(中略)
Fib3 return 1548008755920
Fib3 return 2504730781961
2504730781961

Fib3にはmixinで両方の機能が追加できた。
一般的に、再帰呼び出しアスペクトとして
何か処理を挟みたい場合に使える。