简易线段树
TypeScript
export function useSimpleSegmentTree<T>(
inputArray: T[],
operation: (a: T, b: T) => T,
operationFallback: T,
) {
const operate = operation
const fallback = operationFallback
const array = [...inputArray]
const n = array.length
const getLeftNode = (node: number) => node * 2 + 1
const getRightNode = (node: number) => node * 2 + 2
const getMid = (left: number, right: number) =>
Math.trunc((left + right) / 2)
const getInfo = (node: number, start: number, end: number) => {
return {
mid: getMid(start, end),
leftNode: getLeftNode(node),
rightNode: getRightNode(node),
}
}
function initTree(array: T[]) {
let treeLength: number
const length = array.length
if (Number.isInteger(Math.log2(length))) {
// 正好占据完美二叉树的所有叶子节点
treeLength = 2 * length - 1
} else {
const pow = Math.ceil(Math.log2(length))
treeLength = 2 * 2 ** pow - 1
}
return new Array(treeLength).fill(fallback)
}
const tree: T[] = initTree(array)
function buildTree(node: number, start: number, end: number) {
if (start === end) {
tree[node] = array[start]
return
}
const { mid, leftNode, rightNode } = getInfo(node, start, end)
buildTree(leftNode, start, mid)
buildTree(rightNode, mid + 1, end)
tree[node] = operate(tree[leftNode], tree[rightNode])
}
buildTree(0, 0, n - 1)
function updateTree(
node: number,
start: number,
end: number,
index: number,
value: T,
) {
if (start === end && index === start) {
tree[node] = value
return
}
const { mid, leftNode, rightNode } = getInfo(node, start, end)
if (index <= mid) updateTree(leftNode, start, mid, index, value)
else updateTree(rightNode, mid + 1, end, index, value)
tree[node] = operate(tree[leftNode], tree[rightNode])
}
function queryTree(
node: number,
start: number,
end: number,
left: number,
right: number,
): T {
if (right < start || left > end) return fallback
if (left <= start && end <= right) return tree[node]
const { mid, leftNode, rightNode } = getInfo(node, start, end)
const leftResult = queryTree(leftNode, start, mid, left, right)
const rightResult = queryTree(rightNode, mid + 1, end, left, right)
return operate(leftResult, rightResult)
}
const query = (left: number, right: number) => {
if (left > right || left < 0 || right >= n)
throw new Error('left 或 right 超出了范围')
return queryTree(0, 0, n - 1, left, right)
}
const update = (index: number, value: T) => {
if (index < 0 || index >= n) throw new Error('index 超出了数组范围')
updateTree(0, 0, n - 1, index, value)
}
return {
update,
query,
}
}
JavaScript
export function useSimpleSegmentTree(inputArray, operation, operationFallback) {
const operate = operation
const fallback = operationFallback
const array = [...inputArray]
const n = array.length
const getLeftNode = (node) => node * 2 + 1
const getRightNode = (node) => node * 2 + 2
const getMid = (left, right) => Math.trunc((left + right) / 2)
const getInfo = (node, start, end) => {
return {
mid: getMid(start, end),
leftNode: getLeftNode(node),
rightNode: getRightNode(node),
}
}
function initTree(array) {
let treeLength
const length = array.length
if (Number.isInteger(Math.log2(length))) {
treeLength = 2 * length - 1
} else {
const pow = Math.ceil(Math.log2(length))
treeLength = 2 * 2 ** pow - 1
}
return new Array(treeLength).fill(fallback)
}
const tree = initTree(array)
function buildTree(node, start, end) {
if (start === end) {
tree[node] = array[start]
return
}
const { mid, leftNode, rightNode } = getInfo(node, start, end)
buildTree(leftNode, start, mid)
buildTree(rightNode, mid + 1, end)
tree[node] = operate(tree[leftNode], tree[rightNode])
}
buildTree(0, 0, n - 1)
function updateTree(node, start, end, index, value) {
if (start === end && index === start) {
tree[node] = value
return
}
const { mid, leftNode, rightNode } = getInfo(node, start, end)
if (index <= mid) updateTree(leftNode, start, mid, index, value)
else updateTree(rightNode, mid + 1, end, index, value)
tree[node] = operate(tree[leftNode], tree[rightNode])
}
function queryTree(node, start, end, left, right) {
if (right < start || left > end) return fallback
if (left <= start && end <= right) return tree[node]
const { mid, leftNode, rightNode } = getInfo(node, start, end)
const leftResult = queryTree(leftNode, start, mid, left, right)
const rightResult = queryTree(rightNode, mid + 1, end, left, right)
return operate(leftResult, rightResult)
}
const query = (left, right) => {
if (left > right || left < 0 || right >= n)
throw new Error('left 或 right 超出了范围')
return queryTree(0, 0, n - 1, left, right)
}
const update = (index, value) => {
if (index < 0 || index >= n) throw new Error('index 超出了数组范围')
updateTree(0, 0, n - 1, index, value)
}
return {
update,
query,
}
}
Rust
use std::ops::Deref;
#[derive(Clone, Copy)]
struct NodeCtx {
index: usize,
start: usize,
end: usize,
}
impl From<(usize, usize, usize)> for NodeCtx {
fn from((index, start, end): (usize, usize, usize)) -> Self {
Self { index, start, end }
}
}
impl Deref for NodeCtx {
type Target = usize;
fn deref(&self) -> &usize {
&self.index
}
}
impl NodeCtx {
fn split(&self) -> (NodeCtx, NodeCtx) {
let mid = (self.start + self.end) >> 1;
let left_node = (self.index * 2 + 1, self.start, mid).into();
let right_node = (self.index * 2 + 2, mid + 1, self.end).into();
(left_node, right_node)
}
fn is_leaf(&self) -> bool {
self.start == self.end
}
fn contains(&self, index: usize) -> bool {
self.start <= index && index <= self.end
}
fn is_contained_in_range(&self, start: usize, end: usize) -> bool {
start <= self.start && self.end <= end
}
fn overlaps(&self, start: usize, end: usize) -> bool {
!(self.end < start || self.start > end)
}
}
pub struct SegmentTree<T, F>
where
T: Copy,
F: Fn(T, T) -> T,
{
tree: Vec<T>,
operate: F,
fallback: T,
size: usize,
}
impl<T, F> SegmentTree<T, F>
where
T: Clone + Copy,
F: Fn(T, T) -> T,
{
pub fn new(input_array: &[T], operation: F, operation_fallback: T) -> Self {
let operate = operation;
let fallback = operation_fallback;
let n = input_array.len();
let m = n.next_power_of_two();
let tree_length = 2 * m - 1;
let tree = vec![fallback; tree_length];
let mut result = Self {
tree,
operate,
fallback,
size: n,
};
result.build_tree((0, 0, n - 1).into(), input_array);
result
}
fn build_tree(&mut self, node: NodeCtx, array: &[T]) {
if node.is_leaf() {
self.tree[*node] = array[node.start];
return;
}
let (left_node, right_node) = node.split();
self.build_tree(left_node, array);
self.build_tree(right_node, array);
self.tree[*node] = (self.operate)(self.tree[*left_node], self.tree[*right_node]);
}
fn update_tree(&mut self, node: NodeCtx, index: usize, value: T) {
if node.is_leaf() {
self.tree[*node] = value;
return;
}
let (left_node, right_node) = node.split();
if left_node.contains(index) {
self.update_tree(left_node, index, value);
} else if right_node.contains(index) {
self.update_tree(right_node, index, value);
}
self.tree[*node] = (self.operate)(self.tree[*left_node], self.tree[*right_node]);
}
fn query_tree(&self, node: NodeCtx, left: usize, right: usize) -> T {
if !node.overlaps(left, right) {
return self.fallback;
}
if node.is_contained_in_range(left, right) {
return self.tree[*node];
}
let (left_node, right_node) = node.split();
let left_result = self.query_tree(left_node, left, right);
let right_result = self.query_tree(right_node, left, right);
(self.operate)(left_result, right_result)
}
pub fn query(&self, left: usize, right: usize) -> T {
let n = self.size;
if left > right || right >= n {
panic!("超出范围");
}
self.query_tree((0, 0, n - 1).into(), left, right)
}
pub fn update(&mut self, index: usize, value: T) {
let n = self.size;
if index >= n {
panic!("超出范围");
}
self.update_tree((0, 0, n - 1).into(), index, value)
}
}