再帰関数をwhileループに変換する

動機

GLSLで

vec3 shade(...){
  if(material == glass){
    R * shade(reflection) + (1-R) * shade(refraction);
   }
}

みたいなシェーダを書こうとして、「あれ、GLSLで再帰関数って書けないの……?」となり、whileで書く方法を調べていたところ、上のような関数呼び出しの後でいろいろ処理して値を返すような一般の関数に直接応用できる手法をなかなか見つけられず、さらにいろいろ調べたところ、コールスタックっぽいことをするとうまくいくようなので、書いてみようと思いました。

コールスタック及び呼び出し規約周りの知識があるとより分かりやすいかもしれません。

この投稿(英語)がかなりわかりやすいですが、本稿では複数の再帰を持つ、より一般的な例を扱っています。

前処理

①変換元

似たようなプログラムをpythonで書きました。

def f(x):
    if x == 0:
        return 1
    elif x == 1:
        return 2 * f(x-1) + f(-(x-1))
    else:
        return 2 * f(x+1) * f(-(x-1))

②まずreturn hogeans = hoge; return ans にします。

def f(x):
    if x == 0:
        ans = 1
        return ans
    elif x == 1:
        ans = 2 * f(x-1) + f(-(x-1))
        return ans
    else:
        ans = 2 * f(x+1) * f(-(x-1))
        return ans

③次に関数呼び出し部を独立させます。 (コンパイラの中間表現を知っている方は ret = Call(f, x) をイメージするとわかりやすいかもしれません)

def f(x):
    if x == 0:
        ans = 1
        return ans
    elif x > 0:
        ret1 = f(x-1)
        ans = 2 * ret1
        ret2 = f(-(x-1))
        ans += ret2
        return ans
    else:
        ret1 =  f(x+1)
        ans = 2 * ret1
        ret2 = f(-(x+1))
        ans += ret2
        return ans

④エントリーポイントにラベルL0を、それ以降の関数呼び出し地点の直後に順にL1, L2, ...をつけます:

def f(x):
    #L0
    if x == 0:
        ans = 1
        return ans
    elif x > 0:
        ret1 = f(x-1)
        #L1
        ans = 2 * ret1
        ret2 = f(-(x-1))
        #L2
        ans += ret2
        return ans
    else:
        ret1 =  f(x+1)
        #L3
        ans = 2 * ret1
        ret2 = f(-(x+1))
        #L4
        ans += ret2
        return ans

⑤これをgoto文に書き換えて、基本ブロックを独立させます:

def f(x):
#L0
    if x == 0:
        ans = 1
        return ans
    elif x > 0:
        ret1 = f(x-1)
        goto L1
    else:
        ret1 =  f(x+1)
        goto L2
#L1
        ans = 2 * ret1
        ret2 = f(-(x-1))
        goto L2
#L2
        ans += ret2
        return ans
#L3
        ans = 2 * ret1
        ret2 = f(-(x+1))
#L4
        ans += ret2
        return ans

⑥gotoをwhileに書き換えます。returnはまだそのままにします。 pc という、「現在プログラム中のどの部分を処理しているか」の状態を表す変数で管理しています。

def f(x):
    pc = 0    # program counter
    while(True):
        if(pc == 0):
        #L0
            if x == 0:
                ans = 1
                return ans
            elif x > 0:
                ret1 = f(x-1)
                pc = 1
            else:
                ret1 =  f(x+1)
                pc = 1
        elif (pc == 1):
        #L1
                ans = 2 * ret1
                ret2 = f(-(x-1))
                pc = 2
        elif (pc == 2):
        #L2
                ans += ret2
                return ans
        elif (pc == 3):
        #L3
                ans = 2 * ret1
                ret2 = f(-(x+1))
                pc = 4
        else:
        #L4
                ans += ret2
                return ans

メイン処理

ここからが本番です。return とは「元の関数呼び出し文脈へのreturn」ですが、どうやってreturnするんやという問題を解決しなければなりません。 これを、コールスタック的な概念を導入することで解決します。アセンブリというかCPUの処理を模倣するようなイメージです。

ざっくりいうと、

  • 前の関数呼び出しの結果を格納する変数retを用意する。
  • 関数呼び出しのときは
    • スタックに、次に飛びたいラベルの番号、及び現在の(生きている)変数の値たちを積む。
      • 設計上は「どこかの関数呼び出し地点の前後で生きていうる変数すべて」としたほうがシンプルなので、その方針にします。
    • そのうえで関数引数をセットして、
    • pcを0にする(+while文の先頭に戻る)ことで「関数呼び出し」が成立する。
  • return
    • retに返したい値を代入する。
    • スタックがあるなら、
      • 呼び出し元で飛びたいと思っていたラベルの番号、及び呼び出し元での変数の値をスタックから取り出す。
      • pc及び変数の値をセットすることで、元の「関数呼び出し文脈」に帰れる。
    • スタックがないなら、return retすることで、"大元の関数fの値を返す"ようにする。

初めに「生きていうる変数」の一覧を作ります。④を見ると、xに加えansも生きている変数になりうることがわかります。

まずL0の関数呼び出しret1 = f(x-1); pc = 1を変換します。上の手順に従うと、

stack.append([1, x, ans])
x = x - 1
pc = 0

次にL2のans += ret2; return ansを変換します:

ans += ret
ret = ans
if len(stack) == 0:
    return ans
else:
    pc, x, ans = stack.pop()

同様に変換すれば、次のような⑦最終形が得られます:

def g(x):
    ans = 0        # temporary variable 
    ret = 0         # return value
    pc = 0         # program counter
    
    stack = []
    stack.append([pc, x, ans])
    while(True):
        if(pc == 0):
            if(x == 0):
                ret = 1
                if len(stack) == 0:
                    return ret
                else:
                    pc, x, ans = stack.pop()
            elif(x > 0):
                stack.append([1, x, ans])
                x = x - 1
                pc = 0
            else:
                stack.append([3, x, ans])
                x = x + 1
                pc = 0
        elif pc == 1:
            ans = 2 * ret
            stack.append([2, x, ans])
            x = -(x - 1)
            pc = 0
        elif pc == 2:
            ans += ret
            ret = ans
            if len(stack) == 0:
                return ret
            else:
                pc, x, ans = stack.pop()
        elif pc == 3:
            ans = 2 * ret
            stack.append([4, x, ans])
            x = -(x + 1)
            pc = 0
        elif pc == 4:
            ans += ret
            ret = ans
            if len(stack) == 0:
                return ret
            else:
                pc, x, ans = stack.pop()

ひょっとすると、①→③→④→⑦くらいに飛ばすほうがわかりやすいかもしれません。

これを使うと相互再帰も同様の手順で書けると思われます。

自動変換するプログラムとか書けそう。

余談

なお、もう少し厳密に「呼び出し規約」を書き下すと、次のようになります:

  • 「スタック」を用意する。プログラム中ではリストで実装している。
  • レジスタ」をいくつか用意する。これはプログラム中では変数である。
    • 「プログラムカウンタ」用のレジスタpcに加え、
    • 「返り値」の格納用レジスタ retRISC-Vでいうa0)
    • 「引数」の格納用レジスタxRISC-Vでいう a1, a2, ...。RISC-Vだとa0と共用だが混乱を避けるため別にしておく)
    • 「関数呼び出し時に生きている変数」を保持しておくための汎用レジスタx, ans(t0, t1, ...)
  • 「関数呼び出し」(ret1 = f(x-1)等)の手順は次の通り。
    • 「リターンアドレス」と「今生きている変数」x, ans を「スタック」に積んでおく。
      • 「リターンアドレス」は、次に向かうべきラベルになる。
    • 「引数」を引数格納用レジスタxに入れる。
    • 「呼び出し」を行うため「プログラムカウンタ」を関数の先頭にセットする。
  • 値を返す(returnのとき)ときは、
    • ret に値を格納する
    • スタックから生きている変数の値x, ansを回収し現状復帰する
      • レジスタ割り当てが面倒なので全部の呼び出し地点での生きている変数を全部保持しておくことにする
      • saveはcallerだがrestoreはcallee
      • スタックがない場合は"もともとの関数fが値を返す"場合に相当するためreturn retする。
    • スタックから回収したプログラムカウンタの値をpcレジスタにセットして呼び出し元文脈に戻る。

なので、本当は関数呼び出しの時点で、引数レジスタにあるxの値をcallee saveな汎用レジスタに移しておかなければならないのだろうけれど、それをすると余計ややこしくなりそう…