1504. 统计全 1 子矩形

难度中等 159

给你一个 m x n 的二进制矩阵 mat ,请你返回有多少个 子矩形 的元素全部都是 1 。

示例 1:

img

输入:mat = [[1,0,1],[1,1,0],[1,1,0]]

输出:13

解释:
有 6 个 1x1 的矩形。
有 2 个 1x2 的矩形。
有 3 个 2x1 的矩形。
有 1 个 2x2 的矩形。
有 1 个 3x1 的矩形。
矩形数目总共 = 6 + 2 + 3 + 1 + 1 = 13 。

示例 2:

img

输入:mat = [[0,1,1,0],[0,1,1,1],[1,1,1,0]]

输出:24

解释:

有 8 个 1x1 的子矩形。
有 5 个 1x2 的子矩形。
有 2 个 1x3 的子矩形。
有 4 个 2x1 的子矩形。
有 2 个 2x2 的子矩形。
有 2 个 3x1 的子矩形。
有 1 个 3x2 的子矩形。
矩形数目总共 = 8 + 5 + 2 + 4 + 2 + 2 + 1 = 24 。

提示:

  • 1 <= m, n <= 150
  • mat[i][j] 仅包含 01

# 单调栈 + 动态规划

枚举每一层,设 height[j]height[j] 为当前层位置 jj 的向上延伸的高度

若当前位置为 jj,一种直观的做法是倒序枚举 [0j][0\sim j] 的位置并计算对答案的贡献值直至遇到 height[j]==0height[j] == 0 退出循环

例如:height:0,1,3,4,0height:0,1,3,4,0

若当前元素为 33,那么倒序枚举,则对答案的贡献值为 4+3+14+3+1

有没有什么做法可以将免去倒序枚举,在 O(1)O(1) 的时间得出答案?


dp[j]dp[j] 为当前层位置 jj 所包含的子矩形的个数,以位置 jj 作为最小高度值

  • 那么只需要往前寻到找第一个小于 height[j]height[j] 的元素下标 kk

    此时:dp[j]=dp[k]+(jk)height[j]dp[j] = dp[k] + (j - k) * height[j]

    然后累加到答案中

  • 可以参考下图:

/*
    第二部分:
    矩形的数量由
    下标 k 所指的柱子
    所能形成的矩形数量
    决定─┐
         │          ┌第一部分:(j - k) * height [i][j]
        ┌┴─┐        │
        │  │        │ 
        │  │        │  
        │  │    ┌──┐│   
        │  │ ┌──┼──┼┴─┐
        │ ┌┴─┤..│..│..│
       ┌┴─┤xx│..│..│..│
       │xx│xx│..│..│..│
     ──┴──┴──┴──┴──┴──┴──
            │        │
            k        j  
*/

参考代码

class Solution {
    public int numSubmat(int[][] mat) {
        int ans = 0;
        int n = mat.length;
        int m = mat[0].length;
        // 单调栈 + 动态规划
        // 每一层的高
        int[] height = new int[m + 1];
        for (int i = 1; i <= n; i++) {
            // 上一个小于他的元素
            // 递增 st
            Deque<int[]> st = new ArrayDeque<>();
            st.addFirst(new int[]{0, 0});
            for (int j = 1; j <= m; j++) {
                height[j] = mat[i - 1][j - 1] == 0 ? 0 : height[j] + 1;
                while (st.size() > 1 && height[st.peekFirst()[0]] >= height[j]) {
                    st.pollFirst();
                }
                // 上一个小于其元素的下标以及当前为右下角的子矩阵个数
                int[] dp = st.peekFirst();
                int cur = dp[1] + (j - dp[0]) * height[j];
                ans += cur;
                //dp[i] = dp[j] + sum(j + 1, i)
                st.addFirst(new int[]{j, cur});
            }
        }
        return ans;
    }
}

时间复杂度O(nm)O(nm)

空间复杂度O(m)O(m)