diff --git a/sharding.go b/sharding.go index de71c53..ad774f5 100644 --- a/sharding.go +++ b/sharding.go @@ -33,6 +33,8 @@ type Sharding struct { _config Config _tables []any + + mutex sync.RWMutex } // Config specifies the configuration for sharding. @@ -266,8 +268,12 @@ func (s *Sharding) switchConn(db *gorm.DB) { // When DoubleWrite is enabled, we need to query database schema // information by table name during the migration. if _, ok := db.Get(ShardingIgnoreStoreKey); !ok { - s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s} - db.Statement.ConnPool = s.ConnPool + s.mutex.Lock() + if db.Statement.ConnPool != nil { + s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s} + db.Statement.ConnPool = s.ConnPool + } + s.mutex.Unlock() } } diff --git a/sharding_test.go b/sharding_test.go index 2f1057b..bdebbe1 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -1,6 +1,7 @@ package sharding import ( + "context" "fmt" "os" "regexp" @@ -8,6 +9,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/bwmarrin/snowflake" "github.com/longbridgeapp/assert" @@ -448,6 +450,36 @@ func TestReadWriteSplitting(t *testing.T) { assert.Equal(t, "iPhone", order.Product) } +func TestDataRace(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan error) + + for i := 0; i < 2; i++ { + go func() { + for { + select { + case <-ctx.Done(): + return + default: + err := db.Model(&Order{}).Where("user_id", 100).Find(&[]Order{}).Error + if err != nil { + ch <- err + return + } + } + } + }() + } + + select { + case <-time.After(time.Millisecond * 50): + cancel() + case err := <-ch: + cancel() + t.Fatal(err) + } +} + func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) { t.Helper() assert.Equal(t, toDialect(expected), middleware.LastQuery())