Sam's Numbers 矩阵快速幂优化dp

https://www.hackerrank.com/contests/hourrank-21/challenges/sams-numbers

设dp[s][i]表示产生的总和是s的时候,结尾符是i的所有合法方案数。

那么dp[s][i]可以由dp[s - i][1---m]中,abs(i - k) <= d的递推过来。

但是s很大,不能这样解决。

考虑到m只有10,而且dp[s][1]只能由dp[s - 1][1...m]递推过来。

那么先预处理dp[1--m][1--m]

写成m * m的一行(离散化一下就好)

写成dp[1][1]、dp[1][2].....dp[1][m]、、dp[2][1]、dp[2][2].....dp[2][m]、、dp[3][1]、dp[3][2]......dp[3][m]这样

那么dp[1][1]要递推变成dp[2][1],可以构造一个矩阵出来就好了。前m * m - m个可以由后面的得到。

构造矩阵难得就是dp[m][k]要递推到dp[m + 1][k],模拟一下就能找到递推式。主要是知道m比较小,可以暴力离散化来搞。

复杂度1e6 * logs

#include <bits/stdc++.h>
#define IOS ios::sync_with_stdio(false)
using namespace std;
#define inf (0x3f3f3f3f)
typedef long long int LL;
const int maxn = 100 + 2;
struct Matrix {
    LL a[maxn][maxn];
    int row;
    int col;
}ans, base;
//struct Matrix matrix_mul(struct Matrix a, struct Matrix b, int MOD) {  //求解矩阵a*b%MOD
//    struct Matrix c = {0};
//    c.row = a.row;
//    c.col = b.col;
//    for (int i = 1; i <= a.row; i++) {
//        for (int j = 1; j <= b.col; j++) {
//            for (int k = 1; k <= b.row; k++) {
//                c.a[i][j] += a.a[i][k] * b.a[k][j];
//                c.a[i][j] = (c.a[i][j] + MOD) % MOD;
//            }
//        }
//    }
//    return c;
//}

struct Matrix matrix_mul(struct Matrix a, struct Matrix b, int MOD) {
    struct Matrix c = {0};
    c.row = a.row;
    c.col = b.col;
    for (int i = 1; i <= a.row; ++i) {
        for (int k = 1; k <= a.col; ++k) {
            if (a.a[i][k]) {
                for (int j = 1; j <= b.col; ++j) {
                    c.a[i][j] += a.a[i][k] * b.a[k][j];
                    c.a[i][j] = (c.a[i][j] + MOD) % MOD;
                }
            }
        }
    }
    return c;
}
struct Matrix quick_matrix_pow(struct Matrix ans, struct Matrix base, LL n, int MOD) {
    while (n) {
        if (n & 1) {
            ans = matrix_mul(ans, base, MOD);
        }
        n >>= 1;
        base = matrix_mul(base, base, MOD);
    }
    return ans;
}
const int MOD = 1e9 + 9;
void add(int &x, int y) {
    x += y;
    if (x >= MOD) x -= MOD;
}
int dp[20][20];
LL s, m, d;
int getId(int x, int y) {
    return (x - 1) * m + y;
}
int res[maxn];
void work() {
    cin >> s >> m >> d;
    for (int i = 1; i <= m; ++i) dp[i][i] = 1;
    for (int i = 2; i <= m; ++i) {
        for (int j = 1; j <= m && j <= i; ++j) {
            for (int k = 1; k <= m; ++k) {
                if (abs(k - j) > d) continue;
                add(dp[i][j], dp[i - j][k]);
            }
        }
    }
    if (s <= m) {
        int ans = 0;
        for (int i = 1; i <= m; ++i) {
            add(ans, dp[s][i]);
        }
        cout << ans << endl;
        return;
    }
    for (int i = 1; i <= m; ++i) {
        for (int j = 1; j <= m; ++j) {
            res[getId(i, j)] = dp[i][j];
        }
    }
    base.col = base.row = m * m;
    int to = m + 1;
    for (int i = 1; i <= m * m - m; ++i) {
        base.a[to][i] = 1;
        to ++;
    }
    int now = 1;
    for (int i = m * m - m + 1; i <= m * m; ++i) {
//        int id = getId(m - now + 1, now);
        for (int j = 1; j <= m; ++j) {
            if (abs(j - now) > d) continue;
            base.a[getId(m - now + 1, j)][i] = 1;
        }
        now++;
    }
    ans.row = 1, ans.col = m * m;
    to = 1;
    for (int i = 1; i <= m; ++i) {
        for (int j = 1; j <= m; ++j) {
            ans.a[1][to] = dp[i][j];
            to++;
        }
    }
//
//    for (int i = 1; i <= m * m; ++i) {
//        for (int j = 1; j <= m * m; ++j) {
//            cout << base.a[i][j] << " ";
//        }
//        cout << endl;
//    }
//    cout << endl;
//    for (int i = 1; i <= m * m; ++i) {
//        cout << ans.a[1][i] << " ";
//    }
//    cout << endl;
    ans = quick_matrix_pow(ans, base, s - m, MOD);
    int out = 0;
    for (int i = m * m - m + 1; i <= m * m; ++i) {
        add(out, ans.a[1][i]);
    }
//    for (int i = 1; i <= m * m; ++i) {
//        cout << ans.a[1][i] << " ";
//    }
//    cout << endl;
    cout << out << endl;
}
int main() {
#ifdef local
    freopen("data.txt", "r", stdin);
//    freopen("data.txt", "w", stdout);
#endif
    IOS;
    work();
    return 0;
}
View Code