线段树

DANGER

注意!区间赋值操作并未得到验证,请谨慎使用!如果不需要使用区间赋值,推荐使用 简易线段树

TypeScript
type MergeFn<T> = (a: T, b: T) => T
type ApplyFn<T, L> = (value: T, lazy: L, length: number) => T
type ComposeFn<L> = (oldLazy: L, newLazy: L) => L

interface Operation<T, L> {
    merge: MergeFn<T>
    /** 如何根据 lazy 更新 tree */
    apply: ApplyFn<T, L>
    /** 如何更新 lazy */
    compose: ComposeFn<L>
    identityValue: () => T
    identityLazy: () => L
}

export function useSegmentTree<T, L>(data: T[], op: Operation<T, L>) {
    const n = data.length
    const tree: T[] = Array(4 * n).fill(op.identityValue())
    const lazy: L[] = Array(4 * n).fill(op.identityLazy())

    function build(node: number, l: number, r: number) {
        if (l === r) {
            tree[node] = data[l]
        } else {
            const mid = (l + r) >> 1
            build(node * 2, l, mid)
            build(node * 2 + 1, mid + 1, r)
            tree[node] = op.merge(tree[node * 2], tree[node * 2 + 1])
        }
    }

    function applyToNode(node: number, val: L, length: number) {
        tree[node] = op.apply(tree[node], val, length)
        lazy[node] = op.compose(lazy[node], val)
    }

    function pushDown(node: number, l: number, r: number) {
        if (JSON.stringify(lazy[node]) !== JSON.stringify(op.identityLazy())) {
            const mid = (l + r) >> 1
            applyToNode(node * 2, lazy[node], mid - l + 1)
            applyToNode(node * 2 + 1, lazy[node], r - mid)
            lazy[node] = op.identityLazy()
        }
    }

    function updateRange(
        node: number,
        l: number,
        r: number,
        ql: number,
        qr: number,
        val: L,
    ) {
        if (ql > r || qr < l) return
        if (ql <= l && r <= qr) {
            applyToNode(node, val, r - l + 1)
            return
        }
        pushDown(node, l, r)
        const mid = (l + r) >> 1
        updateRange(node * 2, l, mid, ql, qr, val)
        updateRange(node * 2 + 1, mid + 1, r, ql, qr, val)
        tree[node] = op.merge(tree[node * 2], tree[node * 2 + 1])
    }

    function queryRange(
        node: number,
        l: number,
        r: number,
        ql: number,
        qr: number,
    ): T {
        if (ql > r || qr < l) return op.identityValue()
        if (ql <= l && r <= qr) return tree[node]
        pushDown(node, l, r)
        const mid = (l + r) >> 1
        return op.merge(
            queryRange(node * 2, l, mid, ql, qr),
            queryRange(node * 2 + 1, mid + 1, r, ql, qr),
        )
    }

    build(1, 0, n - 1)

    return {
        query: (l: number, r: number) => queryRange(1, 0, n - 1, l, r),
        update: (l: number, r: number, val: L) =>
            updateRange(1, 0, n - 1, l, r, val),
    }
}

// 求区间和(区间增值)
{
    const seg = useSegmentTree<number, number>([1, 2, 3, 4, 5], {
        merge: (a, b) => a + b,
        apply: (value, lazy, length) => value + lazy * length,
        compose: (oldLazy, newLazy) => oldLazy + newLazy,
        identityValue: () => 0,
        identityLazy: () => 0,
    })

    console.log(seg.query(0, 4)) // 15
    seg.update(1, 3, 2)
    console.log(seg.query(0, 4)) // 21
}

// 求区间和(区间赋值)
{
    const seg = useSegmentTree<number, number>([1, 2, 3, 4, 5], {
        merge: (a, b) => a + b,
        apply: (_value, lazy, length) => lazy * length,
        compose: (_oldLazy, newLazy) => newLazy,
        identityValue: () => 0,
        identityLazy: () => 0,
    })

    console.log(seg.query(0, 4)) // 15
    seg.update(1, 3, 2)
    console.log(seg.query(0, 4)) // 12
}
JavaScript
export function useSegmentTree(data, op) {
    const n = data.length
    const tree = Array(4 * n).fill(op.identityValue())
    const lazy = Array(4 * n).fill(op.identityLazy())
    function build(node, l, r) {
        if (l === r) {
            tree[node] = data[l]
        } else {
            const mid = (l + r) >> 1
            build(node * 2, l, mid)
            build(node * 2 + 1, mid + 1, r)
            tree[node] = op.merge(tree[node * 2], tree[node * 2 + 1])
        }
    }
    function applyToNode(node, val, length) {
        tree[node] = op.apply(tree[node], val, length)
        lazy[node] = op.compose(lazy[node], val)
    }
    function pushDown(node, l, r) {
        if (JSON.stringify(lazy[node]) !== JSON.stringify(op.identityLazy())) {
            const mid = (l + r) >> 1
            applyToNode(node * 2, lazy[node], mid - l + 1)
            applyToNode(node * 2 + 1, lazy[node], r - mid)
            lazy[node] = op.identityLazy()
        }
    }
    function updateRange(node, l, r, ql, qr, val) {
        if (ql > r || qr < l) return
        if (ql <= l && r <= qr) {
            applyToNode(node, val, r - l + 1)
            return
        }
        pushDown(node, l, r)
        const mid = (l + r) >> 1
        updateRange(node * 2, l, mid, ql, qr, val)
        updateRange(node * 2 + 1, mid + 1, r, ql, qr, val)
        tree[node] = op.merge(tree[node * 2], tree[node * 2 + 1])
    }
    function queryRange(node, l, r, ql, qr) {
        if (ql > r || qr < l) return op.identityValue()
        if (ql <= l && r <= qr) return tree[node]
        pushDown(node, l, r)
        const mid = (l + r) >> 1
        return op.merge(
            queryRange(node * 2, l, mid, ql, qr),
            queryRange(node * 2 + 1, mid + 1, r, ql, qr),
        )
    }
    build(1, 0, n - 1)
    return {
        query: (l, r) => queryRange(1, 0, n - 1, l, r),
        update: (l, r, val) => updateRange(1, 0, n - 1, l, r, val),
    }
}
{
    const seg = useSegmentTree([1, 2, 3, 4, 5], {
        merge: (a, b) => a + b,
        apply: (value, lazy, length) => value + lazy * length,
        compose: (oldLazy, newLazy) => oldLazy + newLazy,
        identityValue: () => 0,
        identityLazy: () => 0,
    })
    console.log(seg.query(0, 4))
    seg.update(1, 3, 2)
    console.log(seg.query(0, 4))
}
{
    const seg = useSegmentTree([1, 2, 3, 4, 5], {
        merge: (a, b) => a + b,
        apply: (_value, lazy, length) => lazy * length,
        compose: (_oldLazy, newLazy) => newLazy,
        identityValue: () => 0,
        identityLazy: () => 0,
    })
    console.log(seg.query(0, 4))
    seg.update(1, 3, 2)
    console.log(seg.query(0, 4))
}
Rust
use std::fmt::Debug;

/// 值的代数(幺半群/Monoid)
pub trait Monoid {
    type Item: Clone + Debug;
    /// 合并两个子区间的结果
    fn merge(a: &Self::Item, b: &Self::Item) -> Self::Item;
    /// 单位元
    fn identity() -> Self::Item;
}

/// 延迟操作接口:如何把 lazy 应用到节点值,以及如何合成延迟标记
pub trait LazyOp: Monoid {
    /// lazy 标记的类型
    type Lazy: Clone + Debug + PartialEq;
    /// 把延迟标记应用到节点值(注意:len 用于像求和那样按长度缩放)
    fn apply(value: &Self::Item, lazy: &Self::Lazy, len: usize) -> Self::Item;
    /// 把 newLazy 复合到 oldLazy(old <- compose(old, new)),语义由实现者定义。
    fn compose(old: &mut Self::Lazy, new: &Self::Lazy);
    /// lazy 的单位元(表示“无操作”)
    fn identity_lazy() -> Self::Lazy;
}

/// 通用线段树
pub struct SegmentTree<M: LazyOp> {
    n: usize,
    seg: Vec<M::Item>,
    lz: Vec<M::Lazy>,
}

impl<M: LazyOp> SegmentTree<M> {
    /// 从原始数组构建
    pub fn from_vec(a: &[M::Item]) -> Self {
        let n = a.len();
        let size = n.next_power_of_two() * 2;
        let seg = vec![M::identity(); size];
        let lz = vec![M::identity_lazy(); size];
        let mut st = SegmentTree { n, seg, lz };
        if n > 0 {
            st.build(1, 0, n - 1, a);
        }
        st
    }

    fn build(&mut self, idx: usize, l: usize, r: usize, a: &[M::Item]) {
        if l == r {
            self.seg[idx] = a[l].clone();
            return;
        }
        let mid = (l + r) >> 1;
        self.build(idx << 1, l, mid, a);
        self.build(idx << 1 | 1, mid + 1, r, a);
        self.pull(idx);
    }

    fn pull(&mut self, idx: usize) {
        self.seg[idx] = M::merge(&self.seg[idx << 1], &self.seg[idx << 1 | 1]);
    }

    fn apply_node(&mut self, idx: usize, l: usize, r: usize, lazy: &M::Lazy) {
        let len = r - l + 1;
        self.seg[idx] = M::apply(&self.seg[idx], lazy, len);
        M::compose(&mut self.lz[idx], lazy);
    }

    fn push(&mut self, idx: usize, l: usize, r: usize) {
        // 如果当前节点没有延迟标记(是 identity),则不下传
        let ident = M::identity_lazy();
        if self.lz[idx] != ident {
            let mid = (l + r) >> 1;
            self.apply_node(idx << 1, l, mid, &self.lz[idx].clone());
            self.apply_node(idx << 1 | 1, mid + 1, r, &self.lz[idx].clone());
            self.lz[idx] = M::identity_lazy();
        }
    }

    /// 区间更新 [ql, qr]
    pub fn update(&mut self, ql: usize, qr: usize, val: &M::Lazy) {
        assert!(ql <= qr && qr < self.n);
        self.update_rec(1, 0, self.n - 1, ql, qr, val);
    }

    fn update_rec(&mut self, idx: usize, l: usize, r: usize, ql: usize, qr: usize, val: &M::Lazy) {
        if ql <= l && r <= qr {
            self.apply_node(idx, l, r, val);
            return;
        }
        self.push(idx, l, r);
        let mid = (l + r) >> 1;
        if ql <= mid {
            self.update_rec(idx << 1, l, mid, ql, qr, val);
        }
        if qr > mid {
            self.update_rec(idx << 1 | 1, mid + 1, r, ql, qr, val);
        }
        self.pull(idx);
    }

    /// 区间查询 [ql, qr]
    pub fn query(&mut self, ql: usize, qr: usize) -> M::Item {
        assert!(ql <= qr && qr < self.n);
        self.query_rec(1, 0, self.n - 1, ql, qr)
    }

    fn query_rec(&mut self, idx: usize, l: usize, r: usize, ql: usize, qr: usize) -> M::Item {
        if ql <= l && r <= qr {
            return self.seg[idx].clone();
        }
        self.push(idx, l, r);
        let mid = (l + r) >> 1;
        if qr <= mid {
            return self.query_rec(idx << 1, l, mid, ql, qr);
        }
        if ql > mid {
            return self.query_rec(idx << 1 | 1, mid + 1, r, ql, qr);
        }
        let left = self.query_rec(idx << 1, l, mid, ql, qr);
        let right = self.query_rec(idx << 1 | 1, mid + 1, r, ql, qr);
        M::merge(&left, &right)
    }
}

// 求区间和(同时支持区间赋值和区间增值)
#[derive(Clone, Debug, PartialEq)]
struct SumMonoid;

impl Monoid for SumMonoid {
    type Item = i64;
    fn merge(a: &Self::Item, b: &Self::Item) -> Self::Item {
        *a + *b
    }
    fn identity() -> Self::Item {
        0
    }
}

#[derive(Clone, PartialEq, Debug)]
pub enum SumLazy {
    None,
    Add(i64),
    Assign(i64),
}

impl LazyOp for SumMonoid {
    type Lazy = SumLazy;
    fn apply(value: &Self::Item, lazy: &Self::Lazy, len: usize) -> Self::Item {
        match lazy {
            SumLazy::Add(lazy) => *value + lazy * len as i64,
            SumLazy::Assign(lazy) => lazy * len as i64,
            _ => *value,
        }
    }
    fn compose(old: &mut Self::Lazy, new: &Self::Lazy) {
        match new {
            SumLazy::Assign(_) => *old = new.clone(),
            SumLazy::Add(n) => match old {
                SumLazy::Add(o) => *o += n,
                SumLazy::Assign(o) => *o += n,
                _ => *old = new.clone(),
            },
            _ => (),
        }
    }
    fn identity_lazy() -> Self::Lazy {
        SumLazy::None
    }
}

// DELETE: START
#[cfg(test)]
mod tests {
    use super::*;

    // --------- 示例实现:区间赋值 + 区间最大值 ---------
    #[derive(Clone, Debug, PartialEq)]
    struct MaxAssignMonoid;

    impl Monoid for MaxAssignMonoid {
        type Item = i64; // 存当前区间的最大值
        fn merge(a: &Self::Item, b: &Self::Item) -> Self::Item {
            *a.max(b)
        }
        fn identity() -> Self::Item {
            i64::MIN / 4
        }
    }

    #[derive(Clone, Debug, PartialEq)]
    enum AssignLazy {
        None,
        Set(i64),
    }

    impl LazyOp for MaxAssignMonoid {
        type Lazy = AssignLazy;
        fn apply(value: &Self::Item, lazy: &Self::Lazy, _len: usize) -> Self::Item {
            match lazy {
                AssignLazy::None => *value,
                AssignLazy::Set(x) => *x,
            }
        }
        fn compose(old: &mut Self::Lazy, new: &Self::Lazy) {
            match new {
                AssignLazy::None => (),
                AssignLazy::Set(x) => *old = AssignLazy::Set(*x),
            }
        }
        fn identity_lazy() -> Self::Lazy {
            AssignLazy::None
        }
    }

    #[test]
    fn test_sum_range_add() {
        let a = vec![1i64, 2, 3, 4, 5];
        let mut st = SegmentTree::<SumMonoid>::from_vec(&a);
        assert_eq!(st.query(0, 4), 15);
        st.update(1, 3, &SumLazy::Add(10)); // a[1..3] += 10
        assert_eq!(st.query(0, 0), 1);
        assert_eq!(st.query(1, 1), 12);
        assert_eq!(st.query(0, 4), 45);
        assert_eq!(st.query(2, 4), 13 + 14 + 5);
        st.update(1, 3, &SumLazy::Assign(10)); // a[1..3] = 10
        assert_eq!(st.query(1, 1), 10);
        assert_eq!(st.query(0, 4), 36);
        assert_eq!(st.query(2, 4), 10 + 10 + 5);
    }

    #[test]
    fn test_max_range_set() {
        let a = vec![1i64, 7, 3, 9, 5];
        let mut st = SegmentTree::<MaxAssignMonoid>::from_vec(&a);
        assert_eq!(st.query(0, 4), 9);
        st.update(1, 3, &AssignLazy::Set(4));
        assert_eq!(st.query(0, 4), 5.max(4).max(1));
        assert_eq!(st.query(1, 2), 4);
    }
}

// DELETE: END
Java

import java.util.Arrays;

class SegmentTree {

    private final int[] tree, nums, lazy;
    final int UNUSE = 0; // 注意这里,如果是区间赋值,UNUSE 需要保证一定不能等于区间赋值!

    public SegmentTree(int[] nums) {
        int n = nums.length;
        tree = new int[4 * n];
        lazy = new int[4 * n]; // lazy tag
        Arrays.fill(lazy, UNUSE);
        this.nums = nums;
        buildTree(0, 0, n - 1);
    }

    /**
     * 构建线段树
     *
     * @param node  当前节点对应到 tree 的下标
     * @param start 当前节点对应的 nums 区间起点
     * @param end   当前节点对应的 nums 区间终点
     */
    private void buildTree(int node, int start, int end) {
        if (start == end) {
            tree[node] = nums[start];
            return;
        }
        int mid = start + (end - start) / 2;
        int leftNode = node * 2 + 1;
        int rightNode = node * 2 + 2;

        buildTree(leftNode, start, mid);
        buildTree(rightNode, mid + 1, end);

        tree[node] = tree[leftNode] + tree[rightNode];
    }

    private void pushDown(int node, int start, int end) {
        if (lazy[node] == UNUSE) {
            return;
        }
        int leftNode = node * 2 + 1;
        int rightNode = node * 2 + 2;
        int mid = start + (end - start) / 2;

        // 如果方案为区间增值,把下面四个 = 改成 += 即可
        tree[leftNode] = lazy[node] * (mid - start + 1);
        lazy[leftNode] = lazy[node];
        tree[rightNode] = lazy[node] * (end - mid);
        lazy[rightNode] = lazy[node];
        // 清空当前 lazy tag
        lazy[node] = UNUSE;
    }

    /**
     * 更新线段树
     *
     * @param node  当前节点对应到 tree 的下标
     * @param start 当前节点对应的 nums 区间起点
     * @param end   当前节点对应的 nums 区间终点
     * @param left  待更新的 nums 左端点
     * @param right 待更新的 nums 右端点
     * @param val   区间设定的值
     */
    private void updateTree(int node, int start, int end, int left, int right, int val) {
        // 当前节点被区间完全覆盖,更新节点并设置 lazy tag,这样就没必要向下更新了
        if (left <= start && end <= right) {
            // 如果方案为区间增值,把下面两个 = 改成 += 即可
            tree[node] = (end - start + 1) * val;
            lazy[node] = val;
            return;
        }

        // 把当前节点的 lazy tag 传递到下一层
        pushDown(node, start, end);

        int mid = start + (end - start) / 2;
        int leftNode = 2 * node + 1;
        int rightNode = 2 * node + 2;

        // 需要更新左子树
        if (left <= mid) {
            updateTree(leftNode, start, mid, left, right, val);
        }
        // 需要更新右子树
        if (right > mid) {
            updateTree(rightNode, mid + 1, end, left, right, val);
        }
        // 更新当前节点
        tree[node] = tree[leftNode] + tree[rightNode];
    }

    /**
     * 查询线段树
     *
     * @param node  当前节点对应到 tree 的下标
     * @param start 当前节点对应的 nums 区间起点
     * @param end   当前节点对应的 nums 区间终点
     * @param left  待查询的 nums 区间起点
     * @param right 待查询的 nums 区间终点
     */
    private int queryTree(int node, int start, int end, int left, int right) {
        // 查找范围不在当前范围内
        if (right < start || left > end) {
            return 0;
        }
        // 当前范围就在查找范围中,直接返回
        if (start >= left && end <= right) {
            return tree[node];
        }

        // 部分范围在查找区间中,先把 lazy tag 传下去
        pushDown(node, start, end);

        int mid = start + (end - start) / 2;
        int leftNode = 2 * node + 1;
        int rightNode = 2 * node + 2;
        int leftSum = queryTree(leftNode, start, mid, left, right);
        int rightSum = queryTree(rightNode, mid + 1, end, left, right);
        return leftSum + rightSum;
    }

    /**
     * 更新 nums 下标的值
     *
     * @param index 对应的 nums 下标
     * @param val   更新的值
     */
    public void update(int index, int val) {
        updateTree(0, 0, nums.length - 1, index, index, val);
    }

    /**
     * 区间更改
     *
     * @param left  nums 左端点
     * @param right nums 右端点
     * @param val   设定的值
     */
    public void update(int left, int right, int val) {
        updateTree(0, 0, nums.length - 1, left, right, val);
    }

    /**
     * 查询 nums 的区间和
     *
     * @param left  区间起点
     * @param right 区间终点
     */
    public int query(int left, int right) {
        return queryTree(0, 0, nums.length - 1, left, right);
    }
}