
LeetCode 416: Partition Equal Subset Sum gives us another chance to practice multidimensional dynamic programming. The problem asks:
Given an integer array
nums, returntrueif you can partition the array into two subsets A and B such that the sum of the elements in A equals the sum of the elements in B. Returnfalseif this is not possible.
Recursive Solution
If you’re not sure how to solve a LeetCode problem, it’s useful to write a brute-force solution. It will help you understand the problem, and you can run it on small test cases as an initial correctness check. With dynamic programming problems, there’s the additional benefit of a step-by-step process to make your brute-force solution much faster. Let’s try that for Partition Equal Subset Sum.
The standard brute-force approach for DP problems uses a recursive design. Consider a recursive function IsEqual that returns true if the input array can be partitioned given a starting position in the array and subset sums for A and B. We can use the following parameters:
nums: the input array, of sizen.pos: the current position in the input array.sumA: the current sum of the elements in subset A.sumB: the current sum of the elements in subset B.
As a base case, if we’re past the end of the array, we can just directly compare sumA and sumB. So if pos == n, we return true if sumA == sumB and false otherwise.
If we have not reached the end of the input array, then we have two choices: We can add the current number to sumA, or we can add it to sumB. Then we move to the next position in the array. If adding the current number to sumA eventually results in equal subset sums, or adding the current number to sumB eventually results in equal subset sums, then we have a valid solution. If it is not possible to get equal subset sums regardless of which subset we add the current number to, then a solution is not possible from this state. Here’s the recursive call to implement that:
return
IsEqual(nums, pos+1, sumA + nums[pos], sumB) ||
IsEqual(nums, pos+1, sumA, sumB + nums[pos])
As usual with recursive solutions, the calls to IsEqual form a tree. Each call creates two child nodes, the left child where we add the current element to sumA and the right child where we add it to sumB. If either of these calls returns true, we have found a path through the tree that ends at a leaf node with sumA == sumB.
Memoization
The recursive solution will return the correct result for small test cases. But if we draw out the call tree, we’ll find many nodes with the same pos, sumA, and sumB values. That means we’re making many unnecessary recursive calls. To avoid an explosion of recursive calls, we should store the result of each unique call in a memo table.
As currently designed, we would need a 3D array for our memo table, since we have three parameters. Let’s check if that is really necessary.
Let total be the sum of all the input array elements. We want to find sumA == sumB. We need to use all the elements, so sumA + sumB == total. Combining those two equations gives us sumA = total/2 and sumB = total/2. So we only need to keep track of one of the two sums, which we can compare to total/2. This lets us reduce the dimensions of our memo table to 2D.
Also, the problem details tell us that each input integer is between 1 and 100 (inclusive). So once our sum value exceeds total/2, there’s no way to get a solution, since the sum can never decrease. We can immediately return false in this case.
Finally, if total is odd, there’s no way for the sums of both subsets to be total/2. So in this case we can return false before we make any recursive calls.
With these optimizations, we can implement our memo table using a boolean array dp with dimensions n+1 and total+1 to store each result once we have calculated it the first time. As an implementation detail, we’ll want a boolean data type that can store three values: null, false, and true. A null value means we haven’t calculated that result yet.
Pseudocode
bool CanPartition(int[] nums)
n = length of nums
total = sum of all elements in nums
// we can't divide an odd integer into two equal halves
if total is odd, return false
initialize bool dp[n+1][total+1]
return IsEqual(0, 0, nums, dp)
bool IsEqual(pos, sum, nums, dp, n)
// sum is too large (and it can never decrease)
if sum > total/2
return false
// finished processing; check the result
if pos == n
return sum == total/2
// case 1: don't use the current number in the sum
if dp[pos+1][sum] == null
dp[pos+1][sum] = IsEqual(pos+1, sum, nums, dp, n)
// case 2: do use the current number in the sum
newSum = sum + nums[pos]
if dp[pos+1][newSum] == null
dp[pos+1][newSum] = IsEqual(pos+1, newSum, nums, dp, n)
// if either case returned true, we have a true result
return dp[pos+1][sum] || dp[pos+1][newSum]
References
For an epic whiteboard editorial for this problem, see George Chrysochoou’s post on LeetCode.
(Image credit: DALLĀ·E 3)
For an introduction to this year’s project, see A Project for 2024.