近日刷到Leetcode 4题目。感觉这个题目作为标记hard的题目,还是很有意思的。
题目如下:
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
简单来说,就是在两串已经排序好的数组,找到其中的中位数。
但是这道题中比较反直觉的一点就是它要求了
时间复杂度要是O(log (m+n))。理论上说,任何一个程序的时间复杂度一定大于空间复杂度。因为程序开辟这些空间,或者将这些空间中填入数据,就已经要跟空间复杂度相同的时间了。
解法
那这道题是怎么做的呢。Leetcode只用你写调用的函数就好。所以它的时间复杂度是调用函数的时间复杂度,而不是整个程序的时间复杂度(整个程序的时间复杂度只要是O(m+n))。
解法是很容易想到的。因为题目已经要求时间复杂度是O(log(m+n)),所以显然可以使用二分的思想。这里提供一个找到两个序列中第k大数的思路。
每次比较两个序列的中位数大小。相应的就可以丢掉其中一个数列的一半。比如两个序列A、B长度为10、20,而B的中位数比A的大,这就说明的B的中位数一定大于15个数。如果k<15的话,那就可以直接丢掉B序列的后面一半了。然后再次调用这个函数就可以了。
具体实现还有一些细节要考虑。最后只要找第(m+n)/2大的数就好了。
代码附在最后。
时间复杂度
这里来看时间复杂度。首先程序主体部分复杂度是O(log(n+m))的。但是我开始时候的困惑是数组vector的size这布操作,复杂度是O(n)嘛?
查阅相关资料可以看到,现在vector size()时间复杂度已经是常数了。说明vector的实现方式比string高端了一点。vector本身维护了size,而这也方便了越界检查吧。
第二个巧妙之处在于,数组传参时使用的是地址,而不是把每个参数都拷贝到对应一层。所以数据传参的时间复杂度也是常数了。
题解代码
class Solution {
public:
int find(vector<int>& nums1, vector<int>& nums2, int st1, int ed1, int st2, int ed2, int n) {
int l1 = ed1-st1+1;
int l2 = ed2-st2+1;
int m1=(st1+ed1)/2;
int m2=(st2+ed2)/2;
int big= l1+l2-n+1;
if(l1==0)
return st2+n-1+nums1.size();
if(l2==0)
return st1+n-1;
if(n==1)
return (nums1[st1]<nums2[st2])?st1:st2+nums1.size();
if(l1==1 && l2==1){
return (nums1[st1]>nums2[st2])?st1:st2+nums1.size();
}
if(l1==1)
return (nums2[st2+n-1]<=nums1[st1])?(st2+n-1+nums1.size()):
( (nums2[st2+n-2]<=nums1[st1])?st1:( st2+n-2+nums1.size() )
);
if(l2==1)
return (nums1[st1+n-1]<=nums2[st2])?(st1+n-1):
( (nums1[st1+n-2]<=nums2[st2])?(st2+nums1.size()):( st1+n-2 )
);
int small1=m1-st1, small2=m2-st2;
int large1=ed1-m1, large2=ed2-m2;
if(n<=small1){
return find(nums1,nums2,st1,m1-1,st2,ed2,n);
}
if(n<=small2){
return find(nums1,nums2,st1,ed1,st2,m2-1,n);
}
if( big<=large1 ){
return find(nums1,nums2,m1+1,ed1,st2,ed2,n-small1-1);
}
if( big<=large2 ){
return find(nums1,nums2,st1,ed1,m2+1,ed2,n-small2-1);
}
if( n<=(l1+l2)/2 ){
if(nums1[m1]>nums2[m2]){
return find(nums1,nums2,st1,m1,st2,ed2,n);
} else {
return find(nums1,nums2,st1,ed1,st2,m2,n);
}
} else {
if(nums1[m1]>nums2[m2]){
return find(nums1,nums2,st1,ed1,m2+1,ed2,n-small2-1);
} else {
return find(nums1,nums2,m1+1,ed1,st2,ed2,n-small1-1);
}
}
}
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int s1=nums1.size();
int s2=nums2.size();
if((s1+s2)%2==1){
int res=find(nums1,nums2,0,s1-1,0,s2-1,(s1+s2)/2+1);
if(res<s1)
return nums1[res];
else
return nums2[res-s1];
} else {
int res1=find(nums1,nums2,0,s1-1,0,s2-1,(s1+s2)/2);
int res2=find(nums1,nums2,0,s1-1,0,s2-1,(s1+s2)/2+1);
double aa,bb;
if(res1<s1)
aa= nums1[res1];
else
aa= nums2[res1-s1];
if(res2<s1)
bb= nums1[res2];
else
bb= nums2[res2-s1];
return (aa+bb)/2;
}
}
};