動機
GLSLで
vec3 shade(...){ if(material == glass){ R * shade(reflection) + (1-R) * shade(refraction); } }
みたいなシェーダを書こうとして、「あれ、GLSLで再帰関数って書けないの……?」となり、whileで書く方法を調べていたところ、上のような関数呼び出しの後でいろいろ処理して値を返すような一般の関数に直接応用できる手法をなかなか見つけられず、さらにいろいろ調べたところ、コールスタックっぽいことをするとうまくいくようなので、書いてみようと思いました。
コールスタック及び呼び出し規約周りの知識があるとより分かりやすいかもしれません。
この投稿(英語)がかなりわかりやすいですが、本稿では複数の再帰を持つ、より一般的な例を扱っています。
前処理
①変換元
似たようなプログラムをpythonで書きました。
def f(x): if x == 0: return 1 elif x == 1: return 2 * f(x-1) + f(-(x-1)) else: return 2 * f(x+1) * f(-(x-1))
②まずreturn hoge
を ans = hoge; return ans
にします。
def f(x): if x == 0: ans = 1 return ans elif x == 1: ans = 2 * f(x-1) + f(-(x-1)) return ans else: ans = 2 * f(x+1) * f(-(x-1)) return ans
③次に関数呼び出し部を独立させます。
(コンパイラの中間表現を知っている方は ret = Call(f, x)
をイメージするとわかりやすいかもしれません)
def f(x): if x == 0: ans = 1 return ans elif x > 0: ret1 = f(x-1) ans = 2 * ret1 ret2 = f(-(x-1)) ans += ret2 return ans else: ret1 = f(x+1) ans = 2 * ret1 ret2 = f(-(x+1)) ans += ret2 return ans
④エントリーポイントにラベルL0
を、それ以降の関数呼び出し地点の直後に順にL1, L2, ...
をつけます:
def f(x): #L0 if x == 0: ans = 1 return ans elif x > 0: ret1 = f(x-1) #L1 ans = 2 * ret1 ret2 = f(-(x-1)) #L2 ans += ret2 return ans else: ret1 = f(x+1) #L3 ans = 2 * ret1 ret2 = f(-(x+1)) #L4 ans += ret2 return ans
⑤これをgoto文に書き換えて、基本ブロックを独立させます:
def f(x): #L0 if x == 0: ans = 1 return ans elif x > 0: ret1 = f(x-1) goto L1 else: ret1 = f(x+1) goto L2 #L1 ans = 2 * ret1 ret2 = f(-(x-1)) goto L2 #L2 ans += ret2 return ans #L3 ans = 2 * ret1 ret2 = f(-(x+1)) #L4 ans += ret2 return ans
⑥gotoをwhileに書き換えます。returnはまだそのままにします。
pc
という、「現在プログラム中のどの部分を処理しているか」の状態を表す変数で管理しています。
def f(x): pc = 0 # program counter while(True): if(pc == 0): #L0 if x == 0: ans = 1 return ans elif x > 0: ret1 = f(x-1) pc = 1 else: ret1 = f(x+1) pc = 1 elif (pc == 1): #L1 ans = 2 * ret1 ret2 = f(-(x-1)) pc = 2 elif (pc == 2): #L2 ans += ret2 return ans elif (pc == 3): #L3 ans = 2 * ret1 ret2 = f(-(x+1)) pc = 4 else: #L4 ans += ret2 return ans
メイン処理
ここからが本番です。return
とは「元の関数呼び出し文脈へのreturn」ですが、どうやってreturnするんやという問題を解決しなければなりません。
これを、コールスタック的な概念を導入することで解決します。アセンブリというかCPUの処理を模倣するようなイメージです。
ざっくりいうと、
- 前の関数呼び出しの結果を格納する変数
ret
を用意する。 - 関数呼び出しのときは
- スタックに、次に飛びたいラベルの番号、及び現在の(生きている)変数の値たちを積む。
- 設計上は「どこかの関数呼び出し地点の前後で生きていうる変数すべて」としたほうがシンプルなので、その方針にします。
- そのうえで関数引数をセットして、
- pcを0にする(+while文の先頭に戻る)ことで「関数呼び出し」が成立する。
- スタックに、次に飛びたいラベルの番号、及び現在の(生きている)変数の値たちを積む。
return
は- retに返したい値を代入する。
- スタックがあるなら、
- 呼び出し元で飛びたいと思っていたラベルの番号、及び呼び出し元での変数の値をスタックから取り出す。
- pc及び変数の値をセットすることで、元の「関数呼び出し文脈」に帰れる。
- スタックがないなら、
return ret
することで、"大元の関数fの値を返す"ようにする。
初めに「生きていうる変数」の一覧を作ります。④を見ると、xに加えansも生きている変数になりうることがわかります。
まずL0の関数呼び出しret1 = f(x-1); pc = 1
を変換します。上の手順に従うと、
stack.append([1, x, ans]) x = x - 1 pc = 0
次にL2のans += ret2; return ans
を変換します:
ans += ret ret = ans if len(stack) == 0: return ans else: pc, x, ans = stack.pop()
同様に変換すれば、次のような⑦最終形が得られます:
def g(x): ans = 0 # temporary variable ret = 0 # return value pc = 0 # program counter stack = [] stack.append([pc, x, ans]) while(True): if(pc == 0): if(x == 0): ret = 1 if len(stack) == 0: return ret else: pc, x, ans = stack.pop() elif(x > 0): stack.append([1, x, ans]) x = x - 1 pc = 0 else: stack.append([3, x, ans]) x = x + 1 pc = 0 elif pc == 1: ans = 2 * ret stack.append([2, x, ans]) x = -(x - 1) pc = 0 elif pc == 2: ans += ret ret = ans if len(stack) == 0: return ret else: pc, x, ans = stack.pop() elif pc == 3: ans = 2 * ret stack.append([4, x, ans]) x = -(x + 1) pc = 0 elif pc == 4: ans += ret ret = ans if len(stack) == 0: return ret else: pc, x, ans = stack.pop()
ひょっとすると、①→③→④→⑦くらいに飛ばすほうがわかりやすいかもしれません。
これを使うと相互再帰も同様の手順で書けると思われます。
自動変換するプログラムとか書けそう。
余談
なお、もう少し厳密に「呼び出し規約」を書き下すと、次のようになります:
- 「スタック」を用意する。プログラム中では
リスト
で実装している。 - 「レジスタ」をいくつか用意する。これはプログラム中では
変数
である。 - 「関数呼び出し」(
ret1 = f(x-1)
等)の手順は次の通り。- 「リターンアドレス」と「今生きている変数」
x
,ans
を「スタック」に積んでおく。- 「リターンアドレス」は、次に向かうべき
ラベル
になる。
- 「リターンアドレス」は、次に向かうべき
- 「引数」を引数格納用レジスタ
x
に入れる。 - 「呼び出し」を行うため「プログラムカウンタ」を関数の先頭にセットする。
- 「リターンアドレス」と「今生きている変数」
- 値を返す(returnのとき)ときは、
なので、本当は関数呼び出しの時点で、引数レジスタにあるxの値をcallee saveな汎用レジスタに移しておかなければならないのだろうけれど、それをすると余計ややこしくなりそう…