diff --git a/README.md b/README.md index 89fa047..a26d072 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -Golang implementation of a basic tree walk interpreter +Golang implementation of a basic tree walk interpreter. Find the binaries for windows, linux and macos under the releases section. @@ -8,4 +8,5 @@ https://craftinginterpreters.com/ - [x] Representing Code - [x] Parsing Expressions - [x] Evaluating Expressions -- [x] Statements and State \ No newline at end of file +- [x] Statements and State +- [x] Control Flow \ No newline at end of file diff --git a/pkg/ast/expr.go b/pkg/ast/expr.go index 93a7763..229cb2f 100644 --- a/pkg/ast/expr.go +++ b/pkg/ast/expr.go @@ -57,4 +57,14 @@ type AssignExpr struct { func (a *AssignExpr) Accept(v ExprVisitor) interface{} { return v.VisitAssignExpr(a) +} + +type LogicalExpr struct { + Left Expr + Operator *scanner.Token + Right Expr +} + +func (l *LogicalExpr) Accept(v ExprVisitor) interface{} { + return v.VisitLogicalExpr(l) } \ No newline at end of file diff --git a/pkg/ast/parser.go b/pkg/ast/parser.go index 8e1c62a..a122c79 100644 --- a/pkg/ast/parser.go +++ b/pkg/ast/parser.go @@ -60,6 +60,22 @@ func (p *parser) statement() Stmt { return p.printStatement() } + if p.match(scanner.IF) { + return p.ifStatement() + } + + if p.match(scanner.FOR) { + return p.forStatement() + } + + if p.match(scanner.WHILE) { + return p.whileStatement() + } + + if p.match(scanner.LEFT_BRACE) { + return &BlockStmt{p.block()} + } + return p.exprStatement() } @@ -73,6 +89,100 @@ func (p *parser) printStatement() Stmt { return &PrintStmt{expr} } +func (p *parser) ifStatement() Stmt { + if !p.match(scanner.LEFT_PAREN) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected '(' after if")) + } + + condition := p.expression() + if !p.match(scanner.RIGHT_PAREN) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected ')' after conditional expression")) + } + + thenBranch := p.statement() + var elseBranch Stmt + if p.match(scanner.ELSE) { + elseBranch = p.statement() + } + + return &IfStmt{condition, thenBranch, elseBranch} +} + +func (p *parser) forStatement() Stmt { + if !p.match(scanner.LEFT_PAREN) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected '(' after for")) + } + + var initializer Stmt + if p.match(scanner.SEMICOLON) { + initializer = nil + } else if p.match(scanner.VAR) { + initializer = p.varDeclaration() + } else { + initializer = p.exprStatement() + } + + var condition Expr + if p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.SEMICOLON { + condition = p.expression() + } + if !p.match(scanner.SEMICOLON) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected ';' after conditional expression")) + } + + var increment Expr + if p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.RIGHT_PAREN { + increment = p.expression() + } + if !p.match(scanner.RIGHT_PAREN) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected ')' after for clause")) + } + + body := p.statement() + if increment != nil { + body = &BlockStmt{[]Stmt{body, &ExprStmt{increment}}} + } + + if condition == nil { + condition = &LiteralExpr{true} + } + + body = &WhileStmt{condition, body} + + if initializer != nil { + body = &BlockStmt{[]Stmt{initializer, body}} + } + + return body +} + +func (p *parser) whileStatement() Stmt { + if !p.match(scanner.LEFT_PAREN) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected '(' after while")) + } + + condition := p.expression() + if !p.match(scanner.RIGHT_PAREN) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected ')' after conditional expression")) + } + + return &WhileStmt{condition, p.statement()} +} + +func (p *parser) block() []Stmt { + stmts := []Stmt{} + + for p.tokens[p.current].TokenType != scanner.EOF && p.tokens[p.current].TokenType != scanner.RIGHT_BRACE { + stmts = append(stmts, p.declaration()) + } + + if !p.match(scanner.RIGHT_BRACE) { + panic(fault.NewFault(p.tokens[p.current].Line, "expected '}' after block")) + } + + return stmts +} + func (p *parser) exprStatement() Stmt { expr := p.expression() @@ -88,7 +198,7 @@ func (p *parser) expression() Expr { } func (p *parser) assignment() Expr { - expr := p.equality() + expr := p.or() if p.match(scanner.EQUAL) { equals := p.tokens[p.current-1] @@ -97,13 +207,37 @@ func (p *parser) assignment() Expr { if variable, ok := expr.(*VariableExpr); ok { return &AssignExpr{variable.Name, value} } - + fault.NewFault(equals.Line, "invalid assignment target") } return expr } +func (p *parser) or() Expr { + left := p.and() + + for p.match(scanner.OR) { + operator := p.tokens[p.current-1] + right := p.and() + left = &LogicalExpr{left, &operator, right} + } + + return left +} + +func (p *parser) and() Expr { + left := p.equality() + + for p.match(scanner.AND) { + operator := p.tokens[p.current-1] + right := p.equality() + left = &LogicalExpr{left, &operator, right} + } + + return left +} + func (p *parser) equality() Expr { left := p.comparison() diff --git a/pkg/ast/stmt.go b/pkg/ast/stmt.go index 820003e..34e7ad6 100644 --- a/pkg/ast/stmt.go +++ b/pkg/ast/stmt.go @@ -29,4 +29,31 @@ type VarStmt struct { func (v *VarStmt) Accept(v_ StmtVisitor) interface{} { return v_.VisitVarStmt(v) +} + +type BlockStmt struct { + Statements []Stmt +} + +func (b *BlockStmt) Accept(v StmtVisitor) interface{} { + return v.VisitBlockStmt(b) +} + +type IfStmt struct { + Condition Expr + ThenBranch Stmt + ElseBranch Stmt +} + +func (i *IfStmt) Accept(v StmtVisitor) interface{} { + return v.VisitIfStmt(i) +} + +type WhileStmt struct { + Condition Expr + Body Stmt +} + +func (w *WhileStmt) Accept(v StmtVisitor) interface{} { + return v.VisitWhileStmt(w) } \ No newline at end of file diff --git a/pkg/ast/visitor.go b/pkg/ast/visitor.go index c510053..fee47c2 100644 --- a/pkg/ast/visitor.go +++ b/pkg/ast/visitor.go @@ -7,10 +7,14 @@ type ExprVisitor interface { VisitUnaryExpr(u *UnaryExpr) interface{} VisitVariableExpr(v *VariableExpr) interface{} VisitAssignExpr(a *AssignExpr) interface{} + VisitLogicalExpr(l *LogicalExpr) interface{} } type StmtVisitor interface { VisitExprStmt(e *ExprStmt) interface{} VisitPrintStmt(p *PrintStmt) interface{} VisitVarStmt(p *VarStmt) interface{} + VisitBlockStmt(b *BlockStmt) interface{} + VisitIfStmt(i *IfStmt) interface{} + VisitWhileStmt(w *WhileStmt) interface{} } \ No newline at end of file diff --git a/pkg/interpreter/environment.go b/pkg/interpreter/environment.go index abaaedf..fd87639 100644 --- a/pkg/interpreter/environment.go +++ b/pkg/interpreter/environment.go @@ -8,7 +8,8 @@ import ( ) type environment struct { - values map[string]interface{} + enclosing *environment + values map[string]interface{} } func (e *environment) get(name *scanner.Token) interface{} { @@ -16,6 +17,10 @@ func (e *environment) get(name *scanner.Token) interface{} { return value } + if e.enclosing != nil { + return e.enclosing.get(name) + } + message := fmt.Sprintf("undefined variable %s", name.Lexeme) panic(fault.NewFault(name.Line, message)) } @@ -23,6 +28,8 @@ func (e *environment) get(name *scanner.Token) interface{} { func (e *environment) assign(name *scanner.Token, value interface{}) { if _, ok := e.values[name.Lexeme]; ok { e.values[name.Lexeme] = value + } else if e.enclosing != nil { + e.enclosing.assign(name, value) } else { message := fmt.Sprintf("undefined variable %s", name.Lexeme) panic(fault.NewFault(name.Line, message)) diff --git a/pkg/interpreter/interpreter.go b/pkg/interpreter/interpreter.go index 20e14cb..30b4755 100644 --- a/pkg/interpreter/interpreter.go +++ b/pkg/interpreter/interpreter.go @@ -15,7 +15,7 @@ type interpreter struct { } func NewInterpreter() *interpreter { - env := &environment{map[string]interface{}{}} + env := &environment{nil, map[string]interface{}{}} return &interpreter{env} } @@ -64,6 +64,40 @@ func (i *interpreter) VisitVarStmt(v *ast.VarStmt) interface{} { return nil } +func (i *interpreter) VisitBlockStmt(b *ast.BlockStmt) interface{} { + prev := i.env + + defer func() { + i.env = prev + }() + + i.env = &environment{prev, map[string]interface{}{}} + for _, stmt := range b.Statements { + stmt.Accept(i) + } + + return nil +} + +func (i *interpreter) VisitIfStmt(i_ *ast.IfStmt) interface{} { + value := i_.Condition.Accept(i) + if isTruthy(value) { + i_.ThenBranch.Accept(i) + } else if i_.ElseBranch != nil { + i_.ElseBranch.Accept(i) + } + + return nil +} + +func (i *interpreter) VisitWhileStmt(w *ast.WhileStmt) interface{} { + for isTruthy(w.Condition.Accept(i)) { + w.Body.Accept(i) + } + + return nil +} + func (i *interpreter) VisitBinaryExpr(b *ast.BinaryExpr) interface{} { left := b.Left.Accept(i) right := b.Right.Accept(i) @@ -156,6 +190,16 @@ func (i *interpreter) VisitAssignExpr(a *ast.AssignExpr) interface{} { return value } +func (i *interpreter) VisitLogicalExpr(l *ast.LogicalExpr) interface{} { + left := l.Left.Accept(i) + + if (l.Operator.TokenType == scanner.OR && isTruthy(left)) || !isTruthy(left) { + return left + } + + return l.Right.Accept(i) +} + func (i *interpreter) checkNumberOperands(operator *scanner.Token, left interface{}, right interface{}) (float64, float64) { if leftValue, leftOk := left.(float64); leftOk { if rightValue, rightOk := right.(float64); rightOk { @@ -164,4 +208,16 @@ func (i *interpreter) checkNumberOperands(operator *scanner.Token, left interfac } panic(fault.NewFault(operator.Line, "operands must be numbers")) +} + +func isTruthy(value interface{}) bool { + if value == nil { + return false + } + + if boolean, ok := value.(bool); ok { + return boolean + } + + return true } \ No newline at end of file