bzoj 4180: 字符串计数 (后缀自动机+矩阵乘法)
4180: 字符串计数
Time Limit: 10 Sec Memory Limit: 128 MB Submit: 140 Solved: 63 [Submit][Status][Discuss]Description
SD有一名神犇叫做Oxer,他觉得字符串的题目都太水了,于是便出了一道题来虐蒟蒻yts1999。 他给出了一个字符串T,字符串T中有且仅有4种字符 'A', 'B', 'C', 'D'。现在他要求蒟蒻yts1999构造一个新的字符串S,构造的方法是:进行多次操作,每一次操作选择T的一个子串,将其加入S的末尾。 对于一个可构造出的字符串S,可能有多种构造方案,Oxer定义构造字符串S所需的操作次数为所有构造方案中操作次数的最小值。 Oxer想知道对于给定的正整数N和字符串T,他所能构造出的所有长度为N的字符串S中,构造所需的操作次数最大的字符串的操作次数。 蒟蒻yts1999当然不会做了,于是向你求助。Input
第一行包含一个整数N,表示要构造的字符串长度。 第二行包含一个字符串T,T的意义如题所述。Output
输出文件包含一行,一个整数,为你所求出的最大的操作次数。Sample Input
5 ABCCADSample Output
5HINT
【样例说明】 例如字符串"AAAAA",该字符串所需操作次数为5,不存在能用T的子串构造出的,且所需操作次数比5大的字符串。 【数据规模和约定】 对于100%的数据,1 ≤ N ≤ 10^18,1 ≤ |T| ≤ 10^5。Source
By yts1999
题解:后缀自动机+矩阵乘法
感觉好久不做后缀自动机,思维能力明显退化了,想了好久才想出来,而且调的过程中还发现各种手残,感觉也是醉了。不过不得不说,其实这是一道不错的题。
看到10^18自然想到有很大的概率与矩阵乘法有关,然后看到模式串是10^5那么矩阵的状态一定不是模式串的位置,或者后缀自动机的结点之类的。因为我们需要用很多的子串来构造合法的序列,并且必须保证将两个串连接起来不能的到新的长度超过前面串的子串,否则一定存在操作次数更少的方案。那么如果我们要尽可能最大化操作的次数,就需要选取的可拼接的子串的长度尽可能小。
考虑构造转移矩阵,f[i][j]表示已i开头的子串可以接j开头的子串的最短长度。这个矩阵在构造的时候,可以对模式串建立后缀自动机,然后按照拓扑倒序,从后往前更新答案。mu[i][j]表示后缀自动机中的结点i向后选择一个最短长度使其可以接j开头的子串。mu[i][j]=min(mu[i][j],mu[ch[i][k]][j]+1)
f[i][j]=mu[ch[1][i]][j] 1号结点的i儿子能到达的所有串都是以i开头的子串。
对于这种最小值最大的问题,一般都可以用二分答案来做。所以我们二分一个操作数,然后用矩阵快速幂做这么多次,然后判断一下得到的最小长度是否>=n,如果大于等于n说明答案有可能最小。
#include<iostream> #include<cstring> #include<cstdio> #include<algorithm> #include<cmath> #define N 300005 #define M 4 #define LL long long using namespace std; LL m,mu[N][6]; const LL inf=1e18; int n,ch[N][6],fa[N],l[N],np,q,cnt,root,nq,p,vis[N],last; char s[N]; struct data{ LL a[20][20]; }num; void extend(int x) { int c=s[x]-'A'+1; p=last; np=last=++cnt; l[np]=l[p]+1; for (;p&&!ch[p][c];p=fa[p]) ch[p][c]=np; if (!p) fa[np]=root; else { int q=ch[p][c]; if (l[q]==l[p]+1) fa[np]=q; else { nq=++cnt; l[nq]=l[p]+1; memcpy(ch[nq],ch[q],sizeof(ch[q])); fa[nq]=fa[q]; fa[q]=fa[np]=nq; for (;ch[p][c]==q;p=fa[p]) ch[p][c]=nq; } } } void clear(data &a,LL x) { for (int i=1;i<=M;i++) for (int j=1;j<=M;j++) a.a[i][j]=x; } data mul(data a,data b) { data c; clear(c,inf); for (int i=1;i<=M;i++) for (int j=1;j<=M;j++) for (int k=1;k<=M;k++) c.a[i][j]=min(c.a[i][j],a.a[i][k]+b.a[k][j]); return c; } data quickpow(data num,LL x) { data ans; clear(ans,0); data base; base=num; while (x) { if (x&1) ans=mul(ans,base); x>>=1; base=mul(base,base); } return ans; } void dfs(int x) { if (vis[x]) return; vis[x]=1; for (int i=1;i<=4;i++) if (ch[x][i]) dfs(ch[x][i]),mu[x][i]=inf; for (int i=1;i<=4;i++) { if (!ch[x][i]) { mu[x][i]=1; continue; } for (int j=1;j<=4;j++) mu[x][j]=min(mu[x][j],mu[ch[x][i]][j]+1); } } void build() { dfs(root); for (int i=1;i<=M;i++) for (int j=1;j<=M;j++) num.a[i][j]=mu[ch[1][i]][j]; } bool check(LL x) { data ans1=quickpow(num,x); LL mx=inf; for (int i=1;i<=M;i++) for (int j=1;j<=M;j++) mx=min(mx,ans1.a[i][j]); return mx>=m; } int main() { freopen("a.in","r",stdin); freopen("my.out","w",stdout); scanf("%I64d",&m); scanf("%s",s+1); n=strlen(s+1); root=last=++cnt; for (int i=1;i<=n;i++) extend(i); build(); //for (int i=1;i<=4;i++,cout<<endl) // for (int j=1;j<=4;j++) cout<<num.a[i][j]<<" "; LL l=1; LL r=m; LL ans=m+1; while (l<=r) { LL mid=(l+r)/2; if (check(mid)) ans=min(ans,mid),r=mid-1; else l=mid+1; } if (ans==m+1) PRintf("0\n"); else printf("%I64d\n",ans); }