线段树
DANGER
注意!区间赋值操作并未得到验证,请谨慎使用!如果不需要使用区间赋值,推荐使用 简易线段树。
TypeScript
type MergeFn<T> = (a: T, b: T) => T
type ApplyFn<T, L> = (value: T, lazy: L, length: number) => T
type ComposeFn<L> = (oldLazy: L, newLazy: L) => L
interface Operation<T, L> {
merge: MergeFn<T>
/** 如何根据 lazy 更新 tree */
apply: ApplyFn<T, L>
/** 如何更新 lazy */
compose: ComposeFn<L>
identityValue: () => T
identityLazy: () => L
}
export function useSegmentTree<T, L>(data: T[], op: Operation<T, L>) {
const n = data.length
const tree: T[] = Array(4 * n).fill(op.identityValue())
const lazy: L[] = Array(4 * n).fill(op.identityLazy())
function build(node: number, l: number, r: number) {
if (l === r) {
tree[node] = data[l]
} else {
const mid = (l + r) >> 1
build(node * 2, l, mid)
build(node * 2 + 1, mid + 1, r)
tree[node] = op.merge(tree[node * 2], tree[node * 2 + 1])
}
}
function applyToNode(node: number, val: L, length: number) {
tree[node] = op.apply(tree[node], val, length)
lazy[node] = op.compose(lazy[node], val)
}
function pushDown(node: number, l: number, r: number) {
if (JSON.stringify(lazy[node]) !== JSON.stringify(op.identityLazy())) {
const mid = (l + r) >> 1
applyToNode(node * 2, lazy[node], mid - l + 1)
applyToNode(node * 2 + 1, lazy[node], r - mid)
lazy[node] = op.identityLazy()
}
}
function updateRange(
node: number,
l: number,
r: number,
ql: number,
qr: number,
val: L,
) {
if (ql > r || qr < l) return
if (ql <= l && r <= qr) {
applyToNode(node, val, r - l + 1)
return
}
pushDown(node, l, r)
const mid = (l + r) >> 1
updateRange(node * 2, l, mid, ql, qr, val)
updateRange(node * 2 + 1, mid + 1, r, ql, qr, val)
tree[node] = op.merge(tree[node * 2], tree[node * 2 + 1])
}
function queryRange(
node: number,
l: number,
r: number,
ql: number,
qr: number,
): T {
if (ql > r || qr < l) return op.identityValue()
if (ql <= l && r <= qr) return tree[node]
pushDown(node, l, r)
const mid = (l + r) >> 1
return op.merge(
queryRange(node * 2, l, mid, ql, qr),
queryRange(node * 2 + 1, mid + 1, r, ql, qr),
)
}
build(1, 0, n - 1)
return {
query: (l: number, r: number) => queryRange(1, 0, n - 1, l, r),
update: (l: number, r: number, val: L) =>
updateRange(1, 0, n - 1, l, r, val),
}
}
// 求区间和(区间增值)
{
const seg = useSegmentTree<number, number>([1, 2, 3, 4, 5], {
merge: (a, b) => a + b,
apply: (value, lazy, length) => value + lazy * length,
compose: (oldLazy, newLazy) => oldLazy + newLazy,
identityValue: () => 0,
identityLazy: () => 0,
})
console.log(seg.query(0, 4)) // 15
seg.update(1, 3, 2)
console.log(seg.query(0, 4)) // 21
}
// 求区间和(区间赋值)
{
const seg = useSegmentTree<number, number>([1, 2, 3, 4, 5], {
merge: (a, b) => a + b,
apply: (_value, lazy, length) => lazy * length,
compose: (_oldLazy, newLazy) => newLazy,
identityValue: () => 0,
identityLazy: () => 0,
})
console.log(seg.query(0, 4)) // 15
seg.update(1, 3, 2)
console.log(seg.query(0, 4)) // 12
}
JavaScript
export function useSegmentTree(data, op) {
const n = data.length
const tree = Array(4 * n).fill(op.identityValue())
const lazy = Array(4 * n).fill(op.identityLazy())
function build(node, l, r) {
if (l === r) {
tree[node] = data[l]
} else {
const mid = (l + r) >> 1
build(node * 2, l, mid)
build(node * 2 + 1, mid + 1, r)
tree[node] = op.merge(tree[node * 2], tree[node * 2 + 1])
}
}
function applyToNode(node, val, length) {
tree[node] = op.apply(tree[node], val, length)
lazy[node] = op.compose(lazy[node], val)
}
function pushDown(node, l, r) {
if (JSON.stringify(lazy[node]) !== JSON.stringify(op.identityLazy())) {
const mid = (l + r) >> 1
applyToNode(node * 2, lazy[node], mid - l + 1)
applyToNode(node * 2 + 1, lazy[node], r - mid)
lazy[node] = op.identityLazy()
}
}
function updateRange(node, l, r, ql, qr, val) {
if (ql > r || qr < l) return
if (ql <= l && r <= qr) {
applyToNode(node, val, r - l + 1)
return
}
pushDown(node, l, r)
const mid = (l + r) >> 1
updateRange(node * 2, l, mid, ql, qr, val)
updateRange(node * 2 + 1, mid + 1, r, ql, qr, val)
tree[node] = op.merge(tree[node * 2], tree[node * 2 + 1])
}
function queryRange(node, l, r, ql, qr) {
if (ql > r || qr < l) return op.identityValue()
if (ql <= l && r <= qr) return tree[node]
pushDown(node, l, r)
const mid = (l + r) >> 1
return op.merge(
queryRange(node * 2, l, mid, ql, qr),
queryRange(node * 2 + 1, mid + 1, r, ql, qr),
)
}
build(1, 0, n - 1)
return {
query: (l, r) => queryRange(1, 0, n - 1, l, r),
update: (l, r, val) => updateRange(1, 0, n - 1, l, r, val),
}
}
{
const seg = useSegmentTree([1, 2, 3, 4, 5], {
merge: (a, b) => a + b,
apply: (value, lazy, length) => value + lazy * length,
compose: (oldLazy, newLazy) => oldLazy + newLazy,
identityValue: () => 0,
identityLazy: () => 0,
})
console.log(seg.query(0, 4))
seg.update(1, 3, 2)
console.log(seg.query(0, 4))
}
{
const seg = useSegmentTree([1, 2, 3, 4, 5], {
merge: (a, b) => a + b,
apply: (_value, lazy, length) => lazy * length,
compose: (_oldLazy, newLazy) => newLazy,
identityValue: () => 0,
identityLazy: () => 0,
})
console.log(seg.query(0, 4))
seg.update(1, 3, 2)
console.log(seg.query(0, 4))
}
Rust
use std::fmt::Debug;
/// 值的代数(幺半群/Monoid)
pub trait Monoid {
type Item: Clone + Debug;
/// 合并两个子区间的结果
fn merge(a: &Self::Item, b: &Self::Item) -> Self::Item;
/// 单位元
fn identity() -> Self::Item;
}
/// 延迟操作接口:如何把 lazy 应用到节点值,以及如何合成延迟标记
pub trait LazyOp: Monoid {
/// lazy 标记的类型
type Lazy: Clone + Debug + PartialEq;
/// 把延迟标记应用到节点值(注意:len 用于像求和那样按长度缩放)
fn apply(value: &Self::Item, lazy: &Self::Lazy, len: usize) -> Self::Item;
/// 把 newLazy 复合到 oldLazy(old <- compose(old, new)),语义由实现者定义。
fn compose(old: &mut Self::Lazy, new: &Self::Lazy);
/// lazy 的单位元(表示“无操作”)
fn identity_lazy() -> Self::Lazy;
}
/// 通用线段树
pub struct SegmentTree<M: LazyOp> {
n: usize,
seg: Vec<M::Item>,
lz: Vec<M::Lazy>,
}
impl<M: LazyOp> SegmentTree<M> {
/// 从原始数组构建
pub fn from_vec(a: &[M::Item]) -> Self {
let n = a.len();
let size = n.next_power_of_two() * 2;
let seg = vec![M::identity(); size];
let lz = vec![M::identity_lazy(); size];
let mut st = SegmentTree { n, seg, lz };
if n > 0 {
st.build(1, 0, n - 1, a);
}
st
}
fn build(&mut self, idx: usize, l: usize, r: usize, a: &[M::Item]) {
if l == r {
self.seg[idx] = a[l].clone();
return;
}
let mid = (l + r) >> 1;
self.build(idx << 1, l, mid, a);
self.build(idx << 1 | 1, mid + 1, r, a);
self.pull(idx);
}
fn pull(&mut self, idx: usize) {
self.seg[idx] = M::merge(&self.seg[idx << 1], &self.seg[idx << 1 | 1]);
}
fn apply_node(&mut self, idx: usize, l: usize, r: usize, lazy: &M::Lazy) {
let len = r - l + 1;
self.seg[idx] = M::apply(&self.seg[idx], lazy, len);
M::compose(&mut self.lz[idx], lazy);
}
fn push(&mut self, idx: usize, l: usize, r: usize) {
// 如果当前节点没有延迟标记(是 identity),则不下传
let ident = M::identity_lazy();
if self.lz[idx] != ident {
let mid = (l + r) >> 1;
self.apply_node(idx << 1, l, mid, &self.lz[idx].clone());
self.apply_node(idx << 1 | 1, mid + 1, r, &self.lz[idx].clone());
self.lz[idx] = M::identity_lazy();
}
}
/// 区间更新 [ql, qr]
pub fn update(&mut self, ql: usize, qr: usize, val: &M::Lazy) {
assert!(ql <= qr && qr < self.n);
self.update_rec(1, 0, self.n - 1, ql, qr, val);
}
fn update_rec(&mut self, idx: usize, l: usize, r: usize, ql: usize, qr: usize, val: &M::Lazy) {
if ql <= l && r <= qr {
self.apply_node(idx, l, r, val);
return;
}
self.push(idx, l, r);
let mid = (l + r) >> 1;
if ql <= mid {
self.update_rec(idx << 1, l, mid, ql, qr, val);
}
if qr > mid {
self.update_rec(idx << 1 | 1, mid + 1, r, ql, qr, val);
}
self.pull(idx);
}
/// 区间查询 [ql, qr]
pub fn query(&mut self, ql: usize, qr: usize) -> M::Item {
assert!(ql <= qr && qr < self.n);
self.query_rec(1, 0, self.n - 1, ql, qr)
}
fn query_rec(&mut self, idx: usize, l: usize, r: usize, ql: usize, qr: usize) -> M::Item {
if ql <= l && r <= qr {
return self.seg[idx].clone();
}
self.push(idx, l, r);
let mid = (l + r) >> 1;
if qr <= mid {
return self.query_rec(idx << 1, l, mid, ql, qr);
}
if ql > mid {
return self.query_rec(idx << 1 | 1, mid + 1, r, ql, qr);
}
let left = self.query_rec(idx << 1, l, mid, ql, qr);
let right = self.query_rec(idx << 1 | 1, mid + 1, r, ql, qr);
M::merge(&left, &right)
}
}
// 求区间和(同时支持区间赋值和区间增值)
#[derive(Clone, Debug, PartialEq)]
struct SumMonoid;
impl Monoid for SumMonoid {
type Item = i64;
fn merge(a: &Self::Item, b: &Self::Item) -> Self::Item {
*a + *b
}
fn identity() -> Self::Item {
0
}
}
#[derive(Clone, PartialEq, Debug)]
pub enum SumLazy {
None,
Add(i64),
Assign(i64),
}
impl LazyOp for SumMonoid {
type Lazy = SumLazy;
fn apply(value: &Self::Item, lazy: &Self::Lazy, len: usize) -> Self::Item {
match lazy {
SumLazy::Add(lazy) => *value + lazy * len as i64,
SumLazy::Assign(lazy) => lazy * len as i64,
_ => *value,
}
}
fn compose(old: &mut Self::Lazy, new: &Self::Lazy) {
match new {
SumLazy::Assign(_) => *old = new.clone(),
SumLazy::Add(n) => match old {
SumLazy::Add(o) => *o += n,
SumLazy::Assign(o) => *o += n,
_ => *old = new.clone(),
},
_ => (),
}
}
fn identity_lazy() -> Self::Lazy {
SumLazy::None
}
}
// DELETE: START
#[cfg(test)]
mod tests {
use super::*;
// --------- 示例实现:区间赋值 + 区间最大值 ---------
#[derive(Clone, Debug, PartialEq)]
struct MaxAssignMonoid;
impl Monoid for MaxAssignMonoid {
type Item = i64; // 存当前区间的最大值
fn merge(a: &Self::Item, b: &Self::Item) -> Self::Item {
*a.max(b)
}
fn identity() -> Self::Item {
i64::MIN / 4
}
}
#[derive(Clone, Debug, PartialEq)]
enum AssignLazy {
None,
Set(i64),
}
impl LazyOp for MaxAssignMonoid {
type Lazy = AssignLazy;
fn apply(value: &Self::Item, lazy: &Self::Lazy, _len: usize) -> Self::Item {
match lazy {
AssignLazy::None => *value,
AssignLazy::Set(x) => *x,
}
}
fn compose(old: &mut Self::Lazy, new: &Self::Lazy) {
match new {
AssignLazy::None => (),
AssignLazy::Set(x) => *old = AssignLazy::Set(*x),
}
}
fn identity_lazy() -> Self::Lazy {
AssignLazy::None
}
}
#[test]
fn test_sum_range_add() {
let a = vec![1i64, 2, 3, 4, 5];
let mut st = SegmentTree::<SumMonoid>::from_vec(&a);
assert_eq!(st.query(0, 4), 15);
st.update(1, 3, &SumLazy::Add(10)); // a[1..3] += 10
assert_eq!(st.query(0, 0), 1);
assert_eq!(st.query(1, 1), 12);
assert_eq!(st.query(0, 4), 45);
assert_eq!(st.query(2, 4), 13 + 14 + 5);
st.update(1, 3, &SumLazy::Assign(10)); // a[1..3] = 10
assert_eq!(st.query(1, 1), 10);
assert_eq!(st.query(0, 4), 36);
assert_eq!(st.query(2, 4), 10 + 10 + 5);
}
#[test]
fn test_max_range_set() {
let a = vec![1i64, 7, 3, 9, 5];
let mut st = SegmentTree::<MaxAssignMonoid>::from_vec(&a);
assert_eq!(st.query(0, 4), 9);
st.update(1, 3, &AssignLazy::Set(4));
assert_eq!(st.query(0, 4), 5.max(4).max(1));
assert_eq!(st.query(1, 2), 4);
}
}
// DELETE: END
Java
import java.util.Arrays;
class SegmentTree {
private final int[] tree, nums, lazy;
final int UNUSE = 0; // 注意这里,如果是区间赋值,UNUSE 需要保证一定不能等于区间赋值!
public SegmentTree(int[] nums) {
int n = nums.length;
tree = new int[4 * n];
lazy = new int[4 * n]; // lazy tag
Arrays.fill(lazy, UNUSE);
this.nums = nums;
buildTree(0, 0, n - 1);
}
/**
* 构建线段树
*
* @param node 当前节点对应到 tree 的下标
* @param start 当前节点对应的 nums 区间起点
* @param end 当前节点对应的 nums 区间终点
*/
private void buildTree(int node, int start, int end) {
if (start == end) {
tree[node] = nums[start];
return;
}
int mid = start + (end - start) / 2;
int leftNode = node * 2 + 1;
int rightNode = node * 2 + 2;
buildTree(leftNode, start, mid);
buildTree(rightNode, mid + 1, end);
tree[node] = tree[leftNode] + tree[rightNode];
}
private void pushDown(int node, int start, int end) {
if (lazy[node] == UNUSE) {
return;
}
int leftNode = node * 2 + 1;
int rightNode = node * 2 + 2;
int mid = start + (end - start) / 2;
// 如果方案为区间增值,把下面四个 = 改成 += 即可
tree[leftNode] = lazy[node] * (mid - start + 1);
lazy[leftNode] = lazy[node];
tree[rightNode] = lazy[node] * (end - mid);
lazy[rightNode] = lazy[node];
// 清空当前 lazy tag
lazy[node] = UNUSE;
}
/**
* 更新线段树
*
* @param node 当前节点对应到 tree 的下标
* @param start 当前节点对应的 nums 区间起点
* @param end 当前节点对应的 nums 区间终点
* @param left 待更新的 nums 左端点
* @param right 待更新的 nums 右端点
* @param val 区间设定的值
*/
private void updateTree(int node, int start, int end, int left, int right, int val) {
// 当前节点被区间完全覆盖,更新节点并设置 lazy tag,这样就没必要向下更新了
if (left <= start && end <= right) {
// 如果方案为区间增值,把下面两个 = 改成 += 即可
tree[node] = (end - start + 1) * val;
lazy[node] = val;
return;
}
// 把当前节点的 lazy tag 传递到下一层
pushDown(node, start, end);
int mid = start + (end - start) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
// 需要更新左子树
if (left <= mid) {
updateTree(leftNode, start, mid, left, right, val);
}
// 需要更新右子树
if (right > mid) {
updateTree(rightNode, mid + 1, end, left, right, val);
}
// 更新当前节点
tree[node] = tree[leftNode] + tree[rightNode];
}
/**
* 查询线段树
*
* @param node 当前节点对应到 tree 的下标
* @param start 当前节点对应的 nums 区间起点
* @param end 当前节点对应的 nums 区间终点
* @param left 待查询的 nums 区间起点
* @param right 待查询的 nums 区间终点
*/
private int queryTree(int node, int start, int end, int left, int right) {
// 查找范围不在当前范围内
if (right < start || left > end) {
return 0;
}
// 当前范围就在查找范围中,直接返回
if (start >= left && end <= right) {
return tree[node];
}
// 部分范围在查找区间中,先把 lazy tag 传下去
pushDown(node, start, end);
int mid = start + (end - start) / 2;
int leftNode = 2 * node + 1;
int rightNode = 2 * node + 2;
int leftSum = queryTree(leftNode, start, mid, left, right);
int rightSum = queryTree(rightNode, mid + 1, end, left, right);
return leftSum + rightSum;
}
/**
* 更新 nums 下标的值
*
* @param index 对应的 nums 下标
* @param val 更新的值
*/
public void update(int index, int val) {
updateTree(0, 0, nums.length - 1, index, index, val);
}
/**
* 区间更改
*
* @param left nums 左端点
* @param right nums 右端点
* @param val 设定的值
*/
public void update(int left, int right, int val) {
updateTree(0, 0, nums.length - 1, left, right, val);
}
/**
* 查询 nums 的区间和
*
* @param left 区间起点
* @param right 区间终点
*/
public int query(int left, int right) {
return queryTree(0, 0, nums.length - 1, left, right);
}
}