Skip to content

Commit

Permalink
Move framework list into model run options (#419)
Browse files Browse the repository at this point in the history
(More set up for CLI options)
  • Loading branch information
asmaloney authored Feb 29, 2024
1 parent 2d2e6f2 commit cc379b1
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 105 deletions.
13 changes: 0 additions & 13 deletions actr/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,6 @@ func (model *Model) Initialize() {
model.parameters = parameters
}

func (m *Model) SetRunOptions(options *runoptions.Options) {
if options == nil {
return
}

m.LogLevel = options.LogLevel
m.TraceActivations = options.TraceActivations

if options.RandomSeed != nil {
m.RandomSeed = options.RandomSeed
}
}

func (model *Model) AddImplicitChunk(chunkName string) {
model.ImplicitChunks = append(model.ImplicitChunks, chunkName)
}
Expand Down
11 changes: 0 additions & 11 deletions framework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package framework

import (
"slices"
"time"

"github.com/asmaloney/gactar/actr"
Expand Down Expand Up @@ -89,13 +88,3 @@ func (l List) Exists(framework string) bool {

return false
}

// IsValidFramework returns if the framework name is in our list of valid ones or not.
func IsValidFramework(frameworkName string) bool {
return slices.Contains(ValidFrameworks, frameworkName)
}

// ValidNamedFrameworks returns the list of all valid framework names without "all".
func ValidNamedFrameworks() []string {
return ValidFrameworks[1:]
}
16 changes: 0 additions & 16 deletions modes/web/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,6 @@ var (
ErrNoModel = errors.New("no model loaded")
)

type ErrFrameworkNotActive struct {
Name string
}

func (e ErrFrameworkNotActive) Error() string {
return fmt.Sprintf("framework %q is not active on server", e.Name)
}

type ErrInvalidFrameworkName struct {
Name string
}

func (e ErrInvalidFrameworkName) Error() string {
return fmt.Sprintf("invalid framework name: %q", e.Name)
}

type ErrInvalidModelID struct {
ID int
}
Expand Down
27 changes: 19 additions & 8 deletions modes/web/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ type Model struct {
actrModel *actr.Model
}

// runOptionsJSON is the JSON version of runoptions.Options
type runOptionsJSON struct {
Frameworks []string `json:"frameworks,omitempty"` // list of frameworks to run on (if empty, "all")
LogLevel *string `json:"logLevel,omitempty"`
TraceActivations *bool `json:"traceActivations,omitempty"`
RandomSeed *uint32 `json:"randomSeed,omitempty"`
Frameworks runoptions.FrameworkNameList `json:"frameworks,omitempty"` // list of frameworks to run on (if empty, "all")
LogLevel *string `json:"logLevel,omitempty"`
TraceActivations *bool `json:"traceActivations,omitempty"`
RandomSeed *uint32 `json:"randomSeed,omitempty"`
}

func initModels(w *Web) {
Expand Down Expand Up @@ -82,14 +83,24 @@ func (w *Web) loadModel(sessionID int, amodFile string) (model *Model, err error
return
}

// actrOptions converts runOptions into actr.Options
func actrOptions(options *runOptionsJSON) *runoptions.Options {
// actrOptionsFromJSON converts runOptionsJSON into actr.Options
func (w Web) actrOptionsFromJSON(options *runOptionsJSON) (runoptions.Options, error) {
if options == nil {
return nil
return runoptions.Options{}, nil
}

activeFrameworkNames := w.settings.Frameworks.Names()

options.Frameworks.NormalizeFrameworkList(activeFrameworkNames)

err := options.Frameworks.VerifyFrameworkList(activeFrameworkNames)
if err != nil {
return runoptions.Options{}, err
}

opts := runoptions.New()

opts.Frameworks = options.Frameworks
opts.RandomSeed = options.RandomSeed

if options.LogLevel != nil {
Expand All @@ -100,7 +111,7 @@ func actrOptions(options *runOptionsJSON) *runoptions.Options {
opts.TraceActivations = *options.TraceActivations
}

return &opts
return opts, nil
}

func generateModel(amodFile string) (model *actr.Model, err error) {
Expand Down
8 changes: 3 additions & 5 deletions modes/web/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,13 @@ func (w *Web) runModelSessionHandler(rw http.ResponseWriter, req *http.Request)
return
}

data.Options.Frameworks = w.normalizeFrameworkList(data.Options.Frameworks)

err = w.verifyFrameworkList(data.Options.Frameworks)
aoptions, err := w.actrOptionsFromJSON(&data.Options)
if err != nil {
encodeErrorResponse(rw, err)
return
}

model.actrModel.SetRunOptions(actrOptions(&data.Options))
model.actrModel.Options = aoptions

// ensure temp dir exists
// https://github.com/asmaloney/gactar/issues/103
Expand All @@ -85,7 +83,7 @@ func (w *Web) runModelSessionHandler(rw http.ResponseWriter, req *http.Request)
return
}

resultMap := w.runModel(model.actrModel, data.Buffers, data.Options.Frameworks)
resultMap := w.runModel(model.actrModel, data.Buffers)

for key := range resultMap {
result := resultMap[key]
Expand Down
54 changes: 9 additions & 45 deletions modes/web/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"os"
"path"
"slices"
"sort"
"strconv"
"strings"
Expand All @@ -24,7 +23,6 @@ import (
"github.com/asmaloney/gactar/framework"

"github.com/asmaloney/gactar/util/cli"
"github.com/asmaloney/gactar/util/container"
"github.com/asmaloney/gactar/util/issues"
"github.com/asmaloney/gactar/util/validate"
"github.com/asmaloney/gactar/util/version"
Expand Down Expand Up @@ -158,21 +156,19 @@ func (w Web) runModelHandler(rw http.ResponseWriter, req *http.Request) {
return
}

data.Options.Frameworks = w.normalizeFrameworkList(data.Options.Frameworks)

err = w.verifyFrameworkList(data.Options.Frameworks)
model, log, err := amod.GenerateModel(data.AMODFile)
if err != nil {
encodeErrorResponse(rw, err)
encodeIssueResponse(rw, log)
return
}

model, log, err := amod.GenerateModel(data.AMODFile)
aoptions, err := w.actrOptionsFromJSON(data.Options)
if err != nil {
encodeIssueResponse(rw, log)
encodeErrorResponse(rw, err)
return
}

model.SetRunOptions(actrOptions(data.Options))
model.Options = aoptions

initialGoal := strings.TrimSpace(data.Goal)
initialBuffers := framework.InitialBuffers{
Expand All @@ -189,7 +185,7 @@ func (w Web) runModelHandler(rw http.ResponseWriter, req *http.Request) {
return
}

resultMap := w.runModel(model, initialBuffers, data.Options.Frameworks)
resultMap := w.runModel(model, initialBuffers)

rr := runResult{
Issues: log.AllIssues(),
Expand All @@ -205,45 +201,13 @@ func (w Web) runModelHandler(rw http.ResponseWriter, req *http.Request) {
encodeResponse(rw, json.RawMessage(string(results)))
}

// normalizeFrameworkList will look for "all" and replace it with all available
// framework names. It will then return a unique and sorted list of framework names.
func (w Web) normalizeFrameworkList(list []string) (normalized []string) {
normalized = list

if list == nil || slices.Contains(list, "all") {
normalized = w.settings.Frameworks.Names()
}

normalized = container.UniqueAndSorted(normalized)
return
}

// verifyFrameworkList will check that each name is of a valid framework and that
// it is active on this server.
func (w Web) verifyFrameworkList(list []string) (err error) {
for _, name := range list {
if !framework.IsValidFramework(name) {
err = &ErrInvalidFrameworkName{Name: name}
return
}

// we have a valid name, check if it is active
if _, ok := w.settings.Frameworks[name]; !ok {
err = &ErrFrameworkNotActive{Name: name}
return
}
}

return
}

func (w Web) runModel(model *actr.Model, initialBuffers framework.InitialBuffers, frameworkNames []string) (resultMap frameworkRunResultMap) {
resultMap = make(frameworkRunResultMap, len(frameworkNames))
func (w Web) runModel(model *actr.Model, initialBuffers framework.InitialBuffers) (resultMap frameworkRunResultMap) {
resultMap = make(frameworkRunResultMap, len(model.Options.Frameworks))

var wg sync.WaitGroup
var mutex = &sync.Mutex{}

for _, name := range frameworkNames {
for _, name := range model.Options.Frameworks {
f := w.settings.Frameworks[name]

wg.Add(1)
Expand Down
3 changes: 2 additions & 1 deletion util/frameworkutil/frameworkutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ import (

"github.com/asmaloney/gactar/util/chalk"
"github.com/asmaloney/gactar/util/cli"
"github.com/asmaloney/gactar/util/runoptions"
)

// CreateFrameworks takes a slice of framework names and some settings,
// creates any valid ones, and returns a list of them.
// If "names" is empty it will try to create all valid frameworks.
func CreateFrameworks(settings *cli.Settings, names []string) (list framework.List) {
if len(names) == 0 {
names = framework.ValidNamedFrameworks()
names = runoptions.ValidNamedFrameworks()
}

list = framework.List{}
Expand Down
19 changes: 19 additions & 0 deletions util/runoptions/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package runoptions

import "fmt"

type ErrFrameworkNotActive struct {
Name string
}

func (e ErrFrameworkNotActive) Error() string {
return fmt.Sprintf("framework %q is not active on server", e.Name)
}

type ErrInvalidFrameworkName struct {
Name string
}

func (e ErrInvalidFrameworkName) Error() string {
return fmt.Sprintf("invalid framework name: %q", e.Name)
}
68 changes: 62 additions & 6 deletions util/runoptions/runoptions.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
// Package runoptions preovide structs and types to pass around options used when running models.
package runoptions

import "slices"
import (
"slices"

"github.com/asmaloney/gactar/util/container"
)

type ACTRLogLevel string

var ACTRLoggingLevels = []string{
"min",
"info",
"detail",
}
var (
// ValidFrameworks lists the valid options for choosing frameworks on the command line and in the
// interactive case. Make sure "all" is the first entry as we use [1:] to get the rest.
ValidFrameworks = []string{"all", "ccm", "pyactr", "vanilla"}

ACTRLoggingLevels = []string{
"min",
"info",
"detail",
}
)

// ValidLogLevel returns whether the string is a valid logging level or not.
func ValidLogLevel(e string) bool {
return slices.Contains(ACTRLoggingLevels, e)
}

// FrameworkNameList is a list of framework names used in the run options.
type FrameworkNameList []string

// Options are options used when running a model.
type Options struct {
// List of frameworks to run on (if empty, "all")
Frameworks FrameworkNameList

// One of 'min', 'info', or 'detail'
LogLevel ACTRLogLevel

Expand All @@ -33,8 +49,48 @@ type Options struct {
// New returns a default-initialized Options struct.
func New() Options {
return Options{
Frameworks: FrameworkNameList{"all"},
LogLevel: ACTRLogLevel("info"),
TraceActivations: false,
RandomSeed: nil,
}
}

// IsValidFramework returns if the framework name is in our list of valid ones or not.
func IsValidFramework(frameworkName string) bool {
return slices.Contains(ValidFrameworks, frameworkName)
}

// ValidNamedFrameworks returns the list of all valid framework names without "all".
func ValidNamedFrameworks() []string {
return ValidFrameworks[1:]
}

// NormalizeFrameworkList will look for "all" and replace it with all available
// framework names.
func (f *FrameworkNameList) NormalizeFrameworkList(activeFrameworks FrameworkNameList) {
if f == nil || slices.Contains(*f, "all") {
*f = activeFrameworks
}

*f = container.UniqueAndSorted(*f)
}

// VerifyFrameworkList will check that each name is of a valid framework and that
// it is active on this server.
func (f FrameworkNameList) VerifyFrameworkList(activeFrameworks FrameworkNameList) (err error) {
for _, name := range f {
if !IsValidFramework(name) {
err = &ErrInvalidFrameworkName{Name: name}
return
}

// we have a valid name, check if it is active
if !slices.Contains(activeFrameworks, name) {
err = &ErrFrameworkNotActive{Name: name}
return
}
}

return
}

0 comments on commit cc379b1

Please sign in to comment.