Programming Praxis – List Intersection And Union

In today’s Programming Praxis exercise, our goal is to write union and intersection functions for lists in three different time complexities: O(n^2), O(n log n) and O(n). We then need to show that our functions actually have the correct complexity by timing them. Let’s get started, shall we?

Some imports:

import qualified Data.IntSet as I
import qualified Data.Set as S
import System.CPUTime
import Text.Printf

Rather than trying to come up with entirely different algorithms, I’m going to use this opportunity to show the importance of choosing an appropriate datastructure for your algorithms. We do this by using essentially the same algorithm for all three complexities, but each with a different data structure.

The simplest algorithm for union is to take all the elements in the first list plus the ones from the second list that we don’t already have. Taking the intersection of two list can be achieved by taking all the elements from the first list that also appear in the second list. The interesting operation in both algorithms is seeing whether an element is present in a list, since it determines the overall time complexity. On a linked list, this operation is O(n). There are different data structures, however, that offer better performance. The generic algorithms, therefor, need to know how to convert a linked list to the desired data structure and how to find an element in that data structure.

genUnion, genIntersect :: ([a] -> b) -> (a -> b -> Bool) -> [a] -> [a] -> [a]
genUnion load get xs ys = xs ++ filter (not . flip get (load xs)) ys

For the O(n^2)  version, we use a linked list as our datastructure, so there’s no conversion to be done.

union_n2, intersect_n2 :: Eq a => [a] -> [a] -> [a]
union_n2        = genUnion     id elem
intersect_n2    = genIntersect id elem

By storing our list in a Set, we get O(log n) lookup rather than O(n), which reduces the total complexity to O(n log n).

union_nlogn, intersect_nlogn :: Ord a => [a] -> [a] -> [a]
union_nlogn     = genUnion     S.fromList S.member
intersect_nlogn = genIntersect S.fromList S.member

Since we’re working with Ints, we can store them in an IntSet, which further reduces the lookup time to O(1), resulting in O(n) total time.

union_n, intersect_n :: [Int] -> [Int] -> [Int]
union_n         = genUnion     I.fromList I.member
intersect_n     = genIntersect I.fromList I.member

Timing how long something takes in Haskell is a bit trickier than in most languages since Haskell uses lazy evaluation. If we don’t print the result or otherwise use it, Haskell will happily avoid doing any work and report that doing nothing took 0 ms. Since I’m only interested in the duration here I don’t want to print the result, I instead used the \$! operator to evaluate the result to head normal form. Note that this only works if the result of the function you’re benchmarking is strict; an Int, for example, works fine, but if the result is a list only the first element will be evaluated, which will once again lead to a time of 0 ms.

time :: a -> IO Integer
time f = do start <- getCPUTime
_     <- return \$! f
end   <- getCPUTime
return (div (end - start) (10^9))

To benchmark a function we feed it a number of lists that double in size each time and print the results in a row. We use the length function to force evaluation of the list and return something that’s in head normal form.

benchmark :: ([Int] -> [Int] -> [Int]) -> [Int] -> IO ()
benchmark f ns = putStrLn . (printf "%7d" =<<) =<<
mapM (\n -> time (length \$ f [1..n] [1..n])) ns

All that’s left to do is to first test whether our functions work correctly and then to benchmark them. I cut the O(n^2) version off at 80000 elements because I’m not that patient; adding the other four steps would take almost three hours.

main :: IO ()
main = do let a = [4,7,12,6,17,5,13]
let b = [7,19,4,11,13,2,15]
let testUnion f = print \$ f a b == [4,7,12,6,17,5,13,19,11,2,15]
let testIsect f = print \$ f a b == [4,7,13]
testUnion union_n2
testIsect intersect_n2
testUnion union_nlogn
testIsect intersect_nlogn
testUnion union_n
testIsect intersect_n

let ns = iterate (*2) 10000
benchmark union_n2    \$ take 4 ns
benchmark union_nlogn \$ take 8 ns
benchmark union_n     \$ take 8 ns

Here’s the resulting timing table. As we can see, the O(n^2) version takes four times as long when the input length doubles, which shows that it is indeed O(n^2). The O(n log n) grows a little faster than O(n) as expected, and the O(n) version appears to be linear. Looks like the default Haskell collection types perform like they should. The 0 and 15 ms timings at the beginning are a consequence of running this test on Windows, which has a timer resolution of 15 ms.

546   2043   8252  32963
15     15     46     93    187    436    858   1872
0      0     15     31     62    124    296    592

As we can see, even if we keep the algorithm the same, the choice of data structure can have a huge impact on performance.