树的直径
概念
给定一棵树,树中每个边都有权值,树中两点之间的距离定义为连接两点的路径边权和.树中最远的两点之间的距离称之为树的直径,也就是树的最长链(偶尔也指代为一条路径).
树的直径有两种求法,都是O(n)的.一种是基于树形dp(之后再补上).一种是基于搜索.
搜索
我们随便找一个节点作为根(之前如果有根就不用换).进行一次搜索,找到距离根最远的节点p,然后再将p作为起点,寻找一个距离p最远的点q,那么pq就是直径.
简单证明:
假设根节点是直径的一端,那么根据直径的定义,p一定是直径的一端.
假设根节点不是直径的一端,而且找到的最远的p也不是直径的一端,假设直径的一端是p’.另一个端点是q,那么p’到q的距离可以拆分成 p’到根,根到q,如果将前半截更新成p到根,结果不会缩小.反证法得出p一定是直径(之一)的一端.
其实这里跑最短路也行,只不过搜索的复杂度是O( | V | + | E | )的.(参考<算法与实现>)算法与实现> |
模板
此题最后输出的是一半的直径(向上取整)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// 将代码中的dfs和bfs替换不会影响结果,dfs第二个参数填0
#include<bits/stdc++.h>
using namespace std;
#define MAXM 200001
struct NODE {
int w;
int e;
int next; //next[i]表示与第i条边同起点的上一条边的储存位置
} edge[MAXM];
int cnt;
int head[MAXM];
void CFSInit() { // 链式前向星初始化
cnt = 1;
memset(head, 0, sizeof(head));
}
void add(int u, int v, int w) {
edge[cnt].w = w;
edge[cnt].e = v; //edge[i]表示第i条边的终点
edge[cnt].next = head[u]; //head[i]表示以i为起点的最后一条边的储存位置
head[u] = cnt++;
}
bool mark[MAXM];
bool vis[MAXM];
int maxn, maxid;
void dfs(int cur, int cursum) {
vis[cur] = 1;
if (mark[cur] && maxn < cursum) {
maxn = cursum;
maxid = cur;
}
for (int i = head[cur]; i; i = edge[i].next) {
// cerr << "cur " << cur << " " << edge[i].e << endl;
if (!vis[edge[i].e]) {
dfs(edge[i].e, cursum + 1);
}
}
}
void bfs(int cur) {
queue<pair<int, int> > q;
q.push(make_pair(cur, 0));
vis[cur] = 1;
while (!q.empty()) {
int u = q.front().first;
int dis = q.front().second;
if (mark[u] && dis > maxn) {
maxn = dis;
maxid = u;
}
q.pop();
for (int i = head[u]; i; i = edge[i].next) {
if (!vis[edge[i].e]) {
vis[edge[i].e] = 1;
q.push(make_pair(edge[i].e, dis + 1));
}
}
}
}
int main() {
int n, k;
while (scanf("%d%d", &n, &k) != EOF) {
memset(head, 0, sizeof(head));
memset(mark, 0, sizeof(mark));
memset(vis, 0, sizeof(vis));
cnt = 1;
int a, b, c;
for (int i = 0; i < n - 1; ++i) {
scanf("%d%d", &a, &b);
c = 1;
add(a, b, c);
add(b, a, c);
}
for (int i = 0; i < k; ++i) {
int tmp;
scanf("%d", &tmp);
mark[tmp] = true;
}
maxn = 0;
bfs(1);
int p = maxid;
maxn = 0;
memset(vis, 0, sizeof(vis));
bfs(p);
printf("%d\n", (maxn + 1) / 2);
}
return 0;
}