In today’s Programming Praxis exercise, our goal is to implement a Set data structure. Let’s get started, shall we?

import Data.Hashable
import qualified Data.HashTable.IO as H
import Data.List (sort)

The data structure underlying our Set will be a hashtable. This does have the downside that all operations will be monadic, but has the advantage that Set elements do not need to implement Ord. Initially I used the HashTable from Data.HashTable.ST.Basic, but I decided that having everything operate in the IO monad would be more convenient when using it.

data Set a = Set (H.BasicHashTable a ())

new, member, adjoin and delete are thin wrappers around the existing hashtable functions. Since we only care about the keys in the hashtable, we simply insert Unit as values. Additionally, we make adjoin and delete return the modified set to make chaining operations easier.

new :: IO (Set a)
new = fmap Set H.new
member :: (Eq a, Hashable a) => a -> Set a -> IO Bool
member x (Set s) = fmap (maybe False $ const True) $ H.lookup s x
adjoin :: (Eq a, Hashable a) => a -> Set a -> IO (Set a)
adjoin x (Set s) = H.insert s x () >> return (Set s)
delete :: (Eq a, Hashable a) => a -> Set a -> IO (Set a)
delete x (Set s) = H.delete s x >> return (Set s)

fold is a convenience function that reorders the parameters of the existing fold on hashtables and ignores the values, which results in significantly cleaner code in some of the functions below.

fold :: (a -> b -> IO b) -> Set a -> b -> IO b
fold f (Set s) x = H.foldM (\a (k,_) -> f k a) x s

For a union, we simply insert all the keys of both sets in a new one. Thanks to the fold function we can chain everything together nice and neat.

union :: (Eq a, Hashable a) => Set a -> Set a -> IO (Set a)
union s1 s2 = fold adjoin s2 =<< fold adjoin s1 =<< new

Since intersect and minus are virtually identical, I’ve refactored the common code into a combine function.

combine :: (Eq a, Hashable a) => (Bool -> Bool) -> Set a -> Set a -> IO (Set a)
combine cond s1 s2 = fold (\k a -> member k s2 >>= \b ->
if cond b then adjoin k a else return a) s1 =<< new

The insersect function takes the elements from the first set that do exist in the other one…

intersect :: (Eq a, Hashable a) => Set a -> Set a -> IO (Set a)
intersect = combine id

and the minus functions takes the ones that don’t.

minus :: (Eq a, Hashable a) => Set a -> Set a -> IO (Set a)
minus = combine not

To convert the hashtable to a list we just cons all the elements together. Note that since the order is determined by the hashing algorithm, the resulting list is not guaranteed to be ordered. Hence you will see calls to sort in the tests when the results are checked.

toList :: Set a -> IO [a]
toList s = fold ((return .) . (:)) s []

We could calculate the size of the set with another fold, but this is shorter, more intuitive and works just as well.

size :: Set a -> IO Int
size = fmap length . toList

Some tests to see if everything is working correctly:

main :: IO ()
main = do s <- adjoin 1 =<< adjoin 2 =<< adjoin 3 =<< new
t <- adjoin 3 =<< adjoin 4 =<< adjoin 5 =<< new
print . (== [1..3]) . sort =<< toList s
print . (== 3) =<< size s
print . (== [3..5]) . sort =<< toList t
print . (== 3) =<< size t
print . (== [3 :: Int]) =<< toList =<< intersect s t
print . (== [1..5]) . sort =<< toList =<< union s t
print . (== [1..2]) . sort =<< toList =<< minus s t