Tag Archives: trampoline

Mutual recursion in Clojure – the trampoline

We’ve seen how recursion can be implemented efficiently (that is, without stack overflow), and also how the loop special form can be used to simulate a loop in an imperative language. These techniques work well if we’re dealing with simple recursion, in which a single function calls itself.

A more complex case is mutual recursion, in which two or more functions call each other recursively. Since recur always directs a function to call itself, it can’t be used in mutual recursion, so we can’t use it to avoid stack overflow. Most simple examples of mutual recursion are somewhat contrived, and the one about to be given is no exception, but it should serve to get the point across.

Suppose we are given a number and want to alternate between taking the square root and dividing by 2 until the result is less than 1. We can set up one function for taking the square root and another for dividing by 2, and use mutual recursion to perform the desired task:

(declare div2-recur)
(defn sqrt-recur [n]
  (do
    (println "sqrt-recur:" n)
    (if (< n 1)
      n
      (div2-recur (Math/sqrt n)))))

(defn div2-recur [n]
  (do
    (println "div2-recur:" n)
    (if (< n 1)
      n
      (sqrt-recur (/ n 2)))))

On line 1, we explicitly declare the second function div2-recur since it is referred to in the first function, and the Clojure reader will complain about an undefined symbol unless we declare it.

The sqrt-recur function takes a number n as an argument, prints it out, then tests if it is less than 1. If so, n is returned and the function ends. If not, a recursive call is made to div2-recur, passing the square root of n as the argument.

div2-recur works the same way, except it passes n/2 back to sqrt-recur. We can test the functions at the REPL:

=> (sqrt-recur 5)
sqrt-recur: 5
div2-recur: 2.23606797749979
sqrt-recur: 1.118033988749895
div2-recur: 1.057371263440564
sqrt-recur: 0.528685631720282
0.528685631720282

The number in the last line is just the final value of n returned from the last call to div2-recur.

However, because we’re not optimizing the recursive calls, if we give this program a large enough number, it will cause a stack overflow. The solution to this problem is not to use recur, as we did with single recursion, but rather a function called trampoline. To do this, we need to make a slight modification to each function. The revised functions are:

(declare div2-recur2)
(defn sqrt-recur2 [n]
  (do
    (println "sqrt-recur2:" n)
    (if (< n 1)
      n
      #(div2-recur2 (Math/sqrt n)))))

(defn div2-recur2 [n]
  (do
    (println "div2-recur2:" n)
    (if (< n 1)
      n
      #(sqrt-recur2 (/ n 2)))))

Apart from changing the names of the functions, the only change we made is the addition of a # in the last line of each function. This converts what was a function call in the first version to an anonymous function definition in the revised version. That is, instead of a recursive call being made to div2-recur2 from sqrt-recur2, an anonymous function is returned from sqrt-recur2. As it stands, there is no recursive call between the two functions any more. You can verify this by calling either function directly at the REPL:

=> (sqrt-recur2 5)
sqrt-recur2: 5
#<FunctionTest$sqrt_recur2__1356$fn__1358 .FunctionTest$sqrt_recur2__1356$fn__1358@5506d>

The first println in sqrt-recur2 prints out its value, and then we get a cryptic line of output which is Clojure’s internal representation of the anonymous function in the last line of sqrt-recur2.

However, if we use the built-in Clojure function trampoline, the functionality is magically restored:

=> (trampoline sqrt-recur2 5)
sqrt-recur2: 5
div2-recur2: 2.23606797749979
sqrt-recur2: 1.118033988749895
div2-recur2: 1.057371263440564
sqrt-recur2: 0.528685631720282
0.528685631720282

So what exactly is trampoline doing? We can find out by looking at its code (available in the clojure.core sourcecode listings):

(defn trampoline
  ([f]
     (let [ret (f)]
       (if (fn? ret)
         (recur ret)
         ret)))
  ([f & args]
     (trampoline #(apply f args))))

The function can take either a single argument, or multiple arguments. Since we invoked it with 3 arguments, we’ll look at that bit (lines 7 and 8) first. trampoline assumes its first argument is a function (sqrt-recur2 in our example above), and it constructs an anonymous function which it then passes back to itself, thus causing the version of trampoline with a single argument to be called.

The anonymous function uses apply to apply the function f to the remaining arguments. In our example, the only remaining argument is 5, so sqrt-recur2 is applied to 5. However, look at what trampoline does with the function it is passed. The argument is tested to see if it is a function and, if so, recur gets the function ret to run, and passes whatever it returns as a recursive argument back into trampoline. In our example, calling sqrt-recur2 with an argument of 5 will return an anonymous function calling div2-recur2 with an argument that is the square root of 5. Thus trampoline receives another function as its argument in the recur, so it will call that function and pass its return value back into trampoline. Eventually, one of these function calls will return just a number (when that number is less than 1). At that point, the argument to trampoline will not be a function (rather, it’s just a number), so the recursion in trampoline stops, and the number itself is returned.

What has happened is that trampoline handles the mutual recursion by converting it into single recursion within the trampoline function, so that recur can be used to optimize the recursion and prevent a stack overflow. This is yet another example of the power of using first-class functions (functions that can be passed around as arguments to other functions).

One final note. Forcing the user to call this program by explicitly typing out a call to trampoline isn’t very friendly, so it’s a good idea to provide a wrapper that hides what’s going on. That’s easy enough to do by defining an auxiliary function:

(defn sqrt-div2 [n]
  (trampoline sqrt-recur2 n))

Now we can run the code by typing just (sqrt-div2 5).

Advertisements