diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a5e2915..08b73ef 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,9 +40,10 @@ jobs: env: DIALECTOR: postgres - DATABASE_URL: postgres://gorm:gorm@localhost:5432/sharding-test - DATABASE_READ_URL: postgres://gorm:gorm@localhost:5432/sharding-read-test - DATABASE_WRITE_URL: postgres://gorm:gorm@localhost:5432/sharding-write-test + DB_URL: postgres://gorm:gorm@localhost:5432/sharding-test + DB_NOID_URL: postgres://gorm:gorm@localhost:5432/sharding-noid-test + DB_READ_URL: postgres://gorm:gorm@localhost:5432/sharding-read-test + DB_WRITE_URL: postgres://gorm:gorm@localhost:5432/sharding-write-test steps: - name: Set up Go uses: actions/setup-go@v4 @@ -50,6 +51,9 @@ jobs: go-version: "1.20" id: go + - name: Create No ID Database + run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-noid-test";' + - name: Create Read Database run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-read-test";' @@ -93,9 +97,10 @@ jobs: env: DIALECTOR: mysql - DATABASE_URL: gorm:gorm@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4&parseTime=True&loc=Local - DATABASE_READ_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4&parseTime=True&loc=Local - DATABASE_WRITE_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4&parseTime=True&loc=Local + DB_URL: gorm:gorm@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4&parseTime=True&loc=Local + DB_NOID_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-noid-test?charset=utf8mb4&parseTime=True&loc=Local + DB_READ_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4&parseTime=True&loc=Local + DB_WRITE_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4&parseTime=True&loc=Local steps: - name: Set up Go uses: actions/setup-go@v4 @@ -103,13 +108,14 @@ jobs: go-version: "1.20" id: go + - name: Create No ID Database + run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-noid-test + - name: Create Read Database run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-read-test - #run: mysql -e 'CREATE DATABASE sharding-read-test' -ugorm -pgorm - name: Create Write Database run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-write-test - #run: mysql -e 'CREATE DATABASE sharding-write-test' -ugorm -pgorm - name: Check out code into the Go module directory uses: actions/checkout@v3 @@ -125,7 +131,7 @@ jobs: strategy: matrix: - dbversion: ["mariadb:latest"] + dbversion: ["mariadb:10.11"] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -148,9 +154,10 @@ jobs: env: DIALECTOR: mariadb - DATABASE_URL: gorm:gorm@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4&parseTime=True&loc=Local - DATABASE_READ_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4&parseTime=True&loc=Local - DATABASE_WRITE_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4&parseTime=True&loc=Local + DB_URL: gorm:gorm@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4&parseTime=True&loc=Local + DB_NOID_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-noid-test?charset=utf8mb4&parseTime=True&loc=Local + DB_READ_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4&parseTime=True&loc=Local + DB_WRITE_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4&parseTime=True&loc=Local steps: - name: Set up Go uses: actions/setup-go@v4 @@ -158,13 +165,14 @@ jobs: go-version: "1.20" id: go + - name: Create No ID Database + run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-noid-test + - name: Create Read Database run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-read-test - #run: mysql -e 'CREATE DATABASE sharding-read-test' -ugorm -pgorm - name: Create Write Database run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-write-test - #run: mysql -e 'CREATE DATABASE sharding-write-test' -ugorm -pgorm - name: Check out code into the Go module directory uses: actions/checkout@v3 diff --git a/README.md b/README.md index 6fb441e..0c0a1ba 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,10 @@ db.Use(sharding.Register(sharding.Config{ }, "orders") ``` +### No primary key + +If your table doesn't have a primary key, or has a primary key that isn't called `id`, anyway, you don't want to auto-fill the `id` field, then you can set `PrimaryKeyGenerator` to `PKCustom` and have `PrimaryKeyGeneratorFn` return `0`. + ## Combining with dbresolver > 🚨 NOTE: Use dbresolver first. diff --git a/conn_pool.go b/conn_pool.go index 0603eb2..3a6d311 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -22,7 +22,7 @@ func (pool ConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stm return pool.ConnPool.PrepareContext(ctx, query) } -func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { ftQuery, stQuery, table, err := pool.sharding.resolve(query, args...) if err != nil { return nil, err @@ -42,7 +42,7 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...inte } // https://github.com/go-gorm/gorm/blob/v1.21.11/callbacks/query.go#L18 -func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { ftQuery, stQuery, table, err := pool.sharding.resolve(query, args...) if err != nil { return nil, err @@ -61,7 +61,7 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...int return pool.ConnPool.QueryContext(ctx, stQuery, args...) } -func (pool ConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { +func (pool ConnPool) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { _, query, _, _ = pool.sharding.resolve(query, args...) pool.sharding.querys.Store("last_query", query) diff --git a/dialector.go b/dialector.go index da39703..7498617 100644 --- a/dialector.go +++ b/dialector.go @@ -33,7 +33,7 @@ func (d ShardingDialector) Migrator(db *gorm.DB) gorm.Migrator { } } -func (m ShardingMigrator) AutoMigrate(dst ...interface{}) error { +func (m ShardingMigrator) AutoMigrate(dst ...any) error { shardingDsts, noShardingDsts, err := m.splitShardingDsts(dst...) if err != nil { return err @@ -61,7 +61,7 @@ func (m ShardingMigrator) AutoMigrate(dst ...interface{}) error { return nil } -func (m ShardingMigrator) DropTable(dst ...interface{}) error { +func (m ShardingMigrator) DropTable(dst ...any) error { shardingDsts, noShardingDsts, err := m.splitShardingDsts(dst...) if err != nil { return err @@ -84,15 +84,15 @@ func (m ShardingMigrator) DropTable(dst ...interface{}) error { type shardingDst struct { table string - dst interface{} + dst any } // splite sharding or normal dsts -func (m ShardingMigrator) splitShardingDsts(dsts ...interface{}) (shardingDsts []shardingDst, - noShardingDsts []interface{}, err error) { +func (m ShardingMigrator) splitShardingDsts(dsts ...any) (shardingDsts []shardingDst, + noShardingDsts []any, err error) { shardingDsts = make([]shardingDst, 0) - noShardingDsts = make([]interface{}, 0) + noShardingDsts = make([]any, 0) for _, model := range dsts { stmt := &gorm.Statement{DB: m.sharding.DB} err = stmt.Parse(model) diff --git a/go.mod b/go.mod index bf42936..ea8f184 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,17 @@ module gorm.io/sharding -go 1.20 +go 1.21 require ( github.com/bwmarrin/snowflake v0.3.0 github.com/longbridgeapp/assert v1.1.0 github.com/longbridgeapp/sqlparser v0.3.1 - golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea - gorm.io/driver/mysql v1.4.7 - gorm.io/driver/postgres v1.5.0 - gorm.io/gorm v1.25.1 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 + gorm.io/driver/mysql v1.5.1 + gorm.io/driver/postgres v1.5.2 + gorm.io/gorm v1.25.4 gorm.io/hints v1.1.2 - gorm.io/plugin/dbresolver v1.4.1 + gorm.io/plugin/dbresolver v1.4.7 ) require ( @@ -19,12 +19,12 @@ require ( github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.3.0 // indirect + github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/testify v1.8.1 // indirect - golang.org/x/crypto v0.6.0 // indirect - golang.org/x/text v0.7.0 // indirect + golang.org/x/crypto v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 8001bd4..e4fba1b 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.3.0 h1:/NQi8KHMpKWHInxXesC8yD4DhkXPrVhmnwYkjp9AmBA= github.com/jackc/pgx/v5 v5.3.0/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= +github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= +github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= github.com/jackc/puddle/v2 v2.2.0/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= @@ -54,8 +56,12 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= +golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea h1:vLCWI/yYrdEHyN2JzIzPO3aaQJHQdp89IZBA/+azVC4= golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -78,6 +84,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= @@ -93,8 +101,12 @@ gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= gorm.io/driver/mysql v1.4.7 h1:rY46lkCspzGHn7+IYsNpSfEv9tA+SU4SkkB+GFX125Y= gorm.io/driver/mysql v1.4.7/go.mod h1:SxzItlnT1cb6e1e4ZRpgJN2VYtcqJgqnHxWr4wsP8oc= +gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw= +gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o= gorm.io/driver/postgres v1.5.0 h1:u2FXTy14l45qc3UeCJ7QaAXZmZfDDv0YrthvmRq1l0U= gorm.io/driver/postgres v1.5.0/go.mod h1:FUZXzO+5Uqg5zzwzv4KK49R8lvGIyscBOqYrtI1Ce9A= +gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= +gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= gorm.io/driver/sqlite v1.5.0 h1:zKYbzRCpBrT1bNijRnxLDJWPjVfImGEn0lSnUY5gZ+c= gorm.io/driver/sqlite v1.5.0/go.mod h1:kDMDfntV9u/vuMmz8APHtHF0b4nyBB7sfCieC6G8k8I= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= @@ -103,7 +115,12 @@ gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11/go.mod h1:L4uxeKpfBml98NYqVqw gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64= gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw= +gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= gorm.io/hints v1.1.2 h1:b5j0kwk5p4+3BtDtYqqfY+ATSxjj+6ptPgVveuynn9o= gorm.io/hints v1.1.2/go.mod h1:/ARdpUHAtyEMCh5NNi3tI7FsGh+Cj/MIUlvNxCNCFWg= gorm.io/plugin/dbresolver v1.4.1 h1:Ug4LcoPhrvqq71UhxtF346f+skTYoCa/nEsdjvHwEzk= gorm.io/plugin/dbresolver v1.4.1/go.mod h1:CTbCtMWhsjXSiJqiW2R8POvJ2cq18RVOl4WGyT5nhNc= +gorm.io/plugin/dbresolver v1.4.7 h1:ZwtwmJQxTx9us7o6zEHFvH1q4OeEo1pooU7efmnunJA= +gorm.io/plugin/dbresolver v1.4.7/go.mod h1:l4Cn87EHLEYuqUncpEeTC2tTJQkjngPSD+lo8hIvcT0= diff --git a/sharding.go b/sharding.go index ff66327..b6e9c29 100644 --- a/sharding.go +++ b/sharding.go @@ -32,9 +32,9 @@ type Sharding struct { snowflakeNodes []*snowflake.Node _config Config - _tables []interface{} - - mutex sync.RWMutex + _tables []any + + mutex sync.RWMutex } // Config specifies the configuration for sharding. @@ -56,13 +56,13 @@ type Config struct { // table's suffix by the column value. // For example, this function implements a mod sharding algorithm. // - // func(value interface{}) (suffix string, err error) { + // func(value any) (suffix string, err error) { // if uid, ok := value.(int64);ok { // return fmt.Sprintf("_%02d", user_id % 64), nil // } // return "", errors.New("invalid user_id") // } - ShardingAlgorithm func(columnValue interface{}) (suffix string, err error) + ShardingAlgorithm func(columnValue any) (suffix string, err error) // ShardingSuffixs specifies a function to generate all table's suffix. // Used to support Migrator and generate PrimaryKey. @@ -95,6 +95,7 @@ type Config struct { // PrimaryKeyGeneratorFn specifies a function to generate the primary key. // When use auto-increment like generator, the tableIdx argument could ignored. // For example, this function use the Snowflake library to generate the primary key. + // If you don't want to auto-fill the `id` or use a primary key that isn't called `id`, just return 0. // // func(tableIdx int64) int64 { // return nodes[tableIdx].Generate().Int64() @@ -102,7 +103,7 @@ type Config struct { PrimaryKeyGeneratorFn func(tableIdx int64) int64 } -func Register(config Config, tables ...interface{}) *Sharding { +func Register(config Config, tables ...any) *Sharding { return &Sharding{ _config: config, _tables: tables, @@ -165,7 +166,7 @@ func (s *Sharding) compile() error { } else if c.NumberOfShards < 10000 { c.tableFormat = "_%04d" } - c.ShardingAlgorithm = func(value interface{}) (suffix string, err error) { + c.ShardingAlgorithm = func(value any) (suffix string, err error) { id := 0 switch value := value.(type) { case int: @@ -279,7 +280,7 @@ func (s *Sharding) switchConn(db *gorm.DB) { } // resolve split the old query to full table query and sharding table query -func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, tableName string, err error) { +func (s *Sharding) resolve(query string, args ...any) (ftQuery, stQuery, tableName string, err error) { ftQuery = query stQuery = query if len(s.configs) == 0 { @@ -295,7 +296,7 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, var condition sqlparser.Expr var isInsert bool var insertNames []*sqlparser.Ident - var inserExpressions []*sqlparser.Exprs + var insertExpressions []*sqlparser.Exprs var insertStmt *sqlparser.InsertStatement switch stmt := expr.(type) { @@ -313,7 +314,7 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, table = stmt.TableName isInsert = true insertNames = stmt.ColumnNames - inserExpressions = stmt.Expressions + insertExpressions = stmt.Expressions insertStmt = stmt case *sqlparser.UpdateStatement: condition = stmt.Condition @@ -334,12 +335,12 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, var suffix string if isInsert { var newTable *sqlparser.TableName - for _, inserExpression := range inserExpressions { - var value interface{} + for _, insertExpression := range insertExpressions { + var value any var id int64 var keyFind bool columnNames := insertNames - insertValues := inserExpression.Exprs + insertValues := insertExpression.Exprs value, id, keyFind, err = s.insertValue(r.ShardingKey, insertNames, insertValues, args...) if err != nil { return @@ -368,17 +369,22 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, break } } - if fillID { - suffixWord := strings.Replace(suffix, "_", "", 1) - tblIdx, err := strconv.Atoi(suffixWord) - if err != nil { - tblIdx = slices.Index(r.ShardingSuffixs(), suffixWord) - if tblIdx == -1 { - return ftQuery, stQuery, tableName, errors.New("table suffix '" + suffixWord + "' is not in ShardingSuffixs. In order to generate the primary key, ShardingSuffixs should include all table suffixes") - } - //return ftQuery, stQuery, tableName, err + suffixWord := strings.Replace(suffix, "_", "", 1) + tblIdx, err := strconv.Atoi(suffixWord) + if err != nil { + tblIdx = slices.Index(r.ShardingSuffixs(), suffixWord) + if tblIdx == -1 { + return ftQuery, stQuery, tableName, errors.New("table suffix '" + suffixWord + "' is not in ShardingSuffixs. In order to generate the primary key, ShardingSuffixs should include all table suffixes") } - id := r.PrimaryKeyGeneratorFn(int64(tblIdx)) + //return ftQuery, stQuery, tableName, err + } + + id := r.PrimaryKeyGeneratorFn(int64(tblIdx)) + if id == 0 { + fillID = false + } + + if fillID { columnNames = append(insertNames, &sqlparser.Ident{Name: "id"}) insertValues = append(insertValues, &sqlparser.NumberLit{Value: strconv.FormatInt(id, 10)}) } @@ -386,7 +392,7 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, if fillID { insertStmt.ColumnNames = columnNames - inserExpression.Exprs = insertValues + insertExpression.Exprs = insertValues } } @@ -395,7 +401,7 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, stQuery = insertStmt.String() } else { - var value interface{} + var value any var id int64 var keyFind bool value, id, keyFind, err = s.nonInsertValue(r.ShardingKey, condition, args...) @@ -430,7 +436,7 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, return } -func getSuffix(value interface{}, id int64, keyFind bool, r Config) (suffix string, err error) { +func getSuffix(value any, id int64, keyFind bool, r Config) (suffix string, err error) { if keyFind { suffix, err = r.ShardingAlgorithm(value) if err != nil { @@ -446,7 +452,7 @@ func getSuffix(value interface{}, id int64, keyFind bool, r Config) (suffix stri return } -func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, err error) { +func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) { if len(names) != len(exprs) { return nil, 0, keyFind, errors.New("column names and expressions mismatch") } @@ -474,7 +480,7 @@ func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sql return } -func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, err error) { +func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...any) (value any, id int64, keyFind bool, err error) { err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error { if n, ok := node.(*sqlparser.BinaryExpr); ok { if x, ok := n.X.(*sqlparser.Ident); ok { diff --git a/sharding_test.go b/sharding_test.go index 4343329..bdebbe1 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -31,74 +31,95 @@ type Category struct { Name string } -func databaseURL() string { - databaseURL := os.Getenv("DATABASE_URL") - if len(databaseURL) == 0 { - databaseURL = "postgres://localhost:5432/sharding-test?sslmode=disable" +func dbURL() string { + dbURL := os.Getenv("DB_URL") + if len(dbURL) == 0 { + dbURL = "postgres://localhost:5432/sharding-test?sslmode=disable" if mysqlDialector() { - databaseURL = "root@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4" + dbURL = "root@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4" } } - return databaseURL + return dbURL } -func databaseReadURL() string { - databaseURL := os.Getenv("DATABASE_READ_URL") - if len(databaseURL) == 0 { - databaseURL = "postgres://localhost:5432/sharding-read-test?sslmode=disable" +func dbNoIDURL() string { + dbURL := os.Getenv("DB_NOID_URL") + if len(dbURL) == 0 { + dbURL = "postgres://localhost:5432/sharding-noid-test?sslmode=disable" if mysqlDialector() { - databaseURL = "root@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4" + dbURL = "root@tcp(127.0.0.1:3306)/sharding-noid-test?charset=utf8mb4" } } - return databaseURL + return dbURL } -func databaseWriteURL() string { - databaseURL := os.Getenv("DATABASE_WRITE_URL") - if len(databaseURL) == 0 { - databaseURL = "postgres://localhost:5432/sharding-write-test?sslmode=disable" +func dbReadURL() string { + dbURL := os.Getenv("DB_READ_URL") + if len(dbURL) == 0 { + dbURL = "postgres://localhost:5432/sharding-read-test?sslmode=disable" if mysqlDialector() { - databaseURL = "root@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4" + dbURL = "root@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4" } } - return databaseURL + return dbURL +} + +func dbWriteURL() string { + dbURL := os.Getenv("DB_WRITE_URL") + if len(dbURL) == 0 { + dbURL = "postgres://localhost:5432/sharding-write-test?sslmode=disable" + if mysqlDialector() { + dbURL = "root@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4" + } + } + return dbURL } var ( dbConfig = postgres.Config{ - DSN: databaseURL(), + DSN: dbURL(), + PreferSimpleProtocol: true, + } + dbNoIDConfig = postgres.Config{ + DSN: dbNoIDURL(), PreferSimpleProtocol: true, } dbReadConfig = postgres.Config{ - DSN: databaseReadURL(), + DSN: dbReadURL(), PreferSimpleProtocol: true, } dbWriteConfig = postgres.Config{ - DSN: databaseWriteURL(), + DSN: dbWriteURL(), PreferSimpleProtocol: true, } - db, dbRead, dbWrite *gorm.DB + db, dbNoID, dbRead, dbWrite *gorm.DB - shardingConfig Config - middleware *Sharding - node, _ = snowflake.NewNode(1) + shardingConfig, shardingConfigNoID Config + middleware, middlewareNoID *Sharding + node, _ = snowflake.NewNode(1) ) func init() { if mysqlDialector() { - db, _ = gorm.Open(mysql.Open(databaseURL()), &gorm.Config{ + db, _ = gorm.Open(mysql.Open(dbURL()), &gorm.Config{ + DisableForeignKeyConstraintWhenMigrating: true, + }) + dbNoID, _ = gorm.Open(mysql.Open(dbNoIDURL()), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, }) - dbRead, _ = gorm.Open(mysql.Open(databaseReadURL()), &gorm.Config{ + dbRead, _ = gorm.Open(mysql.Open(dbReadURL()), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, }) - dbWrite, _ = gorm.Open(mysql.Open(databaseWriteURL()), &gorm.Config{ + dbWrite, _ = gorm.Open(mysql.Open(dbWriteURL()), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, }) } else { db, _ = gorm.Open(postgres.New(dbConfig), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, }) + dbNoID, _ = gorm.Open(postgres.New(dbNoIDConfig), &gorm.Config{ + DisableForeignKeyConstraintWhenMigrating: true, + }) dbRead, _ = gorm.Open(postgres.New(dbReadConfig), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, }) @@ -114,7 +135,18 @@ func init() { PrimaryKeyGenerator: PKSnowflake, } + shardingConfigNoID = Config{ + DoubleWrite: true, + ShardingKey: "user_id", + NumberOfShards: 4, + PrimaryKeyGenerator: PKCustom, + PrimaryKeyGeneratorFn: func(_ int64) int64 { + return 0 + }, + } + middleware = Register(shardingConfig, &Order{}) + middlewareNoID = Register(shardingConfigNoID, &Order{}) fmt.Println("Clean only tables ...") dropTables() @@ -130,6 +162,10 @@ func init() { user_id bigint, product text )`) + dbNoID.Exec(`CREATE TABLE ` + table + ` ( + user_id bigint, + product text + )`) dbRead.Exec(`CREATE TABLE ` + table + ` ( id bigint PRIMARY KEY, user_id bigint, @@ -143,12 +179,14 @@ func init() { } db.Use(middleware) + dbNoID.Use(middlewareNoID) } func dropTables() { tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"} for _, table := range tables { db.Exec("DROP TABLE IF EXISTS " + table) + dbNoID.Exec("DROP TABLE IF EXISTS " + table) dbRead.Exec("DROP TABLE IF EXISTS " + table) dbWrite.Exec("DROP TABLE IF EXISTS " + table) db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq")) @@ -185,6 +223,12 @@ func TestInsert(t *testing.T) { assertQueryResult(t, `INSERT INTO orders_0 ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx) } +func TestInsertNoID(t *testing.T) { + dbNoID.Create(&Order{UserID: 100, Product: "iPhone"}) + expected := `INSERT INTO orders_0 ("user_id", "product") VALUES ($1, $2) RETURNING "id"` + assert.Equal(t, toDialect(expected), middlewareNoID.LastQuery()) +} + func TestFillID(t *testing.T) { db.Create(&Order{UserID: 100, Product: "iPhone"}) expected := `INSERT INTO orders_0 ("user_id", "product", id) VALUES` @@ -268,10 +312,17 @@ func TestSelect12(t *testing.T) { } func TestSelect13(t *testing.T) { - tx := db.Raw("SELECT 1").Find(&[]Order{}) + var n int + tx := db.Raw("SELECT 1").Find(&n) assertQueryResult(t, `SELECT 1`, tx) } +func TestSelect14(t *testing.T) { + dbNoID.Model(&Order{}).Where("user_id = 101").Find(&[]Order{}) + expected := `SELECT * FROM orders_1 WHERE user_id = 101` + assert.Equal(t, toDialect(expected), middlewareNoID.LastQuery()) +} + func TestUpdate(t *testing.T) { tx := db.Model(&Order{}).Where("user_id = ?", 100).Update("product", "new title") assertQueryResult(t, `UPDATE orders_0 SET "product" = $1 WHERE user_id = $2`, tx) @@ -327,7 +378,7 @@ func TestNoSharding(t *testing.T) { func TestPKSnowflake(t *testing.T) { var db *gorm.DB if mysqlDialector() { - db, _ = gorm.Open(mysql.Open(databaseURL()), &gorm.Config{ + db, _ = gorm.Open(mysql.Open(dbURL()), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, }) } else { @@ -372,7 +423,7 @@ func TestReadWriteSplitting(t *testing.T) { var db *gorm.DB if mysqlDialector() { - db, _ = gorm.Open(mysql.Open(databaseWriteURL()), &gorm.Config{ + db, _ = gorm.Open(mysql.Open(dbWriteURL()), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, }) } else {