diff --git a/internal/gotmpl/encoder.go b/internal/gotmpl/encoder.go index ba8d5f7..a0e52df 100644 --- a/internal/gotmpl/encoder.go +++ b/internal/gotmpl/encoder.go @@ -6,27 +6,21 @@ import ( "strings" "unicode" + "github.com/mr-tron/base58" "golang.org/x/text/cases" "golang.org/x/text/language" ) -func newNameEncoder() *nameEncoder { - return &nameEncoder{ - re: regexp.MustCompile("[-_]+"), - caser: cases.Title(language.English), - } -} - -type nameEncoder struct { - re *regexp.Regexp - caser cases.Caser -} +var ( + reSeperators = regexp.MustCompile("[-_]+") + caser = cases.Title(language.English) +) -func (*nameEncoder) varName(name string, chainSel uint64) string { +func encodeVarName(name string, chainSel uint64) string { const unnamed = "TEST" x := strings.ReplaceAll(name, "-", "_") x = strings.ToUpper(x) - if len(x) > 0 && unicode.IsDigit(rune(x[0])) { + if len(x) > 0 && (unicode.IsDigit(rune(x[0])) || isSolTestChain(name)) { x = unnamed + "_" + x } if len(x) == 0 { @@ -35,14 +29,23 @@ func (*nameEncoder) varName(name string, chainSel uint64) string { return x } -func (enc *nameEncoder) enumName(name string, chainSel uint64) string { - x := enc.re.ReplaceAllString(name, " ") - varName := strings.ReplaceAll(enc.caser.String(x), " ", "") - if len(varName) > 0 && unicode.IsDigit(rune(varName[0])) { - varName = "Test" + varName +func encodeEnumName(name string, chainSel uint64) string { + const unnamed = "Test" + x := reSeperators.ReplaceAllString(name, " ") + varName := strings.ReplaceAll(caser.String(x), " ", "") + if len(varName) > 0 && (unicode.IsDigit(rune(varName[0])) || isSolTestChain(name)) { + varName = unnamed + varName } if len(varName) == 0 { - varName = "Test" + strconv.FormatUint(chainSel, 10) + varName = unnamed + strconv.FormatUint(chainSel, 10) } return varName } + +// for evm, the above condition is used to detect if name == chainId == (some number) -> which means its a test chain +// for solana, as chainId is not a number but a base58 encoded hash, we cannot use the above condition +// we need to check if the name == chainId == a valid base58 encoded hash +func isSolTestChain(name string) bool { + _, err := base58.Decode(name) + return err == nil +} diff --git a/internal/gotmpl/run.go b/internal/gotmpl/run.go index 80dd2e7..d1be56a 100644 --- a/internal/gotmpl/run.go +++ b/internal/gotmpl/run.go @@ -25,7 +25,6 @@ type chain[C uint64 | string] struct { // C is the type of the chain ID. func Run[C uint64 | string](tmpl *template.Template, chainSelFunc func() map[C]uint64, nameFunc func(C) (string, error)) (string, error) { chains := make([]chain[C], 0) - enc := newNameEncoder() for chainID, chainSel := range chainSelFunc() { name, err := nameFunc(chainID) @@ -37,8 +36,8 @@ func Run[C uint64 | string](tmpl *template.Template, chainSelFunc func() map[C]u ChainID: chainID, Selector: chainSel, Name: name, - VarName: enc.varName(name, chainSel), - EnumName: enc.enumName(name, chainSel), + VarName: encodeVarName(name, chainSel), + EnumName: encodeEnumName(name, chainSel), }) }