树上问题三题

客博的__80gnaij11__ / 2023-08-21 / 原文

例1

有一棵树共 \(n\) 个节点,\(n-1\) 条边。共有 \(q\) 个问询,每个问询关于两条路径:\(u\)\(v\) 的路径,\(a\)\(b\) 的路径,需要判断这两条路径是否相交,也就是判断两条路径是否有公共的节点。对于这 \(q\) 个问询,请统计共有几个问题的答案是路径相交的。

首先考虑暴力方法:

  1. 枚举树上所有 \(n\) 个节点,判断是否有节点同时在两个路径上。复杂度为 \(\Theta(nq\log n)\)
  2. 枚举路径 \((a,b)\) 上每一个节点 \(x\) ,判断 \(x\) 是否在路径 \((u,v)\) 上。平均复杂度为 \(\Theta(q\log^2 n)\),最差复杂度为 \(\Theta(nq\log n)\)

还能不能更快?

手算几个样例,可以发现规律:设 \(lca(u,v)\) 不高于 \(lca(a,b)\) ,两直线相交,必经过 \(lca(u,v)\),即必经过较低的 \(lca\)

因此产生了正解思路:对于路径 \((a,b)\) 和路径 \((u,v)\),分别计算 \(lca(a,b)\)\(lca(u,v)\)。当 \(lca(a,b)\) 高于 \(lca(u,v)\),则如果 \((a,b)\) 经过 \(lca(u,v)\),两路径相交。反之如果 \((u,v)\) 经过 \(lca(a,b)\),两路径相交。该算法时间复杂度为 \(\Theta(n\log n+q\log n)\)


AC代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll ans;
int L,n,q,d[200009],p[200009][209],tI[200009],tO[200009],timer;
vector<int> to[200009];
void depth(int x,int dep){
	d[x]=dep;
	for(int i=0;i<to[x].size();i++){
		if(d[to[x][i]])  continue;
		depth(to[x][i],dep+1);
	}
}
void init(int u,int fa){
	tI[u]=++timer;
	p[u][0]=fa;
	for(int i=1;i<=L;i++)  p[u][i]=p[p[u][i-1]][i-1];
	for(int i=0;i<to[u].size();i++){
		if(to[u][i]!=fa)  init(to[u][i],u);
	}
	tO[u]=++timer;
}
bool up(int u,int v){
	return !u||tI[u]<=tI[v]&&tO[v]<=tO[u];
}
int lca(int u,int v){
	if(u==v)  return u;
	if(up(u,v))  return u;
	if(up(v,u))  return v;
	for(int i=L;i>=0;i--){
		if(!up(p[u][i],v))  u=p[u][i];
	}
	return p[u][0];
}
int dst(int x,int y){
	return d[x]+d[y]-2*d[lca(x,y)];
}
bool onPath(int x,int u,int v){
	return dst(u,x)+dst(v,x)==dst(u,v);
}
int main(){
	cin>>n;
	for(int i=1;i<=n-1;i++){
		int u,v;  cin>>u>>v;
		to[u].push_back(v);
		to[v].push_back(u);
	}
	depth(1,1);init(1,0);
	L=log(n)/log(2)+1;
	cin>>q;
	for(int i=1;i<=q;i++){
		int a,b,u,v;cin>>u>>v>>a>>b;
		int uv=lca(u,v),ab=lca(a,b);
		if(d[ab]>d[uv])  ans+=onPath(ab,u,v);
		else  ans+=onPath(uv,a,b);
	}
	cout<<ans;
	return 0;
}

例2

输入一棵树,输出该树的直径。

思路1:

枚举每一个转折点,计算以该点为转折点的两条最长路(注意不能有重叠的边),输出最大值。

AC代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int ans,n,lst[100009],h1[100009],h2[100009];
vector<int> to[100009];
void dfs(int u,int fa){
	for(int i=0;i<to[u].size();i++){
		int v=to[u][i];
		if(v==fa)  continue;
		dfs(v,u);
		h2[u]=max(h2[u],h1[v]+1);
		if(h2[u]>h1[u])  swap(h2[u],h1[u]);
	}
	return ;
}
void solve(){
	dfs(1,0);
	for(int i=1;i<=n;i++)  lst[i]=h1[i]+h2[i];
	ans=*max_element(lst+1,lst+1+n);
	return ;
}
int main(){
	cin>>n;
	for(int i=1;i<=n-1;i++){
		int u,v;
		cin>>u>>v;
		to[u].push_back(v);
		to[v].push_back(u);
	}
	solve();
	cout<<ans;
	return 0;
}


思路2:

直径的 \(2\) 个端点必定是“边缘点”。执行 \(2\) 次DFS,第一次以任意点为根节点,找到距离 \(A\) 最远的点 \(B\)。再以 \(B\) 为根节点DFS,找到距离 \(B\) 最远的点 \(C\)。则 \(BC\) 一定为直径。

例3

春节期间,小明要去长辈家里拜年。假设小明和爸妈住在 \(A\) 点,爷爷奶奶家在 \(B\) 点,外公外婆家在 \(C\) 点。小明的拜年路线会从 \(A\) 点出发,然后判断 \(B\)\(C\) 点哪个点距离 \(A\) 点更近,小明就先去哪个点拜年。假如 \(A\)\(B\) 的距离比 \(A\)\(C\) 的距离更近,那么小明就先从 \(A\)\(B\) 再从 \(B\)\(C\);假如 \(A\)\(C\) 的距离比 \(A\)\(B\) 的距离更近,那么小明就先从 \(A\)\(C\) 再从 \(C\)\(B\)。已知 \(A,B,C\) 所在地区的道路形成一棵树的形状,共 \(n\) 个点,\(n-1\) 条道路,所有点都能通过道路连通。但是 \(A,B,C\) 具体在哪个点的位置,我们并不知道,请问小明从 \(A\) 点开始拜年两处长辈所走的路径总长度最多可能是多少?


思路1:枚举转折点,预计算前 \(3\) 长路径

思路2:计算直径 \(BC\) \(+\) 枚举 \(A\)


#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll n,d[200009],tI[200009],tO[200009],p[200009][109],timer,L;
vector<ll> to[200009];
vector<ll> dis[200009];
void dfs(ll x,ll dep){
	d[x]=dep;
	for(ll i=0;i<to[x].size();i++){
		ll y=to[x][i];
		if(d[y])  continue;
		dfs(y,dep+dis[x][i]);
	}
	return ;
}
void init(ll u,ll fa){
	tI[u]=++timer;
	p[u][0]=fa;
	for(ll i=1;i<=L;i++)  p[u][i]=p[p[u][i-1]][i-1];
	for(ll i=0;i<to[u].size();i++){
		if(to[u][i]!=fa)  init(to[u][i],u);
	}
	tO[u]=++timer;
}
bool up(ll u,ll v){
	return !u||tI[u]<=tI[v]&&tO[v]<=tO[u];
}
ll lca(ll u,ll v){
	if(u==v)  return u;
	if(up(u,v))  return u;
	if(up(v,u))  return v;
	for(ll i=L;i>=0;i--){
		if(!up(p[u][i],v))  u=p[u][i];
	}
	return p[u][0];
}
ll dst(ll x,ll y){
	return d[x]+d[y]-2*d[lca(x,y)];
}
int main(){
	scanf("%lld",&n);
	for(ll i=1;i<=n-1;i++){
		ll u,v,w;
		scanf("%lld%lld%lld",&u,&v,&w);
		to[u].push_back(v);
		to[v].push_back(u);
		dis[u].push_back(w);
		dis[v].push_back(w);
	}
	dfs(1,1);
	ll B=max_element(d+1,d+1+n)-d;
	memset(d,0,sizeof(d));
	dfs(B,1);
	ll C=max_element(d+1,d+1+n)-d;
	ll ans=0;
	memset(d,0,sizeof(d));
	dfs(1,1);
	L=log(n)/log(2)+1;
	init(1,0);
	for(ll A=1;A<=n;A++){
		ll AB=dst(A,B);
		ll AC=dst(A,C);
		ll cost=min(AB,AC)+dst(B,C);
		ans=max(ans,cost); 
	}
	printf("%lld",ans);
	return 0;
}