Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent interface type array from causing runtime errors #7361

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}

tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
return db.create(value)
}

// CreateInBatches inserts value in batches of batchSize
Expand Down Expand Up @@ -63,12 +60,15 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {

tx.RowsAffected = rowsAffected
default:
tx = db.getInstance()
tx.Statement.Dest = value
tx = tx.callbacks.Create().Execute(tx)
db.create(value)
}
return
}
func (db *DB) create(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
}

// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) {
Expand Down
7 changes: 6 additions & 1 deletion scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
if reflectValueType.Kind() == reflect.Interface && reflectValue.Len() > 0 {
reflectValueType = reflect.Indirect(reflectValue.Index(0)).Elem().Type()
}
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
Expand Down Expand Up @@ -318,7 +321,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
} else {
elem = reflect.New(reflectValueType)
}

if elem.Type().Kind() == reflect.Interface {
elem = elem.Elem()
}
db.scanIntoStruct(rows, elem, values, fields, joinFields)

if !update {
Expand Down
4 changes: 3 additions & 1 deletion schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam

for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
if modelType.Kind() == reflect.Interface && value.Len() > 0 {
modelType = reflect.Indirect(value.Index(0)).Elem().Type()
}
}

if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
Expand Down
Loading