Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
amirylm committed Nov 1, 2024
1 parent a72fd5c commit 5fd0ce7
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 105 deletions.
35 changes: 5 additions & 30 deletions genchains.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
package main

import (
"bytes"
"fmt"
"go/format"
"html/template"
"os"
"sort"
"strconv"
"strings"
"text/template"
"unicode"

chain_selectors "github.com/smartcontractkit/chain-selectors"
"github.com/smartcontractkit/chain-selectors/internal/gotmpl"
)

const filename = "generated_chains.go"
Expand Down Expand Up @@ -48,7 +46,7 @@ var ALL = []Chain{
`)

func main() {
src, err := genChainsSourceCode()
src, err := gotmpl.Run(chainTemplate, &goNameEncoder{})
if err != nil {
panic(err)
}
Expand All @@ -74,32 +72,9 @@ func main() {
}
}

func genChainsSourceCode() (string, error) {
var wr = new(bytes.Buffer)
chains := make([]chain, 0)

for evmChainID, chainSel := range chain_selectors.EvmChainIdToChainSelector() {
name, err := chain_selectors.NameFromChainId(evmChainID)
if err != nil {
return "", err
}

chains = append(chains, chain{
EvmChainID: evmChainID,
Selector: chainSel,
Name: name,
VarName: toVarName(name, chainSel),
})
}

sort.Slice(chains, func(i, j int) bool { return chains[i].VarName < chains[j].VarName })
if err := chainTemplate.ExecuteTemplate(wr, "", chains); err != nil {
return "", err
}
return wr.String(), nil
}
type goNameEncoder struct{}

func toVarName(name string, chainSel uint64) string {
func (*goNameEncoder) VarName(name string, chainSel uint64) string {
const unnamed = "TEST"
x := strings.ReplaceAll(name, "-", "_")
x = strings.ToUpper(x)
Expand Down
47 changes: 47 additions & 0 deletions internal/gotmpl/run.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package gotmpl

import (
"bytes"
"sort"
"text/template"

chain_selectors "github.com/smartcontractkit/chain-selectors"
)

type NameEncoder interface {
VarName(name string, chainSel uint64) string
}

type chain struct {
EvmChainID uint64
Selector uint64
Name string
VarName string
}

func Run(tmpl *template.Template, enc NameEncoder) (string, error) {
var wr = new(bytes.Buffer)
chains := make([]chain, 0)

for evmChainID, chainSel := range chain_selectors.EvmChainIdToChainSelector() {
name, err := chain_selectors.NameFromChainId(evmChainID)
if err != nil {
return "", err
}

chains = append(chains, chain{
EvmChainID: evmChainID,
Selector: chainSel,
Name: name,
VarName: enc.VarName(name, chainSel),
})
}

sort.Slice(chains, func(i, j int) bool { return chains[i].VarName < chains[j].VarName })

if err := tmpl.Execute(wr, chains); err != nil {
return "", err
}

return wr.String(), nil
}
119 changes: 44 additions & 75 deletions rs/genchains_rs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,40 @@
package main

import (
"bytes"
"fmt"
"os"
"os/exec"
"path"
"regexp"
"sort"
"strconv"
"strings"
"text/template"
"unicode"

"github.com/smartcontractkit/chain-selectors/internal/gotmpl"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"gopkg.in/yaml.v3"
)

func main() {
const (
generatedFileName = "generated_chains.rs"
tmplFileName = "generated_chains.rs.tmpl"
)

func wd() string {
rsDir := os.Getenv("PWD")
if !strings.HasSuffix(rsDir, "/rs") {
rsDir = path.Join(rsDir, "rs")
}
tmplRaw, err := os.ReadFile(path.Join(rsDir, "generated_chains.rs.tmpl"))
if err != nil {
panic(err)
}
chains, err := readChainsFromSelectors(
path.Join(rsDir, "..", "selectors.yml"),
path.Join(rsDir, "..", "test_selectors.yml"),
)
return rsDir
}

func main() {
rsDir := wd()
tmplRaw, err := os.ReadFile(path.Join(rsDir, tmplFileName))
if err != nil {
panic(err)
}

generatedFileName := "generated_chains.rs"
tmpl, err := template.New(generatedFileName).Parse(string(tmplRaw))
if err != nil {
panic(err)
Expand All @@ -48,20 +47,13 @@ func main() {
if err != nil {
panic(err)
}
var wr = new(bytes.Buffer)
if err := tmpl.Execute(wr, chains); err != nil {
panic(err)
}
tmpFile := path.Join(os.TempDir(), generatedFileName)
if err := os.WriteFile(tmpFile, wr.Bytes(), 0644); err != nil {
panic(err)
}
defer os.Remove(tmpFile)
cmd := exec.Command("rustfmt", tmpFile)
if err := cmd.Run(); err != nil {

raw, err := gotmpl.Run(tmpl, newRustNameEncoder())
if err != nil {
panic(err)
}
formatted, err := os.ReadFile(tmpFile)

formatted, err := rustfmt([]byte(raw))
if err != nil {
panic(err)
}
Expand All @@ -70,69 +62,46 @@ func main() {
fmt.Println("rust: no changes detected")
return
}

if err := os.WriteFile(generatedFilePath, formatted, 0644); err != nil {
panic(err)
}
}

type SelectorsYamlEntry struct {
ChainName string `yaml:"name"`
ChainSelector uint64 `yaml:"selector"`
}

type SelectorsYaml struct {
Selectors map[uint64]SelectorsYamlEntry `yaml:"selectors"`
}

type Chain struct {
EvmChainID uint64
Selector uint64
Name string
VarName string
}

func readSelectorsYaml(filePath string) (*SelectorsYaml, error) {
selectorsRaw, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read selectors yml: %w", err)
}
var selectors SelectorsYaml
err = yaml.Unmarshal(selectorsRaw, &selectors)
if err != nil {
return nil, fmt.Errorf("failed to parse selectors yml: %w", err)
func rustfmt(src []byte) ([]byte, error) {
tmpFile := path.Join(os.TempDir(), generatedFileName)
if err := os.WriteFile(tmpFile, src, 0644); err != nil {
panic(err)
}
return &selectors, nil
}
defer os.Remove(tmpFile)

func readChainsFromSelectors(selectorsYml, testSelectorsYml string) ([]Chain, error) {
selectors, err := readSelectorsYaml(selectorsYml)
if err != nil {
return nil, err
cmd := exec.Command("rustfmt", tmpFile)
if err := cmd.Run(); err != nil {
panic(err)
}
testSelectors, err := readSelectorsYaml(testSelectorsYml)
formatted, err := os.ReadFile(tmpFile)
if err != nil {
return nil, err
}
re := regexp.MustCompile("[-_]+")
caser := cases.Title(language.English)
chains := make([]Chain, 0, len(selectors.Selectors)+len(testSelectors.Selectors))
for chainID, chain := range selectors.Selectors {
chains = append(chains, Chain{
EvmChainID: chainID,
Selector: chain.ChainSelector,
Name: chain.ChainName,
VarName: toVarName(chain.ChainName, chain.ChainSelector, caser, re),
})
panic(err)
}

sort.Slice(chains, func(i, j int) bool { return chains[i].VarName < chains[j].VarName })
return formatted, nil
}

return chains, nil
type rustNameEncoder struct {
re *regexp.Regexp
caser cases.Caser
}

func newRustNameEncoder() *rustNameEncoder {
return &rustNameEncoder{
re: regexp.MustCompile("[-_]+"),
caser: cases.Title(language.English),
}
}

func toVarName(name string, chainSel uint64, caser cases.Caser, reSep *regexp.Regexp) string {
x := reSep.ReplaceAllString(name, " ")
varName := strings.ReplaceAll(caser.String(x), " ", "")
func (enc *rustNameEncoder) VarName(name string, chainSel uint64) string {
x := enc.re.ReplaceAllString(name, " ")
varName := strings.ReplaceAll(enc.caser.String(x), " ", "")
if len(varName) > 0 && unicode.IsDigit(rune(varName[0])) {
varName = "Test" + varName
}
Expand Down

0 comments on commit 5fd0ce7

Please sign in to comment.