hdu2196:Computer (树形DP)

xiaoxiao2021-02-27  368

题目传送门:http://acm.hdu.edu.cn/showproblem.php?pid=2196

题目大意:给定一个无根树,每一条边有长度,输出以每一个点为根时,树上的点到根的距离的最大值。多组数据。

题目分析:首先我们可以很容易地想出O(n^2)的算法:以每一个点为根进行一次DFS,同时维护一棵以node为根的子树内所有节点到node距离的最大值。

然而我们来想一下更优的算法:假设我们先选定1为根进行一次dfs,同样对于每一个节点node维护一棵以node为根的子树内所有节点到node距离的最大值。那么我们在以node为根的时候,树上的其他点到node的距离的最大值会在哪些地方取到呢:

不难发现,只有那些打叉的节点,即node所在子树距离最远的节点,以及node祖先中除了node所在子树以外的最远点,再加上node到它这个祖先的距离

那么后者的信息我们怎么维护呢?我们考虑两次DFS,第一次我们维护node所在子树距离最远的节点和距离次远的节点(注意他们必须不来自同一棵子树)。第二次DFS统计答案。如果当前DFS到node,往node距离最远的节点所在子树DFS时,用次远节点距离+当前边长更新其答案;往node其他子树DFS时,用最远节点距离+当前边长更新其答案。同时我们还要注意把祖先的信息传下来,取个max,这可以在DFS时传一个变量做到。

CODE:

#include<iostream> #include<string> #include<cstring> #include<cmath> #include<cstdio> #include<cstdlib> #include<stdio.h> #include<algorithm> using namespace std; const int maxn=10010; struct data { int obj,len; data *Next; } e[maxn<<1]; data *head[maxn]; int cur; long long dis1[maxn]; long long dis2[maxn]; int son1[maxn]; int son2[maxn]; long long ans[maxn]; int n; void Add(int x,int y,int l) { cur++; e[cur].obj=y; e[cur].len=l; e[cur].Next=head[x]; head[x]=e+cur; } void Release(int x,int y,long long z) { if (z>=dis1[x]) { dis2[x]=dis1[x]; son2[x]=son1[x]; dis1[x]=z; son1[x]=y; } else if (z>dis2[x]) { dis2[x]=z; son2[x]=y; } } void Dfs1(int node,int fa) { dis1[node]=dis2[node]=0; son1[node]=son2[node]=0; for (data *p=head[node]; p; p=p->Next) if (p->obj!=fa) { int son=p->obj; Dfs1(son,node); Release(node,son,dis1[son]+(long long)p->len); } } void Dfs2(int node,int fa,long long up) { ans[node]=max(up,dis1[node]); for (data *p=head[node]; p; p=p->Next) if (p->obj!=fa) { int son=p->obj; if (son==son1[node]) Dfs2(son,node, max(up,dis2[node])+(long long)p->len ); else Dfs2(son,node, max(up,dis1[node])+(long long)p->len ); } } int main() { freopen("c.in","r",stdin); freopen("c.out","w",stdout); while ( scanf("%d",&n),n ) { cur=-1; for (int i=1; i<=n; i++) head[i]=NULL; for (int i=2; i<=n; i++) { int x,y; scanf("%d%d",&x,&y); Add(i,x,y); Add(x,i,y); } Dfs1(1,1); Dfs2(1,1,0); for (int i=1; i<=n; i++) printf("%lld\n",ans[i]); n=0; } return 0; }

转载请注明原文地址: https://www.6miu.com/read-1180.html

最新回复(0)