树上问题三题
例1
有一棵树共 \(n\) 个节点,\(n-1\) 条边。共有 \(q\) 个问询,每个问询关于两条路径:\(u\) 到 \(v\) 的路径,\(a\) 到 \(b\) 的路径,需要判断这两条路径是否相交,也就是判断两条路径是否有公共的节点。对于这 \(q\) 个问询,请统计共有几个问题的答案是路径相交的。
首先考虑暴力方法:
- 枚举树上所有 \(n\) 个节点,判断是否有节点同时在两个路径上。复杂度为 \(\Theta(nq\log n)\)
- 枚举路径 \((a,b)\) 上每一个节点 \(x\) ,判断 \(x\) 是否在路径 \((u,v)\) 上。平均复杂度为 \(\Theta(q\log^2 n)\),最差复杂度为 \(\Theta(nq\log n)\)
还能不能更快?
手算几个样例,可以发现规律:设 \(lca(u,v)\) 不高于 \(lca(a,b)\) ,两直线相交,必经过 \(lca(u,v)\),即必经过较低的 \(lca\)。
因此产生了正解思路:对于路径 \((a,b)\) 和路径 \((u,v)\),分别计算 \(lca(a,b)\) 和 \(lca(u,v)\)。当 \(lca(a,b)\) 高于 \(lca(u,v)\),则如果 \((a,b)\) 经过 \(lca(u,v)\),两路径相交。反之如果 \((u,v)\) 经过 \(lca(a,b)\),两路径相交。该算法时间复杂度为 \(\Theta(n\log n+q\log n)\)。
AC代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll ans;
int L,n,q,d[200009],p[200009][209],tI[200009],tO[200009],timer;
vector<int> to[200009];
void depth(int x,int dep){
d[x]=dep;
for(int i=0;i<to[x].size();i++){
if(d[to[x][i]]) continue;
depth(to[x][i],dep+1);
}
}
void init(int u,int fa){
tI[u]=++timer;
p[u][0]=fa;
for(int i=1;i<=L;i++) p[u][i]=p[p[u][i-1]][i-1];
for(int i=0;i<to[u].size();i++){
if(to[u][i]!=fa) init(to[u][i],u);
}
tO[u]=++timer;
}
bool up(int u,int v){
return !u||tI[u]<=tI[v]&&tO[v]<=tO[u];
}
int lca(int u,int v){
if(u==v) return u;
if(up(u,v)) return u;
if(up(v,u)) return v;
for(int i=L;i>=0;i--){
if(!up(p[u][i],v)) u=p[u][i];
}
return p[u][0];
}
int dst(int x,int y){
return d[x]+d[y]-2*d[lca(x,y)];
}
bool onPath(int x,int u,int v){
return dst(u,x)+dst(v,x)==dst(u,v);
}
int main(){
cin>>n;
for(int i=1;i<=n-1;i++){
int u,v; cin>>u>>v;
to[u].push_back(v);
to[v].push_back(u);
}
depth(1,1);init(1,0);
L=log(n)/log(2)+1;
cin>>q;
for(int i=1;i<=q;i++){
int a,b,u,v;cin>>u>>v>>a>>b;
int uv=lca(u,v),ab=lca(a,b);
if(d[ab]>d[uv]) ans+=onPath(ab,u,v);
else ans+=onPath(uv,a,b);
}
cout<<ans;
return 0;
}
例2
输入一棵树,输出该树的直径。
思路1:
枚举每一个转折点,计算以该点为转折点的两条最长路(注意不能有重叠的边),输出最大值。
AC代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
int ans,n,lst[100009],h1[100009],h2[100009];
vector<int> to[100009];
void dfs(int u,int fa){
for(int i=0;i<to[u].size();i++){
int v=to[u][i];
if(v==fa) continue;
dfs(v,u);
h2[u]=max(h2[u],h1[v]+1);
if(h2[u]>h1[u]) swap(h2[u],h1[u]);
}
return ;
}
void solve(){
dfs(1,0);
for(int i=1;i<=n;i++) lst[i]=h1[i]+h2[i];
ans=*max_element(lst+1,lst+1+n);
return ;
}
int main(){
cin>>n;
for(int i=1;i<=n-1;i++){
int u,v;
cin>>u>>v;
to[u].push_back(v);
to[v].push_back(u);
}
solve();
cout<<ans;
return 0;
}
思路2:
直径的 \(2\) 个端点必定是“边缘点”。执行 \(2\) 次DFS,第一次以任意点为根节点,找到距离 \(A\) 最远的点 \(B\)。再以 \(B\) 为根节点DFS,找到距离 \(B\) 最远的点 \(C\)。则 \(BC\) 一定为直径。
例3
春节期间,小明要去长辈家里拜年。假设小明和爸妈住在 \(A\) 点,爷爷奶奶家在 \(B\) 点,外公外婆家在 \(C\) 点。小明的拜年路线会从 \(A\) 点出发,然后判断 \(B\) 和 \(C\) 点哪个点距离 \(A\) 点更近,小明就先去哪个点拜年。假如 \(A\) 到 \(B\) 的距离比 \(A\) 到 \(C\) 的距离更近,那么小明就先从 \(A\) 到 \(B\) 再从 \(B\) 到 \(C\);假如 \(A\) 到 \(C\) 的距离比 \(A\) 到 \(B\) 的距离更近,那么小明就先从 \(A\) 到 \(C\) 再从 \(C\) 到 \(B\)。已知 \(A,B,C\) 所在地区的道路形成一棵树的形状,共 \(n\) 个点,\(n-1\) 条道路,所有点都能通过道路连通。但是 \(A,B,C\) 具体在哪个点的位置,我们并不知道,请问小明从 \(A\) 点开始拜年两处长辈所走的路径总长度最多可能是多少?
思路1:枚举转折点,预计算前 \(3\) 长路径
思路2:计算直径 \(BC\) \(+\) 枚举 \(A\)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll n,d[200009],tI[200009],tO[200009],p[200009][109],timer,L;
vector<ll> to[200009];
vector<ll> dis[200009];
void dfs(ll x,ll dep){
d[x]=dep;
for(ll i=0;i<to[x].size();i++){
ll y=to[x][i];
if(d[y]) continue;
dfs(y,dep+dis[x][i]);
}
return ;
}
void init(ll u,ll fa){
tI[u]=++timer;
p[u][0]=fa;
for(ll i=1;i<=L;i++) p[u][i]=p[p[u][i-1]][i-1];
for(ll i=0;i<to[u].size();i++){
if(to[u][i]!=fa) init(to[u][i],u);
}
tO[u]=++timer;
}
bool up(ll u,ll v){
return !u||tI[u]<=tI[v]&&tO[v]<=tO[u];
}
ll lca(ll u,ll v){
if(u==v) return u;
if(up(u,v)) return u;
if(up(v,u)) return v;
for(ll i=L;i>=0;i--){
if(!up(p[u][i],v)) u=p[u][i];
}
return p[u][0];
}
ll dst(ll x,ll y){
return d[x]+d[y]-2*d[lca(x,y)];
}
int main(){
scanf("%lld",&n);
for(ll i=1;i<=n-1;i++){
ll u,v,w;
scanf("%lld%lld%lld",&u,&v,&w);
to[u].push_back(v);
to[v].push_back(u);
dis[u].push_back(w);
dis[v].push_back(w);
}
dfs(1,1);
ll B=max_element(d+1,d+1+n)-d;
memset(d,0,sizeof(d));
dfs(B,1);
ll C=max_element(d+1,d+1+n)-d;
ll ans=0;
memset(d,0,sizeof(d));
dfs(1,1);
L=log(n)/log(2)+1;
init(1,0);
for(ll A=1;A<=n;A++){
ll AB=dst(A,B);
ll AC=dst(A,C);
ll cost=min(AB,AC)+dst(B,C);
ans=max(ans,cost);
}
printf("%lld",ans);
return 0;
}