牛客Wannafly挑战赛23F 计数(循环卷积+拉格朗日插值/单位根反演)

传送门
直接的想法就是设 (x^k) 为边权,矩阵树定理一波后取出 (x^{nk}) 的系数即可
也就是求出模 (x^k) 意义下的循环卷积的常数项
考虑插值出最后多项式,类比 (DFT) 的方法
假设我们要求

[C_i=sum_{j=0}^{n}sum_{k=0}^{n}A_jB_k[(j+k)~mod~n=i] ]

(A,B,C) 为多项式
我们知道了 (A,B)(n) 个点值

[A(w_n^i)=sum_{k=0}^{n}A_kw_n^{ik} ]

[B(w_n^i)=sum_{k=0}^{n}B_kw_n^{ik} ]

那么

[C(w_n^k)=sum_{i=0}^{n}sum_{j=0}^{n}A_iw_n^{ik}B_jw_n^{jk}=sum_{i=0}^{n}sum_{j=0}^{n}A_iB_jw_n^{k(i+j)} ]

而根据消去引理 (w_n^{k(i+j)}=w_n^{k((i+j)~mod~n)})
所以

[C(w_n^k)=sum_{l=0}^{n}sum_{i=0}^{n}sum_{j=0}^{n}[(i+j)~mod~n=l]A_iB_jw_n^{kl} ]

正好对应了循环卷积,所以只要求得到 (w_n^{k},(k=0...n-1)) 的点值就可以得到最后的多项式了
这道题 (p~mod~k=1) 所以直接用原根就好了,最后插值一下

upd: 其实最后并不需要插值
根据单位根反演

[[k|x]=frac{1}{k}sum_{i=0}^{k-1}omega_{k}^{ix} ]

把多项式的每一项都换成这个东西,得到的值就是要的答案
也就是说直接带入每一个单位根,把矩阵树定理得到的权值加起来最后除去 (k) 就好了

# include <bits/stdc++.h>
using namespace std;
typedef long long ll;

int n, m, k, mod, g, pr[233333], tot, a[105][105];
int xi[105], yi[105];

struct Edge {
	int u, v, w;
} edge[10005];

inline int Pow(ll x, int y) {
	register ll ret = 1;
	for (; y; y >>= 1, x = x * x % mod)
		if (y & 1) ret = ret * x % mod;
	return ret;
}

inline void Inc(int &x, int y) {
	x = x + y >= mod ? x + y - mod : x + y;
}

inline void Getrt() {
	register int x, i, j;
	for (x = mod - 1, i = 2; i * i <= x; ++i)
		if (x % i == 0) {
			pr[++tot] = i;
			while (x % i == 0) x /= i;
		}
	if (x > 1) pr[++tot] = x;
	for (x = mod - 1, i = 2; i <= x; ++i) {
		for (g = i, j = 1; g && j <= tot; ++j)
			if (Pow(g, x / pr[j]) == 1) g = 0;
		if (g) break;
	}
}

inline int Gauss() {
	register int ans = 1, i, j, l, inv;
	for (i = 1; i < n; ++i) {
		for (j = i; j < n; ++j)
			if (a[j][i]) {
				if (i != j) swap(a[i], a[j]), ans = mod - ans;
				break;
			}
		for (j = i + 1; j < n; ++j)
			if (a[j][i]) {
				inv = (ll)a[j][i] * Pow(a[i][i], mod - 2) % mod;
				for (l = i; l < n; ++l) Inc(a[j][l], mod - (ll)a[i][l] * inv % mod);
			}
		ans = (ll)ans * a[i][i] % mod;
	}
	return ans;
}

int main() {
	register int i, j, w, u, v, ans;
	scanf("%d%d%d%d", &n, &m, &k, &mod), Getrt();
	for (i = 1; i <= m; ++i) scanf("%d%d%d", &edge[i].u, &edge[i].v, &edge[i].w);
	xi[0] = 1, xi[1] = Pow(g, (mod - 1) / k);
	for (i = 0; i < k; ++i) {
		if (i > 1) xi[i] = (ll)xi[i - 1] * xi[1] % mod;
		memset(a, 0, sizeof(a));
		for (j = 1; j <= m; ++j) {
			u = edge[j].u, v = edge[j].v, w = Pow(xi[i], edge[j].w);
			Inc(a[u][u], w), Inc(a[v][v], w), Inc(a[u][v], mod - w), Inc(a[v][u], mod - w);
		}
		yi[i] = Gauss();
	}
	for (i = ans = 0; i < k; ++i) {
		for (w = yi[i], j = 0; j < k; ++j)
			if (i ^ j) w = (ll)w * (mod - xi[j]) % mod * Pow((xi[i] + mod - xi[j]) % mod, mod - 2) % mod;
		Inc(ans, w);
	}
	printf("%d
", ans);
	return 0;
}