Chapter 3 Nested Expressions

The previous chapter would compile expressions like 3+4, 4*2, etc. That’s of course not very interesting, so let’s add nested expressions: (3+4)*2 and also 3+4*2.

We will see how to generate code using auxiliary memory to store temporary results, being clever about storing them in registers as opposed to main memory when possible. We will also see how to support proper operator precedence in the parser.

Themes: recursion, memory hierarchy.

3.1 Language Definition

To start off, the grammar becomes recursive, and so does our AST data type:

<exp> ::= <num>
       |  <exp> + <exp>
       |  <exp> - <exp>
       |  <exp> * <exp>
       |  <exp> / <exp>
       |  <exp> % <exp>

Abstract syntax as Scala data structure:

abstract class Exp
case class Lit(x: Int) extends Exp
case class Plus(x: Exp, y: Exp) extends Exp
case class Minus(x: Exp, y: Exp) extends Exp
case class Times(x: Exp, y: Exp) extends Exp
case class Div(x: Exp, y: Exp) extends Exp
case class Mod(x: Exp, y: Exp) extends Exp

3.2 Direct Interpreter and Compiler

The interpreter becomes recursive, too:

type Val = Int

def eval(e: Exp): Val = e match {
  case Lit(x)     => x
  case Plus(x,y)  => eval(x) + eval(y)
  case Minus(x,y) => eval(x) - eval(y)
  ...
}

Example:

eval(Plus(Lit(3),Lit(4))) // => 7

And the direct compiler, too:

type Code = String

def trans(e: Exp): Code = e match {
  case Lit(x)     => s"$x"
  case Plus(x,y)  => s"(${trans(x)} + ${trans(y)})"
  case Minus(x,y) => s"(${trans(x)} - ${trans(y)})"
  ...
}

Examples:

eval(Plus(Lit(3),Lit(4)))              // => "(3+4)"
eval(Plus(Lit(1),Plus(Lit(2),Lit(3)))) // => "(1+(2+3))"

Note the use of parantheses in the output to preserve evaluation order and grouping.

Note (again): Compilers and interpreters are fundamentally linked (specialization, Futamura)

3.3 Compiling to Machine Code

Our first compiler is basically the identity transform. It does not lower the level of abstraction.

Machine language does not support nested expressions. Typically only operations on atomic values. Hence we need to store intermediate results somewhere.

Key: Strategy. First define interpreter. Lower level until desired. Think what can be pre-computed statically, i.e. nail down phase distinction. Then directly map to target 1:1.

3.3.1 Architecture Refresher

Machine language. Skylake. Assembly language is an abstraction. MyOps. ILP.

3.3.1.1 Memory Hierarchy

Memory hierarchy. DRAM (Flash, NVM, dunno, …), caches, registers. Fast = expensive.

3.3.2 Notional Machine

Here’s our updated conceptual model of the machine. We’re using a single level of memory for now, but we could make the split between registers and main memory (DRAM) explicit, too.

val memory: Array[Int](MEM_SIZE)
memory(0) = 4
memory(1) = 5
memory(2) = memory(0) + memory(1)

Some architectures (like x86) only allow two arguments, where one is updated in-place:

memory(i) += memory(j)

In fact, most architectures including x86 only allows writes to registers, not DRAM.

Some architectures like RISC-V only allow reading from registers, too. These architectures provide separate instructions to load memory into registers and store registers back to memory (load/store architecture). On the other hand RISC-V supports the full three-address form:

r1 = r2 + r3

3.3.3 Stack-Based Interpreter and Compiler

To understand what we need to do, we refactor our interpreter so that it doesn’t return values directly, but rather follows a convention where in memory the result values are stored. For this the interpreter receives a target memory address sp as a parameter. For literals the implementation is clear:

var memory = new Array[Int](MEM_SIZE)

def eval(e: Exp, sp: Int): Unit = e match {
  case Lit(x)     => memory(sp) = x
  case Plus(x,y)  => eval(x,sp)
                     ???
  ...
}

But what should we do for nested expressions? The key observation is that we can follow a stack-discipline. Any memory space upwards of sp can freely be used as scratch data:

var memory = new Array[Int](MEM_SIZE)

def eval(e: Exp, sp: Int): Unit = e match {
  case Lit(x)     => memory(sp) = x
  case Plus(x,y)  => eval(x,sp); eval(y,sp+1)
                     memory(sp) += memory(sp+1)
  ...
}

We can derive a stack-based compiler that follows exactly the same pattern:

def trans(e: Exp, sp: Int): Unit = e match {
  case Lit(x)     => println(s"memory($sp) = $x")
  case Plus(x,y)  => trans(x,sp); trans(y,sp+1)
                     println(s"memory($sp) += memory(${sp+1})")
  ...
}

Example:

trans(Plus(Lit(1),Plus(Lit(2),Lit(3))),0) // 1+(2+3)

Output:

memory(0) = 1
memory(1) = 2
memory(2) = 3
memory(1) += memory(2)
memory(0) += memory(1)

3.3.4 Targeting x86

3.3.4.1 Registers Only

For x86, as discussed above, we have to be explicit which kind of memory will be used. Let’s use CPU registers as the first idea:

val regs = Seq("%rbx", "%rcx", "%rdi", "%rsi", "%r8", "%r9")
def trans(e: Exp, sp: Int): Unit = e match {
  case Lit(x)     => println(s"movq $$$x, ${regs(sp)}")
  case Plus(x,y)  => trans(x,sp); trans(y,sp+1)
                     println(s"addq ${regs(sp+1)}, ${regs(sp)}")
  ...
}

trans(Plus(Lit(1),Plus(Lit(2),Lit(3))),0) // 1+(2+3)

movq $1, %rbx
movq $2, %rcx
movq $3, %rdi
addq %rdi, %rcx
addq %rcx, %rbx

Note: why not rax?

We used the fast memory (registers) instead of (slow) main memory. Now the problem is that we only have a finite set of registers, so at some point we can’t avoid using main memory.

Note: introduce x86 stack

3.3.5 Stack-Based Evaluation with Register Allocation

3.3.5.1 Idea 1: Registers First, Memory Second

def mem(i: Int) = if (i < regs.length) regs(i) else s"stack(${i-regs.length})"

def trans(e: Exp, sp: Int): Unit = e match {
  case Lit(x)     => println(s"movq $$$x, ${mem(sp)}")
  case Plus(x,y)  => trans(x,sp); trans(y,sp+1)
                     println(s"addq ${mem(sp+1)}, ${mem(sp)}")
  ...
}

Problem: still need an accumulator or something else, most instructions can’t operate on memory only and need at least one register operand

3.3.5.2 Idea 2: Registers as Circular Buffer

We’re most likely to operate on the top of the stack, so when the stack grows, it would make much more sense to store the top of the stack in registers, rather than the bottom, as with the idea above. So the next idea is to use our n available registers as the top n stack slots. We do this by treating the register file as a circular buffer:

var sp = 0
def mem(i: Int) = regs(i % regs.length) // TODO: may need to access mem, too?
def grow() = {
  if (sp >= regs.length) emitln(s"push ${mem(sp)} # evict ${sp-regs.length}")
  sp += 1
}
def shrink() = {
  sp -= 1
  if (sp >= regs.length) emitln(s"pop ${mem(sp)} # reload ${sp-regs.length}")
}
def trans(e: Exp): Unit = e match {
  case Lit(x)     => grow(); emitln(s"movq $$$x, ${mem(sp-1)} # ${sp-1}");
  case Plus(x,y)  => trans(x); trans(y); shrink(); emitln(s"addq ${mem(sp)}, ${mem(sp-1)}")
  case Minus(x,y) => trans(x); trans(y); shrink(); emitln(s"subq ${mem(sp)}, ${mem(sp-1)}")
  ...
}

Problem: too much memory traffic once we’ve exceeded available registers. Solution: add some slack, don’t try to fill up all registers immediately when the stack shrinks, but cache up to n elements.

3.3.5.3 Idea 3: Registers as Circular Buffer with Lazy Reloading

The idea is to treat the register file as a ring buffer that caches the top n elements of the stack. We want to do this lazily. Tracking exactly n elements would cause to much traffic back and forth. So instead, we allow the buffer to become empty if the stack shrinks. This enables us to reload values lazily. The benefit is that there is no memory traffic if the stack height remains within the window given by the register set. (There is some flexibility how to do this exactly. We chose to reload before each write, so that writes always go to registers).

var sp = 0
var offset = 0

def mem(i: Int) =
  if (i >= offset) regs(i % regs.length)
  else s"stack($i)" // TODO

def grow() = {
  // about to write to sp
  if (sp == offset + regs.length) {
    emitln(s"push ${mem(offset)} # evict ${offset}")
    offset += 1
  }
  sp += 1
}

def shrink() = {
  sp -= 1
  // about to write to sp-1
  if (sp == offset) {
    offset -= 1
    emitln(s"pop ${mem(offset)} # reload ${offset}")
  }
}

def trans(e: Exp): Unit = e match {
  case Lit(x)     => grow(); emitln(s"movq $$$x, ${mem(sp-1)} # ${sp-1}");
  case Plus(x,y)  => trans(x); trans(y); shrink(); emitln(s"addq ${mem(sp)}, ${mem(sp-1)}")
  case Minus(x,y) => trans(x); trans(y); shrink(); emitln(s"subq ${mem(sp)}, ${mem(sp-1)}")
  ...
}

Result for (1+(2+(3+(4+(5+(6+(7+(8+(9+10))))))))) with only three available registers %rax,%rdx“,%rcx:

movq $1, %rax     # 0
movq $2, %rdx     # 1
movq $3, %rcx     # 2
push %rax         # evict 0
movq $4, %rax     # 3
push %rdx         # evict 1
movq $5, %rdx     # 4
push %rcx         # evict 2
movq $6, %rcx     # 5
push %rax         # evict 3
movq $7, %rax     # 6
push %rdx         # evict 4
movq $8, %rdx     # 7
push %rcx         # evict 5
movq $9, %rcx     # 8
push %rax         # evict 6
movq $10, %rax    # 9
addq %rax, %rcx
addq %rcx, %rdx
pop %rax          # reload 6
addq %rdx, %rax
pop %rcx          # reload 5
addq %rax, %rcx
pop %rdx          # reload 4
addq %rcx, %rdx
pop %rax          # reload 3
addq %rdx, %rax
pop %rcx          # reload 2
addq %rax, %rcx
pop %rdx          # reload 1
addq %rcx, %rdx
pop %rax          # reload 0
addq %rdx, %rax

Todo: pick a better example that illustrates stack window, perhaps (1+(2+(3+(4+5)))) + (1+(2+(3+(4+5)))) or something like that

3.4 Parsing

Let’s recap the grammar we defined for the AST. Before we had

<exp> ::= <num> + <num>
       | ...

which directly mapped to parser logic but now we have:

<exp> ::= <num>
       |  <exp> + <exp>
       |  <exp> - <exp>
       |  <exp> * <exp>
       |  <exp> / <exp>
       |  <exp> % <exp>

Key difficulty for parsing: this grammar is ambigous. A string like 3+4*5 matches the grammar in more than one way, i.e., as Plus(3,Times(4,5)) or as Times(Plus(3,4),5). The parser must make decisions to pick the right one.

Base class for our parser – unchanged from last chapter.

def isDigit(c: Char) = '0' <= c && c <= '9'

def getNum(): Int = {
   if (in.hasNext(isDigit)) (in.next() - '0')
   else expected("number")
}

Todo: use version for multiple digits

We add a convenience function to consume a given delimiter character:

def accept(c: Char) =
   if (in.hasNext(_ == c)) in.next()
   else expected(s"'$c'")

3.4.1 Parenthesized Expressions

First, paranthesized expressions like (3+4)+5.

We define a grammar useful for parsing this:

<term> ::= '(' <expr> ')' | num
<expr> ::= <term> + <term>

Recall parseTerm:

def parseTerm: Exp = Lit(getNum)

We change it to:

def parseTerm(): Exp = if (in.peek == '(') {
  accept('(')
  val res = parseExpression()
  accept(')')
  res
} else Lit(getNum())

Now we can parse and run:

run(trans(parse("(3+4)*5")))   // 35

Take-away: parser becomes recursive, just like the AST definition and the code generator.

3.4.2 Operator Sequences

What about sequences like 3+4+5? We want this to be left-associative, i.e., parse as (3+4)+5. Grammar:

<expr> ::= <expr> '+' <term> | <term>

If we wanted it to be right-associative, i.e.,

<expr> ::= <term> '+' <expr> | <term>

we could do this:

def expr() = {
   val x = term()
   if (in.hasNext(isOperator)) in.peek match {
      case '+' => in.next(); Plus(x,expr())
      case '-' => in.next(); Minus(x,expr())
      ...
   }
}

But we cannot put the recursion on the left, i.e.:

def expr() = {
   val x = expr()
   if (in.hasNext(isOperator)) in.peek match {
      case '+' => in.next(); Plus(x,term())
      case '-' => in.next(); Minus(x,term())
      ...
   }
}

Why? Well think about it. We make recursive calls without ever advancing the position. A recursive call in exactly the same state cannot terminate, so the result will just be an infinite loop.

However, this is easy to fix. We need to turn recursion into iteration and assemble the parse tree from the right. On the grammar level, we change

<expr> ::= <expr> '+' <term> | <term>

into

<expr> ::= <term> ['+' <term>]*

and in the implementation, we create the corresponding AST nodes bottom up.

In code:

def expr(): Exp = {
   var res = term()
   while (in.hasNext(isOperator)) in.peek match {
      case '+' => in.next(); Plus(res,term())
      case '-' => in.next(); Minus(res,term())
      //...
   }
   res
}

We can now successfully parse expressions like “1+2+3” into

  Plus(Plus(Lit(1),Lit(2)),Lit(3))

or the equivalent of “(1+2)+3”.

But what about “1+2*3“?

With the current logic, this will parse as “(1+2)*3“, which is probably not what we want.

3.4.3 Operator Precedence

Now the problem is that 3+4*5 parses as (3+4)*5, but we would really like it to parse according to standard precedence rules, namely 3+(4*5).

To achieve this, we split a layer off of ‘term’, into ‘factor’. Now a term is a list of factors delimited by *,/,% and expr is a list of terms delimited by +,-.

Grammar with explicit operator precedence:

<addop>  ::= '+' | '-'
<mulop>  ::= '*' | '/'
<factor> ::= <num> | '(' <expr> ')'
<term>   ::= <factor> [<mulop> <factor>]*
<expr>   ::= <term> [<addop> <term>]*

Code:

def factor(): Exp = if (in.peek == '(') {
  accept('(')
  val res = expr()
  accept(')')
  res
} else Lit(getNum())

def term() = {
  var x = factor()
  while (in.hasNext(Set('*','/','%'))) {
    x = in.peek match {
      case '*' => in.next(); Times(x, factor())
      case '/' => in.next(); Div(x, factor())
      case '%' => in.next(); Mod(x, factor())
    }
  }
  x
}

def expr() = {
  var x = term()
  while (in.hasNext(Set('+','-'))) {
    x = in.peek match {
      case '+' => in.next(); Plus(x, term())
      case '-' => in.next(); Minus(x, term())
    }
  }
  x
}

Result:

assert(parse("(3+4)*5") == Times(Plus(Lit(3),Lit(4)),Lit(5)))
assert(runProgram(transProgram(parse("3+4*5"))) == 23)

3.5 Generalize

As a final simplification, we can collapse our AST nodes for all primitives into a single class and use the operator as a name. We also provision for primitives with more than two arguments, or fewer, such as unary minus. Finally, op is a String instead of a Char, so in the future it could be ==, !==, <=, >=, etc.

Language:

abstract class Exp
case class Lit(x: Int) extends Exp
case class Prim(op: String, xs: List[Exp]) extends Exp

Codegen:

def trans(e: Exp): Unit = e match {
  case Lit(x)               => grow(); emitln(s"movq $$$x, ${mem(sp-1)} # ${sp-1}");
  case Prim("+",List(x,y))  => trans(x); trans(y); shrink(); emitln(s"addq ${mem(sp)}, ${mem(sp-1)}")
  case Prim("-",List(x,y))  => trans(x); trans(y); shrink(); emitln(s"subq ${mem(sp)}, ${mem(sp-1)}")
  ...
}

Parser – generic operator precedence can be done like this.

We define a grammar that can be easily updated with new operators:

<op>   ::= '+' | '-' | '*' | '/'
<atom> ::= <num> | '(' <expr> ')'
<expr> ::= <atom> [<op> <atom>]*

Todo: include numeric precedence in grammar

In addition, we define a precedence level for each operator. From low to high:

  • ‘+’, ‘-’
  • ’*‘,’/’

We also need to define the associativity of each operator.

2 * 3 / 4 => (2 * 3) / 4 // left associative
x = y = z => x = (y = z) // right associative (in C)

Code:

def atom() = ... // num | '(' expr() ')'

def prec(op: String) = op match { // higher binds tighter
  case "+" | "-"       => 1
  case "*" | "/" | "%" => 2
}
def assoc(op: String) = op match { // 1 left, 0 right
  case "+" | "-"       => 1
  case "*" | "/" | "%" => 1
}

def binop(min: Int): String = {
  var res = atom()
  while (hasNext(isOperator) && prec(in.peek) >= min) {
    val op = in.peek; in.next()
    val nextMin = prec(op) + assoc(op) // + 1 for left assoc
    res = Prim(op, res, binop(nextMin))
  }
  res
}

def expr() = binop(0)

Note: examples

3.6 Pretty-Printing

Sometimes we want to print out an AST for inspection. We can use our direct compiler and improve the output based on our general operator precedence facilities.

Todo

3.7 Where are we?

Wrap up. Outcomes.

Things missing from parser:

  • Handle whitespace
  • Keep track of source positions for errors