Luogu3193 HNOI2008 GT考试 Description Solution Code

怎么这么神仙的题目呀

link

(n) 位数,求其中不含有长度为 (m) 的一个数字串的方案数

(nle 10^9,mle 20)

Solution

一开始想到了正难则反,但是发现两边重合之后容斥的重复情况不会推……

之后想到了一个 (dp)

(f_{i,j}) 为前 (i) 位数,有前 (j) 的后一个串的方案数

最后答案是 (sumlimits^{m-1}_{i=0} f_{n,i})

然后的转移?

这个匹配长度好像还是不可处理

(看题解ing)

发现可以用 (kmp) 算转移的方法

然后就又会了一点:

[f_{i,j}=sum^{n}_ {k=1} f_{i-1,k} imes g_{k,j} ]

(g[i][j]) 就是 长度转移方案数

计算方法:

		for(int i=2,j=0;i<=m;++i)
		{
			while(j&&s[i]!=s[j+1]) j=nxt[j];
			if(s[i]==s[j+1]) ++j; nxt[i]=j;
		}
		for(int i=0;i<m;++i)
		{
			for(int j='0';j<='9';++j)
			{
				int k=i;
				while(k&&s[k+1]!=j) k=nxt[k];
				if(s[k+1]==j) ++k;
				++g[i+1][k+1]; 
			}
		}

之后这个 (n) 的范围?

还是不大行

(我不知为啥能想到矩乘,怕是面向数据范围编程)

这里的优化还是挺妙的:

(dp) 式子是个矩乘的式子对吧……

(F[0]) 表示 (i) 位数的矩阵,就是由所有 (f_{i,j}) 构成的

这个就变成了 (F_i=F_{i-1} imes G)

(G) 就是 (g_{i,j}) 是固定的

同时我们发现这里的 (F_0) 的第一位是个 (1)

(这里好像可以当做一个单位矩阵)

还不用再处理一下,就没了

(kmp+dp) 矩阵快速幂优化还是挺厉害的

[ans=sum_{i=0}^{n-1} F[n]_{1,i} ]

如果不理解请先学习矩阵加速数列(蒟蒻看了30min的加速数列就懂了这题)

Code

#include<bits/stdc++.h>
using namespace std;
#define int long long
namespace yspm{
	inline int read()
	{
		int res=0,f=1; char k;
		while(!isdigit(k=getchar())) if(k=='-') f=-1;
		while(isdigit(k)) res=res*10+k-'0',k=getchar();
		return res*f;
	}
	const int N=30;
	int n,m,mod,nxt[N],ans;
	char s[N];
	struct mat{
		int a[N][N];
		int* operator [](int x){return a[x];}
		inline void init(){for(int i=1;i<=m;++i) a[i][i]=1; return ;}
	}g;
	inline mat mul(mat x,mat y)
	{
		mat ans; memset(ans.a,0,sizeof(ans.a));
		for(int i=1;i<=m;++i)
		{
			for(int j=1;j<=m;++j)
			{
				for(int k=1;k<=m;++k)
				{
					ans[i][j]+=x[i][k]*y[k][j]; 
				}
			}
		}
		for(int i=1;i<=m;++i) for(int j=1;j<=m;++j) ans[i][j]%=mod;
		return ans;
	}
	inline mat ksm(mat x,int y)
	{
		mat ans; ans.init();
		for(;y;y>>=1,x=mul(x,x)) if(y&1) ans=mul(ans,x);
		return ans;
	}
	signed main()
	{
		n=read(); m=read(); mod=read();
		scanf("%s",s+1); nxt[0]=1;
		for(int i=2,j=0;i<=m;++i)
		{
			while(j&&s[i]!=s[j+1]) j=nxt[j];
			if(s[i]==s[j+1]) ++j; nxt[i]=j;
		}
		for(int i=0;i<m;++i)
		{
			for(int j='0';j<='9';++j)
			{
				int k=i;
				while(k&&s[k+1]!=j) k=nxt[k];
				if(s[k+1]==j) ++k;
				++g[i+1][k+1]; 
			}
		}
		g=ksm(g,n);
		for(int i=1;i<=m;++i) ans+=g[1][i],ans%=mod;
		printf("%lld
",ans);
		return 0;
	}
}
signed main(){return yspm::main();}