Skip to content

Commit

Permalink
CBG-4117: Add type checking to mandatory audit fields during event va…
Browse files Browse the repository at this point in the history
…lidation (#7011)

* misc audit field type changes

* More audit field type corrections

* Read handlePutDbAuditConfig into bytes before unmarshalling so we can pass it into audit event as raw json

* Type check mandatory field values to ensure consistency with the defined event

* uint -> uint32 for audit ID
  • Loading branch information
bbrks authored Jul 25, 2024
1 parent 48d51d3 commit 36d1a81
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 13 deletions.
34 changes: 29 additions & 5 deletions base/audit_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package base

import (
"fmt"
"reflect"
"strconv"
)

Expand Down Expand Up @@ -73,7 +74,7 @@ const (
var fieldsByGroup = map[fieldGroup]map[string]any{
fieldGroupGlobal: {
AuditFieldTimestamp: "timestamp",
AuditFieldID: 123,
AuditFieldID: uint32(123),
AuditFieldName: "event name",
AuditFieldDescription: "event description",
},
Expand Down Expand Up @@ -150,7 +151,7 @@ func (ed *EventDescriptor) expandOptionalFieldGroups(groups []fieldGroup) {

func (i AuditID) MustValidateFields(f AuditFields) {
if err := i.ValidateFields(f); err != nil {
panic(fmt.Errorf("audit event %s(%s) invalid:\n%v", AuditEvents[i].Name, i, err))
panic(fmt.Errorf("audit event %q (%s) invalid:\n%v", i, AuditEvents[i].Name, err))
}
}

Expand All @@ -168,15 +169,38 @@ func (i AuditID) ValidateFields(f AuditFields) error {
func mandatoryFieldsPresent(fields, mandatoryFields AuditFields, baseName string) error {
me := &MultiError{}
for k, v := range mandatoryFields {
if _, ok := fields[k]; !ok {
me = me.Append(fmt.Errorf("missing mandatory field %s", baseName+k))
continue
}
if !matchingTypes(v, fields[k]) {
me = me.Append(fmt.Errorf("field value for %s%s must be of type %T but had %T", baseName, k, v, fields[k]))
continue
}
// recurse if map
if vv, ok := v.(map[string]any); ok {
if pv, ok := fields[k].(map[string]any); ok {
me = me.Append(mandatoryFieldsPresent(pv, vv, baseName+k+"."))
}
}
if _, ok := fields[k]; !ok {
me = me.Append(fmt.Errorf("missing mandatory field %s", baseName+k))
}
}
return me.ErrorOrNil()
}

// matchingTypes returns true if the types of a and b are the same.
func matchingTypes(a, b any) bool {
typeOfA, typeOfB := reflect.TypeOf(a), reflect.TypeOf(b)
if typeOfA == nil || typeOfB == nil {
return typeOfA == typeOfB
}
// deref
if typeOfA.Kind() == reflect.Pointer && typeOfB.Kind() != reflect.Pointer {
typeOfA = typeOfA.Elem()
} else if typeOfB.Kind() == reflect.Pointer && typeOfA.Kind() != reflect.Pointer {
typeOfB = typeOfB.Elem()
}
if typeOfA.ConvertibleTo(typeOfB) {
return true
}
return typeOfA.Kind() == typeOfB.Kind()
}
4 changes: 2 additions & 2 deletions base/logger_audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func expandFields(id AuditID, ctx context.Context, globalFields AuditFields, add
}

// static event data
fields[AuditFieldID] = uint64(id)
fields[AuditFieldID] = uint32(id)
fields[AuditFieldName] = AuditEvents[id].Name
fields[AuditFieldDescription] = AuditEvents[id].Description

Expand Down Expand Up @@ -89,7 +89,7 @@ func expandFields(id AuditID, ctx context.Context, globalFields AuditFields, add
}
}

fields[AuditFieldTimestamp] = time.Now()
fields[AuditFieldTimestamp] = time.Now().Format(time.RFC3339)

fields.merge(ctx, globalFields)
fields.merge(ctx, logCtx.RequestAdditionalAuditFields)
Expand Down
13 changes: 9 additions & 4 deletions rest/admin_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -793,9 +793,14 @@ func (h *handler) handleGetDbAuditConfig() error {
// PUT/POST audit config for database
func (h *handler) handlePutDbAuditConfig() error {

var body HandleDbAuditConfigBody
var bodyRaw []byte
err := h.mutateDbConfig(func(config *DbConfig) error {
if err := h.readJSONInto(&body); err != nil {
bodyRaw, err := h.readBody()
if err != nil {
return err
}
var body HandleDbAuditConfigBody
if err := base.JSONUnmarshal(bodyRaw, &body); err != nil {
return err
}

Expand Down Expand Up @@ -860,7 +865,7 @@ func (h *handler) handlePutDbAuditConfig() error {
}
base.Audit(h.ctx(), base.AuditIDAuditConfigChanged, base.AuditFields{
base.AuditFieldAuditScope: "db",
base.AuditFieldPayload: body,
base.AuditFieldPayload: string(bodyRaw),
})
return nil
}
Expand Down Expand Up @@ -2062,7 +2067,7 @@ func (h *handler) putReplication() error {
if err != nil {
return err
}
auditFields := base.AuditFields{base.AuditFieldReplicationID: replicationConfig.ID, base.AuditFieldPayload: body}
auditFields := base.AuditFields{base.AuditFieldReplicationID: replicationConfig.ID, base.AuditFieldPayload: string(body)}
if created {
h.writeStatus(http.StatusCreated, "Created")
base.Audit(h.ctx(), base.AuditIDISGRCreate, auditFields)
Expand Down
6 changes: 4 additions & 2 deletions rest/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ func (h *handler) handlePostResync() error {

action := h.getQuery("action")
regenerateSequences, _ := h.getOptBoolQuery("regenerate_sequences", false)
reset := h.getBoolQuery("reset")

body, err := h.readBody()
if err != nil {
return err
Expand Down Expand Up @@ -353,7 +355,7 @@ func (h *handler) handlePostResync() error {
"database": h.db,
"regenerateSequences": regenerateSequences,
"collections": resyncPostReqBody.Scope,
"reset": h.getBoolQuery("reset"),
"reset": reset,
})
if err != nil {
return err
Expand All @@ -367,7 +369,7 @@ func (h *handler) handlePostResync() error {
base.Audit(h.ctx(), base.AuditIDDatabaseResyncStart, base.AuditFields{
"collections": resyncPostReqBody.Scope,
"regenerate_sequences": regenerateSequences,
"reset": h.getQuery("reset"),
"reset": reset,
})
} else {
dbState := atomic.LoadUint32(&h.db.State)
Expand Down

0 comments on commit 36d1a81

Please sign in to comment.