P3311 [SDOI2014] 数数

SunnyYuan的博客 / 2024-09-26 / 原文

参考题解做法。

题目

image

思路

数位 dp + AC 自动机好题。

直接往下递归,dfs(u, ver, limit, st) 表示目前在数字 \(n\) 的第 \(u\) 位进行讨论,\(ver\) 表示当前在 AC 自动机上的节点,\(limit\) 是是否步步紧逼 \(n\),只要位数不足 \(n\) 的位数或者有一位小于 \(n\) 的那一位就不叫步步紧逼,\(st\) 表示现在是否已经进入数字,因为很多数字位数不如 \(n\),就相当于在它们前面填充 \(0\)

在往下递归过程中,如果遇到边界,那么立刻返回 1,注意对 \(0\) 的特判(题目中说是 \(1\)\(n\),不是 \(0\)\(n\));可以确认这个 DP 是个 DAG,所以加上记忆化搜索,避免 TLE;getfail 时注意传递标识。

代码

#include <bits/stdc++.h>

using namespace std;

const int N = 1510, mod = 1e9 + 7;

string s, n;
int t;
int f[N][N * 10][2][2];

struct _ac {
    int ch[N][10], fail[N * 10], idx;
    bool val[N * 10];

    void insert(string& s) {
        int p = 0;
        for (auto x : s) {
            int u = x - '0';
            if (!ch[p][u]) ch[p][u] = ++idx;
            p = ch[p][u];
        }
        val[p] = 1;
    }

    void getfail() {
        queue<int> q;

        for (int i = 0; i < 10; i++) {
            if (ch[0][i]) {
                q.push(ch[0][i]);
            }
        }

        while (q.size()) {
            int t = q.front();
            q.pop();
            for (int i = 0; i < 10; i++) {
                if (ch[t][i]) {
                    fail[ch[t][i]] = ch[fail[t]][i];
                    q.push(ch[t][i]);
                    val[ch[t][i]] |= val[fail[ch[t][i]]];
                }
                else ch[t][i] = ch[fail[t]][i];
            }
        }
    }
} ac;

int dfs(int u, int ver, bool limit, bool st) {      // u : 数字 n 的长度, ver : 对应 ac 自动机的节点编号, limit : 是否被限制, st : 是否还未进入数字(用 0 填充)
    if (ac.val[ver]) return 0;                      // 如果遇到标记,立即返回
    if (u >= n.size()) return !st;                  // 注意对 0 的去除
    if (f[u][ver][limit][st] != -1) return f[u][ver][limit][st];// 记忆化搜索
    int up = limit ? n[u] - '0' : 9;                // 限制
    int ans = 0;                                    // 结果

    for (int i = 0; i <= up; i++) {
        bool nxt_limit = (limit && i == up) ? true : false;
        bool nxt_st = (st && i == 0) ? true : false;
        int  nxt_ver = (st && i == 0) ? 0 : ac.ch[ver][i];
        ans = (ans + dfs(u + 1, nxt_ver, nxt_limit, nxt_st)) % mod;
    }
    f[u][ver][limit][st] = ans;
    return ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    memset(f, -1, sizeof(f));

    cin >> n;
    cin >> t;
    while (t--) {
        cin >> s;
        ac.insert(s);
    }
    ac.getfail();
    int res = dfs(0, 0, true, true);
    cout << res << '\n';
    return 0;
}