~/dsa/three-sum

3Sum

Medium·
  • Arrays
  • Two Pointers
  • Sorting
NeetCode 150

the trick

Sorting turns the inner search into a two-pointer scan: fix the first number, then converge from both ends to hit the target. Sorted order also makes duplicate triplets easy to skip.

The idea

Find every unique triplet that sums to zero. The "unique" part is what makes this more than a triple loop. Once I fix the first number a, the rest collapses to a familiar problem: find two numbers in the remaining array that sum to -a.

The move that unlocks everything is sorting the array first. Sorted input lets me search each pair with two pointers in linear time, and it lines up duplicates next to each other so I can skip them cleanly.

Brute force, then the cut

The naive version checks all triples with three nested loops, O(n^3), then dedupes with a set. That works but it's slow and the dedup is clumsy.

After sorting, I walk one index i for the first element. For each i, two pointers left and right converge from both sides of the remaining window. If the three-way sum is too small I move left right; too big, I move right left; zero, I record it. Because the array is sorted, advancing past equal values skips duplicate triplets without a set.

def three_sum(nums: list[int]) -> list[list[int]]:
    nums.sort()
    res: list[list[int]] = []
 
    for i in range(len(nums) - 2):
        if nums[i] > 0:
            break
        if i > 0 and nums[i] == nums[i - 1]:
            continue
 
        left, right = i + 1, len(nums) - 1
        while left < right:
            total = nums[i] + nums[left] + nums[right]
            if total < 0:
                left += 1
            elif total > 0:
                right -= 1
            else:
                res.append([nums[i], nums[left], nums[right]])
                left += 1
                right -= 1
                while left < right and nums[left] == nums[left - 1]:
                    left += 1
                while left < right and nums[right] == nums[right + 1]:
                    right -= 1
 
    return res

Why it holds

The nums[i] > 0 break is safe because three non-negative numbers can't sum to zero once the smallest is positive. Skipping equal i values, and equal left/right after a hit, makes each triplet appear once. The outer loop runs n times and each two-pointer pass is O(n), so the total is O(n^2) time. Sorting costs O(n log n), with O(1) space beyond the output.