HDU4758 Walk Through Squares AC自动机&&dp
这道题当时做的时候觉得是数论题,包含两个01串什么的,但是算重复的时候又很蛋疼,赛后听说是字符串,然后就觉得很有可能。昨天队友问到这一题,在学了AC自动机之后就觉得简单了许多。那个时候不懂AC自动机,不知道什么是状态,因此没有想到有效的dp方法。
题意是这样的,给定两个RD串,譬如RRD,DDR这样子的串,然后现在要你向右走(R)m步,向下走(D)n步,问有多少种走法能够包含给定的两个串。
一个传统的dp思想是这样的 dp[i][j][x][y][k],表示走了i步R,j步D,x,y表示两个串各匹配了多少各,k表示的是1,2串匹配的一个4进制数(00,01,10,11,你懂的,11表示都匹配了,10表示匹配了1串)。 但是这样一来空间开不下,二来当某个点失配的时候我们不知道当前的x,y会转移到哪里,这个时候很自然的,我们就想到了AC自动机,AC自动机压入两个串只需要不超过串的总长度的结点,而且当我们在自动机上转移的时候,我们可以知道失配的时候转移到哪里。所以重新定义一下就是 dp[i][j][k][x] k表示自动机上的状态,x表示4进制数。转移的时候就考虑由当前的状态dp[i][j][k][x]转移到dp[i+1][j][nxt1][nxtx] dp[i][j+1].... 其中新的状态nxt以及对应的四进制数转移就需要根据AC自动机的失配算出来。 如果预处理出当失配时回到的那个结点感觉可能会更快一些。 我代码里多写了个dfs,主要是预处理了 到达改状态时对应的四进制数,所以转移的时候只需要或一下就可以了。
第一次做AC自动机上的dp然后 1A了,好开心!
#pragma warning(disable:4996) #include<iostream> #include<cstring> #include<string> #include<cstdio> #include<algorithm> #include<vector> #include<cmath> #include<queue> #define maxn 2000 #define mod 1000000007 using namespace std; struct Trie { Trie * go[2]; Trie *fail; int sta; void init() { memset(go, 0, sizeof(go)), fail == NULL; sta = 0; } }pool[maxn],*root; int tot; void insert(char *c,int type) { int len = strlen(c); Trie *p = root; for (int i = 0; i < len; i++){ int ind = c[i] == 'R' ? 0 : 1; if (p->go[ind] != NULL){ p = p->go[ind]; } else{ pool[tot].init(); p->go[ind] = &pool[tot++]; p = p->go[ind]; } } p->sta |= type; } void getFail() { queue<Trie*> que; que.push(root); root->fail = NULL; while (!que.empty()) { Trie *temp = que.front(); que.pop(); Trie *p = NULL; for (int i = 0; i < 2; i++){ if (temp->go[i] != NULL){ if (temp == root) temp->go[i]->fail = root; else{ p = temp->fail; while (p != NULL){ if (p->go[i] != NULL){ temp->go[i]->fail = p->go[i]; break; } p = p->fail; } if (p == NULL) temp->go[i]->fail = root; } que.push(temp->go[i]); } } } } int dfs(Trie *x){ if (x == NULL) return 0; return x->sta |= dfs(x->fail); } int m, n; int dp[120][120][240][4]; char str[120]; int main() { int T; cin >> T; while (T--) { tot = 0; root = &pool[tot++]; root->init(); scanf("%d%d", &m, &n); for (int i = 1; i <= 2; i++){ scanf("%s", str); insert(str, i); } getFail(); for (int i = 0; i < tot; i++){ dfs(&pool[i]); } for (int i = 0; i <= m; i++){ for (int j = 0; j <= n; j++){ for (int k = 0; k <= tot; k++){ for (int x = 0; x < 4; x++){ dp[i][j][k][x] = 0; } } } } dp[0][0][0][0] = 1; for (int i = 0; i <= m; i++){ for (int j = 0; j <= n; j++){ for (int k = 0; k < tot; k++){ for (int x = 0; x < 4; x++){ Trie* p = &pool[k]; if (p->go[0] != NULL){ (dp[i + 1][j][p->go[0] - pool][x | p->go[0]->sta] += dp[i][j][k][x]) %= mod; } else{ Trie *temp = p->fail; while (temp != NULL) { if (temp->go[0] != NULL){ (dp[i + 1][j][temp->go[0] - pool][x | temp->go[0]->sta] += dp[i][j][k][x]) %= mod; break; } temp = temp->fail; } if (temp == NULL) (dp[i + 1][j][0][x | root->sta] += dp[i][j][k][x]) %= mod; } if (p->go[1] != NULL){ (dp[i][j + 1][p->go[1] - pool][x | p->go[1]->sta] += dp[i][j][k][x]) %= mod; } else{ Trie *temp = p->fail; while (temp != NULL) { if (temp->go[1] != NULL){ (dp[i][j + 1][temp->go[1] - pool][x | temp->go[1]->sta] += dp[i][j][k][x]) %= mod; break; } temp = temp->fail; } if (temp == NULL) (dp[i][j + 1][0][x | root->sta] += dp[i][j][k][x]) %= mod; } } } } } int ans = 0; for (int i = 0; i < tot; i++){ ans = ans + dp[m][n][i][3]; ans %= mod; } printf("%d ", ans); } return 0; }