diff --git a/v1/ast/annotations.go b/v1/ast/annotations.go index 9484535dec..5c817fb3e9 100644 --- a/v1/ast/annotations.go +++ b/v1/ast/annotations.go @@ -8,7 +8,7 @@ import ( "encoding/json" "fmt" "net/url" - "sort" + "slices" "strings" "github.com/open-policy-agent/opa/internal/deepcopy" @@ -18,12 +18,32 @@ import ( const ( annotationScopePackage = "package" - annotationScopeImport = "import" annotationScopeRule = "rule" annotationScopeDocument = "document" annotationScopeSubpackages = "subpackages" ) +var ( + scopeTerm = StringTerm("scope") + titleTerm = StringTerm("title") + entrypointTerm = StringTerm("entrypoint") + descriptionTerm = StringTerm("description") + organizationsTerm = StringTerm("organizations") + authorsTerm = StringTerm("authors") + relatedResourcesTerm = StringTerm("related_resources") + schemasTerm = StringTerm("schemas") + customTerm = StringTerm("custom") + refTerm = StringTerm("ref") + nameTerm = StringTerm("name") + emailTerm = StringTerm("email") + schemaTerm = StringTerm("schema") + definitionTerm = StringTerm("definition") + documentTerm = StringTerm(annotationScopeDocument) + packageTerm = StringTerm(annotationScopePackage) + ruleTerm = StringTerm(annotationScopeRule) + subpackagesTerm = StringTerm(annotationScopeSubpackages) +) + type ( // Annotations represents metadata attached to other AST nodes such as rules. Annotations struct { @@ -291,7 +311,6 @@ func (ar *AnnotationsRef) MarshalJSON() ([]byte, error) { } func scopeCompare(s1, s2 string) int { - o1 := scopeOrder(s1) o2 := scopeOrder(s2) @@ -342,7 +361,7 @@ func compareRelatedResources(a, b []*RelatedResourceAnnotation) int { } for i := range a { - if cmp := strings.Compare(a[i].String(), b[i].String()); cmp != 0 { + if cmp := a[i].Compare(b[i]); cmp != 0 { return cmp } } @@ -409,7 +428,9 @@ func (a *Annotations) Copy(node Node) *Annotations { cpy.Schemas[i] = a.Schemas[i].Copy() } - cpy.Custom = deepcopy.Map(a.Custom) + if a.Custom != nil { + cpy.Custom = deepcopy.Map(a.Custom) + } cpy.node = node @@ -425,19 +446,30 @@ func (a *Annotations) toObject() (*Object, *Error) { } if len(a.Scope) > 0 { - obj.Insert(StringTerm("scope"), StringTerm(a.Scope)) + switch a.Scope { + case annotationScopeDocument: + obj.Insert(scopeTerm, documentTerm) + case annotationScopePackage: + obj.Insert(scopeTerm, packageTerm) + case annotationScopeRule: + obj.Insert(scopeTerm, ruleTerm) + case annotationScopeSubpackages: + obj.Insert(scopeTerm, subpackagesTerm) + default: + obj.Insert(scopeTerm, StringTerm(a.Scope)) + } } if len(a.Title) > 0 { - obj.Insert(StringTerm("title"), StringTerm(a.Title)) + obj.Insert(titleTerm, StringTerm(a.Title)) } if a.Entrypoint { - obj.Insert(StringTerm("entrypoint"), BooleanTerm(true)) + obj.Insert(entrypointTerm, InternedBooleanTerm(true)) } if len(a.Description) > 0 { - obj.Insert(StringTerm("description"), StringTerm(a.Description)) + obj.Insert(descriptionTerm, StringTerm(a.Description)) } if len(a.Organizations) > 0 { @@ -445,19 +477,19 @@ func (a *Annotations) toObject() (*Object, *Error) { for _, org := range a.Organizations { orgs = append(orgs, StringTerm(org)) } - obj.Insert(StringTerm("organizations"), ArrayTerm(orgs...)) + obj.Insert(organizationsTerm, ArrayTerm(orgs...)) } if len(a.RelatedResources) > 0 { rrs := make([]*Term, 0, len(a.RelatedResources)) for _, rr := range a.RelatedResources { - rrObj := NewObject(Item(StringTerm("ref"), StringTerm(rr.Ref.String()))) + rrObj := NewObject(Item(refTerm, StringTerm(rr.Ref.String()))) if len(rr.Description) > 0 { - rrObj.Insert(StringTerm("description"), StringTerm(rr.Description)) + rrObj.Insert(descriptionTerm, StringTerm(rr.Description)) } rrs = append(rrs, NewTerm(rrObj)) } - obj.Insert(StringTerm("related_resources"), ArrayTerm(rrs...)) + obj.Insert(relatedResourcesTerm, ArrayTerm(rrs...)) } if len(a.Authors) > 0 { @@ -465,14 +497,14 @@ func (a *Annotations) toObject() (*Object, *Error) { for _, author := range a.Authors { aObj := NewObject() if len(author.Name) > 0 { - aObj.Insert(StringTerm("name"), StringTerm(author.Name)) + aObj.Insert(nameTerm, StringTerm(author.Name)) } if len(author.Email) > 0 { - aObj.Insert(StringTerm("email"), StringTerm(author.Email)) + aObj.Insert(emailTerm, StringTerm(author.Email)) } as = append(as, NewTerm(aObj)) } - obj.Insert(StringTerm("authors"), ArrayTerm(as...)) + obj.Insert(authorsTerm, ArrayTerm(as...)) } if len(a.Schemas) > 0 { @@ -480,21 +512,21 @@ func (a *Annotations) toObject() (*Object, *Error) { for _, s := range a.Schemas { sObj := NewObject() if len(s.Path) > 0 { - sObj.Insert(StringTerm("path"), NewTerm(s.Path.toArray())) + sObj.Insert(pathTerm, NewTerm(s.Path.toArray())) } if len(s.Schema) > 0 { - sObj.Insert(StringTerm("schema"), NewTerm(s.Schema.toArray())) + sObj.Insert(schemaTerm, NewTerm(s.Schema.toArray())) } if s.Definition != nil { def, err := InterfaceToValue(s.Definition) if err != nil { return nil, NewError(CompileErr, a.Location, "invalid definition in schema annotation: %s", err.Error()) } - sObj.Insert(StringTerm("definition"), NewTerm(def)) + sObj.Insert(definitionTerm, NewTerm(def)) } ss = append(ss, NewTerm(sObj)) } - obj.Insert(StringTerm("schemas"), ArrayTerm(ss...)) + obj.Insert(schemasTerm, ArrayTerm(ss...)) } if len(a.Custom) > 0 { @@ -502,7 +534,7 @@ func (a *Annotations) toObject() (*Object, *Error) { if err != nil { return nil, NewError(CompileErr, a.Location, "invalid custom annotation %s", err.Error()) } - obj.Insert(StringTerm("custom"), NewTerm(c)) + obj.Insert(customTerm, NewTerm(c)) } return &obj, nil @@ -563,7 +595,11 @@ func attachAnnotationsNodes(mod *Module) Errors { case *Package: a.Scope = annotationScopePackage case *Import: - a.Scope = annotationScopeImport + // Note that this isn't a valid scope, but set here so that the + // validate function called below can print an error message with + // a context that makes sense ("invalid scope: 'import'" instead of + // "invalid scope: '') + a.Scope = "import" } } @@ -681,7 +717,6 @@ func (s *SchemaAnnotation) Copy() *SchemaAnnotation { // Compare returns an integer indicating if s is less than, equal to, or greater // than other. func (s *SchemaAnnotation) Compare(other *SchemaAnnotation) int { - if cmp := s.Path.Compare(other.Path); cmp != 0 { return cmp } @@ -819,9 +854,7 @@ func (as *AnnotationSet) Flatten() FlatAnnotationsRefSet { } // Sort by path, then annotation location, for stable output - sort.SliceStable(refs, func(i, j int) bool { - return refs[i].Compare(refs[j]) < 0 - }) + slices.SortStableFunc(refs, (*AnnotationsRef).Compare) return refs } @@ -853,8 +886,8 @@ func (as *AnnotationSet) Chain(rule *Rule) AnnotationsRefSet { if len(refs) > 1 { // Sort by annotation location; chain must start with annotations declared closest to rule, then going outward - sort.SliceStable(refs, func(i, j int) bool { - return refs[i].Annotations.Location.Compare(refs[j].Annotations.Location) > 0 + slices.SortStableFunc(refs, func(a, b *AnnotationsRef) int { + return -a.Annotations.Location.Compare(b.Annotations.Location) }) }