Skip to content

Commit

Permalink
switch to allshards
Browse files Browse the repository at this point in the history
  • Loading branch information
longquanzheng authored Oct 11, 2021
1 parent 7ee9c0d commit 22197f4
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
31 changes: 29 additions & 2 deletions common/persistence/sql/sqldriver/sharded.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"context"
"database/sql"
"fmt"
"reflect"

"github.com/jmoiron/sqlx"
"go.uber.org/multierr"
Expand Down Expand Up @@ -138,16 +139,42 @@ func (s *sharded) ExecDDL(dbShardID int, query string, args ...interface{}) (sql
}

func (s *sharded) SelectForSchemaQuery(dbShardID int, dest interface{}, query string, args ...interface{}) error {
if dbShardID == sqlplugin.DbShardUndefined || dbShardID == sqlplugin.DbAllShards {
if dbShardID == sqlplugin.DbShardUndefined {
return fmt.Errorf("invalid dbShardID %v shouldn't be used to SelectForSchemaQuery, there must be a bug", dbShardID)
}
if dbShardID == sqlplugin.DbAllShards {
var prevDest interface{}
for idx, db := range s.dbs {
err := db.Select(dest, query, args...)
if err != nil {
return err
}
if prevDest != nil && !reflect.DeepEqual(prevDest, dest) {
return fmt.Errorf("SelectForSchemaQuery fails for multiple database, value(%v) of shard(%v) is not the same as the value(%v) or shard(%v)", dest, idx, prevDest, idx-1)
}
prevDest = dest
}
}
return s.dbs[dbShardID].Select(dest, query, args...)
}

func (s *sharded) GetForSchemaQuery(dbShardID int, dest interface{}, query string, args ...interface{}) error {
if dbShardID == sqlplugin.DbShardUndefined || dbShardID == sqlplugin.DbAllShards {
if dbShardID == sqlplugin.DbShardUndefined {
return fmt.Errorf("invalid dbShardID %v shouldn't be used to Get, there must be a bug", dbShardID)
}
if dbShardID == sqlplugin.DbAllShards {
var prevDest interface{}
for idx, db := range s.dbs {
err := db.Get(dest, query, args...)
if err != nil {
return err
}
if prevDest != nil && !reflect.DeepEqual(prevDest, dest) {
return fmt.Errorf("GetForSchemaQuery fails for multiple database, value(%v) of shard(%v) is not the same as the value(%v) or shard(%v)", dest, idx, prevDest, idx-1)
}
prevDest = dest
}
}
return s.dbs[dbShardID].Get(dest, query, args...)
}

Expand Down
4 changes: 2 additions & 2 deletions common/persistence/sql/sqlplugin/mysql/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (mdb *db) CreateSchemaVersionTables() error {
// ReadSchemaVersion returns the current schema version for the keyspace
func (mdb *db) ReadSchemaVersion(database string) (string, error) {
var version string
err := mdb.driver.GetForSchemaQuery(sqlplugin.DbDefaultShard, &version, readSchemaVersionQuery, database)
err := mdb.driver.GetForSchemaQuery(sqlplugin.DbAllShards, &version, readSchemaVersionQuery, database)
return version, err
}

Expand All @@ -95,7 +95,7 @@ func (mdb *db) ExecSchemaOperationQuery(stmt string, args ...interface{}) error
// ListTables returns a list of tables in this database
func (mdb *db) ListTables(database string) ([]string, error) {
var tables []string
err := mdb.driver.SelectForSchemaQuery(sqlplugin.DbDefaultShard, &tables, fmt.Sprintf(listTablesQuery, database))
err := mdb.driver.SelectForSchemaQuery(sqlplugin.DbAllShards, &tables, fmt.Sprintf(listTablesQuery, database))
return tables, err
}

Expand Down
4 changes: 2 additions & 2 deletions common/persistence/sql/sqlplugin/postgres/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (pdb *db) CreateSchemaVersionTables() error {
// ReadSchemaVersion returns the current schema version for the keyspace
func (pdb *db) ReadSchemaVersion(database string) (string, error) {
var version string
err := pdb.driver.GetForSchemaQuery(sqlplugin.DbDefaultShard, &version, readSchemaVersionQuery, database)
err := pdb.driver.GetForSchemaQuery(sqlplugin.DbAllShards, &version, readSchemaVersionQuery, database)
return version, err
}

Expand All @@ -101,7 +101,7 @@ func (pdb *db) ExecSchemaOperationQuery(stmt string, args ...interface{}) error
// ListTables returns a list of tables in this database
func (pdb *db) ListTables(database string) ([]string, error) {
var tables []string
err := pdb.driver.SelectForSchemaQuery(sqlplugin.DbDefaultShard, &tables, listTablesQuery)
err := pdb.driver.SelectForSchemaQuery(sqlplugin.DbAllShards, &tables, listTablesQuery)
return tables, err
}

Expand Down

0 comments on commit 22197f4

Please sign in to comment.