1569. 将子数组重新排序得到同一个二叉搜索树的方案数

难度困难

给你一个数组 nums 表示 1n 的一个排列。我们按照元素在 nums 中的顺序依次插入一个初始为空的二叉搜索树(BST)。请你统计将 nums 重新排序后,统计满足如下条件的方案数:重排后得到的二叉搜索树与 nums 原本数字顺序得到的二叉搜索树相同。

比方说,给你 nums = [2,1,3] ,我们得到一棵 2 为根,1 为左孩子,3 为右孩子的树。数组 [2,3,1] 也能得到相同的 BST,但 [3,2,1] 会得到一棵不同的 BST 。

请你返回重排 nums 后,与原数组 nums 得到相同二叉搜索树的方案数。

由于答案可能会很大,请将结果对 10^9 + 7 取余数。

示例 1:

img

输入:nums = [2,1,3]
输出:1
解释:我们将 nums 重排, [2,3,1] 能得到相同的 BST 。没有其他得到相同 BST 的方案了。

示例 2:

img

输入:nums = [3,4,5,1,2]
输出:5
解释:下面 5 个数组会得到相同的 BST:
[3,1,2,4,5]
[3,1,4,2,5]
[3,1,4,5,2]
[3,4,1,2,5]
[3,4,1,5,2]

示例 3:

img

输入:nums = [1,2,3]
输出:0
解释:没有别的排列顺序能得到相同的 BST 。

提示:

  • 1 <= nums.length <= 1000
  • 1 <= nums[i] <= nums.length
  • nums 中所有数 互不相同

# 动态规划 + 组合数

f[ai]f[ai] 为当前节点作为根节点所得的满足条件的 BSTBST 方案

  • 设其左孩子节点个数为 leftSizeleftSize
  • 设其孩子节点个数为 totalSizetotalSize

f[ai]=CtotalSizeleftSize×f[ail]×f[air]f[ai]=C_{totalSize}^{leftSize} \times f[ail]\times f[air]

  • CtotalSizeleftSizeC_{totalSize}^{leftSize}:从 totalSizetotalSize 选择 leftSizeleftSize 个的组合数。

    由于小于 aiai 的数 leftleft 作为一个整体在剩余位置中无论放在哪里都不会影响最后生成的 BSTBST

    • 对于 1,21,2 的位置在剩余位置中任意选择 22 个位置放置,如下所示其中 33 个位置

      3____312__3_12_31_2_\begin{matrix}3 & \_ & \_&\_&\_\\3 & 1 & 2 &\_&\_\\3 & \_ & 1&2&\_\\3&1&\_&2&\_\end{matrix}\,

    leftleft 必然在根节点 33 的左边,所以 rightright 不会影响 leftleft 的放置位置

    • rightright 必然放置在根节点 33 的右边

    即可放的位置数:在 totalSizetotalSize 中选择 leftSizeleftSize 个位置

    • 剩余的位置即为 rightSizerightSize 的位置数量
  • f[ail]×f[air]f[ail]\times f[air]:根据乘法原理

    固定左边 f[ail]f[ail]f[ail]f[ail] 对应 f[air]f[air] 每一种方案

class Solution {
    int MOD = 1000000007;
    public int numOfWays(int[] nums) {
        int n = nums.length;
        // 计算组合数
        c = new long[n][n];
        // 叶子节点,组合数为 1
        c[0][0] = 1;
        for (int i = 1; i < n; i++) {
            c[0][i] = 1;
            for (int j = 1; j <= i; j++) {
                c[j][i] = (c[j][i - 1] + c[j - 1][i - 1]) % MOD;
            }
        }
        TreeNode root = new TreeNode(nums[0]);
        // 先构建一颗 BST
        for (int i = 1; i < n; i++) {
            insert(root, nums[i]);
        }
        return (int) (dfs(root) - 1) % MOD;
    }
    long[][] c;
    //f [ai] 为当前节点作为根节点所得的满足条件的 BST 方案
    public long dfs(TreeNode root) {
        if (root == null) {
            return 1;
        }
        long lf = dfs(root.left);
        long rf = dfs(root.right);
        int leftSize = root.left == null ? 0 : root.left.totalSize;
        int rightSize = root.right == null ? 0 : root.right.totalSize;
        root.totalSize = leftSize + rightSize + 1;
        return c[leftSize][root.totalSize - 1] * (lf * rf % MOD) % MOD;
    }
    public void insert(TreeNode root, int v) {
        TreeNode prev = root;
        while (root != null) {
            prev = root;
            if (root.v < v) {
                root = root.right;
            } else {
                root = root.left;
            }
        }
        if (prev.v < v) {
            prev.right = new TreeNode(v);
        } else {
            prev.left = new TreeNode(v);
        }
    }
}
class TreeNode {
    TreeNode left;
    TreeNode right;
    int totalSize;
    int v;
    public TreeNode(int v) {
        this.v = v;
    }
}

# 乘法逆元 + 费马小定理 + 并查集 + 动态规划

参考:将子数组重新排序得到同一个二叉查找树的方案数

从上述方法可以看出时间复杂度耗费在求解组合数与构建 BSTBST

组合数:Cnk=n(n1)(nk+1)k!=n!k!(nk)!C_n^k=\frac{n\cdot (n-1)\cdots(n-k+1)}{k!}=\frac{n!}{k!\cdot(n-k)!}\,

乘法逆元:若一个线性同余方程:ab1mod(m)ab\equiv 1\,mod\,(m),则称 bbamodma\, mod \,m 的乘法逆元,记作:a1a^{-1}

即:1a=bmod(m)\frac{1}{a}= b\,mod\,(m)

乘法逆元就是平时说的倒数

那么:ca=cbmod(m)\frac{c}{a} = c\cdot b\,mod\,(m)


如何求解 bb 呢?

mm 为质数,由费马小定理ab1mod(m)am1mod(m)ab\equiv1\,mod\,(m)\equiv a^{m-1}\,\,mod\,(m)

bam2mod(m)b\equiv a^{m-2}\,mod\,(m)

可以采用「快速幂算法」求解 am2mod(m)a^{m-2}\,mod\,(m)


  • 只需预处理所有 fac[i]=i!mod(m)fac[i] = i!\, mod \,(m)

  • 只需预处理所有 facInv[i]=(i!)1mod(m)facInv[i] = (i!)^{-1}\, mod \,(m)

那么 Cnk=n!k!(nk)!=fac[n]facInv[k]facInv[nk]mod(m)C_n^k=\frac{n!}{k!\cdot(n-k)!}=fac[n]\cdot facInv[k]\cdot facInv[n-k]\,mod\,(m)


如上所示,采用快速幂算法需要 log(m)log(m) 的时间,其实可以「线性求逆元」

m=ui+vm = u\cdot i + v,其中 u=m/i,v=m%iu=⌊m/i⌋ ,v=m \,\% \,i ,即:mm 除以 ii 的商和余数

两边同时 modmmod\,m

ui+v0(modm)u\cdot i + v \equiv 0\,(mod\, m)

u+vi10(modm)u + v \cdot i^{-1}\equiv 0\,(mod\, m)

uv1+i10(modm)u\cdot v^{-1} + i^{-1}\equiv 0\,(mod\, m)

i1uv1(modm)i^{-1}\equiv -u\cdot v^{-1} \,(mod\, m)

即:inv[i]=m/iinv[m%i](modm)inv[i] = - ⌊m/i⌋ \cdot inv[m \,\% \,i] \,(mod \,m)

  • fac[i]=i!mod(m)fac[i] = i!\, mod \,(m)
  • facInv[i]=(i!)1mod(m)=inv[i]facInv[i1]mod(m)facInv[i] = (i!)^{-1}\, mod \,(m)=inv[i] \cdot facInv[i - 1]\,mod\,(m)

构建 BSTBST

前提注意:numsnums 表示 11nn 的一个排列,所以每个元素有且仅出现一次

numnum 出现在 num1num - 1 之前,那么 num1num - 1 所在的树是 numnum 的左子树

num1num - 1 为根的树是 numnum 的右子树,由于 numnumnum1num - 1 之前,所以必然会走向 numnum 的左子树

numnum 出现在 num1num - 1 之后,那么 numnum 为所在的树是 num1num - 1 的右子树,且 numnum 的左子树为空

因为 num1num - 1numnum 之前遍历,若 vv 要走到 numnum 左子树,必然 v<numv < num

对于 num+1num + 1 同理。


那么如何寻找 num1num - 1num+1num + 1 所在子树的根节点呢?

可以采用并查集维护当前 numnum 所在子树的根节点

所以倒序遍历数组 numsnums

  • num+1num + 1 在之前就遍历过,那么寻找 num+1num + 1 所在子树的根节点 root=find(num+1)root = find(num + 1) ,并将 cur.right=rootcur.right = root

    否则:cur.right=nullcur.right = null

  • num1num - 1​ 在之前就遍历过,那么寻找 num1num - 1 所在子树的根节点 root=find(num1)root = find(num - 1) ,并将 cur.left=rootcur.left= root

    否则:cur.left=nullcur.left = null

并查集的维护过程

若轮播图加载不出来,请刷新

# 快速幂求解逆元

class Solution {
    int MOD = 1000000007;
    long[] fac;
    long[] facInv;
    int n;
    public int numOfWays(int[] nums) {
        n = nums.length;
        init();
        fac = new long[n];
        facInv = new long[n];
        fac[0] = 1;
        facInv[0] = 1;
        for (int i = 1; i < n; i++) {
            fac[i] = fac[i - 1] * i % MOD;
            facInv[i] = pow(fac[i], MOD - 2);
        }
        // 数字对应的节点
        Map<Integer, TreeNode> map = new HashMap<>();
        // 倒序构建 BST
        for (int i = n - 1; i >= 0; i--) {
            TreeNode node = new TreeNode(nums[i]);
            //num + 1 在之前遍历过
            if (map.containsKey(nums[i] + 1)) {
                int r = find(nums[i] + 1);
                node.right = map.get(r);
                union(r, nums[i]);
            }
            //num - 1 在之前遍历过
            if (map.containsKey(nums[i] - 1)) {
                int r = find(nums[i] - 1);
                node.left = map.get(r);
                union(r, nums[i]);
            }
            map.put(nums[i], node);
        }
        // 若结果刚好是 MOD 的整数倍时,ans = 0,那么 ans - 1 为负数
        return (int) ((dfs(map.get(nums[0])) - 1 + MOD) % MOD);
    }
    public long dfs(TreeNode root) {
        if (root == null) {
            return 1;
        }
        // 左边的方案数
        long lAns = dfs(root.left);
        // 右边的方案数
        long rAns = dfs(root.right);
        // 左子树的节点数
        int leftTotal = root.left == null ? 0 : root.left.totalSize;
        // 右子树的节点数
        int rightTotal = root.right == null ? 0 : root.right.totalSize;
        root.totalSize = leftTotal + rightTotal + 1;
        return fac[root.totalSize - 1] * facInv[leftTotal] % MOD * facInv[root.totalSize - 1 - leftTotal] % MOD * (lAns * rAns % MOD) % MOD;
    }
    // 快速幂:a^x % mod
    public long pow(long a, int x) {
        long ans = 1;
        long base = a;
        while (x > 0) {
            if ((x & 1) == 1) {
                ans = ans * base % MOD;
            }
            base = base * base % MOD;
            x >>= 1;
        }
        return ans;
    }
    int[] fa;
    public void init() {
        fa = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            fa[i] = i;
        }
    }
    public void union(int i, int j) {
        int lAncestor = find(i);
        int rAncestor = find(j);
        fa[lAncestor] = rAncestor;
    }
    public int find(int x) {
        if (x == fa[x]) {
            return x;
        }
        return fa[x] = find(fa[x]);
    }
}
class TreeNode {
    TreeNode left;
    TreeNode right;
    int totalSize;
    int v;
    public TreeNode(int v) {
        this.v = v;
    }
}

# 线性求解逆元

class Solution {
    int MOD = 1000000007;
    long[] fac;
    long[] inv;
    long[] facInv;
    int n;
    public int numOfWays(int[] nums) {
        n = nums.length;
        if (n == 1) {
            return 0;
        }
        init();
        fac = new long[n];
        inv = new long[n];
        facInv = new long[n];
        fac[0] = inv[0] = facInv[0] = 1;
        fac[1] = inv[1] = facInv[1] = 1;
        for (int i = 2; i < n; i++) {
            fac[i] = fac[i - 1] * i % MOD;
            inv[i] = -MOD / i * inv[MOD % i] % MOD;
            facInv[i] = (facInv[i - 1] * inv[i] % MOD + MOD) % MOD;
        }
        // 数字对应的节点
        Map<Integer, TreeNode> map = new HashMap<>();
        // 倒序构建 BST
        for (int i = n - 1; i >= 0; i--) {
            TreeNode node = new TreeNode(nums[i]);
            //num + 1 在之前遍历过
            if (map.containsKey(nums[i] + 1)) {
                int r = find(nums[i] + 1);
                node.right = map.get(r);
                union(r, nums[i]);
            }
            //num - 1 在之前遍历过
            if (map.containsKey(nums[i] - 1)) {
                int r = find(nums[i] - 1);
                node.left = map.get(r);
                union(r, nums[i]);
            }
            map.put(nums[i], node);
        }
        return (int) ((dfs(map.get(nums[0])) - 1 + MOD) % MOD);
    }
    public long dfs(TreeNode root) {
        if (root == null) {
            return 1;
        }
        // 左边的方案数
        long lAns = dfs(root.left);
        // 右边的方案数
        long rAns = dfs(root.right);
        // 左子树的节点数
        int leftTotal = root.left == null ? 0 : root.left.totalSize;
        // 右子树的节点数
        int rightTotal = root.right == null ? 0 : root.right.totalSize;
        root.totalSize = leftTotal + rightTotal + 1;
        return fac[root.totalSize - 1] * facInv[leftTotal] % MOD * facInv[root.totalSize - 1 - leftTotal] % MOD * (lAns * rAns % MOD) % MOD;
    }
    // 快速幂:a^x % mod
    public long pow(long a, int x) {
        long ans = 1;
        long base = a;
        while (x > 0) {
            if ((x & 1) == 1) {
                ans = ans * base % MOD;
            }
            base = base * base % MOD;
            x >>= 1;
        }
        return ans;
    }
    int[] fa;
    public void init() {
        fa = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            fa[i] = i;
        }
    }
    public void union(int i, int j) {
        int lAncestor = find(i);
        int rAncestor = find(j);
        fa[lAncestor] = rAncestor;
    }
    public int find(int x) {
        if (x == fa[x]) {
            return x;
        }
        return fa[x] = find(fa[x]);
    }
}
class TreeNode {
    TreeNode left;
    TreeNode right;
    int totalSize;
    int v;
    public TreeNode(int v) {
        this.v = v;
    }
}