题目描述
Given an n x n matrix where each of the rows and columns are sorted in ascending order, return the kth smallest element in the matrix.
Note that it is the kth smallest element in the sorted order, not the kth distinct element.
Example 1:
Input: matrix = [[1,5,9],[10,11,13],[12,13,15] ], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15], and the 8th smallest number is 13
Example 2:
Input: matrix = [[-5] ], k = 1
Output: -5
Constraints:
- n == matrix.length
- n == matrix[i].length
- 1 <= n <= 300
- -10^9 <= matrix[i][j] <= 10^9
- All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order.
- 1 <= k <= n^2
题目分析
这个问题和合并k个有序链表基本相同;只是数据表现方式不一样而已;
这里求解的具体思路为:
- 将matrix每一行看做一个有序链表;那么第一个最小元素肯定在第一列的某一行;
- 找到第一个最小元素之后,这个所在元素所在行往前移动1格;这样,第二个最小元素依然可以看做是在第一列的某一行
由于每次都需要在第一列的N个元素中找到最小值,那么最小堆的性质完美符合。因此,我们将第一列数据构建为一个最小堆;每弹出一个堆中最小元素之后,位于同一行的下一个元素压入堆中;如果某一行已经全部弹出,跳过即可
综上,代码如下:
示例代码
func kthSmallest(matrix [][]int, k int) int {
heap := heap{}
// 初始化堆,将第一列元素添加到堆中
for i := 0; i < len(matrix); i++ {
heap.Add(&item{
val: matrix[i][0],
idx: i * len(matrix),
})
}
count := 0
cols := len(matrix[0])
for {
node := heap.Pop()
if node == nil {
break
}
count++
// 如果弹出的元素已经满足k个,说明此时弹出的元素即为kth smallest element
if count == k {
return node.val
}
// 否则索引+1
idx := node.idx + 1
// 如果新的索引为列数的倍数,说明当前行元素已经遍历完了,这种情况跳过即可
if idx%cols != 0 {
// 计算对应的row, col
row, col := idx/cols, idx%cols
heap.Add(&item{
val: matrix[row][col],
idx: idx,
})
}
}
return -1
}
type item struct {
idx int // 将matrix的row,col转化为1维后的索引
val int // 对应的值
}
type heap struct {
list []*item
}
// 弹出堆顶元素
func (h *heap) Pop() *item {
if len(h.list) == 0 {
return nil
}
ans := h.list[0]
// 将末尾元素移动至堆顶
h.list[0] = h.list[len(h.list)-1]
h.list = h.list[:len(h.list)-1]
// 堆化
h.adjustFromTop()
return ans
}
func (h *heap) Add(item *item) {
h.list = append(h.list, item)
h.adjustFromBottom()
}
// 自顶往下堆化
func (h *heap) adjustFromTop() {
idx, swapidx := 0, 0
for {
if 2*idx+1 < len(h.list) && h.list[2*idx+1].val < h.list[swapidx].val {
swapidx = 2*idx + 1
}
if 2*idx+2 < len(h.list) && h.list[2*idx+2].val < h.list[swapidx].val {
swapidx = 2*idx + 2
}
if idx == swapidx {
break
}
h.list[idx], h.list[swapidx] = h.list[swapidx], h.list[idx]
idx = swapidx
}
}
// 自底往上堆化
func (h *heap) adjustFromBottom() {
idx := len(h.list) - 1
for {
if idx == 0 {
break
}
parentidx := (idx - 1) / 2
if h.list[parentidx].val < h.list[idx].val {
break
}
h.list[idx], h.list[parentidx] = h.list[parentidx], h.list[idx]
idx = parentidx
}
}