[Leetcode解題] 3Sum - 前後指針解

26 November 2022
medium two-pointers bs apple amazon microsoft

題目

15. 3Sum

解題思路

這題要判斷給定的數列中,列出數列中三個數字的和為零的組合。這個問題可以先簡化成兩個數字,也就是TwoSum。

假設我們給定一個數列,還有一個給定的數字n,我們想知道,所有不重複兩個數字和為n的:

def twoSum(nums, n):
    ans = []
    # Our solution
    return ans

暴力解 $O(n^2)$

我們可以想到最簡單的方法,就是用兩層迴圈去窮舉所有可能,但我們要注意一點,因為題目提供的數列有可能有重複的數字,所以我們要先排序後,並跳過重複數字的部分。

def twoSum(nums, n):
    ans = []
    nums = sorted(nums)
    i = 0
    while i < len(nums):
        j = i+1
        while j < len(nums):
            if nums[i] + nums[j] == n:
                ans.append([nums[i], nums[j]])
            # Ensure that repeated number is not considered
            while j+1 < len(nums) and nums[j+1] == nums[j]:
                j += 1
            j += 1
        # Ensure that repeated number is not considered
        while i+1 < len(nums) and nums[i+1] == nums[i]:
            i += 1
        i += 1            
    return ans

使用這個方法,時間複雜度為排序$O(n log(n))$加上兩層迴圈$O(n^2)$,所以為$O(n^2)$

二分搜尋法(Binary Search) $O(nlog(n))$

大家觀察暴力解的作法,由於一開始我們已經將數列排序,所以第二層迴圈的部分可以用Binary Search來取代。

def binary_search(nums, n):
    if len(nums) == 0:
        return False
    upper_bound = len(nums)
    lower_bound = 0
    mid = (upper_bound + lower_bound) // 2
    if nums[mid] == n:
        return True
    if nums[mid] < n:
        return binary_search(nums[mid+1:], n)
    if nums[mid] > n:
        return binary_search(nums[:mid], n)

def twoSum(nums, n):
    ans = []
    nums = sorted(nums)
    i = 0
    while i < len(nums):
        # We want to search n-nums[i]
        if binary_search(nums[i+1:], n-nums[i]):
            ans.append([nums[i], n-nums[i]])
        # Ensure that repeated number is not considered
        while i+1 < len(nums) and nums[i+1] == nums[i]:
            i += 1
        i += 1            
    return ans

使用Binary Search,我們可以將時間複雜度降低到$O(n(logn))$

Hash

如果我們走排序後搜尋這條路,那麼二分搜尋法看起來是我們目前可得的最佳解,但是如果在空間足夠的情況下,我們可以使用Hash Table的解法,也就是我們是跑過一次序列,把有出現的值記錄起來,再來再跑過一次序列,每一次我只要對Hash Table查找(n-nums[i])的值是否存在於表中即可,時間為$O(1)$,我們可以進一步將時間壓縮到$O(n)$。

使用這個方法我們也同樣需要避免計算到重覆的組合,每一回合查找完畢,我們就可以將配對到的值移除Hash Table,因為同一個數字,一定只會對應到另一個值。

def twoSum(nums, n):
    ht = dict()
    
    for num in nums:
        ht[num] = True
    
    ans = []
    i = 0
    for num in nums:
        if n-num in ht:
            ans.append([num, n-num])
            ht.remove(num)
            ht.remove(n-num)
        i += 1            
    return ans

Greedy

如果我們的輸入可以確定是一個排序後的數列的話,其實還有另一個解法可以是$O(n)$的時間複雜度,我們用一個head指標從頭看,跟一個tail指標從尾巴看,根據兩者相加的值來移動兩個指標:

  1. 若nums[head] + nums[tail] < n,則移動head往前一格
  2. 若nums[head] + nums[tail] > n,則移動tail往後一格
  3. 若nums[head] + nums[tail] == n,則得到一組解,並同時加head往前一格,以及tail往後一格

照這樣走,我們最壞的情況就是將不斷只移動head或tail,所以時間複雜度也為$O(n)$,讓我們來仔細頗析裡面的原理。

讓我們回想暴力法的情況,代表我們需要檢驗$C^n_2$個可能性,但其中其實有一些組合是我們根本不用去看的,但當然前提是要是排序好的數列。

假設數列為[1,3,5,7,9],而n為14,我們一開始得知1+9=10<14,所以將head往前了一格,這個動作隱含的意義代表我們再也不用考慮1跟其他所有數字的可能性,而這也相當直覺,因為1連搭配最大的數字都還小於目標數字,更何況搭配更小的數字。

而若我們的n為6,1+9=10>6,所以我們將tail往後退了一格,這個動作隱含的是代表我們再也不用考慮9跟其他所有數字搭配的可能性,因為9連跟最小的數字搭配都還大於目標數字,更何況跟更大的數字配。

透過兩個指標來查找的方法相當高效,可以在每次的判斷中,直接去除掉許多不會發生的組合。

Three Sum

所以讓我們回到Three Sum的這題,由於排序不再是我們的Bottle Neck,所以我們可以直接使用兩個指標的作法,並再外層多加一層迴圈即可,所以時間複雜度為$O(n(logn)) + O(n^2)$,即為$O(n^2)$。

class Solution(object):
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
        if len(nums) < 3:
            return []
        nums = sorted(nums)
        i = 0
        prev = None
        out = []
        while i < len(nums)-2 and nums[i] <= 0:
            if prev == nums[i]:
                i += 1
                continue
            prev = nums[i]
            # Two Sum
            head = i+1
            tail = len(nums)-1
            while head < tail:
                if nums[head] + nums[tail] < -1 * nums[i]:
                    while head+1 < len(nums) and nums[head+1] == nums[head]:
                        head += 1
                    head += 1
                elif nums[head] + nums[tail] > -1 * nums[i]:
                    while tail-1 >= 0 and nums[tail-1] == nums[tail]:
                        tail -= 1
                    tail -= 1
                else:
                    out.append([nums[i], nums[head], nums[tail]])
                    while head+1 < len(nums) and nums[head+1] == nums[head]:
                        head += 1
                    while tail-1 >= 0 and nums[tail-1] == nums[tail]:
                        tail -= 1
                    head += 1
                    tail -= 1
        return out