[LeetCode 378.] Kth Smallest Element in a Sorted Matrix LeetCode 378. Kth Smallest Element in a Sorted Matrix

一道经典的二分查找的题目,特点在于查找对象从一位有序数组变成了二位行列有序数组。

题目描述

Given an n x n matrix where each of the rows and columns are sorted in ascending order, return the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Example 1:

Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13

Example 2:

Input: matrix = [[-5]], k = 1
Output: -5

Constraints:

  • n == matrix.length
  • n == matrix[i].length
  • 1 <= n <= 300
  • -109 <= matrix[i][j] <= -109
  • All the rows and columns of matrix are guaranteed to be sorted in non-degreasing order.
  • 1 <= k <= n2

解题思路

查找第k大元素这一类题目,常见的思路有快速选择算法、堆、二分查找等。
这道题也可以用堆来做,空间复杂度 O(K),时间复杂度 O(N*N*logK)
但是用堆没有利用到行列有序的特点,所以时间复杂度比较高。
另一种做法是二分查找,空间复杂度 O(1),时间复杂度 O(N*logN*logM),其中 M 代表矩阵最大值与最小值的差值。

参考代码

堆查找代码:

这里简单把所有元素都入队一次,其实可以进一步优化:如果 k 比 n 小的话,只有前 k 行前 k 列可能是第 k 小元素;前 k 行前 k 列里面,第一行最多有 k 个元素是前 k 小,第二行最多 k-1 个元素是前 k 小,第三行最多 k-2 个元素是前 k 小 …… 第 k行最多有一个元素是前 k 小。这样就把入队元素压缩到了 min(N*N, k*(k+1)/2)

/*
 * @lc app=leetcode id=378 lang=cpp
 *
 * [378] Kth Smallest Element in a Sorted Matrix
 */

// @lc code=start
class Solution {
public:
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        assert(!matrix.empty());
        size_t n = matrix.size();
        priority_queue<int> q;
        for (int i=0; i<n; i++) {
            for (int j=0; j<n; j++) {
                q.push(matrix[i][j]);
                if (q.size() > k) {
                    q.pop();
                }
            }
        }
        return q.top();
    } // AC
};
// @lc code=end

二分查找代码:

/*
 * @lc app=leetcode id=378 lang=cpp
 *
 * [378] Kth Smallest Element in a Sorted Matrix
 */

// @lc code=start
class Solution {
    // each of the rows and columns are sorted
public:
    int kthSmallest(vector<vector<int>>& matrix, int kth) {
        assert(!matrix.empty());

        size_t n = matrix.size();
        int l = matrix[0][0], r = matrix[n-1][n-1];
        while (l < r) {
            int m = l + (r - l) / 2;
            int cnt = 0;
            for (int k=0; k<n; k++) {
                cnt += upper_bound(matrix[k].begin(), matrix[k].end(), m) - matrix[k].begin();
            } // upper_bound !!
            printf("%d is the %d-th number
", m, cnt);
            if (cnt < kth) l = m + 1;
            else r = m;
        }
        return l;
    } // AC
};
// @lc code=end