読者です 読者をやめる 読者になる 読者になる

YAMAGUCHI::weblog

土足で窓から失礼いたします。今日からあなたの息子になります。 当年とって92歳、下町の発明王、エジソンです。

n個の対応する括弧のパターン(続き)

はじめに

こんにちは、OCaml界の桑原茂一です。id:nishiohirokazutmp.pyの解説という記事を書いていて、「あ、そういえば前これ俺が問題出されたやつだった」というのを思い出しました。自分の方のブログの記事を見直してみたらメモ化をしてないのですごく遅いまま。というわけでメモ化をするようにコードを書き換えてみました。

関連リンク

前回書いたリンクはこちら。n=1 なら ()、 n=2 なら ()(), (()) という具合に開き括弧がn個で考えられるパターンは全部で何パターンあるでしょうか?という問題。

今回のGistはこちら

測定環境

CPU Intel Core 2 Duo L7500 1.60GHz
RAM 2.97GB
OS Windows XP SP3 (Cygwin 1.7.1)
OCaml 3.12.0

書いたコード

元のコード

まずは元のコードはこれ。

let map f l = List.rev (List.rev_map f l);;
let enclose x = "(" ^ x ^ ")";;
let cross lhs rhs =
  let rec cross_aux accu = function
    | [] -> accu
    | x::xs -> cross_aux (List.fold_left (fun a y -> (x^y)::a) accu rhs) xs
  in
  cross_aux [] lhs
;;

let rec solution = function
  | 0 -> []
  | n when n < 0 -> prerr_int n; invalid_arg "solution"
  | n -> List.rev_append (atomic n) (separate n)
and atomic = function
  | 1 -> ["()"]
  | n -> map enclose (solution (n-1))
and separate n =
  let rec separate_aux accu = function
    | 0 -> accu
    | k -> separate_aux (List.rev_append (cross (atomic k) (solution (n-k))) accu) (pred k)
  in
  separate_aux [] (n-1)

n=10くらいまでならまあ何とか表示してくれるんだけど、n=14とかになるともう全然処理が終わらない。やっぱり毎回パターンを作っているのが問題なんだと思う。というわけでメモ化でそれを改善してみます。

メモ化したコード

メモ化に際してHashtblを使いました。

メモ化、と言っても今回やっているのは単純に括弧数とそれに対応する組み合わせのリストのペアを作るだけです。つまりこんなイメージ:

1 : ['()']
2 : ['(())'; '()()']
3 : ['((()))'; '(()())'; '()(())'; '()()()'; '(())()']
...

で、単純にハッシュテーブルにあったらそれを使うし、なかったら作って、ハッシュテーブルに突っ込んでから返すというだけ。つまりこんな感じに書きかえる。

if Hashtbl.mem table n
then Hashtbl.find table n
else
  begin
    let v = someprocess n in
    Hashtbl.add table n v;
    v
  end

で、メモ化する範囲を「結果全体」にしたものと「ひと固まりのものとそうでないもの(eg, '()(())')」の2つの範囲に分けた場合でそれぞれsolver, fast_solverという名前で作ってみました。
とりあえず準備段階の部分は上のコードと共通なので、solverの部分だけを切り取ってます。全体が見たい人はGistのコード見てください。

(**
  val solution : int -> string list
  val atomic : int -> string list
  val separate : int -> string list
*)
let solver table n = 
  let rec solution n =
    match n with
    | 0 -> []
    | n when n < 0 -> prerr_int n; invalid_arg "solution"
    | n ->
        if Hashtbl.mem table n
        then Hashtbl.find table n
        else
          begin
            let v = List.rev_append (atomic n) (separate n) in
            Hashtbl.add table n v;
            v
          end
  and atomic n =
    match n with
    | 1 -> ["()"]
    | n -> map enclose (solution (n-1))
  and separate n =
    let rec separate_aux accu = function
      | 0 -> accu
      | k -> separate_aux (List.rev_append (cross (atomic k) (solution (n-k))) accu) (pred k)
    in
    separate_aux [] (n-1)
  in
  solution n
;;


let fast_solver tbl atbl stbl n =
  let rec solution n =
    match n with
    | 0 -> []
    | n when n < 0 -> prerr_int n; invalid_arg "solution"
    | n ->
        if Hashtbl.mem tbl n
        then Hashtbl.find tbl n
        else
          begin
            let v = List.rev_append (atomic n) (separate n) in
            Hashtbl.add tbl n v;
            v
          end
  and atomic n =
    match n with
    | 1 -> ["()"]
    | n -> 
        if Hashtbl.mem atbl n
        then Hashtbl.find atbl n
        else
          begin
            let v = map enclose (solution (n-1)) in
            Hashtbl.add atbl n v;
            v
          end
  and separate n =
    let rec separate_aux accu = function
      | 0 -> accu
      | k -> separate_aux (List.rev_append (cross (atomic k) (solution (n-k))) accu) (pred k)
    in
    if Hashtbl.mem stbl n
    then Hashtbl.find stbl n
    else
      begin
        let v = separate_aux [] (n-1) in
        Hashtbl.add stbl n v;
        v
      end
  in
  solution n
;;

n=14でsolverとfast_solverを実行するとそれぞれこのくらいの時間になる。

  • solver
$ time ./paren.exe
*** test start ***
2674440

real    0m6.599s
user    0m6.358s
sys     0m0.218s
  • fast_solver
$ time ./paren.exe
*** fast test start ***
2674440

real    0m4.099s
user    0m3.983s
sys     0m0.124s

でもしかし、

前回のコードをみたid:nishiohirokazuがこんなことを最後に言ってたんですよね。

あー、これ、OCamlは遅延評価じゃないし純粋でもないよね?解は出ているけど再利用はできてないんじゃないかな?solution 1とかを何度も何度も呼んでない?Pythonでn=12で0m0.888s, n=14で0m3.099sなので、OCamlでそれより遅いのはおかしいんだよ。

まだあと1秒も差があるなあ。うーむ、どうしたもんかな...

追記

Python忘年会で確認したら id:nishiohirokazu のPCはCore 2 Duoの2.4GHzだったみたい。改めて自分の家でMBP 15''(Intel Core i7, 2.66GHz)でfast_solverを計算させてみたらこんな感じになった。

$ time ./paren 
*** fast test start ***
2674440

real	0m2.134s
user	0m1.934s
sys	0m0.195s

まあCPU性能の差も出ているとはいえ、2秒は切ってほしいなあ。

追記2

キャミバ様が僕を木人形にしてくれました!

id:camlspotter添削有難うございました!

コンパイラ

エントリの中で質問がありましたので回答します。

おい、お前、お前の使ったコンパイラの名前を言ってみろ!
ocamlopt でコンパイルしたんだろうな?

もちろんです!

$ ocamlopt -o paren paren.ml
文字列のわけ

しかし… 今気づいたんだが、なんで括弧で溢れる文字列を作る必要があるんだ?!

最初ちょろちょろと括弧を見ながらprintfデバッグしてたのをそのまま使ってたなんてとても言えない!

うわらば

肝に銘じておきます!!

  • まず問題を良く読み、アルゴリズムを再検討しよう
  • 末尾再帰に気をつけよう
  • 無駄なデータの計算は避けよう
  • リスト系再帰関数を書いたら、fold かどうか確認しよう。大体が fold で綺麗に書けるはず
  • OCaml は自動的なコード最適化はほとんど行わない。それでも十分早い!さらに、工夫すればするだけスピードが上がり、そしてそのゲインは大体予想できる
  • 関数型パラダイムの元で破壊的操作は十分に統御可能。OCaml では purity にこだわるな!
  • 境界問題が発生しない場合は unsafe 操作も恐れずに使用しよう
  • 瞬間的にスピードが必要な場合、GC を一時的に止める事も考えよう