2019 牛客暑期多校 G subsequence 1 (dp+组合数)

题目:https://ac.nowcoder.com/acm/contest/885/G

题意:给你两个串,要求上面哪个串的子序列的值大于下面这个串的值的序列个数,不含前导零

思路:我们很容易就可以看出选择上面的那个的字符串的子序列时如果长度超过下面这个一定可以选用,同理长度短于的肯定不能用,长度大的部分我们可以用组合数来进行求解

所以难点就在于我们要求序列长度等于的时候有多少个序列满足要求,因为种数太多我们可以想到应该要用dp来优化复杂度,想一下,如果当前某一位低于了,那么这个串就肯定比他小了

所以我们要计算两种状态,到当前位置依然相等的个数 0  ,当前已经大于的序列个数 1

dp[i][j][k]

当前第一个串的第i位匹配第二个串的第j位第k种状态时有多少个序列

最后dp[n-1][m-1][1]就是长度相等时的答案

当然这里需要降维,把第一维去掉,我们就可以采用01背包那种做法,把第二种循环遍历顺序取反即可

最后我们还要判断以0开头的情况,这样是不合法的,但是我组合数时计算了值,所以我们最后还要去掉

 

#include<bits/stdc++.h>
#define maxn 3005
#define mod  998244353
using namespace std;
typedef long long ll;
ll dp[maxn][2];
ll c[maxn][maxn];
int n,m;
char s1[maxn],s2[maxn];
ll fact[maxn],ifact[maxn];
ll p(ll a,ll b){
    if(a+b>=mod){
        return a+b-mod;
    }
    else return a+b;
}
int main(){
    int t;
    scanf("%d",&t);
    c[1][0]=c[1][1]=1;
    for(int i=2;i<maxn;i++){
        c[i][0]=1;
        for(int j=1;j<=i;j++)
            c[i][j]=p(c[i-1][j],c[i-1][j-1]);
    }
    while(t--){
        memset(dp,0,sizeof(dp));
        scanf("%d%d",&n,&m);
        scanf("%s%s",s1,s2);
        for(int i=0;i<n;i++){//计算相等情况
            if(i!=0)
            {
                for(int j=min(i,m-1);j>=1;j--){
                    if(s1[i]==s2[j]){
                        dp[j][0]=p(dp[j][0],dp[j-1][0]);
                        dp[j][1]=p(dp[j][1],dp[j-1][1]);
                    }
                    else if(s1[i]>s2[j]){
                        dp[j][1]=p(dp[j][1],dp[j-1][0]);
                        dp[j][1]=p(dp[j][1],+dp[j-1][1]);
                        
                    }
                    else{
                        dp[j][1]=p(dp[j][1],dp[j-1][1]);
                    }
                }
            } 
            if(s1[i]==s2[0]) dp[0][0]++;
            else if(s1[i]>s2[0]) dp[0][1]++; 
        }
        ll sum=dp[m-1][1];
    //    printf("%lld
",sum); 
        for(int i=m+1;i<=n;i++){//计算大于长度的时候
            sum=p(sum,c[n][i]);
        } 
        for(int i=0;i<n;i++){//去掉以0开头
            if(s1[i]=='0'){
                ll len=n-i-1;
                for(int j=m;j<=n&&j<=len;j++){
                    sum=p(sum-c[len][j],mod);
                }
            }
        } 
        printf("%lld
",p(sum,mod));
    }
}