Skip to content

Commit

Permalink
Merge pull request #68 from olachat/yinloo/support-retry
Browse files Browse the repository at this point in the history
support retry
  • Loading branch information
yinloo-ola authored Feb 5, 2025
2 parents 2b5b84b + 8ecea05 commit 2f27902
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 7 deletions.
103 changes: 103 additions & 0 deletions coredb/engine_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,63 @@ package coredb
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"strings"
"time"
)

type IsNonRetryableErrorFunc func(err error) bool

// RetryConfig encapsulates retry parameters.
type RetryConfig struct {
MaxRetries int
InitialBackoff time.Duration
IsNonRetryableErrorFunc IsNonRetryableErrorFunc
}

// DefaultRetryConfig provides a reasonable default configuration
var DefaultRetryConfig = RetryConfig{
MaxRetries: 5,
InitialBackoff: 200 * time.Millisecond,
IsNonRetryableErrorFunc: IsNonRetryableError,
}

// IsNonRetryableError checks if an error is non-retryable.
func IsNonRetryableError(err error) bool {
if err == nil {
return false
}
// Example (Replace with your database's non-retryable errors)

// SQL specific errors that are not retryable
if errors.Is(err, sql.ErrNoRows) {
return true
}

// Example: Invalid SQL syntax
if strings.Contains(err.Error(), "syntax error") {
return true
}

if strings.Contains(err.Error(), "1146") { // Table doesn't exists
return true
}
if strings.Contains(err.Error(), "1064") { // No database selected
return true
}
if strings.Contains(err.Error(), "1149") { // Invalid SQL statement
return true
}
// Example: Authentication issues
if strings.Contains(err.Error(), "Access denied") {
return true
}

return false // Default is retryable
}

// FetchByPKCtx returns a row of T type with given primary key value
func FetchByPKCtx[T any](ctx context.Context, dbname string, tableName string, pkName []string, val ...any) (*T, error) {
sql := "WHERE `" + pkName[0] + "` = ?"
Expand Down Expand Up @@ -57,6 +110,56 @@ func ExecCtx(ctx context.Context, dbname string, query string, params ...any) (s
return mydb.ExecContext(ctx, query, params...)
}

// ExecWithRetry executes a query with retry logic on failure.
func ExecWithRetry(ctx context.Context, dbname string, query string, retryConfig RetryConfig, params ...any) (sql.Result, error) {
// Set defaults for invalid config
if retryConfig.MaxRetries <= 0 {
retryConfig.MaxRetries = DefaultRetryConfig.MaxRetries
}

if retryConfig.InitialBackoff <= 0 {
retryConfig.InitialBackoff = DefaultRetryConfig.InitialBackoff
}

// Use the default if NonRetryableErrorFunc is nil
nonRetryableErrorFunc := retryConfig.IsNonRetryableErrorFunc
if nonRetryableErrorFunc == nil {
nonRetryableErrorFunc = IsNonRetryableError
}

var result sql.Result
var err error
retryCount := 0
currentBackoff := retryConfig.InitialBackoff

for {
select {
case <-ctx.Done():
return result, fmt.Errorf("context cancelled during retry: %w", ctx.Err())
default:
result, err = ExecCtx(ctx, dbname, query, params...)
if err == nil {
return result, nil // Success!
}

if nonRetryableErrorFunc(err) {
return result, err // Fail immediately for non-retryable errors
}

retryCount++
if retryCount > retryConfig.MaxRetries {
log.Printf("Max retries (%d) exceeded for: %s, last error: %v", retryConfig.MaxRetries, query, err)
return result, fmt.Errorf("max retries exceeded, last error: %w", err)
}

delay := currentBackoff
log.Printf("Retrying attempt %d with delay %v. Last error: %v", retryCount, delay, err)
time.Sleep(delay)
currentBackoff *= 2
}
}
}

// FindOneCtx returns a row from given table type with where query.
// If no rows found, *T will be nil. No error will be returned.
func FindOneCtx[T any](ctx context.Context, dbname string, tableName string, where WhereQuery) (*T, error) {
Expand Down
50 changes: 50 additions & 0 deletions coredb/txengine/tx_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,56 @@ func RunTransaction(ctx context.Context, dbName string, fn func(ctx context.Cont
return
}

// RunTxWithRetry runs a transaction with retry logic on failure.
func RunTxWithRetry(ctx context.Context, dbName string, retryConfig coredb.RetryConfig, fn func(ctx context.Context, sqlTx *sql.Tx) error) (err error) {
// Set defaults for invalid config
if retryConfig.MaxRetries <= 0 {
retryConfig.MaxRetries = coredb.DefaultRetryConfig.MaxRetries
}

if retryConfig.InitialBackoff <= 0 {
retryConfig.InitialBackoff = coredb.DefaultRetryConfig.InitialBackoff
}

nonRetryableErrorFunc := retryConfig.IsNonRetryableErrorFunc
if nonRetryableErrorFunc == nil {
nonRetryableErrorFunc = coredb.IsNonRetryableError
}

var resultErr error
retryCount := 0
currentBackoff := retryConfig.InitialBackoff

for {
select {
case <-ctx.Done():
return fmt.Errorf("context cancelled during retry: %w", ctx.Err())
default:
resultErr = RunTransaction(ctx, dbName, fn)
if resultErr == nil {
return nil // Success!
}

if nonRetryableErrorFunc(resultErr) {
log.Printf("Non-retryable error: %v", resultErr)
return resultErr // Fail immediately for non-retryable errors
}

retryCount++
if retryCount > retryConfig.MaxRetries {
log.Printf("Max retries (%d) exceeded, last error: %v", retryConfig.MaxRetries, resultErr)
return fmt.Errorf("max retries exceeded, last error: %w", resultErr)

}

delay := currentBackoff
log.Printf("Retrying attempt %d with delay %v. Last error: %v", retryCount, delay, resultErr)
time.Sleep(delay)
currentBackoff *= 2
}
}
}

func runTransaction(ctx context.Context, tx *sql.Tx, conn *sql.Conn, fn func(ctx context.Context, sqlTx *sql.Tx) error) (err error) {
if tx == nil && conn == nil {
return errors.New("wrong usage. tx and conn cannot both be nil")
Expand Down
33 changes: 33 additions & 0 deletions tests/retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package tests

import (
"context"
"testing"
"time"

_ "github.com/go-sql-driver/mysql"
"github.com/olachat/gola/v2/coredb"
)

func TestExecWithRetry_Success(t *testing.T) {
ctx := context.Background()

_, err := coredb.ExecWithRetry(ctx, testDBName, "INSERT INTO test_table (name, email) VALUES (?, ?)", coredb.DefaultRetryConfig, "test", "[email protected]")
if err != nil {
t.Fatalf("Expected success, but got error: %v", err)
}
}

func TestExecWithRetry_Fail(t *testing.T) {
ctx := context.Background()
now := time.Now()
_, err := coredb.ExecWithRetry(ctx, testDBName, "INSERT INTO no_such_table (name, email) VALUES (?, ?)", coredb.DefaultRetryConfig, "test", "[email protected]")
if err == nil {
t.Fatalf("Expected error, but got success")
}
elapsed := time.Since(now)
if elapsed < (200+400+800+1600+3200)*time.Millisecond {
t.Fatalf("Expected retry to take at least 100ms, but took: %v", elapsed)
}
t.Logf("retry took: %v", elapsed)
}
16 changes: 9 additions & 7 deletions tests/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,18 @@ func init() {
panic(err)
}

// realdb, err := open()

// if err != nil {
// panic(err)
// }

coredb.Setup(func(dbname string, mode coredb.DBMode) *sql.DB {
return db
if dbname == testDBName {
return db
}
return nil
})

_, err = db.Exec("CREATE TABLE IF NOT EXISTS test_table (name VARCHAR(255), email VARCHAR(255))")
if err != nil {
panic(err)
}

// create tables
for _, tableName := range tableNames {
query, _ := testdata.Fixtures.ReadFile(tableName + ".sql")
Expand Down

0 comments on commit 2f27902

Please sign in to comment.