1
/
5

要素数が2の冪でないセグメント木でも,再帰を使わずに二分探索できる

Photo by Raychan on Unsplash

 本記事では,要素数が2の冪とは限らないようなセグメント木上の二分探索の,再帰を使わない実装を例示し,それに準じる実装の正しさを定理証明支援系 Coq を用いて検証します.これによって,より高速な二分探索の実装を,信頼性を損なわずに得られることでしょう.

 前回の記事を書くために要素数が2の冪とは限らないようなセグメント木の実装について色々調べていたのですが,競技プログラマーの間では要素数が2の冪かどうかより,再帰を使わずに各種操作を実装できるかどうかの関心の方が高いように感じました.

 関数型言語では息を吐くように再帰を書くので忘れがちになりますが,関数呼び出しはレジスタの退避を伴う比較的コストの重い操作であり,関数からの return も分岐予測に優しくない,現代のアーキテクチャではこれまた重めの操作です.再帰を使わずに実装できるのならそうしたいと考えるのは自然な欲求でしょう.

 本記事では,要素数が2の冪とは限らないようなセグメント木上の二分探索の,再帰を使わない実装を例示し,それに対応する末尾再帰だけを使った実装の正しさを定理証明支援系 Coq を用いて検証します.これによって,これまでの記事で検証してきたものよりも高速な二分探索の実装を,これまで通り信頼性を担保したまま得られることでしょう.

再帰を使わない二分探索の実装

 形式検証を行う前に,要素数が2の冪とは限らないようなセグメント木上の二分探索の,再帰を使わない実装の導出を行いましょう.本節では,直観的に理解しやすい再帰的な定義からスタートして書き換えを行い,goto 文と等価な末尾再帰だけを使った実装を通して,再帰を使わずにループを使った実装を導出します.

再帰を使った二分探索の実装

 まずは直観的に理解しやすい,再帰を使った二分探索の実装を見てみましょう.単調な述語 predicate と配列に埋め込まれたセグメント木 segtree,セグメント木の葉の数 leaves および添え字 l を受け取って,セグメント木の添え字 [l, r) の要素の,モノイドの演算 M.op による積を predicate に渡した際に真になるような最大の r を返す関数 upper_bound predicate segtree leaves l は OCaml では以下のように再帰を使って定義できます.

(*
 * 添え字 n から始まる世代で,[l, r) の要素と p の積を predicate に渡した結果が真になるような
 * 最大の r を求めて,r と [l, r) の要素の積を返す関数
 * m は n から始まる世代の頂点の数
 *)
let rec upper_bound_rec predicate segtree m n l p =
  (* l が既にその世代で取りうる終端まで来ていたら,それ以上区間を広げようがない *)
  if m <= l
  then (l, p)
  else
    (* 添え字 n から始まる世代の l 番目の頂点を親から見た時,
     * 左側の子か右側の子かで場合分けをする.
     * 左側の子だった場合,親が保持する区間積の始点は現在見ている子のものと一致するので,
     * 親世代に丸投げしておけば倍速で探索できる.
     * 右側の子だった場合は保持している区間積の始点が一致しないので,今見ている頂点だけ
     * 線形探索して,左側の子から始まる区間にしてから親世代に投げる.
     *)
    let p' = if l mod 2 = 0 then p else M.op p segtree.(n + l) in
    if l mod 2 = 1 && not (predicate p')
    (* 線形探索した結果,そこが終端だった場合は当然そこで探索を打ち切る *)
    then (l, p)
    else
      let (l, p) =
        upper_bound_rec predicate segtree (m / 2) (m + n) ((l + 1) / 2) p' in
      if m <= l * 2
      (* 親世代で検索した右端が今見ている世代でも終端に相当していたら,それ以上探索しない *)
      then (l * 2, p)
      else
        (* 親世代で探索をすれば倍速で検索できるとは言っても添え字を1つ飛ばしで検索しているので
           取りこぼしてしまった奇数の添え字の部分を線形探索する *)
        let p' = M.op p segtree.(n + l * 2) in
        if predicate p'
        (* 親世代で取りこぼしていた奇数番目もいける場合 *)
        then (l * 2 + 1, p')
        (* 奇数番目は駄目な場合 *)
        else (l * 2, p)

let upper_bound predicate segtree leaves l =
  (* M.e はモノイドの単位元
     葉の世代が添え字 0 から始まっているのは単にそういうエンコードをしたからで,
     根の世代から要素を格納する実装もありうる *)
  upper_bound_rec predicate segtree leaves 0 l M.e

セグメント木上の二分探索を実装する上での直観も区間積のそれと同様,親世代で処理すれば倍速で片付くので,親世代で処理できない中途半端な部分だけ今の世代で処理してたらい回しにするといった感じです.やたら場合分けが多いように見えますが,境界検査が必要なので仕方ない面があります.

継続渡しスタイルによる二分探索の実装

 さて,この再帰で書かれた二分探索の実装をループだけ使った実装に書き換えるために,まずは goto 文と等価な末尾再帰だけを使って実装できないかどうか考えてみましょう.

 末尾再帰と言うのは再帰呼び出しを行った後で何もせず,得られた戻り値をそのまま関数の返り値にしてしまうような再帰呼び出しの事ですが,無論先程の upper_bound_rec に含まれる再帰呼び出しはそうなっていません.upper_bound_rec の再帰呼び出しで親世代での終端を得た後,今の世代の添え字に直したり親世代で処理できなかった奇数番目の添え字を調べたりと色々処理していますね.

 末尾でない再帰呼び出しを含む関数を,末尾再帰だけ含むような関数に直す最も素朴な方法は,その再帰呼び出しの後に行うはずの計算(継続)も受け取るようなスタイル,継続渡しスタイルに書き換えることです.先程の再帰を使った二分探索の実装を継続渡しスタイルに直すと,以下のようになります.

(* cont l p: 戻り値 l および p を返した後に行われるはずだった計算 *)
let rec upper_bound_rec predicate segtree m n l p cont =
  if m <= l
  (* 呼び出し元が戻り値を受け取った後にやるはずだった計算 cont l p も,呼び出された側でやる *)
  then cont l p
  else
    let p' = if l mod 2 = 0 then p else M.op p segtree.(n + l) in
    if l mod 2 = 1 && not (predicate p')
    then cont l p
    (* upper_bound_rec の返り値を得たあとにやりたかった計算も,クロージャとして渡してやる
       そうすると呼び出された側でその計算が行えるようになるので,そのまま呼び出し元の返り値に
       できる値を返せるようになる *)
    else upper_bound_rec predicate segtree (m / 2) (m + n) ((l + 1) / 2) p' (fun l p ->
      if m <= l * 2
      then cont (l * 2) p
      else
        let p' = M.op p segtree.(n + l * 2) in
        if predicate p'
        then cont (l * 2 + 1) p'
        else cont (l * 2) p)

let upper_bound predicate segtree leaves l =
  upper_bound_rec predicate segtree leaves 0 l M.e (fun l p -> (l, p))

 もっとも,末尾再帰の度にクロージャを作っていてはスタックの代わりにヒープに呼び出し元の情報を退避させているようなものですから,相変わらず効率的な実装にはなっていません.再帰を使わない等価な実装へ変換する前に,クロージャを作らないような実装にしたいものです.

クロージャを作らない二分探索の実装

 再帰の度にクロージャを作らないような二分探索の実装にするために,再帰呼び出しの後で行う計算 cont を再帰関数 upper_bound_rec の外へ出せないかどうか考えてみましょう.cont の中で使われている変数は predicate, segtree, m, n, l, p なので,それらを引数として cont l p と同じ値を返す関数 upper_bound_cont predicate segtree m n l p の定義を試みます.

 ここで問題となるのは,upper_bound_rec を再帰で呼び出した側が渡した cont の中で使われている変数 m, n は呼び出した側から見えている m, n であり,呼び出された側が引数として受け取るのは m / 2, m + n と食い違っている点です.セグメント木の葉の数が2の冪に限定されていれば,今見ている世代の頂点数 m は必ず2の冪になりますから,呼び出された側から見えている m / 2, m + n から呼び出し元の見ている m, n を復元するのは難しくありません.しかし今回はその制約を取り払っているので m は奇数になる可能性があり,2で割った事で m の偶奇の情報が欠けてしまっている以上復元は不可能でしょう.

 これを解決するために,各再帰呼び出しでの m の偶奇をリストにした s を追加の引数として渡してやると,以下のように cont l p と同じ値を返す関数 upper_bound_cont が定義できます.引数だけでは関数を定義するための情報が足りないのなら,その情報を持った引数を追加で受け取ろうと言う訳ですね.

let rec upper_bound_cont predicate segtree m n s l p =
  match s with
  | [] -> (l, p)
  | lsb :: s ->
      (* 元々の m の偶奇を s に格納していたので,m / 2 とその情報から m を復元できる *)
      let m = m * 2 + lsb in
      (* m さえ分かれば,元々の n も分かる *)
      let n = n - m in
      if m <= l * 2
      then upper_bound_cont predicate segtree m n s (l * 2) p
      else
        let p' = M.op p segtree.(n + l * 2) in
        if predicate p'
        then upper_bound_cont predicate segtree m n s (l * 2 + 1) p'
        else upper_bound_cont predicate segtree m n s (l * 2) p

 ただ,m の偶奇を格納するリスト s は再帰の深さ以上に長くなりませんし,格納するデータも1ビットあれば表現できるようなものなので,わざわざ連結リストにするのも富豪的な感じがします.これはこれで貧乏臭いのですが,m の偶奇をビット列として整数型に埋め込むようにすると,以下のように高速な実装が得られることでしょう.

let rec upper_bound_cont predicate segtree m n s l p =
  if s <= 1
  then (l, p)
  else
    (* 最下位ビットに m の偶奇の情報を格納している *)
    let m = m * 2 + s mod 2 in
    let n = n - m in
    if m <= l * 2
    (* 最下位ビットを除いたビット列を作りたければ,2で割れば良い *)
    then upper_bound_cont predicate segtree m n (s / 2) (l * 2) p
    else
      let p' = M.op p segtree.(n + l * 2) in
      if predicate p'
      then upper_bound_cont predicate segtree m n (s / 2) (l * 2 + 1) p'
      else upper_bound_cont predicate segtree m n (s / 2) (l * 2) p

let rec upper_bound_rec predicate segtree m n s l p =
  if m <= l
  then upper_bound_cont predicate segtree m n s l p
  else
    let p' = if l mod 2 = 0 then p else M.op p segtree.(n + l) in
    if l mod 2 = 1 && not (predicate p')
    then upper_bound_cont predicate segtree m n s l p
    else upper_bound_rec predicate segtree (m / 2) (m + n) (s * 2 + m mod 2) ((l + 1) / 2) p'

let upper_bound predicate segtree leaves l =
  upper_bound_rec predicate segtree leaves 0 1 M.e l

 OCaml は末尾再帰の最適化を行ってくれるので,この実装でも再帰を使わない実装と同じ機械語を吐いてくれるのですが,一応参考のために再帰を使わない実装も示しておきましょう.

 まずは対応を分かりやすくするために,末尾再帰を単純に goto 文に置き換えた実装を以下に示します.OCaml には goto 文が存在しないので,とりあえずC言語で書いておきますか.

/* MONOID はセグメント木の要素に入れるモノイドの型のつもり */
int upper_bound(bool (*predicate)(MONOID p), MONOID segtree[], int leaves, int l)
{
	/* MONOID_UNIT はモノイドの単位元のつもり */
	MONOID p = MONOID_UNIT;
	int m = leaves, n = 0, s = 1;

upper_bound_rec:
	if (m <= l) {
		goto upper_bound_cont;
	} else {
		MONOID pp;

		if (l % 2 == 0)
			pp = p;
		else
			/* MONOID_OP はモノイドの演算のつもり */
			pp = MONOID_OP(p, segtree[n + l]);

		if (l % 2 == 1 && !predicate(pp)) {
			goto upper_bound_cont;
		} else {
			s = s * 2 + m % 2;
			n += m;
			m /= 2;
			l = (l + 1) / 2;
			p = pp;
			goto upper_bound_rec;
		}
	}

upper_bound_cont:
	if (s <= 1) {
		/* C言語で複数引数を返すのは面倒なので,とりあえず範囲の終端だけ返す */
		return l;
	} else {
		m = m * 2 + s % 2;
		n -= m;
		if (m <= l * 2) {
			s /= 2;
			l *= 2;
			goto upper_bound_cont;
		} else {
			MONOID pp = MONOID_OP(p, segtree[n + l * 2]);

			if (predicate(pp)) {
				s /= 2;
				l = l * 2 + 1;
				p = pp;
				goto upper_bound_cont;
			} else {
				s /= 2;
				l *= 2;
				goto upper_bound_cont;
			}
		}
	}
}

C言語らしい書き方はせずあえて元の OCaml のコードに近い書き方をしてみたのですが,再帰と末尾再帰の対応を感じ取れたでしょうか.

 goto 文を放置しているとダイクストラ先生に叱られそうですし,while 文を使ってC言語らしいコードにリファクタリングするとしたら以下のような感じになります.

int upper_bound(bool (*predicate)(MONOID p), MONOID segtree[], int leaves, int l)
{
	MONOID p = MONOID_UNIT;
	int m = leaves, n = 0, s = 1;

	while (l < m) {
		if (l % 2 == 1) {
			MONOID pp = MONOID_OP(p, segtree[n + l]);

			if (!predicate(pp))
				break;

			l++;
			p = pp;
		}

		s = s * 2 + m % 2;
		n += m;
		m /= 2;
		l /= 2;
	}

	while (1 < s) {
		m = m * 2 + s % 2;
		n -= m;
		s /= 2;
		l *= 2;

		if (l < m) {
			MONOID pp = MONOID_OP(p, segtree[n + l]);

			if (predicate(pp)) {
				l++;
				p = pp;
			}
		}
	}

	return l;
}

C言語で書いたコードが正しいのかはあまり自信が無いですが,これで要素数が2の冪に限らないセグメント木上の二分探索を,再帰を使わずに実装出来ました.

効率的な末尾再帰による実装の検証

 要素数が2の冪とは限らないようなセグメント木上の二分探索を効率的に実装できたのですが,果たして導出の過程に誤りがないのか気になるところなので,一応 Coq で検証しておきましょう.Coq 上で手続きてきなコードを検証するのは少し面倒なので,それに変換する前の,OCaml で末尾再帰だけを使って書いた実装の正しさを今回は確認します.

 検証の方針ですが,まずは比較的検証しやすい再帰を使った実装で正しく二分探索ができている事を証明し,その実装と今回の末尾再帰だけを使った実装が同じ値を返す事を証明し,間接的に効率的な実装の正しさを証明します.前々回の記事でやったのと同じ方針ですね.

 再帰を使った実装の正しさは https://www.wantedly.com/companies/wantedly/post_articles/369478 等で既に検証しているので割愛して,末尾再帰だけを使った実装との対応を証明しましょう.これは,似たような関数二つが同じ値を返すくらいの事しか言ってないので,末尾再帰だけを使った二分探索の実装が本当に二分探索を行えるかに比べれば比較的簡単に示せます.

Lemma upper_bound_rec_correspondence segtree : forall rest_nodes leaves n s l p cont,
  valid_segtree leaves rest_nodes ->
  0 < s ->
  l <= size leaves ->
  (forall i, i < size leaves + size (encode rest_nodes) ->
        segtree (n + i) = nth idm (leaves ++ encode rest_nodes) i) ->
  (forall l p, l <= size leaves ->
          cont l p = upper_bound_iter_cont segtree (size leaves) n s l p) ->
  upper_bound_rec l p leaves rest_nodes cont
    = upper_bound_iter_rec segtree (size leaves) n s l p.
Proof.
  elim => /= [ | parents rest_nodes IH ] leaves ? s l ? ? /[1! upper_bound_iter_rec_equation']
        => /= [ /eqP -> /= ? /[1! leqn0] /eqP -> ? -> //
            | /andP [ /andP [ /eqP Hsize Heq ] ? ] ? ? Hsegtree Hcont ].
  case: ifPn => [ ? /[1! Hcont] // | ? ].
  rewrite Hsegtree ?size_cat; try lia.
  rewrite -Hsize nth_cat (_: l < size leaves); try lia.
  case: ifPn => [ /[1! Hcont] // | ? ].
  apply /IH => //= [ | | ? ? | k ? ? ]; try lia.
  - rewrite -addnA Hsegtree ?size_cat; try lia.
    by rewrite nth_cat ltnNge leq_addr addKn.
  - rewrite upper_bound_iter_cont_equation Hsize.
    have -> : (s.*2 + odd (size leaves) <= 1) = false by lia.
    have -> : (size leaves)./2.*2 + odd (s.*2 + odd (size leaves)) = size leaves by lia.
    have -> : (s.*2 + odd (size leaves))./2 = s by lia.
    rewrite addnK Hsegtree ?nth_cat ?size_cat; try lia.
    case (leqP (size leaves) k.*2) => [ ? /[1! Hcont] // /ltac:(lia) | ? ].
    by case: ifP => ? /[1! Hcont] //; lia.
Qed.

末尾再帰だけを使った実装を導出する際,継続 cont と同じ値を返す関数を定義するのだと書きましたが,まさにその直観と同じような内容がステートメントに含まれていて興味深いですね.

 この,再帰を使った二分探索の実装と末尾再帰だけを使った実装の対応を使えば,前者の正しさから後者の正しさは容易に導けます.

Corollary upper_bound_iter_correct m l segtree leaves rest_nodes :
  valid_segtree leaves rest_nodes ->
  (forall k, k <= size leaves ->
        P (\big[mul/idm]_(l <= i < k) nth idm leaves i) = (k <= m)) ->
  (forall i, i < size leaves + size (encode rest_nodes) ->
        segtree i = nth idm (leaves ++ encode rest_nodes) i) ->
  m <= size leaves ->
  l <= m ->
  upper_bound_iter segtree (size leaves) l = (m, \big[mul/idm]_(l <= i < m) nth idm leaves i).
Proof.
  move => ? ? ? ? ?.
  rewrite -(upper_bound_iter_correspondence _ _ _ rest_nodes) //; try lia.
  exact /upper_bound_correct.
Qed.

末尾再帰だけを使った二分探索の実装の使用例

 前回の記事にも書いた通り Coq で書いたコードを OCaml のコードに変換する機能が Coq 自体に備わっているので,折角ですから正しさを数学的に証明した,末尾再帰だけを使った二分探索の実装を使って AtCoder の問題を解いてみましょう.

 以下のようなコマンドを Coq 上で実行すると,

Extraction "segtreeQueries.ml" upper_bound_iter.

次のような OCaml のコードが得られます.

(** val odd : int -> bool **)

let rec odd = fun n -> n mod 2 = 1

(** val double : int -> int **)

let double = fun n -> n * 2

(** val half : int -> int **)

let rec half = fun n -> n / 2

(** val uphalf : int -> int **)

and uphalf = fun n -> (n + 1) / 2

(** val lsb : int -> int **)

let lsb = fun n -> n mod 2

(** val product_iter_rec :
    ('a1 -> 'a1 -> 'a1) -> (int ->
    'a1) -> int -> int -> int -> int
    -> 'a1 -> 'a1 -> 'a1 **)

let rec upper_bound_iter_cont mul p x x0 x1 x2 x3 x4 =
  if ( <= ) x2 (succ 0)
  then (x3, x4)
  else if ( <= )
            (( + ) (double x0)
              (lsb x2)) (double x3)
       then upper_bound_iter_cont mul
              p x
              (( + ) (double x0)
                (lsb x2))
              (( - ) x1
                (( + ) (double x0)
                  (lsb x2)))
              (half x2) (double x3) x4
       else if p
                 (mul x4
                   (x
                     (( + )
                       (( - ) x1
                         (( + )
                           (double x0)
                           (lsb x2)))
                       (double x3))))
            then upper_bound_iter_cont
                   mul p x
                   (( + ) (double x0)
                     (lsb x2))
                   (( - ) x1
                     (( + )
                       (double x0)
                       (lsb x2)))
                   (half x2) (succ
                   (double x3))
                   (mul x4
                     (x
                       (( + )
                         (( - ) x1
                           (( + )
                            (double
                            x0)
                            (lsb x2)))
                         (double x3))))
            else upper_bound_iter_cont
                   mul p x
                   (( + ) (double x0)
                     (lsb x2))
                   (( - ) x1
                     (( + )
                       (double x0)
                       (lsb x2)))
                   (half x2)
                   (double x3) x4

(** val upper_bound_iter_rec :
    ('a1 -> 'a1 -> 'a1) -> 'a1 pred
    -> (int -> 'a1) -> int -> int ->
    int -> int -> 'a1 -> int * 'a1 **)

let rec upper_bound_iter_rec mul p x x0 x1 x2 x3 x4 =
  if ( <= ) x0 x3
  then upper_bound_iter_cont mul p x
         x0 x1 x2 x3 x4
  else if odd x3
       then if p
                 (mul x4
                   (x (( + ) x1 x3)))
            then upper_bound_iter_rec
                   mul p x (half x0)
                   (( + ) x1 x0)
                   (( + ) (double x2)
                     (lsb x0)) (succ
                   (half x3))
                   (mul x4
                     (x (( + ) x1 x3)))
            else upper_bound_iter_cont
                   mul p x x0 x1 x2
                   x3 x4
       else upper_bound_iter_rec mul
              p x (half x0)
              (( + ) x1 x0)
              (( + ) (double x2)
                (lsb x0)) (half x3) x4

(** val upper_bound_iter :
    'a1 -> ('a1 -> 'a1 -> 'a1) -> 'a1
    pred -> (int -> 'a1) -> int ->
    int -> int * 'a1 **)

let upper_bound_iter idm mul p segtree m l =
  upper_bound_iter_rec mul p segtree
    m 0 (succ 0) l idm

 今回も前回同様に AtCoder Regular Contest 033 C問題の「データ構造」を解いて,その実行時間やメモリ使用量を比較してみることにしましょう.ほぼセグメント木上の二分探索を使うだけみたいな問題はこんな時に便利ですね.

 まず,今回検証した末尾再帰だけを使った二分探索の実装を使って,実際に「データ構造」を解いた結果は以下の通りです.

https://atcoder.jp/contests/arc033/submissions/32185102

実行時間は 158ms,メモリ使用量は 9212kB でした.

 前回再帰を使った実装で「データ構造」を解いた際の提出は以下の通りですが,実行時間 179ms,メモリ使用量 9372kB と確かに定数倍高速化が実現できていることが分かります.

https://atcoder.jp/contests/arc033/submissions/32140388

逆に言えばその程度の高速化でしかないので,どうせスタックが溢れる心配も無いのだし,正直可読性を優先して再帰で実装しても良いとは思います.OCaml ではなく別の言語ならループで書くことによるメリットが大きかったりするのかもしれないですが.

まとめ

 本記事では,要素数が2の冪とは限らないようなセグメント木上の二分探索の,再帰を使わない実装を例示し,それに対応する末尾再帰だけを使った実装の正しさを定理証明支援系 Coq を用いて検証しました.末尾再帰だけを使った実装はあまり直観的ではないものでしたが,Coq を用いて正当性を検証することで,実際に競技プログラミングの問題を解くのに使っても問題のない,信頼性の高いコードであることが確認できました.

 今回コンテストサイトに提出したコードは以下のリンクから参照できますし,セグメント木上の二分探索を Coq で検証した際のコードが気になった人は Gist から確認できます.

https://atcoder.jp/contests/arc033/submissions/32185102

Wantedly, Inc.では一緒に働く仲間を募集しています
1 いいね!
1 いいね!
同じタグの記事
今週のランキング
Wantedly, Inc.からお誘い
この話題に共感したら、メンバーと話してみませんか?