[BZOJ4556][Tjoi2016&Heoi2016]字符串
[BZOJ4556][Tjoi2016&Heoi2016]字符串
试题描述
佳媛姐姐过生日的时候,她的小伙伴从某东上买了一个生日礼物。生日礼物放在一个神奇的箱子中。箱子外边写了一个长为n的字符串s,和m个问题。佳媛姐姐必须正确回答这m个问题,才能打开箱子拿到礼物,升职加薪,出任CEO,嫁给高富帅,走上人生巅峰。每个问题均有a,b,c,d四个参数,问你子串s[a..b]的所有子串和s[c..d]的最长公共前缀的长度的最大值是多少?佳媛姐姐并不擅长做这样的问题,所以她向你求助,你该如何帮助她呢?
输入
输入的第一行有两个正整数n,m,分别表示字符串的长度和询问的个数。接下来一行是一个长为n的字符串。接下来m行,每行有4个数a,b,c,d,表示询问s[a..b]的所有子串和s[c..d]的最长公共前缀的最大值。1<=n,m<=100,000,字符串中仅有小写英文字母,a<=b,c<=d,1<=a,b,c,d<=n
输出
对于每一次询问,输出答案。
输入示例
5 5 aaaaa 1 1 1 5 1 5 1 1 2 3 2 3 2 4 2 3 2 3 2 4
输出示例
1 1 2 2 2
数据规模及约定
见“输入”
题解
注意题目描述,不是 s[c..d] 中的子串,而是 s[c..d] 本身。那么我们可以想象就是用 s[c..d] 这段字符串去匹配原串,要求找到所有匹配位置,然后取范围在 s[a..b] 中的答案。
要找到所有匹配位置,就可以想到后缀自动机了,然而“后缀”自动机不能直接支持前缀操作(很明显 s[c..d] 匹配的是一段前缀),解决办法很简单,把串倒过来即可。
接下来,我们构建后缀自动机。大体思路便有了(我们令 A = n - a + 1, B = n - b + 1, C = n - c + 1, D = n - d + 1,那么就是区间 [D, C] 的一段后缀取匹配区间 [B, A] 中的字符串),肯定是要查询 right 集合中含有 C 的所有状态节点——即从 right = { C } 所对应的节点到根节点的路径,我们只需要看这一条链上所有 right 集合中是否含有 [B, A] 中的数就好了。
但是,有一些细节还欠考虑。那就是我们须要保证 [B, A] 中匹配到的子串全都在 [B, A] 这个范围内,所以仅仅查找 right 集合中是否含有 [B, A] 中的数是不够的。改进一下,我们二分答案 x,然后我们需要找 maxlength ≥ x 的节点,然后看是否有 [B + x - 1, A] 中的数就好了。但这样会变成 O(n log2 n) 的,当然 105 的范围是可接受的。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 200010 #define maxc 26 #define maxlog 18 int n; char str[maxn]; int ToT, rt, lst, to[maxn][maxc], par[maxn], mxl[maxn], node[maxn]; void extend(int pos) { int x = str[pos] - 'a', np = ++ToT, p = lst; mxl[np] = mxl[p] + 1; node[pos] = np; lst = np; while(p && !to[p][x]) to[p][x] = np, p = par[p]; if(!p){ par[np] = rt; return ; } int q = to[p][x]; if(mxl[q] == mxl[p] + 1){ par[np] = q; return ; } int nq = ++ToT; mxl[nq] = mxl[p] + 1; par[nq] = par[q]; par[q] = par[np] = nq; memcpy(to[nq], to[q], sizeof(to[q])); while(p && to[p][x] == q) to[p][x] = nq, p = par[p]; return ; } struct Tree { int m, head[maxn], nxt[maxn], to[maxn], dl[maxn], dr[maxn], clo, fa[maxn][maxlog], val[maxn]; int tot, Rt[maxn], lc[maxn*maxlog], rc[maxn*maxlog], sumv[maxn*maxlog]; Tree(): m(0), clo(0) { memset(head, 0, sizeof(head)); } void AddEdge(int a, int b) { to[++m] = b; nxt[m] = head[a]; head[a] = m; return ; } void build(int u) { dl[u] = ++clo; for(int i = 1; i < maxlog; i++) fa[u][i] = fa[fa[u][i-1]][i-1]; for(int e = head[u]; e; e = nxt[e]) fa[to[e]][0] = u, build(to[e]); dr[u] = clo; return ; } void update(int& y, int x, int l, int r, int p) { sumv[y = ++tot] = sumv[x] + 1; if(l == r) return ; int mid = l + r >> 1; lc[y] = lc[x]; rc[y] = rc[x]; if(p <= mid) update(lc[y], lc[x], l, mid, p); else update(rc[y], rc[x], mid + 1, r, p); return ; } int query(int o, int l, int r, int ql, int qr) { if(!o) return 0; if(ql <= l && r <= qr) return sumv[o]; int mid = l + r >> 1, ans = 0; if(ql <= mid) ans += query(lc[o], l, mid, ql, qr); if(qr > mid) ans += query(rc[o], mid + 1, r, ql, qr); return ans; } void gettree() { for(int i = 2; i <= ToT; i++) AddEdge(par[i], i); build(1); for(int i = 1; i <= n; i++) val[dl[node[i]]] = i; // for(int i = 1; i <= ToT; i++) printf("%d%c", val[i], i < ToT ? ' ' : ' '); for(int i = 1; i <= ToT; i++) if(val[i]) update(Rt[i], Rt[i-1], 1, n, val[i]); else Rt[i] = Rt[i-1]; return ; } bool check(int x, int lft, int rgt, int ql, int qr) { int u = node[qr]; for(int i = maxlog - 1; i >= 0; i--) if(mxl[fa[u][i]] >= x) u = fa[u][i]; // printf("check(%d): node %d [%d, %d] val[%d, %d] ", x, u, dl[u], dr[u], lft + x - 1, rgt); return query(Rt[dr[u]], 1, n, lft + x - 1, rgt) - query(Rt[dl[u]-1], 1, n, lft + x - 1, rgt) > 0; } void solve(int q) { while(q--) { int rgt = n - read() + 1, lft = n - read() + 1, qr = n - read() + 1, ql = n - read() + 1; int l = 0, r = min(rgt - lft + 1, qr - ql + 1) + 1; while(r - l > 1) { int mid = l + r >> 1; if(check(mid, lft, rgt, ql, qr)) l = mid; else r = mid; } printf("%d ", l); } return ; } } sol; int main() { n = read(); int q = read(); scanf("%s", str + 1); for(int i = 1; i <= (n >> 1); i++) swap(str[i], str[n-i+1]); ToT = rt = 1; for(int i = 1; i <= n; i++) extend(i); sol.gettree(); sol.solve(q); return 0; }