HDU 6397 组合数学+容斥 母函数 Character Encoding

Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 131072/131072 K (Java/Others)
Total Submission(s): 1473    Accepted Submission(s): 546


Problem Description
In computer science, a character is a letter, a digit, a punctuation mark or some other similar symbol. Since computers can only process numbers, number codes are used to represent characters, which is known as character encoding. A character encoding system establishes a bijection between the elements of an alphabet of a certain size k?

Since the answer may be large, you only need to output it modulo 998244353.
 
Input
The first line of input is a single integer 6, respectively.
 
Output
For each test case, display the answer modulo 998244353 in a single line.
 
Sample Input
4
2 3 3
2 3 4
3 3 3
128 3 340
 
Sample Output
1
0
7
903
 

容斥写法

x1+x2+...+xm = k (xi>=0) 共有C(k+m-1,m-1) 种 插板法

如果有c个违反条件 把每一个违反条件的x减去n

x1'+x2'+x3'+x4'+x5'+...+xn'= k-c*n xi>=0 共有 C(k-c*n+m-1,m-1)种
    容斥系数    变量选法
ans  = (-1)^c   *   C(m,c)       *     C(k-cn+m-1,m-1)

母函数写法

1+x+x^2+...+x^(n-1)=(1-x^n)/(1-x)

(1+x+x^2+...+x^(n-1))^m

=(1-x^n)^m/(1-x)^m
=(1-x^n)^m*(1-x)^(-m)
=(1-x^n)^m*(sum_ (x^i)*C(m+i-1,m-1)) //上篇博客说的核武器。。。。

ans=x^k 的系数
左边二项式展开 按照每个i 右边应该有k-ni
ans= sum (-1)^i*C(m,i)*C(m+k-n*I-1,m-1)

左边 x^n*i      右边x^(k-n*i)
系数(-1)^i*C(m,i)   系数C(m+k-n*I-1,m-1)

AC代码

#include <bits/stdc++.h>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define all(a) (a).begin(), (a).end()
#define fillchar(a, x) memset(a, x, sizeof(a))
#define huan printf("
");
#define debug(a,b) cout<<a<<" "<<b<<" "<<endl;
using namespace std;
const int maxn= 3e5+10;
const int inf = 0x3f3f3f3f,mod=998244353;
typedef long long ll;
ll fac[maxn],inv[maxn];
void init()
{
    fac[0]=fac[1]=1;
    inv[0]=inv[1]=1;
    for(ll i=2;i<maxn;i++)
    {
        fac[i]=fac[i-1]*i%mod;
        inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    }
    for(ll i=2;i<maxn;i++)         
        inv[i]=inv[i-1]*inv[i]%mod;  
}
ll C(ll x,ll y)
{
    if(y>x) return 0;
    if(y==0||x==0) return 1;
    return fac[x]*inv[y]%mod*inv[x-y]%mod;
}
int main()
{
    ll n,m,k,t;
    init();
    cin>>t;
    while(t--)
    {
        cin>>n>>m>>k;
        if(k==0)
        {
            cout<<1<<endl;
            continue;
        }
        else if((n-1)*m<k)
        {
            cout<<0<<endl;
            continue;
        }
        int c=min(k/n,m);
        ll ans=0;
        for(int i=0;i<=c;i++)
        {
            if(i%2==0)
                ans=(ans+C(m,i)*C(k-i*n+m-1,m-1)%mod)%mod;
            else
                ans=(ans-C(m,i)*C(k-i*n+m-1,m-1)%mod+mod)%mod;
        }
        cout<<ans<<endl;
    }
}