loj #547. 「LibreOJ β Round #7」匹配字符串 #547. 「LibreOJ β Round #7」匹配字符串

 

题目描述

对于一个 01 串(即由字符 0 和 1 组成的字符串)sss,我们称 sss 合法,当且仅当串 sss 的任意一个长度为 mmm 的子串 s′s's​​,不为全 1 串。

请求出所有长度为 nnn 的 01 串中,有多少合法的串,答案对 655376553765537 取模。

输入格式

输入共一行,包含两个正整数 n,mn,mn,m。

输出格式

输出共一行,表示所求的和对 655376553765537 取模的结果。

样例

样例输入 1

5 2

样例输出 1

13

样例解释 1

以下是所有合法的串:

00000
00001
00010
00100
00101
01000
01001
01010
10000
10001
10010
10100
10101

样例输入 2

2018 7

样例输出 2

27940

#include<iostream>
#include<cstdio>
#include<cstring>
#define mod 65537
using namespace std;
int n,m;
struct node{
    int n,m;
    int a[3][3];
    node(){memset(a,0,sizeof(a));}
    node operator * (const node &b)const{
        node res;
        res.n=n;res.m=b.m;
        for(int i=1;i<=n;i++)
            for(int j=1;j<=b.m;j++)
                for(int k=1;k<=m;k++)
                    res.a[i][j]+=1LL*a[i][k]*b.a[k][j]%mod;
        return res;
    }
};
bool check(int sta){
    int cnt=0;
    for(int i=1;i<=n;i++){
        if(sta&(1<<i-1))cnt++;
        else cnt=0;
        if(cnt>=m)return 0;
    }
    return 1;
}
node Pow(node x,int y){
    node res;
    res.n=2;res.m=2;
    res.a[1][1]=1;res.a[2][2]=1;
    while(y){
        if(y&1)res=res*x;
        x=x*x;
        y>>=1;
    }
    return res;
}
void work1(){
    node a;
    a.n=1;a.m=2;
    a.a[1][1]=1;a.a[1][2]=1;
    node b;
    b.n=b.m=2;
    b.a[1][1]=1;b.a[1][2]=1;b.a[2][1]=1;
    b=Pow(b,n-1);
    a=a*b;
    int ans=(a.a[1][1]+a.a[1][2])%mod;
    printf("%d",ans);
}
int main(){
    scanf("%d%d",&n,&m);
    if(m==1){puts("1");return 0;}
    if(m==2){work1();return 0;}
    int ans=0;
    for(int sta=0;sta<(1<<n);sta++)
        if(check(sta)){
            ans++;
            if(ans>=mod)ans-=mod;
        }
    printf("%d",ans);
    return 0;
}
13分 矩阵快速幂优化dp+枚举
#include<iostream>
#include<cstdio>
#include<cstring>
#define mod 65537
using namespace std;
long long n,m;
int tmp[6005],b[6005],c[6005],ans,inv[mod],fac[mod],bin[3005];
int Pow(int x,int y){
    int res=1;
    while(y){
        if(y&1)res=1LL*res*x%mod;
        x=1LL*x*x%mod;
        y>>=1;
    }
    return res;
}
int C(int n,int m){
    if(m>n)return 0;
    return 1LL*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int Lucas(long long n,long long m){
    if(!m)return 1;
    return 1LL*Lucas(n/mod,m/mod)*C(n%mod,m%mod)%mod;
}
void mul(int *b,int *c){
    for(int i=0;i<m;i++)
        for(int j=0;j<m;j++)
        tmp[i+j]=(tmp[i+j]+1LL*b[i]*c[j]%mod)%mod;
    for(int i=2*m-2;i>=m;i--)
        for(int j=1;j<=m;j++)
        tmp[i-j]=(tmp[i-j]+tmp[i])%mod;
    for(int i=0;i<m;i++)b[i]=tmp[i],tmp[i]=tmp[i+m]=0;
}
void solve1(){
    bin[0]=c[0]=1;
    if(m==1)b[0]=1;
    else b[1]=1;
    for(int i=1;i<m;i++)
        bin[i]=1LL*2*bin[i-1]%mod;
    while(n){
        if(n&1)mul(c,b);
        mul(b,b);
        n>>=1;
    }
    for(int i=0;i<m;i++)
        ans=(ans+1LL*bin[i]*c[i]%mod)%mod;
    cout<<ans;
}
int s(long long n){
    long long base=Pow(Pow(2,m+1),mod-2);
    long long cc=Pow(2,n);int res=0;
    for(int k=0;k*(m+1)<=n;k++){
        long long tt=1LL*Lucas(n-k*m,k)*cc%mod;
        tt=(k&1)?mod-tt:tt;
        res=(res+tt)>=mod?res-mod+tt:res+tt;
        cc=1LL*cc*base%mod;
    }
    return res;
}
void solve2(){
    fac[0]=fac[1]=1;
    for(int i=2;i<mod;i++)fac[i]=1LL*fac[i-1]*i%mod;
    inv[mod-1]=mod-1;
    for(int i=mod-1;i>=1;i--)inv[i-1]=1LL*inv[i]*i%mod;
    ans=s(n+1)-s(n);
    printf("%d
",(ans<0)?ans+mod:ans);
}
int main(){
    cin>>n>>m;
    if(m<=2500)solve1();
    else solve2();
    return 0;
}
100分