As I’ve been learning more about Scala andfunctional programming, I’ve been looking at accomplishing more tasks with recursive programming techniques. As part of my studies I put together a number of Scala recursion examples below, including:
- Sum
- Product
- Max
- Fibonacci
- Factorial
I won’t write too much about recursion theory today, just some basic thoughts. I’ll come back here and add more when I have some good thoughts or better examples to share.
Thinking in recursion
When I’m going to write a recursive method, I usually think about it like this:
- I know I want to do something with a collection of data elements.
- Therefore, my function will usually take this collection as an argument.
- Within the function I usually have two branches:
- In one case, when I’m handling the situation of being at the last element of the collection, I do some “ending” operation. For instance, in the
Sum
example below, when I get to theNil
element in aList
, I return0
and let the recursive method calls unroll. - In the second case, as when the function is not at the end of the list, I write the code for my main algorithm; it operates on the current element in the collection (the ’head’ element); I then recursively call my function, passing it the remainder of the collection (the ’tail’).
- In one case, when I’m handling the situation of being at the last element of the collection, I do some “ending” operation. For instance, in the
- When the function calls unroll, the function returns whatever it is that I’m calculating. For instance, in the sum, product, and max functions that follow, the function returns an
Int
. In the Fibonacci example the function prints its result as it goes along, so it doesn’t return anything.
As another note, in some cases it helps to have an “accumulator” function inside your main function. I show this in the examples that follow, and I’ll describe it more at some point in the future.
A recursive ’sum’ function
The following code shows three ways to calculate the sum of a List[Int]
recursively. I don’t think the first approach is practical; it is simple, but results in a StackOverflowError
when the list is large.
The second approach shows how to fix the first approach by using a tail-recursive algorithm. This solution uses the “accumulator” I mentioned above.
The third approach shows how to use an if/else construct instead of a match
expression. It’s taken from the URL shown.
With that introduction, here’s the code:
package recursion import scala.annotation.tailrec /** * Different ways to calculate the sum of a list using * recursive Scala methods. */ object Sum extends App { val list = List.range(1, 100) println(sum(list)) println(sum2(list)) println(sum3(list)) println(sumWithReduce(list)) // (1) yields a "java.lang.StackOverflowError" with large lists def sum(ints: List[Int]): Int = ints match { case Nil => 0 case x :: tail => x + sum(tail) } // (2) tail-recursive solution def sum2(ints: List[Int]): Int = { @tailrec def sumAccumulator(ints: List[Int], accum: Int): Int = { ints match { case Nil => accum case x :: tail => sumAccumulator(tail, accum + x) } } sumAccumulator(ints, 0) } // (3) good descriptions of recursion here: // stackoverflow.com/questions/12496959/summing-values-in-a-list // this example is from that page: def sum3(xs: List[Int]): Int = { if (xs.isEmpty) 0 else xs.head + sum3(xs.tail) } }
I don’t want to stray too far from the point of this article, but while I’m talking about “sum” algorithms, another way you can calculate the sum of a List[Int]
in Scala is to use the reduceLeft
method on theList
:
def sumWithReduce(ints: List[Int]) = { ints.reduceLeft(_ + _) }
(That’s all I’ll say about reduceLeft
today.)
Calculating the “product” of a List[Int] recursively
Calculating the product of a List[Int]
is very similar to calculating the sum; you just multiply the values inside the function, and return 1
in the Nil
case. Therefore I’ll just show the following code without discussing it:
package recursion import scala.annotation.tailrec /** * Different ways to calculate the product of a List[Int] recursion. */ object Product extends App { val list = List(1, 2, 3, 4) println(product(List(1, 2, 3, 4))) println(product2(List(1, 2, 3, 4))) // (1) basic recursion; yields a "java.lang.StackOverflowError" with large lists def product(ints: List[Int]): Int = ints match { case Nil => 1 case x :: tail => x * product(tail) } // // (2) tail-recursive solution def product2(ints: List[Int]): Int = { @tailrec def productAccumulator(ints: List[Int], accum: Int): Int = { ints match { case Nil => accum case x :: tail => productAccumulator(tail, accum * x) } } productAccumulator(ints, 1) } }
Calculating the “max” of a List[Int] recursively
Calculating the “max” of a List[Int] recursively is a little different than calculating the sum or product. In this algorithm you need to keep track of the highest value found as you go along, so I jump right into using an accumulator function inside the outer function.
I show two approaches in the source code below, the first using a match
expression and the second using an if/else expression:
package main.scala.recursion import scala.annotation.tailrec object Max extends App { val list = List.range(1, 100000) println(max(list)) println(max2(list)) // 1 - using `match` def max(ints: List[Int]): Int = { @tailrec def maxAccum(ints: List[Int], theMax: Int): Int = { ints match { case Nil => theMax case x :: tail => val newMax = if (x > theMax) x else theMax maxAccum(tail, newMax) } } maxAccum(ints, 0) } // 2 - using if/else def max2(ints: List[Int]): Int = { @tailrec def maxAccum2(ints: List[Int], theMax: Int): Int = { if (ints.isEmpty) { return theMax } else { val newMax = if (ints.head > theMax) ints.head else theMax maxAccum2(ints.tail, newMax) } } maxAccum2(ints, 0) } }
A Scala Fibonacci recursion example
The code below shows one way to calculate a Fibonacci sequence recursively using Scala:
package recursion /** * Calculating a Fibonacci sequence recursively using Scala. */ object Fibonacci extends App { println(fib(1, 2)) def fib(prevPrev: Int, prev: Int) { val next = prevPrev + prev println(next) if (next > 1000000) System.exit(0) fib(prev, next) } }
There are other ways to calculate a Fibonacci sequence, but since my function takes two Int
values as arguments and prints as it goes along, this solution works.
Recursive factorial algorithms
Finally, without much discussion, the following Scala code shows two different recursive factorial algorithms, with the second solution showing the tail-recursive solution:
package recursion import scala.annotation.tailrec object Factorial extends App { println(factorial(5)) println(factorial2(5)) // 1 - basic recursive factorial method def factorial(n: Int): Int = { if (n == 0) return 1 else return n * factorial(n-1) } // 2 - tail-recursive factorial method def factorial2(n: Long): Long = { @tailrec def factorialAccumulator(acc: Long, n: Long): Long = { if (n == 0) acc else factorialAccumulator(n*acc, n-1) } factorialAccumulator(1, n) } }