Skip to content

Commit

Permalink
created PKMySQLSequence
Browse files Browse the repository at this point in the history
  • Loading branch information
ricleal committed Oct 9, 2023
1 parent e491340 commit 10d5ba1
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 2 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ db.Use(sharding.Register(sharding.Config{
}, "orders")
```
### Use MySQL Sequence
There has built-in MySQL sequence primary key implementation in Gorm Sharding, you just configure `PrimaryKeyGenerator: sharding.PKMySQLSequence` to use.
You don't need create sequence manually, Gorm Sharding check and create when the MySQL sequence does not exists.
This sequence name followed `gorm_sharding_${table_name}_id_seq`, for example `orders` table, the sequence name is `gorm_sharding_orders_id_seq`.
```go
db.Use(sharding.Register(sharding.Config{
ShardingKey: "user_id",
NumberOfShards: 64,
PrimaryKeyGenerator: sharding.PKMySQLSequence,
}, "orders")
```
### No primary key
If your table doesn't have a primary key, or has a primary key that isn't called `id`, anyway, you don't want to auto-fill the `id` field, then you can set `PrimaryKeyGenerator` to `PKCustom` and have `PrimaryKeyGeneratorFn` return `0`.
Expand Down
35 changes: 35 additions & 0 deletions primary_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ const (
PKSnowflake = iota
// Use PostgreSQL sequence primary key generator
PKPGSequence
// Use MySQL sequence primary key generator
PKMySQLSequence
// Use custom primary key generator
PKCustom
)
Expand All @@ -15,6 +17,8 @@ func (s *Sharding) genSnowflakeKey(index int64) int64 {
return s.snowflakeNodes[index].Generate().Int64()
}

// PostgreSQL sequence

func (s *Sharding) genPostgreSQLSequenceKey(tableName string, index int64) int64 {
var id int64
err := s.DB.Raw("SELECT nextval('" + pgSeqName(tableName) + "')").Scan(&id).Error
Expand All @@ -31,3 +35,34 @@ func (s *Sharding) createPostgreSQLSequenceKeyIfNotExist(tableName string) error
func pgSeqName(table string) string {
return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
}

// MySQL Sequence

func (s *Sharding) genMySQLSequenceKey(tableName string, index int64) int64 {
var id int64
err := s.DB.Exec("UPDATE `" + mySQLSeqName(tableName) + "` SET id = LAST_INSERT_ID(id + 1)").Error
if err != nil {
panic(err)
}
err = s.DB.Raw("SELECT LAST_INSERT_ID()").Scan(&id).Error
if err != nil {
panic(err)
}
return id
}

func (s *Sharding) createMySQLSequenceKeyIfNotExist(tableName string) error {
stmt := s.DB.Exec("CREATE TABLE IF NOT EXISTS `" + mySQLSeqName(tableName) + "` (id INT NOT NULL)")
if stmt.Error != nil {
return fmt.Errorf("failed to create sequence table: %w", stmt.Error)
}
stmt = s.DB.Exec("INSERT INTO `" + mySQLSeqName(tableName) + "` VALUES (0)")
if stmt.Error != nil {
return fmt.Errorf("failed to insert into sequence table: %w", stmt.Error)
}
return nil
}

func mySQLSeqName(table string) string {
return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
}
21 changes: 20 additions & 1 deletion sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,21 @@ func (s *Sharding) compile() error {
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
return s.genPostgreSQLSequenceKey(t, index)
}
} else if c.PrimaryKeyGenerator == PKMySQLSequence {
err := s.createMySQLSequenceKeyIfNotExist(t)
if err != nil {
return err
}

c.PrimaryKeyGeneratorFn = func(index int64) int64 {
return s.genMySQLSequenceKey(t, index)
}
} else if c.PrimaryKeyGenerator == PKCustom {
if c.PrimaryKeyGeneratorFn == nil {
return errors.New("PrimaryKeyGeneratorFn is required when use PKCustom")
}
} else {
return errors.New("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence and PKCustom")
return errors.New("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence, PKMySQLSequence and PKCustom")
}

if c.ShardingAlgorithm == nil {
Expand Down Expand Up @@ -240,6 +249,16 @@ func (s *Sharding) Initialize(db *gorm.DB) error {
return fmt.Errorf("init postgresql sequence error, %w", err)
}
}
if c.PrimaryKeyGenerator == PKMySQLSequence {
err := s.DB.Exec("CREATE TABLE IF NOT EXISTS " + mySQLSeqName(t) + " (id INT NOT NULL)").Error
if err != nil {
return fmt.Errorf("init mysql create sequence error, %w", err)
}
err = s.DB.Exec("INSERT INTO " + mySQLSeqName(t) + " VALUES (0)").Error
if err != nil {
return fmt.Errorf("init mysql insert sequence error, %w", err)
}
}
}

s.snowflakeNodes = make([]*snowflake.Node, 1024)
Expand Down
31 changes: 30 additions & 1 deletion sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ func dropTables() {
dbNoID.Exec("DROP TABLE IF EXISTS " + table)
dbRead.Exec("DROP TABLE IF EXISTS " + table)
dbWrite.Exec("DROP TABLE IF EXISTS " + table)
db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq"))
if mysqlDialector() {
db.Exec(("DROP TABLE IF EXISTS gorm_sharding_" + table + "_id_seq"))
} else {
db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq"))
}
}
}

Expand Down Expand Up @@ -417,6 +421,27 @@ func TestPKPGSequence(t *testing.T) {
assert.Equal(t, expected, middleware.LastQuery())
}

func TestPKMySQLSequence(t *testing.T) {
if !mysqlDialector() {
return
}

db, _ := gorm.Open(mysql.Open(dbURL()), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
shardingConfig.PrimaryKeyGenerator = PKMySQLSequence
middleware := Register(shardingConfig, &Order{})
db.Use(middleware)

db.Exec("UPDATE `" + mySQLSeqName("orders") + "` SET id = 42")
db.Create(&Order{UserID: 100, Product: "iPhone"})
expected := "INSERT INTO orders_0 (`user_id`, `product`, id) VALUES (?, ?, 43)"
if mariadbDialector() {
expected = expected + " RETURNING `id`"
}
assert.Equal(t, expected, middleware.LastQuery())
}

func TestReadWriteSplitting(t *testing.T) {
dbRead.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)")
dbWrite.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)")
Expand Down Expand Up @@ -531,3 +556,7 @@ func assertSfidQueryResult(t *testing.T, expected, lastQuery string) {
func mysqlDialector() bool {
return os.Getenv("DIALECTOR") == "mysql" || os.Getenv("DIALECTOR") == "mariadb"
}

func mariadbDialector() bool {
return os.Getenv("DIALECTOR") == "mariadb"
}

0 comments on commit 10d5ba1

Please sign in to comment.