Skip to content

Commit

Permalink
pkg gotmpl: support multiple chain types
Browse files Browse the repository at this point in the history
  • Loading branch information
amirylm committed Jan 9, 2025
1 parent 5b97045 commit 92edba3
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 190 deletions.
48 changes: 3 additions & 45 deletions genchains_aptos.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@
package main

import (
"bytes"
"fmt"
"go/format"
"html/template"
"os"
"sort"
"strconv"
"strings"
"unicode"
"text/template"

chain_selectors "github.com/smartcontractkit/chain-selectors"
"github.com/smartcontractkit/chain-selectors/internal/gotmpl"
)

const filename = "generated_chains_aptos.go"
Expand Down Expand Up @@ -48,7 +44,7 @@ var AptosALL = []AptosChain{
`)

func main() {
src, err := genChainsSourceCode()
src, err := gotmpl.Run(chainTemplate, chain_selectors.AptosChainIdToChainSelector, chain_selectors.AptosNameFromChainId)
if err != nil {
panic(err)
}
Expand All @@ -74,41 +70,3 @@ func main() {
panic(err)
}
}

func genChainsSourceCode() (string, error) {
var wr = new(bytes.Buffer)
chains := make([]chain, 0)

for ChainID, chainSel := range chain_selectors.AptosChainIdToChainSelector() {
name, err := chain_selectors.AptosNameFromChainId(ChainID)
if err != nil {
return "", err
}

chains = append(chains, chain{
ChainID: ChainID,
Selector: chainSel,
Name: name,
VarName: toVarName(name, chainSel),
})
}

sort.Slice(chains, func(i, j int) bool { return chains[i].VarName < chains[j].VarName })
if err := chainTemplate.ExecuteTemplate(wr, "", chains); err != nil {
return "", err
}
return wr.String(), nil
}

func toVarName(name string, chainSel uint64) string {
const unnamed = "TEST"
x := strings.ReplaceAll(name, "-", "_")
x = strings.ToUpper(x)
if len(x) > 0 && unicode.IsDigit(rune(x[0])) {
x = unnamed + "_" + x
}
if len(x) == 0 {
x = unnamed + "_" + strconv.FormatUint(chainSel, 10)
}
return x
}
30 changes: 3 additions & 27 deletions genchains_evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,14 @@ import (
"fmt"
"go/format"
"os"
"strconv"
"strings"
"text/template"
"unicode"

chain_selectors "github.com/smartcontractkit/chain-selectors"
"github.com/smartcontractkit/chain-selectors/internal/gotmpl"
)

const filename = "generated_chains_evm.go"

type chain struct {
EvmChainID uint64
Selector uint64
Name string
VarName string
}

var chainTemplate, _ = template.New("").Parse(`// Code generated by go generate please DO NOT EDIT
package chain_selectors
Expand All @@ -35,7 +26,7 @@ type Chain struct {
var (
{{ range . }}
{{.VarName}} = Chain{EvmChainID: {{ .EvmChainID }}, Selector: {{ .Selector }}, Name: "{{ .Name }}"}{{ end }}
{{.VarName}} = Chain{EvmChainID: {{ .ChainID }}, Selector: {{ .Selector }}, Name: "{{ .Name }}"}{{ end }}
)
var ALL = []Chain{
Expand All @@ -46,7 +37,7 @@ var ALL = []Chain{
`)

func main() {
src, err := gotmpl.Run(chainTemplate, &goNameEncoder{})
src, err := gotmpl.Run(chainTemplate, chain_selectors.EvmChainIdToChainSelector, chain_selectors.NameFromChainId)
if err != nil {
panic(err)
}
Expand All @@ -72,18 +63,3 @@ func main() {
panic(err)
}
}

type goNameEncoder struct{}

func (*goNameEncoder) VarName(name string, chainSel uint64) string {
const unnamed = "TEST"
x := strings.ReplaceAll(name, "-", "_")
x = strings.ToUpper(x)
if len(x) > 0 && unicode.IsDigit(rune(x[0])) {
x = unnamed + "_" + x
}
if len(x) == 0 {
x = unnamed + "_" + strconv.FormatUint(chainSel, 10)
}
return x
}
48 changes: 3 additions & 45 deletions genchains_solana.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@
package main

import (
"bytes"
"fmt"
"go/format"
"html/template"
"os"
"sort"
"strconv"
"strings"
"unicode"
"text/template"

chain_selectors "github.com/smartcontractkit/chain-selectors"
"github.com/smartcontractkit/chain-selectors/internal/gotmpl"
)

const filename = "generated_chains_solana.go"
Expand Down Expand Up @@ -48,7 +44,7 @@ var SolanaALL = []SolanaChain{
`)

func main() {
src, err := genChainsSourceCode()
src, err := gotmpl.Run(chainTemplate, chain_selectors.SolanaChainIdToChainSelector, chain_selectors.SolanaNameFromChainId)
if err != nil {
panic(err)
}
Expand All @@ -74,41 +70,3 @@ func main() {
panic(err)
}
}

func genChainsSourceCode() (string, error) {
var wr = new(bytes.Buffer)
chains := make([]chain, 0)

for ChainID, chainSel := range chain_selectors.SolanaChainIdToChainSelector() {
name, err := chain_selectors.SolanaNameFromChainId(ChainID)
if err != nil {
return "", err
}

chains = append(chains, chain{
ChainID: ChainID,
Selector: chainSel,
Name: name,
VarName: toVarName(name, chainSel),
})
}

sort.Slice(chains, func(i, j int) bool { return chains[i].VarName < chains[j].VarName })
if err := chainTemplate.ExecuteTemplate(wr, "", chains); err != nil {
return "", err
}
return wr.String(), nil
}

func toVarName(name string, chainSel uint64) string {
const unnamed = "TEST"
x := strings.ReplaceAll(name, "-", "_")
x = strings.ToUpper(x)
if len(x) > 0 && unicode.IsDigit(rune(x[0])) {
x = unnamed + "_" + x
}
if len(x) == 0 {
x = unnamed + "_" + strconv.FormatUint(chainSel, 10)
}
return x
}
48 changes: 48 additions & 0 deletions internal/gotmpl/encoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package gotmpl

import (
"regexp"
"strconv"
"strings"
"unicode"

"golang.org/x/text/cases"
"golang.org/x/text/language"
)

func newNameEncoder() *nameEncoder {
return &nameEncoder{
re: regexp.MustCompile("[-_]+"),
caser: cases.Title(language.English),
}
}

type nameEncoder struct {
re *regexp.Regexp
caser cases.Caser
}

func (*nameEncoder) varName(name string, chainSel uint64) string {
const unnamed = "TEST"
x := strings.ReplaceAll(name, "-", "_")
x = strings.ToUpper(x)
if len(x) > 0 && unicode.IsDigit(rune(x[0])) {
x = unnamed + "_" + x
}
if len(x) == 0 {
x = unnamed + "_" + strconv.FormatUint(chainSel, 10)
}
return x
}

func (enc *nameEncoder) enumName(name string, chainSel uint64) string {
x := enc.re.ReplaceAllString(name, " ")
varName := strings.ReplaceAll(enc.caser.String(x), " ", "")
if len(varName) > 0 && unicode.IsDigit(rune(varName[0])) {
varName = "Test" + varName
}
if len(varName) == 0 {
varName = "Test" + strconv.FormatUint(chainSel, 10)
}
return varName
}
49 changes: 28 additions & 21 deletions internal/gotmpl/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,48 @@ import (
"bytes"
"sort"
"text/template"

chain_selectors "github.com/smartcontractkit/chain-selectors"
)

type NameEncoder interface {
VarName(name string, chainSel uint64) string
}

type chain struct {
EvmChainID uint64
Selector uint64
Name string
VarName string
// chain is a generic struct that can be used to represent chain families.
// C is the type of the chain ID.
//
// Supported types:
// EVM: uint64
// Solana: string
// Aptos: uint64
type chain[C uint64 | string] struct {
ChainID C
Selector uint64
Name string
VarName string
EnumName string
}

func Run(tmpl *template.Template, enc NameEncoder) (string, error) {
var wr = new(bytes.Buffer)
chains := make([]chain, 0)
// Run runs the template with the given chains and returns the result.
// C is the type of the chain ID.
func Run[C uint64 | string](tmpl *template.Template, chainSelFunc func() map[C]uint64, nameFunc func(C) (string, error)) (string, error) {
chains := make([]chain[C], 0)
enc := newNameEncoder()

for evmChainID, chainSel := range chain_selectors.EvmChainIdToChainSelector() {
name, err := chain_selectors.NameFromChainId(evmChainID)
for chainID, chainSel := range chainSelFunc() {
name, err := nameFunc(chainID)
if err != nil {
return "", err
}

chains = append(chains, chain{
EvmChainID: evmChainID,
Selector: chainSel,
Name: name,
VarName: enc.VarName(name, chainSel),
chains = append(chains, chain[C]{
ChainID: chainID,
Selector: chainSel,
Name: name,
VarName: enc.varName(name, chainSel),
EnumName: enc.enumName(name, chainSel),
})
}

sort.Slice(chains, func(i, j int) bool { return chains[i].VarName < chains[j].VarName })

var wr = new(bytes.Buffer)

if err := tmpl.Execute(wr, chains); err != nil {
return "", err
}
Expand Down
Loading

0 comments on commit 92edba3

Please sign in to comment.