Skip to content

Commit

Permalink
feat: fix failure when calling SavePolicy within the Transaction meth…
Browse files Browse the repository at this point in the history
…od (#251)

Co-authored-by: junfengxu <[email protected]>
  • Loading branch information
Hill1126 and junfengxu authored Nov 13, 2024
1 parent aef8c1f commit 16aa502
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
20 changes: 17 additions & 3 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,18 +707,17 @@ func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) erro
a.transactionMu.Lock()
defer a.transactionMu.Unlock()
var err error
oriAdapter := a.db
// reload policy from database to sync with the transaction
defer func() {
e.SetAdapter(&Adapter{db: oriAdapter, transactionMu: a.transactionMu})
e.SetAdapter(a.Copy())
err = e.LoadPolicy()
if err != nil {
panic(err)
}
}()
copyDB := *a.db
tx := copyDB.Begin(opts...)
b := &Adapter{db: tx, transactionMu: a.transactionMu}
b := a.Copy()
// copy enforcer to set the new adapter with transaction tx
copyEnforcer := e
copyEnforcer.SetAdapter(b)
Expand Down Expand Up @@ -946,6 +945,21 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
return oldPolicies, tx.Commit().Error
}

func (a *Adapter) Copy() *Adapter {
oriAdapter := a.db
return &Adapter{
db: oriAdapter,
transactionMu: a.transactionMu,
driverName: a.driverName,
dataSourceName: a.dataSourceName,
databaseName: a.databaseName,
tablePrefix: a.tablePrefix,
tableName: a.tableName,
dbSpecified: a.dbSpecified,
isFiltered: a.isFiltered,
}
}

// Preview Pre-checking to avoid causing partial load success and partial failure deep
func (a *Adapter) Preview(rules *[]CasbinRule, model model.Model) error {
j := 0
Expand Down
30 changes: 30 additions & 0 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,33 @@ func TestTransactionRace(t *testing.T) {
require.True(t, e.HasPolicy("jack", fmt.Sprintf("data%d", i), "write"))
}
}

func TestTransactionWithSavePolicy(t *testing.T) {
a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule")
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)
defer func() {
e.ClearPolicy()
err := e.SavePolicy()
if err != nil {
t.Fatalf("save policy err %v", err)
}
}()
err := e.GetAdapter().(*Adapter).Transaction(e, func(e casbin.IEnforcer) error {
_, err := e.AddPolicy("jack", "data1", "write")
if err != nil {
return err
}
_, err = e.AddPolicy("jack", "data2", "write")
if err != nil {
return err
}
err = e.SavePolicy()
if err != nil {
return err
}
return nil
})
if err != nil {
return
}
}

0 comments on commit 16aa502

Please sign in to comment.