LibreOJ 2320 「清华集训 2017」生成树计数

Description

有一个图有\(n\)个树形连通块,第\(i\)个有\(a_i\)个点。你需要再连\(n - 1\)条边使得这个图变成一个树。

对于一个方案\(T\),我们记\(d_i\)表示第\(i\)个连通块向外连了多少边。那么我们称其价值为: \[ \mathrm{val}(T) = (\sum_{i = 1}^n d_i^m)(\prod_{i = 1}^n d_i^m) \] 其中\(m\)为给定常数。求所有可能的树的价值之和。

\(n\leq 30000, m\leq 30, a_i\in Z_{998244353}\)

Solution

辣鸡卡常题毁我青春.jpg

先化简一下式子: \[ \begin{aligned} \quad&\sum_{T} (\sum_{i = 1}^nd_i^m)(\prod_{i = 1}^nd_i^m)\\ =&\sum_{T}\sum_{i = 1}^n d_i^{2m}\prod_{j\neq i} d_i^m \end{aligned} \] 观察到度数对答案的贡献一点,很容易想到Prufer序列。我们知道如果一个点度数为\(d_i\),那么他在Prufer序列中也会出现\(d_i - 1\)次。因此这个问题转化成了一个排列问题……事指数型生成函数的用武之地。

那么我们很容易想到定义两类生成函数: \[ \begin{aligned} A_i(x) &= \sum_{k = 0}^{+\infty}\frac{(k + 1)^{2m}A_i^{k + 1}x^k}{k!}\\ B_i(x) &= \sum_{k = 0}^{+\infty}\frac{(k + 1)^{m}A_i^{k + 1}x^k}{k!} \end{aligned} \] 很显然答案生成函数就是\(\sum_{i = 1}^n A_i(x)\prod_{j\neq i} B_i(x)\)。然后下面先考虑一下那个\(A_i(x)\)……注意到他的系数全都和\(k + 1\)有关,但是\(x\)的次数为\(k\),因此用积分将其右移(记\(C_i(x) = \int A_i(x)\mathrm{d}x\)): \[ \begin{aligned} C_i(x) &= \sum_{k = 1}^{+\infty} \frac{k^{2m} A_i^k x^k}{k!}\\ &=\sum_{k = 0}^{+\infty} \frac{k^{2m} A_i^k x^k}{k!}\\ &=\sum_{k = 0}^{+\infty}\frac{(A_ix)^k}{k!}\sum_{j = 0}^{2m}\begin{Bmatrix}2m\\ j\end{Bmatrix}k^{\underline{j}}\\ &=\sum_{j = 0}^{2m}\begin{Bmatrix}2m\\ j\end{Bmatrix}j!\sum_{k = 0}^{+\infty}\frac{(A_ix)^k}{k!}\binom{k}{j}\\ &=\sum_{j = 0}^{2m}\begin{Bmatrix}2m\\ j\end{Bmatrix}j!\cdot\frac{A_i^jx^j}{j!}e^{A_ix}\\ &=\sum_{j = 0}^{2m}\begin{Bmatrix}2m\\ j\end{Bmatrix}A_i^jx^je^{A_ix} \end{aligned} \] 然后再推回去: \[ \begin{aligned} A_i(x) &= C_i'(x)\\ &= \sum_{j = 0}^{2m}\begin{Bmatrix}2m\\ j\end{Bmatrix}A_i^j(jx^{j - 1}e^{A_ix} + A_ix^je^{A_ix})\\ &= e^{A_ix}\sum_{j = 0}^{2m}\begin{Bmatrix}2m\\ j + 1\end{Bmatrix}(j + 1)A_i^{j + 1}x^j + \begin{Bmatrix}2m\\ j\end{Bmatrix}A_i^{j + 1}x^j\\ &= e^{A_ix}\sum_{j = 0}^{2m}A_i^{j + 1}\begin{Bmatrix}2m + 1\\ j + 1\end{Bmatrix}x^j \end{aligned} \] 最后一步用了第二类斯特林数的递推公式……类似的我们可以得到\(B_i(x)=e^{A_ix}\sum_{j = 0}^{m}A_i^{j + 1}\begin{Bmatrix}m + 1\\ j + 1\end{Bmatrix}x^j\)

然后我们提出所有的\(e^{cx}\)的形式,然后剩下的东西就全都是\(2m\)次以下的多项式了……

最后我们发现我们要求类似于\(\sum_{i = 1}^n A_i(x)\prod_{j\neq i}B_i(x)\)这样的形式,分治FFT一下即可。

注意常数啊……这玩意卡常剧毒。讲个笑话,我ban掉NTT的循环展开之后卡过了最后一个点……

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <functional>
#include <utility>
#include <memory>
#include <vector>
using ll = long long;
constexpr ll ha = 998244353LL;
inline ll pow_mod(ll a, ll b) {
ll ans = 1, res = a;
while(b) {
if(1LL & b) ans = ans * res % ha;
res = res * res % ha; b >>= 1;
}
return ans;
}
inline ll inv(ll x) {
return pow_mod(x, ha - 2LL);
}

const int maxn = 30005;
ll fac[maxn], ifac[maxn];
inline void process_fac() {
int n = 30000;
fac[0] = 1;
for(int i = 1; i <= n; i ++) {
fac[i] = fac[i - 1] * (ll)i % ha;
}
ifac[n] = inv(fac[n]);
for(int i = n - 1; i >= 0; i --) {
ifac[i] = ifac[i + 1] * (ll(i + 1)) % ha;
}
}
ll S[105][105];
inline void process_S() {
int n = 62;
S[0][0] = 1;
for(int i = 1; i <= n; i ++) {
S[i][1] = S[i][i] = 1;
for(int j = 2; j < i; j ++) {
S[i][j] = ((S[i - 1][j] * (ll)j) % ha + S[i - 1][j - 1]);
if(S[i][j] > ha) S[i][j] -= ha;
}
}
}

inline int flip(int bi, int x) {
int ans = 0;
for(int i = 0; i < bi; i ++) {
if((1 << i) & x) {
ans |= (1 << (bi - i - 1));
}
}
return ans;
}
inline void NTT(ll *A, int bi, bool flag = false) {
int n = 1 << bi;
for(int i = 1, j = 0; i < n; i ++) {
int k = n;
do j ^= (k >>= 1); while ((j & k) == 0);
if(i < j) std::swap(A[j], A[i]);
}
for(int L = 1; L < n; L <<= 1) {
ll c = (ha - 1LL) / (ll(L << 1));
if(flag) c = ha - 1LL - c;
ll xi_n = pow_mod(3LL, c);
ll *B = A + L;
for(int i = 0; i < n; i += (L << 1)) {
ll xi = 1;
if(true) {
for(int j = i; j < i + L; j ++) {
ll x = A[j], y = B[j] * xi % ha;
A[j] = (x + y); if(A[j] > ha) A[j] -= ha;
B[j] = (x - y + ha); if(B[j] > ha) B[j] -= ha;
xi = xi * xi_n % ha;
}
} else {
for(int j = i; j < i + L; j += 4) {
ll x = A[j], y = B[j] * xi % ha;
A[j] = (x + y); if(A[j] > ha) A[j] -= ha;
B[j] = (x - y + ha); if(B[j] > ha) B[j] -= ha;
xi = xi * xi_n % ha;
x = A[j + 1], y = B[j + 1] * xi % ha;
A[j + 1] = (x + y); if(A[j + 1] > ha) A[j + 1] -= ha;
B[j + 1] = (x - y + ha); if(B[j + 1] > ha) B[j + 1] -= ha;
xi = xi * xi_n % ha;
x = A[j + 2], y = B[j + 2] * xi % ha;
A[j + 2] = (x + y); if(A[j + 2] > ha) A[j + 2] -= ha;
B[j + 2] = (x - y + ha); if(B[j + 2] > ha) B[j + 2] -= ha;
xi = xi * xi_n % ha;
x = A[j + 3], y = B[j + 3] * xi % ha;
A[j + 3] = (x + y); if(A[j + 3] > ha) A[j + 3] -= ha;
B[j + 3] = (x - y + ha); if(B[j + 3] > ha) B[j + 3] -= ha;
xi = xi * xi_n % ha;
}
}
}
}
if(flag) {
static const ll inv_2 = inv(2);
ll inv_n = 1; int cnt = bi;
while(cnt --) inv_n = inv_n * inv_2 % ha;
for(int i = 0; i < n; i ++) {
A[i] = A[i] * inv_n % ha;
}
}
}

const int bufsiz = 200 * 1024 * 1024;
char buf[bufsiz];
inline void *alloc(size_t size) {
static char *cur = buf;
if(cur - buf + size > bufsiz) {
return malloc(size);
} else {
char *ret = cur; cur += size;
return ret;
}
}

struct Poly {
ll *A; int n;
};
inline Poly *new_poly(int len) {
Poly *ret = (Poly*)alloc(sizeof(Poly));
ret -> n = len;
ret -> A = (ll*)alloc(sizeof(Poly) * (len + 1));
return ret;
}
using pii = std::pair<Poly*, Poly*>;
ll a[maxn]; int n, m;
pii div_con(int L, int R) {
if(L == R) {
Poly *A = new_poly(m), *B = new_poly(2 * m);
ll pw = a[L];
for(int i = 0; i <= (m << 1); i ++) {
if(i <= m) A -> A[i] = pw * S[m + 1][i + 1] % ha;
B -> A[i] = pw * S[m << 1 | 1][i + 1] % ha;
pw = pw * a[L] % ha;
}
return std::make_pair(A, B);
} else {
static ll t1[maxn << 2], t2[maxn << 2], t3[maxn << 2], t4[maxn << 2], t5[maxn << 2];
int M = (L + R) / 2;
pii tmp_l = div_con(L, M), tmp_r = div_con(M + 1, R);
Poly *la = tmp_l.first, *lb = tmp_l.second;
Poly *ra = tmp_r.first, *rb = tmp_r.second;
int nb_n = la -> n + ra -> n;
int nn = std::max(la -> n + rb -> n, lb -> n + ra -> n);
int len = 1, bi = 0;
while(len <= std::max(nb_n, nn)) {
len <<= 1; bi ++;
}
std::copy(la -> A, (la -> A) + la -> n + 1, t1);
std::copy(ra -> A, (ra -> A) + ra -> n + 1, t2);
std::fill(t1 + la -> n + 1, t1 + len, 0LL);
std::fill(t2 + ra -> n + 1, t2 + len, 0LL);
NTT(t1, bi); NTT(t2, bi);
if(len < 4) {
for(int i = 0; i < len; i ++) t5[i] = t1[i] * t2[i] % ha;
} else {
for(int i = 0; i < len; i += 4) {
t5[i] = t1[i] * t2[i] % ha;
t5[i + 1] = t1[i + 1] * t2[i + 1] % ha;
t5[i + 2] = t1[i + 2] * t2[i + 2] % ha;
t5[i + 3] = t1[i + 3] * t2[i + 3] % ha;
}
}
NTT(t5, bi, true);
nb_n = std::min(nb_n, n - 2);
Poly *A = new_poly(nb_n); std::copy(t5, t5 + nb_n + 1, A -> A);

// std::copy(la -> A, (la -> A) + la -> n + 1, t1);
// std::fill(t1 + la -> n + 1, t1 + len, 0LL);
// std::copy(ra -> A, (ra -> A) + ra -> n + 1, t2);
// std::fill(t2 + ra -> n + 1, t2 + len, 0LL);
std::copy(lb -> A, (lb -> A) + lb -> n + 1, t3);
std::fill(t3 + lb -> n + 1, t3 + len, 0LL);
std::copy(rb -> A, (rb -> A) + rb -> n + 1, t4);
std::fill(t4 + rb -> n + 1, t4 + len, 0LL);
// delete la; delete ra; delete lb; delete rb;
// NTT(t1, bi); NTT(t2, bi);
NTT(t3, bi); NTT(t4, bi);
for(int i = 0; i < len; i ++) {
t1[i] = t1[i] * t4[i] % ha;
t2[i] = t2[i] * t3[i] % ha;
t1[i] = (t1[i] + t2[i]); if(t1[i] > ha) t1[i] -= ha;
}
NTT(t1, bi, true); // NTT(t2, bi, true);
// for(int i = 0; i < len; i ++) t1[i] = (t1[i] + t2[i]) % ha;
nn = std::min(nn, n - 2);
Poly *B = new_poly(nn); std::copy(t1, t1 + nn + 1, B -> A);
return std::make_pair(A, B);
}
}

int main() {
process_fac(); process_S();
scanf("%d%d", &n, &m);
ll s = 0;
for(int i = 1; i <= n; i ++) {
scanf("%lld", &a[i]); s += a[i];
if(s > ha) s -= ha;
}
Poly *B = div_con(1, n).second;
ll ans = 0;
ll pw = 1LL;
for(int i = 0; i <= n - 2; i ++) {
ll delta = pw * ifac[i] % ha;
delta = delta * (B -> A[n - 2 - i]) % ha;
ans = (ans + delta);
if(ans > ha) ans -= ha;
pw = pw * s % ha;
}
ans = ans * fac[n - 2] % ha;
printf("%lld\n", ans);
return 0;
}