[树状数组] 学习笔记

Zwjay's Blog / 2023-08-23 / 原文

原理

int lowbit(int x)
{
	return x & (-x);
}

void add(int x, int k)
{
	for (; x <= n; x += lowbit(x)) c[x] += k;
}

int query(int x)
{
	int ans = 0;
	for (; x ; x -= lowbit(x)) ans += c[x];
	return ans;
}

单点修改 + 区间查询

int lowbit(int x)
{
	return x & (-x);
}
void add(int x, int k)
{
	for (; x <= n; x += lowbit(x)) c[x] += k;
}
int query(int x)
{
	int ans = 0;
	for (; x ; x -= lowbit(x)) ans += c[x];
	return ans;
}
void init()
{
	for (int i = 1; i <= n; i ++) add(i, a[i]);
}
void get_add(int x, int k)
{
	add(x, k);
}
int get_query(int x, int y)
{
	return query(y) - query(x - 1);
}

单点查询 + 区间修改

设差分数组 \(b\),定义 \(b[i]=a[i]-a[i-1](a[0]=0,i\in[1,n])\),则有两个性质:

  • \(a[i]+=k\Leftrightarrow \begin{cases}b[l] +=k\\b[r+1] -=k\end{cases}~(i\in[l,r])\)

  • \(a[x] = \sum\limits_{i=1}^{x}b[i]\)

int lowbit(int x)
{
	return x & (-x);
}
void add(int x, int k)
{
	for (; x <= n; x += lowbit(x)) c[x] += k;
}
int query(int x)
{
	int ans = 0;
	for (; x ; x -= lowbit(x)) ans += c[x];
	return ans;
}
void init()
{
	for (int i = 1; i <= n; i ++)
	{
		b[i] = a[i] - a[i-1];
		add(i, b[i]);
	}
}
void get_add(int x, int y, int k)
{
	add(x, k);
	add(y + 1, -k);
}
int get_query(int x)
{
	return query(x);
}

区间查询 + 区间修改

在差分思想的基础上,若查询原数组 \(a\)\([1,p]\) 的区间和,则有:

\[\begin{aligned}\sum\limits_{i=1}^{p}a[i] &= \sum\limits_{i=1}^{p}\sum\limits_{j=1}^{i}b[j] \\ &=\sum\limits_{i=1}^{p}[~b[i]\cdot(p-i+1)~] \\ &=\sum\limits_{i=1}^{p}[~b[i]\cdot(p+1)-b[i]\cdot i~] \\ &=\sum\limits_{i=1}^{p}[~b[i]\cdot(p+1)~]-\sum\limits_{i=1}^{p}(~b[i]\cdot i~) \end{aligned}\]

到最后一个式子就把不同的变量分离了,实现更简单,\(b[i]\cdot i\) 可以再建一个树状数组代替。

注意:树状数组中的下标不是原数组的下标,在 \(p\) 这个位置上加,所以是\(\times p\) ,而不是 \(\times x\)

int lowbit(int x)
{
	return x & (-x);
}
void add(int x, int k)
{
	int p = x;
	for (; x <= n; x += lowbit(x)) 
	{
		c1[x] += k;
		c2[x] += k * p;
	}
}
int query(int x)
{
	int ans = 0, p = x;
	for (; x ; x -= lowbit(x)) 
		ans += (p + 1) * c1[x] - c2[x];
	return ans;
}
void init()
{
	for (int i = 1; i <= n; i ++)
	{
		b[i] = a[i] - a[i-1];
		add(i, b[i]);
	}
}
void get_add(int x, int y, int k)
{
	add(x, k);
	add(y + 1, -k);
}
int get_query(int x, int y)
{
	return query(y) - query(x - 1);
}