BZOJ 2555 = P5212 SubString 题解

laijinyi / 2024-09-23 / 原文

Statement

给你一个字符串 \(\text{init}\),要求你支持两个操作:

  1. 在当前字符串的后面插入一个字符串;
  2. 询问字符串 \(s\) 在当前字符串中出现了几次?(作为连续子串)

你必须在线支持这些操作。

Solution

extend 中 link[cur] = q,相当于 link,链加

新建 copy 那里,相当于 link,cut,链加

询问,相当于单点查权值

于是 SAM + LCT 做完了


想一想有没有常数更小的方法

考虑单点加,LCT 维护子树和

考虑 Link、Cut 操作仅仅是改变父亲,可以用 access 实现

于是没有链操作了,没有 Makeroot、Pushdown 了,常数小多了

Code 1

有 down,无 up

#include <bits/stdc++.h>
using namespace std;
#define rep(i, j, k) for (int i = (j); i <= (k); ++i)
#define reo(i, j, k) for (int i = (j); i >= (k); --i)
typedef long long ll;
const int N = 8e6 + 10;

namespace LCT {
	int fa[N], ch[N][2], rev[N], val[N], add[N];
  #define get(u) (u == ch[fa[u]][1])
  #define nrt(u) (u == ch[fa[u]][0] || u == ch[fa[u]][1])
	void Add(int u, int v) {
		add[u] += v, val[u] += v;
	}
	void swp(int u) {
		rev[u] ^= 1, swap(ch[u][0], ch[u][1]);
	}
	void upd(int u) {
		if (add[u]) Add(ch[u][0], add[u]), Add(ch[u][1], add[u]), add[u] = 0;
		if (rev[u]) swp(ch[u][0]), swp(ch[u][1]), rev[u] = 0;
	}
	void down(int u) {
		if (nrt(u)) down(fa[u]);
		upd(u);
	}
	void rot(int u) {
		int f = fa[u], g = fa[f], k = get(u);
		if (nrt(f)) ch[g][get(f)] = u;
		ch[f][k] = ch[u][!k];
		if (ch[u][!k]) fa[ch[u][!k]] = f;
		ch[u][!k] = f, fa[f] = u, fa[u] = g;
	}
	void splay(int u) {
		for (down(u); nrt(u); rot(u)) if (nrt(fa[u])) rot(get(u) == get(fa[u]) ? fa[u] : u);
	}
	void access(int u) {
		for (int v = 0; u; v = u, u = fa[u]) splay(u), ch[u][1] = v;
	}
	void makert(int u) {
		access(u), splay(u), swp(u);
	}
	void link(int u, int v) {
		makert(u), fa[u] = v;
	}
	void cut(int u, int v) {
		makert(u), access(v), splay(v), ch[v][0] = fa[u] = 0;
	}
	int split(int u, int v) {
		makert(u), access(v), splay(v);
		return v;
	}
  #undef get
  #undef nrt
}

namespace SAM {
	int sz, cur, last, len[N], link[N], nxt[N][2];
	void init() {
		link[0] = -1;
	}
	void extend(int ch) {
		cur = ++sz, len[cur] = len[last] + 1;
		int p = last;
		for (; ~p; p = link[p])
			if (!nxt[p][ch]) nxt[p][ch] = cur;
			else break;
		if (!~p) {
			link[cur] = 0, LCT::link(cur + 1, 1);
		} else {
			int q = nxt[p][ch];
			if (len[p] + 1 == len[q]) {
				link[cur] = q, LCT::link(cur + 1, q + 1);
			} else {
				int copy = ++sz;
				LCT::split(q + 1, q + 1), LCT::split(copy + 1, copy + 1), LCT::Add(copy + 1, LCT::val[q + 1]);
				LCT::cut(q + 1, link[q] + 1);
				LCT::link(copy + 1, link[q] + 1);
				LCT::link(q + 1, copy + 1);
				LCT::link(cur + 1, copy + 1);
				link[copy] = link[q], link[q] = copy, link[cur] = copy, len[copy] = len[p] + 1;
				rep(i, 0, 1) nxt[copy][i] = nxt[q][i];
				for (; ~p; p = link[p])
					if (nxt[p][ch] == q) nxt[p][ch] = copy;
					else break;
			}
		}
		last = cur, LCT::Add(LCT::split(1, cur + 1), 1);
	}
}

void decodeWithMask(string &s, int mask) {
	int len = s.length();
	rep(j, 0, len - 1) swap(s[j], s[mask = (mask * 131 + j) % len]);
}

int main() {
	ios::sync_with_stdio(false), cin.tie(nullptr);
	int q;
	cin >> q;
	string init;
	cin >> init;
	int initlen = init.length();
	SAM::init();
	rep(i, 0, initlen - 1) SAM::extend(init[i] - 'A');
	int mask = 0;
	while (q--) {
		string type, str;
		cin >> type >> str, decodeWithMask(str, mask);
		int len = str.length();
		if (type == "ADD") {
			rep(i, 0, len - 1) SAM::extend(str[i] - 'A');
		}
		if (type == "QUERY") {
			int pos = 0, res = 0, ok = 1;
			rep(i, 0, len - 1) {
				if (!SAM::nxt[pos][str[i] - 'A']) {
					ok = 0;
					break;
				}
				pos = SAM::nxt[pos][str[i] - 'A'];
			}
			if (ok) {
				LCT::split(1, pos + 1);
				cout << (res = LCT::val[pos + 1]) << '\n';
			} else cout << "0\n";
			mask ^= res;
		}
	}
	return 0;
}

Code 2

有 up,无 down(维护子树信息)

#include <bits/stdc++.h>
using namespace std;
#define rep(i, j, k) for (int i = (j); i <= (k); ++i)
#define reo(i, j, k) for (int i = (j); i >= (k); --i)
typedef long long ll;
const int N = 8e6 + 10;

namespace LCT {
	int fa[N], ch[N][2], All[N], MyXu[N], val[N];
  #define get(u) (u == ch[fa[u]][1])
  #define nrt(u) (u == ch[fa[u]][0] || u == ch[fa[u]][1])
	void up(int u) {
		All[u] = All[ch[u][0]] + All[ch[u][1]] + val[u] + MyXu[u];
	}
	void rot(int u) {
		int f = fa[u], g = fa[f], k = get(u);
		if (nrt(f)) ch[g][get(f)] = u;
		ch[f][k] = ch[u][!k];
		if (ch[u][!k]) fa[ch[u][!k]] = f;
		ch[u][!k] = f, fa[f] = u, fa[u] = g, up(f), up(u);
	}
	void splay(int u) {
		for (; nrt(u); rot(u)) if (nrt(fa[u])) rot(get(u) == get(fa[u]) ? fa[u] : u);
	}
	void access(int u) {
		for (int v = 0; u; v = u, u = fa[u]) 
			splay(u), MyXu[u] = MyXu[u] - All[v] + All[ch[u][1]], ch[u][1] = v, up(u);
	}
	void Add(int u, int v) {
		access(u), splay(u), All[u] += v, val[u] += v;
	}
	void link(int u, int v) {
		fa[u] = v, MyXu[v] += All[u], access(u);
	}
	void cut(int u, int v) {
		access(u), splay(v), All[v] -= All[u], fa[u] = ch[v][1] = 0;
	}
	int qry(int u) {
		return access(u), splay(u), MyXu[u] + val[u];
	}
  #undef get
  #undef nrt
}

namespace SAM {
	int sz, cur, last, len[N], link[N], nxt[N][2];
	void init() {
		link[0] = -1;
	}
	void extend(int ch) {
		cur = ++sz, len[cur] = len[last] + 1;
		int p = last;
		for (; ~p; p = link[p])
			if (!nxt[p][ch]) nxt[p][ch] = cur;
			else break;
		if (!~p) {
			link[cur] = 0, LCT::link(cur + 1, 1);
		} else {
			int q = nxt[p][ch];
			if (len[p] + 1 == len[q]) {
				link[cur] = q, LCT::link(cur + 1, q + 1);
			} else {
				int copy = ++sz;
				LCT::cut(q + 1, link[q] + 1);
				LCT::link(copy + 1, link[q] + 1);
				LCT::link(q + 1, copy + 1);
				LCT::link(cur + 1, copy + 1);
				link[copy] = link[q], link[q] = copy, link[cur] = copy, len[copy] = len[p] + 1;
				rep(i, 0, 1) nxt[copy][i] = nxt[q][i];
				for (; ~p; p = link[p])
					if (nxt[p][ch] == q) nxt[p][ch] = copy;
					else break;
			}
		}
		last = cur, LCT::Add(cur + 1, 1);
	}
}

void decodeWithMask(string &s, int mask) {
	int len = s.length();
	rep(j, 0, len - 1) swap(s[j], s[mask = (mask * 131 + j) % len]);
}

int main() {
	ios::sync_with_stdio(false), cin.tie(nullptr);
	int q;
	cin >> q;
	string init;
	cin >> init;
	int initlen = init.length();
	SAM::init();
	rep(i, 0, initlen - 1) SAM::extend(init[i] - 'A');
	int mask = 0;
	while (q--) {
		string type, str;
		cin >> type >> str, decodeWithMask(str, mask);
		int len = str.length();
		if (type == "ADD") {
			rep(i, 0, len - 1) SAM::extend(str[i] - 'A');
		}
		if (type == "QUERY") {
			int pos = 0, res = 0, ok = 1;
			rep(i, 0, len - 1) {
				if (!SAM::nxt[pos][str[i] - 'A']) {
					ok = 0;
					break;
				}
				pos = SAM::nxt[pos][str[i] - 'A'];
			}
			if (ok) {
				cout << (res = LCT::qry(pos + 1)) << '\n';
			} else cout << "0\n";
			mask ^= res;
		}
	}
	return 0;
}