Skip to content

Commit

Permalink
feat: Support for AutoMigrate (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
a631807682 authored Apr 22, 2022
1 parent d9ee537 commit c6e2cad
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
77 changes: 77 additions & 0 deletions dialector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package sharding

import (
"fmt"

"gorm.io/gorm"
)

type ShardingDialector struct {
gorm.Dialector
sharding *Sharding
}

type ShardingMigrator struct {
gorm.Migrator
sharding *Sharding
dialector gorm.Dialector
}

func NewShardingDialector(d gorm.Dialector, s *Sharding) ShardingDialector {
return ShardingDialector{
Dialector: d,
sharding: s,
}
}

func (d ShardingDialector) Migrator(db *gorm.DB) gorm.Migrator {
m := d.Dialector.Migrator(db)
return ShardingMigrator{
Migrator: m,
sharding: d.sharding,
dialector: d.Dialector,
}
}

func (m ShardingMigrator) AutoMigrate(dst ...interface{}) error {
noShardingDsts := make([]interface{}, 0)
for _, model := range dst {
stmt := &gorm.Statement{DB: m.sharding.DB}
if err := stmt.Parse(model); err == nil {
if cfg, ok := m.sharding.configs[stmt.Table]; ok {
// support sharding table
suffixs := cfg.ShardingSuffixs()
if len(suffixs) == 0 {
return fmt.Errorf("sharding table:%s suffixs is empty", stmt.Table)
}

for _, suffix := range suffixs {
shardingTable := stmt.Table + suffix
tx := stmt.DB.Session(&gorm.Session{}).Table(shardingTable)
if err := m.dialector.Migrator(tx).AutoMigrate(model); err != nil {
return err
}
}

if cfg.DoubleWrite {
noShardingDsts = append(noShardingDsts, model)
}
} else {
noShardingDsts = append(noShardingDsts, model)
}
} else {
return err
}
}

if len(noShardingDsts) > 0 {
if err := m.Migrator.AutoMigrate(noShardingDsts...); err != nil {
return err
}
}
return nil
}

// TODO: DropTable drop sharding table
// func (m ShardingMigrator) DropTable(dst ...interface{}) error {
// }
28 changes: 28 additions & 0 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ type Config struct {
// }
ShardingAlgorithm func(columnValue interface{}) (suffix string, err error)

// ShardingSuffixs specifies a function to generate all table's suffix.
// Used to support Migrator.
// For example, this function get a mod all sharding suffixs.
//
// func () (suffixs []string) {
// numberOfShards := 5
// for i := 0; i < numberOfShards; i++ {
// suffixs = append(suffixs, fmt.Sprintf("_%02d", i%numberOfShards))
// }
// return
// }
ShardingSuffixs func() (suffixs []string)

// ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding
// table's suffix by the primary key. Used when no sharding key specified.
// For example, this function use the Snowflake library to generate the suffix.
Expand Down Expand Up @@ -161,10 +174,24 @@ func (s *Sharding) compile() error {
return "", fmt.Errorf("default algorithm only support integer and string column," +
"if you use other type, specify you own ShardingAlgorithm")
}

return fmt.Sprintf(c.tableFormat, id%int(c.NumberOfShards)), nil
}
}

if c.ShardingSuffixs == nil {
c.ShardingSuffixs = func() (suffixs []string) {
for i := 0; i < int(c.NumberOfShards); i++ {
suffix, err := c.ShardingAlgorithm(i)
if err != nil {
return nil
}
suffixs = append(suffixs, suffix)
}
return
}
}

if c.ShardingAlgorithmByPrimaryKey == nil {
if c.PrimaryKeyGenerator == PKSnowflake {
c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) {
Expand Down Expand Up @@ -194,6 +221,7 @@ func (s *Sharding) LastQuery() string {

// Initialize implement for Gorm plugin interface
func (s *Sharding) Initialize(db *gorm.DB) error {
db.Dialector = NewShardingDialector(db.Dialector, s)
s.DB = db
s.registerCallbacks(db)

Expand Down
15 changes: 15 additions & 0 deletions sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"regexp"
"sort"
"strings"
"testing"

Expand Down Expand Up @@ -151,6 +152,20 @@ func dropTables() {
}
}

func TestAutoMigrate(t *testing.T) {
targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
for _, table := range targetTables {
db.Exec("DROP TABLE IF EXISTS " + table)
db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq"))
}

db.AutoMigrate(&Order{}, &Category{})
tables, _ := db.Migrator().GetTables()
sort.Strings(tables)
sort.Strings(targetTables)
assert.Equal(t, tables, targetTables)
}

func TestInsert(t *testing.T) {
tx := db.Create(&Order{ID: 100, UserID: 100, Product: "iPhone"})
assertQueryResult(t, `INSERT INTO orders_0 ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx)
Expand Down

0 comments on commit c6e2cad

Please sign in to comment.