Skip to content

Commit

Permalink
tests: add coverage for request log
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter Van Bouwel committed Feb 8, 2025
1 parent 5ebf19d commit 11a7068
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 3 deletions.
9 changes: 9 additions & 0 deletions aws/service/sts/api/apiactions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package api

type STSOperation int

//go:generate stringer -type=STSOperation $GOFILE
const (
UnknownOperation STSOperation = iota
AssumeRoleWithWebIdentity
)
24 changes: 24 additions & 0 deletions aws/service/sts/api/stsoperation_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion aws/service/sts/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/VITObelgium/fakes3pp/aws/credentials"
"github.com/VITObelgium/fakes3pp/aws/service"
"github.com/VITObelgium/fakes3pp/aws/service/iam"
"github.com/VITObelgium/fakes3pp/aws/service/sts/api"
"github.com/VITObelgium/fakes3pp/aws/service/sts/oidc"
"github.com/VITObelgium/fakes3pp/aws/service/sts/session"
"github.com/VITObelgium/fakes3pp/requestctx"
Expand Down Expand Up @@ -170,9 +171,10 @@ type stsClaims map[string]interface{}
// - RoleSessionName
// - WebIdentityToken following the structure
func (s *STSServer)assumeRoleWithWebIdentity(ctx context.Context, w http.ResponseWriter, r *http.Request) {

requestctx.SetOperation(r, api.AssumeRoleWithWebIdentity)
claims := stsClaims{}
defer slog.InfoContext(ctx, "Auditlog", "claims", claims)
requestctx.AddAccessLogInfo(r, "sts", slog.Any("claims", claims))

token := r.Form.Get(stsWebIdentityToken)

Expand Down
53 changes: 53 additions & 0 deletions cmd/almost-e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"testing"
"time"
Expand Down Expand Up @@ -599,3 +600,55 @@ func TestListingOfS3BucketHasExpectedObjects(t *testing.T) {
assertObjectInBucketListing(t, listObjects, "team.txt")
}

func TestAuditLogEntry(t *testing.T) {
tearDownProxy, getSignedToken, stsServer, s3Server := testingFixture(t)
defer tearDownProxy()
teardownLog, getCapturedStructuredLogEntries := testutils.CaptureStructuredLogsFixture(t, slog.LevelInfo, nil)
defer teardownLog()

//GIVEN we run another test scenario
//_GIVEN token for team that does have access
token := getSignedToken("mySubject", time.Minute * 20, session.AWSSessionTags{PrincipalTags: map[string][]string{testTeamTag: {testAllowedTeam}}})
creds := getCredentialsFromTestStsProxy(t, token, "my-session", testPolicyAllowTeamFolderARN, stsServer)

//_WHEN access is attempted that required the team information
content, err := getTestBucketObjectContent(t, testRegion1, testTeamFile, credentials.FromAwsFormat(creds), s3Server)

//_THEN the file content should be returned
if err != nil {
t.Errorf("Could not get team file even though part of correct team. got %s", err)
}
expectedContent := "teamSecret123"
if content != expectedContent {
t.Errorf("Got %s, expected %s", content, expectedContent)
}

//WHEN we get the logs
logEntries := getCapturedStructuredLogEntries()
//THEN we have 1 access log entry per service (sts & s3)
accesslogEntries := logEntries.GetEntriesWithMsg(t, "Request end")
if len(accesslogEntries) != 2 {
t.Errorf("Invalid number of access log entries. Expected 2 got: %d", len(accesslogEntries))
}

//WHEN we check the s3 auditlog entry
s3Entry := accesslogEntries.GetEntriesContainingField(t, "s3")[0]
//Then the operation should be GetObject
operation := s3Entry.GetStringField(t, "Operation")
if operation != "GetObject" {
t.Errorf("Wrong operation present in s3 access log. Expected GetObject got %s", operation)
}
if s3Entry.GetFloat64(t, "HTTP status") != 200 {
t.Error("HTTPS status should have been a 200")
}

//WHEN we check the sts audit log entry
stsEntry := accesslogEntries.GetEntriesContainingField(t, "sts")[0]
//Then the operation should be AssumeRoleWithWebIdentity
operation = stsEntry.GetStringField(t, "Operation")
if operation != "AssumeRoleWithWebIdentity" {
t.Errorf("Wrong operation present in sts access log. Expected AssumeRoleWithWebIdentity got %s", operation)

}

}
8 changes: 7 additions & 1 deletion middleware/observability.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ func logFinalRequestDetails(ctx context.Context, lvl slog.Level, startTime time.
requestLogAttrs := getRequestCtxLogAttrs(rCtx)
requestLogAttrs = append(requestLogAttrs, slog.Int64("Total ms", time.Since(startTime).Milliseconds()))
requestLogAttrs = append(requestLogAttrs, slog.Uint64("Bytes sent", rCtx.BytesSent))
requestLogAttrs = append(requestLogAttrs, slog.Uint64("Bytes received", rCtx.BytesReceived))
operation := "unknown"
if rCtx.Operation != nil {
operation = rCtx.Operation.String()
}
requestLogAttrs = append(requestLogAttrs, slog.String("Operation", operation))
requestLogAttrs = append(requestLogAttrs, slog.Int("HTTP status", rCtx.HTTPStatus))
requestLogAttrs = append(requestLogAttrs, rCtx.GetAccessLogInfo()...)
slog.LogAttrs(
Expand All @@ -138,7 +144,7 @@ func logFinalRequestDetails(ctx context.Context, lvl slog.Level, startTime time.


func getRequestCtxLogAttrs(r *requestctx.RequestCtx) (logAttrs []slog.Attr) {
logAttrs = append(logAttrs, slog.Time("Time", r.Time))
logAttrs = append(logAttrs, slog.Time("StartTime", r.Time))
logAttrs = append(logAttrs, slog.String("RemoteIP", r.RemoteIP))
logAttrs = append(logAttrs, slog.String("RequestURI", r.RequestURI))
logAttrs = append(logAttrs, slog.String("Referer", r.Referer))
Expand Down
86 changes: 85 additions & 1 deletion testutils/loghelpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package testutils

import (
"bytes"
"encoding/json"
"io"
"log/slog"
"testing"
Expand Down Expand Up @@ -50,4 +51,87 @@ func CaptureLogFixture(tb testing.TB, lvl slog.Level, fe logging.ForceEnabler) (
}

return teardown, getCapturedLogLines
}
}

//A fixture to capture structured logs
func CaptureStructuredLogsFixture (tb testing.TB, lvl slog.Level, fe logging.ForceEnabler) (teardown func()(), getCapturedLogEntries func()(StructuredLogEntries)) {
teardown, getCapturedLogLines := CaptureLogFixture(tb, lvl, fe)

getCapturedLogEntries = func() (StructuredLogEntries) {
capturedEntries := StructuredLogEntries{}
for _, line := range getCapturedLogLines() {
entry := StructuredLogEntry{}
err := json.Unmarshal([]byte(line), &entry)
if err != nil {
tb.Errorf("could not convert %s to structured logging entry", line)
tb.Fail()
} else {
capturedEntries = append(capturedEntries, entry)
}
}
return capturedEntries
}
return teardown, getCapturedLogEntries
}

type StructuredLogEntry map[string]any
type StructuredLogEntries []StructuredLogEntry

func (s StructuredLogEntry) GetStringField(t testing.TB, fieldName string) (string) {
fieldValue, ok := s[fieldName]
if ok {
stringValue, ok := fieldValue.(string)
if ok {
return stringValue
}
t.Errorf("field %s is not a string", fieldName)
}
t.Errorf("field %s is not present", fieldName)
t.FailNow()
return ""
}

//Default choice by a JSON unmarshaller for a number
func (s StructuredLogEntry) GetFloat64(t testing.TB, fieldName string) (float64) {
fieldValue, ok := s[fieldName]
if ok {
floatValue, ok := fieldValue.(float64)
if ok {
return floatValue
}
t.Errorf("field %s is not a number", fieldName)
}
t.Errorf("field %s is not present", fieldName)
t.FailNow()
return 0.0
}

func (s StructuredLogEntry) GetLevel(t testing.TB) (string) {
return s.GetStringField(t, "level")
}

func (s StructuredLogEntry) GetMsg(t testing.TB) (string) {
return s.GetStringField(t, "msg")
}

func (s *StructuredLogEntries) GetEntriesWithMsg(t testing.TB, msgValue string) (StructuredLogEntries) {
filteredEntries := StructuredLogEntries{}
for _, entry := range *s {
msg := entry.GetMsg(t)
if msg == msgValue {
filteredEntries = append(filteredEntries, entry)
}
}
return filteredEntries
}

func (s *StructuredLogEntries) GetEntriesContainingField(t testing.TB, fieldName string) (StructuredLogEntries) {
filteredEntries := StructuredLogEntries{}
for _, entry := range *s {
_, ok := entry[fieldName]
if ok {
filteredEntries = append(filteredEntries, entry)
}
}
return filteredEntries
}

0 comments on commit 11a7068

Please sign in to comment.