bzoj 2142: 礼物【中国剩余定理+组合数学】

参考:http://blog.csdn.net/wzq_qwq/article/details/46709471
首先推组合数,设sum为每个人礼物数的和,那么答案为

[( C_{n}^{sum}C_{sum}^{w[1]}c_{sum-w[1]}^{w[2]}... ]

设w[0]=n-sum,然后化简成阶乘的形式:

[frac{n!}{w[0]!w[1]!...w[n]!} ]

注意到这里p不是质数,所以把p拆成质数的方相乘的形式,最后用中国剩余定理合并即可
然后现在的问题是怎么快速求出阶乘
假设当前的质数的方为p=3那么1x2x3x4x5x6x7x8x9x10x11=1x2x4x5x7x8x10x11x 3x(1x2x3),注意到后面又是一个阶乘,但是范围更小,所以可以递归来做,然后前面乘的3被模消去了

#include<iostream>
#include<cstdio>
using namespace std;
const int N=100005;
long long P,n,m,w[10],p[N],cnt[N],mod[N],tot,sum,a[N];
struct qwe
{
	int a,b;
};
void exgcd(long long a,long long b,long long &x,long long &y,long long &d)
{
	if(!b)
	{
		x=1;
		y=0;
		d=a;
		return;
	}
	exgcd(b,a%b,y,x,d);
	y=y-a/b*x;
}
long long china()
{
	long long d,x=0,y;
	for(int i=1;i<=tot;i++)
	{
		long long r=P/mod[i];
		exgcd(mod[i],r,d,y,d);
		x=(x+r*y*a[i])%P;
	}
	return (x+P)%P;
}
long long ksm(long long a,long long b,long long mod)
{
	long long r=1ll;
	while(b)
	{
		if(b&1)
			r=r*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return r;
}
long long inv(long long a,long long b)
{
	long long x,y,d;
	exgcd(a,b,x,y,d);
	return (x%b+b)%b;
}
qwe fac(long long k,long long n)
{
	qwe r;
	if(!n)
	{
		r.a=0,r.b=1;
		return r;
	}
	long long x=n/p[k],y=n/mod[k],ans=1ll;
	if(y)
	{
		for(int i=2;i<mod[k];i++)
			if(i%p[k]!=0)
				ans=ans*i%mod[k];
		ans=ksm(ans,y,mod[k]);
	}
	for(int i=y*mod[k]+1;i<=n;i++)
		if(i%p[k]!=0)
			ans=ans*i%mod[k];
	qwe tmp=fac(k,x);
	r.a=x+tmp.a,r.b=ans*tmp.b%P;
	return r;
}
long long clc(int k,long long n,long long m)
{
	if(n<m)
		return 0;
	qwe a=fac(k,n),b=fac(k,m),c=fac(k,n-m);
	return ksm(p[k],a.a-b.a-c.a,mod[k])*a.b%mod[k]*inv(b.b,mod[k])%mod[k]*inv(c.b,mod[k])%mod[k];
}
long long wk(long long n,long long m)
{
	for(int i=1;i<=tot;i++)
		a[i]=clc(i,n,m);
	return china();
}
int main()
{
	scanf("%lld%lld%lld",&P,&n,&m);
	for(int i=1;i<=m;i++)
		scanf("%lld",&w[i]),sum+=w[i];
	int x=P;
	for(int i=2;i*i<=x;i++)
		if(x%i==0)
		{
			p[++tot]=i;
			mod[tot]=1;
			while(x%i==0)
			{
				x/=i;
				cnt[tot]++;
				mod[tot]*=i;
			}
		}
	if(x>1)
	{
		p[++tot]=x;
		mod[tot]=x;
		cnt[tot]=1;
	}
	if(sum>n)
	{
		puts("Impossible");
		return 0;
	}
	long long ans=wk(n,sum)%P;
	for(int i=1;i<=m;i++)
	{
		ans=ans*wk(sum,w[i])%P;
		sum-=w[i];
	}
	printf("%lld
",ans);
	return 0;
}