セグメント木

蟻本p.153

https://ikatakos.com/pot/programming_algorithm/data_structure/segment_tree

https://algo-logic.info/segment-tree/

https://kujira16.hateblo.jp/entry/2016/12/15/000000

セグメント木

セグメント木 (Segment Tree) は要素列を表現するデータ構造の1つ。
要素列 \(A = a_1 a_2 a_3 \ldots a_n\) に対して、区間 \(A_{[l, r)}\) に対する操作を高速に行うことができる。
オンラインクエリに対応できることが特徴。
逆にオフラインの場合はクエリをソートすることで、別のデータ構造を用いたさらなる高速化ができる場合がある。

原理

以下、簡単のため要素数 \(n\) を2のべき乗とする。

セグメント木では \(n\) 個の要素を次のように、長さ \(2n-1\) の配列で表現する。
最下層は要素列のインデックスを割り当て、上層には要素列の「区間」を割り当てる。
segment_tree_slice
セグメント木では、任意の区間に対する問い合わせを「連続する小区間に対する問い合わせ」として表現する。
例えば区間 \([1, 7]\) に対する問い合わせは、小区間 \([1], [2,3], [4,7]\) への問い合わせ結果を用いて作成する。
segment_tree_from1to7
また、実装においては1-indexedを採用し、長さ \(2n\) の配列をセグメント木とみなす。
これにより、インデックスに関して次の性質が成り立つ。
  • ビット長が木の深さを表す。

  • 兄弟ノードは下位1ビットだけが異なる。

segment_tree_indice
なお、もとの配列の長さ \(n\) が2のべき乗でないときは、\(n\) より大きい2のべき乗の長さの配列とみなしてセグメント木を作る。
使用しないノードには初期値を代入する。

Range Minimum Query (RMQ)

要素列をセグメント木に格納すると、区間に対する問い合わせを高速化できる。

  • 初期化: \(\mathcal{O}(n)\)

  • 要素1つの更新: \(\mathcal{O}(\log{n})\)

  • 区間 \([l, r)\) の最小値を求める問い合わせ: \(\mathcal{O}(\log{n})\)

RMQでは節点に「区間の最小値」を持たせる。
例えば数列 \([3, 7, 5, 5, 4, -3, 11, 2]\) は、RMQで次のように表現される。
rmq_init

1点更新

要素を更新するにはセグメント木を下層から上層へと辿りつつ、通過した節点に書かれた区間最小値を更新していく。

例えば配列の3番目の要素を \(5 \rightarrow 2\) に更新する場合、次のような更新経路を辿る。

rmq_update

問い合わせ (トップダウン)

問い合わせの実装は2通りある。
トップダウン型では分割統治法を用いる。問い合わせ区間をセグメント木に合わせて分割し、小区間ごとの問い合わせ結果をマージする。
範囲外の区間を表すノードも見ることになるため、\(\mathcal{O}(\log{n})\) は最良計算量となる。次に示すボトムアップ型より遅い。

問い合わせ (ボトムアップ)

ボトムアップ型では更新操作と同様に、問い合わせ区間の走査を下層から上層に向かって行う。
下層から上層へと辿りながら経路上の区間最小値を取得していき、今までに取得した値より小さい値であれば保持している値を更新する。
区間の左端 \(l\) は次のように更新する。
辿るノードはすべて「問い合わせ区間に含まれるノード」または「問い合わせ区間に含まれるかもしれないノード」のいずれかである。
  • ノードの添字が 奇数 (図中の緑ノード) ならば、ノードに書かれた数値は区間最小値である。値を取得して 右上 のノードへ移動する。

  • ノードの添字が偶数ならば、ノードに書かれた数値は区間最小値でない可能性がある。値を取得せずに直上のノードへ移動する。

rmq_query_bottomup_l

同様に、区間の右端 \(r\) も更新する。

  • ノードの添字が 偶数 (図中の青ノード) ならば、ノードに書かれた数値は区間最小値である。値を取得して 左上 のノードへ移動する。

  • ノードの添字が奇数ならば、ノードに書かれた数値は区間最小値でない可能性がある。値を取得せずに直上のノードへ移動する。

rmq_query_bottomup_r
\(l, r\) の経路が交差するまで区間最小値の取得・更新を続ける。
最終的に下図の赤色ノードの区間最小値 \(\{2, -3, 11\}\) が取り込まれ、問い合わせ区間の最小値が \(\min(2, -3, 11) = -3\) と求まる。
rmq_query_bottomup_cross

ボトムアップ型では問い合わせ区間外のノードを見ることがないため、平均計算量 \(\mathcal{O}(\log{n})\) が実現できる。

実装

ボトムアップ型、トップダウン型それぞれの実装を示す。実際の問い合わせはボトムアップ型で処理する。

[1]:
class RMQ():

    def __init__(self, size, op=min, init_value=10**8):
        """初期化"""
        self.size = size
        self.op = op
        self.init_value = init_value
        n = 2 ** ((size-1).bit_length())
        treesize = n * 2
        st = [init_value] * treesize
        self.st = st
        self.offset = len(st) // 2

    @classmethod
    def from_array(cls, a, op=min, init_value=10**8):
        st = cls(len(a), op=op, init_value=init_value)
        for i, x in enumerate(a):
            st.update(i, x)
        return st

    def update(self, key, value):
        """値の更新"""
        k = self.offset + key
        self.st[k] = value
        k >>= 1
        while k > 0:
            self.st[k] = self.op(self.st[k * 2], self.st[k * 2 + 1])
            k >>= 1

    def _query_bottomup(self, a, b):
        """区間[a, b) に対する累積操作
        """
        a += self.offset
        b += self.offset - 1
        s = self.init_value
        while a < b:
            if a & 1:
                s = self.op(s, self.st[a])
                a += 1
            a >>= 1
            if not b & 1:
                s = self.op(s, self.st[b])
                b -= 1
            b >>= 1
        if a == b:
            s = self.op(s, self.st[a])
        return s

    def _query_topdown(self, a, b, k=1, l=0, r=-1):
        """区間[a, b) に対する累積操作
        k: 着目しているノード (1-indexed)
        l: 探索区間 st[l, r) の左端 (0-indexed)
        r: 探索区間 st[l, r) の右端 (0-indexed)
        """
        if r == -1:
            r = self.offset
        if r <= a or b <= l:
            return self.init_value
        if a <= l and r <= b:
            return self.st[k]
        mid = (l + r) // 2
        lv = self._query_topdown(a, b, k * 2, l, mid)
        rv = self._query_topdown(a, b, k * 2 + 1, mid, r)
        return self.op(lv, rv)

    def query(self, a, b):
        """区間[a, b) に対する累積操作"""
        if a > b:
            raise ValueError("a must be less than equal b.")
        return self._query_bottomup(a, b)
[2]:
A = [3, 7, 5, 5, 4, -3, 11, 2]

rmq = RMQ.from_array(A)
rmq.update(2, 2)
print(rmq.query(2, 7))
-3
セグメント木による区間に対する操作は、最小値以外の操作でも利用できる。
具体的には、次の条件を満たせばよい。
  • 結合法則が成り立つ \((a \cdot b) \cdot c = a \cdot (b \cdot c)\)

  • 単位元 \(e\) をもつ \(a \cdot e = e \cdot a = a\)

操作 op には二項演算子を、初期値 init_value には単位元を指定する。

クエリ

操作

初期値

operator.add

0

operator.mul

1

最小値

min

+INF

最大値

max

-INF

AND

operator.and_

1

OR

operator.or_

0

XOR

operator.xor

0

GCD

math.gcd

0

LCM

1

Range Add Query (RAQ)

区間に対する高速な更新操作を実現する。

  • 初期化: \(\mathcal{O}(n)\)

  • 区間更新、区間に対する加算: \(\mathcal{O}(\log{n})\)

  • 1点問い合わせ、値の取得: \(\mathcal{O}(\log{n})\)

RAQでは葉や節点に「区間に加算する値」を持たせる。
例えば数列 \([3, 7, 5, 5, 4, -3, 11, 2]\) と「区間 \([2, 8]\) に対する値 \(4\) の加算」は、RAQで次のように表現される。
raq_init

1点問い合わせ

問い合わせでは、セグメント木を下層から上層へと辿る。
通過したすべての葉と節点の値の和が問い合わせ結果となる。

例えば配列の3番目の要素を取得する場合、次のような経路を辿る。取得される値は \(+5+4+0+0=9\) となる。

raq_query

区間更新

区間更新では区間を小区間に分割し、それぞれの区間に対して加算を行う。
例えば区間 \([2, 6]\) の全要素に値 \(2\) を加算するとき、更新対象は次の赤ノードとなる。
raq_update_target
区間更新はRMQの区間問い合わせと同様に、分割統治法によるトップダウン型、またはビット演算を利用したボトムアップ型のいずれかの方法で処理することができる。
ここでは計算量が小さいボトムアップ型のみ説明する。
上の図について、ノードの添字の偶奇に着目すると、ノードの更新経路はRMQの区間問い合わせと同じ経路を辿ることがわかる。
よって、左端 \(l\) と右端 \(r\) を起点とする更新経路が交差するまでセグメント木の下層から上層へ辿り、通過したノードの値を更新すればよい。

実際、RAQの区間更新の実装は、RMQの区間問い合わせの「ノードの値の取得」を「ノードの更新」に置き換えただけである。

raq_update_lr

実装

区間更新の実装については、トップダウン型とボトムアップ型の両方を示す。

[3]:
class RAQ():

    def __init__(self, size):
        """初期化"""
        self.size = size
        n = 2 ** ((size-1).bit_length())
        treesize = n * 2
        st = [0] * treesize
        self.st = st
        self.offset = len(st) // 2

    @classmethod
    def from_array(cls, a):
        st = cls(len(a))
        for i, x in enumerate(a):
            st.add(i, i+1, x)
        return st

    def _add_topdown(self, a, b, value, k=1, l=0, r=-1):
        """区間[a, b) に対する加算
        k: 着目しているノード (1-indexed)
        l: 探索区間 st[l, r) の左端 (0-indexed)
        r: 探索区間 st[l, r) の右端 (0-indexed)
        """
        if r == -1:
            r = self.offset
        if r <= a or b <= l:
            return
        if l == r - 1:
            self.st[k] += value
            return
        if a <= l and r <= b:
            self.st[k] += value
            return
        mid = (l + r) // 2
        self._add(a, b, value, k * 2, l, mid)
        self._add(a, b, value, k * 2 + 1, mid, r)

    def _add_bottomup(self, a, b, value):
        """区間[a, b) に対する加算
        """
        a += self.offset
        b += self.offset - 1
        while a < b:
            if a & 1:
                self.st[a] += value
                a += 1
            a >>= 1
            if not b & 1:
                self.st[b] += value
                b -= 1
            b >>= 1
        if a == b:
            self.st[a] += value

    def add(self, a, b, value):
        """区間[a, b) に対する加算"""
        if a > b:
            raise ValueError("a must be less than equal b.")
        return self._add_bottomup(a, b, value)

    def get(self, key):
        """値の取得"""
        offset = len(self.st) // 2
        k = offset + key
        v = self.st[k]
        k >>= 1
        while k > 0:
            v += self.st[k]
            k >>= 1
        return v
[4]:
A = [3, 7, 5, 5, 4, -3, 11, 2]

raq = RAQ.from_array(A)
raq.add(2, 4, 4)
raq.add(4, 8, 4)
print(raq.get(2))
raq.add(2, 6, 2)
print(raq.get(2))
9
11