Skip to content

Commit

Permalink
Separate run options from model (#421)
Browse files Browse the repository at this point in the history
Instead of storing them directly on the model, pass them as a param.

The model now stores the defaults which come from the amod file. These may overridden by the command line or the web API.
  • Loading branch information
asmaloney authored Mar 1, 2024
1 parent ddcc20b commit 90dd90f
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 76 deletions.
12 changes: 7 additions & 5 deletions actr/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ type Model struct {

Productions []*Production

runoptions.Options
// These defaults come from the amod file and may be overridden on the command line
// or by web requests.
DefaultParams runoptions.Options

// Used to validate our parameters
parameters param.ParametersInterface
Expand Down Expand Up @@ -80,7 +82,7 @@ func (model *Model) Initialize() {
model.Procedural = modules.NewProcedural()
model.Modules = append(model.Modules, model.Procedural)

model.LogLevel = "info"
model.DefaultParams = runoptions.New()

// Declare our parameters
loggingParam := param.NewStr(
Expand Down Expand Up @@ -293,16 +295,16 @@ func (model *Model) SetParam(kv *keyvalue.KeyValue) (err error) {

switch kv.Key {
case "log_level":
model.LogLevel = runoptions.ACTRLogLevel(*value.Str)
model.DefaultParams.LogLevel = runoptions.ACTRLogLevel(*value.Str)

case "trace_activations":
boolVal, _ := value.AsBool() // already validated
model.TraceActivations = boolVal
model.DefaultParams.TraceActivations = boolVal

case "random_seed":
seed := uint32(*value.Number)

model.RandomSeed = &seed
model.DefaultParams.RandomSeed = &seed

default:
return param.ErrUnrecognizedOption{Option: kv.Key}
Expand Down
37 changes: 19 additions & 18 deletions framework/ccm_pyactr/ccm_pyactr.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/asmaloney/gactar/util/filesystem"
"github.com/asmaloney/gactar/util/issues"
"github.com/asmaloney/gactar/util/numbers"
"github.com/asmaloney/gactar/util/runoptions"
)

//go:embed ccm_print.py
Expand Down Expand Up @@ -115,8 +116,8 @@ func (c CCMPyACTR) Model() (model *actr.Model) {

// Run generates the python code from the amod file, writes it to disk, creates a "run" file
// to actually run the model, and returns the output (stdout and stderr combined).
func (c *CCMPyACTR) Run(initialBuffers framework.InitialBuffers) (result *framework.RunResult, err error) {
runFile, err := c.WriteModel(c.tmpPath, initialBuffers)
func (c *CCMPyACTR) Run(options *runoptions.Options, initialBuffers framework.InitialBuffers) (result *framework.RunResult, err error) {
runFile, err := c.WriteModel(c.tmpPath, options, initialBuffers)
if err != nil {
return
}
Expand All @@ -137,7 +138,7 @@ func (c *CCMPyACTR) Run(initialBuffers framework.InitialBuffers) (result *framew
}

// WriteModel converts the internal actr.Model to Python and writes it to a file.
func (c *CCMPyACTR) WriteModel(path string, initialBuffers framework.InitialBuffers) (outputFileName string, err error) {
func (c *CCMPyACTR) WriteModel(path string, options *runoptions.Options, initialBuffers framework.InitialBuffers) (outputFileName string, err error) {
// If our model has a print statement, then write out our support file
if c.model.HasPrintStatement() {
err = framework.WriteSupportFile(path, ccmPrintFileName, ccmPrintPython)
Expand All @@ -147,7 +148,7 @@ func (c *CCMPyACTR) WriteModel(path string, initialBuffers framework.InitialBuff
}

// If our model is tracing activations, then write out our support file
if c.model.TraceActivations {
if options.TraceActivations {
err = framework.WriteSupportFile(path, gactarActivateTraceFileName, gactarActivateTraceFile)
if err != nil {
return
Expand All @@ -164,7 +165,7 @@ func (c *CCMPyACTR) WriteModel(path string, initialBuffers framework.InitialBuff
return "", err
}

_, err = c.GenerateCode(initialBuffers)
_, err = c.GenerateCode(options, initialBuffers)
if err != nil {
return
}
Expand All @@ -178,7 +179,7 @@ func (c *CCMPyACTR) WriteModel(path string, initialBuffers framework.InitialBuff
}

// GenerateCode converts the internal actr.Model to Python code.
func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []byte, err error) {
func (c *CCMPyACTR) GenerateCode(options *runoptions.Options, initialBuffers framework.InitialBuffers) (code []byte, err error) {
patterns, err := framework.ParseInitialBuffers(c.model, initialBuffers)
if err != nil {
return
Expand All @@ -195,13 +196,13 @@ func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code

memory := c.model.Memory

c.writeImports()
c.writeImports(options)

c.Write("\n\n")

// random
if c.model.RandomSeed != nil {
c.Writeln("random.seed(%d)", *c.model.RandomSeed)
if options.RandomSeed != nil {
c.Writeln("random.seed(%d)", *options.RandomSeed)
c.Write("\n\n")
}

Expand Down Expand Up @@ -237,7 +238,7 @@ func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code
c.Writeln(" %s = Memory(%s)", memory.ModuleName(), memory.BufferName())
}

if c.model.TraceActivations {
if options.TraceActivations {
c.Writeln(" trace = ActivateTrace(%s)", memory.ModuleName())
}

Expand Down Expand Up @@ -286,7 +287,7 @@ func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code
c.Writeln("")
}

if c.model.LogLevel == "info" {
if options.LogLevel == "info" {
// this turns on some logging at the high level
c.Writeln(" def __init__(self):")
c.Writeln(" super().__init__(log=True)")
Expand All @@ -308,7 +309,7 @@ func (c *CCMPyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code

c.Writeln("")

c.writeMain()
c.writeMain(options)

code = c.GetContents()
return
Expand Down Expand Up @@ -346,8 +347,8 @@ func (c CCMPyACTR) writeAuthors() {
c.Writeln("")
}

func (c CCMPyACTR) writeImports() {
if c.model.RandomSeed != nil {
func (c CCMPyACTR) writeImports(runOptions *runoptions.Options) {
if runOptions.RandomSeed != nil {
c.Writeln("import random")
}

Expand Down Expand Up @@ -379,7 +380,7 @@ func (c CCMPyACTR) writeImports() {
c.Write("from python_actr import %s\n", strings.Join(additionalImports, ", "))
}

if c.model.LogLevel == "detail" {
if runOptions.LogLevel == "detail" {
c.Writeln("from python_actr import log, log_everything")
}

Expand All @@ -388,7 +389,7 @@ func (c CCMPyACTR) writeImports() {
c.Writeln(fmt.Sprintf("from %s import CCMPrint", ccmPrintImportName))
}

if c.model.TraceActivations {
if runOptions.TraceActivations {
c.Writeln("")
c.Writeln(fmt.Sprintf("from %s import ActivateTrace", gactarActivateTraceImportName))
}
Expand Down Expand Up @@ -489,11 +490,11 @@ func (c CCMPyACTR) writeProductions() {
}
}

func (c CCMPyACTR) writeMain() {
func (c CCMPyACTR) writeMain(runOptions *runoptions.Options) {
c.Writeln("if __name__ == \"__main__\":")
c.Writeln(fmt.Sprintf(" model = %s()", c.className))

if c.model.LogLevel == "detail" {
if runOptions.LogLevel == "detail" {
c.Writeln(" log(summary=1)")
c.Writeln(" log_everything(model)")
}
Expand Down
7 changes: 4 additions & 3 deletions framework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/asmaloney/gactar/actr"

"github.com/asmaloney/gactar/util/issues"
"github.com/asmaloney/gactar/util/runoptions"
"github.com/asmaloney/gactar/util/version"
)

Expand Down Expand Up @@ -50,9 +51,9 @@ type Framework interface {
SetModel(model *actr.Model) (err error)
Model() (model *actr.Model)

Run(initialBuffers InitialBuffers) (result *RunResult, err error)
WriteModel(path string, initialBuffers InitialBuffers) (outputFileName string, err error)
GenerateCode(initialBuffers InitialBuffers) (code []byte, err error)
Run(options *runoptions.Options, initialBuffers InitialBuffers) (result *RunResult, err error)
WriteModel(path string, options *runoptions.Options, initialBuffers InitialBuffers) (outputFileName string, err error)
GenerateCode(options *runoptions.Options, initialBuffers InitialBuffers) (code []byte, err error)
}

type List map[string]Framework
Expand Down
31 changes: 16 additions & 15 deletions framework/pyactr/pyactr.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/asmaloney/gactar/util/filesystem"
"github.com/asmaloney/gactar/util/issues"
"github.com/asmaloney/gactar/util/numbers"
"github.com/asmaloney/gactar/util/runoptions"
)

//go:embed pyactr_print.py
Expand Down Expand Up @@ -132,8 +133,8 @@ func (p PyACTR) Model() (model *actr.Model) {
return p.model
}

func (p *PyACTR) Run(initialBuffers framework.InitialBuffers) (result *framework.RunResult, err error) {
runFile, err := p.WriteModel(p.tmpPath, initialBuffers)
func (p *PyACTR) Run(options *runoptions.Options, initialBuffers framework.InitialBuffers) (result *framework.RunResult, err error) {
runFile, err := p.WriteModel(p.tmpPath, options, initialBuffers)
if err != nil {
return
}
Expand All @@ -156,7 +157,7 @@ func (p *PyACTR) Run(initialBuffers framework.InitialBuffers) (result *framework
}

// WriteModel converts the internal actr.Model to Python and writes it to a file.
func (p *PyACTR) WriteModel(path string, initialBuffers framework.InitialBuffers) (outputFileName string, err error) {
func (p *PyACTR) WriteModel(path string, options *runoptions.Options, initialBuffers framework.InitialBuffers) (outputFileName string, err error) {
// If our model has a print statement, then write out our support file
if p.model.HasPrintStatement() {
err = framework.WriteSupportFile(path, pyactrPrintFileName, pyactrPrintPython)
Expand All @@ -175,7 +176,7 @@ func (p *PyACTR) WriteModel(path string, initialBuffers framework.InitialBuffers
return "", err
}

_, err = p.GenerateCode(initialBuffers)
_, err = p.GenerateCode(options, initialBuffers)
if err != nil {
return
}
Expand All @@ -189,7 +190,7 @@ func (p *PyACTR) WriteModel(path string, initialBuffers framework.InitialBuffers
}

// GenerateCode converts the internal actr.Model to Python code.
func (p *PyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []byte, err error) {
func (p *PyACTR) GenerateCode(options *runoptions.Options, initialBuffers framework.InitialBuffers) (code []byte, err error) {
patterns, err := framework.ParseInitialBuffers(p.model, initialBuffers)
if err != nil {
return
Expand All @@ -204,13 +205,13 @@ func (p *PyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []b

p.writeHeader()

p.writeImports()
p.writeImports(options)

p.Writeln("")

// random
if p.model.RandomSeed != nil {
p.Writeln("numpy.random.seed(%d)\n", *p.model.RandomSeed)
if options.RandomSeed != nil {
p.Writeln("numpy.random.seed(%d)\n", *options.RandomSeed)
}

memory := p.model.Memory
Expand Down Expand Up @@ -253,7 +254,7 @@ func (p *PyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []b
p.Writeln(" rule_firing=%s,", numbers.Float64Str(*procedural.DefaultActionTime))
}

if p.model.TraceActivations {
if options.TraceActivations {
p.Writeln(" activation_trace=True,")
}

Expand Down Expand Up @@ -329,7 +330,7 @@ func (p *PyACTR) GenerateCode(initialBuffers framework.InitialBuffers) (code []b
p.Writeln("")

// ...add our code to run
p.writeMain()
p.writeMain(options)

code = p.GetContents()
return
Expand Down Expand Up @@ -367,8 +368,8 @@ func (p PyACTR) writeAuthors() {
p.Writeln("")
}

func (p PyACTR) writeImports() {
if p.model.RandomSeed != nil {
func (p PyACTR) writeImports(runOptions *runoptions.Options) {
if runOptions.RandomSeed != nil {
p.Writeln("import numpy")
}

Expand Down Expand Up @@ -491,20 +492,20 @@ func (p PyACTR) writeProductions() {
}
}

func (p PyACTR) writeMain() {
func (p PyACTR) writeMain(runOptions *runoptions.Options) {
p.Writeln("# Main")
p.Writeln("if __name__ == '__main__':")

options := []string{"gui=False"}

if p.model.LogLevel == "min" {
if runOptions.LogLevel == "min" {
options = append(options, "trace=False")
}

p.Writeln(" sim = %s.simulation( %s )", p.className, strings.Join(options, ", "))
p.Writeln(" sim.run()")

if p.model.LogLevel != "min" {
if runOptions.LogLevel != "min" {
p.Writeln(" if goal.test_buffer('full'):")
p.Writeln(" print('chunk left in goal: ' + str(goal.pop()))")
p.Writeln(" if %s.retrieval.test_buffer('full'):", p.className)
Expand Down
2 changes: 1 addition & 1 deletion framework/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func GenerateCodeFromFile(fw Framework, inputFile string, initialBuffers Initial
return
}

code, err = fw.GenerateCode(initialBuffers)
code, err = fw.GenerateCode(&model.DefaultParams, initialBuffers)
if err != nil {
return
}
Expand Down
Loading

0 comments on commit 90dd90f

Please sign in to comment.