Codeforces 1988D(思维 + 树形DP)

佚名 / 2024-09-18 / 原文

题目链接

题意

有一棵包含 \(n\) 个结点的树。编号为 \(i(1 \le i \le n)\) 的结点上有一个攻击力为 \(a_i\) 的怪物。你要跟怪物对战若干回合,将它们全部杀死。
每一回合,所有存活着的怪物会对你进行一次攻击,假设它们的编号分别为 \(j_1、j_2、\cdot\cdot\cdot、j_m\),则损失的生命值为 \(a_{j_1} + a_{j_2} + \cdot\cdot\cdot + a_{j_m}\)
然后,你选择若干结点 \((u_1、u_2、u_3、\cdot\cdot\cdot、u_m)\),满足:\(u_1、u_2、u_3、\cdot\cdot\cdot、u_m\) 两两不在同一条边上,将 \(u_1、u_2、u_3、\cdot\cdot\cdot、u_m\) 位置上的怪物杀死。
若选择最优的方案,在所有怪物被杀死之后,你最少损失多少生命值。

数据范围:\(1 \le n \le 3 \times 10^5\)\(1 \le a_i \le 10^{12}\)

题解

首先,\(2\) 个回合一定可以把所有怪物打死,但是,显然这种做法不总是最优的。

考虑 树形DP,设 \(dp[U][i]\) 表示将以 \(U\) 为根的子树上的怪物全部杀死并且第 \(i\) 回合杀死结点 \(U\) 位置上的怪物 之后,最少损失多少生命值。

结论 :最优策略下的总回合数一定小于等于 $ log n + 1$

证明:对于任意一个结点 $ U $,假设它的邻点为 \(V_1、V_2、V_3、\cdot\cdot\cdot、V_m\),击杀它们的回合数分别为:\(t_1、t_2、t_3、\cdot\cdot\cdot、t_m\),那么在回合 \(MEX(t_1、t_2、\cdot\cdot\cdot、t_m)\) 击杀编号为 \(U\) 的结点上的怪物显然是最优的。
假设存在某一个怪兽,我们在第 \(i\) 回合击杀它是最优的,那么树上结点的个数至少为 \(f_i\)
有:\(f_1 = 1\)\(\forall i \ge 2, f_i = 1 + \sum\limits_{j = 1}^{i - 1} f_j\)
$\rightarrow f_i = 1 + f_{i - 1} + f_{i - 1} $

$\rightarrow f_i = 1 + 4 \times f_{i - 2} $

$\rightarrow f_i = 1 + 2^{i - 1} $

因为需要满足 \(f_i = 1 + 2^{i-1} \le n\),所以,\(j \le log(n - 1) + 1\),证毕。


DP 的转移方程:\(dp[U][i] = i \times a_U + \sum\limits_{V \in son(U)} \min\limits_{j \in [1, logn + 1] \&\& j \ne i} dp[V][j]\)

答案即:$\min\limits_{j = 1}^{logn + 1} dp[1][j] $。

时间复杂度为 \(\mathcal{O}(nlog^2n)\)

点击查看代码
#include <bits/stdc++.h>

using i64 = long long;

constexpr i64 inf = 1E18;

void solve() {
  int n;
  std::cin >> n;
  
  std::vector<i64> a(n + 1);
  for (int i = 1; i <= n; i++) {
    std::cin >> a[i];
  }

  std::vector<std::vector<int>> adj(n + 1);
  for (int i = 1; i < n; i++) {
    int u, v;
    std::cin >> u >> v;

    adj[u].push_back(v);
    adj[v].push_back(u);
  }

  const int M = std::log2(n) + 1;
  std::vector<std::vector<i64>> DP(n + 1, std::vector<i64>(M + 1));
  auto dfs = [&](auto self, int u, int p) -> void {
    for (int i = 1; i <= M; i++) {
      DP[u][i] = a[u] * i;
    }

    for (auto v : adj[u]) {
      if (v == p) {
        continue;
      }
      self(self, v, u);

      for (int i = 1; i <= M; i++) {
        i64 min = inf;
        for (int j = 1; j <= M; j++) {
          if (j != i) {
            min = std::min(min, DP[v][j]);
          }
        }
        DP[u][i] += min;
      }
    }
  };
  dfs(dfs, 1, 0);

  i64 ans = inf;
  for (int i = 1; i <= M; i++) {
    ans = std::min(ans, DP[1][i]);
  }
  std::cout << ans << "\n";
}

int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);

  int T;
  std::cin >> T;

  while (T--) {
    solve();
  }
  return 0;
}

Bonus

在更新 \(dp[U][i]\) 之前,可以记录 \(dp[V][j]\) 的前缀、后缀最小值,这样可以实现 \(\mathcal{O}(logn)\) 的更新。总的时间复杂度可以优化到 \(\mathcal{O}(nlogn)\)

点击查看代码
#include <bits/stdc++.h>

using i64 = long long;

constexpr i64 inf = 1E18;

void solve() {
  int n;
  std::cin >> n;
  
  std::vector<i64> a(n + 1);
  for (int i = 1; i <= n; i++) {
    std::cin >> a[i];
  }

  std::vector<std::vector<int>> adj(n + 1);
  for (int i = 1; i < n; i++) {
    int u, v;
    std::cin >> u >> v;

    adj[u].push_back(v);
    adj[v].push_back(u);
  }

  const int M = std::log2(n) + 1;
  std::vector<std::vector<i64>> DP(n + 1, std::vector<i64>(M + 1));
  std::vector<i64> pre(M + 1), suf(M + 1);
  auto dfs = [&](auto self, int u, int p) -> void {
    for (int i = 1; i <= M; i++) {
      DP[u][i] = a[u] * i;
    }

    for (auto v : adj[u]) {
      if (v == p) {
        continue;
      }
      self(self, v, u);

      pre[1] = DP[v][1], suf[M] = DP[v][M];
      for (int i = 2; i <= M; i++) {
        pre[i] = std::min(pre[i - 1], DP[v][i]);
      }
      for (int i = M - 1; i >= 1; i--) {
        suf[i] = std::min(suf[i + 1], DP[v][i]);
      }
      for (int i = 1; i <= M; i++) {
        i64 min = inf;
        if (i - 1 >= 1) {
          min = std::min(min, pre[i - 1]);
        }
        if (i + 1 <= M) {
          min = std::min(min, suf[i + 1]);
        }
        DP[u][i] += min;
      }

    }
  };
  dfs(dfs, 1, 0);

  i64 ans = inf;
  for (int i = 1; i <= M; i++) {
    ans = std::min(ans, DP[1][i]);
  }
  std::cout << ans << "\n";
}

int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);

  int T;
  std::cin >> T;

  while (T--) {
    solve();
  }
  return 0;
}