Skip to content

373. Find K Pairs with Smallest Sums

You are given two integer arrays nums1 and nums2 sorted in non-decreasing order and an integer k.

Define a pair (u, v) which consists of one element from the first array and one element from the second array.

Return the k pairs (u1, v1), (u2, v2), ..., (uk, vk) with the smallest sums.

Example 1:

Input: nums1 = [1,7,11], nums2 = [2,4,6], k = 3
Output: [[1,2],[1,4],[1,6]]
Explanation: The first 3 pairs are returned from the sequence: [1,2],[1,4],[1,6],[7,2],[7,4],[11,2],[7,6],[11,4],[11,6]

Example 2:

Input: nums1 = [1,1,2], nums2 = [1,2,3], k = 2
Output: [[1,1],[1,1]]
Explanation: The first 2 pairs are returned from the sequence: [1,1],[1,1],[1,2],[2,1],[1,2],[2,2],[1,3],[1,3],[2,3]

Example 3:

Input: nums1 = [1,2], nums2 = [3], k = 3
Output: [[1,3],[2,3]]
Explanation: All possible pairs are returned from the sequence: [1,3],[2,3]

Solution:

一、初步思路 为描述方便,下文把 \(nums_1\)记作 a,\(nums_2\)记作 b。

哪个数对的和最小?

由于数组是有序的,\((a[0],b[0])\) 是和最小的数对,计入答案。

哪个数对的和第二小?

次小只能是 \((a[0],b[1])\)\((a[1],b[0])\),其它没有计入答案的数对和不会比这两个更小。

\((a[0],b[1])\)\((a[1],b[0])\) 这两个数对和的大小还好比较,但如果要求第 k 小,就要涉及到更多的数对,那样就更加复杂了。如何按从小到大的顺序快速地求出这些数对呢?

二、借助最小堆

为了更高效地比大小,我们可以借助最小堆来优化。

堆中保存下标对 \((i,j)\),即可能成为下一个数对的 \(a\) 的下标 \(i\)\(b\) 的下标 \(j\)。堆顶是最小的\(a[i]+b[j]\)

初始把 \((0,0)\) 入堆。

每次 \((i,j)\) 出堆时,把候选项 \((i+1,j)\)\((i,j+1)\) 入堆。(和「初步思路」中的讨论一样,其它的不会比这两个更小。)

但这会导致一个问题:例如当 \((1,0)\) 出堆时,会把 \((1,1)\) 入堆;当 \((0,1)\) 出堆时,也会把 \((1,1)\) 入堆,这样堆中会有重复元素。为了避免有重复元素,还需要额外用一个哈希表记录在堆中的下标对。只有当下标对不在堆中时,才能入堆。

能否不用哈希表呢?

三、优化 换个角度,如果要把 \((i,j)\) 入堆,那么之前出堆的下标对是什么?

根据上面的讨论,出堆的下标对只能是 \((i−1,j)\)\((i,j−1)\)

只要保证 \((i−1,j)\)\((i,j−1)\) 的其中一个会将 \((i,j)\) 入堆,而另一个什么也不做,就不会出现重复了!

不妨规定 \((i,j−1)\) 出堆时,将 \((i,j)\) 入堆;而 \((i−1,j)\) 出堆时只计入答案,其它什么也不做。

换句话说,在 \((i,j)\) 出堆时,只需将 \((i,j+1)\) 入堆,无需将 \((i+1,j)\) 入堆。

但若按照该规则,初始仅把 \((0,0)\) 入堆的话,只会得到 $(0,1),(0,2),⋯ $这些下标对。

所以初始不仅要把 \((0,0)\) 入堆,\((1,0),(2,0),⋯\) 这些都要入堆。

代码实现时,为了方便比较大小,实际入堆的是三元组 \((a[i]+b[j],i,j)\)

另一种理解角度

示例 1 的 \(nums_1=[1,7,11]\)\(nums_2=[2,4,6]\)。我们把每个数对的和算出来,可以得到一个矩阵\(M\),其中 \(M_{i,j} =nums_1[i]+nums_2[j]\)。 $$ M = \left [ \begin{matrix} 3 & 5 & 7 \ 9 & 11 & 13 \ 13 & 15 & 17 \end{matrix} \right ] $$ 由于 \(nums_2\)是递增的,所以矩阵每一行都是递增的。问题相当于:

  • 合并 \(n\) 个升序列表,找前 k 小元素。(其中\(n\)\(nums_1\) 的长度)

根据 23. 合并 K 个升序链表 的 堆的做法:

  1. 把矩阵每一行的第一个数 \(M_{i,0}\)及其位置 \((i,0)\) 加到最小堆中。
  2. 循环 k 次。
  3. 每次循环,弹出堆顶,把堆顶 \(M_{i,j}\)的对应数对加入答案,把堆顶右边元素\(M_{i,j+1}\)及其位置 \((i,j+1)\) 入堆。
class Solution {
    public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
       List<List<Integer>> result = new ArrayList<>();

       PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> (a[2] - b[2]));

       //    2 4 6
       // 1  3
       // 7  9
       // 11 13
       for (int i = 0; i < Math.min(nums1.length, k); i++){
        pq.offer(new int[]{i, 0, nums1[i] + nums2[0]});
       }

       for (int i = 0; i < k; i++){
        int[] cur = pq.poll();
        int x = cur[0];
        int y = cur[1];
        result.add(List.of(nums1[x], nums2[y]));
        if (y + 1 < nums2.length){
            pq.offer(new int[]{x, y + 1, nums1[x] + nums2[y + 1]});
        }
       }

       return result;
    }
}
class Solution {
    public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
       List<List<Integer>> result = new ArrayList<>(k); // 预分配空间
       PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] - b[0]);

       for (int i = 0; i < Math.min(nums1.length, k); i++){ // 至多 k 个
        pq.add(new int[]{nums1[i] + nums2[0], i, 0});// {1+2, 0, 0}, {7 + 2, 1, 0}, {11 + 2, 2, 0}
        // {3, 0, 0}, {9, 1, 0}, {13, 2, 0}
       } 
       for (int z = 0; z < k; z++){
        int[] top = pq.poll();
        int i = top[1];
        int j = top[2];
        result.add(List.of(nums1[i], nums2[j]));
        if (j + 1 < nums2.length){
            pq.add(new int[]{nums1[i] + nums2[j + 1], i, j + 1});
        }
       }

       return result;
    }
}

// TC: O(klogmin(n,k)),其中 n 为 nums_1的长度。为了得到 k 个数对,需要循环 k 次,每次出堆入堆的时间复杂度为 logmin(n,k)。所以总的时间复杂度为 O(klogmin(n,k))。
// SC: O(min(n,k))。堆中至多有 O(min(n,k)) 个三元组。
class Solution {
    public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
        int n = nums1.length;
        int m = nums2.length;
        List<List<Integer>> result = new ArrayList<>();
        if (n == 0 || m == 0 || k == 0) return result; // edge case

        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> 
            (nums1[a[0]] + nums2[a[1]]) - (nums1[b[0]] + nums2[b[1]])); // min heap => sorting by the pair sum

        for (int i = 0; i < Math.min(n, k); i++) {
            pq.offer(new int[]{i, 0}); // insert all pair of nums1[i] with nums2[0]
        }

        // Think of it as grid
        //           2    4    6  
        //  ------ | -- | -- | -- | --
        //  1      | 3  | 5  | 7  |
        //  7      | 9  | 11 | 13 |
        //  11     | 13 | 15 | 17 |

        while (k-- > 0 && !pq.isEmpty()) { // until k!=0 or pq != empty
            int[] curr = pq.poll(); // retrieve min
            int i = curr[0]; // index of nums1
            int j = curr[1]; // index of nums2

            result.add(new ArrayList<>(Arrays.asList(nums1[i], nums2[j]))); // add this pair to result

            if (j + 1 < m) { // next pair
                pq.offer(new int[]{i, j + 1});
            }
        }
        return result;
    }
}
class Solution {
    static class Cell{
        int x;
        int y;
        int value;

        Cell(int x, int y, int value){
            this.x = x;
            this.y = y;
            this.value = value;
        }
    }

    public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
        int m = nums1.length;
        int n = nums2.length;

        List<List<Integer>> result = new ArrayList<>();

        //boolean[][] visited = new boolean[m][n];
        Set<Pair<Integer, Integer>> visited = new HashSet<Pair<Integer, Integer>>();

        PriorityQueue<Cell> minHeap = new PriorityQueue<Cell>(k, new Comparator<Cell>(){
            public int compare(Cell c1, Cell c2){
                if (c1.value == c2.value){
                    return 0;
                }else if (c1.value < c2.value){
                    return -1;
                }else{
                    return 1;
                }
            }
        });


        minHeap.offer(new Cell(0, 0, nums1[0]+ nums2[0]));
        // visited[0][0] = true;
        visited.add(new Pair<Integer,Integer>(0,0));

        for (int i = 0; i < k; i++){
            Cell cur = minHeap.poll();
            result.add(List.of(nums1[cur.x], nums2[cur.y]));

            if (cur.x + 1 < m && !visited.contains(new Pair<Integer, Integer>(cur.x+ 1, cur.y))){
                minHeap.offer(new Cell(cur.x + 1, cur.y, nums1[cur.x+1]+ nums2[cur.y]));
                // visited[cur.x+1][cur.y] = true;
                visited.add(new Pair<Integer, Integer>(cur.x+1, cur.y));
            }

            if (cur.y + 1 < n && !visited.contains(new Pair<Integer, Integer>(cur.x, cur.y + 1))){
                minHeap.offer(new Cell(cur.x, cur.y + 1, nums1[cur.x]+ nums2[cur.y+1]));
                visited.add(new Pair<Integer, Integer>(cur.x, cur.y + 1));
            }

        }

        return result;
    }
}
// TC: O(klogk)
// SC: O(k)
class Solution {
    static class Cell{
        int x;
        int y;
        int value;

        Cell(int x, int y, int value){
            this.x = x;
            this.y = y;
            this.value = value;
        }
    }

    public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
        int m = nums1.length;
        int n = nums2.length;

        List<List<Integer>> result = new ArrayList<>();
        boolean[][] visited = new boolean[m][n];

        PriorityQueue<Cell> minHeap = new PriorityQueue<Cell>(k, new Comparator<Cell>(){
            public int compare(Cell c1, Cell c2){
                if (c1.value == c2.value){
                    return 0;
                }else if (c1.value < c2.value){
                    return -1;
                }else{
                    return 1;
                }
            }
        });

        minHeap.offer(new Cell(0, 0, nums1[0]+ nums2[0]));
        visited[0][0] = true;

        for (int i = 0; i < k; i++){
            Cell cur = minHeap.poll();
            result.add(List.of(nums1[cur.x], nums2[cur.y]));

            if (cur.x + 1 < m && !visited[cur.x+ 1][cur.y]){
                minHeap.offer(new Cell(cur.x + 1, cur.y, nums1[cur.x+1]+ nums2[cur.y]));
                visited[cur.x+1][cur.y] = true;
            }

            if (cur.y + 1 < n && !visited[cur.x][cur.y + 1]){
                minHeap.offer(new Cell(cur.x, cur.y + 1, nums1[cur.x]+ nums2[cur.y+1]));
                visited[cur.x][cur.y+1] = true;
            }

        }

        return result;
    }
}

// TC: O(klogk)
// SC: O(k)