From 7d22b76bcd5478aa849a020c0aa7278d42c26d3f Mon Sep 17 00:00:00 2001 From: panpeng Date: Tue, 29 Oct 2024 10:10:58 +0800 Subject: [PATCH 1/5] enables sharding for a single table with flexible support for multiple partition keys. --- README.md | 2 + conn_pool.go | 13 +- sharding.go | 17 +- test/sharding_test.go | 351 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 377 insertions(+), 6 deletions(-) create mode 100644 test/sharding_test.go diff --git a/README.md b/README.md index 3c0debd..1505100 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,8 @@ err = db.Exec("DELETE FROM orders WHERE product_id = 3").Error fmt.Println(err) // ErrMissingShardingKey ``` +The example demonstrating a single table supporting multiple partitioning strategies is(单表支持多种分表策略的例子在这里)[here](./test/sharding_test.go). + The full example is [here](./examples/order.go). > 🚨 NOTE: Gorm config `PrepareStmt: true` is not supported for now. diff --git a/conn_pool.go b/conn_pool.go index c83adb6..c0b8e85 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -3,6 +3,7 @@ package sharding import ( "context" "database/sql" + "fmt" "time" "gorm.io/gorm" @@ -28,7 +29,7 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) curTime = time.Now() ) - ftQuery, stQuery, table, err := pool.sharding.resolve(query, args...) + ftQuery, stQuery, table, err := pool.sharding.resolve(ctx, query, args...) if err != nil { return nil, err } @@ -36,7 +37,11 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) pool.sharding.querys.Store("last_query", stQuery) if table != "" { - if r, ok := pool.sharding.configs[table]; ok { + key := table + if shardingKey, ok := ctx.Value("sharding_key").(string); ok { + key = fmt.Sprintf("%s_%v", table, shardingKey) + } + if r, ok := pool.sharding.configs[key]; ok { if r.DoubleWrite { pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) { result, _ := pool.ConnPool.ExecContext(ctx, ftQuery, args...) @@ -63,7 +68,7 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...any curTime = time.Now() ) - _, stQuery, _, err := pool.sharding.resolve(query, args...) + _, stQuery, _, err := pool.sharding.resolve(ctx, query, args...) if err != nil { return nil, err } @@ -80,7 +85,7 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...any } func (pool ConnPool) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { - _, query, _, _ = pool.sharding.resolve(query, args...) + _, query, _, _ = pool.sharding.resolve(ctx, query, args...) pool.sharding.querys.Store("last_query", query) return pool.ConnPool.QueryRowContext(ctx, query, args...) diff --git a/sharding.go b/sharding.go index 2f29d18..ab37625 100644 --- a/sharding.go +++ b/sharding.go @@ -1,6 +1,7 @@ package sharding import ( + "context" "errors" "fmt" "hash/crc32" @@ -110,6 +111,13 @@ func Register(config Config, tables ...any) *Sharding { } } +// enables sharding for a single table with flexible support for multiple partition keys. +func RegisterWithKeys(configs map[string]Config) *Sharding { + return &Sharding{ + configs: configs, + } +} + func (s *Sharding) compile() error { if s.configs == nil { s.configs = make(map[string]Config) @@ -297,7 +305,7 @@ func (s *Sharding) switchConn(db *gorm.DB) { } // resolve split the old query to full table query and sharding table query -func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableName string, err error) { +func (s *Sharding) resolve(ctx context.Context, query string, args ...any) (ftQuery, stQuery, tableName string, err error) { ftQuery = query stQuery = query if len(s.configs) == 0 { @@ -344,7 +352,12 @@ func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableNa } tableName = table.Name.Name - r, ok := s.configs[tableName] + key := tableName + // If sharding key is set in context, use it to get the sharding config. + if shardingKey, ok := ctx.Value("sharding_key").(string); ok { + key = fmt.Sprintf("%s_%v", tableName, shardingKey) + } + r, ok := s.configs[key] if !ok { return } diff --git a/test/sharding_test.go b/test/sharding_test.go new file mode 100644 index 0000000..20b561b --- /dev/null +++ b/test/sharding_test.go @@ -0,0 +1,351 @@ +package test + +import ( + "context" + "fmt" + "math/rand" + "testing" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "gorm.io/sharding" +) + +var globalDB *gorm.DB + +type Order struct { + ID int64 `gorm:"primaryKey"` + OrderId string `gorm:"sharding:order_id"` // 指明 OrderId 是分片键 + UserID int64 `gorm:"sharding:user_id"` + ProductID int64 + OrderDate time.Time + OrderYear int `gorm:"sharding:order_year"` +} + +func InitGormDb() *gorm.DB { + log := logger.Default.LogMode(logger.Info) + // 连接到 MySQL 数据库 + dsn := "user:password@tcp(ip:port)/sharding?charset=utf8mb4&parseTime=True&loc=Local" + db, err := gorm.Open(mysql.New(mysql.Config{ + DSN: dsn, + }), &gorm.Config{ + Logger: log, + }) + if err != nil { + panic("failed to connect database") + } + globalDB = db + return db +} + +// orders 表的分表键为order_year,根据order_year分表 +func customShardingAlgorithmWithOrderYear(value any) (suffix string, err error) { + if year, ok := value.(int); ok { + return fmt.Sprintf("_%d", year), nil + } + return "", fmt.Errorf("invalid order_date") +} + +// orders 表的分表键为user_id,根据user_id分表 +func customShardingAlgorithmWithUserId(value any) (suffix string, err error) { + if userId, ok := value.(int64); ok { + return fmt.Sprintf("_%d", userId%4), nil + } + return "", fmt.Errorf("invalid user_id") +} + +// orders 表的分表键为user_id,根据order_id分表 +func customShardingAlgorithmWithOrderId(value any) (suffix string, err error) { + if orderId, ok := value.(string); ok { + // 截取字符串,截取前8位,获取年份 + orderId = orderId[0:8] + orderDate, err := time.Parse("20060102", orderId) + if err != nil { + return "", fmt.Errorf("invalid order_date") + } + year := orderDate.Year() + return fmt.Sprintf("_%d", year), nil + } + return "", fmt.Errorf("invalid order_date") +} + +// customePrimaryKeyGeneratorFn 自定义主键生成函数 +func customePrimaryKeyGeneratorFn(tableIdx int64) int64 { + var id int64 + seqTableName := "gorm_sharding_orders_id_seq" // 序列表名 + db := globalDB + err := db.Exec("UPDATE `" + seqTableName + "` SET id = id+1").Error + if err != nil { + panic(err) + } + err = db.Raw("SELECT id FROM " + seqTableName + " ORDER BY id DESC LIMIT 1").Scan(&id).Error + if err != nil { + panic(err) + } + return id +} +func Test_Gorm_CreateTable(t *testing.T) { + // 初始化 Gorm DB + db := InitGormDb() + + // 创建gorm_sharding_orders_id_seq表 + err := db.Exec(`CREATE TABLE IF NOT EXISTS gorm_sharding_orders_id_seq ( + id BIGINT PRIMARY KEY NOT NULL DEFAULT 1 + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`).Error + if err != nil { + panic("failed to create table") + } + // 插入一条记录 + err = db.Exec(`INSERT INTO gorm_sharding_orders_id_seq (id) VALUES (1)`).Error + if err != nil { + panic("failed to insert data") + } + + // 预先创建 4 个分片表。 + // orders_0, orders_1, orders_2, orders_3 + // 根据 user_id 分片键策略,每个分片表存储 user_id 取模 4 余数为 0, 1, 2, 3 的订单数据。 + for i := 0; i < 4; i++ { + table := fmt.Sprintf("orders_%d", i) + // 删除已存在的表(如果存在) + db.Exec(`DROP TABLE IF EXISTS ` + table) + // 创建新的分片表 + db.Exec(`CREATE TABLE ` + table + ` ( + id BIGINT PRIMARY KEY, + order_id VARCHAR(50), + user_id INT, + product_id INT, + order_date DATETIME + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`) + } + + // 创建 Order 表 + // 根据 order_id 分片键策略,每个分片表存储不同年份的订单数据。 + // 也可根据order_year分片策略路由到不同的分片表。 + // orders_2024, orders_2025 + err = db.Exec(`CREATE TABLE IF NOT EXISTS orders_2024 ( + id BIGINT PRIMARY KEY, + order_id VARCHAR(50), + user_id INT, + product_id INT, + order_date DATETIME, + order_year INT + )`).Error + if err != nil { + panic("failed to create table") + } + err = db.Exec(`CREATE TABLE IF NOT EXISTS orders_2025 ( + id BIGINT PRIMARY KEY, + order_id VARCHAR(50), + user_id INT, + product_id INT, + order_date DATETIME, + order_year INT + )`).Error + if err != nil { + panic("failed to create table") + } +} + +func Test_Gorm_Sharding_WithKeys(t *testing.T) { + // 初始化 Gorm DB + db := InitGormDb() + + // 分表策略配置 + configWithOrderYear := sharding.Config{ + ShardingKey: "order_year", + ShardingAlgorithm: customShardingAlgorithmWithOrderYear, // 使用自定义的分片算法 + PrimaryKeyGenerator: sharding.PKCustom, // 使用自定义的主键生成函数 + PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, // 自定义主键生成函数 + } + configWithUserId := sharding.Config{ + ShardingKey: "user_id", + NumberOfShards: 4, + ShardingAlgorithm: customShardingAlgorithmWithUserId, // 使用自定义的分片算法 + PrimaryKeyGenerator: sharding.PKSnowflake, // 使用 Snowflake 算法生成主键 + } + configWithOrderId := sharding.Config{ + ShardingKey: "order_id", + ShardingAlgorithm: customShardingAlgorithmWithOrderId, // 使用自定义的分片算法 + PrimaryKeyGenerator: sharding.PKCustom, + PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, + } + mapConfig := make(map[string]sharding.Config) + mapConfig["orders_order_year"] = configWithOrderYear + mapConfig["orders_user_id"] = configWithUserId + mapConfig["orders_order_id"] = configWithOrderId + + // 配置 Gorm Sharding 中间件,注册分表策略配置 + middleware := sharding.RegisterWithKeys(mapConfig) // 逻辑表名为 "orders" + db.Use(middleware) + + // 根据order_year分片键策略,插入和查询示例 + InsertOrderByOrderYearKey(db) + FindByOrderYearKey(db, 2024) + + // 根据user_id分片键策略,插入和查询示例 + InsertOrderByUserId(db) + FindByUserIDKey(db, int64(100)) + + // 根据order_id分片键策略,插入、查询、更新和删除示例 + InsertOrderByOrderIdKey(db) + FindOrderByOrderIdKey(db, "20240101ORDER0002") + UpdateByOrderIdKey(db, "20240101ORDER0002") + DeleteByOrderIdKey(db, "20240101ORDER8480") +} + +func InsertOrderByOrderYearKey(db *gorm.DB) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = context.WithValue(ctx, "sharding_key", "order_year") + db = db.WithContext(ctx) + // 随机2024年或者2025年 + orderYear := rand.Intn(2) + 2024 + // 随机userId + userId := rand.Intn(100) + orderId := fmt.Sprintf("%d0101ORDER%04v", orderYear, rand.Int31n(10000)) + // 示例:插入订单数据 + order := Order{ + OrderId: orderId, + UserID: int64(userId), + ProductID: 100, + OrderDate: time.Date(orderYear, 1, 1, 0, 0, 0, 0, time.UTC), + OrderYear: orderYear, + } + err := db.Table("orders").Create(&order).Error + if err != nil { + fmt.Println("Error creating order:", err) + } + return err +} +func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) { + // 查询示例 + var orders []Order + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = context.WithValue(ctx, "sharding_key", "order_year") + db = db.WithContext(ctx) + db = db.Table("orders") + err := db.Model(&Order{}).Where("order_year=? and product_id=? and order_id=?", orderYear, 102, "20240101ORDER0002").Find(&orders).Error + if err != nil { + fmt.Println("Error querying orders:", err) + } + fmt.Printf("sharding key order_year Selected orders: %#v\n", orders) + return orders, err +} + +func InsertOrderByOrderIdKey(db *gorm.DB) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = context.WithValue(ctx, "sharding_key", "order_id") + db = db.WithContext(ctx) + // 随机2024年或者2025年 + orderYear := rand.Intn(2) + 2024 + // 随机userId + userId := rand.Intn(100) + orderId := fmt.Sprintf("%d0101ORDER%04v", orderYear, rand.Int31n(10000)) + // 示例:插入订单数据 + order := Order{ + OrderId: orderId, + UserID: int64(userId), + ProductID: 100, + OrderDate: time.Date(orderYear, 1, 1, 0, 0, 0, 0, time.UTC), + OrderYear: orderYear, + } + db = db.Table("orders") + err := db.Create(&order).Error + if err != nil { + fmt.Println("Error creating order:", err) + } + return err +} + +func UpdateByOrderIdKey(db *gorm.DB, orderId string) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = context.WithValue(ctx, "sharding_key", "order_id") + db = db.WithContext(ctx) + db = db.Table("orders") + err := db.Model(&Order{}).Where("order_id=?", orderId).Update("product_id", 102).Error + if err != nil { + fmt.Println("Error updating order:", err) + } + return err +} + +func DeleteByOrderIdKey(db *gorm.DB, orderId string) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = context.WithValue(ctx, "sharding_key", "order_id") + db = db.WithContext(ctx) + db = db.Table("orders") + err := db.Where("order_id=? and product_id=?", orderId, 100).Delete(&Order{}).Error + if err != nil { + fmt.Println("Error deleting order:", err) + } + return err +} +func FindOrderByOrderIdKey(db *gorm.DB, orderId string) ([]Order, error) { + var orders []Order + // 查询示例 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = context.WithValue(ctx, "sharding_key", "order_id") + db = db.WithContext(ctx) + db = db.Table("orders") + err := db.Model(&Order{}).Where("order_id=?", orderId).Find(&orders).Error + if err != nil { + fmt.Println("Error querying orders:", err) + } + fmt.Printf("sharding key order_id Selected orders: %#v\n", orders) + return orders, err +} + +type OrderByUserId struct { + ID int64 `gorm:"primaryKey"` + OrderId string `gorm:"sharding:order_id"` // 指明 OrderId 是分片键 + UserID int64 `gorm:"sharding:user_id"` + ProductID int64 + OrderDate time.Time +} + +func InsertOrderByUserId(db *gorm.DB) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = context.WithValue(ctx, "sharding_key", "user_id") + db = db.WithContext(ctx) + // 随机2024年或者2025年 + orderYear := rand.Intn(2) + 2024 + // 随机userId + userId := rand.Intn(100) + orderId := fmt.Sprintf("%d0101ORDER%04v", orderYear, rand.Int31n(10000)) + // 示例:插入订单数据 + order := OrderByUserId{ + OrderId: orderId, + UserID: int64(userId), + ProductID: 100, + OrderDate: time.Date(orderYear, 1, 1, 0, 0, 0, 0, time.UTC), + } + err := db.Table("orders").Create(&order).Error + if err != nil { + fmt.Println("Error creating order:", err) + } + return err +} + +func FindByUserIDKey(db *gorm.DB, userID int64) ([]Order, error) { + var orders []Order + // 查询示例 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = context.WithValue(ctx, "sharding_key", "user_id") + db = db.WithContext(ctx) + db = db.Table("orders") + err := db.Model(&Order{}).Where("user_id = ?", userID).Find(&orders).Error + if err != nil { + fmt.Println("Error querying orders:", err) + } + fmt.Printf("sharding key user_id Selected orders: %#v\n", orders) + return orders, err +} From 165e3d671aedd2be3228d6d62721d687c176cbd6 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Wed, 30 Oct 2024 14:43:53 +0800 Subject: [PATCH 2/5] AI translate comment to English. --- test/sharding_test.go | 96 +++++++++++++++++++++---------------------- 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/test/sharding_test.go b/test/sharding_test.go index 20b561b..3efe7a2 100644 --- a/test/sharding_test.go +++ b/test/sharding_test.go @@ -17,7 +17,7 @@ var globalDB *gorm.DB type Order struct { ID int64 `gorm:"primaryKey"` - OrderId string `gorm:"sharding:order_id"` // 指明 OrderId 是分片键 + OrderId string `gorm:"sharding:order_id"` // Specify that OrderId is the sharding key UserID int64 `gorm:"sharding:user_id"` ProductID int64 OrderDate time.Time @@ -26,7 +26,7 @@ type Order struct { func InitGormDb() *gorm.DB { log := logger.Default.LogMode(logger.Info) - // 连接到 MySQL 数据库 + // Connect to MySQL database dsn := "user:password@tcp(ip:port)/sharding?charset=utf8mb4&parseTime=True&loc=Local" db, err := gorm.Open(mysql.New(mysql.Config{ DSN: dsn, @@ -40,7 +40,7 @@ func InitGormDb() *gorm.DB { return db } -// orders 表的分表键为order_year,根据order_year分表 +// The sharding key of the orders table is order_year, sharding based on order_year func customShardingAlgorithmWithOrderYear(value any) (suffix string, err error) { if year, ok := value.(int); ok { return fmt.Sprintf("_%d", year), nil @@ -48,7 +48,7 @@ func customShardingAlgorithmWithOrderYear(value any) (suffix string, err error) return "", fmt.Errorf("invalid order_date") } -// orders 表的分表键为user_id,根据user_id分表 +// The sharding key of the orders table is user_id, sharding based on user_id func customShardingAlgorithmWithUserId(value any) (suffix string, err error) { if userId, ok := value.(int64); ok { return fmt.Sprintf("_%d", userId%4), nil @@ -56,10 +56,10 @@ func customShardingAlgorithmWithUserId(value any) (suffix string, err error) { return "", fmt.Errorf("invalid user_id") } -// orders 表的分表键为user_id,根据order_id分表 +// The sharding key of the orders table is order_id, sharding based on order_id func customShardingAlgorithmWithOrderId(value any) (suffix string, err error) { if orderId, ok := value.(string); ok { - // 截取字符串,截取前8位,获取年份 + // Extract the first 8 characters of the string to get the year orderId = orderId[0:8] orderDate, err := time.Parse("20060102", orderId) if err != nil { @@ -71,10 +71,10 @@ func customShardingAlgorithmWithOrderId(value any) (suffix string, err error) { return "", fmt.Errorf("invalid order_date") } -// customePrimaryKeyGeneratorFn 自定义主键生成函数 +// customePrimaryKeyGeneratorFn Custom primary key generation function func customePrimaryKeyGeneratorFn(tableIdx int64) int64 { var id int64 - seqTableName := "gorm_sharding_orders_id_seq" // 序列表名 + seqTableName := "gorm_sharding_orders_id_seq" // Sequence table name db := globalDB err := db.Exec("UPDATE `" + seqTableName + "` SET id = id+1").Error if err != nil { @@ -87,42 +87,42 @@ func customePrimaryKeyGeneratorFn(tableIdx int64) int64 { return id } func Test_Gorm_CreateTable(t *testing.T) { - // 初始化 Gorm DB + // Initialize Gorm DB db := InitGormDb() - // 创建gorm_sharding_orders_id_seq表 + // Create gorm_sharding_orders_id_seq table err := db.Exec(`CREATE TABLE IF NOT EXISTS gorm_sharding_orders_id_seq ( id BIGINT PRIMARY KEY NOT NULL DEFAULT 1 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`).Error if err != nil { panic("failed to create table") } - // 插入一条记录 + // Insert a record err = db.Exec(`INSERT INTO gorm_sharding_orders_id_seq (id) VALUES (1)`).Error if err != nil { panic("failed to insert data") } - // 预先创建 4 个分片表。 + // Pre-create 4 shard tables. // orders_0, orders_1, orders_2, orders_3 - // 根据 user_id 分片键策略,每个分片表存储 user_id 取模 4 余数为 0, 1, 2, 3 的订单数据。 + // According to the user_id sharding key strategy, each shard table stores order data with user_id modulo 4 remainder of 0, 1, 2, 3. for i := 0; i < 4; i++ { table := fmt.Sprintf("orders_%d", i) - // 删除已存在的表(如果存在) + // Drop existing table (if exists) db.Exec(`DROP TABLE IF EXISTS ` + table) - // 创建新的分片表 + // Create new shard table db.Exec(`CREATE TABLE ` + table + ` ( id BIGINT PRIMARY KEY, - order_id VARCHAR(50), - user_id INT, - product_id INT, - order_date DATETIME + order_id VARCHAR(50), + user_id INT, + product_id INT, + order_date DATETIME ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4`) } - // 创建 Order 表 - // 根据 order_id 分片键策略,每个分片表存储不同年份的订单数据。 - // 也可根据order_year分片策略路由到不同的分片表。 + // Create Order table + // According to the order_id sharding key strategy, each shard table stores order data of different years. + // It can also be routed to different shard tables according to the order_year sharding strategy. // orders_2024, orders_2025 err = db.Exec(`CREATE TABLE IF NOT EXISTS orders_2024 ( id BIGINT PRIMARY KEY, @@ -149,25 +149,25 @@ func Test_Gorm_CreateTable(t *testing.T) { } func Test_Gorm_Sharding_WithKeys(t *testing.T) { - // 初始化 Gorm DB + // Initialize Gorm DB db := InitGormDb() - // 分表策略配置 + // Sharding strategy configuration configWithOrderYear := sharding.Config{ ShardingKey: "order_year", - ShardingAlgorithm: customShardingAlgorithmWithOrderYear, // 使用自定义的分片算法 - PrimaryKeyGenerator: sharding.PKCustom, // 使用自定义的主键生成函数 - PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, // 自定义主键生成函数 + ShardingAlgorithm: customShardingAlgorithmWithOrderYear, // Use custom sharding algorithm + PrimaryKeyGenerator: sharding.PKCustom, // Use custom primary key generation function + PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, // Custom primary key generation function } configWithUserId := sharding.Config{ ShardingKey: "user_id", NumberOfShards: 4, - ShardingAlgorithm: customShardingAlgorithmWithUserId, // 使用自定义的分片算法 - PrimaryKeyGenerator: sharding.PKSnowflake, // 使用 Snowflake 算法生成主键 + ShardingAlgorithm: customShardingAlgorithmWithUserId, // Use custom sharding algorithm + PrimaryKeyGenerator: sharding.PKSnowflake, // Use Snowflake algorithm to generate primary key } configWithOrderId := sharding.Config{ ShardingKey: "order_id", - ShardingAlgorithm: customShardingAlgorithmWithOrderId, // 使用自定义的分片算法 + ShardingAlgorithm: customShardingAlgorithmWithOrderId, // Use custom sharding algorithm PrimaryKeyGenerator: sharding.PKCustom, PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, } @@ -176,19 +176,19 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) { mapConfig["orders_user_id"] = configWithUserId mapConfig["orders_order_id"] = configWithOrderId - // 配置 Gorm Sharding 中间件,注册分表策略配置 - middleware := sharding.RegisterWithKeys(mapConfig) // 逻辑表名为 "orders" + // Configure Gorm Sharding middleware, register sharding strategy configuration + middleware := sharding.RegisterWithKeys(mapConfig) // Logical table name is "orders" db.Use(middleware) - // 根据order_year分片键策略,插入和查询示例 + // Insert and query examples based on order_year sharding key strategy InsertOrderByOrderYearKey(db) FindByOrderYearKey(db, 2024) - // 根据user_id分片键策略,插入和查询示例 + // Insert and query examples based on user_id sharding key strategy InsertOrderByUserId(db) FindByUserIDKey(db, int64(100)) - // 根据order_id分片键策略,插入、查询、更新和删除示例 + // Insert, query, update, and delete examples based on order_id sharding key strategy InsertOrderByOrderIdKey(db) FindOrderByOrderIdKey(db, "20240101ORDER0002") UpdateByOrderIdKey(db, "20240101ORDER0002") @@ -200,12 +200,12 @@ func InsertOrderByOrderYearKey(db *gorm.DB) error { defer cancel() ctx = context.WithValue(ctx, "sharding_key", "order_year") db = db.WithContext(ctx) - // 随机2024年或者2025年 + // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 - // 随机userId + // Random userId userId := rand.Intn(100) orderId := fmt.Sprintf("%d0101ORDER%04v", orderYear, rand.Int31n(10000)) - // 示例:插入订单数据 + // Example: Insert order data order := Order{ OrderId: orderId, UserID: int64(userId), @@ -220,7 +220,7 @@ func InsertOrderByOrderYearKey(db *gorm.DB) error { return err } func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) { - // 查询示例 + // Query example var orders []Order ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -240,12 +240,12 @@ func InsertOrderByOrderIdKey(db *gorm.DB) error { defer cancel() ctx = context.WithValue(ctx, "sharding_key", "order_id") db = db.WithContext(ctx) - // 随机2024年或者2025年 + // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 - // 随机userId + // Random userId userId := rand.Intn(100) orderId := fmt.Sprintf("%d0101ORDER%04v", orderYear, rand.Int31n(10000)) - // 示例:插入订单数据 + // Example: Insert order data order := Order{ OrderId: orderId, UserID: int64(userId), @@ -288,7 +288,7 @@ func DeleteByOrderIdKey(db *gorm.DB, orderId string) error { } func FindOrderByOrderIdKey(db *gorm.DB, orderId string) ([]Order, error) { var orders []Order - // 查询示例 + // Query example ctx, cancel := context.WithCancel(context.Background()) defer cancel() ctx = context.WithValue(ctx, "sharding_key", "order_id") @@ -304,7 +304,7 @@ func FindOrderByOrderIdKey(db *gorm.DB, orderId string) ([]Order, error) { type OrderByUserId struct { ID int64 `gorm:"primaryKey"` - OrderId string `gorm:"sharding:order_id"` // 指明 OrderId 是分片键 + OrderId string `gorm:"sharding:order_id"` // Specify that OrderId is the sharding key UserID int64 `gorm:"sharding:user_id"` ProductID int64 OrderDate time.Time @@ -315,12 +315,12 @@ func InsertOrderByUserId(db *gorm.DB) error { defer cancel() ctx = context.WithValue(ctx, "sharding_key", "user_id") db = db.WithContext(ctx) - // 随机2024年或者2025年 + // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 - // 随机userId + // Random userId userId := rand.Intn(100) orderId := fmt.Sprintf("%d0101ORDER%04v", orderYear, rand.Int31n(10000)) - // 示例:插入订单数据 + // Example: Insert order data order := OrderByUserId{ OrderId: orderId, UserID: int64(userId), @@ -336,7 +336,7 @@ func InsertOrderByUserId(db *gorm.DB) error { func FindByUserIDKey(db *gorm.DB, userID int64) ([]Order, error) { var orders []Order - // 查询示例 + // Query example ctx, cancel := context.WithCancel(context.Background()) defer cancel() ctx = context.WithValue(ctx, "sharding_key", "user_id") From 6fd67a83563097963991c5896b51c549ac2fec20 Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Wed, 30 Oct 2024 14:47:26 +0800 Subject: [PATCH 3/5] Update sharding_test.go --- test/sharding_test.go | 55 ++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/test/sharding_test.go b/test/sharding_test.go index 3efe7a2..4617c4a 100644 --- a/test/sharding_test.go +++ b/test/sharding_test.go @@ -152,33 +152,34 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) { // Initialize Gorm DB db := InitGormDb() - // Sharding strategy configuration - configWithOrderYear := sharding.Config{ - ShardingKey: "order_year", - ShardingAlgorithm: customShardingAlgorithmWithOrderYear, // Use custom sharding algorithm - PrimaryKeyGenerator: sharding.PKCustom, // Use custom primary key generation function - PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, // Custom primary key generation function - } - configWithUserId := sharding.Config{ - ShardingKey: "user_id", - NumberOfShards: 4, - ShardingAlgorithm: customShardingAlgorithmWithUserId, // Use custom sharding algorithm - PrimaryKeyGenerator: sharding.PKSnowflake, // Use Snowflake algorithm to generate primary key - } - configWithOrderId := sharding.Config{ - ShardingKey: "order_id", - ShardingAlgorithm: customShardingAlgorithmWithOrderId, // Use custom sharding algorithm - PrimaryKeyGenerator: sharding.PKCustom, - PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, - } - mapConfig := make(map[string]sharding.Config) - mapConfig["orders_order_year"] = configWithOrderYear - mapConfig["orders_user_id"] = configWithUserId - mapConfig["orders_order_id"] = configWithOrderId - // Configure Gorm Sharding middleware, register sharding strategy configuration - middleware := sharding.RegisterWithKeys(mapConfig) // Logical table name is "orders" - db.Use(middleware) + // Logical table name is "orders" + db.Use(sharding.RegisterWithKeys(map[string]sharding.Config{ + "orders_order_year": { + ShardingKey: "order_year", + // Use custom sharding algorithm + ShardingAlgorithm: customShardingAlgorithmWithOrderYear, + // Use custom primary key generation function + PrimaryKeyGenerator: sharding.PKCustom, + // Custom primary key generation function + PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, + }, + "orders_user_id": { + ShardingKey: "user_id", + NumberOfShards: 4, + // Use custom sharding algorithm + ShardingAlgorithm: customShardingAlgorithmWithUserId, + // Use Snowflake algorithm to generate primary key + PrimaryKeyGenerator: sharding.PKSnowflake, + }, + "orders_order_id": { + ShardingKey: "order_id", + // Use custom sharding algorithm + ShardingAlgorithm: customShardingAlgorithmWithOrderId, + PrimaryKeyGenerator: sharding.PKCustom, + PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, + }, + })) // Insert and query examples based on order_year sharding key strategy InsertOrderByOrderYearKey(db) @@ -231,7 +232,7 @@ func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) { if err != nil { fmt.Println("Error querying orders:", err) } - fmt.Printf("sharding key order_year Selected orders: %#v\n", orders) + fmt.Printf("sharding key order_year Selected orders: %#v\nn", orders) return orders, err } From c5a8a2416a2e328474b666437bc7924a7535d3ea Mon Sep 17 00:00:00 2001 From: panpeng Date: Thu, 31 Oct 2024 16:29:36 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E5=8D=95=E8=A1=A8=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=A4=9A=E5=88=86=E8=A1=A8=E7=AD=96=E7=95=A5=E4=BC=98=E5=8C=96?= =?UTF-8?q?1=EF=BC=9A=E5=88=86=E8=A1=A8=E7=AD=96=E7=95=A5=E7=9A=84key?= =?UTF-8?q?=E6=8B=BC=E6=8E=A5=E5=B0=81=E8=A3=85=E5=88=B0=E7=BB=84=E4=BB=B6?= =?UTF-8?q?=E5=86=85=EF=BC=8C=E4=B8=8D=E5=AF=B9=E7=94=A8=E6=88=B7=E5=BC=80?= =?UTF-8?q?=E6=94=BE2.=E4=B8=8A=E4=B8=8B=E6=96=87=E4=B8=ADsharding=5Fkey?= =?UTF-8?q?=E6=94=B9=E4=B8=BAconst3.=E4=BF=AE=E5=A4=8D=E4=B8=80=E4=B8=AApa?= =?UTF-8?q?nic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- conn_pool.go | 14 +++++---- sharding.go | 68 ++++++++++++++++++++++++++++++++++++++----- test/sharding_test.go | 47 +++++++++++++++++------------- 3 files changed, 96 insertions(+), 33 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index c0b8e85..ed80563 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -3,7 +3,6 @@ package sharding import ( "context" "database/sql" - "fmt" "time" "gorm.io/gorm" @@ -38,14 +37,17 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) if table != "" { key := table - if shardingKey, ok := ctx.Value("sharding_key").(string); ok { - key = fmt.Sprintf("%s_%v", table, shardingKey) + key, err = pool.sharding.getConfigKey(ctx, table) + if err != nil { + return nil, err } if r, ok := pool.sharding.configs[key]; ok { if r.DoubleWrite { pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) { result, _ := pool.ConnPool.ExecContext(ctx, ftQuery, args...) - rowsAffected, _ = result.RowsAffected() + if result != nil { + rowsAffected, _ = result.RowsAffected() + } return pool.sharding.Explain(ftQuery, args...), rowsAffected }, pool.sharding.Error) } @@ -55,7 +57,9 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) var result sql.Result result, err = pool.ConnPool.ExecContext(ctx, stQuery, args...) pool.sharding.Logger.Trace(ctx, curTime, func() (sql string, rowsAffected int64) { - rowsAffected, _ = result.RowsAffected() + if result != nil { + rowsAffected, _ = result.RowsAffected() + } return pool.sharding.Explain(stQuery, args...), rowsAffected }, pool.sharding.Error) diff --git a/sharding.go b/sharding.go index ab37625..bbef9ee 100644 --- a/sharding.go +++ b/sharding.go @@ -16,15 +16,20 @@ import ( ) var ( - ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =") - ErrInvalidID = errors.New("invalid id format") - ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ") + ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =") + ErrInvalidID = errors.New("invalid id format") + ErrInsertDiffSuffix = errors.New("can not insert different suffix table in one query ") + ErrShardingKeyNotExistInContext = errors.New("the value passed in the context is not the sharding key") + ErrMissingTableName = errors.New("table name is required") ) var ( ShardingIgnoreStoreKey = "sharding_ignore" ) +// ContextKeyForShardingKey is the context key for sharding key. +const ContextKeyForShardingKey = "sharding_key" + type Sharding struct { *gorm.DB ConnPool *ConnPool @@ -47,6 +52,10 @@ type Config struct { // For example, for a product order table, you may want to split the rows by `user_id`. ShardingKey string + // logical table name.Suport multiple table names with same sharding key. + // For example, for user and order table, you may want to shard by `user_id`. + TableNames []string + // NumberOfShards specifies how many tables you want to sharding. NumberOfShards uint @@ -112,10 +121,53 @@ func Register(config Config, tables ...any) *Sharding { } // enables sharding for a single table with flexible support for multiple partition keys. -func RegisterWithKeys(configs map[string]Config) *Sharding { +func RegisterWithKeys(configs []Config) (*Sharding, error) { + mapConfig := make(map[string]Config, len(configs)) + for _, config := range configs { + for _, tableName := range config.TableNames { + configKey, err := generateConfigsKey(tableName, config.ShardingKey) + if err != nil { + return nil, err + } + mapConfig[configKey] = config + } + } return &Sharding{ - configs: configs, + configs: mapConfig, + }, nil +} + +// generates the key for the sharding config. +func generateConfigsKey(tableName, shardingKey string) (string, error) { + // Table name cannot be empty + if tableName == "" { + return "", ErrMissingTableName + } + if shardingKey == "" { + return "", ErrMissingShardingKey } + return fmt.Sprintf("%s_%s", tableName, shardingKey), nil +} + +// get the configs key for using it to get the sharding config. +func (s *Sharding) getConfigKey(ctx context.Context, tableName string) (string, error) { + configKey := tableName + if shardingKey, ok := ctx.Value(ContextKeyForShardingKey).(string); ok { + // If sharding key is set in context, use it to get the sharding config. + configKey = fmt.Sprintf("%s_%s", tableName, shardingKey) + } else { + // If sharding key is not set in context, use the table name as the key. + return configKey, nil + } + + // check if the sharding key exists in the configs. + _, exis := s.configs[configKey] + if !exis { + return "", ErrShardingKeyNotExistInContext + } + + // If sharding key is not set in context, use the table name as the key. + return configKey, nil } func (s *Sharding) compile() error { @@ -353,9 +405,9 @@ func (s *Sharding) resolve(ctx context.Context, query string, args ...any) (ftQu tableName = table.Name.Name key := tableName - // If sharding key is set in context, use it to get the sharding config. - if shardingKey, ok := ctx.Value("sharding_key").(string); ok { - key = fmt.Sprintf("%s_%v", tableName, shardingKey) + key, err = s.getConfigKey(ctx, tableName) + if err != nil { + return } r, ok := s.configs[key] if !ok { diff --git a/test/sharding_test.go b/test/sharding_test.go index 4617c4a..53d09eb 100644 --- a/test/sharding_test.go +++ b/test/sharding_test.go @@ -154,32 +154,39 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) { // Configure Gorm Sharding middleware, register sharding strategy configuration // Logical table name is "orders" - db.Use(sharding.RegisterWithKeys(map[string]sharding.Config{ - "orders_order_year": { - ShardingKey: "order_year", + shardingConfig, err := sharding.RegisterWithKeys([]sharding.Config{ + { + ShardingKey: "order_year", + TableNames: []string{"orders"}, // Use custom sharding algorithm - ShardingAlgorithm: customShardingAlgorithmWithOrderYear, + ShardingAlgorithm: customShardingAlgorithmWithOrderYear, // Use custom primary key generation function - PrimaryKeyGenerator: sharding.PKCustom, + PrimaryKeyGenerator: sharding.PKCustom, // Custom primary key generation function PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, }, - "orders_user_id": { - ShardingKey: "user_id", - NumberOfShards: 4, + { + ShardingKey: "user_id", + TableNames: []string{"orders"}, + NumberOfShards: 4, // Use custom sharding algorithm - ShardingAlgorithm: customShardingAlgorithmWithUserId, + ShardingAlgorithm: customShardingAlgorithmWithUserId, // Use Snowflake algorithm to generate primary key PrimaryKeyGenerator: sharding.PKSnowflake, }, - "orders_order_id": { - ShardingKey: "order_id", + { + ShardingKey: "order_id", + TableNames: []string{"orders"}, // Use custom sharding algorithm ShardingAlgorithm: customShardingAlgorithmWithOrderId, PrimaryKeyGenerator: sharding.PKCustom, PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, }, - })) + }) + if err != nil { + panic(err) + } + db.Use(shardingConfig) // Insert and query examples based on order_year sharding key strategy InsertOrderByOrderYearKey(db) @@ -199,7 +206,7 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) { func InsertOrderByOrderYearKey(db *gorm.DB) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_year") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_year") db = db.WithContext(ctx) // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 @@ -225,7 +232,7 @@ func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) { var orders []Order ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_year") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_year") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("order_year=? and product_id=? and order_id=?", orderYear, 102, "20240101ORDER0002").Find(&orders).Error @@ -239,7 +246,7 @@ func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) { func InsertOrderByOrderIdKey(db *gorm.DB) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") db = db.WithContext(ctx) // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 @@ -265,7 +272,7 @@ func InsertOrderByOrderIdKey(db *gorm.DB) error { func UpdateByOrderIdKey(db *gorm.DB, orderId string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("order_id=?", orderId).Update("product_id", 102).Error @@ -278,7 +285,7 @@ func UpdateByOrderIdKey(db *gorm.DB, orderId string) error { func DeleteByOrderIdKey(db *gorm.DB, orderId string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Where("order_id=? and product_id=?", orderId, 100).Delete(&Order{}).Error @@ -292,7 +299,7 @@ func FindOrderByOrderIdKey(db *gorm.DB, orderId string) ([]Order, error) { // Query example ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "order_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("order_id=?", orderId).Find(&orders).Error @@ -314,7 +321,7 @@ type OrderByUserId struct { func InsertOrderByUserId(db *gorm.DB) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "user_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "user_id") db = db.WithContext(ctx) // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 @@ -340,7 +347,7 @@ func FindByUserIDKey(db *gorm.DB, userID int64) ([]Order, error) { // Query example ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, "sharding_key", "user_id") + ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "user_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("user_id = ?", userID).Find(&orders).Error From b1745bd42f3d09e284cc2e839b1c7f90d02ce6d2 Mon Sep 17 00:00:00 2001 From: panpeng Date: Wed, 6 Nov 2024 15:10:47 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B3=A8=E5=86=8C?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E7=AD=96=E7=95=A5=E7=9A=84=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sharding.go | 38 ++++++++------------------------------ test/sharding_test.go | 27 ++++++++++++--------------- 2 files changed, 20 insertions(+), 45 deletions(-) diff --git a/sharding.go b/sharding.go index bbef9ee..8468d9e 100644 --- a/sharding.go +++ b/sharding.go @@ -23,13 +23,12 @@ var ( ErrMissingTableName = errors.New("table name is required") ) -var ( +const ( ShardingIgnoreStoreKey = "sharding_ignore" + // ContextKeyForShardingKey is the context key for sharding key. + ShardingContextKey = "sharding_key" ) -// ContextKeyForShardingKey is the context key for sharding key. -const ContextKeyForShardingKey = "sharding_key" - type Sharding struct { *gorm.DB ConnPool *ConnPool @@ -52,10 +51,6 @@ type Config struct { // For example, for a product order table, you may want to split the rows by `user_id`. ShardingKey string - // logical table name.Suport multiple table names with same sharding key. - // For example, for user and order table, you may want to shard by `user_id`. - TableNames []string - // NumberOfShards specifies how many tables you want to sharding. NumberOfShards uint @@ -121,38 +116,21 @@ func Register(config Config, tables ...any) *Sharding { } // enables sharding for a single table with flexible support for multiple partition keys. -func RegisterWithKeys(configs []Config) (*Sharding, error) { - mapConfig := make(map[string]Config, len(configs)) - for _, config := range configs { - for _, tableName := range config.TableNames { - configKey, err := generateConfigsKey(tableName, config.ShardingKey) - if err != nil { - return nil, err - } - mapConfig[configKey] = config - } - } +func RegisterWithKeys(configs map[string]Config) (*Sharding, error) { return &Sharding{ - configs: mapConfig, + configs: configs, }, nil } // generates the key for the sharding config. -func generateConfigsKey(tableName, shardingKey string) (string, error) { - // Table name cannot be empty - if tableName == "" { - return "", ErrMissingTableName - } - if shardingKey == "" { - return "", ErrMissingShardingKey - } - return fmt.Sprintf("%s_%s", tableName, shardingKey), nil +func GenerateConfigsKey(tableName, shardingKey string) string { + return fmt.Sprintf("%s_%s", tableName, shardingKey) } // get the configs key for using it to get the sharding config. func (s *Sharding) getConfigKey(ctx context.Context, tableName string) (string, error) { configKey := tableName - if shardingKey, ok := ctx.Value(ContextKeyForShardingKey).(string); ok { + if shardingKey, ok := ctx.Value(ShardingContextKey).(string); ok { // If sharding key is set in context, use it to get the sharding config. configKey = fmt.Sprintf("%s_%s", tableName, shardingKey) } else { diff --git a/test/sharding_test.go b/test/sharding_test.go index 53d09eb..d482319 100644 --- a/test/sharding_test.go +++ b/test/sharding_test.go @@ -154,10 +154,9 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) { // Configure Gorm Sharding middleware, register sharding strategy configuration // Logical table name is "orders" - shardingConfig, err := sharding.RegisterWithKeys([]sharding.Config{ - { + shardingConfig, err := sharding.RegisterWithKeys(map[string]sharding.Config{ + sharding.GenerateConfigsKey("orders", "order_year"): { ShardingKey: "order_year", - TableNames: []string{"orders"}, // Use custom sharding algorithm ShardingAlgorithm: customShardingAlgorithmWithOrderYear, // Use custom primary key generation function @@ -165,18 +164,16 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) { // Custom primary key generation function PrimaryKeyGeneratorFn: customePrimaryKeyGeneratorFn, }, - { + sharding.GenerateConfigsKey("orders", "user_id"): { ShardingKey: "user_id", - TableNames: []string{"orders"}, NumberOfShards: 4, // Use custom sharding algorithm ShardingAlgorithm: customShardingAlgorithmWithUserId, // Use Snowflake algorithm to generate primary key PrimaryKeyGenerator: sharding.PKSnowflake, }, - { + sharding.GenerateConfigsKey("orders", "order_id"): { ShardingKey: "order_id", - TableNames: []string{"orders"}, // Use custom sharding algorithm ShardingAlgorithm: customShardingAlgorithmWithOrderId, PrimaryKeyGenerator: sharding.PKCustom, @@ -206,7 +203,7 @@ func Test_Gorm_Sharding_WithKeys(t *testing.T) { func InsertOrderByOrderYearKey(db *gorm.DB) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_year") + ctx = context.WithValue(ctx, sharding.ShardingContextKey, "order_year") db = db.WithContext(ctx) // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 @@ -232,7 +229,7 @@ func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) { var orders []Order ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_year") + ctx = context.WithValue(ctx, sharding.ShardingContextKey, "order_year") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("order_year=? and product_id=? and order_id=?", orderYear, 102, "20240101ORDER0002").Find(&orders).Error @@ -246,7 +243,7 @@ func FindByOrderYearKey(db *gorm.DB, orderYear int) ([]Order, error) { func InsertOrderByOrderIdKey(db *gorm.DB) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") + ctx = context.WithValue(ctx, sharding.ShardingContextKey, "order_id") db = db.WithContext(ctx) // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 @@ -272,7 +269,7 @@ func InsertOrderByOrderIdKey(db *gorm.DB) error { func UpdateByOrderIdKey(db *gorm.DB, orderId string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") + ctx = context.WithValue(ctx, sharding.ShardingContextKey, "order_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("order_id=?", orderId).Update("product_id", 102).Error @@ -285,7 +282,7 @@ func UpdateByOrderIdKey(db *gorm.DB, orderId string) error { func DeleteByOrderIdKey(db *gorm.DB, orderId string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") + ctx = context.WithValue(ctx, sharding.ShardingContextKey, "order_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Where("order_id=? and product_id=?", orderId, 100).Delete(&Order{}).Error @@ -299,7 +296,7 @@ func FindOrderByOrderIdKey(db *gorm.DB, orderId string) ([]Order, error) { // Query example ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "order_id") + ctx = context.WithValue(ctx, sharding.ShardingContextKey, "order_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("order_id=?", orderId).Find(&orders).Error @@ -321,7 +318,7 @@ type OrderByUserId struct { func InsertOrderByUserId(db *gorm.DB) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "user_id") + ctx = context.WithValue(ctx, sharding.ShardingContextKey, "user_id") db = db.WithContext(ctx) // Randomly 2024 or 2025 orderYear := rand.Intn(2) + 2024 @@ -347,7 +344,7 @@ func FindByUserIDKey(db *gorm.DB, userID int64) ([]Order, error) { // Query example ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctx = context.WithValue(ctx, sharding.ContextKeyForShardingKey, "user_id") + ctx = context.WithValue(ctx, sharding.ShardingContextKey, "user_id") db = db.WithContext(ctx) db = db.Table("orders") err := db.Model(&Order{}).Where("user_id = ?", userID).Find(&orders).Error