Worksheet Algebraic Data Types and Pattern Matching
Table of Contents
- 1. Option Type
- 2. Tail Recursion
- 3. Lists
- 4. Classes in Scala
- 5. Algebraic Data Types via Case classes in Scala
- 6. Solutions
- 6.1. Solution: Maximum of a List
- 6.2. Solution: Find Element
- 6.3. Solution: Handle an Option - Pattern Matching
- 6.4. Solution: Handle an Option - flatMap
- 6.5. Solution: Scala - Looping in Scala with a Function
- 6.6. Solution: Scala - Counting Values
- 6.7. Solution: flatMap for Lists
- 6.8. Solution: Classes - Java Translation
- 6.9. Solution: MyList Case Class
- 6.10. Solution: Binary Tree Case Class
1. Option Type
1.1. Maximum of a List
Write a recursive function that returns the maximum of a list of integers.
Instead of throwing an exception when the list is empty, return an Option
type.
That is, your function should have the following form.
def maximum (xs:List[Int]) : Option[Int] = { ... }
1.2. Find Element
Write a higher-order function that finds the first element in a list that satisfies a given predicate (function taking an element of the list and returning a Boolean value).
Instead of throwing an exception when no such element is found, return an Option
type.
That is, your function should have the following form.
def find [X] (xs:List[X], p:X=>Boolean) : Option[X] = { ... }
1.3. Handle an Option - Pattern Matching
Consider the following Scala function.
It searches for the first String
in a list of String
references that is strictly longer than s
and that contains s
as a substring.
def oneFind (xs:List[String], s:String) : Option[String] = { xs.find ((x:String) => (x.length > s.length) && x.contains (s)) }
For example:
scala> oneFind (List ("the", "rain", "in", "spain"), "i") res1: Option[String] = Some(rain) scala> oneFind (List ("the", "rain", "in", "spain"), "in") res2: Option[String] = Some(rain) scala> oneFind (List ("the", "rain", "in", "spain"), "ain") res3: Option[String] = Some(rain) scala> oneFind (List ("the", "rain", "in", "spain"), "pain") res16: Option[String] = Some(spain) scala> oneFind (List ("the", "rain", "in", "spain"), "rain") res4: Option[String] = None
Now write a function twoFind
that calls oneFind
and then uses the result (if there is one) in a second call to oneFind
.
Use pattern-matching to do case analysis on the result of the first call to oneFind
.
Your definition of twoFind
should have the following form.
def twoFind (xs:List[String], s:String) : Option[String] = { ... }
Note that twoFind
must not have return type Option[Option[String]]
!
For example:
twoFind (List ("the", "rain", "in", "spain"), "i") res1: Option[String] = None scala> twoFind (List ("the", "rain", "in", "spain"), "in") res2: Option[String] = None scala> twoFind (List ("the", "rain", "in", "spain"), "n") res3: Option[String] = None scala> twoFind (List ("the", "rain", "in", "spain", "trained", "after", "a", "sprain"), "in") res4: Option[String] = Some(trained) scala> twoFind (List ("the", "rain", "in", "spain", "trained", "after", "a", "sprain"), "i") res5: Option[String] = Some(trained) scala> twoFind (List ("the", "rain", "in", "spain", "trained", "after", "a", "sprain"), "n") res6: Option[String] = Some(trained) scala> twoFind (List ("the", "rain", "in", "spain", "trained", "after", "a", "sprain"), "ain") res7: Option[String] = Some(trained) scala> twoFind (List ("the", "rain", "in", "spain", "trained", "after", "a", "sprain"), "pain") res8: Option[String] = None
1.4. Handle an Option - flatMap
Rewrite twoFind
from the previous question to use flatMap
for the Option
type instead of explicit pattern-matching.
Note: if e
has type Option[A]
and you write e.flatMap (f)
, then it will have type Option[B]
whenever f
has type A=>Option[B]
.
You may find it helpful to look at the Scala API documentation for the
Option
type at https://www.scala-lang.org/api/current/scala/Option.html.
Variations of flatMap
are idiomatic in many languages that use Option
types.
2. Tail Recursion
2.1. Scala - Looping in Scala with a Function
Write Scala code that uses a tail-recursive function to print the following:
1 potato 2 potato 3 potato 4 potato
You should define a function for this.
Include a tailrec
annotation so that the Scala system tells you if your function is not tail recursive.
2.2. Scala - Counting Values
Here is Java code to count the number of integers in a linked list that are between low and high values.
public class CountingValues { static class Node { int item; Node next; public Node (int item, Node next) { this.item = item; this.next = next; } } static int count (Node data, int low, int high) { int result = 0; while (data != null) { if (low <= data.item && data.item <= high) { result = result + 1; } data = data.next; } return result; } public static void main (String[] args) { Node data = null; for (int i = args.length - 1; i >= 0; i--) { data = new Node (Integer.parseInt (args[i]), data); } System.out.printf ("count (..., 5, 10) = %d%n", count (data, 5, 10)); } }
$ javac CountingValues.java $ java CountingValues 1 2 3 4 5 6 7 8 9 10 11 12 count (..., 5, 10) = 6
Your task is to rewrite the count
function in Scala. Your definition should be a tail-recursive function.
2.3. Scala - Which Functions Are Tail Recursive?
Consider each of the following Scala functions.
For each Scala function, is it tail recursive?
Test your answers by adding a tailrec
annotation and pasting the function definition into the Scala console.
/* The Scala library includes many other functions on lists that are common. * Below we define our own versions of these functions for pedagogical purposes, * but normally the library versions would be used instead. */ def append [X] (xs:List[X], ys:List[X]) : List[X] = xs match case Nil => ys case z::zs => z :: append (zs, ys) def flatten [X] (xss:List[List[X]]) : List[X] = xss match case Nil => Nil case ys::yss => ys ::: flatten (yss) /* The "take" function takes the first n elements of a list. */ def take [X] (n:Int, xs:List[X]) : List[X] = if n <= 0 then Nil else xs match case Nil => Nil case y::ys => y :: take (n - 1, ys) /* The "drop" function drop the first n elements of a list. */ def drop [X] (n:Int, xs:List[X]) : List[X] = if n <= 0 then xs else xs match case Nil => Nil case y::ys => drop (n - 1, ys) val sampleList : List[Int] = (1 to 12).toList val takeresult : List[Int] = take (3, sampleList) val dropresult : List[Int] = drop (3, sampleList) /* Reverse a list. This version is straightforward, but inefficient. Revisited later on. */ def reverse [X] (xs:List[X]) : List[X] = xs match case Nil => Nil case y::ys => reverse (ys) ::: List (y) /* zip takes two lists and creates a single list containing pairs from the two lists * that occur at the same position. The definition shows more sophisticated pattern * matching: the use of wildcards; overlapping patterns; and decomposing pairs and * lists simultaneously. Note that the "xs" and "ys" in the third case shadow the * "xs" and "ys" parameters to the function. */ def zip [X,Y] (xs:List[X], ys:List[Y]) : List[(X,Y)] = (xs, ys) match case (Nil, _) => Nil case (_, Nil) => Nil case (x::xs, y::ys) => (x, y) :: zip (xs, ys) val ziplist = zip (List (1, 2, 3), List ("the", "rain", "in", "spain")) /* unzip decomposes a list of pairs into a pair of lists. * The recursive case illustrates pattern-matching the result of the * recursive call in order to apply an operation to the elements. */ def unzip [X,Y] (ps:List[(X,Y)]) : (List[X], List[Y]) = ps match case Nil => (Nil, Nil) case (x, y)::qs => val (xs, ys) = unzip (qs) (x :: xs, y :: ys) val unziplist = unzip (ziplist) def reverse2 [X] (xs:List[X]) : List[X] = def aux (xs:List[X], result:List[X]) : List[X] = xs match case Nil => result case y::ys => aux (ys, y :: result) aux (xs, Nil)
3. Lists
3.1. flatMap for Lists
What is the result of evaluating the following Scala expressions?
def repeat [X] (x:X, n:Int) : List[X] = { if n == 0 then Nil else x :: repeat (x, n - 1) } val xs:List[(Char,Int)] = List (('a', 2), ('b', 4), ('c', 8)) val ys = xs.map ((p:(Char,Int)) => repeat (p._1, p._2)) val zs = xs.flatMap ((p:(Char,Int)) => repeat (p._1, p._2))
What types do ys
and zs
have, and how does flatMap
differ from map
?
Try to express flatMap
using map
and flatten
.
4. Classes in Scala
4.1. Java Translation
Translate the following Java class definition into a corresponding Scala class definition.
You will need a companion object definition to model the static
components.
In order to definition the Scala class and object at the same time in the Scala REPL, you must enter :paste
, paste your definitions together, and then press control-D.
class Counter { private static int numCounters = 0; final int id; int count; Counter (int initial) { this.id = numCounters; this.count = initial; numCounters++; } static int getNumCounters () { return numCounters; } int getId () { return this.id; } int getNextCount () { return this.count++; } }
5. Algebraic Data Types via Case classes in Scala
5.1. MyList Case Class
Define your own MyList
class in Scala (as in the slides).
Now implement the following functions:
enum MyList[+X] { ... } import MyList._ def length [X] (xs:MyList[X]) : Int = { ... } def toList [X] (xs:MyList[X]) : List[X] = { ... } def fromList [X] (xs:List[X]) : MyList[X] = { ... } def append [X] (xs:MyList[X], ys:MyList[X]) : MyList[X] = { ... } def map [X,Y] (xs:MyList[X], f:X=>Y) : MyList[Y] = { ... }
5.2. Binary Tree Case Class
Define your own Tree
class for binary trees in Scala (as in the slides).
Now implement the following functions:
enum Tree[+X] { case Leaf case Node(l:Tree[X], c:X, r:Tree[X]) } import Tree.{Leaf,Node} val tree1:Tree[Int] = Leaf val tree2:Tree[Int] = Node (Leaf, 5, Leaf) val tree3:Tree[Int] = Node (Node (Leaf, 2, Leaf), 3, Node (Leaf, 4, Leaf)) // Find the size of a binary tree. Leaves have size zero here. def size [X] (t:Tree[X]) : Int = { ... } // Insert a number into a sorted binary tree. def insert [X] (x:X, t:Tree[X], lt:(X,X)=>Boolean) : Tree[X] = { ... } // Put the elements of the tree into a list using an in-order traversal. def inorder [X] (t:Tree[X]) : List[X] = { ... }
6. Solutions
6.1. Solution: Maximum of a List
def maximum (xs:List[Int]) : Option[Int] = { xs match { case Nil => None case y::ys => maximum (ys) match { case None => Some (y) case Some (z) => Some (y max z) } } }
Suppose we start with maximum (List(11,31,21)). Computation goes as follows:
maximum (List(11,31,21)) => maximum (List(31,21)) match [with y=11] { ... } => (maximum (List(21)) match [with y=31] {...}) match [with y=11] { ... } => ((maximum (List()) match [with y=21] {...}) match [with y=31] {...}) match [with y=11] { ... } => ((None match [with y=21] {...}) match [with y=31] {...}) match [with y=11] { ... } => (Some(21) match [with y=31] {...}) match [with y=11] { ... } => Some(31 max 21) match [with y=11] { ... } => Some(31) match [with y=11] { ... } => Some(11 max 31) => Some(31)
It’s actually easier to understand if you starting at Nil by running on progressively longer lists:
maximum(List()) => None maximum(List(21)) => maximum(List()) match {...} => None match {...} => Some(21) maximum(List(31,21)) => maximum(List(21)) match {...} => Some(21) match {...} => Some(31 max 21) => Some(31) maximum(List(11,31,21)) => maximum(List(31,21)) match {...} => Some(31) match {...} => Some(11 max 31) => Some(31)
Recall that match is simply short-hand for if-else and is evaluated left to right. So the definition above is equivalent to:
def maximum (xs:List[Int]) : Option[Int] = { if (xs == Nil) { //case Nil None } else { // case y::ys val y = xs.head val ys = xs.tail val zoption = maximum (ys) if (zoption == None) { // case None Some (y) } else { // case Some(z) val z = zoption.get Some (y max z) } } }
6.2. Solution: Find Element
def find [X] (xs:List[X], p:X=>Boolean) : Option[X] = { xs match { case Nil => None case y::ys if p (y) => Some (y) case _::ys => find (ys, p) } }
This one is easier since tail recursive:
def p = (x:Int) => (x==5) find(List(11,5,21), p) => find(List(5,21), p) => Some(5) find(List(11,21), p) => find(List(21), p) => None
6.3. Solution: Handle an Option - Pattern Matching
def oneFind (xs:List[String], s:String) : Option[String] = { xs.find ((x:String) => (x.length > s.length) && x.contains (s)) } def twoFind (xs:List[String], s:String) : Option[String] = { oneFind (xs, s) match { case None => None case Some (t) => oneFind (xs, t) } }
6.4. Solution: Handle an Option - flatMap
def oneFind (xs:List[String], s:String) : Option[String] = { xs.find ((x:String) => (x.length > s.length) && x.contains (s)) } def twoFind (xs:List[String], s:String) : Option[String] = { oneFind (xs, s).flatMap ((t:String) => oneFind (xs, t)) }
6.5. Solution: Scala - Looping in Scala with a Function
import scala.annotation.tailrec @tailrec def foo (n:Int) : Unit = if n <= 4 then println ("%d potato".format (n)) foo (n + 1)
scala> foo (1) 1 potato 2 potato 3 potato 4 potato
6.6. Solution: Scala - Counting Values
Here is one way to structure the program.
The tail-recursive aux
function is nested inside count
, which means it can only be called from inside count
.
def count (xs:List[Int], low:Int, high:Int) : Int = def aux (ys:List[Int], result:Int) : Int = ys match case Nil => result case z::zs if low <= z && z <= high => aux (zs, result + 1) case _::zs => aux (zs, result) aux (xs, 0)
scala> count (List (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), 5, 10) res1: Int = 6
6.7. Solution: flatMap for Lists
def repeat [X] (x:X, n:Int) : List[X] = { if n == 0 then Nil else x :: repeat (x, n - 1) } val xs:List[(Char,Int)] = List (('a', 2), ('b', 4), ('c', 8)) val ys = xs.map ((p:(Char,Int)) => repeat (p._1, p._2)) val zs = xs.flatMap ((p:(Char,Int)) => repeat (p._1, p._2)) val zs2 = xs.map ((p:(Char,Int)) => repeat (p._1, p._2)).flatten (zs == zs2)
6.8. Solution: Classes - Java Translation
object Counter: private var numCounters = 0 def getNumCounters : Int = numCounters def incNumCounters : Unit = numCounters = numCounters + 1 class Counter (initial:Int): private val id:Int = Counter.getNumCounters private var count:Int = initial Counter.incNumCounters def getId () : Int = id def getNextCount () : Int = { val tmp = count; count = count + 1; tmp }
6.9. Solution: MyList Case Class
enum MyList[+X]: case MyNil case MyCons(head:X, tail:MyList[X]) import MyList.* def length [X] (xs:MyList[X]) : Int = xs match case MyNil => 0 case MyCons (_, ys) => 1 + length (ys) def toList [X] (xs:MyList[X]) : List[X] = xs match case MyNil => Nil case MyCons (y, ys) => y::toList(ys) def fromList [X] (xs:List[X]) : MyList[X] = xs match case Nil => MyNil case y::ys => MyCons (y, fromList (ys)) def append [X] (xs:MyList[X], ys:MyList[X]) : MyList[X] = xs match case MyNil => ys case MyCons (z, zs) => MyCons (z, append (zs, ys)) def map [X,Y] (xs:MyList[X], f:X=>Y) : MyList[Y] = xs match case MyNil => MyNil case MyCons (y, ys) => MyCons (f (y), map (ys, f)) val xs = MyCons(11, MyCons(21, MyCons(31, MyNil))) val len = length (xs)
6.10. Solution: Binary Tree Case Class
enum Tree[+X]: case Leaf case Node(l:Tree[X], c:X, r:Tree[X]) import Tree.* val tree1:Tree[Int] = Leaf val tree2:Tree[Int] = Node (Leaf, 5, Leaf) val tree3:Tree[Int] = Node (Node (Leaf, 2, Leaf), 3, Node (Leaf, 4, Leaf)) // Find the size of a binary tree. Leaves have size zero here. def size [X] (t:Tree[X]) : Int = t match case Leaf => 0 case Node (t1, _, t2) => size (t1) + 1 + size (t2) // Insert a number into a sorted binary tree. def insert [X] (x:X, t:Tree[X], lt:(X,X)=>Boolean) : Tree[X] = t match case Leaf => Node (Leaf, x, Leaf) case Node (t1, c, t2) if (lt (x, c)) => Node (insert (x, t1, lt), c, t2) case Node (t1, c, t2) if (lt (c, x)) => Node (t1, c, insert (x, t2, lt)) case Node (t1, c, t2) => Node (t1, c, t2) // Put the elements of the tree into a list using an in-order traversal. def inorder [X] (t:Tree[X]) : List[X] = t match case Leaf => Nil case Node (t1, c, t2) => inorder (t1) ::: List (c) ::: inorder (t2)