4. 寻找两个正序数组的中位数

题目来源:https://leetcode-cn.com/problems/median-of-two-sorted-arrays/

给定两个大小为 m 和 n 的正序(从小到大)数组 nums1nums2

请你找出这两个正序数组的中位数,并且要求算法的时间复杂度为 O(log(m + n))

你可以假设 nums1nums2 不会同时为空。

示例 1:

nums1 = [1, 3]

nums2 = [2]

则中位数是 2.0

示例 2:

nums1 = [1, 2]

nums2 = [3, 4]

则中位数是 (2 + 3)/2 = 2.5

解法

C++代码:

解法一:辅助数组

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int len1 = nums1.size(),len2 = nums2.size();
        int len = len1 + len2;
        int *nums = new int[len]; //辅助数组
        int p=0,q=0;
        for(int i=0;i<len;i++)
        {
            if(p<len1 && q<len2) //数组没有越界
            {
                nums[i] = nums1[p] > nums2[q] ? nums2[q++] : nums1[p++];
            }else
            {
                nums[i] = p==len1 ? nums2[q++] : nums1[p++]; 
            }
        }
        return len % 2 == 0 ? (nums[len/2-1]+nums[len/2])/2.0 : nums[len/2];  
    }
};

执行用时 :16 ms, 在所有 C++ 提交中击败了90.62%的用户

内存消耗 :7.4 MB, 在所有 C++ 提交中击败了100.00%的用户

复杂度分析:

时间复杂度:O(m+n)

空间复杂度:O(m+n)


解法二:双指针遍历两个数组

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int len1 = nums1.size(),len2 = nums2.size();
        int mid = (len1 + len2)/2;
        int left = -1, right = -1;
        int p = 0, q = 0;
        for(int i=0; i<=mid; i++)
        {
            left = right;
            if(p < len1 && (q >= len2 || nums1[p] < nums2[q]))
            {
                right = nums1[p++];
            }else
            {
                right = nums2[q++];
            }
        }
        if((len1 + len2) % 2 == 0)
            return (right+left)/2.0;
        return right;
    }
};

执行用时 :16 ms, 在所有 C++ 提交中击败了90.62%的用户

内存消耗 :7.1 MB, 在所有 C++ 提交中击败了100.00%的用户

复杂度分析:

时间复杂度:O(m+n)

空间复杂度:O(1)


注:以上两个解法均没有满足题意:算法的时间复杂度为 O(log(m + n))

而要满足题意,需要用二分法的思想去解决问题。


解法三:第K小数解法(折半删除法)

第K小数解法:

假设下面这个场景:

A组 1 2 3 4 5 6 7 8  
B组 9 10 11 12 13 14 15 16 17

总共17个数,中位数为第9个数,即K=9,寻找第9小的数。

解法过程:

轮次 K K/2 谁小 删除
1 9 (最开始的K) 4 A A
2 5 = 9 - 4(A中4个已经删除) 2 A A
3 3 = 5 - 2(A中2个已经删除) 1 A A
4 2 = 3 - 1(A中1个已经删除) 1 A A
5 1 = 2 - 1(A中1个已经删除) 0   选择B首元素

第1轮:A组 1、2、3、4被选中;B组9、10、11、12被选中。因为A组 4 < B组的12,因此A组前四个一定不包括中位数,因此删除;

第2轮:A组 5、6被选中;B组9、10被选中。因为A组 6 < B组的10,因此A组这两个一定不包括中位数,因此删除;

第3轮:A组7被选中;B组9被选中。因为A组 7 < B组的9,因此A组7一定不是中位数,因此删除;

第4轮:A组8被选中;B组9被选中。因为A组 8 < B组的9,因此A组8一定不是中位数,因此删除;

第5轮:K == 1,A组元素已经全部删除了,因此B首元素为第k个数,即为所求的中位数。


下对该解法进行简单的证明:

要证第K小数折半删除的方法,仅需要证明每次删除的数都在第k小数的左边即可。

已知条件:有两个有序数组,元素总数个数为n

以上述第一轮为例,后续删除同理可得

仅需证 4 的右边存在的元素个数 >= n - k + 1

又 4 的右边存在的元素个数包括A组比4大的所有元素,也包括B组>=12的所有元素。

因此 4 右边存在的元素个数 = n - (k / 2 * 2 - 1) (注:k / 2 为 K 整除 2)

所以仅需证 n - (k / 2 * 2 - 1)>= n - k + 1

证明:

n - (k / 2 * 2 - 1)>= n - k + 1……………………………(1)

k - (K / 2 * 2 ) >= 0………………………………………………(2)

k >= (K / 2 * 2 )………………………………………………….(3)

当 K = 偶数时 k>= k 显然成立

当 K = 奇数时 k>= k - 1 显然成立

∴ 每次删除的数都在第k小数的左边,即证明该算法的可行性。

扩展:

如果题目给的是n个数组呢?

解法如上,进行折n删除即可。

C++代码:

注:最开始用非递归的办法去实现,但最后要考虑很多特殊的情况,代码冗余过多,因此选择使用递归的办法求解。

递归版

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int n = nums1.size();
        int m = nums2.size();
        int left = (n + m + 1) / 2;
        int right = (n + m + 2) / 2;
        return (getKth(nums1, 0, n - 1, nums2, 0, m - 1, left) + getKth(nums1, 0, n - 1, nums2, 0, m - 1, right)) / 2.0;  
    }
private:
    int getKth(vector<int>& nums1, int start1, int end1, vector<int>& nums2, int start2, int end2, int k) {
        int len1 = end1 - start1 + 1;
        int len2 = end2 - start2 + 1;
        if (len1 > len2) return getKth(nums2, start2, end2, nums1, start1, end1, k);
        if (len1 == 0) return nums2[start2 + k - 1];

        if (k == 1) return min(nums1[start1], nums2[start2]);

        int i = start1 + min(len1, k / 2) - 1;
        int j = start2 + min(len2, k / 2) - 1;

        if (nums1[i] > nums2[j]) {
            return getKth(nums1, start1, end1, nums2, j + 1, end2, k - (j - start2 + 1));
        }
        else {
            return getKth(nums1, i + 1, end1, nums2, start2, end2, k - (i - start1 + 1));
        }
    }
};

执行用时 :16 ms, 在所有 C++ 提交中击败了90.62%的用户

内存消耗 :7.1 MB, 在所有 C++ 提交中击败了100.00%的用户

非递归版

非递归版本试验排错了很久…要考虑很多种情况,因此能用递归解决问题不建议用非递归方法解决

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int len1 = nums1.size();
        int len2 = nums2.size();
        if(len1 > len2)  //保持nums1是元素个数最小的
            return findMedianSortedArrays(nums2,nums1);
        int k = (len1 + len2 + 1)/2; //寻找第k小数
        if(len1 == 0) return (len1+len2)%2 == 0 ? (nums2[k-1]+nums2[k])/2.0 : nums2[k-1];
        int start1=0,start2=0,end1=len1-1,end2=len2-1;
        while(k / 2 != 0){
            int len = k / 2;
            if(len1 - start1 >= len){
                end1 = start1 + len - 1;
                end2 = start2 + len - 1;
                start1 = nums1[end1] >= nums2[end2] ? start1 : end1 + 1;
                start2 = nums1[end1] >= nums2[end2] ? end2 + 1 : start2;
                k -= len;
            }else{
                if(start1 >= len1){ //数组一已经完全遍历
                    return (len1+len2)%2 == 0 ? (nums2[start2+k]+nums2[start2+k-1])/2.0 : nums2[start2 + k-1];
                }else{              //数组一还未完全遍历
                    end1 = len1 - 1;
                    end2 = start2 + (len * 2) - (len1 - start1) - 1;
                    k = nums1[end1] >= nums2[end2] ? k - (len*2)+(len1 - start1) : k - (end1 - start1 + 1);
                    start1 = nums1[end1] >= nums2[end2] ? start1 : end1 + 1; //数组遍历完
                    start2 = nums1[end1] >= nums2[end2] ? end2 + 1 : start2;
                }
            }
        }
        if(start1 >= len1){ //数组一已经完全遍历
            return (len1 + len2)%2==0 ? (nums2[start2+k]+nums2[start2+k-1])/2.0 : nums2[start2+k-1]; 
        }
        if((len1 + len2) % 2 != 0) return min(nums1[start1], nums2[start2]);
        else{
            int left,right;
            if(nums1[start1] > nums2[start2]){
                left = nums2[start2++];
                right = start2 < len2 ? min(nums1[start1],nums2[start2]) : nums1[start1];
            }
            else{
                left = nums1[start1++];
                right = start1 < len1 ? min(nums1[start1],nums2[start2]) : nums2[start2];
            }
            return (left+right)/2.0;
        }
    }
};

执行用时 :12 ms, 在所有 C++ 提交中击败了97.78%的用户

内存消耗 :7.1 MB, 在所有 C++ 提交中击败了100.00%的用户

以上两种版本均能满足题目要求的时间复杂度为O(log(m + n)),空间复杂度为O(1)