Luogu5346 做题记录

Lgx_ / 2024-12-19 / 原文

入坑 PGF。 link

考虑 \(F(x),G(x)\) 分别表示 从起始状态到达最终状态 和 从最终状态到达最终状态 的 OGF,注意不一定是第一次到达。设 \(H(x)\) 表示从起始状态第一次到达最终状态的 OGF,那么有 \(H(x)F(x) = G(x)\)\(H(x) = \dfrac {G(x)} {F(x)}\)。答案为 \(H'(1) = \dfrac {F'(1)G(1) - F(1)G'(1)} {G^2(1)}\)

由于是有序排列组合,可以计算出两个生成函数对应的 EGF:

\[\hat F(x) = \prod_{i = 1} ^ n \dfrac {e^{p_i x} + (-1)^{s_i}e^{-p_ix}} 2 \]

\[\hat G(x) = \prod_{i = 1} ^ n \dfrac {e^{p_i x} + e^{-p_ix}} 2 \]

可以通过背包转化为 \(\hat F(x) = \sum\limits_i a_i e^{ix}, \ \hat G(x) = \sum\limits_i b_i e^{ix}\) 的形式。

进而得到 \(F(x) = \mathscr L \hat F(x) = \sum\limits_i \dfrac {a_i} {1 - ix}, \ G(x) = \mathscr L \hat G(x) = \sum\limits_i \dfrac {b_i} {1 - ix}\)

注意到当 \(x = 1\) 时,\(1 - x = 0\) 不收敛。但是整个答案的分式是收敛的,所以考虑上下同时乘以 \(1 - x\)。对于 \(F(1)\)

\[\begin{aligned} (1 - x)F(x) &= \sum\limits_i \dfrac {a_i (1 - x)} {1 - ix} \\ &= a_1 + \sum\limits_{i \not = 1} \dfrac {a_i (1 - x)} {1 - ix} \end{aligned}\]

\[(1 - x) F(1) = a_1 \]

对于 \(F'(1)\)

\[\begin{aligned} (1 - x) F'(x) &= \sum\limits_i \left( \dfrac {a_i (1 - x)} {1 - ix} \right)' \\ &= \sum\limits_{i \not = 1} \dfrac {-a_i (1 - ix) - a_i (1 - x) (-i)} {(1 - ix) ^ 2} \\ &= \sum\limits_{i \not = 1} \dfrac {(i - 1)a_i} {(1 - ix) ^ 2} \end{aligned}\]

\[\begin{aligned} (1 - x) F'(1) &= \sum\limits_{i \not = 1} \dfrac {(i - 1) a_i} {(1 - i) ^ 2} \\ &= \sum\limits_{i \not = 1} \dfrac {a_i} {i - 1} \end{aligned}\]

\(G(1), G'(1)\) 的求解同理,时间复杂度 \(\mathcal O(n^2 p)\)

点击查看代码
#include <bits/stdc++.h>
#define ll long long
#define fi first
#define se second
#define pir pair <ll, ll>
#define mkp make_pair
#define pb push_back
using namespace std;
template <class T>
void rd(T &x) {
	char c; ll f = 1;
	while(!isdigit(c = getchar()))
		if(c == '-') f = -1;
	x = c - '0';
	while(isdigit(c = getchar())) x = x * 10 + c - '0';
	x *= f;
}
const ll maxn = 2e5 + 10, mod = 998244353, inv = mod - mod / 2;
void add(ll &x, const ll y) { x = x + y >= mod? x + y - mod : x + y; }
ll n, a[maxn], p[maxn], sum, f[110][100010], g[110][100010];
ll suminv;
ll power(ll a, ll b = mod - 2) {
	ll s = 1;
	while(b) {
		if(b & 1) s = s * a %mod;
		a = a * a %mod, b >>= 1;
	} return s;
}
int main() {
	rd(n);
	for(ll i = 1; i <= n; i++) rd(a[i]);
	for(ll i = 1; i <= n; i++) rd(p[i]), sum += p[i];
	f[0][sum] = g[0][sum] = 1; suminv = power(sum);
	for(ll i = 1; i <= n; i++)
		for(ll j = 0; j <= 2 * sum; j++) {
			if(j >= p[i])
				f[i][j] += f[i - 1][j - p[i]],
				g[i][j] += g[i - 1][j - p[i]];
			if(j + p[i] <= 2 * sum)
				f[i][j] += (a[i]? mod - 1 : 1) * f[i - 1][j + p[i]] %mod,
				g[i][j] += g[i - 1][j + p[i]];
			f[i][j] = f[i][j] * inv %mod, g[i][j] = g[i][j] * inv %mod;
		}
	ll f1 = f[n][sum << 1], _f1 = 0, g1 = g[n][sum << 1], _g1 = 0;
	for(ll i = 0; i < 2 * sum; i++) {
		ll iv = (i - 2 * sum + mod) * suminv %mod;
		iv = power(iv);
		_f1 = (_f1 + f[n][i] * iv) %mod;
		_g1 = (_g1 + g[n][i] * iv) %mod;
	}
	printf("%lld", (_f1 * g1 - f1 * _g1 %mod + mod)
	 %mod * power(g1 * g1 %mod) %mod);
	return 0;
}