一道精妙的树链剖分优化DP的题目

题目

bzoj3522&4543: [Poi2014]Hotel

Time Limit: 20 Sec Memory Limit: 128 MB Submit: 524 Solved: 305

Description

有一个树形结构的宾馆,\(n\)个房间,\(n-1\)条无向边,每条边的长度相同,任意两个房间可以相互到达。吉丽要给他的三个妹子各开(一个)房(间)。三个妹子住的房间要互不相同(否则要打起来了),为了让吉丽满意,你需要让三个房间两两距离相同。 有多少种方案能让吉丽满意? ## Input 第一行一个数\(n\)。 接下来\(n-1\)行,每行两个数\(x,y\),表示\(x\)\(y\)之间有一条边相连。 ## Output 让吉丽满意的方案数。 ## 数据范围 \(n<=5000\)(for T3522), \(n<=100000\)(for T4543) # 解题报告 1. \(O(n^2)\)的做法; 三个房间两两距离相同->存在一个点到三点距离相同,且三点到该店的路径上没有公共点; 其实实现起来很简单,就是枚举中间节点,然后分别dfs与它相连的子树(无根形态),使用三个数组:s1[i] s2[i] f[i],分别表示当前子树枚举前深度为\(i\)的节点个数,当前子树枚举前深度为\(i\)可以贡献答案的点对数,和当前枚举子树中深度为\(i\)的节点个数; 然后可以通过s2[i]=s1[i]*f[i],ans+=s2[i]*f[i],s1[i]+=f[i] 进行转移和统计答案; 2. 上面的状态设计应该是不能再优化了,因为相邻节点之间的转移都是\(O(n)\)的,所以考虑使用另一种状态\(f[i][j]\)表示\(i\)节点的子树中,距离\(i\)距离为\(j\)的节点数,\(g[i][j]\)表示在\(i\)的子树中,需要和到达\(i\)距离为\(j\)的节点构成答案的点对数; 得到转移: \[f[x][0]=1;\] \[ans+=f[x][0];\] \[f[x][j]=f[y][j-1]\] \[g[x][j]=g[y][j+1]\] \[g[x][j]=f[x][j] * f[y][j-1]\] 如果记每个点下方的深度为\(de[i]\),对于每一个当前点\(i\),总的时间代价是\(\Sigma de[y] (father(y)=x)\) 很开心的一点是对于\(x\)枚举的第一个儿子\(y\),f[x][j]=f[y][j-1],g[x][j]=g[y][j+1]可以直接通过指针的移动完成,时间复杂度是\(O(1)\)的,同时得到的启发式空间也可以共用,如果每次都首先完成当前点下方延伸最深的儿子(也就是深度重儿子),那么时间复杂度就可以进化成\(\Sigma de[y]-de[x]+1 (father(y)=x)\)发现展开并化简,最后只剩下一层叶节点,时间复杂度是\(O(n)\); 这样做,一条重链共用的也是一段连续的空间,空间复杂度\(O(n)\)(可以理解为重链长度*常数);


通过轻重链实现启发式的转移得到极大优化的树形dp;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cstdio>
using namespace std;
const int N=100001;
char *cp=(char *) malloc(10000000);
struct edge {
int next,to;
edge(int next=0,int to=0)
:next(next),to(to) {}
} e[N<<1];
int n,son[N],dep[N],head[N],tot;
long long *nxt,*f[N],*g[N],mem[N<<3],ans;
inline void in(int &x) {
for (;*cp<'0'||*cp>'9';cp++);
for (x=0;*cp>='0'&&*cp<='9';cp++)
x=x*10+*cp-48;
}
inline void add(int x,int y) {
e[++tot]=edge(head[x],y),head[x]=tot;
}
void dfs(int x,int fa) {
son[x]=x,dep[x]=dep[fa]+1; int i,y;
for (i=head[x];i;i=e[i].next) {
y=e[i].to; if (y==fa) continue;
dfs(y,x); if (dep[son[y]]>dep[son[x]])
son[x]=son[y];
}
for (i=head[x];i;i=e[i].next) {
y=e[i].to;
if (y!=fa&&(son[y]!=son[x]||x==1)) {
y=son[y],nxt+=dep[y]-dep[x]+2,f[y]=nxt;
++nxt,++nxt,g[y]=nxt,nxt+=(dep[y]-dep[x]+2)<<1;
}
}
}
void dp(int x,int fa) {
int i,y,j;
for (i=head[x];i;i=e[i].next) {
y=e[i].to; if (y==fa) continue;
dp(y,x); if (son[y]==son[x])
f[x]=f[y]-1,g[x]=g[y]+1;
} f[x][0]=1,ans+=g[x][0];
for (i=head[x];i;i=e[i].next) {
y=e[i].to; if (y==fa) continue;
if (son[y]==son[x]) continue;
for(j=0;j<=dep[son[y]]-dep[x];j++) {
ans+=f[x][j-1]*g[y][j]+g[x][j+1]*f[y][j];
}
for(j=0;j<=dep[son[y]]-dep[x];j++) {
g[x][j-1]+=g[y][j];
g[x][j+1]+=f[x][j+1]*f[y][j];
f[x][j+1]+=f[y][j];
}
}
}
int main() {
freopen("bzoj4543.in","r",stdin);
freopen("bzoj4543.out","w",stdout);
fread(cp,1,10000000,stdin);
in(n); int i,j,x,y;
for (i=1;i<n;++i)
in(x),in(y),add(x,y),add(y,x);
nxt=mem,nxt++,nxt++;
dfs(1,0),dp(1,0);
printf("%lld\n",ans);
return 0;
}

留言