diff --git a/aws/service/sts/api/apiactions.go b/aws/service/sts/api/apiactions.go new file mode 100644 index 0000000..ada0855 --- /dev/null +++ b/aws/service/sts/api/apiactions.go @@ -0,0 +1,9 @@ +package api + +type STSOperation int + +//go:generate stringer -type=STSOperation $GOFILE +const ( + UnknownOperation STSOperation = iota + AssumeRoleWithWebIdentity +) \ No newline at end of file diff --git a/aws/service/sts/api/stsoperation_string.go b/aws/service/sts/api/stsoperation_string.go new file mode 100644 index 0000000..874e18b --- /dev/null +++ b/aws/service/sts/api/stsoperation_string.go @@ -0,0 +1,24 @@ +// Code generated by "stringer -type=STSOperation apiactions.go"; DO NOT EDIT. + +package api + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[UnknownOperation-0] + _ = x[AssumeRoleWithWebIdentity-1] +} + +const _STSOperation_name = "UnknownOperationAssumeRoleWithWebIdentity" + +var _STSOperation_index = [...]uint8{0, 16, 41} + +func (i STSOperation) String() string { + if i < 0 || i >= STSOperation(len(_STSOperation_index)-1) { + return "STSOperation(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _STSOperation_name[_STSOperation_index[i]:_STSOperation_index[i+1]] +} diff --git a/aws/service/sts/server.go b/aws/service/sts/server.go index 2aeb90c..1094162 100644 --- a/aws/service/sts/server.go +++ b/aws/service/sts/server.go @@ -12,6 +12,7 @@ import ( "github.com/VITObelgium/fakes3pp/aws/credentials" "github.com/VITObelgium/fakes3pp/aws/service" "github.com/VITObelgium/fakes3pp/aws/service/iam" + "github.com/VITObelgium/fakes3pp/aws/service/sts/api" "github.com/VITObelgium/fakes3pp/aws/service/sts/oidc" "github.com/VITObelgium/fakes3pp/aws/service/sts/session" "github.com/VITObelgium/fakes3pp/requestctx" @@ -170,9 +171,10 @@ type stsClaims map[string]interface{} // - RoleSessionName // - WebIdentityToken following the structure func (s *STSServer)assumeRoleWithWebIdentity(ctx context.Context, w http.ResponseWriter, r *http.Request) { - + requestctx.SetOperation(r, api.AssumeRoleWithWebIdentity) claims := stsClaims{} defer slog.InfoContext(ctx, "Auditlog", "claims", claims) + requestctx.AddAccessLogInfo(r, "sts", slog.Any("claims", claims)) token := r.Form.Get(stsWebIdentityToken) diff --git a/cmd/almost-e2e_test.go b/cmd/almost-e2e_test.go index c1b94c4..971a22c 100644 --- a/cmd/almost-e2e_test.go +++ b/cmd/almost-e2e_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "testing" "time" @@ -599,3 +600,55 @@ func TestListingOfS3BucketHasExpectedObjects(t *testing.T) { assertObjectInBucketListing(t, listObjects, "team.txt") } +func TestAuditLogEntry(t *testing.T) { + tearDownProxy, getSignedToken, stsServer, s3Server := testingFixture(t) + defer tearDownProxy() + teardownLog, getCapturedStructuredLogEntries := testutils.CaptureStructuredLogsFixture(t, slog.LevelInfo, nil) + defer teardownLog() + + //GIVEN we run another test scenario + //_GIVEN token for team that does have access + token := getSignedToken("mySubject", time.Minute * 20, session.AWSSessionTags{PrincipalTags: map[string][]string{testTeamTag: {testAllowedTeam}}}) + creds := getCredentialsFromTestStsProxy(t, token, "my-session", testPolicyAllowTeamFolderARN, stsServer) + + //_WHEN access is attempted that required the team information + content, err := getTestBucketObjectContent(t, testRegion1, testTeamFile, credentials.FromAwsFormat(creds), s3Server) + + //_THEN the file content should be returned + if err != nil { + t.Errorf("Could not get team file even though part of correct team. got %s", err) + } + expectedContent := "teamSecret123" + if content != expectedContent { + t.Errorf("Got %s, expected %s", content, expectedContent) + } + + //WHEN we get the logs + logEntries := getCapturedStructuredLogEntries() + //THEN we have 1 access log entry per service (sts & s3) + accesslogEntries := logEntries.GetEntriesWithMsg(t, "Request end") + if len(accesslogEntries) != 2 { + t.Errorf("Invalid number of access log entries. Expected 2 got: %d", len(accesslogEntries)) + } + + //WHEN we check the s3 auditlog entry + s3Entry := accesslogEntries.GetEntriesContainingField(t, "s3")[0] + //Then the operation should be GetObject + operation := s3Entry.GetStringField(t, "Operation") + if operation != "GetObject" { + t.Errorf("Wrong operation present in s3 access log. Expected GetObject got %s", operation) + } + if s3Entry.GetFloat64(t, "HTTP status") != 200 { + t.Error("HTTPS status should have been a 200") + } + + //WHEN we check the sts audit log entry + stsEntry := accesslogEntries.GetEntriesContainingField(t, "sts")[0] + //Then the operation should be AssumeRoleWithWebIdentity + operation = stsEntry.GetStringField(t, "Operation") + if operation != "AssumeRoleWithWebIdentity" { + t.Errorf("Wrong operation present in sts access log. Expected AssumeRoleWithWebIdentity got %s", operation) + + } + +} diff --git a/middleware/observability.go b/middleware/observability.go index 108f815..52ddd91 100644 --- a/middleware/observability.go +++ b/middleware/observability.go @@ -126,6 +126,12 @@ func logFinalRequestDetails(ctx context.Context, lvl slog.Level, startTime time. requestLogAttrs := getRequestCtxLogAttrs(rCtx) requestLogAttrs = append(requestLogAttrs, slog.Int64("Total ms", time.Since(startTime).Milliseconds())) requestLogAttrs = append(requestLogAttrs, slog.Uint64("Bytes sent", rCtx.BytesSent)) + requestLogAttrs = append(requestLogAttrs, slog.Uint64("Bytes received", rCtx.BytesReceived)) + operation := "unknown" + if rCtx.Operation != nil { + operation = rCtx.Operation.String() + } + requestLogAttrs = append(requestLogAttrs, slog.String("Operation", operation)) requestLogAttrs = append(requestLogAttrs, slog.Int("HTTP status", rCtx.HTTPStatus)) requestLogAttrs = append(requestLogAttrs, rCtx.GetAccessLogInfo()...) slog.LogAttrs( @@ -138,7 +144,7 @@ func logFinalRequestDetails(ctx context.Context, lvl slog.Level, startTime time. func getRequestCtxLogAttrs(r *requestctx.RequestCtx) (logAttrs []slog.Attr) { - logAttrs = append(logAttrs, slog.Time("Time", r.Time)) + logAttrs = append(logAttrs, slog.Time("StartTime", r.Time)) logAttrs = append(logAttrs, slog.String("RemoteIP", r.RemoteIP)) logAttrs = append(logAttrs, slog.String("RequestURI", r.RequestURI)) logAttrs = append(logAttrs, slog.String("Referer", r.Referer)) diff --git a/testutils/loghelpers.go b/testutils/loghelpers.go index 6f239cf..8ad3ae9 100644 --- a/testutils/loghelpers.go +++ b/testutils/loghelpers.go @@ -2,6 +2,7 @@ package testutils import ( "bytes" + "encoding/json" "io" "log/slog" "testing" @@ -50,4 +51,87 @@ func CaptureLogFixture(tb testing.TB, lvl slog.Level, fe logging.ForceEnabler) ( } return teardown, getCapturedLogLines -} \ No newline at end of file +} + +//A fixture to capture structured logs +func CaptureStructuredLogsFixture (tb testing.TB, lvl slog.Level, fe logging.ForceEnabler) (teardown func()(), getCapturedLogEntries func()(StructuredLogEntries)) { + teardown, getCapturedLogLines := CaptureLogFixture(tb, lvl, fe) + + getCapturedLogEntries = func() (StructuredLogEntries) { + capturedEntries := StructuredLogEntries{} + for _, line := range getCapturedLogLines() { + entry := StructuredLogEntry{} + err := json.Unmarshal([]byte(line), &entry) + if err != nil { + tb.Errorf("could not convert %s to structured logging entry", line) + tb.Fail() + } else { + capturedEntries = append(capturedEntries, entry) + } + } + return capturedEntries + } + return teardown, getCapturedLogEntries +} + +type StructuredLogEntry map[string]any +type StructuredLogEntries []StructuredLogEntry + +func (s StructuredLogEntry) GetStringField(t testing.TB, fieldName string) (string) { + fieldValue, ok := s[fieldName] + if ok { + stringValue, ok := fieldValue.(string) + if ok { + return stringValue + } + t.Errorf("field %s is not a string", fieldName) + } + t.Errorf("field %s is not present", fieldName) + t.FailNow() + return "" +} + +//Default choice by a JSON unmarshaller for a number +func (s StructuredLogEntry) GetFloat64(t testing.TB, fieldName string) (float64) { + fieldValue, ok := s[fieldName] + if ok { + floatValue, ok := fieldValue.(float64) + if ok { + return floatValue + } + t.Errorf("field %s is not a number", fieldName) + } + t.Errorf("field %s is not present", fieldName) + t.FailNow() + return 0.0 +} + +func (s StructuredLogEntry) GetLevel(t testing.TB) (string) { + return s.GetStringField(t, "level") +} + +func (s StructuredLogEntry) GetMsg(t testing.TB) (string) { + return s.GetStringField(t, "msg") +} + +func (s *StructuredLogEntries) GetEntriesWithMsg(t testing.TB, msgValue string) (StructuredLogEntries) { + filteredEntries := StructuredLogEntries{} + for _, entry := range *s { + msg := entry.GetMsg(t) + if msg == msgValue { + filteredEntries = append(filteredEntries, entry) + } + } + return filteredEntries +} + +func (s *StructuredLogEntries) GetEntriesContainingField(t testing.TB, fieldName string) (StructuredLogEntries) { + filteredEntries := StructuredLogEntries{} + for _, entry := range *s { + _, ok := entry[fieldName] + if ok { + filteredEntries = append(filteredEntries, entry) + } + } + return filteredEntries +}