Posts Tagged ‘datastructures’

AVL Tree Implementation in Clojure

Friday, July 17th, 2009

To learn Clojure (and data structures) better, I decided to write an AVL tree implementation in Clojure. It was complicated enough that I ran into problems with Clojure, while still being doable in about a day, so it was about the right size for what I wanted to do.

An AVL tree is a balanced binary search tree, optimized for fast lookups. It’s much like a Red-Black tree, except it has faster asymptotic lookup times. At any node, the difference between the height of the node’s subtrees can be at most 1 - if this invariant is not true, the tree is not a valid AVL tree. The implementation I wrote is entirely immutable: It doesn’t ever change the data on a node, just creates a new one. This is great for concurrency, since there aren’t ever any locking issues on immutable data.

(import '(java.util Random))
(use 'clojure.contrib.test-is)

These are just the imports needed to implement the AVL Tree. The only Java library you need if for generating random numbers, and Clojure’s test-is framework was what I used for unit-testing some of the smaller methods.

(defstruct avl-tree :data :height :left :right)

Defining an avl-tree structure. Ended up only using it it one place and just overwriting attributes with assoc usually, but this also serves as documentation of the structure.

(defn get-height
  "Returns the height of an AVL tree, or -1 if nil"
  ([tree] (if tree (tree :height) -1))
  {:test (fn []
	   (is (= (get-height {:height 0}) 0))
	   (is (= (get-height nil) -1)))})

This gets the height of an AVL tree, or just -1 if the tree doesn’t exist. This is very useful, since it allows treating empty-subtrees the same as actual nodes for purposes of height calculations. If we just attempted to do (tree :height), a null pointer exception could be thrown, instead of returning -1.

(defn balance-factor
  "Returns the height of the right subtree - height of left subtree"
  ([tree] (- (get-height (tree :right)) (get-height (tree :left))))
  {:test (fn []
	   (is (balance-factor {:right {:height 3} :left {:height 5}}) -2))})

This just returns the balance factor, which the height of the right subtree minus the height of the left subtree. This is less than 2 for all nodes in a AVL tree.

(defn rrotate [tree]
  "Performs a right rotation on an AVL tree"
  (let [right-height (inc (max (get-height (tree :right))
       		     	       (get-height ((tree :left) :right))))]
    (assoc (tree :left)
      :height (inc (max right-height (get-height ((tree :left) :left))))
      :right (assoc tree
	       :height right-height
	       :left (if (tree :left) ((tree :left) :right) nil)))))
 
(defn lrotate [tree]
  "Performs a left rotation on an AVL tree"
  (let [left-height (inc (max (get-height (tree :left))
       		    	      (get-height ((tree :right) :left))))]
    (assoc (tree :right)
      :height (inc (max left-height (get-height ((tree :right) :right))))
      :left (assoc tree
	      :height left-height
	      :right (if (tree :right) ((tree :right) :left) nil)))))

These methods perform left and right rotations on AVL trees - they also work for any type of binary tree, but for those they will append extra information(probably incorrect) about the height. Tree rotations are used in AVL and Red-Black trees to maintain the invariants. They are operations that can be performed that rearrange the nodes in a binary search tree in a way that the tree is still valid, essentially by ‘rotating’ about a pivot node. For more information about these, look here.

(defn balance [tree]
  "Return a balanced version of the AVL tree"
  (if tree
    (if (< (Math/abs (balance-factor tree)) 2)
      tree
      (cond
	(= (balance-factor tree) 2) (if (>= (balance-factor (tree :right)) 1)
				      (lrotate tree)
				      (lrotate (assoc tree :right (rrotate (tree :right)))))
	(= (balance-factor tree) -2) (if (<= (balance-factor (tree :left)) -1)
				       (rrotate tree)
				       (rrotate (assoc tree :left (lrotate (tree :left)))))))))

This method balances an unbalanced tree, which can arise after insertion or deletion of a node from an AVL tree. Since the tree was balanced before the node was inserted or removed, the maximum difference in subtree heights is 2, and if the balance factor for a node is either 2 or -2 you must rebalance the tree by performing rotations based on the balance factor of the subtrees. If the balance factor of the current node is 1, 0, or -1, the node is a valid AVL tree and does not need to be rebalanced. However, nodes closer up the root may be invalid AVL trees and so need to be rebalanced.

(defn predecessor [tree]
  (#(if % (if (% :right) (recur (% :right)) %)) (tree :left)))

This finds the predecessor of a tree, or the largest node smaller than the current node. This makes use of the # reader macro, which creates an anonymous function, and %, which stands as the first argument to the function in one of these anonymous functions. It also uses the recur form, used to specify that this recursive call should be tail-call optimized. The recur is mainly used in this context to recursively call the anonymous function; the tree would have to be very large for this to throw a StackOverflowException.

(defn tree-lookup
  "Returns the data from the tree corresponding to val if it exists, otherwise nil"
  ([tree val < >]
     (if tree
       (cond
	 (< (tree :data) val) (tree-lookup (tree :right) val < >)
	 (> (tree :data) val) (tree-lookup (tree :left) val < >)
	 true (tree :data))
       nil))
  ([tree val] (tree-lookup tree val < >)))

This is the standard binary search tree lookup function, generalized so that custom comparators can be passed in. It works on any type of BST, assuming they are implemented as maps with a :left and :right field. I also like how clojure and other lisps allows you to define operators such as < and >, making for much cleaner-looking code. This function has multiple different ways to call it: you can call with the specifying all of [tree val < >], where tree is the tree to look in and val is the value to look for, or just [tree val], where < and > will default to numeric < and >.

(defn avl-insert
  "Inserts a new node into an AVL tree"
  ([tree val < >]
     (if tree
       (balance
	(cond
	  (> (tree :data) val)
	  (let [left (avl-insert (tree :left) val < >)]
	    (assoc tree
	      :height (inc (max (get-height left) (get-height (tree :right))))
	      :left left))
	  (< (tree :data) val)
	  (let [right (avl-insert (tree :right) val < >)]
	    (assoc tree
	      :height (inc (max (get-height (tree :left)) (get-height right)))
	      :right right))
	  true (assoc tree :data val)))
       (struct avl-tree val 0)))
  ([tree val] (avl-insert tree val < >)))

This is the function for inserting a node into an AVL tree. It also allows you to specify comparators - I’d like to be able to specify them on the AVL tree itself somehow so there was less chance of mistakes by passing in different < and > functions, but I couldn’t figure out a way to do it without having to have the comparators specified whenever you created a new node anyway, which still leaves the problem. If the tree you pass in is nil, it just creates a new avl-tree with that value and returns it; otherwise, it does a standard BST insert and calls balance on each subtree it visited. While balance doesn’t need to be called on every one, the calls to balance that are unnecessary will do no work anyway and so probably aren’t much of a performance hit.

 (defn avl-remove
   "Removes a node from an AVL tree"
   ([tree val < >]
      (if tree
	(balance
	 (cond
	   (< (tree :data) val) (let [right (avl-remove (tree :right) val)]
				  (assoc tree
				    :right right
				    :height (inc (max (get-height (tree :left))
						      (get-height right)))))
	   (> (tree :data) val) (let [left (avl-remove (tree :left) val)]
				  (assoc tree
				    :left left
				    :height (inc (max (get-height (tree :right))
						      (get-height left)))))
	   true (if (not (= (tree :height) 0))
		  (let [new-tree
			(if (predecessor tree)
			  (let [left (avl-remove (tree :left) ((predecessor tree) :data))]
			    (assoc tree
			      :data ((predecessor tree) :data)
			      :height (inc (max (get-height (tree :right))
						(get-height left)))
			      :left left))
			  (tree :right))]
		    (assoc new-tree
		      :height (inc (max (get-height (new-tree :left))
					(get-height (new-tree :right)))))))))))
   ([tree val] (avl-remove tree val < >)))

This is the function to call to remove a node from an AVL tree, and it is the most complicated function is the set. To remove it, you descend down the subtree until you find the node you wish to remove. If it is a leaf node, you can just remove it. If the left subtree is nil, the node is just replaced with the right subtree of the node. Otherwise, the node is given the data of it’s predecessor and then the predecessor is removed from the left subtree. This is all done immutably, of course. Once this new node is created, the tree is balanced at each node leading up from where the node was finally removed from to the root node, ending with a balanced AVL tree.

(defn assert-correct-heights [tree]
  (if tree
    (do
      (assert (< (balance-factor tree) 2))
      (assert (.equals (tree :height) (inc (max
					    (assert-correct-heights (tree :left))
					    (assert-correct-heights (tree :right))))))
      (tree :height))
    -1))

This is a function I used to test an AVL tree. It will go through the tree, make sure that all the heights are correct and that the tree is balanced at every node. If this is the case, then the tree is a valid AVL tree and it returns the height. If it is not, an exception is thrown.

(defn test-all []
  (dotimes [x 100]
    (def tree (ref (avl-insert nil 50)))
    (let [rand (Random.)]
      (dotimes [x 100]
	(dosync
	 (alter tree avl-insert (.nextInt rand 1000))))
      (assert-correct-heights tree)
      (def tree (deref tree))
      (dotimes [x 1000]
	(let [t (avl-remove tree x)]
	  (assert-correct-heights t)
	  (assert (not (tree-lookup t x)))))
      (def tree (ref tree))
      (dotimes [x 1000]
	(dosync
	 (alter tree avl-remove x)
	 (assert (not (tree-lookup (deref tree) x)))
	 (assert-correct-heights (deref tree))
	 )))))

This test puts the AVL tree through it’s paces. It creates a reference to an AVL tree so I could conveniently insert items into it. The tree is initially an already-inserted AVL tree because, unfortunately, (ref nil) is not to nil and so was giving null pointer exceptions on the first insert. The tree then has 100 random elements inserted into it, using dosync and alter to mutate tree. Once these insertions are done, this tree is checked to ensure it is a valid AVL tree. The tree then has each of it’s elements removed from it non mutably (so each node is removed once, but this does not effect the next removal which removes from the entire tree), which in a tree of this size will probably test each possible removal scenario. Each time a node is removed, the tree is checked for validity. After this, each node is removed and the tree is mutated one at a time randomly, leaving a nil tree at the end. At each intermediate step, the tree is checked for validity. This entire process is repeated 100 times. This test can take a while to finish, but once it’s done you can be pretty sure the tree works.