雅礼集训2019Day1 T3—Math(矩阵快速幂优化dp)

描述

给出 nnmmxx,你需要求出下列式子的值:

(imki)=nimsin(kix)sum_{(sum_{i≤m}{k_i})=n}prod_{i≤m}sin(k_i∗x)
其中kik_i为正整数,由于答案非常大,你只需要输出答案(保证不为 0)的正负(如果是负数输出负号,否则输出正号)和从左往右第一个非 0 数位上的数字即可。


Solution

设?(?, ?)表示 k 的序列大小为 m,k 之和为 n 时的答案。
我们对kmk_m进行分类讨论来得到?(?, ?)的转移:

  1. kmk_m=1,则?(? − 1, ? − 1) ∗ sin(?) −→ ?(?, ?)
  2. kmk_m> 1,考虑对???(? ∗ ?)进行下列变换:
    ???(? ∗ ?) = ???(?) ∗ ???((? − 1) ∗ ?) + ???(?) ∗ ???((? − 1) ∗ ?)
    又???(?) ∗ ???(? ∗ ?) = ???(?) ∗ (???(?) ∗ ???((? − 1) ∗ ?) − ???(?) ∗ ???((? − 1) ∗ ?))
    = ???(?) ∗ ???(?) ∗ ???((? − 1) ∗ ?) + (???2
    (?) − 1) ∗ ???((? − 1) ∗ ?)
    = ???(?) ∗ (???(?) ∗ ???((? − 1) ∗ ?) + ???(?) ∗ ???((? − 1) ∗ ?)) − ???((? − 1) ∗ ?)
    = ???(?) ∗ ???(? ∗ ?) − ???((? − 1) ∗ ?)
    所以???(? ∗ ?) = 2 ???(?) ∗ ???((? − 1) ∗ ?) − ???((? − 2) ∗ ?)
    则这部分的贡献为:
    2?(? − 1, ?) ∗ ???(?) − ?(? − 2, ?)longrightarrow ?(?, ?)
    显然可以用矩阵乘法优化这个过程。
    所以复杂度为O(m3logn)O(m^3logn)

#include<bits/stdc++.h>
using namespace std;
const int RLEN=1<<20|1;
#define ll long long
inline char gc(){
	static char ibuf[RLEN],*ib,*ob;
	(ib==ob)&&(ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
	return (ib==ob)?EOF:*ib++;
}
inline int read(){
	char ch=gc();
	int res=0,f=1;
	while(!isdigit(ch))f^=ch=='-',ch=gc();
	while(isdigit(ch))res=(res+(res<<2)<<1)+(ch^48),ch=gc();
	return f?res:-res;
}
const int N=62;
int siz;
struct mat{
	double a[N][N];
	mat(){memset(a,0,sizeof(a));}
	inline const double *const operator [](const int &oset)const{return a[oset];}
	inline double *const operator [](const int &oset){return a[oset];}
	friend inline mat operator *(const mat &a,const mat &b){
		mat c;
		for(int i=1;i<=siz;i++)
			for(int j=1;j<=siz;j++)
				for(int k=1;k<=siz;k++)	
					c[i][k]+=a[i][j]*b[j][k];
		return c;
	}
}A,B;
inline mat ksm(mat a,int b,mat res){
	for(;b;b>>=1,a=a*a)if(b&1)res=res*a;
	return res;
} 
int T;
int n,m;
double x;
int main(){
	cin>>T;
	while(T--){
		cin>>m>>n>>x;siz=m*2;
		A=mat(),B=mat();
		for(int i=m+1;i<=m*2;i++){
			A[i][i-m]=1;
			A[i-m][i]=-1;
			A[i][i]=2*cos(x);
			if(i<m*2)A[i][i+1]=sin(x);
		}
		B[1][1]=sin(x);
		B[1][m+1]=sin(2*x);
		B[1][m+2]=sin(x)*sin(x);
		B=ksm(A,n-1,B);
		double res=B[1][m];
		if(res<0)putchar('-'),res=-res;
		else putchar('+');
		while(res<1)res*=10;
		while(res>=10)res/=10;
		putchar((int)(floor(res))^48);puts("");
	}
}