memoizeでPythonの再帰計算をキャッシュして高速化
背景
ここ数ヶ月,頭の体操とPythonでの数値処理に慣れるのとで,project eulerの問題を順に解いていってます*1.割と初等整数論を使った計算問題が多いので,単純に公式を調べてなるほどこういう公式があるんだーとか感心しながら計算してます*2.
で,その際によくあるのがProblem 3 - Project Eulerのように,3桁とかであれば簡単に手計算できるものについてこれが20桁だったときどうするか,という風に,計算量をちゃんと考えないと実時間で計算が終わらない問題です.そんなときによくやるのが,繰り返し使う値をキャッシュするってやつです.今回はこれについてのお話.
単純なキャッシュ
とりあえずPythonで簡単にキャッシュするだけなら,グローバルにリストを定義して,ひたすらそこに追加すればOKです.例えば100万以下の素数をすべて計算する場合,以下のようにprimesリストを最初に作っておいて,新しい素数が見つかったらキャッシュに追加しておけばOKです*3.
primes = [2] def is_prime(n): # すでに素数を探索している範囲では,素数リストで割り切れるかだけチェックすればOK for p in primes: if n % p == 0: return False # 素数判定の場合,対象の数の平方根以下の整数までしかチェックする必要はない for i in range(max(primes)+1, int(math.sqrt(n))+1): if n % i == 0: return False return True for i in range(2, 1000000): if is_prime(i): primes.append(i) print primes
とはいえ,この形だとグローバルに変数持ってるし,あんまり汎用性もなくてしょんぼりすることが多いです.
デコレータを使ってキャッシュ実装
デコレータとクロージャ
上記のようなキャッシュのことをmemoization*4.といいます.で,Pythonの場合デコレータ+クロージャを使ってこれを汎用的に実装可能です.以下の記事がよくまとまっているので読んでください.
で,このクロージャを使って,キャッシュ関数を付け加えましょうという話です.具体例としてProblem 76 - Project Eulerをみてみます.これは分割数についての問題で,補助関数を使うことで再帰的に解くことができます.具体的なコードは以下のとおり.
def p(k, n): if k > n: return 0 elif k == n: return 1 return p(k+1, n)+p(k, n-k) def calc(n): return sum([p(k, n-k) for k in range(1, int(n/2)+1)]) calc(100)
memoize関数の実装
これをまともに実行すると,16分くらいかかってしまいます.再帰処理の引数が2変数のtupleになっているため,組み合わせ分だけ再帰数が増加してしまうためです.これに対して以下のmemoize関数を当てはめます.memoize関数は,キャッシュ用のdictionaryを外部スコープに持つhelper()関数を戻します.このhelper関数は,(x, y)がすでにキャッシュにキーとして存在する場合には,それ以降を計算せずにキャッシュを返し,まだ存在しない場合のみf(x, y)を返します.
def memoize(f): cache = {} def helper(x, y): if (x, y) not in cache: cache[(x, y)] = f(x, y) return cache[(x, y)] return helper
これをデコレータとして用いることで,計算を高速化できます.ということで以下のようにすると,calc()からp(k, n)が呼ばれるたびに,memoize()によってキャッシュ機構がラップされたp(k, n)が返されるため,あとはキャッシュを用いて計算を行うことができます.これで実行時間が0.01秒程度まで短縮されます*5.
def memoize(f): cache = {} def helper(x, y): if (x, y) not in cache: cache[(x, y)] = f(x, y) return cache[(x, y)] return helper @memoize def p(k, n): if k > n: return 0 elif k == n: return 1 return p(k+1, n)+p(k, n-k) def calc(n): return sum([p(k, n-k) for k in range(1, int(n/2)+1)]) calc(100)
*1:年内には,もう少しちゃんとした紹介記事みたいなのを書きたいです.最初は100問解いたら書こうと思ってましたが,年内にはおわらなさそう...
*2:実際のところ,問題が当てはまる理論部分がわかれば,あとは数式をコードに落とし込むだけのことも多かったり.
*3:primesリストを空にしてないのは,is_prime()内のforループで空リストを呼ぶとエラーで落ちるからです.
*4:memorizationではないです.私自身はrが入ってるものだと勝手に思い込んで,ググっても全然引っかからないどういうことだろうとかしばしハマりました.
*5:つまり,ほとんどが同じ再帰処理を何度も繰り返しているだけだった,ということですね.
*6:別にargsでなくとも,変数に*をつけてあればなんでもOKです.さらに**と2つ重ねると辞書型の可変引数になります.