Sunday, January 25, 2009

Erlang style binary matching in Scala

Binary Patterns


Erlang programs typically read binary data into Binaries which are broadly equivalent to byte arrays in C family languages. These binary objects can then be matched by patterns in function definitions and case expressions using the bit syntax. A simple example which matches two 3 bit values and one 10 bit values is shown below:
case Header of <<Version:3, Type:3, Len:10>> -> ...

The bit syntax can also pack values into binaries and unpack them into variables. There are lots of options which control the size and type of the extracted variables: signed, little-endian etc.

By comparison working with bytes arrays in Java is much more cumbersome. You would typically wrap your byte array in a ByteArrayInputstream and DataInputStream for reading the built-in types, and for reading unusual numbers of bits you need to write your own bit fiddling code. If we had a BitStream class the previous example could be written like the following in Java:

if (in.remaining == 16) {
int version = in.getBits(3);
int type = in.getBits(3);
int len = in.getBits(10);
...
}

The situation in Scala is pretty much the same as it doesn't have much of an IO library yet.
Since it is possible to implement Erlang's party trick (actor concurrency) as a Scala library I thought it would be interesting to see if something approximating the bit syntax could be implemented in Scala.

Scala Pattern Matching


Scala has a quite rich pattern matching ability. The three most common uses of pattern matching in Scala are with constants/literals, with types, and with case classes. Matching literals or constants is pretty much the classic switch statement. Matching on types nicely replaces the if-instance-then-cast idiom from Java:

x match {
case s:String => ...
case i:Int => ...
}

Case classes are used to model algebraic data types as are common in functional languages:

sealed abstract class Tree
case object Empty extends Tree
case class LeafNode(value: Int) extends Tree
case class InternalNode(left: Tree, right: Tree, value: Int) extends tree

def sum(t: Tree) = t match {
case Empty => 0
case LeafNode(v) => v
case InternalNode(l, r, v) => v + sum(l) + sum(r)
}

Obviously neither of these types of pattern matching will let us do bit syntax style matching. Luckily Scala has another mechanism for pattern matching: extractors. Extractors are objects with an unapply method which can return either Some[A] if it matchs or None if it doesn't. A simple example (from Programming In Scala) is an email address matcher:

object EmailAddress {
def unapply(s: String): Option[(String,String)] = {
val parts = s split "@"
if (parts.length == 2) Some(parts(0), parts(1)) else None
}
}

"foo@bar.com" match {
case EmailAddress(user, domain) => println("Hello " + user)
case _ => println("Unmatched")
}

What immediately comes to mind is something like:

case BitPattern(Bits(3), Bits(3), Bits(10)) =>

But of course this won't work. All extractors know about is the value passed to the unapply method. They know nothing about the context in which it was called, or any nested patterns.
So any chance of exactly copying the Erlang syntax is scuppered. But if we are willing to allow the definition of the pattern in a different place to its use then we can still achieve something.

Underlying implementation


To start with I'll assume a BitStream class that wraps a byte array. I won't go into the details as it is not very exciting (and my bit fiddling expertise isn't all that). We then need some classes to represent the various data types we want to pull out.

abstract class BitPart[T] {
def bitLength : Int
def take(b: BitStream) : T
}

class Bits(c : Int) extends BitPart[Int] {
def bitLength = c
def take(b: BitStream) = b.getBits(c)
}

object SInt extends BitPart[Int] {
def bitLength = 32
def take(b: BitStream) = b.getBits(bitLength)
}

object Remaining extends BitPart[Array[Byte]] {
def bitLength = -1
def take(b: BitStream) = b.remainingBytes
}

Remaining is a special case that will match any number of bytes that lets us match arbitrary length byte arrays. Obviously we would need a bunch more types for this to be a useful API.
So, on to matching with these classes. It is not too difficult to take a sequence of BitPart instances, apply them to a BitStream, and return None if they don't match or Some[Array[Any]] if they do.

def matchPattern(bytes: Array[Byte], parts : BitPart[_]* ) : Option[Array[Any]] = {
...
}

I'm omitting the implementation again, because it is a bit long and not overly interesting.
It is possibly to write extractors that return a variable number of items but that is not really what we want to do here as we could make a mistake between the pattern definition and use. Also we've lost the type safety and so would need to stick type annotations on all the pattern variables: annoying and error prone.

The Extractors


We will wrap this matchPattern method up in a type safe extractor.

class BitMatch3[T1,T2,T3] (p1: BitPart[T1], p2: BitPart[T2], p3: BitPart[T3]) {
def unapply(bytes : Array[Byte]) : Option[(T1,T2,T3)] = {
for (a <- matchPattern(bytes, p1, p2, p3))
yield (a(0).asInstanceOf[T1], a(1).asInstanceOf[T2], a(2).asInstanceOf[T3])
}
}

Let's try this out on our original example:

val Header = new BitMatch3(new Bits(3), new Bits(3), new Bits(10))
bArray match {
case Header(version, type, len) => ...
case _ =>
}

A little clunky maybe, but getting there. As you might have guessed from the BitMatch3 name we need a different extractor for each number of elements. This is the same as Tuples in Scala which are defined with up to 22 elements. It is a common limitation in statically typed languages which could possibly be overcome with Lisp style macros or C++ templates.
Along with all these classes I created an object with a bunch of overloaded methods to hide them, so that the end user doesn't have to use the ugly names with numbers in.

object Patterns {
def bitPattern[T](p: BitPart[T]) = new BitMatch1(p)
def bitPattern[T1,T2](p1: BitPart[T1],p2: BitPart[T2]) = new BitMatch2(p1,p2)
def bitPattern[T1,T2,T3](p1: BitPart[T1],p2: BitPart[T2],p3: BitPart[T3]) = new BitMatch3(p1,p2,p3)
def bitPattern[T1,T2,T3,T4](p1: BitPart[T1],p2: BitPart[T2],p3: BitPart[T3],p4: BitPart[T4]) = new BitMatch4(p1,p2,p3,p4)
...
}

The direct use of the BitPart constructors and objects is a bit ugly so the usual Scala DSL stuff comes into play:

def int = SInt
def byte = SByte
def remaining = Remaining

class PatternInt(x :Int) {
def bits = new Bits(x)
def bytes = new Bytes(x)
}

implicit def intExtras(x : Int) = new PatternInt(x)

This makes the simple example a lot nicer:

val Header = bitPattern(3 bits, 3 bits, 10 bits)

A more complicated example could include lots of types and lots of DSL magic:

val Packet = bitPattern(2 bits, 6 bits, byte, unsigned int, float, remaining)

If you don't like the long bitPattern name you could change it to the Erlang << and stick a def >> = this method on the BitMatch classes for symmetry, which would give you:
val Header = <<(3 bits, 3 bits, 10 bits)>>

Lets look at a more complicated example from Programming Erlang which parses an IPv4 datagram:

-define(IP_VERSION, 4).

DgramSize = size(Dgram),
case Dgram of
<<?IP_VERSION:4, Hlen:4, SrvcType:8, TotLen:16, ID:16, Flgs:3, FragOff:13,
TTL:8, proto:8, HdrChkSum:16, SrcIP: 32, DstIP:32, RestDgram/binary>>
when HLen >= 5, 4 * HLen =< DgramSize ->
...

In Scala:

val IpVersion = 4
val IPv4Dgram = bitPattern(4 bits, 4 bits, byte, short, short, 3 bits, 13 bits,
byte, byte, short, int, int, remaining)

datagram match {
case IPv4Dgram(IpVersion, hLen, srvcType, totLen, id, flags, fragOff,
ttl, proto, hdrChkSum, srcIP, destIP, restDgram) if hLen >= 5 &
(4 * hLen) <= datagram.length => ...
}

Patterns can be nested so, for example, in the case above we could have(assuming an IPAddress extractor):

ttl, proto, hdrChkSum, IPAddress(127,0,0,1), destIP, restDgram)

Which would only match datagrams from localhost.

Wrap-up


Erlang bit syntax is compiled, and so is supposedly quite fast, whereas this Scala implementation involves a fair bit of boxing and allocation, so I couldn't recommend it for the core of your high performance server, but for occasional network or file tasks it is a pretty neat solution.
This implementation works on byte arrays, because that is pretty much what Erlang does, and going by Programming Erlang it is common practise to read a whole file into one big binary. I wouldn't normally want to do that, so pattern matching that worked on lazy list (scala.Stream) wrapper around an InputStream wold be nice to have. The tricky part in that is managing when earlier parts of the Stream are garbage collected as you could end up holding on to the whole file anyway.
Extractors are a very cool part of the Scala language and it is nice to see how easy it is to imitate features of other language.