Skip to content

Commit

Permalink
Change checkSvc to use warrantSvc instead of warrantRepo (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
akajla09 authored Oct 3, 2023
1 parent 183dd0c commit f48f881
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 591 deletions.
2 changes: 1 addition & 1 deletion cmd/warrant/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func main() {
warrantSvc := warrant.NewService(svcEnv, warrantRepository, eventSvc, objectTypeSvc, objectSvc)

// Init check service
checkSvc := check.NewService(svcEnv, warrantRepository, eventSvc, objectTypeSvc, cfg.Check, nil)
checkSvc := check.NewService(svcEnv, warrantSvc, eventSvc, objectTypeSvc, cfg.Check, nil)

// Init query service
querySvc := query.NewService(svcEnv, objectTypeSvc, warrantSvc, objectSvc)
Expand Down
103 changes: 62 additions & 41 deletions pkg/authz/check/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ import (
"github.com/warrant-dev/warrant/pkg/wookie"
)

const (
MaxWarrants = 5000
)

type CheckContextFunc func(ctx context.Context) (context.Context, error)

type CheckService struct {
service.BaseService
WarrantRepository warrant.WarrantRepository
warrantSvc warrant.Service
EventSvc event.Service
ObjectTypeSvc objecttype.Service
CheckConfig *config.CheckConfig
Expand All @@ -48,10 +52,10 @@ func defaultCreateCheckContext(ctx context.Context) (context.Context, error) {
return checkCtx, nil
}

func NewService(env service.Env, warrantRepo warrant.WarrantRepository, eventSvc event.Service, objectTypeSvc objecttype.Service, checkConfig *config.CheckConfig, checkContext CheckContextFunc) *CheckService {
func NewService(env service.Env, warrantSvc warrant.Service, eventSvc event.Service, objectTypeSvc objecttype.Service, checkConfig *config.CheckConfig, checkContext CheckContextFunc) *CheckService {
svc := &CheckService{
BaseService: service.NewBaseService(env),
WarrantRepository: warrantRepo,
warrantSvc: warrantSvc,
EventSvc: eventSvc,
ObjectTypeSvc: objectTypeSvc,
CheckConfig: checkConfig,
Expand All @@ -69,22 +73,35 @@ func (svc CheckService) getWithPolicyMatch(ctx context.Context, checkPipeline *p
checkPipeline.AcquireServiceLock()
defer checkPipeline.ReleaseServiceLock()

warrants, err := svc.WarrantRepository.GetAllMatchingObjectRelationAndSubject(ctx, spec.ObjectType, spec.ObjectId, spec.Relation, spec.Subject.ObjectType, spec.Subject.ObjectId, spec.Subject.Relation)
if err != nil || len(warrants) == 0 {
listParams := service.DefaultListParams(warrant.WarrantListParamParser{})
listParams.Limit = MaxWarrants
warrantSpecs, err := svc.warrantSvc.List(
ctx,
&warrant.FilterParams{
ObjectType: []string{spec.ObjectType},
ObjectId: []string{spec.ObjectId},
Relation: []string{spec.Relation},
SubjectType: []string{spec.Subject.ObjectType},
SubjectId: []string{spec.Subject.ObjectId},
SubjectRelation: []string{spec.Subject.Relation},
},
listParams,
)
if err != nil || len(warrantSpecs) == 0 {
return nil, err
}

// if a warrant without a policy is found, match it
for _, warrant := range warrants {
if warrant.GetPolicy() == "" {
return warrant.ToWarrantSpec(), nil
for _, warrant := range warrantSpecs {
if warrant.Policy == "" {
return &warrant, nil
}
}

for _, warrant := range warrants {
if warrant.GetPolicy() != "" {
for _, warrant := range warrantSpecs {
if warrant.Policy != "" {
if policyMatched := evalWarrantPolicy(warrant, spec.Context); policyMatched {
return warrant.ToWarrantSpec(), nil
return &warrant, nil
}
}
}
Expand All @@ -106,31 +123,33 @@ func (svc CheckService) getMatchingSubjects(ctx context.Context, checkPipeline *
return warrantSpecs, nil
}

warrants, err := svc.WarrantRepository.GetAllMatchingObjectAndRelation(
listParams := service.DefaultListParams(warrant.WarrantListParamParser{})
listParams.Limit = MaxWarrants
warrantSpecs, err = svc.warrantSvc.List(
ctx,
objectType,
objectId,
relation,
&warrant.FilterParams{
ObjectType: []string{objectType},
ObjectId: []string{objectId},
Relation: []string{relation},
},
listParams,
)
if err != nil {
return warrantSpecs, err
}

for _, warrant := range warrants {
if warrant.GetPolicy() == "" {
warrantSpecs = append(warrantSpecs, *warrant.ToWarrantSpec())
matchingSpecs := make([]warrant.WarrantSpec, 0)
for _, warrant := range warrantSpecs {
if warrant.Policy == "" {
matchingSpecs = append(matchingSpecs, warrant)
} else {
if policyMatched := evalWarrantPolicy(warrant, checkCtx); policyMatched {
warrantSpecs = append(warrantSpecs, *warrant.ToWarrantSpec())
matchingSpecs = append(matchingSpecs, warrant)
}
}
}

if err != nil {
return warrantSpecs, err
}

return warrantSpecs, nil
return matchingSpecs, nil
}

func (svc CheckService) getMatchingSubjectsBySubjectType(ctx context.Context, checkPipeline *pipeline, objectType string,
Expand All @@ -148,32 +167,34 @@ func (svc CheckService) getMatchingSubjectsBySubjectType(ctx context.Context, ch
return warrantSpecs, nil
}

warrants, err := svc.WarrantRepository.GetAllMatchingObjectAndRelationBySubjectType(
listParams := service.DefaultListParams(warrant.WarrantListParamParser{})
listParams.Limit = MaxWarrants
warrantSpecs, err = svc.warrantSvc.List(
ctx,
objectType,
objectId,
relation,
subjectType,
&warrant.FilterParams{
ObjectType: []string{objectType},
ObjectId: []string{objectId},
Relation: []string{relation},
SubjectType: []string{subjectType},
},
listParams,
)
if err != nil {
return warrantSpecs, err
}

for _, warrant := range warrants {
if warrant.GetPolicy() == "" {
warrantSpecs = append(warrantSpecs, *warrant.ToWarrantSpec())
matchingSpecs := make([]warrant.WarrantSpec, 0)
for _, warrant := range warrantSpecs {
if warrant.Policy == "" {
matchingSpecs = append(matchingSpecs, warrant)
} else {
if policyMatched := evalWarrantPolicy(warrant, checkCtx); policyMatched {
warrantSpecs = append(warrantSpecs, *warrant.ToWarrantSpec())
matchingSpecs = append(matchingSpecs, warrant)
}
}
}

if err != nil {
return warrantSpecs, err
}

return warrantSpecs, nil
return matchingSpecs, nil
}

func (svc CheckService) CheckMany(ctx context.Context, authInfo *service.AuthInfo, warrantCheck *CheckManySpec) (*CheckResultSpec, error) {
Expand Down Expand Up @@ -713,16 +734,16 @@ func (p *pipeline) execTasks(ctx context.Context, parentResultC chan<- result, t
}
}

func evalWarrantPolicy(w warrant.Model, policyCtx warrant.PolicyContext) bool {
func evalWarrantPolicy(w warrant.WarrantSpec, policyCtx warrant.PolicyContext) bool {
policyCtxWithWarrant := make(warrant.PolicyContext)
for k, v := range policyCtx {
policyCtxWithWarrant[k] = v
}
policyCtxWithWarrant["warrant"] = w

policyMatched, err := w.GetPolicy().Eval(policyCtxWithWarrant)
policyMatched, err := w.Policy.Eval(policyCtxWithWarrant)
if err != nil {
log.Err(err).Msgf("check: error while evaluating policy %s", w.GetPolicy())
log.Err(err).Msgf("check: error while evaluating policy %s", w.Policy)
return false
}

Expand Down
184 changes: 0 additions & 184 deletions pkg/authz/warrant/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,72 +120,6 @@ func (repo MySQLRepository) Delete(ctx context.Context, objectType string, objec
return nil
}

func (repo MySQLRepository) GetAllMatchingObject(ctx context.Context, objectType string, objectId string) ([]Model, error) {
models := make([]Model, 0)
warrants := make([]Warrant, 0)
err := repo.DB.SelectContext(
ctx,
&warrants,
`
SELECT id, objectType, objectId, relation, subjectType, subjectId, subjectRelation, policy, createdAt, updatedAt, deletedAt
FROM warrant
WHERE
objectType = ? AND
objectId = ? AND
deletedAt IS NULL
`,
objectType,
objectId,
)
if err != nil {
switch err {
case sql.ErrNoRows:
return models, nil
default:
return models, errors.Wrapf(err, "error deleting warrants with object %s:%s", objectType, objectId)
}
}

for i := range warrants {
models = append(models, &warrants[i])
}

return models, nil
}

func (repo MySQLRepository) GetAllMatchingSubject(ctx context.Context, subjectType string, subjectId string) ([]Model, error) {
models := make([]Model, 0)
warrants := make([]Warrant, 0)
err := repo.DB.SelectContext(
ctx,
&warrants,
`
SELECT id, objectType, objectId, relation, subjectType, subjectId, subjectRelation, policy, createdAt, updatedAt, deletedAt
FROM warrant
WHERE
subjectType = ? AND
subjectId = ? AND
deletedAt IS NULL
`,
subjectType,
subjectId,
)
if err != nil {
switch err {
case sql.ErrNoRows:
return models, nil
default:
return models, errors.Wrapf(err, "error deleting warrants with subject %s:%s", subjectType, subjectId)
}
}

for i := range warrants {
models = append(models, &warrants[i])
}

return models, nil
}

func (repo MySQLRepository) Get(ctx context.Context, objectType string, objectId string, relation string, subjectType string, subjectId string, subjectRelation string, policyHash string) (Model, error) {
var warrant Warrant
err := repo.DB.GetContext(
Expand Down Expand Up @@ -458,121 +392,3 @@ func (repo MySQLRepository) List(ctx context.Context, filterParams *FilterParams

return models, nil
}

func (repo MySQLRepository) GetAllMatchingObjectRelationAndSubject(ctx context.Context, objectType string, objectId string, relation string, subjectType string, subjectId string, subjectRelation string) ([]Model, error) {
models := make([]Model, 0)
warrants := make([]Warrant, 0)
err := repo.DB.SelectContext(
ctx,
&warrants,
`
SELECT id, objectType, objectId, relation, subjectType, subjectId, subjectRelation, policy, createdAt, updatedAt, deletedAt
FROM warrant
WHERE
objectType = ? AND
(objectId = ? OR objectId = ?) AND
relation = ? AND
subjectType = ? AND
subjectId = ? AND
subjectRelation = ? AND
deletedAt IS NULL
`,
objectType,
objectId,
Wildcard,
relation,
subjectType,
subjectId,
subjectRelation,
)
if err != nil {
switch err {
case sql.ErrNoRows:
return models, nil
default:
return nil, errors.Wrapf(err, "error getting warrants with object type %s, object id %s, relation %s, subject type %s, subject id %s, and subject relation %s", objectType, objectId, relation, subjectType, subjectId, subjectRelation)
}
}

for i := range warrants {
models = append(models, &warrants[i])
}

return models, nil
}

func (repo MySQLRepository) GetAllMatchingObjectAndRelation(ctx context.Context, objectType string, objectId string, relation string) ([]Model, error) {
models := make([]Model, 0)
warrants := make([]Warrant, 0)
err := repo.DB.SelectContext(
ctx,
&warrants,
`
SELECT id, objectType, objectId, relation, subjectType, subjectId, subjectRelation, policy, createdAt, updatedAt, deletedAt
FROM warrant
WHERE
objectType = ? AND
(objectId = ? OR objectId = ?) AND
relation = ? AND
deletedAt IS NULL
ORDER BY createdAt DESC, id DESC
`,
objectType,
objectId,
Wildcard,
relation,
)
if err != nil {
switch err {
case sql.ErrNoRows:
return models, nil
default:
return nil, errors.Wrapf(err, "error getting warrants with object type %s, object id %s, and relation %s", objectType, objectId, relation)
}
}

for i := range warrants {
models = append(models, &warrants[i])
}

return models, nil
}

func (repo MySQLRepository) GetAllMatchingObjectAndRelationBySubjectType(ctx context.Context, objectType string, objectId string, relation string, subjectType string) ([]Model, error) {
models := make([]Model, 0)
warrants := make([]Warrant, 0)
err := repo.DB.SelectContext(
ctx,
&warrants,
`
SELECT id, objectType, objectId, relation, subjectType, subjectId, subjectRelation, policy, createdAt, updatedAt, deletedAt
FROM warrant
WHERE
objectType = ? AND
(objectId = ? OR objectId = ?) AND
relation = ? AND
subjectType = ? AND
deletedAt IS NULL
ORDER BY createdAt DESC, id DESC
`,
objectType,
objectId,
Wildcard,
relation,
subjectType,
)
if err != nil {
switch err {
case sql.ErrNoRows:
return models, nil
default:
return nil, errors.Wrapf(err, "error getting warrants with object type %s, object id %s, and relation %s", objectType, objectId, relation)
}
}

for i := range warrants {
models = append(models, &warrants[i])
}

return models, nil
}
Loading

0 comments on commit f48f881

Please sign in to comment.