Skip to content

树上问题

树的直径

两次DFS

从任意结点开始,做DFS,达到最远结点,再从该最远结点出发做DFS,达到其最远结点。两个最远结点就是直径两端。

只适用于正边权

树形DP

过程一

我们记录当1为树的根时,每个节点作为子树的根向下,所能延伸的最长路径长度 d1与次长路径(与最长路径无公共边)长度d2,那么直径就是对于每一个点,该点 d1 + d2 能取到的值中的最大值。

树形 DP 可以在存在负权边的情况下求解出树的直径。

constexpr int N = 10000 + 10;

int n, d = 0;
int d1[N], d2[N];
vector<int> E[N];

void dfs(int u, int fa) {
  d1[u] = d2[u] = 0;
  for (int v : E[u]) {
    if (v == fa) continue;
    dfs(v, u);
    int t = d1[v] + 1;
    if (t > d1[u])
      d2[u] = d1[u], d1[u] = t;
    else if (t > d2[u])
      d2[u] = t;
  }
  d = max(d, d1[u] + d2[u]);
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i < n; i++) {
    int u, v;
    scanf("%d %d", &u, &v);
    E[u].push_back(v), E[v].push_back(u);
  }
  dfs(1, 0);
  printf("%d\n", d);
  return 0;
}

过程二

我们定义dp[u]:以u为根的子树中,从u出发的最长路径。转移方程:dp[u] = max(dp[u], dp[v]+w(u,v));其中,vu的叶子结点,w(a,b)代表结点a,b之间的权重。(子结点不是叶子)

对于树的直径,实际上是可以通过枚举从某个节点出发不同的两条路径相加的最大值求出。因此,在DP求解的过程中,我们只需要在更新dp[u]之前,计算d=max(d, dp[u]+dp[v]+w(u,v));即可算出直径d。

constexpr int N = 10000 + 10;

int n, d = 0;
int dp[N];
vector<int> E[N];

void dfs(int u, int fa) {
  for (int v : E[u]) {
    if (v == fa) continue;
    dfs(v, u);
    d = max(d, dp[u] + dp[v] + 1);
    dp[u] = max(dp[u], dp[v] + 1);
  }
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i < n; i++) {
    int u, v;
    scanf("%d %d", &u, &v);
    E[u].push_back(v), E[v].push_back(u);
  }
  dfs(1, 0);
  printf("%d\n", d);
  return 0;
}

树的中心

在树中,如果节点x作为根节点时,从x出发的最长链最短,那么称x为这棵树的中心。

性质

  • 树的中心不一定唯一,但最多有2个,且这两个中心是相邻的。
  • 树的中心一定位于树的直径上。
  • 树上所有点到其最远点的路径一定交会于树的中心。
  • 当树的中心为根节点时,其到达直径端点的两条链分别为最长链和次长链。
  • 当通过在两棵树间连一条边以合并为一棵树时,连接两棵树的中心可以使新树的直径最小。
  • 树的中心到其他任意节点的距离不超过树直径的一半。

求法

寻找一个点x,使其作为根节点时,最长链的长度最短。

具体步骤

  1. 维护\(len1_x\),表示节点\(x\)子树内的最长链。
  2. 维护\(len2_x\),表示不与\(len1_x\)重叠的最长链。
  3. 维护\(up_x\),表示节点\(x\)子树外的最长链,该链必定经过\(x\)的父节点。
  4. 找到点\(x\)使得\(max(len1_x,up_x)\)最小,那么\(x\)即为树的中心。

树的重心

如果在树中选择某个节点并删除,这棵树将分为若干棵子树,统计子树节点数并记录最大值。取遍树上所有节点,使此最大值取到最小的节点被称为整个树的重心。

(这里以及下文中的「子树」若无特殊说明都是指无根树的子树,即包括「向上」的那棵子树,并且不包括整棵树自身。)

性质

  • 树的重心如果不唯一,则至多有两个,且这两个重心相邻。
  • 以树的重心为根时,所有子树的大小都不超过整棵树大小的一半。
  • 树中所有点到某个点的距离和中,到重心的距离和是最小的;如果有两个重心,那么到它们的距离和一样。
  • 把两棵树通过一条边相连得到一棵新的树,那么新的树的重心在连接原来两棵树的重心的路径上。
  • 在一棵树上添加或删除一个叶子,那么它的重心最多只移动一条边的距离。

求法

在 DFS 中计算每个子树的大小,记录「向下」的子树的最大大小,利用总点数 - 当前子树(这里的子树指有根树的子树)的大小得到「向上」的子树的大小,然后就可以依据定义找到重心了。

最近公共祖先LCA

性质

  1. \(LCA({u})=u\)
  2. \(u\)\(v\)的祖先,当且仅当\(LVA(u,v)=u\)
  3. 如果\(u\)不为\(v\)的祖先并且\(v\)不为\(u\)的祖先,那么\(u,v\)分别处于\(LVA(u,v)\)的两棵不同子树中;
  4. 前序遍历中,\(LCA(S)\)出现在所有\(S\)中元素之前,后序遍历中\(LCA(S)\)则出现在所有\(S\)中元素之后;
  5. 两点集并的最近公共祖先为两点集分别的最近公共祖先的最近公共祖先,即\(LCA(A \cup B)=LAC(LCA(A),LCA(B))\)
  6. 两点的最近公共祖先必定处在树上两点间的最短路上;
  7. \(d(u,v)=h(u)+h(v)-2h(LCA(u,v))\),其中\(d\)是树上两点间的距离,\(h\)代表某点到树根的距离。

倍增算法

通过预处理\(fa_{x,i}\)数组,减少游标跳转次数。\(fa_{x,i}\)表示点\(x\)的第\(2^i\)个祖先,该数组可通过DFS预处理出来。

第一阶段:将u,v两点跳转到同一深度。我们可以计算出u,v两点的深度差,记作y,对y进行二进制拆分,将y次游标跳转优化为「y的二进制表示所含 1的个数」次游标跳转。

第二阶段:从最大的i开始尝试,直到0,如果\(fa_{u,i}\not=fa_{v,i}\),则\(u \leftarrow fa_{u,i} \ ,v \leftarrow fa_{v,i}\),那么最后的LCA为\(fa_{u,0}\)

性质

预处理的时间复杂度是:\(O(n logn)\)

单次查询的复杂度是:\(O(logn)\)

倍增算法也能通过把fa数组的最小维放前面,减少 cache miss的次数

外倍增算法(也称为二分法或倍增法)是一种在树或图结构中查找两个节点的最近公共祖先(LCA)的算法。在这种算法中,fa 数组用于存储每个节点的所有祖先节点,其中 fa[j][i] 表示节点 i2^j 倍祖先。通过交换 fa 数组的两维,将较小的维度放在前面,可以提高程序效率,减少缓存未命中(cache miss)的次数。下面我将解释为什么这样做可以减少 cache miss,以及什么是 cache miss。

什么是 Cache Miss

Cache miss 是指 CPU 在访问缓存(Cache)时,没有找到所需数据的情况。当发生 cache miss 时,CPU 需要从更慢的内存层次(如主存)中获取数据,这会导致性能下降,因为主存的访问速度远低于缓存。

为什么交换 fa 数组的两维可以减少 Cache Miss

  1. 空间局部性原理:CPU 访问数据时,倾向于访问相邻的内存地址。如果数据在内存中是连续存储的,那么它们很可能会被一起加载到 Cache 中,从而提高 Cache 命中率。

  2. 减少 Cache Line 替换:当 fa 数组的较小维度放在前面时,对于每个节点,其所有祖先节点的信息更有可能被存储在连续的 Cache Line 中。这样,当算法访问一个节点的祖先信息时,相关的 Cache Line 已经加载到 Cache 中,减少了因访问不同维度数据而导致的 Cache Line 替换。

  3. 提高数据访问效率:在外倍增算法中,经常需要沿着树向上查找节点的祖先。如果较小的维度(通常是树的深度)放在前面,那么在查找过程中,相关的数据更有可能已经在 Cache 中,因为它们是连续访问的。这样可以减少因访问不同节点而导致的 Cache miss。

  4. 优化数据访问模式:通过交换 fa 数组的两维,可以使得数据访问模式更加友好,从而减少 Cache miss。例如,如果算法首先访问节点的直接父节点,然后是祖父节点,以此类推,那么将较小维度放在前面可以确保这些连续的祖先节点信息在 Cache 中是连续的。

综上所述,通过交换 fa 数组的两维,使得较小的维度放在前面,可以更好地利用空间局部性原理,减少 Cache Line 替换,提高数据访问效率,从而减少 cache miss 的次数,提高程序的整体性能。

#include <iostream>
#include <vector>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <cmath>
#include <climits>
#include <cassert>
using namespace std;
const int N = 40000 + 10;

int fa[N][31], cost[N][31], dep[N];
vector<int> a[N], w[N];

void resolution();
void dfs(int, int);
int LCA(int, int);

int main(){
    ios::sync_with_stdio(0);
    int T;
    cin >> T;
    while(T--){
        resolution();
    }
    return 0;
}

void resolution(){
    int n, m;
    cin >> n >> m;

    memset(fa, 0, sizeof(fa));
    memset(cost, 0, sizeof(cost));
    memset(dep, 0, sizeof(dep));
    for(int i = 0; i <= n; i++){
        a[i].clear();
    }

    for(int i = 1; i <= n-1; i++){
        int u, v, c;
        cin >> u >> v >> c;
        a[u].push_back(v);
        a[v].push_back(u);
        w[u].push_back(c);
        w[v].push_back(c);
    }
    dfs(1, 0);
    for(int i = 1; i <= m; i++){
        int x, y;
        cin >> x >> y;
        cout << LCA(x, y) << endl;
    }
    return ;
}

void dfs(int u, int f){
    dep[u] = dep[f] + 1;
    fa[u][0] = f;
    // 初始化:其他的祖先节点:第 2^i 的祖先节点是第 2^(i-1) 的祖先节点的第2^(i-1) 的祖先节点。
    for(int i = 1; i <= 30; i++){
        fa[u][i] = fa[fa[u][i-1]][i-1];
        cost[u][i] = cost[fa[u][i-1]][i-1] + cost[u][i-1];
    }
    //遍历子结点
    for(int i = 0; i < a[u].size(); i++){
        if(a[u][i] == f)   continue;
        cost[a[u][i]][0] = w[u][i];
        dfs(a[u][i], u);
    }
    return ;
}

int LCA(int x, int y){
    //保持y的深度更深
    if(dep[x] > dep[y]) swap(x, y);
    // 让y和x在同一深度
    int tmp = dep[y] - dep[x], ans = 0;
    for(int i = 0; i <= 30 && tmp; tmp >>= 1, i++){
        if(tmp & 1){
            //当前位置为1
            ans += cost[y][i];
            y = fa[y][i];
        }
    }
    if(x == y){
        return ans;
    }
    for(int j = 30; j >= 0; j--){
        if(fa[x][j]!= fa[y][j]){
            // 如果说超出了树的限度,即不存在该结点,则fa[x][j] == fa[y][j] == 0
            ans += cost[x][j] + cost[y][j];
            x = fa[x][j];
            y = fa[y][j];
        }
    }
    // 由于if语句,该循环不会把x和y更新到公共祖先
    ans += cost[x][0] + cost[y][0];
    return ans;
}

Tarjan 算法

Tarjan 算法是一种 离线算法,需要使用并差集记录某个结点的祖先结点。做法如下:

  1. 首先接受输入边(邻接链表)、查询边(存储在另一个邻接链表内)。查询边其实是虚拟加上去的边,为了方便,每次输入查询边的时候,将这个边及其反向边都加入到 queryEdge 数组里。
  2. 然后对其进行一次 DFS 遍历,同时使用 visited 数组进行记录某个结点是否被访问过、parent记录当前结点的父亲结点。
  3. 其中涉及到了 回溯思想,我们每次遍历到某个结点的时候,认为这个结点的根结点就是它本身。让以这个结点为根节点的 DFS 全部遍历完毕了以后,再将这个结点的根节点设置为这个结点的父一级结点。
  4. 回溯的时候,如果以该节点为起点,queryEdge 查询边的另一个结点也恰好访问过了,则直接更新查询边的 LCA 结果。
  5. 最后输出结果。

性质

Tarjan 算法需要初始化并查集,所以预处理的时间复杂度为\(O(n)\)

朴素的 Tarjan 算法处理所有\(m\)次询问的时间复杂度为\(O(m \alpha(m+n,n)+n)\)。但是 Tarjan 算法的常数比倍增算法大。存在\(O(m+n)\)的实现。

讲解

讲解