Skip to content

Commit

Permalink
macros! (chapter 5 content)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas-young1 committed Jun 29, 2023
1 parent cf60c72 commit fa4b994
Show file tree
Hide file tree
Showing 14 changed files with 786 additions and 1 deletion.
25 changes: 25 additions & 0 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,28 @@ func (hl *HashLiteral) String() string {

return out.String()
}

type MacroLiteral struct {
Token token.Token // the 'macro' token
Parameters []*Identifier
Body *BlockStatement
}

func (ml *MacroLiteral) expressionNode() {}
func (ml *MacroLiteral) TokenLiteral() string { return ml.Token.Literal }
func (ml *MacroLiteral) String() string {
var out bytes.Buffer

params := []string{}
for _, p := range ml.Parameters {
params = append(params, p.String())
}

out.WriteString(ml.TokenLiteral())
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") ")
out.WriteString(ml.Body.String())

return out.String()
}
70 changes: 70 additions & 0 deletions ast/modify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package ast

type ModifierFunc func(Node) Node

// TODO: error handling....... replace all underscores w/ proper error handling

func Modify(node Node, modifier ModifierFunc) Node {
switch node := node.(type) {

case *Program:
for i, statement := range node.Statements {
node.Statements[i], _ = Modify(statement, modifier).(Statement)
}

case *ExpressionStatement:
node.Expression, _ = Modify(node.Expression, modifier).(Expression)

case *InfixExpression:
node.Left, _ = Modify(node.Left, modifier).(Expression)
node.Right, _ = Modify(node.Right, modifier).(Expression)

case *PrefixExpression:
node.Right, _ = Modify(node.Right, modifier).(Expression)

case *IndexExpression:
node.Left, _ = Modify(node.Left, modifier).(Expression)
node.Index, _ = Modify(node.Index, modifier).(Expression)

case *IfExpression:
node.Condition, _ = Modify(node.Condition, modifier).(Expression)
node.Consequence, _ = Modify(node.Consequence, modifier).(*BlockStatement)
if node.Alternative != nil {
node.Alternative, _ = Modify(node.Alternative, modifier).(*BlockStatement)
}

case *BlockStatement:
for i := range node.Statements {
node.Statements[i], _ = Modify(node.Statements[i], modifier).(Statement)
}

case *ReturnStatement:
node.ReturnValue, _ = Modify(node.ReturnValue, modifier).(Expression)

case *LetStatement:
node.Value, _ = Modify(node.Value, modifier).(Expression)

case *FunctionLiteral:
for i := range node.Parameters {
node.Parameters[i], _ = Modify(node.Parameters[i], modifier).(*Identifier)
}
node.Body, _ = Modify(node.Body, modifier).(*BlockStatement)

case *ArrayLiteral:
for i := range node.Elements {
node.Elements[i], _ = Modify(node.Elements[i], modifier).(Expression)
}

case *HashLiteral:
newPairs := make(map[Expression]Expression)

for key, val := range node.Pairs {
newKey, _ := Modify(key, modifier).(Expression)
newVal, _ := Modify(val, modifier).(Expression)
newPairs[newKey] = newVal
}
node.Pairs = newPairs
}

return modifier(node)
}
151 changes: 151 additions & 0 deletions ast/modify_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package ast

import (
"reflect"
"testing"
)

func TestModify(t *testing.T) {
one := func() Expression { return &IntegerLiteral{Value: 1} }
two := func() Expression { return &IntegerLiteral{Value: 2} }

turnOneIntoTwo := func(node Node) Node {
integer, ok := node.(*IntegerLiteral)
if !ok {
return node
}

if integer.Value != 1 {
return node
}

integer.Value = 2
return integer
}

tests := []struct {
input Node
expected Node
}{
{
one(),
two(),
},
{
&Program{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
&Program{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
},
{
&InfixExpression{Left: one(), Operator: "+", Right: two()},
&InfixExpression{Left: two(), Operator: "+", Right: two()},
},
{
&InfixExpression{Left: two(), Operator: "+", Right: one()},
&InfixExpression{Left: two(), Operator: "+", Right: two()},
},
{
&PrefixExpression{Operator: "-", Right: one()},
&PrefixExpression{Operator: "-", Right: two()},
},
{
&IndexExpression{Left: one(), Index: one()},
&IndexExpression{Left: two(), Index: two()},
},
{
&IfExpression{
Condition: one(),
Consequence: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
Alternative: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
},
&IfExpression{
Condition: two(),
Consequence: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
Alternative: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
},
},
{
&ReturnStatement{ReturnValue: one()},
&ReturnStatement{ReturnValue: two()},
},
{
&LetStatement{Value: one()},
&LetStatement{Value: two()},
},
{
&FunctionLiteral{
Parameters: []*Identifier{},
Body: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
},
&FunctionLiteral{
Parameters: []*Identifier{},
Body: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
},
},
{
&ArrayLiteral{Elements: []Expression{one(), one()}},
&ArrayLiteral{Elements: []Expression{two(), two()}},
},
}

for _, tt := range tests {
modified := Modify(tt.input, turnOneIntoTwo)

equal := reflect.DeepEqual(modified, tt.expected)
if !equal {
t.Errorf("not equal. got=%#v, want=%#v", modified, tt.expected)
}
}

hashLiteral := &HashLiteral{
Pairs: map[Expression]Expression{
one(): one(),
two(): two(),
},
}

Modify(hashLiteral, turnOneIntoTwo)

for key, val := range hashLiteral.Pairs {
key, _ := key.(*IntegerLiteral)
if key.Value != 2 {
t.Errorf("value is not %d, got=%d", 2, key.Value)
}

val, _ := val.(*IntegerLiteral)
if val.Value != 2 {
t.Errorf("value is not %d, got=%d", 2, val.Value)
}
}
}
4 changes: 4 additions & 0 deletions evaluator/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
return &object.Function{Parameters: params, Body: body, Env: env}

case *ast.CallExpression:
if node.Function.TokenLiteral() == "quote" {
return quote(node.Arguments[0], env)
}

function := Eval(node.Function, env)
if isError(function) {
return function
Expand Down
119 changes: 119 additions & 0 deletions evaluator/macro_expansion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package evaluator

import (
"monkey/ast"
"monkey/object"
)

// TODO: error handling and debugging (near end of 5th chap)

func DefineMacros(program *ast.Program, env *object.Environment) {
definitions := []int{}

for i, statement := range program.Statements {
if isMacroDefinition(statement) {
addMacro(statement, env)
definitions = append(definitions, i)
}
}

for i := len(definitions) - 1; i >= 0; i -= 1 {
definitionIndex := definitions[i]
program.Statements = append(
program.Statements[:definitionIndex],
program.Statements[definitionIndex+1:]...,
)
}
}

func isMacroDefinition(node ast.Statement) bool {
letStatement, ok := node.(*ast.LetStatement)
if !ok {
return false
}

_, ok = letStatement.Value.(*ast.MacroLiteral)
if !ok {
return ok
}

return true
}

func addMacro(stmt ast.Statement, env *object.Environment) {
letStatement, _ := stmt.(*ast.LetStatement)
macroliteral, _ := letStatement.Value.(*ast.MacroLiteral)

macro := &object.Macro{
Parameters: macroliteral.Parameters,
Env: env,
Body: macroliteral.Body,
}

env.Set(letStatement.Name.Value, macro)
}

func ExpandMacros(program *ast.Program, env *object.Environment) ast.Node {
return ast.Modify(program, func(node ast.Node) ast.Node {
callExpression, ok := node.(*ast.CallExpression)
if !ok {
return node
}

macro, ok := isMacroCall(callExpression, env)
if !ok {
return node
}

args := quoteArgs(callExpression)
evalEnv := extendMacroEnv(macro, args)

evaluated := Eval(macro.Body, evalEnv)

quote, ok := evaluated.(*object.Quote)
if !ok {
panic("we only support returning AST-nodes from macros")
}

return quote.Node
})
}

func isMacroCall(exp *ast.CallExpression, env *object.Environment) (*object.Macro, bool) {
identifier, ok := exp.Function.(*ast.Identifier)
if !ok {
return nil, false
}

obj, ok := env.Get(identifier.Value)
if !ok {
return nil, false
}

macro, ok := obj.(*object.Macro)
if !ok {
return nil, false
}

return macro, true
}

func quoteArgs(exp *ast.CallExpression) []*object.Quote {
args := []*object.Quote{}

for _, a := range exp.Arguments {
args = append(args, &object.Quote{Node: a})
}

return args
}

func extendMacroEnv(macro *object.Macro, args []*object.Quote) *object.Environment {
extended := object.NewEnclosedEnvironment(macro.Env)

for paramIdx, param := range macro.Parameters {
extended.Set(param.Value, args[paramIdx])
}

return extended
}
Loading

0 comments on commit fa4b994

Please sign in to comment.