BZOJ 4545 DQS 的 trie 题解

laijinyi / 2024-09-23 / 原文

Statement

维护一棵树,边权 \(\in\{\texttt a,\texttt b,\texttt c\}\),根为 \(1\),定义这棵树的子串为从 \(1\) 走到所有点构成的字符串的所有后缀,需要支持以下操作:

  • 问当前树的本质不同子串数
  • 给一个点添加一棵子树
  • 问一个串在当前树中作为子串的出现次数

Solution

直接广义 SAM + LCT


考虑如何减小常数。

离线,先把最后的 trie 对应的广义 SAM、parent 树建出来

操作 2 相当于把一些点由未出现变成出现,单点加

操作 1 相当于求出现过的点的 \(\sum \text{len}(u)-\text{len}(\text{link}(u))\)

操作 3 相当于求子树和

于是直接树状数组,单点加、区间和,再维护一下 \(\sum\text{len}(u)-\text{len}(\text{link}(u))\) 即可

Code 1

广义 SAM + LCT

#include <bits/stdc++.h>
using namespace std;
#define rep(i, j, k) for (int i = (j); i <= (k); ++i)
#define reo(i, j, k) for (int i = (j); i >= (k); --i)
typedef long long ll;
const int N = 4e5 + 10;

namespace LCT {
	int fa[N], ch[N][2], All[N], MyXu[N], val[N];
  #define get(u) (u == ch[fa[u]][1])
  #define nrt(u) (u == ch[fa[u]][0] || u == ch[fa[u]][1])
	void up(int u) {
		All[u] = All[ch[u][0]] + All[ch[u][1]] + MyXu[u] + val[u];
	}
	void rot(int u) {
		int f = fa[u], g = fa[f], k = get(u);
		if (nrt(f)) ch[g][get(f)] = u;
		ch[f][k] = ch[u][!k];
		if (ch[u][!k]) fa[ch[u][!k]] = f;
		ch[u][!k] = f, fa[f] = u, fa[u] = g, up(f), up(u);
	}
	void splay(int u) {
		for (; nrt(u); rot(u)) if (nrt(fa[u])) rot(get(u) == get(fa[u]) ? fa[u] : u);
	}
	void access(int u) {
		for (int v = 0; u; v = u, u = fa[u]) 
			splay(u), MyXu[u] = MyXu[u] - All[v] + All[ch[u][1]], ch[u][1] = v, up(u);
	}
	void Add(int u, int v) {
		access(u), splay(u), All[u] += v, val[u] += v;
	}
	void link(int u, int v) {
		access(v), splay(v), fa[u] = v, MyXu[v] += All[u];
	}
	void cut(int u, int v) {
		access(u), splay(v), All[v] -= All[u], ch[v][1] = fa[u] = 0;
	}
	int qry(int u) {
		return access(u), splay(u), MyXu[u] + val[u];
	}
  #undef get
  #undef nrt
}

namespace SAM {
	int sz, cur, len[N], link[N], nxt[N][3];
	ll sum;
	void init() {
		link[0] = -1;
	}
	int extend(int ch, int last) {
		if (nxt[last][ch]) {
			if (len[last] + 1 == len[nxt[last][ch]]) return LCT::Add(nxt[last][ch] + 1, 1), nxt[last][ch];
			else {
				int p = last, q = nxt[last][ch], copy = ++sz;
				LCT::cut(q + 1, link[q] + 1), LCT::link(copy + 1, link[q] + 1), LCT::link(q + 1, copy + 1);
				link[copy] = link[q], link[q] = copy, len[copy] = len[p] + 1;
				rep(i, 0, 2) nxt[copy][i] = nxt[q][i];
				for (; ~p; p = link[p])
					if (nxt[p][ch] == q) nxt[p][ch] = copy;
					else break;
				LCT::Add(copy + 1, 1);
				return copy;
			}
		}
		cur = ++sz, len[cur] = len[last] + 1;
		int p = last;
		for (; ~p; p = link[p])
			if (!nxt[p][ch]) nxt[p][ch] = cur;
			else break;
		if (!~p) {
			link[cur] = 0, LCT::link(cur + 1, 1);
		} else {
			int q = nxt[p][ch];
			if (len[p] + 1 == len[q]) {
				link[cur] = q, LCT::link(cur + 1, q + 1);
			} else {
				int copy = ++sz;
				LCT::cut(q + 1, link[q] + 1);
				LCT::link(copy + 1, link[q] + 1);
				LCT::link(q + 1, copy + 1);
				LCT::link(cur + 1, copy + 1);
				link[copy] = link[q], link[q] = copy, link[cur] = copy, len[copy] = len[p] + 1;
				rep(i, 0, 2) nxt[copy][i] = nxt[q][i];
				for (; ~p; p = link[p])
					if (nxt[p][ch] == q) nxt[p][ch] = copy;
					else break;
			}
		}
		sum += len[cur] - len[link[cur]];
		LCT::Add(cur + 1, 1);
		return cur;
	}
}

struct Edge {
	int t, n, w;
} E[N << 1];
int tot, h[N];
void Add(int u, int v, int w) {
	E[++tot] = (Edge){v, h[u], w}, h[u] = tot;
}
int pos[N], vis[N];

void BFS(int rt) {
	queue<int> q;
	q.push(rt);
	while (!q.empty()) {
		int u = q.front();
		q.pop();
		vis[u] = 1;
		int p = pos[u];
		for (int i = h[u]; i; i = E[i].n) {
			int v = E[i].t, w = E[i].w;
			if (!vis[v]) {
				pos[v] = SAM::extend(w, p), q.push(v);
			}
		}
	}
}

int n0, m;

int main() {
	ios::sync_with_stdio(false), cin.tie(nullptr);
	int id;
	cin >> id >> n0;
	SAM::init();
	rep(i, 1, n0 - 1) {
		int u, v;
		char c;
		cin >> u >> v >> c, Add(u, v, c - 'a'), Add(v, u, c - 'a');
	}
	BFS(1);
	cin >> m;
	while (m--) {
		int opt;
		cin >> opt;
		if (opt == 1) {
			cout << SAM::sum << '\n';
		}
		if (opt == 2) {
			int rt, s;
			cin >> rt >> s;
			rep(i, 1, s - 1) {
				int u, v;
				char c;
				cin >> u >> v >> c, Add(u, v, c - 'a'), Add(v, u, c - 'a');
			}
			BFS(rt);
		}
		if (opt == 3) {
			string s;
			cin >> s;
			int len = s.length(), p = 0, ok = 1;
			rep(i, 0, len - 1) {
				int now = s[i] - 'a';
				if (SAM::nxt[p][now]) {
					p = SAM::nxt[p][now];
				} else {
					ok = 0;
					break;
				}
			}
			if (ok) {
				cout << LCT::qry(p + 1) << '\n';
			} else {
				cout << "0\n";
			}
		}
	}
	return 0;
}

Code 2

离线 + 树状数组

#include <bits/stdc++.h>
using namespace std;
#define rep(i, j, k) for (int i = (j); i <= (k); ++i)
#define reo(i, j, k) for (int i = (j); i >= (k); --i)
typedef long long ll;
const int N = 4e5 + 10;

struct Edge {
	int t, n, w;
} e[N << 1];
int tot, h[N];
void Add(int u, int v, int w) {
	e[++tot] = (Edge){v, h[u], w}, h[u] = tot;
}
int n0, m;

struct Oper {
	int opt, rt, s;
	string S;
} Ops[N];

int sz, cur, len[N], link[N], nxt[N][3];
void init() {
	link[0] = -1;
}
int extend(int ch, int last) {
	cur = ++sz, len[cur] = len[last] + 1;
	int p = last;
	for (; ~p; p = link[p])
		if (!nxt[p][ch]) nxt[p][ch] = cur;
		else break;
	if (!~p) {
		link[cur] = 0;
	} else {
		int q = nxt[p][ch];
		if (len[p] + 1 == len[q]) {
			link[cur] = q;
		} else {
			int copy = ++sz;
			link[copy] = link[q], link[q] = copy, link[cur] = copy, len[copy] = len[p] + 1;
			rep(i, 0, 2) nxt[copy][i] = nxt[q][i];
			for (; ~p; p = link[p])
				if (nxt[p][ch] == q) nxt[p][ch] = copy;
				else break;
		}
	}
	return cur;
}

int pos[N];

void BFS(int be) {
	queue<int> q;
	q.push(be);
	while (!q.empty()) {
		int u = q.front(), p = pos[u];
		q.pop();
		for (int i = h[u]; i; i = e[i].n) {
			int v = e[i].t, w = e[i].w;
			if (!~pos[v]) {
				if (!nxt[p][w]) {
					pos[v] = extend(w, p);
				} else {
					pos[v] = nxt[p][w];
				}
				q.push(v);
			}
		}
	}
}

vector<int> G[N];
int tim, siz[N], dfn[N];
void DFS(int u) {
	dfn[u] = ++tim, siz[u] = 1;
	for (int v : G[u]) DFS(v), siz[u] += siz[v];
}
void BuildLinkTree() {
	rep(i, 1, sz) G[link[i]].push_back(i);
	DFS(0);
}

struct BIT {
	ll sum[N];
	void init() {
		rep(i, 1, tim) sum[i] = 0;
	}
	void upd(int x, int v) {
		for (; x <= tim; x += x & -x) sum[x] += v;
	}
	ll qry(int x) {
		ll res = 0;
		for (; x; x -= x & -x) res += sum[x];
		return res;
	}
	ll qry(int l, int r) {
		return qry(r) - qry(l - 1);
	}
} bit;
int vis[N];
ll sum;

int main() {
	ios::sync_with_stdio(false), cin.tie(nullptr);
	int id;
	cin >> id >> n0;
	rep(i, 1, n0 - 1) {
		int u, v;
		char c;
		cin >> u >> v >> c, Add(u, v, c - 'a'), Add(v, u, c - 'a');
	}
	cin >> m;
	rep(i, 1, m) {
		cin >> Ops[i].opt;
		if (Ops[i].opt == 2) {
			cin >> Ops[i].rt >> Ops[i].s;
			rep(p, 1, Ops[i].s - 1) {
				int u, v;
				char c;
				cin >> u >> v >> c, Add(u, v, c - 'a'), Add(v, u, c - 'a');
			}
		}
		if (Ops[i].opt == 3) cin >> Ops[i].S;
	}
	memset(pos, -1, sizeof(pos));
	init(), pos[1] = 0, BFS(1), BuildLinkTree(), bit.init();
	int cnt = n0;
	auto update = [&](int x) {
		x = pos[x];
		bit.upd(dfn[x], 1);
		while (x && !vis[x]) sum += len[x] - len[link[x]], vis[x] = 1, x = link[x];
	};
	rep(i, 1, cnt) update(i);
	rep(i, 1, m) {
		int opt = Ops[i].opt;
		if (opt == 1) {
			cout << sum << '\n';
		}
		if (opt == 2) {
			rep(p, cnt + 1, cnt + Ops[i].s - 1) update(p);
			cnt += Ops[i].s - 1;
		}
		if (opt == 3) {
			string S = Ops[i].S;
			int u = 0, Slen = S.length(), ok = 1;
			rep(p, 0, Slen - 1) {
				if (nxt[u][S[p] - 'a']) {
					u = nxt[u][S[p] - 'a'];
				} else {
					ok = 0;
					break;
				}
			}
			if (ok) {
				cout << bit.qry(dfn[u], dfn[u] + siz[u] - 1) << '\n';
			} else {
				cout << "0\n";
			}
		}
	}
	return 0;
}