スタックオーバーフローしない相互再帰

概要

相互再帰(互いに呼びだし合う再帰)が
末尾再帰(再帰呼び出しが末尾のみに出現)の場合、
ループでスタックを積まずに実行することができる。
このための高階関数を用意する。

参考: 相互再帰
http://ja.wikipedia.org/wiki/相互再帰

相互再帰の例: Intの偶数、奇数判定
package mut

object Even {
  def main(args : Array[String]) : Unit = {
    println(isEven(100))
    println(isOdd(100))
    println(isEven(101))
    println(isOdd(101))
    println(isEven(10000))
    println(isOdd(10000))
  }
  
  def isEven(n : Int) : Boolean = {
    if (n < 0) { throw new IllegalArgumentException }
    if (n == 0) true else isOdd(n-1)
  }
  
  def isOdd(n : Int) : Boolean = {
    if (n < 0) { throw new IllegalArgumentException }
    if (n == 0) false else isEven(n-1)
  }
}

自然数n > 1が偶数であることは、(n-1)が奇数であることと同値。
自然数n > 1が奇数であることは、(n-1)が偶数であることと同値。
という方針。
ただし、メソッド呼び出しは簡単にスタックオーバーフローする。

方針

定義域がA型の関数fと、B型の関数gが、相互再帰
C型の値を返す関数として実装する。

fを、C型の結果を返すかgに渡すためのB型の値を作る関数
gを、C型の結果を返すかfに渡すためのA型の値を作る関数
として、Either型を返す関数として定義する。

相互再帰呼び出しの代わりに、fとgをループで交互に実行する。
C型が返ってきたら終了。

Either型は、RightかLeftがパターンマッチで判定できるdisjoint sumなので、
A = CやB = CのときでもOK.

実装

(A => C).applyをオーバーライドして実装。
もう片方の関数B => Cも使えるようにanotherをメソッドとして用意する。

単純にwhile(true) + パターンマッチ。
whileの後でRuntimeExceptionを投げているのは、
ブロックの最後にC型の値を書かないとコンパイルが通らなかったから。これはJavaと違う挙動。

package mut

object MutRecMain{
  def main(args : Array[String]) : Unit = { 
    val f : Int => Either[Int,Boolean] = { n : Int =>
      if(n < 0) throw new IllegalArgumentException
      if(n == 0) Right(true) else Left(n-1) 
    }
    val g : Int => Either[Int,Boolean] = { n : Int =>
      if(n < 0) throw new IllegalArgumentException
      if(n == 0) Right(false) else Left(n-1) 
    }
    
    val isEven = new MutRec(f,g)
    val isOdd = isEven.another(_)
    println("0 is even: " + isEven(0))
    println("0 is odd: " + isOdd(0))
    println("1 is even: " + isEven(1))
    println("1 is odd: " + isOdd(1))
    println("10000000 is even: " + isEven(10000000))
    println("10000000 is odd: " + isOdd(10000000))
  }
}

class MutRec[A,B,C](f : A => Either[B,C], g : B => Either[A,C]) extends (A => C){  
  def apply(a : A) : C = {
    var a2 = a
    while(true){
      f(a2) match {
        case Right(c) => return c
        case Left(b) => {
          g(b) match {
            case Right(c) => return c
            case Left(a) => a2 = a
          }
        }
      }
    }
    throw new RuntimeException
  }

  def another (b : B) : C = g(b) match {
    case Right(c) => c
    case Left(a) => apply(a)
  }  
}

実行結果

0 is even: true
0 is odd: false
1 is even: false
1 is odd: true
10000000 is even: true
10000000 is odd: false

引数が10000000でも大丈夫。

言うまでもないことですが、効率化のためには剰余を使いましょう。

val isEven = (n : Int) => n % 2 == 0