k - 路径(mid vension)

onlyblues / 2024-10-18 / 原文

k - 路径(mid vension)

题目描述

这是问题的中等版本。三个版本的做法之间可能有交集。

给定一颗 $n$ 个点的无根树,每个节点有类型和权值,分别用 $c$ 和 $w$ 表示。一条简单路径的权值为这条路径包含的节点权值之和。

对于一个长度为 $len+1(len>1)$ 的序列 $a$,假如 $a_1 ​ =a_{len+1} = len$,并且 $a_2, a_3, \ldots, a_{len}$ 是一个长度为 $len−1$ 的排列,那么称序列 $a$ 是 $len$ - 特殊排列。

假如从 $u$ 走到 $v$ 长度为 $k+1$ 的简单路径依次经过的节点类型组成的序列是一个 $k$ - 特殊排列,那么这个简单路径是一条 $k$ - 路径。

特别的,假如一条简单路径只由两个类型为 $1$ 的节点组成,那么这条路径是 $1$- 路径。

对于 $k=1,2,3, \ldots, n$,输出一个整数,表示所有经过点 $x$ 的 $k$ - 路径的最大权值,假如不存在,输出 $−1$。

提示:本题输入输出数据量较大,建议选手使用快速的输入输出方式。

输入描述:

第一行输入一个整数 $T(1 \leq T \leq 10^4)$,表示测试数据组数。接下来是 $T$ 个测试用例。

每个测试用例第一行输入两个整数 $n,x(2 \leq n \leq 10^6 ,1 \leq x \leq n)$。

第二行有 $n$ 个数 $c_i(1 \leq c_i \leq n)$,表示每个节点的类型。

第三行有 $n$ 个数 $w_i(−10^9 \leq w_i ​\leq 10^9)$,表示每个节点的权值。

然后 $n−1$ 行,每一行有两个整数 $u,v(1 \leq u,v \leq n)$,表示有一条 $u$ 到 $v$ 的无向边。

保证所有测试用例 $n$ 的和不超过 $10^6$。

输出描述:

对于每个测试用例,输出一行,包含 $n$ 个空格分隔的整数,表示答案。

示例1

输入

1
8 1
1 2 3 4 2 2 4 3
7 -8 -5 4 10 1 -7 -9
1 2
2 3
3 4
1 7
1 5
1 6
1 8

输出

-1 18 -15 -9 -1 -1 -1 -1

说明

在第一个测试用例中,简单路径 $(5,6)$ 的权值为 $18$,并且这条简单路径的节点类型组成的序列为 $[2,1,2]$ 是一个 $2$ - 特殊排列,所以这条简单路径是一个 $2$ - 路径,并且显然没有其它任何一条 $2$ - 路径的权值比这条路径的权值大,所以对于 $k=2$,输出 $18$。

示例2

输入

4
10 1
1 2 3 4 5 2 3 4 1 5
-5 9 7 8 3 -7 -10 -7 6 -1
1 2
2 3
3 4
4 5
10 9
1 6
1 7
1 10
1 8
10 1
1 2 3 4 5 5 3 2 4 2
-2 6 -1 4 10 3 4 -1 3 5
1 2
2 3
3 4
4 5
1 8
10 7
1 9
9 6
1 10
10 1
1 2 3 4 5 4 5 5 5 2
-1 -6 -3 7 8 8 5 9 -9 -5
1 2
2 3
3 4
4 5
1 10
7 9
1 6
1 9
1 8
10 1
1 2 3 4 5 5 2 5 5 3
-9 -9 6 -3 3 -9 9 -1 -3 -6
1 2
2 3
3 4
4 5
1 8
1 7
9 10
1 6
1 10

输出

-1 -3 1 12 21 -1 -1 -1 -1 -1
-1 9 -1 10 -1 -1 -1 -1 -1 -1
-1 -12 -1 5 14 -1 -1 -1 -1 -1
-1 -9 -18 -1 -13 -1 -1 -1 -1 -1

 

解题思路

  前不久在 H. Robin Hood Archery 这题学到了异或哈希,没想到很快就用上了,不过可惜的是赛时压根就没看这题,不然还真有可能做出来(虽然补题时因为一个很睿智的错误 debug 了几小时)。

  因为选出的路径要包含节点 $x$,因此容易想到以 $x$ 为根去 dfs 找出链来构成这样的路径(补题的时候习惯性把 $1$ 写成根结果 debug 了半天都找不出这个问题,样例也很坑就是了)。那么包含节点 $x$ 的路径有两种,一种是直接以 $x$ 为端点的一条链,另一种是两条以 $x$ 为端点的链构成的路径。实际上,这些都可以视为同一种路径。在处理时,我们统一以两条链来构成路径。那么一条合法的路径首先要满足 $c_{v_1} = c_{v_2} = c_x$,其中 $v_1$ 和 $v_2$ 是路径的两个端点,同时路径的长度要恰好为 $c_x +1$。并且除了两个端点外,路径上的其他点的 $c_i$ 要构成 $1 \sim c_x - 1$ 的排列。

  如果我们把路径上的点看作成序列,要快速判断一个序列是不是排列,可以用开头提到的异或哈希来实现。具体来说,我们把 $1 \sim n$ 分别映射成一个随机的 64 位无符号整数,即 $i \to f(i), \, (i = 1, \ldots, n)$,这么做是为了降低哈希冲突。定义 $g(i) = f(1) \oplus f(2) \oplus \cdots \oplus f(i)$,因此如果一个序列 $[a_1, \ldots, a_n]$ 是 $1 \sim n$ 的排列,那么应该有 $f(a_1) \oplus \cdots \oplus f(a_n) = g(n)$(这里利用了异或运算 $x \oplus y = y \oplus x$ 这一性质)。

  我们先以 $x$ 为根 dfs 求出以下信息:

  • $h_v$:从 $x$ 到 $v$ 的路径上所有点 $i$ 的 $f(i)$ 的异或和。
  • $s_v$:从 $x$ 到 $v$ 的路径上所有点 $i$ 的 $w_i$ 的和。
  • $d_v$:从 $x$ 到 $v$ 的路径上所有点数量(路径长度)。
  • $q_k$:一个集合,存储所有 $c_v = k$ 的点的编号 $v$。

  因此我们可以在每个 $q_k$ 中求出经过节点 $x$ 的最大 $k$-路径。枚举 $q_k$ 中的点 $v$,如果 $q_k$ 中另外一个点 $u$ 能与 $v$ 构成合法路径的端点,那么就要同时满足以下两个条件:

$$\begin{cases}
h_u \oplus h_v = g(k-1) \oplus f(c_x) \\
d_u + d_v - 1 = k+1
\end{cases}
\Rightarrow
\begin{cases}
h_u = g(k-1) \oplus f(c_x) \oplus h_v \\
d_u = k+2 - d_v
\end{cases}$$

  为了快速查询满足上述条件的 $u$ 对应的最大 $s_u$,我们可以开一个 std::map<array<ULL>, LL> mp,枚举完一个点后把二元组 $(h_v, d_v)$ 作为键,$s_v$ 作为值记录其中。当存在合法的 $u$ 时,路径的最大权值就是 mp[{g[k - 1] ^ f[c[x]] ^ h[v], k + 2 - d[v]}] + s[v] - w[x]

  k - 路径(hard vension) 没学过点分治不会,之后学了再来补吧()

  AC 代码如下,时间复杂度为 $O(n \log{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
typedef unsigned long long ULL;

mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());

const int N = 1e6 + 5, M = N * 2;

int c[N], w[N];
int h[N], e[M], ne[M], idx;
ULL f[N], g[N], hs[N], d[N];
LL s[N];
vector<int> q[N];

void add(int u, int v) {
    e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}

void dfs(int u, int p) {
    hs[u] = hs[p] ^ f[c[u]];
    s[u] = s[p] + w[u];
    d[u] = d[p] + 1;
    q[c[u]].push_back(u);
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p) continue;
        dfs(v, u);
    }
}

void solve() {
    int n, x;
    cin >> n >> x;
    for (int i = 1; i <= n; i++) {
        cin >> c[i];
    }
    for (int i = 1; i <= n; i++) {
        cin >> w[i];
    }
    idx = 0;
    memset(h, -1, n + 1 << 2);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    for (int i = 1; i <= n; i++) {
        f[i] = rng();
        g[i] = g[i - 1] ^ f[i];
        q[i].clear();
    }
    dfs(x, 0);
    for (int k = 1; k <= n; k++) {
        if (q[k].empty()) {
            cout << -1 << ' ';
            continue;
        }
        LL ret = -1e18;
        map<array<ULL, 2>, LL> mp;
        for (auto &v : q[k]) {
            if (d[v] > k + 1) continue;
            ULL a = g[k - 1] ^ f[c[x]] ^ hs[v], b = k + 2 - d[v];
            if (mp.count({a, b})) ret = max(ret, s[v] + mp[{a, b}] - w[x]);
            if (mp.count({hs[v], d[v]})) mp[{hs[v], d[v]}] = max(mp[{hs[v], d[v]}], s[v]);
            else mp[{hs[v], d[v]}] = s[v];
        }
        cout << (ret == LL(-1e18) ? -1 : ret) << ' ';
    }
    cout << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int t;
    cin >> t;
    while (t--) {
        solve();
    }
    
    return 0;
}

 

参考资料

  【比赛题目讲解】牛客小白月赛102:https://www.bilibili.com/video/BV1no21YsESX?p=7