diff --git a/README.md b/README.md index 0c0a1ba..3c0debd 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,22 @@ db.Use(sharding.Register(sharding.Config{ }, "orders") ``` +### Use MySQL Sequence + +There has built-in MySQL sequence primary key implementation in Gorm Sharding, you just configure `PrimaryKeyGenerator: sharding.PKMySQLSequence` to use. + +You don't need create sequence manually, Gorm Sharding check and create when the MySQL sequence does not exists. + +This sequence name followed `gorm_sharding_${table_name}_id_seq`, for example `orders` table, the sequence name is `gorm_sharding_orders_id_seq`. + +```go +db.Use(sharding.Register(sharding.Config{ + ShardingKey: "user_id", + NumberOfShards: 64, + PrimaryKeyGenerator: sharding.PKMySQLSequence, +}, "orders") +``` + ### No primary key If your table doesn't have a primary key, or has a primary key that isn't called `id`, anyway, you don't want to auto-fill the `id` field, then you can set `PrimaryKeyGenerator` to `PKCustom` and have `PrimaryKeyGeneratorFn` return `0`. diff --git a/primary_key.go b/primary_key.go index 3c5ffb9..9278eb9 100644 --- a/primary_key.go +++ b/primary_key.go @@ -7,6 +7,8 @@ const ( PKSnowflake = iota // Use PostgreSQL sequence primary key generator PKPGSequence + // Use MySQL sequence primary key generator + PKMySQLSequence // Use custom primary key generator PKCustom ) @@ -15,6 +17,8 @@ func (s *Sharding) genSnowflakeKey(index int64) int64 { return s.snowflakeNodes[index].Generate().Int64() } +// PostgreSQL sequence + func (s *Sharding) genPostgreSQLSequenceKey(tableName string, index int64) int64 { var id int64 err := s.DB.Raw("SELECT nextval('" + pgSeqName(tableName) + "')").Scan(&id).Error @@ -31,3 +35,34 @@ func (s *Sharding) createPostgreSQLSequenceKeyIfNotExist(tableName string) error func pgSeqName(table string) string { return fmt.Sprintf("gorm_sharding_%s_id_seq", table) } + +// MySQL Sequence + +func (s *Sharding) genMySQLSequenceKey(tableName string, index int64) int64 { + var id int64 + err := s.DB.Exec("UPDATE `" + mySQLSeqName(tableName) + "` SET id = LAST_INSERT_ID(id + 1)").Error + if err != nil { + panic(err) + } + err = s.DB.Raw("SELECT LAST_INSERT_ID()").Scan(&id).Error + if err != nil { + panic(err) + } + return id +} + +func (s *Sharding) createMySQLSequenceKeyIfNotExist(tableName string) error { + stmt := s.DB.Exec("CREATE TABLE IF NOT EXISTS `" + mySQLSeqName(tableName) + "` (id INT NOT NULL)") + if stmt.Error != nil { + return fmt.Errorf("failed to create sequence table: %w", stmt.Error) + } + stmt = s.DB.Exec("INSERT INTO `" + mySQLSeqName(tableName) + "` VALUES (0)") + if stmt.Error != nil { + return fmt.Errorf("failed to insert into sequence table: %w", stmt.Error) + } + return nil +} + +func mySQLSeqName(table string) string { + return fmt.Sprintf("gorm_sharding_%s_id_seq", table) +} diff --git a/sharding.go b/sharding.go index ad774f5..04f8448 100644 --- a/sharding.go +++ b/sharding.go @@ -145,12 +145,21 @@ func (s *Sharding) compile() error { c.PrimaryKeyGeneratorFn = func(index int64) int64 { return s.genPostgreSQLSequenceKey(t, index) } + } else if c.PrimaryKeyGenerator == PKMySQLSequence { + err := s.createMySQLSequenceKeyIfNotExist(t) + if err != nil { + return err + } + + c.PrimaryKeyGeneratorFn = func(index int64) int64 { + return s.genMySQLSequenceKey(t, index) + } } else if c.PrimaryKeyGenerator == PKCustom { if c.PrimaryKeyGeneratorFn == nil { return errors.New("PrimaryKeyGeneratorFn is required when use PKCustom") } } else { - return errors.New("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence and PKCustom") + return errors.New("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence, PKMySQLSequence and PKCustom") } if c.ShardingAlgorithm == nil { @@ -240,6 +249,16 @@ func (s *Sharding) Initialize(db *gorm.DB) error { return fmt.Errorf("init postgresql sequence error, %w", err) } } + if c.PrimaryKeyGenerator == PKMySQLSequence { + err := s.DB.Exec("CREATE TABLE IF NOT EXISTS " + mySQLSeqName(t) + " (id INT NOT NULL)").Error + if err != nil { + return fmt.Errorf("init mysql create sequence error, %w", err) + } + err = s.DB.Exec("INSERT INTO " + mySQLSeqName(t) + " VALUES (0)").Error + if err != nil { + return fmt.Errorf("init mysql insert sequence error, %w", err) + } + } } s.snowflakeNodes = make([]*snowflake.Node, 1024) diff --git a/sharding_test.go b/sharding_test.go index bdebbe1..ccc46f7 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -189,7 +189,11 @@ func dropTables() { dbNoID.Exec("DROP TABLE IF EXISTS " + table) dbRead.Exec("DROP TABLE IF EXISTS " + table) dbWrite.Exec("DROP TABLE IF EXISTS " + table) - db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq")) + if mysqlDialector() { + db.Exec(("DROP TABLE IF EXISTS gorm_sharding_" + table + "_id_seq")) + } else { + db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq")) + } } } @@ -417,6 +421,27 @@ func TestPKPGSequence(t *testing.T) { assert.Equal(t, expected, middleware.LastQuery()) } +func TestPKMySQLSequence(t *testing.T) { + if !mysqlDialector() { + return + } + + db, _ := gorm.Open(mysql.Open(dbURL()), &gorm.Config{ + DisableForeignKeyConstraintWhenMigrating: true, + }) + shardingConfig.PrimaryKeyGenerator = PKMySQLSequence + middleware := Register(shardingConfig, &Order{}) + db.Use(middleware) + + db.Exec("UPDATE `" + mySQLSeqName("orders") + "` SET id = 42") + db.Create(&Order{UserID: 100, Product: "iPhone"}) + expected := "INSERT INTO orders_0 (`user_id`, `product`, id) VALUES (?, ?, 43)" + if mariadbDialector() { + expected = expected + " RETURNING `id`" + } + assert.Equal(t, expected, middleware.LastQuery()) +} + func TestReadWriteSplitting(t *testing.T) { dbRead.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)") dbWrite.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)") @@ -531,3 +556,7 @@ func assertSfidQueryResult(t *testing.T, expected, lastQuery string) { func mysqlDialector() bool { return os.Getenv("DIALECTOR") == "mysql" || os.Getenv("DIALECTOR") == "mariadb" } + +func mariadbDialector() bool { + return os.Getenv("DIALECTOR") == "mariadb" +}