Leetcode Max Sum of Rectangle No Larger Than K

Leetcode Max Sum of Rectangle No Larger Than K。本题使用常规方法,需要O(m^2 * n^2)的时间复杂度,但是经过仔细的计设,可以使时间复杂度降低到O(smaller(n, n)^2 * larger(n, m) * log(larger(n, m))。相关步骤如下:
1. 设数组中较大的维度为dim_max, 较小的dim_min.
2. 针对较大的维度,我们求出其所有项前缀和,即:如果row <col,则针对每一列求出前缀和。
3. 针对较小的维度,我们生成(i, j): i <= j项,针对每一项,我们找到不大于K子项和。即:如果row<col,则生成行数不同子矩阵,然后求得最大且不大于k的子项和。
注:在此算法中,会用于求数组中,最大且不大于k的子数组算法,此算法使用前缀和的方法求得,并使用搜索二叉树加速。
相关代码如下:

#include <iostream>
#include <vector>
#include <algorithm>
#include <iterator>

using namespace std;
struct Node {
    int val;
    struct Node* left;
    struct Node* right;
    Node(int val): val(val), left(nullptr), right(nullptr) {}
};
class Solution {
public:
    int maxSumSubmatrix(vector<vector<int> >& matrix, int k) {
        if (matrix.size() == 0) {
            return 0;
        }
        int max = 0;
        int tmp = 0;
        int flag = 0;
        int row1, row2, col1, col2;

        vector<vector<int> > tmpValue(matrix.size(),
                vector<int>(matrix[0].size(), 0));
        // Get the prefix sum of the maximum dimension
        getTmpVlaue(tmpValue, matrix);

        if (matrix.size() <= matrix[0].size()) {
            // row <= col
            // Get each [i, j] (i <= j, i,j belongs row) sub matrix maximum value
            for (row1 = 0; row1 < matrix.size(); ++row1) {
                for (row2 = row1; row2 < matrix.size(); ++row2) {
                    getResultByRow(tmpValue, row1, row2, k, max, flag);
                    if (max == k) {
                        return k;
                    }
                }
            }
        } else {
            // row > col
            // Get each [i, j] (i < j, i,j belongs col) sub matrix maximum value
            for (col1 = 0; col1 < matrix[0].size(); ++col1) {
                for (col2 = col1; col2 < matrix[0].size(); ++col2) {
                    getResultByCol(tmpValue, col1, col2, k, max, flag);
                    if (max == k) {
                        return k;
                    }
                }
            }
        }
        return max;
    }
    /** * Get the prefix sum for the maximum dimension */
    void getTmpVlaue(vector<vector<int> >& tmpValue,
            vector<vector<int> >& matrix) {
        int row_a = 0;
        int col_a = 0;
        // Judge which dimension is larger
        if (matrix.size() <= matrix[0].size()) {
            row_a = 1;
        } else {
            col_a = 1;
        }
        for (int row = 0; row < matrix.size(); ++row) {
            for (int col = 0; col < matrix[row].size(); ++col) {
                if ((row == 0 && row_a == 1) || (col == 0 && col_a == 1)) {
                    tmpValue[row][col] = matrix[row][col];
                } else {
                    tmpValue[row][col] = tmpValue[row - row_a][col - col_a]
                        + matrix[row][col];
                }
            }
        }
    }
    /** * Get the row major sub matrix maximum value */
    void getResultByRow(vector<vector<int> >& tmpValue, int row1, int row2,
            int k, int& max, int& flag) {
        int value = 0;
        int start = 0;
        int cur = 0;
        Node* ptr = nullptr;
        Node* root = new Node(0);
        for (cur = 0; cur < tmpValue[0].size(); ++cur) {
            if (row1 == 0) {
                value += tmpValue[row2][cur];
            } else {
                value += tmpValue[row2][cur] - tmpValue[row1 - 1][cur];
            }
            if (value == k) {
                max = k;
                return;
            }
            // Search the binary tree find the maximum approriately prefix sum
            ptr = getMinNodeGreaterThan(root, value - k);
            if (ptr && (value - ptr->val > max || flag == 0)) {
                max = value - ptr->val;
                flag = 1;
            }
            // Add the node to the binary search tree
            addNode(root, value);
        }
        // Free the binary search tree
        freeTree(root);
    }
    /** * Get the column major sub matrix maximum value */
    void getResultByCol(vector<vector<int> >& tmpValue, int col1, int col2,
            int k, int& max, int& flag) {
        int value = 0;
        int start = 0;
        int cur = 0;
        Node* ptr = nullptr;
        Node* root = new Node(0);
        for (cur = 0; cur < tmpValue.size(); ++cur) {
            if (col1 == 0) {
                value += tmpValue[cur][col2];
            } else {
                value += tmpValue[cur][col2] - tmpValue[cur][col1 - 1];
            }
            if (value == k) {
                max = k;
                return;
            }
            // Search the binary tree find the maximum approriately prefix sum
            ptr = getMinNodeGreaterThan(root, value - k);
            if (ptr && (value - ptr->val > max || flag == 0)) {
                max = value - ptr->val;
                flag = 1;
            }
            // Add the node to the binary search tree
            addNode(root, value);
        }
        // Free the binary search tree
        freeTree(root);
    }
    /** * Add the value the binary search tree with the given root */
    void addNode(Node* root, int val) {
        if (val == root->val) {
            return;
        }
        if (val < root->val) {
            if (root->left == nullptr) {
                root->left = new Node(val);
            } else {
                addNode(root->left, val);
            }
        } else {
            if (root->right == nullptr) {
                root->right = new Node(val);
            } else {
                addNode(root->right, val);
            }
        }
    }
    /** * Search the binary search tree get the minimum element which greater than * the given val. */
    Node* getMinNodeGreaterThan(Node* root, int val) {
        Node* cur = root;
        Node* re = nullptr;
        while (cur) {
            if (cur->val == val) {
                return cur;
            } else if (cur->val > val) {
                if (re == nullptr || cur->val < re->val) {
                    re = cur;
                }
                cur = cur->left;
            } else if (cur->val < val) {
                cur = cur->right;
            }
        }
        return re;
    }
    /** * Free the binary search tree with the given root */
    void freeTree(Node* root) {
        queue<Node*> leftElement;
        leftElement.push(root);
        Node* ptr = root;
        while (!leftElement.empty()) {
            ptr = leftElement.front();
            leftElement.pop();
            if (ptr->left) {
                leftElement.push(ptr->left);
            }
            if (ptr->right) {
                leftElement.push(ptr->right);
            }
            delete ptr;
        }
    }
};

/** * Test */
int main(int argc, char* argv[]) {
    Solution so;
    vector<vector<int> > matrix {{27,5,-20,-9,1,26,1,12,7,-4,8,7,-1,5,8},{16,28,8,3,16,28,-10,-7,-5,-13,7,9,20,-9,26},{24,-14,20,23,25,-16,-15,8,8,-6,-14,-6,12,-19,-13},{28,13,-17,20,-3,-18,12,5,1,25,25,-14,22,17,12},{7,29,-12,5,-5,26,-5,10,-5,24,-9,-19,20,0,18},{-7,-11,-8,12,19,18,-15,17,7,-1,-11,-10,-1,25,17},{-3,-20,-20,-7,14,-12,22,1,-9,11,14,-16,-5,-12,14},{-20,-4,-17,3,3,-18,22,-13,-1,16,-11,29,17,-2,22},{23,-15,24,26,28,-13,10,18,-6,29,27,-19,-19,-8,0},{5,9,23,11,-4,-20,18,29,-6,-4,-11,21,-6,24,12},{13,16,0,-20,22,21,26,-3,15,14,26,17,19,20,-5},{15,1,22,-6,1,-9,0,21,12,27,5,8,8,18,-1},{15,29,13,6,-11,7,-6,27,22,18,22,-3,-9,20,14},{26,-6,12,-10,0,26,10,1,11,-10,-16,-18,29,8,-8},{-19,14,15,18,-10,24,-9,-7,-19,-14,23,23,17,-5,6}};
    int max = so.maxSumSubmatrix(matrix, -100);
    cout << "maximum is: " << max << endl;
    return 0;
}
点赞