Skip to content

Commit

Permalink
feat: support default UpdateCallback (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rainshaw authored Sep 4, 2022
1 parent 0e70bf2 commit 06fcda7
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 220 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ func main() {

// Set callback to local example
_ = w.SetUpdateCallback(updateCallback)

// Or use the default callback
// _ = w.SetUpdateCallback(rediswatcher.DefaultUpdateCallback(e))

// Update the policy to test the effect.
// You should see "[casbin rules updated]" in the log.
Expand All @@ -77,4 +80,3 @@ func main() {
## License

This project is under Apache 2.0 License. See the [LICENSE](LICENSE) file for the full license text.
>>>>>>> 243bd42 (refactor)
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/casbin/redis-watcher/v2
go 1.18

require (
github.com/casbin/casbin/v2 v2.47.2
github.com/casbin/casbin/v2 v2.53.1
github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.3.0
)
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw=
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
github.com/casbin/casbin/v2 v2.47.2 h1:FVdlX0GEYWpYj7IdSThBpidLr8Bp+yfvlmVNec5INtw=
github.com/casbin/casbin/v2 v2.47.2/go.mod h1:vByNa/Fchek0KZUgG5wEsl7iFsiviAYKRtgrQfcJqHg=
github.com/casbin/casbin/v2 v2.53.1 h1:uD/1LMHEPOkn1Xw5UmLnOJxdBPI7Zz85VbdPLJhivxo=
github.com/casbin/casbin/v2 v2.53.1/go.mod h1:vByNa/Fchek0KZUgG5wEsl7iFsiviAYKRtgrQfcJqHg=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
Expand Down
55 changes: 0 additions & 55 deletions util.go

This file was deleted.

151 changes: 132 additions & 19 deletions watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"strings"
"sync"

"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"

"github.com/casbin/casbin/v2/persist"
rds "github.com/go-redis/redis/v8"
)
Expand All @@ -25,14 +25,66 @@ type Watcher struct {
ctx context.Context
}

func DefaultUpdateCallback(e casbin.IEnforcer) func(string) {
return func(msg string) {
msgStruct := &MSG{}

err := msgStruct.UnmarshalBinary([]byte(msg))
if err != nil {
log.Println(err)
return
}

var res bool
switch msgStruct.Method {
case Update, UpdateForSavePolicy:
err = e.LoadPolicy()
res = true
case UpdateForAddPolicy:
res, err = e.SelfAddPolicy(msgStruct.Sec, msgStruct.Ptype, msgStruct.Rule)
case UpdateForAddPolicies:
res, err = e.SelfAddPolicies(msgStruct.Sec, msgStruct.Ptype, msgStruct.Rules)
case UpdateForRemovePolicy:
res, err = e.SelfRemovePolicy(msgStruct.Sec, msgStruct.Ptype, msgStruct.Rule)
case UpdateForRemoveFilteredPolicy:
res, err = e.SelfRemoveFilteredPolicy(msgStruct.Sec, msgStruct.Ptype, msgStruct.FieldIndex, msgStruct.FieldValues...)
case UpdateForRemovePolicies:
res, err = e.SelfRemovePolicies(msgStruct.Sec, msgStruct.Ptype, msgStruct.Rules)
default:
err = errors.New("unknown update type")
}
if err != nil {
log.Println(err)
}
if !res {
log.Println("callback update policy failed")
}
}
}

type MSG struct {
Method string
ID string
Sec string
Ptype string
Params interface{}
Method UpdateType
ID string
Sec string
Ptype string
Rule []string
Rules [][]string
FieldIndex int
FieldValues []string
}

type UpdateType string

const (
Update UpdateType = "Update"
UpdateForAddPolicy UpdateType = "UpdateForAddPolicy"
UpdateForRemovePolicy UpdateType = "UpdateForRemovePolicy"
UpdateForRemoveFilteredPolicy UpdateType = "UpdateForRemoveFilteredPolicy"
UpdateForSavePolicy UpdateType = "UpdateForSavePolicy"
UpdateForAddPolicies UpdateType = "UpdateForAddPolicies"
UpdateForRemovePolicies UpdateType = "UpdateForRemovePolicies"
)

func (m *MSG) MarshalBinary() ([]byte, error) {
return json.Marshal(m)
}
Expand Down Expand Up @@ -105,7 +157,10 @@ func NewWatcherWithCluster(addrs string, option WatcherOptions) (persist.Watcher
close: make(chan struct{}),
}

w.initConfig(option, true)
err := w.initConfig(option, true)
if err != nil {
return nil, err
}

if err := w.subClient.Ping(w.ctx).Err(); err != nil {
return nil, err
Expand Down Expand Up @@ -171,7 +226,7 @@ func NewPublishWatcher(addr string, option WatcherOptions) (persist.Watcher, err
return w, nil
}

// SetUpdateCallback SetUpdateCallBack sets the update callback function invoked by the watcher
// SetUpdateCallback sets the update callback function invoked by the watcher
// when the policy is updated. Defaults to Enforcer.LoadPolicy()
func (w *Watcher) SetUpdateCallback(callback func(string)) error {
w.l.Lock()
Expand All @@ -186,7 +241,14 @@ func (w *Watcher) Update() error {
return w.logRecord(func() error {
w.l.Lock()
defer w.l.Unlock()
return w.pubClient.Publish(context.Background(), w.options.Channel, &MSG{"Update", w.options.LocalID, "", "", ""}).Err()
return w.pubClient.Publish(
context.Background(),
w.options.Channel,
&MSG{
Method: Update,
ID: w.options.LocalID,
},
).Err()
})
}

Expand All @@ -196,7 +258,16 @@ func (w *Watcher) UpdateForAddPolicy(sec, ptype string, params ...string) error
return w.logRecord(func() error {
w.l.Lock()
defer w.l.Unlock()
return w.pubClient.Publish(context.Background(), w.options.Channel, &MSG{"UpdateForAddPolicy", w.options.LocalID, sec, ptype, params}).Err()
return w.pubClient.Publish(
context.Background(),
w.options.Channel,
&MSG{
Method: UpdateForAddPolicy,
ID: w.options.LocalID,
Sec: sec,
Ptype: ptype,
Rule: params,
}).Err()
})
}

Expand All @@ -206,7 +277,17 @@ func (w *Watcher) UpdateForRemovePolicy(sec, ptype string, params ...string) err
return w.logRecord(func() error {
w.l.Lock()
defer w.l.Unlock()
return w.pubClient.Publish(context.Background(), w.options.Channel, &MSG{"UpdateForRemovePolicy", w.options.LocalID, sec, ptype, params}).Err()
return w.pubClient.Publish(
context.Background(),
w.options.Channel,
&MSG{
Method: UpdateForRemovePolicy,
ID: w.options.LocalID,
Sec: sec,
Ptype: ptype,
Rule: params,
},
).Err()
})
}

Expand All @@ -216,11 +297,16 @@ func (w *Watcher) UpdateForRemoveFilteredPolicy(sec, ptype string, fieldIndex in
return w.logRecord(func() error {
w.l.Lock()
defer w.l.Unlock()
return w.pubClient.Publish(context.Background(), w.options.Channel,
&MSG{"UpdateForRemoveFilteredPolicy", w.options.LocalID,
sec,
ptype,
fmt.Sprintf("%d %s", fieldIndex, strings.Join(fieldValues, " ")),
return w.pubClient.Publish(
context.Background(),
w.options.Channel,
&MSG{
Method: UpdateForRemoveFilteredPolicy,
ID: w.options.LocalID,
Sec: sec,
Ptype: ptype,
FieldIndex: fieldIndex,
FieldValues: fieldValues,
},
).Err()
})
Expand All @@ -232,7 +318,14 @@ func (w *Watcher) UpdateForSavePolicy(model model.Model) error {
return w.logRecord(func() error {
w.l.Lock()
defer w.l.Unlock()
return w.pubClient.Publish(context.Background(), w.options.Channel, &MSG{"UpdateForSavePolicy", w.options.LocalID, "", "", model}).Err()
return w.pubClient.Publish(
context.Background(),
w.options.Channel,
&MSG{
Method: UpdateForSavePolicy,
ID: w.options.LocalID,
},
).Err()
})
}

Expand All @@ -242,7 +335,17 @@ func (w *Watcher) UpdateForAddPolicies(sec string, ptype string, rules ...[]stri
return w.logRecord(func() error {
w.l.Lock()
defer w.l.Unlock()
return w.pubClient.Publish(context.Background(), w.options.Channel, &MSG{"UpdateForAddPolicies", w.options.LocalID, sec, ptype, rules}).Err()
return w.pubClient.Publish(
context.Background(),
w.options.Channel,
&MSG{
Method: UpdateForAddPolicies,
ID: w.options.LocalID,
Sec: sec,
Ptype: ptype,
Rules: rules,
},
).Err()
})
}

Expand All @@ -252,7 +355,17 @@ func (w *Watcher) UpdateForRemovePolicies(sec string, ptype string, rules ...[]s
return w.logRecord(func() error {
w.l.Lock()
defer w.l.Unlock()
return w.pubClient.Publish(context.Background(), w.options.Channel, &MSG{"UpdateForRemovePolicies", w.options.LocalID, sec, ptype, rules}).Err()
return w.pubClient.Publish(
context.Background(),
w.options.Channel,
&MSG{
Method: UpdateForRemovePolicies,
ID: w.options.LocalID,
Sec: sec,
Ptype: ptype,
Rules: rules,
},
).Err()
})
}

Expand Down
Loading

0 comments on commit 06fcda7

Please sign in to comment.