树上可持久化线段树

pengchujie / 2023-08-26 / 原文

例题传送门:Count on a tree

简要题意:有棵\(n\)个节点的树,每次点有个权值\(a_i\),每次询问给出\(u,v,k\),求\(u,v\)两个节点的简单路径上(包括\(u,v\))上第\(k\)小的点,保证数据有解,强制在线

\(1\le n,m\le 10^5,a_i\in[1,2^{31}-1]\)

首先,第\(k\)小就可以想到要可持久化线段树,动态开点。

但是我们发现,要尽量快的寻找\(u,v\)的简单路径上包含的节点。

然后我们就可以用一个很经典的\(trick\)\(s_i\)\(i\)到根节点的路径上包含的节点,则 \(u,v\)的简单路径上包含的节点为:

\[s_u+s_v-s_{LCA(u,v)}-s_{fa[LCA(u,v)]} \]

考虑可持久化线段树存每个\(i\)到根节点的简单路径上包含的节点,只有注意\(a_i\)的离散化就可以做出这道题了

上代码:

#include<bits/stdc++.h>
#define ll int
using namespace std;
const ll N=1e5+50;
ll n,m,u,v,k,ans;
vector<ll> e[N];
struct jgt
{
	ll x,pos;
}a[N];
ll re[N];
bool cmp1(jgt t1,jgt t2)
{
	return t1.x<t2.x;
}
bool cmp2(jgt t1,jgt t2)
{
	return t1.pos<t2.pos;
}
struct jgt1
{
	ll l,r,gs;
}tr[N*25];
ll rt[N],tot;
ll f[N][25],dep[N];
void add(ll &now,ll last,ll l,ll r,ll md)
{
	now=++tot;
	tr[now]=tr[last];
	tr[now].gs++;
	if(l==r) return ;
	ll mid=(l+r)/2;
	if(md<=mid) add(tr[now].l,tr[last].l,l,mid,md);
	else add(tr[now].r,tr[last].r,mid+1,r,md);
}
void dfs(ll wz,ll last)
{
	add(rt[wz],rt[last],1,n,a[wz].x);
	f[wz][0]=last;
	dep[wz]=dep[last]+1;
	for(ll i=1;i<=18;i++)
	f[wz][i]=f[f[wz][i-1]][i-1];
	for(ll i=0;i<e[wz].size();i++)
	{
		ll j=e[wz][i];
		if(j==last) continue;
		dfs(j,wz);
	}
}
ll LCA(ll x,ll y)
{
	if(dep[x]<dep[y]) swap(x,y);
	for(ll i=18;i>=0;i--)
	if(dep[f[x][i]]>=dep[y]) x=f[x][i];
	if(x==y) return x;
	for(ll i=18;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}
ll query(ll n1,ll n2,ll la1,ll la2,ll l,ll r,ll shu)
{
	if(l==r) return l;
	ll tzy=tr[tr[n1].l].gs+tr[tr[n2].l].gs-tr[tr[la1].l].gs-tr[tr[la2].l].gs;
	ll mid=(l+r)/2;
	if(tzy>=shu) return query(tr[n1].l,tr[n2].l,tr[la1].l,tr[la2].l,l,mid,shu);
	return query(tr[n1].r,tr[n2].r,tr[la1].r,tr[la2].r,mid+1,r,shu-tzy);
}
int main()
{
	scanf("%lld %lld",&n,&m);
	for(ll i=1;i<=n;i++)
	{
		scanf("%d",&a[i].x);
		a[i].pos=i;
	}
	for(ll i=1;i<n;i++)
	{
		scanf("%d %d",&u,&v);
		e[u].push_back(v);
		e[v].push_back(u);
	}
	sort(a+1,a+n+1,cmp1);
	for(ll i=1;i<=n;i++)
	re[i]=a[i].x,a[i].x=i;
	sort(a+1,a+n+1,cmp2);
	dfs(1,0);
	for(ll i=1;i<=m;i++)
	{
		scanf("%d %d %d",&u,&v,&k);
		u=(u^ans);
		ans=re[query(rt[u],rt[v],rt[LCA(u,v)],rt[f[LCA(u,v)][0]],1,n,k)];
		printf("%d\n",ans);
	}
	return 0;
}