CodeForces 494B Obsessive String ——(字符串DP+KMP)

  这题的题意就很晦涩。题意是:问有多少种方法,把字符串s划分成不重叠的子串(可以不使用完s的所有字符,但是这些子串必须不重叠),使得t串是所有这些新串的子串。譬如第一个样例,"ababa"和"aba",共有5种方法:{aba}(前3个),{aba}(后3个),{abab},{baba},{ababa}。

  先设s的长度为lena,t的长度为lenb。

  做法是:用dp[i]表示到i为止,有几种方案数,所以最终答案是dp[lena]。然后考虑转移。首先dp[i]至少等于dp[i-1]。然后考虑把包含了一个t串的a[1~i]的后缀给取出来,不妨设目前的a[1~i]为*****abc,,t是"abc",然后新增的方法数有:如果右边包含t串的是"abc",那么左边的选择有(dp[i-3]+1)种(1是左边为空串的情况),类似的,如果右边包含t串的是"*abc",那么左边是(dp[i-4]+1)。那么累和一下即是,(dp[0]+1)+(dp[1]+1)+(dp[2]+1)+...+(dp[i-lenb]+1)。显然的,dp[0] = 0。那么这个式子就可以化简为sum(dp[1~i-lenb])+i-lenb+1。sum部分可以利用前缀和优化,如果假设pre[i]是dp[1~i]的和的话,那么sum部分变为pre[i-lenb],其余部分是i-lenb+1。假设val[i]表示i这个位置开始往前,至少包含了一个t串时的位置,例如串"*****abc",val[8] = 6,串"*****abc*",val[9] = 6。理解了val数组的含义以后,我们就可以发现如果i这个位置是t串的最后一个位置的话,那么val[i] = i-lenb+1,否则呢,val[i] = val[i-1]即可。我们可以用kmp把能够处理的val直接处理出来,其余的val通过递推得到。有了val数组以后,这个dp的转移就可以完全得到了:dp[i] = dp[i-1] + pre[val[i]-1] + val[i]。然后dp的过程中维护一下前缀和即可,最后需要注意要对1e9+7取模。

  另外需要注意的是,如果说当前的串最后一个字符并不是t串的最后一个字符,例如"*****abc*",这样的话通过上面这个转移我们可以知道,右边的串也必须是以最后一个字符结尾的(实际上任何情况下右边这个串都是需要以最后一个字符结尾的,否则新增的这个字符就没有意义了,也就会出现和之前的方案重复的情况了)。这一点想清楚了,这个转移就没有任何的问题了。

  最终的代码如下:

 1 #include <stdio.h>
 2 #include <algorithm>
 3 #include <string.h>
 4 using namespace std;
 5 const int N = 1e5 + 5;
 6 const int mod = 1e9 + 7;
 7 
 8 char a[N],b[N];
 9 int lena,lenb,nxt[N];
10 int val[N];
11 int sum[N];
12 int dp[N];
13 void get_nxt()
14 {
15     nxt[1] = 0;
16     int j = 0;
17     for(int i=2;i<=lenb;i++)
18     {
19         while(j && b[j+1] != b[i]) j = nxt[j];
20         if(b[j+1] == b[i]) j++;
21         nxt[i] = j;
22     }
23     
24     j = 0;
25     for(int i=1;i<=lena;i++)
26     {
27         while(j && b[j+1] != a[i]) j = nxt[j];
28         if(b[j+1] == a[i]) j++;
29         if(j == lenb)
30         {
31             val[i] = i - lenb + 1;
32             j = nxt[j];
33         }
34     }
35 }
36 
37 int main()
38 {
39     scanf("%s%s",a+1,b+1);
40     lena = strlen(a+1);
41     lenb = strlen(b+1);
42     get_nxt();
43     for(int i=1;i<=lena;i++) if(!val[i]) val[i] = val[i-1];
44     for(int i=1;i<=lena;i++)
45     {
46         dp[i] = dp[i-1];
47         if(val[i]) dp[i] = (dp[i] + sum[val[i]-1] + val[i]) % mod;
48         sum[i] = (sum[i-1] + dp[i]) % mod;
49     }
50     printf("%d
",dp[lena]);
51     return 0;
52 }