834. 树中距离之和

难度困难 375

给定一个无向、连通的树。树中有 n 个标记为 0...n-1 的节点以及 n-1 条边 。

给定整数 n 和数组 edgesedges[i] = [ai, bi] 表示树中的节点 aibi 之间有一条边。

返回长度为 n 的数组 answer ,其中 answer[i] 是树中第 i 个节点与所有其他节点之间的距离之和。

示例 1:

img

输入: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]

输出: [8,12,6,10,10,10]

解释:树如图所示。

我们可以计算出 dist (0,1) + dist (0,2) + dist (0,3) + dist (0,4) + dist (0,5)

也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer [0] = 8,以此类推。

示例 2:

img

输入: n = 1, edges = []

输出: [0]

示例 3:

img

输入: n = 2, edges = [[1,0]]

输出: [1,1]

提示:

  • 1 <= n <= 3 * 104
  • edges.length == n - 1
  • edges[i].length == 2
  • 0 <= ai, bi < n
  • ai != bi
  • 给定的输入保证为有效的树

# 换根 dp

class Solution {
    List<Integer>[] graph;
    public int[] sumOfDistancesInTree(int n, int[][] edges) {
        this.n = n;
        graph = new List[n];
        for (int i = 0; i < n; i++) {
            graph[i] = new ArrayList<>();
        }
        for (int[] edge : edges) {
            graph[edge[0]].add(edge[1]);
            graph[edge[1]].add(edge[0]);
        }
        sum = new int[n];
        cnt = new int[n];
        ans = new int[n];
        dfs(0, -1);
        ans[0] = sum[0];
        reRoot(0, -1);
        return ans;
    }
    int n;
    //0 为根节点,子树 i 到其所有子节点的距离之和
    int[] sum;
    //0 为根节点,子树 i 及其孩子节点的个数
    int[] cnt;
    int[] ans;
    public int[] dfs(int x, int fa) {
        List<Integer> list = graph[x];
        int curSum = 0;
        int curCnt = 0;
        for (Integer y : list) {
            if (y != fa) {
                int[] dfs = dfs(y, x);
                curCnt += dfs[1];
                curSum += dfs[0];
            }
        }
        sum[x] = curSum + curCnt;
        cnt[x] = curCnt + 1;
        return new int[]{sum[x], cnt[x]};
    }
    public void reRoot(int u, int fa) {
        //u -> v 的转换 =>  v -> u
        // 与 u 为跟节点的不包含 v 的其他节点全部 + 1
        //u 为根节点,到子树 v 的所有距离全部 -1,
        /* 例如
               a1       b1
                \     /
                 u - v - b3 - b4
                /     \
               a2      b2
            现在是 u 转化为 v
            dis [a1], dis [a2] 全部加 1
            dis [b1], dis [b2], dis [b3], dis [b4] 全部减 1
            然后对其求和
         可以等价转化 :
            以 v 为根节点 sum [u] = 以 u 为根节点 sum [u] - (以 u 为根节点 sum [v] + v 及其子节点个数 (因为 v -> u 也有 1 d))
            以 v 为根节点 sum [v] = 以 u 为根节点 sum [v] + 以 v 为根节点 sum [u] + 除 v 之外的包含 u 的相邻节点个数
            ==========================================
            同时也需要修改不同节点为根节点其子树的数量
            以 v 为根节点 cnt [u] = 以 u 为根节点 cnt [u] - 以 u 为根节点 cnt [v]
            以 v 为根节点 cnt [v] = 以 u 为根节点 cnt [u]
         */
        List<Integer> list = graph[u];
        for (Integer v : list) {
            if (v == fa) {
                continue;
            }
            
            int tf = sum[u];
            int tu = sum[v];
            int tcf = cnt[u];
            int tcu = cnt[v];
			// 除 v 之外的包含 u 的相邻节点个数    
            int count = n - cnt[v];
            // 以 v 为根节点 sum [u] = 以 u 为根节点 sum [u] - (以 u 为根节点 sum [v] + v 及其子节点个数)
            sum[u] = sum[u] - sum[v] - cnt[v];
            // 以 v 为根节点 sum [v] = 以 u 为根节点 sum [v] + 以 v 为根节点 sum [u] + 除 v 之外的包含 u 的相邻节点个数
            sum[v] = sum[v] + sum[u] + count;
            
            // 以 v 为根节点 cnt [u] = 以 u 为根节点 cnt [u] - 以 u 为根节点 cnt [v]
            cnt[u] = cnt[u] - cnt[v];
            // 以 v 为根节点 cnt [v] = 以 u 为根节点 cnt [u]
            cnt[v] = tcf;
            
            ans[v] = sum[v];
            reRoot(v, u);
            
            sum[u] = tf;
            sum[v] = tu;
            cnt[u] = tcf;
            cnt[v] = tcu;  
        }
    }
}