LibreOJ 2537 「PKUWC2018」Minimax

Description

有一个\(n\)个点的以\(1\)为根的树,每个点至多有两个儿子

对于每个叶子,他的权值事确定的,且两两不同。对于一个非叶子结点\(i\),他有一个系数\(p_i(0<p_i<1)\),满足有\(p_i\)的概率该点权值为子节点权值最大值,其他情况下为最小值。

那么我们会发现每种权值根节点都可能取到,假设有\(m\)种权值,第\(i\)小的权值为\(v_i\)\(v_i\)成为根节点权值的概率为\(d_i\)。那么求: \[ \sum_{i = 1}^m i\cdot v_i\cdot d_i^2 \] \(1\leq n\leq 300000, 0\leq v_i\leq 10^9\)

Solution

从PKUWC开始,苦思冥想了很久,然后看到了一句每个点至多有两个儿子。

我想脱口而出一句μαλάκας。

考虑使用线段树来存储一个点所有可能的权值的概率。那么叶子很好处理,只有一个儿子的点只需要用儿子的树就行了,唯一麻烦的就是有两个儿子的点。

这个东西很显然可以线段树合并吧……注意到右边树第\(i\)小的概率\(p_i\)会发生如下变化: \[ p_i = p_i'(pS_{l, 1\ldots i - 1} + (1 - p)S_{l, i+1\ldots n}) \] 这里\(S_{l,a\ldots b}\)代表左边树本来\([a, b]\)这一段的和,\(p_i'\)代表变换前的\(p_i\)\(p\)就是这一个结点取最大的概率。

因此线段树合并的时候维护一下两边树分别在当前合并结点两侧的和就行了……

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <functional>
#include <utility>
#include <vector>
using ll = long long;
const ll ha = 998244353LL;
ll pow_mod(ll a, ll b) {
ll ans = 1, res = a;
while(b) {
if(1LL & b) ans = ans * res % ha;
res = res * res % ha; b >>= 1;
}
return ans;
}
ll inv(ll x) {
return pow_mod(x, ha - 2LL);
}
const int bufsiz = 50 * 1024 * 1024;
char buf[bufsiz];
void *alloc(size_t size) {
static char *cur = buf;
if(cur - buf + size > bufsiz) {
return malloc(size);
} else {
char *ret = cur; cur += size;
return ret;
}
}

struct Node {
Node *lc, *rc;
ll s, mulv;
void maintain();
void paint(ll v);
void pushdown();
};
Node *nil;
void Node::maintain() {
s = (lc -> s + rc -> s) % ha;
}
void Node::paint(ll v) {
if(this != nil) {
mulv = mulv * v % ha;
s = s * v % ha;
}
}
void Node::pushdown() {
if(mulv != 1LL) {
lc -> paint(mulv);
rc -> paint(mulv);
mulv = 1LL;
}
}

void init_pool() {
nil = (Node*)alloc(sizeof(Node));
nil -> lc = nil -> rc = nil;
nil -> s = 0; nil -> mulv = 1;
}
Node *alloc_node(ll s = 1, Node *lc = nil, Node *rc = nil) {
Node *ret = (Node*)alloc(sizeof(Node));
ret -> s = s; ret -> mulv = 1;
ret -> lc = lc; ret -> rc = rc;
return ret;
}

Node *modify(Node *o, int L, int R, int p, ll v) {
if(o == nil) o = alloc_node(v, nil, nil);
if(L < R) {
int M = (L + R) / 2;
if(p <= M) {
o -> lc = modify(o -> lc, L, M, p, v);
} else {
o -> rc = modify(o -> rc, M + 1, R, p, v);
}
}
return o;
}
Node *merge(Node *l, Node *r, const ll &p, ll lp, ll ls, ll rp, ll rs) {
if(l == nil && r == nil) return nil;
if(l == nil) {
ll v1 = p * lp % ha;
ll v2 = ((1LL - p + ha) % ha) * ls % ha;
r -> paint((v1 + v2) % ha);
return r;
}
if(r == nil) {
ll v1 = p * rp % ha;
ll v2 = ((1 - p + ha) % ha) * rs % ha;
l -> paint((v1 + v2) % ha);
return l;
}
l -> pushdown(); r -> pushdown();
ll l1 = l -> lc -> s, l2 = l -> rc -> s;
ll r1 = r -> lc -> s, r2 = r -> rc -> s;
l -> lc = merge(l -> lc, r -> lc, p, lp, (ls + l2) % ha, rp, (rs + r2) % ha);
l -> rc = merge(l -> rc, r -> rc, p, (lp + l1) % ha, ls, (rp + r1) % ha, rs);
l -> maintain();
return l;
}

const int maxn = 300005;
int cnt;
ll A[maxn], p[maxn];
std::vector<int> G[maxn];
Node *solve(int x) {
if(G[x].size() == 0) {
int v = std::lower_bound(A + 1, A + 1 + cnt, p[x]) - A;
return modify(nil, 1, cnt, v, 1);
} else if(G[x].size() == 1) {
return solve(G[x][0]);
} else {
static const ll inv_w = inv(10000);
p[x] = p[x] * inv_w % ha;
Node *l = solve(G[x][0]), *r = solve(G[x][1]);
return merge(l, r, p[x], 0, 0, 0, 0);
}
}
ll calc(Node *o, int L, int R) {
if(o == nil) return 0;
if(L == R) {
ll prob = o -> s;
ll ret = prob * prob % ha;
ret = ret * ((ll)L * (A[L] % ha) % ha) % ha;
return ret;
} else {
o -> pushdown();
ll ret = 0; int M = (L + R) / 2;
ret = (ret + calc(o -> lc, L, M)) % ha;
ret = (ret + calc(o -> rc, M + 1, R)) % ha;
return ret;
}
}

int main() {
init_pool();
int n; scanf("%d", &n);
for(int i = 1; i <= n; i ++) {
int fa; scanf("%d", &fa);
G[fa].push_back(i);
}
cnt = 0;
for(int i = 1; i <= n; i ++) {
scanf("%lld", &p[i]);
if(G[i].empty()) {
A[++ cnt] = p[i];
}
}
std::sort(A + 1, A + 1 + cnt);
printf("%lld\n", calc(solve(1), 1, cnt));
return 0;
}