From 649eb598e221e72ced5a08d24448717f22aba76c Mon Sep 17 00:00:00 2001 From: Aditya Kajla Date: Wed, 21 Jun 2023 16:18:02 -0700 Subject: [PATCH] Implement 'wookie' protocol (#163) * Create zookie-like 'wookie' implementation for specifying consistency/freshness on reads * Persist wookies on updates and send back in responses * Fix postgres createWookie query * Send back latest wookie in all check/read calls * Simplify wookie service methods * Check service Check() to use wookie read * Implement wookie service methods * Allow wookieSafeReads() within WithWookieUpdate() * Don't err out if wookie values not present * Add basic wookie token unit test * Address cr comments --- cmd/warrant/main.go | 20 +- .../mysql/000005_create_wookie_table.down.sql | 5 + .../mysql/000005_create_wookie_table.up.sql | 10 + .../000006_create_wookie_table.down.sql | 5 + .../000006_create_wookie_table.up.sql | 9 + .../000005_create_wookie_table.down.sql | 1 + .../sqlite/000005_create_wookie_table.up.sql | 5 + pkg/authz/check/handlers.go | 16 +- pkg/authz/check/service.go | 347 ++++++++++-------- pkg/authz/feature/handlers.go | 4 +- pkg/authz/feature/service.go | 10 +- pkg/authz/object/handlers.go | 4 +- pkg/authz/object/service.go | 11 +- pkg/authz/objecttype/handlers.go | 22 +- pkg/authz/objecttype/service.go | 172 +++++---- pkg/authz/permission/handlers.go | 4 +- pkg/authz/permission/service.go | 10 +- pkg/authz/pricingtier/handlers.go | 4 +- pkg/authz/pricingtier/service.go | 10 +- pkg/authz/role/handlers.go | 4 +- pkg/authz/role/service.go | 10 +- pkg/authz/tenant/handlers.go | 4 +- pkg/authz/tenant/service.go | 10 +- pkg/authz/user/handlers.go | 4 +- pkg/authz/user/service.go | 11 +- pkg/authz/warrant/handlers.go | 11 +- pkg/authz/warrant/service.go | 67 ++-- pkg/authz/wookie/http.go | 34 ++ pkg/authz/wookie/model.go | 36 ++ pkg/authz/wookie/mysql.go | 76 ++++ pkg/authz/wookie/postgres.go | 75 ++++ pkg/authz/wookie/repository.go | 42 +++ pkg/authz/wookie/service.go | 158 ++++++++ pkg/authz/wookie/sqlite.go | 76 ++++ pkg/authz/wookie/token.go | 60 +++ pkg/authz/wookie/token_test.go | 68 ++++ 36 files changed, 1104 insertions(+), 311 deletions(-) create mode 100644 migrations/datastore/mysql/000005_create_wookie_table.down.sql create mode 100644 migrations/datastore/mysql/000005_create_wookie_table.up.sql create mode 100644 migrations/datastore/postgres/000006_create_wookie_table.down.sql create mode 100644 migrations/datastore/postgres/000006_create_wookie_table.up.sql create mode 100644 migrations/datastore/sqlite/000005_create_wookie_table.down.sql create mode 100644 migrations/datastore/sqlite/000005_create_wookie_table.up.sql create mode 100644 pkg/authz/wookie/http.go create mode 100644 pkg/authz/wookie/model.go create mode 100644 pkg/authz/wookie/mysql.go create mode 100644 pkg/authz/wookie/postgres.go create mode 100644 pkg/authz/wookie/repository.go create mode 100644 pkg/authz/wookie/service.go create mode 100644 pkg/authz/wookie/sqlite.go create mode 100644 pkg/authz/wookie/token.go create mode 100644 pkg/authz/wookie/token_test.go diff --git a/cmd/warrant/main.go b/cmd/warrant/main.go index 27206a07..6bf1c054 100644 --- a/cmd/warrant/main.go +++ b/cmd/warrant/main.go @@ -18,6 +18,7 @@ import ( tenant "github.com/warrant-dev/warrant/pkg/authz/tenant" user "github.com/warrant-dev/warrant/pkg/authz/user" warrant "github.com/warrant-dev/warrant/pkg/authz/warrant" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/config" "github.com/warrant-dev/warrant/pkg/database" "github.com/warrant-dev/warrant/pkg/event" @@ -25,11 +26,11 @@ import ( ) const ( - MySQLDatastoreMigrationVersion = 000004 + MySQLDatastoreMigrationVersion = 000005 MySQLEventstoreMigrationVersion = 000003 - PostgresDatastoreMigrationVersion = 000005 + PostgresDatastoreMigrationVersion = 000006 PostgresEventstoreMigrationVersion = 000004 - SQLiteDatastoreMigrationVersion = 000004 + SQLiteDatastoreMigrationVersion = 000005 SQLiteEventstoreMigrationVersion = 000003 ) @@ -195,22 +196,29 @@ func main() { } eventSvc := event.NewService(svcEnv, eventRepository, cfg.Eventstore.SynchronizeEvents, nil) + // Init wookie repo and service + wookieRepository, err := wookie.NewRepository(svcEnv.DB()) + if err != nil { + log.Fatal().Err(err).Msg("Could not initialize WookieRepository") + } + wookieSvc := wookie.NewService(svcEnv, wookieRepository) + // Init object type repo and service objectTypeRepository, err := objecttype.NewRepository(svcEnv.DB()) if err != nil { log.Fatal().Err(err).Msg("Could not initialize ObjectTypeRepository") } - objectTypeSvc := objecttype.NewService(svcEnv, objectTypeRepository, eventSvc) + objectTypeSvc := objecttype.NewService(svcEnv, objectTypeRepository, eventSvc, wookieSvc) // Init warrant repo and service warrantRepository, err := warrant.NewRepository(svcEnv.DB()) if err != nil { log.Fatal().Err(err).Msg("Could not initialize WarrantRepository") } - warrantSvc := warrant.NewService(svcEnv, warrantRepository, eventSvc, objectTypeSvc) + warrantSvc := warrant.NewService(svcEnv, warrantRepository, eventSvc, objectTypeSvc, wookieSvc) // Init check service - checkSvc := check.NewService(svcEnv, warrantRepository, eventSvc, objectTypeSvc) + checkSvc := check.NewService(svcEnv, warrantRepository, eventSvc, objectTypeSvc, wookieSvc) // Init object repo and service objectRepository, err := object.NewRepository(svcEnv.DB()) diff --git a/migrations/datastore/mysql/000005_create_wookie_table.down.sql b/migrations/datastore/mysql/000005_create_wookie_table.down.sql new file mode 100644 index 00000000..665e5da4 --- /dev/null +++ b/migrations/datastore/mysql/000005_create_wookie_table.down.sql @@ -0,0 +1,5 @@ +BEGIN; + +DROP TABLE IF EXISTS wookie; + +COMMIT; diff --git a/migrations/datastore/mysql/000005_create_wookie_table.up.sql b/migrations/datastore/mysql/000005_create_wookie_table.up.sql new file mode 100644 index 00000000..8679fa7d --- /dev/null +++ b/migrations/datastore/mysql/000005_create_wookie_table.up.sql @@ -0,0 +1,10 @@ +BEGIN; + +CREATE TABLE IF NOT EXISTS wookie ( + id bigint NOT NULL AUTO_INCREMENT, + ver bigint NOT NULL, + createdAt timestamp(6) NULL DEFAULT CURRENT_TIMESTAMP(6), + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; + +COMMIT; diff --git a/migrations/datastore/postgres/000006_create_wookie_table.down.sql b/migrations/datastore/postgres/000006_create_wookie_table.down.sql new file mode 100644 index 00000000..665e5da4 --- /dev/null +++ b/migrations/datastore/postgres/000006_create_wookie_table.down.sql @@ -0,0 +1,5 @@ +BEGIN; + +DROP TABLE IF EXISTS wookie; + +COMMIT; diff --git a/migrations/datastore/postgres/000006_create_wookie_table.up.sql b/migrations/datastore/postgres/000006_create_wookie_table.up.sql new file mode 100644 index 00000000..1e24d157 --- /dev/null +++ b/migrations/datastore/postgres/000006_create_wookie_table.up.sql @@ -0,0 +1,9 @@ +BEGIN; + +CREATE TABLE IF NOT EXISTS wookie ( + id bigserial PRIMARY KEY, + ver bigserial, + created_at timestamp(6) NULL DEFAULT CURRENT_TIMESTAMP(6) +); + +COMMIT; diff --git a/migrations/datastore/sqlite/000005_create_wookie_table.down.sql b/migrations/datastore/sqlite/000005_create_wookie_table.down.sql new file mode 100644 index 00000000..81b92eca --- /dev/null +++ b/migrations/datastore/sqlite/000005_create_wookie_table.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS wookie; diff --git a/migrations/datastore/sqlite/000005_create_wookie_table.up.sql b/migrations/datastore/sqlite/000005_create_wookie_table.up.sql new file mode 100644 index 00000000..0c4a0bcb --- /dev/null +++ b/migrations/datastore/sqlite/000005_create_wookie_table.up.sql @@ -0,0 +1,5 @@ +CREATE TABLE IF NOT EXISTS wookie ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ver INTEGER, + createdAt DATETIME DEFAULT CURRENT_TIMESTAMP +); diff --git a/pkg/authz/check/handlers.go b/pkg/authz/check/handlers.go index c5807cd5..a4fd03af 100644 --- a/pkg/authz/check/handlers.go +++ b/pkg/authz/check/handlers.go @@ -5,6 +5,7 @@ import ( objecttype "github.com/warrant-dev/warrant/pkg/authz/objecttype" warrant "github.com/warrant-dev/warrant/pkg/authz/warrant" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -12,9 +13,12 @@ func (svc CheckService) Routes() ([]service.Route, error) { return []service.Route{ // Standard Authorization service.WarrantRoute{ - Pattern: "/v2/authorize", - Method: "POST", - Handler: service.NewRouteHandler(svc, AuthorizeHandler), + Pattern: "/v2/authorize", + Method: "POST", + Handler: service.ChainMiddleware( + service.NewRouteHandler(svc, AuthorizeHandler), + wookie.ClientTokenMiddleware, + ), OverrideAuthMiddlewareFunc: service.ApiKeyAndSessionAuthMiddleware, }, }, nil @@ -49,10 +53,11 @@ func AuthorizeHandler(svc CheckService, w http.ResponseWriter, r *http.Request) Debug: sessionCheckManySpec.Debug, } - checkResult, err := svc.CheckMany(r.Context(), authInfo, &checkManySpec) + checkResult, updatedWookie, err := svc.CheckMany(r.Context(), authInfo, &checkManySpec) if err != nil { return err } + wookie.AddAsResponseHeader(w, updatedWookie) service.SendJSONResponse(w, checkResult) return nil @@ -64,10 +69,11 @@ func AuthorizeHandler(svc CheckService, w http.ResponseWriter, r *http.Request) return err } - checkResult, err := svc.CheckMany(r.Context(), authInfo, &checkManySpec) + checkResult, updatedWookie, err := svc.CheckMany(r.Context(), authInfo, &checkManySpec) if err != nil { return err } + wookie.AddAsResponseHeader(w, updatedWookie) service.SendJSONResponse(w, checkResult) return nil diff --git a/pkg/authz/check/service.go b/pkg/authz/check/service.go index 6f58af1f..ce3418c6 100644 --- a/pkg/authz/check/service.go +++ b/pkg/authz/check/service.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog/log" objecttype "github.com/warrant-dev/warrant/pkg/authz/objecttype" warrant "github.com/warrant-dev/warrant/pkg/authz/warrant" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -17,14 +18,16 @@ type CheckService struct { WarrantRepository warrant.WarrantRepository EventSvc event.EventService ObjectTypeSvc objecttype.ObjectTypeService + WookieSvc wookie.WookieService } -func NewService(env service.Env, warrantRepo warrant.WarrantRepository, eventSvc event.EventService, objectTypeSvc objecttype.ObjectTypeService) CheckService { +func NewService(env service.Env, warrantRepo warrant.WarrantRepository, eventSvc event.EventService, objectTypeSvc objecttype.ObjectTypeService, wookieSvc wookie.WookieService) CheckService { return CheckService{ BaseService: service.NewBaseService(env), WarrantRepository: warrantRepo, EventSvc: eventSvc, ObjectTypeSvc: objectTypeSvc, + WookieSvc: wookieSvc, } } @@ -56,7 +59,7 @@ func (svc CheckService) getMatchingSubjects(ctx context.Context, objectType stri log.Ctx(ctx).Debug().Msgf("Getting matching subjects for %s:%s#%s@___%s", objectType, objectId, relation, checkCtx) warrantSpecs := make([]warrant.WarrantSpec, 0) - objectTypeSpec, err := svc.ObjectTypeSvc.GetByTypeId(ctx, objectType) + objectTypeSpec, _, err := svc.ObjectTypeSvc.GetByTypeId(ctx, objectType) if err != nil { return warrantSpecs, err } @@ -96,7 +99,7 @@ func (svc CheckService) getMatchingSubjectsBySubjectType(ctx context.Context, ob log.Ctx(ctx).Debug().Msgf("Getting matching subjects for %s:%s#%s@%s:___%s", objectType, objectId, relation, subjectType, checkCtx) warrantSpecs := make([]warrant.WarrantSpec, 0) - objectTypeSpec, err := svc.ObjectTypeSvc.GetByTypeId(ctx, objectType) + objectTypeSpec, _, err := svc.ObjectTypeSvc.GetByTypeId(ctx, objectType) if err != nil { return warrantSpecs, err } @@ -187,7 +190,7 @@ func (svc CheckService) checkRule(ctx context.Context, authInfo *service.AuthInf return true, decisionPath, nil default: if rule.OfType == "" && rule.WithRelation == "" { - return svc.Check(ctx, authInfo, CheckSpec{ + match, decisionPath, _, err := svc.Check(ctx, authInfo, CheckSpec{ CheckWarrantSpec: CheckWarrantSpec{ ObjectType: warrantSpec.ObjectType, ObjectId: warrantSpec.ObjectId, @@ -197,6 +200,7 @@ func (svc CheckService) checkRule(ctx context.Context, authInfo *service.AuthInf }, Debug: warrantCheck.Debug, }) + return match, decisionPath, err } matchingWarrants, err := svc.getMatchingSubjectsBySubjectType(ctx, warrantSpec.ObjectType, warrantSpec.ObjectId, rule.WithRelation, rule.OfType, warrantSpec.Context) @@ -205,7 +209,7 @@ func (svc CheckService) checkRule(ctx context.Context, authInfo *service.AuthInf } for _, matchingWarrant := range matchingWarrants { - match, decisionPath, err := svc.Check(ctx, authInfo, CheckSpec{ + match, decisionPath, _, err := svc.Check(ctx, authInfo, CheckSpec{ CheckWarrantSpec: CheckWarrantSpec{ ObjectType: matchingWarrant.Subject.ObjectType, ObjectId: matchingWarrant.Subject.ObjectId, @@ -229,157 +233,164 @@ func (svc CheckService) checkRule(ctx context.Context, authInfo *service.AuthInf } } -func (svc CheckService) CheckMany(ctx context.Context, authInfo *service.AuthInfo, warrantCheck *CheckManySpec) (*CheckResultSpec, error) { +func (svc CheckService) CheckMany(ctx context.Context, authInfo *service.AuthInfo, warrantCheck *CheckManySpec) (*CheckResultSpec, *wookie.Token, error) { start := time.Now().UTC() if warrantCheck.Op != "" && warrantCheck.Op != objecttype.InheritIfAllOf && warrantCheck.Op != objecttype.InheritIfAnyOf { - return nil, service.NewInvalidParameterError("op", "must be either anyOf or allOf") + return nil, nil, service.NewInvalidParameterError("op", "must be either anyOf or allOf") } var checkResult CheckResultSpec checkResult.DecisionPath = make(map[string][]warrant.WarrantSpec, 0) - if warrantCheck.Op == objecttype.InheritIfAllOf { - var processingTime int64 - for _, warrantSpec := range warrantCheck.Warrants { - match, decisionPath, err := svc.Check(ctx, authInfo, CheckSpec{ - CheckWarrantSpec: warrantSpec, - Debug: warrantCheck.Debug, - }) - if err != nil { - return nil, err - } - if warrantCheck.Debug { - checkResult.ProcessingTime = processingTime + time.Since(start).Milliseconds() - if len(decisionPath) > 0 { - checkResult.DecisionPath[warrantSpec.String()] = decisionPath + newWookie, e := svc.WookieSvc.WookieSafeRead(ctx, func(wkCtx context.Context) error { + if warrantCheck.Op == objecttype.InheritIfAllOf { + var processingTime int64 + for _, warrantSpec := range warrantCheck.Warrants { + match, decisionPath, _, err := svc.Check(wkCtx, authInfo, CheckSpec{ + CheckWarrantSpec: warrantSpec, + Debug: warrantCheck.Debug, + }) + if err != nil { + return err } - } - var eventMeta map[string]interface{} - if warrantSpec.Context != nil { - eventMeta = make(map[string]interface{}) - eventMeta["context"] = warrantSpec.Context - } + if warrantCheck.Debug { + checkResult.ProcessingTime = processingTime + time.Since(start).Milliseconds() + if len(decisionPath) > 0 { + checkResult.DecisionPath[warrantSpec.String()] = decisionPath + } + } - if !match { - err = svc.EventSvc.TrackAccessDeniedEvent(ctx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) - if err != nil { - return nil, err + var eventMeta map[string]interface{} + if warrantSpec.Context != nil { + eventMeta = make(map[string]interface{}) + eventMeta["context"] = warrantSpec.Context } - checkResult.Code = http.StatusForbidden - checkResult.Result = NotAuthorized - return &checkResult, nil - } + if !match { + err = svc.EventSvc.TrackAccessDeniedEvent(wkCtx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) + if err != nil { + return err + } - err = svc.EventSvc.TrackAccessAllowedEvent(ctx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) - if err != nil { - return nil, err + checkResult.Code = http.StatusForbidden + checkResult.Result = NotAuthorized + return nil + } + + err = svc.EventSvc.TrackAccessAllowedEvent(wkCtx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) + if err != nil { + return err + } } + + checkResult.Code = http.StatusOK + checkResult.Result = Authorized + return nil } - checkResult.Code = http.StatusOK - checkResult.Result = Authorized - return &checkResult, nil - } + if warrantCheck.Op == objecttype.InheritIfAnyOf { + var processingTime int64 + for _, warrantSpec := range warrantCheck.Warrants { + match, decisionPath, _, err := svc.Check(wkCtx, authInfo, CheckSpec{ + CheckWarrantSpec: warrantSpec, + Debug: warrantCheck.Debug, + }) + if err != nil { + return err + } - if warrantCheck.Op == objecttype.InheritIfAnyOf { - var processingTime int64 - for _, warrantSpec := range warrantCheck.Warrants { - match, decisionPath, err := svc.Check(ctx, authInfo, CheckSpec{ - CheckWarrantSpec: warrantSpec, - Debug: warrantCheck.Debug, - }) - if err != nil { - return nil, err - } + if warrantCheck.Debug { + checkResult.ProcessingTime = processingTime + time.Since(start).Milliseconds() + if len(decisionPath) > 0 { + checkResult.DecisionPath[warrantSpec.String()] = decisionPath + } + } - if warrantCheck.Debug { - checkResult.ProcessingTime = processingTime + time.Since(start).Milliseconds() - if len(decisionPath) > 0 { - checkResult.DecisionPath[warrantSpec.String()] = decisionPath + var eventMeta map[string]interface{} + if warrantSpec.Context != nil { + eventMeta = make(map[string]interface{}) + eventMeta["context"] = warrantSpec.Context } - } - var eventMeta map[string]interface{} - if warrantSpec.Context != nil { - eventMeta = make(map[string]interface{}) - eventMeta["context"] = warrantSpec.Context - } + if match { + err = svc.EventSvc.TrackAccessAllowedEvent(wkCtx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) + if err != nil { + return err + } - if match { - err = svc.EventSvc.TrackAccessAllowedEvent(ctx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) - if err != nil { - return nil, err + checkResult.Code = http.StatusOK + checkResult.Result = Authorized + return nil } - checkResult.Code = http.StatusOK - checkResult.Result = Authorized - return &checkResult, nil - } - - if !match { - err := svc.EventSvc.TrackAccessDeniedEvent(ctx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) - if err != nil { - return nil, err + if !match { + err := svc.EventSvc.TrackAccessDeniedEvent(wkCtx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) + if err != nil { + return err + } } } + + checkResult.Code = http.StatusForbidden + checkResult.Result = NotAuthorized + return nil } - checkResult.Code = http.StatusForbidden - checkResult.Result = NotAuthorized - return &checkResult, nil - } + if len(warrantCheck.Warrants) > 1 { + return service.NewInvalidParameterError("warrants", "must include operator when including multiple warrants") + } - if len(warrantCheck.Warrants) > 1 { - return nil, service.NewInvalidParameterError("warrants", "must include operator when including multiple warrants") - } + warrantSpec := warrantCheck.Warrants[0] + match, decisionPath, _, err := svc.Check(wkCtx, authInfo, CheckSpec{ + CheckWarrantSpec: warrantSpec, + Debug: warrantCheck.Debug, + }) + if err != nil { + return err + } - warrantSpec := warrantCheck.Warrants[0] - match, decisionPath, err := svc.Check(ctx, authInfo, CheckSpec{ - CheckWarrantSpec: warrantSpec, - Debug: warrantCheck.Debug, - }) - if err != nil { - return nil, err - } + if warrantCheck.Debug { + checkResult.ProcessingTime = time.Since(start).Milliseconds() + if len(decisionPath) > 0 { + checkResult.DecisionPath[warrantSpec.String()] = decisionPath + } + } - if warrantCheck.Debug { - checkResult.ProcessingTime = time.Since(start).Milliseconds() - if len(decisionPath) > 0 { - checkResult.DecisionPath[warrantSpec.String()] = decisionPath + var eventMeta map[string]interface{} + if warrantSpec.Context != nil { + eventMeta = make(map[string]interface{}) + eventMeta["context"] = warrantSpec.Context } - } - var eventMeta map[string]interface{} - if warrantSpec.Context != nil { - eventMeta = make(map[string]interface{}) - eventMeta["context"] = warrantSpec.Context - } + if match { + err = svc.EventSvc.TrackAccessAllowedEvent(wkCtx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) + if err != nil { + return err + } - if match { - err = svc.EventSvc.TrackAccessAllowedEvent(ctx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) - if err != nil { - return nil, err + checkResult.Code = http.StatusOK + checkResult.Result = Authorized + return nil } - checkResult.Code = http.StatusOK - checkResult.Result = Authorized - return &checkResult, nil - } + err = svc.EventSvc.TrackAccessDeniedEvent(wkCtx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) + if err != nil { + return err + } - err = svc.EventSvc.TrackAccessDeniedEvent(ctx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, eventMeta) - if err != nil { - return nil, err + checkResult.Code = http.StatusForbidden + checkResult.Result = NotAuthorized + return nil + }) + if e != nil { + return nil, nil, e } - - checkResult.Code = http.StatusForbidden - checkResult.Result = NotAuthorized - return &checkResult, nil + return &checkResult, newWookie, nil } // Check returns true if the subject has a warrant (explicitly or implicitly) for given objectType:objectId#relation and context -func (svc CheckService) Check(ctx context.Context, authInfo *service.AuthInfo, warrantCheck CheckSpec) (match bool, decisionPath []warrant.WarrantSpec, err error) { +func (svc CheckService) Check(ctx context.Context, authInfo *service.AuthInfo, warrantCheck CheckSpec) (bool, []warrant.WarrantSpec, *wookie.Token, error) { log.Ctx(ctx).Debug().Msgf("Checking for warrant %s", warrantCheck.String()) // Used to automatically append tenant context for session token w/ tenantId checks @@ -387,64 +398,76 @@ func (svc CheckService) Check(ctx context.Context, authInfo *service.AuthInfo, w svc.appendTenantContext(&warrantCheck, authInfo.TenantId) } - // Check for direct warrant match -> doc:readme#viewer@[10] - matchedWarrant, err := svc.getWithPolicyMatch(ctx, warrantCheck.CheckWarrantSpec) - if err != nil { - return false, decisionPath, err - } + var match bool + decisionPath := make([]warrant.WarrantSpec, 0) + var newWookie *wookie.Token + newWookie, e := svc.WookieSvc.WookieSafeRead(ctx, func(wkCtx context.Context) error { + // Check for direct warrant match -> doc:readme#viewer@[10] + matchedWarrant, err := svc.getWithPolicyMatch(ctx, warrantCheck.CheckWarrantSpec) + if err != nil { + return err + } - if matchedWarrant != nil { - return true, []warrant.WarrantSpec{*matchedWarrant}, nil - } + if matchedWarrant != nil { + match = true + decisionPath = []warrant.WarrantSpec{*matchedWarrant} + return nil + } - // Check against indirectly related warrants - matchingWarrants, err := svc.getMatchingSubjects(ctx, warrantCheck.ObjectType, warrantCheck.ObjectId, warrantCheck.Relation, warrantCheck.Context) - if err != nil { - return false, decisionPath, err - } + // Check against indirectly related warrants + matchingWarrants, err := svc.getMatchingSubjects(ctx, warrantCheck.ObjectType, warrantCheck.ObjectId, warrantCheck.Relation, warrantCheck.Context) + if err != nil { + return err + } - for _, matchingWarrant := range matchingWarrants { - if matchingWarrant.Subject.Relation == "" { - continue + for _, matchingWarrant := range matchingWarrants { + if matchingWarrant.Subject.Relation == "" { + continue + } + + match, decisionPath, _, err = svc.Check(ctx, authInfo, CheckSpec{ + CheckWarrantSpec: CheckWarrantSpec{ + ObjectType: matchingWarrant.Subject.ObjectType, + ObjectId: matchingWarrant.Subject.ObjectId, + Relation: matchingWarrant.Subject.Relation, + Subject: warrantCheck.Subject, + Context: warrantCheck.Context, + }, + Debug: warrantCheck.Debug, + }) + if err != nil { + return err + } + + if match { + decisionPath = append(decisionPath, matchingWarrant) + return nil + } } - match, decisionPath, err := svc.Check(ctx, authInfo, CheckSpec{ - CheckWarrantSpec: CheckWarrantSpec{ - ObjectType: matchingWarrant.Subject.ObjectType, - ObjectId: matchingWarrant.Subject.ObjectId, - Relation: matchingWarrant.Subject.Relation, - Subject: warrantCheck.Subject, - Context: warrantCheck.Context, - }, - Debug: warrantCheck.Debug, - }) + // Attempt to match against defined rules for target relation + objectTypeSpec, _, err := svc.ObjectTypeSvc.GetByTypeId(ctx, warrantCheck.ObjectType) if err != nil { - return false, decisionPath, err + return err } - if match { - decisionPath = append(decisionPath, matchingWarrant) - return true, decisionPath, nil + relationRule := objectTypeSpec.Relations[warrantCheck.Relation] + match, decisionPath, err = svc.checkRule(ctx, authInfo, warrantCheck, &relationRule) + if err != nil { + return err } - } - - // Attempt to match against defined rules for target relation - objectTypeSpec, err := svc.ObjectTypeSvc.GetByTypeId(ctx, warrantCheck.ObjectType) - if err != nil { - return false, decisionPath, err - } - relationRule := objectTypeSpec.Relations[warrantCheck.Relation] - match, decisionPath, err = svc.checkRule(ctx, authInfo, warrantCheck, &relationRule) - if err != nil { - return false, decisionPath, err - } + if match { + return nil + } - if match { - return true, decisionPath, nil + match = false + return nil + }) + if e != nil { + return false, decisionPath, nil, e } - - return false, decisionPath, nil + return match, decisionPath, newWookie, nil } func (svc CheckService) appendTenantContext(warrantCheck *CheckSpec, tenantId string) { diff --git a/pkg/authz/feature/handlers.go b/pkg/authz/feature/handlers.go index 20de28fe..ecaed7b5 100644 --- a/pkg/authz/feature/handlers.go +++ b/pkg/authz/feature/handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gorilla/mux" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -114,10 +115,11 @@ func DeleteHandler(svc FeatureService, w http.ResponseWriter, r *http.Request) e return service.NewMissingRequiredParameterError("featureId") } - err := svc.DeleteByFeatureId(r.Context(), featureId) + newWookie, err := svc.DeleteByFeatureId(r.Context(), featureId) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) return nil } diff --git a/pkg/authz/feature/service.go b/pkg/authz/feature/service.go index 252dfcba..39d3e66e 100644 --- a/pkg/authz/feature/service.go +++ b/pkg/authz/feature/service.go @@ -5,6 +5,7 @@ import ( object "github.com/warrant-dev/warrant/pkg/authz/object" objecttype "github.com/warrant-dev/warrant/pkg/authz/objecttype" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -114,14 +115,15 @@ func (svc FeatureService) UpdateByFeatureId(ctx context.Context, featureId strin return updatedFeatureSpec, nil } -func (svc FeatureService) DeleteByFeatureId(ctx context.Context, featureId string) error { +func (svc FeatureService) DeleteByFeatureId(ctx context.Context, featureId string) (*wookie.Token, error) { + var newWookie *wookie.Token err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { err := svc.Repository.DeleteByFeatureId(txCtx, featureId) if err != nil { return err } - err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeFeature, featureId) + newWookie, err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeFeature, featureId) if err != nil { return err } @@ -134,8 +136,8 @@ func (svc FeatureService) DeleteByFeatureId(ctx context.Context, featureId strin return nil }) if err != nil { - return err + return nil, err } - return nil + return newWookie, nil } diff --git a/pkg/authz/object/handlers.go b/pkg/authz/object/handlers.go index 48160949..c5cba522 100644 --- a/pkg/authz/object/handlers.go +++ b/pkg/authz/object/handlers.go @@ -5,6 +5,7 @@ import ( "net/url" "github.com/gorilla/mux" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -90,10 +91,11 @@ func GetHandler(svc ObjectService, w http.ResponseWriter, r *http.Request) error func DeleteHandler(svc ObjectService, w http.ResponseWriter, r *http.Request) error { objectType := mux.Vars(r)["objectType"] objectId := mux.Vars(r)["objectId"] - err := svc.DeleteByObjectTypeAndId(r.Context(), objectType, objectId) + newWookie, err := svc.DeleteByObjectTypeAndId(r.Context(), objectType, objectId) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) w.Header().Set("Content-type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/pkg/authz/object/service.go b/pkg/authz/object/service.go index d02e70d9..90c720bc 100644 --- a/pkg/authz/object/service.go +++ b/pkg/authz/object/service.go @@ -5,6 +5,7 @@ import ( "fmt" warrant "github.com/warrant-dev/warrant/pkg/authz/warrant" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -44,14 +45,15 @@ func (svc ObjectService) Create(ctx context.Context, objectSpec ObjectSpec) (*Ob return newObject.ToObjectSpec(), nil } -func (svc ObjectService) DeleteByObjectTypeAndId(ctx context.Context, objectType string, objectId string) error { +func (svc ObjectService) DeleteByObjectTypeAndId(ctx context.Context, objectType string, objectId string) (*wookie.Token, error) { + var newWookie *wookie.Token err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { err := svc.Repository.DeleteByObjectTypeAndId(txCtx, objectType, objectId) if err != nil { return err } - err = svc.WarrantSvc.DeleteRelatedWarrants(txCtx, objectType, objectId) + newWookie, err = svc.WarrantSvc.DeleteRelatedWarrants(txCtx, objectType, objectId) if err != nil { return err } @@ -59,7 +61,10 @@ func (svc ObjectService) DeleteByObjectTypeAndId(ctx context.Context, objectType return nil }) - return err + if err != nil { + return nil, err + } + return newWookie, nil } func (svc ObjectService) GetByObjectTypeAndId(ctx context.Context, objectType string, objectId string) (*ObjectSpec, error) { diff --git a/pkg/authz/objecttype/handlers.go b/pkg/authz/objecttype/handlers.go index 2b464594..96c206fa 100644 --- a/pkg/authz/objecttype/handlers.go +++ b/pkg/authz/objecttype/handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gorilla/mux" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -23,6 +24,7 @@ func (svc ObjectTypeService) Routes() ([]service.Route, error) { Method: "GET", Handler: service.ChainMiddleware( service.NewRouteHandler(svc, ListHandler), + wookie.ClientTokenMiddleware, service.ListMiddleware[ObjectTypeListParamParser], ), }, @@ -31,7 +33,10 @@ func (svc ObjectTypeService) Routes() ([]service.Route, error) { service.WarrantRoute{ Pattern: "/v1/object-types/{type}", Method: "GET", - Handler: service.NewRouteHandler(svc, GetHandler), + Handler: service.ChainMiddleware( + service.NewRouteHandler(svc, GetHandler), + wookie.ClientTokenMiddleware, + ), }, // update @@ -62,10 +67,11 @@ func CreateHandler(svc ObjectTypeService, w http.ResponseWriter, r *http.Request return err } - createdObjectTypeSpec, err := svc.Create(r.Context(), objectTypeSpec) + createdObjectTypeSpec, newWookie, err := svc.Create(r.Context(), objectTypeSpec) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) service.SendJSONResponse(w, createdObjectTypeSpec) return nil @@ -73,10 +79,11 @@ func CreateHandler(svc ObjectTypeService, w http.ResponseWriter, r *http.Request func ListHandler(svc ObjectTypeService, w http.ResponseWriter, r *http.Request) error { listParams := service.GetListParamsFromContext[ObjectTypeListParamParser](r.Context()) - objectTypeSpecs, err := svc.List(r.Context(), listParams) + objectTypeSpecs, updatedWookie, err := svc.List(r.Context(), listParams) if err != nil { return err } + wookie.AddAsResponseHeader(w, updatedWookie) service.SendJSONResponse(w, objectTypeSpecs) return nil @@ -84,10 +91,11 @@ func ListHandler(svc ObjectTypeService, w http.ResponseWriter, r *http.Request) func GetHandler(svc ObjectTypeService, w http.ResponseWriter, r *http.Request) error { typeId := mux.Vars(r)["type"] - objectTypeSpec, err := svc.GetByTypeId(r.Context(), typeId) + objectTypeSpec, updatedWookie, err := svc.GetByTypeId(r.Context(), typeId) if err != nil { return err } + wookie.AddAsResponseHeader(w, updatedWookie) service.SendJSONResponse(w, objectTypeSpec) return nil @@ -101,10 +109,11 @@ func UpdateHandler(svc ObjectTypeService, w http.ResponseWriter, r *http.Request } typeId := mux.Vars(r)["type"] - updatedObjectTypeSpec, err := svc.UpdateByTypeId(r.Context(), typeId, objectTypeSpec) + updatedObjectTypeSpec, newWookie, err := svc.UpdateByTypeId(r.Context(), typeId, objectTypeSpec) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) service.SendJSONResponse(w, updatedObjectTypeSpec) return nil @@ -112,10 +121,11 @@ func UpdateHandler(svc ObjectTypeService, w http.ResponseWriter, r *http.Request func DeleteHandler(svc ObjectTypeService, w http.ResponseWriter, r *http.Request) error { typeId := mux.Vars(r)["type"] - err := svc.DeleteByTypeId(r.Context(), typeId) + newWookie, err := svc.DeleteByTypeId(r.Context(), typeId) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) w.Header().Set("Content-type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/pkg/authz/objecttype/service.go b/pkg/authz/objecttype/service.go index 599e8e44..b5cdd7fd 100644 --- a/pkg/authz/objecttype/service.go +++ b/pkg/authz/objecttype/service.go @@ -3,6 +3,7 @@ package authz import ( "context" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -13,118 +14,155 @@ type ObjectTypeService struct { service.BaseService Repository ObjectTypeRepository EventSvc event.EventService + WookieSvc wookie.WookieService } -func NewService(env service.Env, repository ObjectTypeRepository, eventSvc event.EventService) ObjectTypeService { +func NewService(env service.Env, repository ObjectTypeRepository, eventSvc event.EventService, wookieSvc wookie.WookieService) ObjectTypeService { return ObjectTypeService{ BaseService: service.NewBaseService(env), Repository: repository, EventSvc: eventSvc, + WookieSvc: wookieSvc, } } -func (svc ObjectTypeService) Create(ctx context.Context, objectTypeSpec ObjectTypeSpec) (*ObjectTypeSpec, error) { +func (svc ObjectTypeService) Create(ctx context.Context, objectTypeSpec ObjectTypeSpec) (*ObjectTypeSpec, *wookie.Token, error) { _, err := svc.Repository.GetByTypeId(ctx, objectTypeSpec.Type) if err == nil { - return nil, service.NewDuplicateRecordError("ObjectType", objectTypeSpec.Type, "An objectType with the given type already exists") + return nil, nil, service.NewDuplicateRecordError("ObjectType", objectTypeSpec.Type, "An objectType with the given type already exists") } objectType, err := objectTypeSpec.ToObjectType() if err != nil { - return nil, err + return nil, nil, err } - newObjectTypeId, err := svc.Repository.Create(ctx, objectType) - if err != nil { - return nil, err - } - - newObjectType, err := svc.Repository.GetById(ctx, newObjectTypeId) - if err != nil { - return nil, err - } + var newObjectTypeSpec *ObjectTypeSpec + newWookie, e := svc.WookieSvc.WithWookieUpdate(ctx, func(txCtx context.Context) error { + newObjectTypeId, err := svc.Repository.Create(txCtx, objectType) + if err != nil { + return err + } - newObjectTypeSpec, err := newObjectType.ToObjectTypeSpec() - if err != nil { - return nil, err - } + newObjectType, err := svc.Repository.GetById(txCtx, newObjectTypeId) + if err != nil { + return err + } - err = svc.EventSvc.TrackResourceCreated(ctx, ResourceTypeObjectType, newObjectType.GetTypeId(), newObjectTypeSpec) - if err != nil { - return nil, err - } + newObjectTypeSpec, err = newObjectType.ToObjectTypeSpec() + if err != nil { + return err + } - return newObjectTypeSpec, nil -} + err = svc.EventSvc.TrackResourceCreated(txCtx, ResourceTypeObjectType, newObjectType.GetTypeId(), newObjectTypeSpec) + if err != nil { + return err + } -func (svc ObjectTypeService) GetByTypeId(ctx context.Context, typeId string) (*ObjectTypeSpec, error) { - objectType, err := svc.Repository.GetByTypeId(ctx, typeId) - if err != nil { - return nil, err + return nil + }) + if e != nil { + return nil, nil, e } - - return objectType.ToObjectTypeSpec() + return newObjectTypeSpec, newWookie, nil } -func (svc ObjectTypeService) List(ctx context.Context, listParams service.ListParams) ([]ObjectTypeSpec, error) { - objectTypes, err := svc.Repository.List(ctx, listParams) - if err != nil { - return nil, err +func (svc ObjectTypeService) GetByTypeId(ctx context.Context, typeId string) (*ObjectTypeSpec, *wookie.Token, error) { + var objectTypeSpec *ObjectTypeSpec + newWookie, e := svc.WookieSvc.WookieSafeRead(ctx, func(wkCtx context.Context) error { + objectType, err := svc.Repository.GetByTypeId(wkCtx, typeId) + if err != nil { + return err + } + + objectTypeSpec, err = objectType.ToObjectTypeSpec() + if err != nil { + return err + } + return nil + }) + if e != nil { + return nil, nil, e } + return objectTypeSpec, newWookie, e +} +func (svc ObjectTypeService) List(ctx context.Context, listParams service.ListParams) ([]ObjectTypeSpec, *wookie.Token, error) { objectTypeSpecs := make([]ObjectTypeSpec, 0) - for _, objectType := range objectTypes { - objectTypeSpec, err := objectType.ToObjectTypeSpec() + newWookie, e := svc.WookieSvc.WookieSafeRead(ctx, func(wkCtx context.Context) error { + objectTypes, err := svc.Repository.List(wkCtx, listParams) if err != nil { - return objectTypeSpecs, err + return err } - objectTypeSpecs = append(objectTypeSpecs, *objectTypeSpec) - } + for _, objectType := range objectTypes { + objectTypeSpec, err := objectType.ToObjectTypeSpec() + if err != nil { + return err + } - return objectTypeSpecs, nil + objectTypeSpecs = append(objectTypeSpecs, *objectTypeSpec) + } + + return nil + }) + if e != nil { + return nil, nil, e + } + return objectTypeSpecs, newWookie, nil } -func (svc ObjectTypeService) UpdateByTypeId(ctx context.Context, typeId string, objectTypeSpec ObjectTypeSpec) (*ObjectTypeSpec, error) { +func (svc ObjectTypeService) UpdateByTypeId(ctx context.Context, typeId string, objectTypeSpec ObjectTypeSpec) (*ObjectTypeSpec, *wookie.Token, error) { currentObjectType, err := svc.Repository.GetByTypeId(ctx, typeId) if err != nil { - return nil, err + return nil, nil, err } - updateTo, err := objectTypeSpec.ToObjectType() if err != nil { - return nil, err + return nil, nil, err } - currentObjectType.SetDefinition(updateTo.Definition) - err = svc.Repository.UpdateByTypeId(ctx, typeId, currentObjectType) - if err != nil { - return nil, err - } + var updatedObjectTypeSpec *ObjectTypeSpec + newWookie, e := svc.WookieSvc.WithWookieUpdate(ctx, func(txCtx context.Context) error { + err := svc.Repository.UpdateByTypeId(txCtx, typeId, currentObjectType) + if err != nil { + return err + } - updatedObjectTypeSpec, err := svc.GetByTypeId(ctx, typeId) - if err != nil { - return nil, err - } + updatedObjectTypeSpec, _, err = svc.GetByTypeId(txCtx, typeId) + if err != nil { + return err + } - err = svc.EventSvc.TrackResourceUpdated(ctx, ResourceTypeObjectType, typeId, updatedObjectTypeSpec) - if err != nil { - return nil, err - } + err = svc.EventSvc.TrackResourceUpdated(txCtx, ResourceTypeObjectType, typeId, updatedObjectTypeSpec) + if err != nil { + return err + } - return updatedObjectTypeSpec, nil + return nil + }) + if e != nil { + return nil, nil, e + } + return updatedObjectTypeSpec, newWookie, nil } -func (svc ObjectTypeService) DeleteByTypeId(ctx context.Context, typeId string) error { - err := svc.Repository.DeleteByTypeId(ctx, typeId) - if err != nil { - return err - } +func (svc ObjectTypeService) DeleteByTypeId(ctx context.Context, typeId string) (*wookie.Token, error) { + newWookie, e := svc.WookieSvc.WithWookieUpdate(ctx, func(txCtx context.Context) error { + err := svc.Repository.DeleteByTypeId(txCtx, typeId) + if err != nil { + return err + } - err = svc.EventSvc.TrackResourceDeleted(ctx, ResourceTypeObjectType, typeId, nil) - if err != nil { - return err - } + err = svc.EventSvc.TrackResourceDeleted(txCtx, ResourceTypeObjectType, typeId, nil) + if err != nil { + return err + } - return nil + return nil + }) + if e != nil { + return nil, e + } + return newWookie, nil } diff --git a/pkg/authz/permission/handlers.go b/pkg/authz/permission/handlers.go index f4721be9..477f556c 100644 --- a/pkg/authz/permission/handlers.go +++ b/pkg/authz/permission/handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gorilla/mux" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -114,10 +115,11 @@ func DeleteHandler(svc PermissionService, w http.ResponseWriter, r *http.Request return service.NewMissingRequiredParameterError("permissionId") } - err := svc.DeleteByPermissionId(r.Context(), permissionId) + newWookie, err := svc.DeleteByPermissionId(r.Context(), permissionId) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) return nil } diff --git a/pkg/authz/permission/service.go b/pkg/authz/permission/service.go index a29dcec5..6e391b83 100644 --- a/pkg/authz/permission/service.go +++ b/pkg/authz/permission/service.go @@ -5,6 +5,7 @@ import ( object "github.com/warrant-dev/warrant/pkg/authz/object" objecttype "github.com/warrant-dev/warrant/pkg/authz/objecttype" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -115,14 +116,15 @@ func (svc PermissionService) UpdateByPermissionId(ctx context.Context, permissio return updatedPermissionSpec, nil } -func (svc PermissionService) DeleteByPermissionId(ctx context.Context, permissionId string) error { +func (svc PermissionService) DeleteByPermissionId(ctx context.Context, permissionId string) (*wookie.Token, error) { + var newWookie *wookie.Token err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { err := svc.Repository.DeleteByPermissionId(txCtx, permissionId) if err != nil { return err } - err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypePermission, permissionId) + newWookie, err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypePermission, permissionId) if err != nil { return err } @@ -135,8 +137,8 @@ func (svc PermissionService) DeleteByPermissionId(ctx context.Context, permissio return nil }) if err != nil { - return err + return nil, err } - return nil + return newWookie, nil } diff --git a/pkg/authz/pricingtier/handlers.go b/pkg/authz/pricingtier/handlers.go index 0a6c2cfa..fb0464a7 100644 --- a/pkg/authz/pricingtier/handlers.go +++ b/pkg/authz/pricingtier/handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gorilla/mux" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -114,10 +115,11 @@ func DeleteHandler(svc PricingTierService, w http.ResponseWriter, r *http.Reques return service.NewMissingRequiredParameterError("pricingTierId") } - err := svc.DeleteByPricingTierId(r.Context(), pricingTierId) + newWookie, err := svc.DeleteByPricingTierId(r.Context(), pricingTierId) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) return nil } diff --git a/pkg/authz/pricingtier/service.go b/pkg/authz/pricingtier/service.go index 352c17ba..dc1b7fcd 100644 --- a/pkg/authz/pricingtier/service.go +++ b/pkg/authz/pricingtier/service.go @@ -5,6 +5,7 @@ import ( object "github.com/warrant-dev/warrant/pkg/authz/object" objecttype "github.com/warrant-dev/warrant/pkg/authz/objecttype" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -115,14 +116,15 @@ func (svc PricingTierService) UpdateByPricingTierId(ctx context.Context, pricing return updatedPricingTierSpec, nil } -func (svc PricingTierService) DeleteByPricingTierId(ctx context.Context, pricingTierId string) error { +func (svc PricingTierService) DeleteByPricingTierId(ctx context.Context, pricingTierId string) (*wookie.Token, error) { + var newWookie *wookie.Token err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { err := svc.Repository.DeleteByPricingTierId(txCtx, pricingTierId) if err != nil { return err } - err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypePricingTier, pricingTierId) + newWookie, err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypePricingTier, pricingTierId) if err != nil { return err } @@ -135,8 +137,8 @@ func (svc PricingTierService) DeleteByPricingTierId(ctx context.Context, pricing return nil }) if err != nil { - return err + return nil, err } - return nil + return newWookie, nil } diff --git a/pkg/authz/role/handlers.go b/pkg/authz/role/handlers.go index dbd45d89..46bbdcdc 100644 --- a/pkg/authz/role/handlers.go +++ b/pkg/authz/role/handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gorilla/mux" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -114,10 +115,11 @@ func DeleteHandler(svc RoleService, w http.ResponseWriter, r *http.Request) erro return service.NewMissingRequiredParameterError("roleId") } - err := svc.DeleteByRoleId(r.Context(), roleId) + newWookie, err := svc.DeleteByRoleId(r.Context(), roleId) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) return nil } diff --git a/pkg/authz/role/service.go b/pkg/authz/role/service.go index ec489ed7..4a9db1dc 100644 --- a/pkg/authz/role/service.go +++ b/pkg/authz/role/service.go @@ -5,6 +5,7 @@ import ( object "github.com/warrant-dev/warrant/pkg/authz/object" objecttype "github.com/warrant-dev/warrant/pkg/authz/objecttype" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -115,14 +116,15 @@ func (svc RoleService) UpdateByRoleId(ctx context.Context, roleId string, roleSp return updatedRoleSpec, nil } -func (svc RoleService) DeleteByRoleId(ctx context.Context, roleId string) error { +func (svc RoleService) DeleteByRoleId(ctx context.Context, roleId string) (*wookie.Token, error) { + var newWookie *wookie.Token err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { err := svc.Repository.DeleteByRoleId(txCtx, roleId) if err != nil { return err } - err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeRole, roleId) + newWookie, err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeRole, roleId) if err != nil { return err } @@ -135,8 +137,8 @@ func (svc RoleService) DeleteByRoleId(ctx context.Context, roleId string) error return nil }) if err != nil { - return err + return nil, err } - return nil + return newWookie, nil } diff --git a/pkg/authz/tenant/handlers.go b/pkg/authz/tenant/handlers.go index 7e3e4ba7..542ad452 100644 --- a/pkg/authz/tenant/handlers.go +++ b/pkg/authz/tenant/handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gorilla/mux" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -110,10 +111,11 @@ func UpdateHandler(svc TenantService, w http.ResponseWriter, r *http.Request) er func DeleteHandler(svc TenantService, w http.ResponseWriter, r *http.Request) error { tenantId := mux.Vars(r)["tenantId"] - err := svc.DeleteByTenantId(r.Context(), tenantId) + newWookie, err := svc.DeleteByTenantId(r.Context(), tenantId) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) w.Header().Set("Content-type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/pkg/authz/tenant/service.go b/pkg/authz/tenant/service.go index b1d0f9b3..f7e328cd 100644 --- a/pkg/authz/tenant/service.go +++ b/pkg/authz/tenant/service.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" object "github.com/warrant-dev/warrant/pkg/authz/object" objecttype "github.com/warrant-dev/warrant/pkg/authz/objecttype" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -130,14 +131,15 @@ func (svc TenantService) UpdateByTenantId(ctx context.Context, tenantId string, return updatedTenantSpec, nil } -func (svc TenantService) DeleteByTenantId(ctx context.Context, tenantId string) error { +func (svc TenantService) DeleteByTenantId(ctx context.Context, tenantId string) (*wookie.Token, error) { + var newWookie *wookie.Token err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { err := svc.Repository.DeleteByTenantId(txCtx, tenantId) if err != nil { return err } - err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeTenant, tenantId) + newWookie, err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeTenant, tenantId) if err != nil { return err } @@ -150,8 +152,8 @@ func (svc TenantService) DeleteByTenantId(ctx context.Context, tenantId string) return nil }) if err != nil { - return err + return nil, err } - return nil + return newWookie, nil } diff --git a/pkg/authz/user/handlers.go b/pkg/authz/user/handlers.go index b2f53737..4dc32d6a 100644 --- a/pkg/authz/user/handlers.go +++ b/pkg/authz/user/handlers.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gorilla/mux" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -109,10 +110,11 @@ func UpdateHandler(svc UserService, w http.ResponseWriter, r *http.Request) erro func DeleteHandler(svc UserService, w http.ResponseWriter, r *http.Request) error { userId := mux.Vars(r)["userId"] - err := svc.DeleteByUserId(r.Context(), userId) + newWookie, err := svc.DeleteByUserId(r.Context(), userId) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) w.Header().Set("Content-type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/pkg/authz/user/service.go b/pkg/authz/user/service.go index 7c67b72c..ff8f7ed1 100644 --- a/pkg/authz/user/service.go +++ b/pkg/authz/user/service.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/errors" object "github.com/warrant-dev/warrant/pkg/authz/object" objecttype "github.com/warrant-dev/warrant/pkg/authz/objecttype" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -130,14 +131,15 @@ func (svc UserService) UpdateByUserId(ctx context.Context, userId string, userSp return updatedUserSpec, nil } -func (svc UserService) DeleteByUserId(ctx context.Context, userId string) error { +func (svc UserService) DeleteByUserId(ctx context.Context, userId string) (*wookie.Token, error) { + var newWookie *wookie.Token err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { err := svc.Repository.DeleteByUserId(txCtx, userId) if err != nil { return err } - err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeUser, userId) + newWookie, err = svc.ObjectSvc.DeleteByObjectTypeAndId(txCtx, objecttype.ObjectTypeUser, userId) if err != nil { return err } @@ -150,5 +152,8 @@ func (svc UserService) DeleteByUserId(ctx context.Context, userId string) error return nil }) - return err + if err != nil { + return nil, err + } + return newWookie, nil } diff --git a/pkg/authz/warrant/handlers.go b/pkg/authz/warrant/handlers.go index 47f0c06b..aa6f94d3 100644 --- a/pkg/authz/warrant/handlers.go +++ b/pkg/authz/warrant/handlers.go @@ -3,6 +3,7 @@ package authz import ( "net/http" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/service" ) @@ -24,6 +25,7 @@ func (svc WarrantService) Routes() ([]service.Route, error) { Method: "GET", Handler: service.ChainMiddleware( service.NewRouteHandler(svc, ListHandler), + wookie.ClientTokenMiddleware, service.ListMiddleware[WarrantListParamParser], ), }, @@ -54,10 +56,11 @@ func CreateHandler(svc WarrantService, w http.ResponseWriter, r *http.Request) e } } - createdWarrant, err := svc.Create(r.Context(), warrantSpec) + createdWarrant, newWookie, err := svc.Create(r.Context(), warrantSpec) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) service.SendJSONResponse(w, createdWarrant) return nil @@ -81,10 +84,11 @@ func ListHandler(svc WarrantService, w http.ResponseWriter, r *http.Request) err filters.Subject.Relation = subjectRelation } - warrants, err := svc.List(r.Context(), &filters, listParams) + warrants, updatedWookie, err := svc.List(r.Context(), &filters, listParams) if err != nil { return err } + wookie.AddAsResponseHeader(w, updatedWookie) service.SendJSONResponse(w, warrants) return nil @@ -97,10 +101,11 @@ func DeleteHandler(svc WarrantService, w http.ResponseWriter, r *http.Request) e return err } - err = svc.Delete(r.Context(), warrantSpec) + newWookie, err := svc.Delete(r.Context(), warrantSpec) if err != nil { return err } + wookie.AddAsResponseHeader(w, newWookie) w.Header().Set("Content-type", "application/json") w.WriteHeader(http.StatusOK) diff --git a/pkg/authz/warrant/service.go b/pkg/authz/warrant/service.go index 6dafbe9e..7b49410b 100644 --- a/pkg/authz/warrant/service.go +++ b/pkg/authz/warrant/service.go @@ -4,6 +4,7 @@ import ( "context" objecttype "github.com/warrant-dev/warrant/pkg/authz/objecttype" + wookie "github.com/warrant-dev/warrant/pkg/authz/wookie" "github.com/warrant-dev/warrant/pkg/event" "github.com/warrant-dev/warrant/pkg/service" ) @@ -13,38 +14,40 @@ type WarrantService struct { Repository WarrantRepository EventSvc event.EventService ObjectTypeSvc objecttype.ObjectTypeService + WookieSvc wookie.WookieService } -func NewService(env service.Env, repository WarrantRepository, eventSvc event.EventService, objectTypeSvc objecttype.ObjectTypeService) WarrantService { +func NewService(env service.Env, repository WarrantRepository, eventSvc event.EventService, objectTypeSvc objecttype.ObjectTypeService, wookieService wookie.WookieService) WarrantService { return WarrantService{ BaseService: service.NewBaseService(env), Repository: repository, EventSvc: eventSvc, ObjectTypeSvc: objectTypeSvc, + WookieSvc: wookieService, } } -func (svc WarrantService) Create(ctx context.Context, warrantSpec WarrantSpec) (*WarrantSpec, error) { +func (svc WarrantService) Create(ctx context.Context, warrantSpec WarrantSpec) (*WarrantSpec, *wookie.Token, error) { // Check that objectType is valid - objectTypeDef, err := svc.ObjectTypeSvc.GetByTypeId(ctx, warrantSpec.ObjectType) + objectTypeDef, _, err := svc.ObjectTypeSvc.GetByTypeId(ctx, warrantSpec.ObjectType) if err != nil { - return nil, service.NewInvalidParameterError("objectType", "The given object type does not exist.") + return nil, nil, service.NewInvalidParameterError("objectType", "The given object type does not exist.") } // Check that relation is valid for objectType _, exists := objectTypeDef.Relations[warrantSpec.Relation] if !exists { - return nil, service.NewInvalidParameterError("relation", "An object type with the given relation does not exist.") + return nil, nil, service.NewInvalidParameterError("relation", "An object type with the given relation does not exist.") } // Check that warrant does not already exist _, err = svc.Repository.Get(ctx, warrantSpec.ObjectType, warrantSpec.ObjectId, warrantSpec.Relation, warrantSpec.Subject.ObjectType, warrantSpec.Subject.ObjectId, warrantSpec.Subject.Relation, warrantSpec.Policy.Hash()) if err == nil { - return nil, service.NewDuplicateRecordError("Warrant", warrantSpec, "A warrant with the given objectType, objectId, relation, subject, and policy already exists") + return nil, nil, service.NewDuplicateRecordError("Warrant", warrantSpec, "A warrant with the given objectType, objectId, relation, subject, and policy already exists") } var createdWarrant Model - err = svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { + newWookie, e := svc.WookieSvc.WithWookieUpdate(ctx, func(txCtx context.Context) error { warrant, err := warrantSpec.ToWarrant() if err != nil { return err @@ -73,29 +76,35 @@ func (svc WarrantService) Create(ctx context.Context, warrantSpec WarrantSpec) ( return nil }) - if err != nil { - return nil, err + if e != nil { + return nil, nil, e } - return createdWarrant.ToWarrantSpec(), nil + return createdWarrant.ToWarrantSpec(), newWookie, nil } -func (svc WarrantService) List(ctx context.Context, filterOptions *FilterOptions, listParams service.ListParams) ([]*WarrantSpec, error) { +func (svc WarrantService) List(ctx context.Context, filterOptions *FilterOptions, listParams service.ListParams) ([]*WarrantSpec, *wookie.Token, error) { warrantSpecs := make([]*WarrantSpec, 0) - warrants, err := svc.Repository.List(ctx, filterOptions, listParams) - if err != nil { - return nil, err - } + newWookie, e := svc.WookieSvc.WookieSafeRead(ctx, func(wkCtx context.Context) error { + warrants, err := svc.Repository.List(wkCtx, filterOptions, listParams) + if err != nil { + return err + } - for _, warrant := range warrants { - warrantSpecs = append(warrantSpecs, warrant.ToWarrantSpec()) - } + for _, warrant := range warrants { + warrantSpecs = append(warrantSpecs, warrant.ToWarrantSpec()) + } - return warrantSpecs, nil + return nil + }) + if e != nil { + return nil, nil, e + } + return warrantSpecs, newWookie, nil } -func (svc WarrantService) Delete(ctx context.Context, warrantSpec WarrantSpec) error { - err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { +func (svc WarrantService) Delete(ctx context.Context, warrantSpec WarrantSpec) (*wookie.Token, error) { + newWookie, e := svc.WookieSvc.WithWookieUpdate(ctx, func(txCtx context.Context) error { warrantToDelete, err := warrantSpec.ToWarrant() if err != nil { return nil @@ -124,15 +133,15 @@ func (svc WarrantService) Delete(ctx context.Context, warrantSpec WarrantSpec) e return nil }) - if err != nil { - return err + if e != nil { + return nil, e } - return nil + return newWookie, nil } -func (svc WarrantService) DeleteRelatedWarrants(ctx context.Context, objectType string, objectId string) error { - err := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { +func (svc WarrantService) DeleteRelatedWarrants(ctx context.Context, objectType string, objectId string) (*wookie.Token, error) { + newWookie, e := svc.WookieSvc.WithWookieUpdate(ctx, func(txCtx context.Context) error { err := svc.Repository.DeleteAllByObject(txCtx, objectType, objectId) if err != nil { return err @@ -145,9 +154,9 @@ func (svc WarrantService) DeleteRelatedWarrants(ctx context.Context, objectType return nil }) - if err != nil { - return err + if e != nil { + return nil, e } - return nil + return newWookie, nil } diff --git a/pkg/authz/wookie/http.go b/pkg/authz/wookie/http.go new file mode 100644 index 00000000..25083527 --- /dev/null +++ b/pkg/authz/wookie/http.go @@ -0,0 +1,34 @@ +package wookie + +import ( + "context" + "net/http" + + "github.com/rs/zerolog/hlog" +) + +const WarrantTokenHeaderName = "Warrant-Token" + +func ClientTokenMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headerVal := r.Header.Get(WarrantTokenHeaderName) + if headerVal != "" { + clientWookie, err := FromString(headerVal) + if err != nil { + hlog.FromRequest(r).Warn().Msgf("invalid client-supplied wookie header: %s", headerVal) + next.ServeHTTP(w, r) + return + } + newContext := context.WithValue(r.Context(), ClientTokenKey{}, clientWookie) + next.ServeHTTP(w, r.WithContext(newContext)) + return + } + next.ServeHTTP(w, r) + }) +} + +func AddAsResponseHeader(w http.ResponseWriter, token *Token) { + if token != nil { + w.Header().Set(WarrantTokenHeaderName, token.String()) + } +} diff --git a/pkg/authz/wookie/model.go b/pkg/authz/wookie/model.go new file mode 100644 index 00000000..41be7b14 --- /dev/null +++ b/pkg/authz/wookie/model.go @@ -0,0 +1,36 @@ +package wookie + +import "time" + +type Model interface { + GetID() int64 + GetVersion() int64 + GetCreatedAt() time.Time + ToToken() *Token +} + +type Wookie struct { + ID int64 `mysql:"id" postgres:"id" sqlite:"id"` + Version int64 `mysql:"ver" postgres:"ver" sqlite:"ver"` + CreatedAt time.Time `mysql:"createdAt" postgres:"created_at" sqlite:"createdAt"` +} + +func (w Wookie) GetID() int64 { + return w.ID +} + +func (w Wookie) GetVersion() int64 { + return w.Version +} + +func (w Wookie) GetCreatedAt() time.Time { + return w.CreatedAt +} + +func (w Wookie) ToToken() *Token { + return &Token{ + ID: w.ID, + Version: w.Version, + Timestamp: w.CreatedAt, + } +} diff --git a/pkg/authz/wookie/mysql.go b/pkg/authz/wookie/mysql.go new file mode 100644 index 00000000..22c4e828 --- /dev/null +++ b/pkg/authz/wookie/mysql.go @@ -0,0 +1,76 @@ +package wookie + +import ( + "context" + + "github.com/pkg/errors" + "github.com/warrant-dev/warrant/pkg/database" +) + +type MySQLRepository struct { + database.SQLRepository +} + +func NewMySQLRepository(db *database.MySQL) MySQLRepository { + return MySQLRepository{ + database.NewSQLRepository(db), + } +} + +func (repo MySQLRepository) Create(ctx context.Context, version int64) (int64, error) { + result, err := repo.DB(ctx).ExecContext( + ctx, + ` + INSERT INTO wookie ( + ver + ) + VALUES (?) + `, + version, + ) + if err != nil { + return -1, errors.Wrap(err, "error creating wookie") + } + id, err := result.LastInsertId() + if err != nil { + return -1, errors.Wrap(err, "error creating wookie") + } + return id, nil +} + +func (repo MySQLRepository) GetById(ctx context.Context, id int64) (Model, error) { + var wookie Wookie + err := repo.DB(ctx).GetContext( + ctx, + &wookie, + ` + SELECT id, ver, createdAt + FROM wookie + WHERE + id = ? + `, + id, + ) + if err != nil { + return nil, errors.Wrap(err, "error getting wookie") + } + return &wookie, nil +} + +func (repo MySQLRepository) GetLatest(ctx context.Context) (Model, error) { + var wookie Wookie + err := repo.DB(ctx).GetContext( + ctx, + &wookie, + ` + SELECT id, ver, createdAt + FROM wookie + WHERE + id = (SELECT MAX(id) FROM wookie) + `, + ) + if err != nil { + return nil, errors.Wrap(err, "error getting latest wookie") + } + return &wookie, nil +} diff --git a/pkg/authz/wookie/postgres.go b/pkg/authz/wookie/postgres.go new file mode 100644 index 00000000..dfbe467e --- /dev/null +++ b/pkg/authz/wookie/postgres.go @@ -0,0 +1,75 @@ +package wookie + +import ( + "context" + + "github.com/pkg/errors" + "github.com/warrant-dev/warrant/pkg/database" +) + +type PostgresRepository struct { + database.SQLRepository +} + +func NewPostgresRepository(db *database.Postgres) PostgresRepository { + return PostgresRepository{ + database.NewSQLRepository(db), + } +} + +func (repo PostgresRepository) Create(ctx context.Context, version int64) (int64, error) { + var newWookieId int64 + err := repo.DB(ctx).GetContext( + ctx, + &newWookieId, + ` + INSERT INTO wookie ( + ver + ) + VALUES (?) + RETURNING id + `, + version, + ) + if err != nil { + return -1, errors.Wrap(err, "error creating wookie") + } + return newWookieId, nil +} + +func (repo PostgresRepository) GetById(ctx context.Context, id int64) (Model, error) { + var wookie Wookie + err := repo.DB(ctx).GetContext( + ctx, + &wookie, + ` + SELECT id, ver, created_at + FROM wookie + WHERE + id = ? + `, + id, + ) + if err != nil { + return nil, errors.Wrap(err, "error getting wookie") + } + return &wookie, nil +} + +func (repo PostgresRepository) GetLatest(ctx context.Context) (Model, error) { + var wookie Wookie + err := repo.DB(ctx).GetContext( + ctx, + &wookie, + ` + SELECT id, ver, created_at + FROM wookie + WHERE + id = (SELECT MAX(id) FROM wookie) + `, + ) + if err != nil { + return nil, errors.Wrap(err, "error getting latest wookie") + } + return &wookie, nil +} diff --git a/pkg/authz/wookie/repository.go b/pkg/authz/wookie/repository.go new file mode 100644 index 00000000..18f9b552 --- /dev/null +++ b/pkg/authz/wookie/repository.go @@ -0,0 +1,42 @@ +package wookie + +import ( + "context" + "fmt" + + "github.com/pkg/errors" + "github.com/warrant-dev/warrant/pkg/database" +) + +type WookieRepository interface { + Create(ctx context.Context, version int64) (int64, error) + GetById(ctx context.Context, id int64) (Model, error) + GetLatest(ctx context.Context) (Model, error) +} + +func NewRepository(db database.Database) (WookieRepository, error) { + switch db.Type() { + case database.TypeMySQL: + mysql, ok := db.(*database.MySQL) + if !ok { + return nil, errors.New(fmt.Sprintf("invalid %s database config", database.TypeMySQL)) + } + return NewMySQLRepository(mysql), nil + case database.TypePostgres: + postgres, ok := db.(*database.Postgres) + if !ok { + return nil, errors.New(fmt.Sprintf("invalid %s database config", database.TypePostgres)) + } + + return NewPostgresRepository(postgres), nil + case database.TypeSQLite: + sqlite, ok := db.(*database.SQLite) + if !ok { + return nil, errors.New(fmt.Sprintf("invalid %s database config", database.TypeSQLite)) + } + + return NewSQLiteRepository(sqlite), nil + default: + return nil, errors.New(fmt.Sprintf("unsupported database type %s specified", db.Type())) + } +} diff --git a/pkg/authz/wookie/service.go b/pkg/authz/wookie/service.go new file mode 100644 index 00000000..802c7452 --- /dev/null +++ b/pkg/authz/wookie/service.go @@ -0,0 +1,158 @@ +package wookie + +import ( + "context" + + "github.com/pkg/errors" + "github.com/rs/zerolog/log" + "github.com/warrant-dev/warrant/pkg/database" + "github.com/warrant-dev/warrant/pkg/service" +) + +const currentWookieVersion = 1 + +type updateWookieKey struct{} + +type wookieQueryContextKey struct{} + +type WookieService struct { + service.BaseService + Repository WookieRepository +} + +func NewService(env service.Env, repository WookieRepository) WookieService { + return WookieService{ + BaseService: service.NewBaseService(env), + Repository: repository, + } +} + +// Apply given updateFunc() and create a new wookie for this update. Returns the new wookie token. +func (svc WookieService) WithWookieUpdate(ctx context.Context, updateFunc func(txCtx context.Context) error) (*Token, error) { + _, hasQueryWookie := ctx.Value(wookieQueryContextKey{}).(*Token) + if hasQueryWookie { + return nil, errors.New("invalid state: can't call WookieUpdate() within WookieSafeRead()") + } + + updateWookie, hasUpdateWookie := ctx.Value(updateWookieKey{}).(*Token) + // An update is already in progress so continue with that ctx + if hasUpdateWookie { + e := updateFunc(ctx) + if e != nil { + return nil, e + } + return updateWookie, nil + } + + // Otherwise create a new tx and new update wookie + var newWookie *Token + e := svc.Env().DB().WithinTransaction(ctx, func(txCtx context.Context) error { + newWookieId, err := svc.Repository.Create(txCtx, currentWookieVersion) + if err != nil { + return err + } + token, err := svc.Repository.GetById(txCtx, newWookieId) + if err != nil { + return err + } + newWookie = token.ToToken() + wkCtx := context.WithValue(txCtx, updateWookieKey{}, newWookie) + err = updateFunc(wkCtx) + if err != nil { + return err + } + + return nil + }) + if e != nil { + return nil, e + } + return newWookie, nil +} + +func (svc WookieService) WookieSafeRead(ctx context.Context, readFunc func(wkCtx context.Context) error) (*Token, error) { + // A read is already in progress so continue with that ctx + queryWookie, hasQueryWookie := ctx.Value(wookieQueryContextKey{}).(*Token) + if hasQueryWookie { + e := readFunc(ctx) + if e != nil { + return nil, e + } + return queryWookie, nil + } + + // If client didn't pass a wookie, run readFunc() with existing ctx and return latest wookie if present + clientWookie, hasClientWookie := ctx.Value(ClientTokenKey{}).(Token) + if !hasClientWookie { + // TODO: Ideally the server should default to some trailing wookie value here. For now, default to 'unsafe' op to always use up-to-date db. + unsafeCtx := context.WithValue(ctx, database.UnsafeOp{}, true) + writerLatest, e := svc.Repository.GetLatest(unsafeCtx) + var latestWookieToReturn *Token + if e != nil { + log.Ctx(ctx).Warn().Err(e).Msg("error getting writer latest wookie") + latestWookieToReturn = nil + } + var wkCtx context.Context + if writerLatest != nil { + latestWookieToReturn = writerLatest.ToToken() + wkCtx = context.WithValue(unsafeCtx, wookieQueryContextKey{}, latestWookieToReturn) + } else { + wkCtx = unsafeCtx + } + e = readFunc(wkCtx) + if e != nil { + return nil, e + } + return latestWookieToReturn, nil + } + + // Otherwise, compare client wookie to a reader's latest wookie to see if we can use it + var latestWookieToReturn *Token + e := svc.Env().DB().WithinConsistentRead(ctx, func(connCtx context.Context) error { + unsafe := false + + // First, get the reader's latest wookie + readerLatest, err := svc.Repository.GetLatest(connCtx) + if err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("error getting reader latest wookie") + unsafe = true + } + + // Compare reader wookie against client-provided wookie + if readerLatest != nil { + if readerLatest.GetID() < clientWookie.ID { + // Reader is behind so op is unsafe + unsafe = true + } else { + // Reader is up-to-date or ahead so is safe to use + unsafe = false + latestWookieToReturn = readerLatest.ToToken() + } + } + + wkCtx := context.WithValue(connCtx, database.UnsafeOp{}, unsafe) + if unsafe { + // Get writer's latest wookie + writerLatest, err := svc.Repository.GetLatest(wkCtx) + if err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("error getting writer latest wookie") + latestWookieToReturn = nil + } + if writerLatest != nil { + latestWookieToReturn = writerLatest.ToToken() + } + } + + // Execute read + readCtx := context.WithValue(wkCtx, wookieQueryContextKey{}, latestWookieToReturn) + err = readFunc(readCtx) + if err != nil { + return err + } + return nil + }) + if e != nil { + return nil, e + } + return latestWookieToReturn, nil +} diff --git a/pkg/authz/wookie/sqlite.go b/pkg/authz/wookie/sqlite.go new file mode 100644 index 00000000..7ecbc234 --- /dev/null +++ b/pkg/authz/wookie/sqlite.go @@ -0,0 +1,76 @@ +package wookie + +import ( + "context" + + "github.com/pkg/errors" + "github.com/warrant-dev/warrant/pkg/database" +) + +type SQLiteRepository struct { + database.SQLRepository +} + +func NewSQLiteRepository(db *database.SQLite) SQLiteRepository { + return SQLiteRepository{ + database.NewSQLRepository(db), + } +} + +func (repo SQLiteRepository) Create(ctx context.Context, version int64) (int64, error) { + result, err := repo.DB(ctx).ExecContext( + ctx, + ` + INSERT INTO wookie ( + ver + ) + VALUES (?) + `, + version, + ) + if err != nil { + return -1, errors.Wrap(err, "error creating wookie") + } + id, err := result.LastInsertId() + if err != nil { + return -1, errors.Wrap(err, "error creating wookie") + } + return id, nil +} + +func (repo SQLiteRepository) GetById(ctx context.Context, id int64) (Model, error) { + var wookie Wookie + err := repo.DB(ctx).GetContext( + ctx, + &wookie, + ` + SELECT id, ver, createdAt + FROM wookie + WHERE + id = ? + `, + id, + ) + if err != nil { + return nil, errors.Wrap(err, "error getting wookie") + } + return &wookie, nil +} + +func (repo SQLiteRepository) GetLatest(ctx context.Context) (Model, error) { + var wookie Wookie + err := repo.DB(ctx).GetContext( + ctx, + &wookie, + ` + SELECT id, ver, createdAt + FROM wookie + WHERE + id = (SELECT MAX(id) FROM wookie) + `, + ) + if err != nil { + return nil, errors.Wrap(err, "error getting latest wookie") + } + return &wookie, nil +} diff --git a/pkg/authz/wookie/token.go b/pkg/authz/wookie/token.go new file mode 100644 index 00000000..95a804f6 --- /dev/null +++ b/pkg/authz/wookie/token.go @@ -0,0 +1,60 @@ +package wookie + +import ( + "encoding/base64" + "fmt" + "strconv" + "strings" + "time" + + "github.com/pkg/errors" +) + +// ClientToken context key +type ClientTokenKey struct{} + +type Token struct { + ID int64 + Version int64 + Timestamp time.Time +} + +// Get string representation of token (to set as header) +func (t Token) String() string { + s := fmt.Sprintf("%d;%d;%d", t.ID, t.Version, t.Timestamp.UnixMicro()) + return base64.StdEncoding.EncodeToString([]byte(s)) +} + +// De-serialize token from string (from header) +func FromString(wookieString string) (Token, error) { + if wookieString == "" { + return Token{}, errors.New("empty wookie string") + } + decodedStr, err := base64.StdEncoding.DecodeString(wookieString) + if err != nil { + return Token{}, errors.New("invalid wookie string") + } + parts := strings.Split(string(decodedStr), ";") + if len(parts) != 3 { + return Token{}, errors.New("invalid wookie string") + } + id, err := strconv.ParseInt(parts[0], 0, 64) + if err != nil { + return Token{}, errors.New("invalid id in wookie string") + } + version, err := strconv.ParseInt(parts[1], 0, 64) + if err != nil { + return Token{}, errors.New("invalid version in wookie string") + } + microTs, err := strconv.ParseInt(parts[2], 0, 64) + if err != nil { + return Token{}, errors.New("invalid timestamp in wookie string") + } + timestamp := time.UnixMicro(microTs) + + return Token{ + ID: id, + Version: version, + Timestamp: timestamp, + }, nil +} diff --git a/pkg/authz/wookie/token_test.go b/pkg/authz/wookie/token_test.go new file mode 100644 index 00000000..687635cd --- /dev/null +++ b/pkg/authz/wookie/token_test.go @@ -0,0 +1,68 @@ +package wookie + +import ( + "reflect" + "testing" + "time" +) + +func TestBasicSerialization(t *testing.T) { + ts := time.UnixMicro(1687375083854) + token := Token{ + ID: 25, + Version: 1, + Timestamp: ts, + } + tokenString := token.String() + expectedString := "MjU7MTsxNjg3Mzc1MDgzODU0" + if tokenString != expectedString { + t.Fatalf("expected token string: %s, actual token string: %s", expectedString, tokenString) + } + + deserializedToken, err := FromString(tokenString) + if err != nil { + t.Fatalf("unexpected error when deserializing token %v", err) + } + + if !reflect.DeepEqual(token, deserializedToken) { + t.Fatal("Deserialized token should be equal") + } +} + +func TestVariousInvalidTokenStrings(t *testing.T) { + invalidString := "" + _, err := FromString(invalidString) + if err == nil { + t.Fatalf("token string %s is invalid and should not deserialize", invalidString) + } + + invalidString = "***" + _, err = FromString(invalidString) + if err == nil { + t.Fatalf("token string %s is invalid and should not deserialize", invalidString) + } + + invalidString = "MjU7MQ==" + _, err = FromString(invalidString) + if err == nil { + t.Fatalf("token string %s is invalid and should not deserialize", invalidString) + } + + invalidString = "aGk7MTsxNjg3Mzc1MDgzODU0" + _, err = FromString(invalidString) + if err == nil { + t.Fatalf("token string %s is invalid and should not deserialize", invalidString) + } + + invalidString = "MjU7eW91OzE2ODczNzUwODM4NTQ=" + _, err = FromString(invalidString) + if err == nil { + t.Fatalf("token string %s is invalid and should not deserialize", invalidString) + } + + invalidString = "MjU7MTthc2Rm" + _, err = FromString(invalidString) + if err == nil { + t.Fatalf("token string %s is invalid and should not deserialize", invalidString) + } +}