From 5fd0ce7d6b4a67fc3cdca46de1b1b66cee0030ea Mon Sep 17 00:00:00 2001 From: amirylm <83904651+amirylm@users.noreply.github.com> Date: Fri, 1 Nov 2024 17:39:26 +0200 Subject: [PATCH] refactor code --- genchains.go | 35 ++---------- internal/gotmpl/run.go | 47 ++++++++++++++++ rs/genchains_rs.go | 119 +++++++++++++++-------------------------- 3 files changed, 96 insertions(+), 105 deletions(-) create mode 100644 internal/gotmpl/run.go diff --git a/genchains.go b/genchains.go index 09742c3..aa63572 100644 --- a/genchains.go +++ b/genchains.go @@ -3,17 +3,15 @@ package main import ( - "bytes" "fmt" "go/format" - "html/template" "os" - "sort" "strconv" "strings" + "text/template" "unicode" - chain_selectors "github.com/smartcontractkit/chain-selectors" + "github.com/smartcontractkit/chain-selectors/internal/gotmpl" ) const filename = "generated_chains.go" @@ -48,7 +46,7 @@ var ALL = []Chain{ `) func main() { - src, err := genChainsSourceCode() + src, err := gotmpl.Run(chainTemplate, &goNameEncoder{}) if err != nil { panic(err) } @@ -74,32 +72,9 @@ func main() { } } -func genChainsSourceCode() (string, error) { - var wr = new(bytes.Buffer) - chains := make([]chain, 0) - - for evmChainID, chainSel := range chain_selectors.EvmChainIdToChainSelector() { - name, err := chain_selectors.NameFromChainId(evmChainID) - if err != nil { - return "", err - } - - chains = append(chains, chain{ - EvmChainID: evmChainID, - Selector: chainSel, - Name: name, - VarName: toVarName(name, chainSel), - }) - } - - sort.Slice(chains, func(i, j int) bool { return chains[i].VarName < chains[j].VarName }) - if err := chainTemplate.ExecuteTemplate(wr, "", chains); err != nil { - return "", err - } - return wr.String(), nil -} +type goNameEncoder struct{} -func toVarName(name string, chainSel uint64) string { +func (*goNameEncoder) VarName(name string, chainSel uint64) string { const unnamed = "TEST" x := strings.ReplaceAll(name, "-", "_") x = strings.ToUpper(x) diff --git a/internal/gotmpl/run.go b/internal/gotmpl/run.go new file mode 100644 index 0000000..7e4b36c --- /dev/null +++ b/internal/gotmpl/run.go @@ -0,0 +1,47 @@ +package gotmpl + +import ( + "bytes" + "sort" + "text/template" + + chain_selectors "github.com/smartcontractkit/chain-selectors" +) + +type NameEncoder interface { + VarName(name string, chainSel uint64) string +} + +type chain struct { + EvmChainID uint64 + Selector uint64 + Name string + VarName string +} + +func Run(tmpl *template.Template, enc NameEncoder) (string, error) { + var wr = new(bytes.Buffer) + chains := make([]chain, 0) + + for evmChainID, chainSel := range chain_selectors.EvmChainIdToChainSelector() { + name, err := chain_selectors.NameFromChainId(evmChainID) + if err != nil { + return "", err + } + + chains = append(chains, chain{ + EvmChainID: evmChainID, + Selector: chainSel, + Name: name, + VarName: enc.VarName(name, chainSel), + }) + } + + sort.Slice(chains, func(i, j int) bool { return chains[i].VarName < chains[j].VarName }) + + if err := tmpl.Execute(wr, chains); err != nil { + return "", err + } + + return wr.String(), nil +} diff --git a/rs/genchains_rs.go b/rs/genchains_rs.go index 2add1f4..f715445 100644 --- a/rs/genchains_rs.go +++ b/rs/genchains_rs.go @@ -3,41 +3,40 @@ package main import ( - "bytes" "fmt" "os" "os/exec" "path" "regexp" - "sort" "strconv" "strings" "text/template" "unicode" + "github.com/smartcontractkit/chain-selectors/internal/gotmpl" "golang.org/x/text/cases" "golang.org/x/text/language" - "gopkg.in/yaml.v3" ) -func main() { +const ( + generatedFileName = "generated_chains.rs" + tmplFileName = "generated_chains.rs.tmpl" +) + +func wd() string { rsDir := os.Getenv("PWD") if !strings.HasSuffix(rsDir, "/rs") { rsDir = path.Join(rsDir, "rs") } - tmplRaw, err := os.ReadFile(path.Join(rsDir, "generated_chains.rs.tmpl")) - if err != nil { - panic(err) - } - chains, err := readChainsFromSelectors( - path.Join(rsDir, "..", "selectors.yml"), - path.Join(rsDir, "..", "test_selectors.yml"), - ) + return rsDir +} + +func main() { + rsDir := wd() + tmplRaw, err := os.ReadFile(path.Join(rsDir, tmplFileName)) if err != nil { panic(err) } - - generatedFileName := "generated_chains.rs" tmpl, err := template.New(generatedFileName).Parse(string(tmplRaw)) if err != nil { panic(err) @@ -48,20 +47,13 @@ func main() { if err != nil { panic(err) } - var wr = new(bytes.Buffer) - if err := tmpl.Execute(wr, chains); err != nil { - panic(err) - } - tmpFile := path.Join(os.TempDir(), generatedFileName) - if err := os.WriteFile(tmpFile, wr.Bytes(), 0644); err != nil { - panic(err) - } - defer os.Remove(tmpFile) - cmd := exec.Command("rustfmt", tmpFile) - if err := cmd.Run(); err != nil { + + raw, err := gotmpl.Run(tmpl, newRustNameEncoder()) + if err != nil { panic(err) } - formatted, err := os.ReadFile(tmpFile) + + formatted, err := rustfmt([]byte(raw)) if err != nil { panic(err) } @@ -70,69 +62,46 @@ func main() { fmt.Println("rust: no changes detected") return } + if err := os.WriteFile(generatedFilePath, formatted, 0644); err != nil { panic(err) } } -type SelectorsYamlEntry struct { - ChainName string `yaml:"name"` - ChainSelector uint64 `yaml:"selector"` -} - -type SelectorsYaml struct { - Selectors map[uint64]SelectorsYamlEntry `yaml:"selectors"` -} - -type Chain struct { - EvmChainID uint64 - Selector uint64 - Name string - VarName string -} - -func readSelectorsYaml(filePath string) (*SelectorsYaml, error) { - selectorsRaw, err := os.ReadFile(filePath) - if err != nil { - return nil, fmt.Errorf("failed to read selectors yml: %w", err) - } - var selectors SelectorsYaml - err = yaml.Unmarshal(selectorsRaw, &selectors) - if err != nil { - return nil, fmt.Errorf("failed to parse selectors yml: %w", err) +func rustfmt(src []byte) ([]byte, error) { + tmpFile := path.Join(os.TempDir(), generatedFileName) + if err := os.WriteFile(tmpFile, src, 0644); err != nil { + panic(err) } - return &selectors, nil -} + defer os.Remove(tmpFile) -func readChainsFromSelectors(selectorsYml, testSelectorsYml string) ([]Chain, error) { - selectors, err := readSelectorsYaml(selectorsYml) - if err != nil { - return nil, err + cmd := exec.Command("rustfmt", tmpFile) + if err := cmd.Run(); err != nil { + panic(err) } - testSelectors, err := readSelectorsYaml(testSelectorsYml) + formatted, err := os.ReadFile(tmpFile) if err != nil { - return nil, err - } - re := regexp.MustCompile("[-_]+") - caser := cases.Title(language.English) - chains := make([]Chain, 0, len(selectors.Selectors)+len(testSelectors.Selectors)) - for chainID, chain := range selectors.Selectors { - chains = append(chains, Chain{ - EvmChainID: chainID, - Selector: chain.ChainSelector, - Name: chain.ChainName, - VarName: toVarName(chain.ChainName, chain.ChainSelector, caser, re), - }) + panic(err) } - sort.Slice(chains, func(i, j int) bool { return chains[i].VarName < chains[j].VarName }) + return formatted, nil +} - return chains, nil +type rustNameEncoder struct { + re *regexp.Regexp + caser cases.Caser +} + +func newRustNameEncoder() *rustNameEncoder { + return &rustNameEncoder{ + re: regexp.MustCompile("[-_]+"), + caser: cases.Title(language.English), + } } -func toVarName(name string, chainSel uint64, caser cases.Caser, reSep *regexp.Regexp) string { - x := reSep.ReplaceAllString(name, " ") - varName := strings.ReplaceAll(caser.String(x), " ", "") +func (enc *rustNameEncoder) VarName(name string, chainSel uint64) string { + x := enc.re.ReplaceAllString(name, " ") + varName := strings.ReplaceAll(enc.caser.String(x), " ", "") if len(varName) > 0 && unicode.IsDigit(rune(varName[0])) { varName = "Test" + varName }