Loj 6485. LJJ 学二项式定理 Loj 6485. LJJ 学二项式定理
题目描述
LJJ 学完了二项式定理,发现这太简单了,于是他将二项式定理等号右边的式子修改了一下,代入了一定的值,并算出了答案。
但人口算毕竟会失误,他请来了你,让你求出这个答案来验证一下。
一共有 $ T $ 组数据,每组数据如下:
输入以下变量的值:$ n, s , a_0 , a_1 , a_2 , a_3$,求以下式子的值:
[Large left[ sum_{i=0}^n left( {nchoose i} cdot s^{i} cdot a_{imod 4}
ight)
ight] mod 998244353
]
其中 $ nchoose i $ 表示 $ frac{n!}{i!(n-i)!} $。
输入格式
第一行一个整数 (T),之后 (T) 行,一行六个整数 (n, s, a_0, a_1, a_2, a_3)。
输出格式
一共 (T) 行,每行一个整数表示答案。
样例
样例输入
6
1 2 3 4 5 6
2 3 4 5 6 1
3 4 5 6 1 2
4 5 6 1 2 3
5 6 1 2 3 4
6 1 2 3 4 5
样例输出
11
88
253
5576
31813
232
数据范围与提示
对于 $ 50% $ 的数据,$ T imes n leq 10^5 $;
对于 $ 100% $ 的数据,$ 1 leq T leq 10^5, 1 leq n leq 10 ^ {18}, 1 leq s, a_0, a_1, a_2, a_3 leq 10^{8} $。
(\)
前置知识:单位根反演
我们考虑对每个(d=0...3)计算
[Ans_d=left[ sum_{i=0}^n[i\%4==d] left( {nchoose i} cdot s^{i} cdot a_d
ight)
ight] mod 998244353
]
答案就是
[Ans=sum_{d=0}^3Ans_d
]
我们交换一下求和顺序:
[egin{align}
Ans_d=a_dsum_{i=0}^ninom{n}{i}s^i[i\%4==d]\
=a_dsum_{i=0}^ninom{n}{i}s^i[4|(i-d)]\
end{align}
]
直接套单位根反演的套路:
[[k|n]=sum_{i=0}^{k-1}(omega_k^n)^i\
Longrightarrow [4|(i-d)]=frac{1}{4}sum_{j=0}^3 (omega_4^{i-d})^j
]
再带回去:
[egin{align}
Ans_d&=a_dsum_{i=0}^ninom{n}{i}s^i[4|(i-d)]\
&=a_dsum_{i=0}^ninom{n}{i}s^ifrac{1}{4}sum_{j=0}^3(omega _4^{i-d})^j\
end{align}
]
这里(i)直接从(0)开始枚举是没有问题的,因为即使(i-d)为负一样满足等比数列求和。
在根据套路交换求和符号:
[egin{align}
Ans_d&=a_dfrac{1}{4}sum_{j=0}^3sum_{i=0}^ninom{n}{i}s^i(omega _4^{ij-dj})\
&=a_dfrac{1}{4}sum_{j=0}^3frac{1}{omega_4 ^{dj}} sum_{i=0}^ninom{n}{i}s^i(omega _4^{j})^i\
end{align}
]
我们设:
[f_n(x)=sum_{i=0}^ninom{n}{i}s^ix^i\
=(sx+1)^n
]
则:
[egin{align}
Ans_d&=a_dfrac{1}{4}sum_{j=0}^3frac{1}{omega_4 ^{dj}} sum_{i=0}^n
inom{n}{i}s^i(omega _4^{j})^i\
&=a_dfrac{1}{4}sum_{j=0}^3frac{1}{omega_4 ^{dj}} f(omega_4^j)\
&=a_dfrac{1}{4}sum_{j=0}^3frac{1}{omega_4 ^{dj}} (scdotomega_4^j+1)^n\
end{align}
]
模质数(p)意义下(omega_4^1)可以取(g^{frac{p-1}{4}})其中(g)是原根。
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
inline ll Get() {ll x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
const ll mod=998244353,g=3;
ll ksm(ll t,ll x) {
ll ans=1;
for(;x;x>>=1,t=t*t%mod)
if(x&1) ans=ans*t%mod;
return ans;
}
ll n,s,a[4];
ll w;
const ll inv4=ksm(4,mod-2);
int main() {
w=ksm(g,(mod-1)>>2);
int T=Get();
while(T--) {
n=Get(),s=Get();
for(int i=0;i<4;i++) a[i]=Get();
ll ans=0;
for(int d=0;d<4;d++) {
ll now=0;
for(int j=0;j<4;j++) {
(now+=ksm(ksm(w,d*j),mod-2)*ksm((s*ksm(w,j)+1)%mod,n))%=mod;
}
(ans+=now*a[d])%=mod;
}
cout<<ans*inv4%mod<<"
";
}
return 0;
}