题解:AT_agc027_b [AGC027B] Garbage Collector

naughty-naught / 2024-10-15 / 原文

Problem Link

[AGC027B] Garbage Collector

题意

原题翻译已经很不错了,这里不再赘述。

思路

推论:每次取的垃圾数量应尽可能均分。

证明

如图,假设有 \(4\) 个垃圾需要被捡起,有两种取法:

  • 取一号垃圾+取二三四号垃圾。

  • 取一二号垃圾+取二三号垃圾。

前者所需能量为:\(\displaystyle 20x_1+16x_2+9x_3+4x_4\)

后者所需能量为:\(\displaystyle 18x_1+13x_2+9x_3+4x_4\)

由于保证所有 \(x_i\) 均大于 \(0\),所以后者明显更优。然后使用归纳法很容易证明推论是正确的。

之后思路就很清晰了,每次从 \(1\)\(n\) 枚举 \(k\),计算每段垃圾所需能量之和。

但是如何计算呢?

暴力计算式子很裸,由于我们一定是走到最后一个要取的垃圾再开始取垃圾,我们将 \(x\) 数组 reverse 一下,那么所需能量即为:

\[ x_1+\sum_{i=1}^{n} ((i+1)^2+p_i-p_{i+1}) \]

其实这个是可以化简的:

\[ x_1+\sum_{i=1}^{n} ((i+1)^2+p_i-p_{i+1}) = 5x_1+\sum_{i=2}^{n}(2i+3) \times p_i \]

这样,对于每个 \(k\),前缀和处理价值即可,每个 \(k\) 的复杂度是 \(\displaystyle O(\frac{n}{k})\),故总复杂度为 \(\displaystyle O(n \ln(n))\)

其他疑问见代码。

代码

// written by Naught

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef __int128 it;
// #define int long long
#define Maxn 400005
#define p2(x) (1ll*(x)*(x))
#define fo(i, l, r) for (int i = l; i <= r; ++i)
#define fr(i, r, l) for (int i = l; i >= r; --i)
// #define getchar()(p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
// char buf[1<<21], *p1 = buf, *p2 = buf;
// inline int read(int x=0, bool f=0, char c=getchar()) {for(;!isdigit(c);c=getchar()) f^=!(c^45);for(;isdigit(c);c=getchar()) x=(x<<1)+(x<<3)+(c^48);return f?-x:x;}
// inline ll lread(ll x=0, bool f=0, char c=getchar()) {for(;!isdigit(c);c=getchar()) f^=!(c^45);for(;isdigit(c);c=getchar()) x=(x<<1)+(x<<3)+(c^48);return f?-x:x;}
void train() {ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);}

int n, d[Maxn];
it ans = 2e24, x, s[Maxn];

void ct(it x)
{
    string anss;
    while(x)
    {
        anss = char(x%10+'0') + anss;
        x /= 10;
    }
    cout << anss << '\n';
}

int main()
{
    train();
    ll nn, xx;
    // n = read(), x = read();
    cin >> nn >> xx;
    n = nn, x = xx;
    fo(i, 1, n) cin >> d[i];
    reverse(d+1, d+n+1);
    fo(i, 1, n) s[i] = s[i-1]+d[i];
    fo(k, 1, n)
    {
        it res = k*x+5*s[k];
        fo(i, 2, (n+k-1)/k) res += (2*i+1)*(s[min(i*k, n)]-s[(i-1)*k]);
        ans = min(ans, res);
    }
    // cout << ans + x*n;
    ans += x*n;
    ct(ans);
    return 0;
}
/*
*/

Tips

因为会爆 long long,所以建议使用 __int128