LibreOJ 2478 「九省联考 2018」林克卡特树

Description

有一个\(n\)个点的带边权树,要求你取出\(k + 1\)条点不相交的链(链可以只有一个单点),使得这些链的边权和最大。

\(0\leq k< n\leq 3\times 10^5\),边权绝对值不超过\(10^6\)

Solution

wqs二分第三题……

下面说的\(k\)默认都是原题中的\(k + 1\)

要是\(nk\)不大的话,那么有一个经典的搞法就是树上DP,记状态\(f_{i,s.0/1/2}\)表示点\(i\)这个子树总共取了\(k\)条链,\(i\)本身在这些链中度数为\(0/1/2\)的情况的最优解,转移还是比较显(fan)然(suo)的……

然后\(nk\)很大……但是发现这种东西取了有收益,然后还有数量限制,要素察觉(意味深)

虽然这个东西的凸性只能感性理解或者打表理解了……我真的不会证啊

然后考虑wqs二分……具体方法就是wqs二分,里面的DP就把上面的DP去掉链数限制就行了。

然后存在一个问题就是,假设二分出来一个\(v\),但是代价为\(v\)的时候链数比\(k\)小,代价为\(v-1\)的时候链数就比\(k\)多了。这样的话当代价为\(v\)的时候照样存在\(k\)条链的最优解,所以在\(v\)时的最优解中补上一个\(kv\)就完了……

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <functional>
#include <utility>
const int maxn = 300005;
const int maxm = maxn << 1;
int first[maxn];
int next[maxm], to[maxm], dist[maxm];
void add_edge(int u, int v, int d) {
static int cnt = 0; cnt ++;
next[cnt] = first[u]; first[u] = cnt;
to[cnt] = v; dist[cnt] = d;
}
void ins_edge(int u, int v, int d) {
add_edge(u, v, d); add_edge(v, u, d);
}

typedef long long ll;
struct pii {
ll a; int b;
pii() {}
pii(ll x, int y) { a = x; b = y; }

bool operator <(const pii &res) const {
if(a == res.a) {
return b > res.b;
} else {
return a < res.a;
}
}
bool operator ==(const pii &res) const {
return (a == res.a) && (b == res.b);
}
};
pii operator +(const pii &a, const pii &b) {
return pii(a.a + b.a, a.b + b.b);
}

pii d[maxn][3];
void dfs(int x, int fa, const ll &cost) {
d[x][0] = pii(0LL, 0);
d[x][1] = d[x][2] = pii(-cost, 1);
for(int i = first[x]; i; i = next[i]) {
int v = to[i];
if(v != fa) {
dfs(v, x, cost);
pii maxv = std::max(d[v][0], std::max(d[v][1], d[v][2]));
pii t2 = std::max(maxv + d[x][2], d[x][1] + d[v][1] + pii((ll)dist[i] + cost, -1));
d[x][2] = std::max(d[x][2], t2);
pii t1 = std::max(maxv + d[x][1], d[x][0] + d[v][1] + pii((ll)dist[i], 0));
d[x][1] = std::max(d[x][1], t1);
d[x][0] = std::max(d[x][0], d[x][0] + maxv);
}
}
}

int main() {
int n, k; scanf("%d%d", &n, &k); k ++;
ll L = 0LL, R = 0LL;
for(int i = 1; i <= n - 1; i ++) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
ins_edge(u, v, w);
R += std::max(0, w); L = std::min(L, (ll)-w);
}
while(L < R) {
#ifdef LOCAL
printf("Range (%lld, %lld)\n", L, R); fflush(stdout);
#endif
ll M = L + (R - L) / 2LL;
dfs(1, -1, M);
pii t = std::max(d[1][0], std::max(d[1][1], d[1][2]));
if(t.b <= k) {
R = M;
} else {
L = M + 1LL;
}
}
dfs(1, -1, L);
pii t = std::max(d[1][0], std::max(d[1][1], d[1][2]));
printf("%lld\n", t.a + (ll)k * L);
return 0;
}