diff --git a/go.mod b/go.mod index d339125..a8b5980 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/tools v0.13.0 // indirect + golang.org/x/tools v0.16.1 // indirect ) require ( @@ -63,11 +63,11 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.25.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect - golang.org/x/net v0.15.0 // indirect + golang.org/x/net v0.19.0 // indirect golang.org/x/oauth2 v0.12.0 // indirect - golang.org/x/sys v0.12.0 // indirect - golang.org/x/term v0.12.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/term v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.3.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/appengine v1.6.8 // indirect diff --git a/go.sum b/go.sum index e28d602..8dec186 100644 --- a/go.sum +++ b/go.sum @@ -168,8 +168,6 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -181,8 +179,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.12.0 h1:smVPGxink+n1ZI5pkQa8y6fZT0RW0MgCO5bFpepy4B4= golang.org/x/oauth2 v0.12.0/go.mod h1:A74bZ3aGXgCY0qaIC9Ahg6Lglin4AMAco8cIv9baba4= @@ -206,18 +204,18 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -230,8 +228,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= +golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/main.go b/main.go index 04ab7cf..24dc585 100644 --- a/main.go +++ b/main.go @@ -235,13 +235,10 @@ func setupAuthenticationServer(listenAddress, tlsCertPath, tlsKeyPath, tlsCaPath srv := grpc.NewServer(grpcOpts...) - authenticator, err := auth.NewAuthenticator( + authenticator := auth.NewAuthenticator( setupLog.WithName("cerberus.authenticator"), ) - if err != nil { - setupLog.Error(err, "unable to create and update authenticator") - return nil, nil, nil, err - } + auth.RegisterServer(srv, authenticator) return listener, srv, authenticator, nil } diff --git a/pkg/auth/authenticator.go b/pkg/auth/authenticator.go index 8733379..480ba42 100644 --- a/pkg/auth/authenticator.go +++ b/pkg/auth/authenticator.go @@ -2,12 +2,8 @@ package auth import ( "context" - "fmt" - "net" "net/http" "net/url" - "path/filepath" - "strings" "sync" "time" @@ -33,6 +29,8 @@ type Authenticator struct { cacheLock sync.RWMutex updateLock sync.Mutex + + validators []AuthenticationValidation } // ExtraHeaders are headers which will be added to the response @@ -84,7 +82,7 @@ func (a *Authenticator) TestAccess(request *Request, wsvc WebservicesCacheEntry) return } - ac, ok := a.accessTokensCache.ReadAccesstoken(token) + ac, ok := a.accessTokensCache.ReadAccessToken(token) if !ok { reason = CerberusReasonTokenNotFound return @@ -92,95 +90,18 @@ func (a *Authenticator) TestAccess(request *Request, wsvc WebservicesCacheEntry) newExtraHeaders.set(CerberusHeaderAccessToken, ac.ObjectMeta.Name) - reason, h := a.testPriority(ac, wsvc) - newExtraHeaders.merge(h) - if reason != "" { - return - } - - reason, h = a.testIPAccess(ac, wsvc, request) - newExtraHeaders.merge(h) - if reason != "" { - return - } - - reason, h = a.testDomainAccess(ac, wsvc, request) - newExtraHeaders.merge(h) - - if !ac.TestAccess(wsvc.Name) { - return + for _, validator := range a.validators { + var headers CerberusExtraHeaders + reason, headers = validator.Validate(&ac, &wsvc, request) + newExtraHeaders.merge(headers) + if reason != "" { + return + } } reason = CerberusReasonOK return } -func (a *Authenticator) testPriority(ac AccessTokensCacheEntry, wsvc WebservicesCacheEntry) (CerberusReason, CerberusExtraHeaders) { - newExtraHeaders := make(CerberusExtraHeaders) - priority := ac.Spec.Priority - minPriority := wsvc.Spec.MinimumTokenPriority - if priority < minPriority { - newExtraHeaders[CerberusHeaderAccessLimitReason] = TokenPriorityLowerThanServiceMinAccessLimit - newExtraHeaders[CerberusHeaderTokenPriority] = fmt.Sprint(priority) - newExtraHeaders[CerberusHeaderWebServiceMinPriority] = fmt.Sprint(minPriority) - return CerberusReasonAccessLimited, newExtraHeaders - } - return "", newExtraHeaders -} - -func (a *Authenticator) testIPAccess(ac AccessTokensCacheEntry, wsvc WebservicesCacheEntry, request *Request) (CerberusReason, CerberusExtraHeaders) { - newExtraHeaders := make(CerberusExtraHeaders) - if len(ac.Spec.AllowedIPs) > 0 { - ipList := make([]string, 0) - - // Retrieve "x-forwarded-for" and "referrer" headers from the request - xForwardedFor := request.Request.Header.Get("x-forwarded-for") - if xForwardedFor != "" { - ips := strings.Split(xForwardedFor, ", ") - ipList = append(ipList, ips...) - } - - // Retrieve "remoteAddr" from the request - remoteAddr := request.Request.RemoteAddr - host, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return CerberusReasonInvalidSourceIp, newExtraHeaders - } - if net.ParseIP(host) == nil { - return CerberusReasonEmptySourceIp, newExtraHeaders - } - ipList = append(ipList, host) - - // Check if IgnoreIP is true, skip IP list check - if !wsvc.Spec.IgnoreIP { - ipAllowed, err := checkIP(ipList, ac.Spec.AllowedIPs) - if err != nil { - return CerberusReasonBadIpList, newExtraHeaders - } - if !ipAllowed { - return CerberusReasonIpNotAllowed, newExtraHeaders - } - } - } - return "", newExtraHeaders -} - -func (a *Authenticator) testDomainAccess(ac AccessTokensCacheEntry, wsvc WebservicesCacheEntry, request *Request) (CerberusReason, CerberusExtraHeaders) { - newExtraHeaders := make(CerberusExtraHeaders) - referrer := request.Request.Header.Get("referrer") - - // Check if IgnoreDomain is true, skip domain list check - if !wsvc.Spec.IgnoreDomain && len(ac.Spec.AllowedDomains) > 0 && referrer != "" { - domainAllowed, err := CheckDomain(referrer, ac.Spec.AllowedDomains) - if err != nil { - return CerberusReasonBadDomainList, newExtraHeaders - } - if !domainAllowed { - return CerberusReasonDomainNotAllowed, newExtraHeaders - } - } - return "", newExtraHeaders -} - // readToken reads token from given Request object and // will return error if it not exists at expected header func (a *Authenticator) readToken(request *Request, wsvc WebservicesCacheEntry) (CerberusReason, string) { @@ -261,52 +182,91 @@ func readRequestContext(request *Request) (wsvc string, ns string, reason Cerber return } +// defineValidators creates a list of validations implemented. +// the validations will run along the order of the list. +func defineValidators() []AuthenticationValidation { + return []AuthenticationValidation{ + &AuthenticatorPriorityValidation{}, + &AuthenticationIPValidation{}, + &AuthenticationDomainValidation{}, + &AuthenticationTokenAccessValidation{}, + } +} + // NewAuthenticator creates new Authenticator object with given logger. // currently it's not returning any error -func NewAuthenticator(logger logr.Logger) (*Authenticator, error) { +func NewAuthenticator(logger logr.Logger) *Authenticator { a := Authenticator{ logger: logger, httpClient: &http.Client{}, } - return &a, nil + a.validators = defineValidators() + return &a } -// checkIP checks if given ip is a member of given CIDR networks or not -// ipAllowList should be CIDR notation of the networks or net.ParseError will be retuned -func checkIP(ips []string, ipAllowList []string) (bool, error) { - for _, ip := range ips { - clientIP := net.ParseIP(ip) +// validateUpstreamAuthRequest validates the service before calling the upstream. +// when calling the upstream authentication, one of read or write tokens must be +// empty and the upstream address must be a valid url. +func validateUpstreamAuthRequest(service WebservicesCacheEntry) CerberusReason { + if service.Spec.UpstreamHttpAuth.ReadTokenFrom == "" || + service.Spec.UpstreamHttpAuth.WriteTokenTo == "" { + return CerberusReasonTargetAuthTokenEmpty + } + if !govalidator.IsRequestURL(service.Spec.UpstreamHttpAuth.Address) { + return CerberusReasonInvalidUpstreamAddress + } + return "" +} - for _, AllowedRangeIP := range ipAllowList { - _, subnet, err := net.ParseCIDR(AllowedRangeIP) - if err != nil { - return false, err - } +// setupUpstreamAuthRequest create request object to call upstream authentication +func setupUpstreamAuthRequest(upstreamHttpAuth *v1alpha1.UpstreamHttpAuthService, request *Request) (*http.Request, error) { + token := request.Request.Header.Get(upstreamHttpAuth.ReadTokenFrom) + req, err := http.NewRequest("GET", upstreamHttpAuth.Address, nil) + if err != nil { + return nil, err + } + req.Header = http.Header{ + upstreamHttpAuth.WriteTokenTo: {token}, + "Content-Type": {"application/json"}, + } + return req, nil +} - if subnet.Contains(clientIP) { - return true, nil - } +// adjustTimeout sets timeout value for httpClient.timeout +func (a *Authenticator) adjustTimeout(timeout int, downstreamDeadline time.Time, hasDownstreamDeadline bool) { + a.httpClient.Timeout = time.Duration(timeout) * time.Millisecond + if hasDownstreamDeadline { + if time.Until(downstreamDeadline)-downstreamDeadlineOffset < a.httpClient.Timeout { + a.httpClient.Timeout = time.Until(downstreamDeadline) - downstreamDeadlineOffset } } - return false, nil } -// CheckDomain checks if given domain will match to one of the GLOB patterns in -// domainAllowedList (the list items should be valid patterns or ErrBadPattern will be returned) -func CheckDomain(domain string, domainAllowedList []string) (bool, error) { - for _, pattern := range domainAllowedList { - pattern = strings.ToLower(pattern) - domain = strings.ToLower(domain) - - matched, err := filepath.Match(pattern, domain) - if err != nil { - return false, err - } - if matched { - return true, nil +// copyUpstreamHeaders copy a listing caring headers from upstream response to +// response headers +func copyUpstreamHeaders(resp *http.Response, extraHeaders *ExtraHeaders, careHeaders []string) { + // Add requested careHeaders to extraHeaders for response + for header, values := range resp.Header { + for _, careHeader := range careHeaders { + if header == careHeader && len(values) > 0 { + (*extraHeaders)[header] = values[0] + break + } } } - return false, nil +} + +// processResponseError handles upstream response headers and translates them to +// meaningful CerberusReason values +func processResponseError(err error) CerberusReason { + if err == nil { + return CerberusReasonNotSet + } + if urlErr, ok := err.(*url.Error); ok && urlErr != nil && urlErr.Timeout() { + return CerberusReasonUpstreamAuthTimeout + } + return CerberusReasonUpstreamAuthFailed + } // checkServiceUpstreamAuth function is designed to validate the request through @@ -315,45 +275,22 @@ func (a *Authenticator) checkServiceUpstreamAuth(service WebservicesCacheEntry, downstreamDeadline, hasDownstreamDeadline := ctx.Deadline() serviceUpstreamAuthCalls.With(AddWithDownstreamDeadline(nil, hasDownstreamDeadline)).Inc() - if service.Spec.UpstreamHttpAuth.ReadTokenFrom == "" { - return false, CerberusReasonSourceAuthTokenEmpty - } - if service.Spec.UpstreamHttpAuth.WriteTokenTo == "" { - return false, CerberusReasonTargetAuthTokenEmpty - } - if !govalidator.IsRequestURL(service.Spec.UpstreamHttpAuth.Address) { - return false, CerberusReasonInvalidUpstreamAddress + if reason := validateUpstreamAuthRequest(service); reason != "" { + return false, reason } - - token := request.Request.Header.Get(service.Spec.UpstreamHttpAuth.ReadTokenFrom) - - // TODO: get http method from webservice crd - req, err := http.NewRequest("GET", service.Spec.UpstreamHttpAuth.Address, nil) + upstreamAuth := service.Spec.UpstreamHttpAuth + req, err := setupUpstreamAuthRequest(&upstreamAuth, request) if err != nil { return false, CerberusReasonUpstreamAuthNoReq } - - req.Header = http.Header{ - service.Spec.UpstreamHttpAuth.WriteTokenTo: {token}, - "Content-Type": {"application/json"}, - } - - a.httpClient.Timeout = time.Duration(service.Spec.UpstreamHttpAuth.Timeout) * time.Millisecond - if hasDownstreamDeadline { - if time.Until(downstreamDeadline)-downstreamDeadlineOffset < a.httpClient.Timeout { - a.httpClient.Timeout = time.Until(downstreamDeadline) - downstreamDeadlineOffset - } - } + a.adjustTimeout(upstreamAuth.Timeout, downstreamDeadline, hasDownstreamDeadline) reqStart := time.Now() resp, err := a.httpClient.Do(req) reqDuration := time.Since(reqStart) - if err != nil { - urlErr, ok := err.(*url.Error) - if ok && urlErr != nil && urlErr.Timeout() { - return false, CerberusReasonUpstreamAuthTimeout - } - return false, CerberusReasonUpstreamAuthFailed + + if reason := processResponseError(err); reason != "" { + return false, reason } labels := AddWithDownstreamDeadline(AddStatusLabel(nil, resp.StatusCode), hasDownstreamDeadline) @@ -363,17 +300,7 @@ func (a *Authenticator) checkServiceUpstreamAuth(service WebservicesCacheEntry, return false, CerberusReasonUnauthorized } // add requested careHeaders to extraHeaders for response - for header, values := range resp.Header { - for _, careHeader := range service.Spec.UpstreamHttpAuth.CareHeaders { - if header == careHeader { - if len(values) > 0 { - (*extraHeaders)[header] = values[0] - } - break - } - } - } - + copyUpstreamHeaders(resp, extraHeaders, service.Spec.UpstreamHttpAuth.CareHeaders) return true, CerberusReasonOK } @@ -383,6 +310,9 @@ func hasUpstreamAuth(service WebservicesCacheEntry) bool { return service.Spec.UpstreamHttpAuth.Address != "" } +// generateResponse initializes defaults for cerberus http result and creates a +// valid response from cerberus reasons and computed headers to inform the client +// that it has the access or not. func generateResponse(ok bool, reason CerberusReason, extraHeaders ExtraHeaders) *Response { var httpStatusCode int if ok { @@ -409,12 +339,16 @@ func generateResponse(ok bool, reason CerberusReason, extraHeaders ExtraHeaders) } } +// merge merges 2 CerberusExtraHeaders and replaces if a key was present before +// with the new value in argument map func (ch CerberusExtraHeaders) merge(h CerberusExtraHeaders) { for key, value := range h { ch[key] = value } } +// set sets the values in CerberusExtraHeaders +// (creates if it's absent and update if it's present) func (ch CerberusExtraHeaders) set(key CerberusHeaderName, value string) { ch[key] = value } diff --git a/pkg/auth/authenticator_cache.go b/pkg/auth/authenticator_cache.go index ff40ceb..f3e1a16 100644 --- a/pkg/auth/authenticator_cache.go +++ b/pkg/auth/authenticator_cache.go @@ -344,8 +344,8 @@ func (c *WebservicesCache) ReadWebservice(wsvc string) (WebservicesCacheEntry, b return r, ok } -// ReadAccesstoken -func (c *AccessTokensCache) ReadAccesstoken(rawToken string) (AccessTokensCacheEntry, bool) { +// ReadAccessToken +func (c *AccessTokensCache) ReadAccessToken(rawToken string) (AccessTokensCacheEntry, bool) { r, ok := (*c)[rawToken] return r, ok } diff --git a/pkg/auth/authenticator_cache_test.go b/pkg/auth/authenticator_cache_test.go new file mode 100644 index 0000000..f01fb46 --- /dev/null +++ b/pkg/auth/authenticator_cache_test.go @@ -0,0 +1,548 @@ +package auth + +import ( + "fmt" + "testing" + + "github.com/go-logr/logr" + "github.com/snapp-incubator/Cerberus/api/v1alpha1" + "github.com/snapp-incubator/Cerberus/pkg/testutils" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type Namer interface { + GetName() string + GetNamespace() string +} + +func localName(meta Namer) string { + return fmt.Sprintf("%s/%s", meta.GetNamespace(), meta.GetName()) +} +func TestEncodeLocalName(t *testing.T) { + type testCase struct { + namespace string + name string + expected string + } + + testCases := []testCase{ + { + namespace: "example", + name: "token1", + expected: "example/token1", + }, + { + namespace: "namespace", + name: "token2", + expected: "namespace/token2", + }, + } + + for _, tc := range testCases { + entry := AccessTokensCacheEntry{} + entry.Name = tc.name + entry.Namespace = tc.namespace + actual := encodeLocalName(entry) + if actual != tc.expected { + t.Errorf("encodeLocalName(%v) = %s; expected %s", entry, actual, tc.expected) + } + } +} + +func TestDecodeLocalName(t *testing.T) { + // Test cases + tests := []struct { + input string + expectedName string + expectedNamespace string + }{ + { + input: "example/token1", + expectedName: "example", + expectedNamespace: "token1", + }, + { + input: "namespace/token2", + expectedName: "namespace", + expectedNamespace: "token2", + }, + { + input: "token3", + expectedName: "token3", + expectedNamespace: "", + }, + } + + // Iterate through test cases + for _, test := range tests { + actualName, actualNamespace := decodeLocalName(test.input) + if actualName != test.expectedName || actualNamespace != test.expectedNamespace { + t.Errorf("decodeLocalName(%s) = (%s, %s); expected (%s, %s)", test.input, actualName, actualNamespace, test.expectedName, test.expectedNamespace) + } + } +} + +func TestBuildNewWebservicesCache(t *testing.T) { + // Create instances of mocked logger, AccessTokensCache, and WebservicesCache + sink := testutils.NewTestLogSink() + // Create an instance of the Authenticator with mocked dependencies + auth := &Authenticator{ + logger: logr.New(sink), + } + + // Create and prepare mock data for Kubernetes resources. + webservices := &v1alpha1.WebServiceList{ + Items: prepareWebservices(2), + } + noNamespaceService := v1alpha1.WebService{ + ObjectMeta: v1.ObjectMeta{ + Name: "nonamespace", + }, + } + webservices.Items = append(webservices.Items, noNamespaceService) + bindings := &v1alpha1.WebserviceAccessBindingList{ + Items: prepareWebserviceAccessBindings(2), + } + + newWebservicesCache := auth.buildNewWebservicesCache(webservices, bindings) + + assert.Equal(t, "info", sink.GetLog(0).Type) + assert.Equal(t, "webservice namespace is empty", sink.GetLog(0).Message) + assert.Equal(t, noNamespaceService.Name, sink.GetLog(0).KeyValues["webservice"]) + + // logs about ignored webservices + bindingLogs := (*sink.Logs)[1 : len(*sink.Logs)-1] + + getBindingfromLogs := func(logs testutils.Logs) []string { + bindings := make([]string, 0) + for _, v := range logs { + bindings = append(bindings, v.KeyValues["binding"].(string)) + } + return bindings + } + + getBindingfromFixtures := func(accessBindings []v1alpha1.WebserviceAccessBinding) []string { + bindings := make([]string, 0) + for _, v := range accessBindings { + bindings = append(bindings, localName(&v.ObjectMeta)) + } + return bindings + } + bindingsNamesFromLog := getBindingfromLogs(bindingLogs) + bindingsNamesFromFixtures := getBindingfromFixtures(bindings.Items) + + assert.ElementsMatch(t, bindingsNamesFromFixtures, bindingsNamesFromLog) + for _, log := range bindingLogs { + assert.Equal(t, "info", log.Type) + assert.Equal(t, "ignored some webservices over binding", log.Message) + } + + assert.Len(t, *newWebservicesCache, 2) +} + +func TestAllowNamespaceAndAdd(t *testing.T) { + wsce := WebservicesCacheEntry{ + allowedNamespacesCache: make(AllowedNamespacesCache), + } + wsce.allowedNamespacesCache["x"] = struct{}{} + wsce.allowedNamespacesCache["y"] = struct{}{} + + // Check adding new namespace directly to the cache interface + wsce.allowedNamespacesCache.add("test1") + assert.Contains(t, wsce.allowedNamespacesCache, "test1") + assert.Len(t, wsce.allowedNamespacesCache, 3) + + // Check adding new namespace with allowNamespace + wsce.allowNamespace("test2") + assert.Contains(t, wsce.allowedNamespacesCache, "test2") + assert.Len(t, wsce.allowedNamespacesCache, 4) + + // Adding x again should do nothing + wsce.allowNamespace("x") + wsce.allowedNamespacesCache.add("x") + assert.Len(t, wsce.allowedNamespacesCache, 4) + +} + +func TestCheckAccessFrom(t *testing.T) { + wsce := WebservicesCacheEntry{ + allowedNamespacesCache: make(AllowedNamespacesCache), + } + wsce.allowedNamespacesCache["x"] = struct{}{} + wsce.allowedNamespacesCache["y"] = struct{}{} + + assert.True(t, wsce.checkAccessFrom("x")) + assert.False(t, wsce.checkAccessFrom("z")) +} + +func TestGetSecretRawTokenMap(t *testing.T) { + // Test case 1: Secrets list is empty + emptySecrets := &corev1.SecretList{} + emptyMap := getSecretRawTokenMap(emptySecrets) + assert.Empty(t, emptyMap, "Result map should be empty for an empty secrets list") + + // Test case 2: Secrets list with one secret containing "token" field + secretWithData := corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "secret1"}, + Data: map[string][]byte{"token": []byte("my-secret-token")}, + } + secretsList := &corev1.SecretList{Items: []corev1.Secret{secretWithData}} + resultMap := getSecretRawTokenMap(secretsList) + assert.Len(t, resultMap, 1, "Result map should contain one entry") + assert.Equal(t, "my-secret-token", resultMap["secret1"], "Token value should be 'my-secret-token'") + + // Test case 3: Secrets list with one secret missing "token" field + secretWithoutToken := corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: "secret2"}, + Data: map[string][]byte{}, + } + secretsList.Items = append(secretsList.Items, secretWithoutToken) + resultMap = getSecretRawTokenMap(secretsList) + assert.Len(t, resultMap, 2, "Result map should contain two entries") + assert.Equal(t, "my-secret-token", resultMap["secret1"], "Token value should be 'my-secret-token'") + assert.Equal(t, "token-field-not-found", resultMap["secret2"], "Token value should be 'token-field-not-found'") +} + +func TestWebservicesCache_AllowWebserviceCallsFromNamespace(t *testing.T) { + // Create a mock WebservicesCache + cache := make(WebservicesCache) + cacheEntry := WebservicesCacheEntry{ + WebService: v1alpha1.WebService{}, + allowedNamespacesCache: make(AllowedNamespacesCache), + } + cacheEntry.allowedNamespacesCache["namespace1"] = struct{}{} + cache["webservice1"] = cacheEntry + + // Test case 1: Allow namespace access to the webservice + err := cache.allowWebserviceCallsFromNamespace("webservice1", "namespace1") + assert.NoError(t, err, "No error should occur") + assert.True(t, cache["webservice1"].checkAccessFrom("namespace1"), "Namespace should have access to the webservice") + + // Test case 2: Try to allow access for a non-existent webservice + err = cache.allowWebserviceCallsFromNamespace("nonexistentwebservice", "namespace1") + assert.Error(t, err, "Error should occur for nonexistent webservice") + assert.EqualError(t, err, "webservice not found in webservices cache", "Error message should indicate webservice not found") +} + +func TestWebservicesCache_validateWebservice(t *testing.T) { + // Mocking a WebservicesCacheEntry + wsc := make(WebservicesCache) + ws := v1alpha1.WebService{} + wsc[ws.LocalName()] = WebservicesCacheEntry{ + WebService: ws, + allowedNamespacesCache: make(AllowedNamespacesCache), + } + assert.NoError(t, wsc.validateWebservice(ws.LocalName())) + assert.ErrorContains(t, wsc.validateWebservice("not-defined"), "webservice not found in webservices cache") +} + +func TestWebservicesCacheEntry_checkAccessFrom(t *testing.T) { + // Mocking a WebservicesCacheEntry with allowed namespaces + wse := WebservicesCacheEntry{ + WebService: v1alpha1.WebService{}, + allowedNamespacesCache: make(AllowedNamespacesCache), + } + + assert.False(t, wse.checkAccessFrom("x")) + + wse.allowedNamespacesCache["x"] = struct{}{} + assert.True(t, wse.checkAccessFrom("x")) +} + +func TestWebservicesCache_CheckAccess(t *testing.T) { + // Create a mock WebservicesCache + cache := make(WebservicesCache) + cacheEntry := WebservicesCacheEntry{ + WebService: v1alpha1.WebService{}, + allowedNamespacesCache: make(AllowedNamespacesCache), + } + cacheEntry.allowedNamespacesCache["namespace1"] = struct{}{} + cache["webservice1"] = cacheEntry + + // Test case 1: Namespace is allowed + allowed, err := cache.checkAccess("webservice1", "namespace1") + assert.NoError(t, err, "No error should occur") + assert.True(t, allowed, "Namespace should have access to the webservice") + + // Test case 2: Namespace is not allowed + allowed, err = cache.checkAccess("webservice1", "namespace2") + assert.NoError(t, err, "No error should occur") + assert.False(t, allowed, "Namespace should not have access to the webservice") + + // Test case 3: Webservice not found + _, err = cache.checkAccess("nonexistentwebservice", "namespace1") + assert.Error(t, err, "Error should occur for nonexistent webservice") + assert.EqualError(t, err, "webservice not found in webservices cache", "Error message should indicate webservice not found") +} + +func TestWebservicesCache_ReadWebservice(t *testing.T) { + cache := make(WebservicesCache) + cacheEntry := WebservicesCacheEntry{ + WebService: v1alpha1.WebService{ + ObjectMeta: v1.ObjectMeta{ + Name: "webservice1", + Namespace: "does not matter", + }, + }, + allowedNamespacesCache: make(AllowedNamespacesCache), + } + cache["webservice1"] = cacheEntry + + value, ok := cache.ReadWebservice("webservice1") + assert.True(t, ok) + assert.Equal(t, cacheEntry, value) + + _, ok = cache.ReadWebservice("webservice2") + assert.False(t, ok) +} + +func TestAccessTokensCache_ReadAccessToken(t *testing.T) { + cache := make(AccessTokensCache) + cacheEntry := AccessTokensCacheEntry{ + AccessToken: v1alpha1.AccessToken{ + ObjectMeta: metav1.ObjectMeta{ + Name: "token1", + Namespace: "random-name", + }, + }, + allowedWebservicesCache: make(AllowedWebservicesCache), + } + cache["token1"] = cacheEntry + + value, ok := cache.ReadAccessToken("token1") + assert.True(t, ok) + assert.Equal(t, cacheEntry, value) + + _, ok = cache.ReadAccessToken("token2") + assert.False(t, ok) +} + +func TestAccessTokensCacheEntry_TestAccess(t *testing.T) { + cacheEntry := AccessTokensCacheEntry{ + AccessToken: v1alpha1.AccessToken{ + ObjectMeta: metav1.ObjectMeta{ + Name: "token1", + Namespace: "random-name", + }, + }, + allowedWebservicesCache: make(AllowedWebservicesCache), + } + cacheEntry.allowedWebservicesCache["webservice1"] = struct{}{} + + assert.True(t, cacheEntry.TestAccess("webservice1")) + assert.False(t, cacheEntry.TestAccess("webservice2")) + +} + +func TestAccessTokensCache_buildAllowedWebservicesCache(t *testing.T) { + cache := make(WebservicesCache) + cacheEntry1 := WebservicesCacheEntry{ + WebService: v1alpha1.WebService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "webservice1", + Namespace: "namespace1", + }, + }, + allowedNamespacesCache: AllowedNamespacesCache{"namespace1": struct{}{}}, + } + cacheEntry2 := WebservicesCacheEntry{ + WebService: v1alpha1.WebService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "webservice2", + }, + }, + allowedNamespacesCache: AllowedNamespacesCache{"namespace1": struct{}{}}, + } + cache[cacheEntry1.WebService.LocalName()] = cacheEntry1 + cache["namespace1/"+cacheEntry2.WebService.Name] = cacheEntry2 + + // Create a mock AccessTokensCacheEntry + accessTokenEntry1 := AccessTokensCacheEntry{ + AccessToken: v1alpha1.AccessToken{ + ObjectMeta: v1.ObjectMeta{ + Name: "token1", + Namespace: "namespace1", + }, + Spec: v1alpha1.AccessTokenSpec{ + AllowedWebservices: []*v1alpha1.WebserviceReference{ + { + Namespace: "namespace1", + Name: "webservice1", + }, + { + Namespace: "namespace2", // Invalid namespace + Name: "webservice2", + }, + }, + }, + }, + allowedWebservicesCache: make(AllowedWebservicesCache), + } + accessTokenEntry2 := AccessTokensCacheEntry{ + AccessToken: v1alpha1.AccessToken{ + ObjectMeta: v1.ObjectMeta{ + Name: "token2", + Namespace: "namespace1", + }, + Spec: v1alpha1.AccessTokenSpec{ + AllowedWebservices: []*v1alpha1.WebserviceReference{ + { + Namespace: "namespace1", + Name: "webservice2", + }, + { + Namespace: "namespace3", // Invalid namespace + Name: "webservice2", + }, + }, + }, + }, + allowedWebservicesCache: make(AllowedWebservicesCache), + } + tokenCache := make(AccessTokensCache) + + tokenCache["token1"] = accessTokenEntry1 + tokenCache["token2"] = accessTokenEntry2 + + // Checking ignores + ignored := tokenCache.buildAllowedWebservicesCache(&cache) + assert.Contains(t, ignored, localName(&accessTokenEntry1)) + assert.Contains(t, ignored, localName(&accessTokenEntry2)) + assert.Contains(t, ignored[localName(&accessTokenEntry1)], accessTokenEntry1.Spec.AllowedWebservices[1]) + assert.Contains(t, ignored[localName(&accessTokenEntry2)], accessTokenEntry2.Spec.AllowedWebservices[1]) + + // Checking token cache allowedWebservicesCache + expectedWebservice1 := accessTokenEntry1.Spec.AllowedWebservices[0] + expectedWebservice2 := accessTokenEntry2.Spec.AllowedWebservices[0] + assert.Contains(t, tokenCache, "token1") + assert.Contains(t, tokenCache["token1"].allowedWebservicesCache, expectedWebservice1.LocalName()) + assert.Contains(t, tokenCache, "token2") + assert.Contains(t, tokenCache["token2"].allowedWebservicesCache, expectedWebservice2.LocalName()) +} + +func TestAccessTokensCacheEntry_buildAllowedWebservicesCache(t *testing.T) { + // Create a mock WebservicesCache + cache := make(WebservicesCache) + cacheEntry1 := WebservicesCacheEntry{ + WebService: v1alpha1.WebService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "webservice1", + Namespace: "namespace1", + }, + }, + allowedNamespacesCache: AllowedNamespacesCache{"namespace1": struct{}{}}, + } + cacheEntry2 := WebservicesCacheEntry{ + WebService: v1alpha1.WebService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "webservice2", + }, + }, + allowedNamespacesCache: AllowedNamespacesCache{"namespace1": struct{}{}}, + } + cache[cacheEntry1.WebService.LocalName()] = cacheEntry1 + cache["namespace1/"+cacheEntry2.WebService.Name] = cacheEntry2 + + // Create a mock AccessTokensCacheEntry + accessToken := AccessTokensCacheEntry{ + AccessToken: v1alpha1.AccessToken{ + ObjectMeta: v1.ObjectMeta{ + Name: "token1", + Namespace: "namespace1", + }, + Spec: v1alpha1.AccessTokenSpec{ + AllowedWebservices: []*v1alpha1.WebserviceReference{ + { + Namespace: "namespace1", + Name: "webservice1", + }, + { + Namespace: "", // Empty namespace should default to AccessToken's namespace + Name: "webservice2", + }, + { + Namespace: "namespace2", // Invalid namespace + Name: "webservice2", + }, + }, + }, + }, + allowedWebservicesCache: make(AllowedWebservicesCache), + } + + // Call the function + ignoredEntries := accessToken.buildAllowedWebservicesCache(&cache) + + // Check cache + assert.Len(t, accessToken.allowedWebservicesCache, 2) + assert.Contains(t, accessToken.allowedWebservicesCache, + cacheEntry1.LocalName()) + + // Check if it adds default namespace with tokens namespace + assert.Contains(t, accessToken.allowedWebservicesCache, + fmt.Sprintf("%s/%s", accessToken.Namespace, cacheEntry2.Name)) + + // Check results + assert.Len(t, ignoredEntries, 1, "There should be one ignored entry") + assert.Equal(t, "namespace2", ignoredEntries[0].Namespace, + "The ignored entry should have namespace 'namespace2'") +} + +func TestAuthenticator_buildNewAccessTokensCache(t *testing.T) { + sink := testutils.NewTestLogSink() + auth := Authenticator{ + logger: logr.New(sink), + } + secrets := &corev1.SecretList{} + tokens := &v1alpha1.AccessTokenList{ + Items: []v1alpha1.AccessToken{ + {}, + }, + } + newWebservicesCache := &WebservicesCache{} + tokens.Items[0].Name = "." + auth.buildNewAccessTokensCache(tokens, secrets, newWebservicesCache) + assert.Equal(t, sink.GetLog(0).Message, "dot character is not allowed in AccessToken name") + + tokens.Items[0].Name = "valid" + tokens.Items[0].Namespace = "." + newWebservicesCache = &WebservicesCache{} + auth.buildNewAccessTokensCache(tokens, secrets, newWebservicesCache) + + assert.Equal(t, sink.GetLog(2).Message, "dot character is not allowed in AccessToken namespace") + + tokens.Items[0].Name = "valid" + tokens.Items[0].Namespace = "valid" + newWebservicesCache = &WebservicesCache{} + auth.buildNewAccessTokensCache(tokens, secrets, newWebservicesCache) + assert.Equal(t, sink.GetLog(4).Message, "unable to find secret for accesstoken") + + secrets.Items = []corev1.Secret{ + { + ObjectMeta: v1.ObjectMeta{ + Name: "valid.valid", + }, + }, + } + newWebservicesCache = &WebservicesCache{} + auth.buildNewAccessTokensCache(tokens, secrets, newWebservicesCache) + assert.Equal(t, sink.GetLog(6).Message, "corresponding secret for accesstoken does not contain token field") + + secrets.Items = []corev1.Secret{ + { + ObjectMeta: v1.ObjectMeta{ + Name: "valid.valid", + }, + Data: map[string][]byte{ + "token": []byte("test-token"), + }, + }, + } + newWebservicesCache = &WebservicesCache{} + tokenCache := auth.buildNewAccessTokensCache(tokens, secrets, newWebservicesCache) + assert.Equal(t, tokens.Items[0], (*tokenCache)["test-token"].AccessToken) + +} diff --git a/pkg/auth/authenticator_filters.go b/pkg/auth/authenticator_filters.go new file mode 100644 index 0000000..42900bf --- /dev/null +++ b/pkg/auth/authenticator_filters.go @@ -0,0 +1,159 @@ +package auth + +import ( + "fmt" + "net" + "net/http" + "path/filepath" + "strings" +) + +// AuthenticationValidation Validation for IP restrictions +type AuthenticationValidation interface { + Validate(ac *AccessTokensCacheEntry, wc *WebservicesCacheEntry, request *Request) (CerberusReason, CerberusExtraHeaders) +} + +type AuthenticatorPriorityValidation struct{} + +func (apt *AuthenticatorPriorityValidation) Validate(ac *AccessTokensCacheEntry, + wsvc *WebservicesCacheEntry, _ *Request) (CerberusReason, CerberusExtraHeaders) { + + newExtraHeaders := make(CerberusExtraHeaders) + priority := ac.Spec.Priority + minPriority := wsvc.Spec.MinimumTokenPriority + if priority < minPriority { + newExtraHeaders[CerberusHeaderAccessLimitReason] = TokenPriorityLowerThanServiceMinAccessLimit + newExtraHeaders[CerberusHeaderTokenPriority] = fmt.Sprint(priority) + newExtraHeaders[CerberusHeaderWebServiceMinPriority] = fmt.Sprint(minPriority) + return CerberusReasonAccessLimited, newExtraHeaders + } + return CerberusReasonNotSet, newExtraHeaders +} + +type AuthenticationIPValidation struct{} + +// getIPListFromRequest extract IP addresses from request and it's headers +func getIPListFromRequest(request *http.Request) (CerberusReason, []string) { + ipList := make([]string, 0) + + // Retrieve "x-forwarded-for" and "referrer" headers from the request + xForwardedFor := request.Header.Get("x-forwarded-for") + if xForwardedFor != "" { + ips := strings.Split(xForwardedFor, ", ") + ipList = append(ipList, ips...) + } + + // Retrieve "remoteAddr" from the request + remoteAddr := request.RemoteAddr + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return CerberusReasonInvalidSourceIp, nil + } + + if net.ParseIP(host) == nil { + return CerberusReasonEmptySourceIp, nil + } + + ipList = append(ipList, host) + return CerberusReasonNotSet, ipList +} + +// Validate validates IP access restrictions +func (ait *AuthenticationIPValidation) Validate( + ac *AccessTokensCacheEntry, wsvc *WebservicesCacheEntry, request *Request) (CerberusReason, CerberusExtraHeaders) { + newExtraHeaders := make(CerberusExtraHeaders) + if len(ac.Spec.AllowedIPs) > 0 { + + reason, ipList := getIPListFromRequest(&request.Request) + if reason != CerberusReasonNotSet { + return reason, newExtraHeaders + } + + // Check if IgnoreIP is true, skip IP list check + if !wsvc.Spec.IgnoreIP { + ipAllowed, err := checkIP(ipList, ac.Spec.AllowedIPs) + if err != nil { + return CerberusReasonBadIpList, newExtraHeaders + } + if !ipAllowed { + return CerberusReasonIpNotAllowed, newExtraHeaders + } + } + } + return CerberusReasonNotSet, newExtraHeaders + +} + +// checkIP checks if given ip is a member of given CIDR networks or not +// ipAllowList should be CIDR notation of the networks or net.ParseError will be retuned +func checkIP(ips []string, ipAllowList []string) (bool, error) { + for _, ip := range ips { + clientIP := net.ParseIP(ip) + + for _, AllowedRangeIP := range ipAllowList { + _, subnet, err := net.ParseCIDR(AllowedRangeIP) + if err != nil { + return false, err + } + + if subnet.Contains(clientIP) { + return true, nil + } + } + } + return false, nil +} + +// AuthenticationDomainValidation validates for domain definitions +type AuthenticationDomainValidation struct{} + +// Validate checks domain restrictions +func (adv *AuthenticationDomainValidation) Validate(ac *AccessTokensCacheEntry, + wsvc *WebservicesCacheEntry, request *Request) (CerberusReason, CerberusExtraHeaders) { + + newExtraHeaders := make(CerberusExtraHeaders) + referrer := request.Request.Header.Get("referrer") + + // Check if IgnoreDomain is true, skip domain list check + if !wsvc.Spec.IgnoreDomain && len(ac.Spec.AllowedDomains) > 0 && referrer != "" { + domainAllowed, err := CheckDomain(referrer, ac.Spec.AllowedDomains) + if err != nil { + return CerberusReasonBadDomainList, newExtraHeaders + } + if !domainAllowed { + return CerberusReasonDomainNotAllowed, newExtraHeaders + } + } + return CerberusReasonNotSet, newExtraHeaders + +} + +// CheckDomain checks if given domain will match to one of the GLOB patterns in +// domainAllowedList (the list items should be valid patterns or ErrBadPattern will be returned) +func CheckDomain(domain string, domainAllowedList []string) (bool, error) { + for _, pattern := range domainAllowedList { + pattern = strings.ToLower(pattern) + domain = strings.ToLower(domain) + + matched, err := filepath.Match(pattern, domain) + if err != nil { + return false, err + } + if matched { + return true, nil + } + } + return false, nil +} + +// AuthenticationTokenAccessValidation check for token and webservice access +type AuthenticationTokenAccessValidation struct{} + +// Validate checks token and webservice access +func (adv *AuthenticationTokenAccessValidation) Validate(ac *AccessTokensCacheEntry, + wsvc *WebservicesCacheEntry, request *Request) (CerberusReason, CerberusExtraHeaders) { + if !ac.TestAccess(wsvc.Name) { + return CerberusReasonWebserviceNotAllowed, CerberusExtraHeaders{} + } + return CerberusReasonNotSet, CerberusExtraHeaders{} +} diff --git a/pkg/auth/authenticator_filters_test.go b/pkg/auth/authenticator_filters_test.go new file mode 100644 index 0000000..17beabd --- /dev/null +++ b/pkg/auth/authenticator_filters_test.go @@ -0,0 +1,222 @@ +package auth + +import ( + "net/http" + "testing" + + "github.com/snapp-incubator/Cerberus/api/v1alpha1" + "github.com/stretchr/testify/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestCheckIP(t *testing.T) { + // Test case 1: Empty IP allow list + ips := []string{"192.168.1.1"} + ipAllowList := []string{} + allowed, err := checkIP(ips, ipAllowList) + assert.NoError(t, err, "No error should occur") + assert.False(t, allowed, "IP should not be allowed") + + // Test case 2: IP is allowed + ips = []string{"192.168.1.1"} + ipAllowList = []string{"192.168.1.0/24"} + allowed, err = checkIP(ips, ipAllowList) + assert.NoError(t, err, "No error should occur") + assert.True(t, allowed, "IP should be allowed") + + // Test case 3: IP is not allowed + ips = []string{"192.168.2.1"} + ipAllowList = []string{"192.168.1.0/24"} + allowed, err = checkIP(ips, ipAllowList) + assert.NoError(t, err, "No error should occur") + assert.False(t, allowed, "IP should not be allowed") + + // Test case 4: Error while parsing IP allow list + ips = []string{"192.168.1.1"} + ipAllowList = []string{"invalidCIDR"} + allowed, err = checkIP(ips, ipAllowList) + assert.Error(t, err, "Error should occur") + assert.False(t, allowed, "IP should not be allowed due to error") + assert.EqualError(t, err, "invalid CIDR address: invalidCIDR", "Error message should indicate invalid CIDR") +} + +func TestCheckDomainComplex(t *testing.T) { + testCases := []struct { + domain string + domainAllowed []string + expectedResult bool + }{ + // Exact domain matches + {"example.com", []string{"example.com"}, true}, + {"sub.example.com", []string{"sub.example.com"}, true}, + {"sub.sub.example.com", []string{"sub.sub.example.com"}, true}, + + // Wildcard prefix and suffix matches + {"sub.example.com", []string{"*.example.com"}, true}, + {"example.net", []string{"example.*"}, true}, + + // Multiple patterns with mixed results + {"example.com", []string{"example.net", "*.example.com", "example.*"}, true}, + {"sub.sub.example.net", []string{"*.example.com", "example.*"}, false}, + {"example.org", []string{"*.example.com", "example.*"}, true}, + + // Case-insensitive matching + {"ExAmPlE.CoM", []string{"example.com"}, true}, + + // Character class [a-z0-9] + {"example1.com", []string{"example[0-9].com"}, true}, + {"examplea.com", []string{"example[a-z].com"}, true}, + {"exampleA.com", []string{"example[a-z].com"}, true}, + + // Multiple * wildcards + {"sub.sub.example.net", []string{"*.sub.*.net"}, true}, + {"sub.sub.example.net", []string{"*.*.*.net"}, true}, + {"sub.sub.example.net", []string{"*.example.net"}, true}, + + // ? wildcard character + {"example1.com", []string{"example?.com"}, true}, + {"example12.com", []string{"example?.com"}, false}, + } + + for _, tc := range testCases { + result, err := CheckDomain(tc.domain, tc.domainAllowed) + if result != tc.expectedResult { + t.Errorf("Domain: %s, Expected: %v, Got: %v", tc.domain, tc.expectedResult, result) + } + if err != nil { + t.Errorf("Domain: %s, Expected Error: nil, Got Error: %v", tc.domain, err) + } + } +} + +func TestAuthenticationDomainValidation_Validate(t *testing.T) { + // Test case 1: When IgnoreDomain is false, and referrer is in allowed domains + ac := &AccessTokensCacheEntry{ + AccessToken: v1alpha1.AccessToken{ + Spec: v1alpha1.AccessTokenSpec{AllowedDomains: []string{"example.com", "example.org"}}, + }, + } + wsvc := &WebservicesCacheEntry{WebService: v1alpha1.WebService{Spec: v1alpha1.WebServiceSpec{IgnoreDomain: false}}} + + request := &Request{Request: http.Request{Header: http.Header{}}} + request.Request.Header.Set("referrer", "example.com") + + auth := &AuthenticationDomainValidation{} + wsvc.Spec.IgnoreDomain = false + reason, headers := auth.Validate(ac, wsvc, request) + assert.Equal(t, CerberusReasonNotSet, reason, "Expected reason should be NotSet") + assert.Empty(t, headers, "Expected headers should be empty") + + // Test case 2: When IgnoreDomain is true + request.Request.Header.Set("referrer", "random") + wsvc.Spec.IgnoreDomain = true + reason, headers = auth.Validate(ac, wsvc, request) + assert.Equal(t, CerberusReasonNotSet, reason, "Expected reason should be NotSet") + assert.Empty(t, headers, "Expected headers should be empty") + + // Test case 3: When IgnoreDomain is true, and referrer is not in allowed domains + wsvc.Spec.IgnoreDomain = false + request.Request.Header.Set("referrer", "x.com") + reason, headers = auth.Validate(ac, wsvc, request) + assert.Equal(t, CerberusReasonDomainNotAllowed, reason, "Expected reason should be DomainNotAllowed") + assert.Empty(t, headers, "Expected headers should be empty") + + // Test case 4: When IgnoreDomain is true, and referrer is bad + wsvc.Spec.IgnoreDomain = false + request.Request.Header.Set("referrer", "x.com") + ac.Spec.AllowedDomains = []string{"["} + reason, headers = auth.Validate(ac, wsvc, request) + assert.Equal(t, CerberusReasonBadDomainList, reason, "Expected reason should be BadDomainList") + assert.Empty(t, headers, "Expected headers should be empty") + + // Test case 5: When no allowed domains are specified + ac.Spec.AllowedDomains = nil + wsvc.Spec.IgnoreDomain = false + reason, headers = auth.Validate(ac, wsvc, request) + assert.Equal(t, CerberusReasonNotSet, reason, "Expected reason should be NotSet") + assert.Empty(t, headers, "Expected headers should be empty") +} + +func TestAuthenticationTokenAccessValidation_Validate(t *testing.T) { + wsvc := WebservicesCacheEntry{} + wsvc.allowedNamespacesCache = make(AllowedNamespacesCache) + wsvc.allowedNamespacesCache["test-ns"] = struct{}{} + wsvc.Name = "test-ws" + wsvc.Namespace = "test-ns" + ac := AccessTokensCacheEntry{} + ac.AccessToken = v1alpha1.AccessToken{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-token", + Namespace: "test-ns", + }, + } + ac.allowedWebservicesCache = make(AllowedWebservicesCache) + ac.allowedWebservicesCache["test-ws"] = struct{}{} + + atcv := AuthenticationTokenAccessValidation{} + + reason, _ := atcv.Validate(&ac, &wsvc, nil) + assert.Equal(t, reason, CerberusReasonNotSet) + + ac.allowedWebservicesCache = make(AllowedWebservicesCache) + ac.allowedWebservicesCache["test-ws-2"] = struct{}{} + reason, _ = atcv.Validate(&ac, &wsvc, nil) + assert.Equal(t, reason, CerberusReasonWebserviceNotAllowed) +} + +func TestGetIPListFromRequest(t *testing.T) { + // Test case 1: Valid x-forwarded-for header and remote address + request := &http.Request{ + Header: http.Header{"X-Forwarded-For": {"192.0.2.1, 198.51.100.2"}}, + RemoteAddr: "192.0.2.3:12345", + } + reason, ipList := getIPListFromRequest(request) + assert.Equal(t, CerberusReasonNotSet, reason) + assert.ElementsMatch(t, []string{"192.0.2.1", "198.51.100.2", "192.0.2.3"}, ipList) + + // Test case 2: Valid remote address only + request = &http.Request{ + RemoteAddr: "192.0.2.3:12345", + } + reason, ipList = getIPListFromRequest(request) + assert.Equal(t, CerberusReasonNotSet, reason) + assert.ElementsMatch(t, []string{"192.0.2.3"}, ipList) + + // Test case 3: Invalid remote address + request = &http.Request{ + RemoteAddr: "[invalid]", + } + reason, ipList = getIPListFromRequest(request) + assert.Equal(t, CerberusReasonInvalidSourceIp, reason) + assert.Nil(t, ipList) + + // Test case 4: Empty x-forwarded-for header, valid remote address + request = &http.Request{ + RemoteAddr: "192.0.2.3:12345", + } + reason, ipList = getIPListFromRequest(request) + assert.Equal(t, CerberusReasonNotSet, reason) + assert.ElementsMatch(t, []string{"192.0.2.3"}, ipList) + + // Test case 5: Empty x-forwarded-for header, invalid remote address + request = &http.Request{ + RemoteAddr: "[invalid]", + } + reason, ipList = getIPListFromRequest(request) + assert.Equal(t, CerberusReasonInvalidSourceIp, reason) + assert.Nil(t, ipList) + + // Test case 6: Empty x-forwarded-for header and remote address + request = &http.Request{} + reason, ipList = getIPListFromRequest(request) + assert.Equal(t, CerberusReasonInvalidSourceIp, reason) + assert.Nil(t, ipList) + + // Test case 7: Empty x-forwarded-for header and invalid remote address + request = &http.Request{ + RemoteAddr: "192.168.1.1.1:80", + } + reason, ipList = getIPListFromRequest(request) + assert.Equal(t, CerberusReasonEmptySourceIp, reason) + assert.Nil(t, ipList) +} diff --git a/pkg/auth/authenticator_test.go b/pkg/auth/authenticator_test.go index 2d064b3..247892e 100644 --- a/pkg/auth/authenticator_test.go +++ b/pkg/auth/authenticator_test.go @@ -1,19 +1,20 @@ package auth import ( + "context" + "errors" "fmt" "net/http" + "net/url" "testing" + "time" cerberusv1alpha1 "github.com/snapp-incubator/Cerberus/api/v1alpha1" + "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "sigs.k8s.io/controller-runtime/pkg/client" - - "context" - - "github.com/stretchr/testify/assert" "sigs.k8s.io/controller-runtime/pkg/client/fake" ) @@ -24,12 +25,10 @@ var ( testDomain = "example.com" // Use a valid domain for testing subjects = generateSubjects(2) // Generates ["subject-1", "subject-2"] webservices = generateWebservices(2) // Generates ["webservice-1", "webservice-2"] - // tokenSecretRef = generateTokenSecretRef() - ) +// generateIPAllowList Generate a large IP allow list with unique subnets func generateIPAllowList(size int) []string { - // Generate a large IP allow list with unique subnets ipList := make([]string, size) for i := 0; i < size; i++ { ipList[i] = fmt.Sprintf("192.168.0.%d", i) @@ -37,8 +36,8 @@ func generateIPAllowList(size int) []string { return ipList } +// generateDomainAllowList Generate a large domain allow list with unique patterns func generateDomainAllowList(size int) []string { - // Generate a large domain allow list with unique patterns domainList := make([]string, size) for i := 0; i < size; i++ { domainList[i] = fmt.Sprintf("example%d.com", i) @@ -47,11 +46,7 @@ func generateDomainAllowList(size int) []string { return domainList } -// func generateTokenSecretRef() *corev1.LocalObjectReference { -// example := &corev1.LocalObjectReference{Name: "example-token-secret-ref"} -// return example -// } - +// generateSubjects create a list of subjects in form of string array func generateSubjects(subjectCount int) []string { subject := make([]string, subjectCount) @@ -62,6 +57,7 @@ func generateSubjects(subjectCount int) []string { return subject } +// generateWebservices create a list of webservices in form of LocalWebserviceReference array func generateWebservices(webserviceCount int) []cerberusv1alpha1.LocalWebserviceReference { webservice := make([]cerberusv1alpha1.LocalWebserviceReference, webserviceCount) @@ -88,55 +84,6 @@ func BenchmarkCheckDomainWithLargeInput(b *testing.B) { } } -func TestCheckDomainComplex(t *testing.T) { - testCases := []struct { - domain string - domainAllowed []string - expectedResult bool - }{ - // Exact domain matches - {"example.com", []string{"example.com"}, true}, - {"sub.example.com", []string{"sub.example.com"}, true}, - {"sub.sub.example.com", []string{"sub.sub.example.com"}, true}, - - // Wildcard prefix and suffix matches - {"sub.example.com", []string{"*.example.com"}, true}, - {"example.net", []string{"example.*"}, true}, - - // Multiple patterns with mixed results - {"example.com", []string{"example.net", "*.example.com", "example.*"}, true}, - {"sub.sub.example.net", []string{"*.example.com", "example.*"}, false}, - {"example.org", []string{"*.example.com", "example.*"}, true}, - - // Case-insensitive matching - {"ExAmPlE.CoM", []string{"example.com"}, true}, - - // Character class [a-z0-9] - {"example1.com", []string{"example[0-9].com"}, true}, - {"examplea.com", []string{"example[a-z].com"}, true}, - {"exampleA.com", []string{"example[a-z].com"}, true}, - - // Multiple * wildcards - {"sub.sub.example.net", []string{"*.sub.*.net"}, true}, - {"sub.sub.example.net", []string{"*.*.*.net"}, true}, - {"sub.sub.example.net", []string{"*.example.net"}, true}, - - // ? wildcard character - {"example1.com", []string{"example?.com"}, true}, - {"example12.com", []string{"example?.com"}, false}, - } - - for _, tc := range testCases { - result, err := CheckDomain(tc.domain, tc.domainAllowed) - if result != tc.expectedResult { - t.Errorf("Domain: %s, Expected: %v, Got: %v", tc.domain, tc.expectedResult, result) - } - if err != nil { - t.Errorf("Domain: %s, Expected Error: nil, Got Error: %v", tc.domain, err) - } - } -} - func TestReadService(t *testing.T) { authenticator := &Authenticator{ accessTokensCache: &AccessTokensCache{}, @@ -173,7 +120,7 @@ func TestReadService(t *testing.T) { expectedReason CerberusReason expectedCacheEntry WebservicesCacheEntry }{ - {wsvc, true, "", webservice}, + {wsvc, true, CerberusReasonNotSet, webservice}, {"nonexistent_service", false, CerberusReasonWebserviceNotFound, WebservicesCacheEntry{}}, } @@ -216,13 +163,13 @@ func TestReadToken(t *testing.T) { reason, token := authenticator.readToken(request, webservice) - if reason != "" { - t.Errorf("Expected reason to be empty.") - } + assert.Equal(t, reason, CerberusReasonNotSet, "Expected reason to be empty.") + assert.Equalf(t, token, "test-token", "Expected token to be 'test-token'. Got: %s", token) + + webservice.Spec.LookupHeader = "" + reason, _ = authenticator.readToken(request, webservice) + assert.Equal(t, reason, CerberusReasonLookupIdentifierEmpty, "lookup-identifier-empty") - if token != "test-token" { - t.Errorf("Expected token to be 'test-token'. Got: %s", token) - } } func TestUpdateCache(t *testing.T) { @@ -376,6 +323,9 @@ func TestTestAccessBadIPList(t *testing.T) { authenticator := &Authenticator{ accessTokensCache: &AccessTokensCache{}, webservicesCache: &WebservicesCache{}, + validators: []AuthenticationValidation{ + &AuthenticationIPValidation{}, + }, } tokenEntry := AccessTokensCacheEntry{ @@ -431,6 +381,7 @@ func TestTestAccessLimited(t *testing.T) { authenticator := &Authenticator{ accessTokensCache: &AccessTokensCache{}, webservicesCache: &WebservicesCache{}, + validators: defineValidators(), } // Assuming an a token with lower Priority than WebService threshold @@ -484,6 +435,8 @@ func TestTestAccessLimited(t *testing.T) { } +// setupTestEnvironment create test environment for kubernetes client enabled +// tests to mock the apis. func setupTestEnvironment(t *testing.T) (client.Client, *Authenticator) { // Initialize a Kubernetes client's scheme. scheme := runtime.NewScheme() @@ -508,6 +461,7 @@ func setupTestEnvironment(t *testing.T) (client.Client, *Authenticator) { return fakeClient, authenticator } +// prepareAccessTokens create a list of test AccessTokens for tests func prepareAccessTokens(count int) []cerberusv1alpha1.AccessToken { // Create and prepare access tokens with unique names. @@ -526,6 +480,7 @@ func prepareAccessTokens(count int) []cerberusv1alpha1.AccessToken { return accessTokens } +// prepareWebserviceAccessBindings create a list of test WebserviceAccessBindings for tests func prepareWebserviceAccessBindings(count int) []cerberusv1alpha1.WebserviceAccessBinding { // Create and prepare webservice access bindings with unique names. bindings := make([]cerberusv1alpha1.WebserviceAccessBinding, count) @@ -542,6 +497,7 @@ func prepareWebserviceAccessBindings(count int) []cerberusv1alpha1.WebserviceAcc return bindings } +// prepareWebservices creates a list of WebServices for tests func prepareWebservices(count int) []cerberusv1alpha1.WebService { // Create and prepare webservice resources with unique names. webservices := make([]cerberusv1alpha1.WebService, count) @@ -562,6 +518,7 @@ func prepareWebservices(count int) []cerberusv1alpha1.WebService { return webservices } +// prepareWebservices creates a list of WebServices for tests func prepareSecrets(count int) []corev1.Secret { // Create and prepare secrets with unique names. secrets := make([]corev1.Secret, count) @@ -577,6 +534,7 @@ func prepareSecrets(count int) []corev1.Secret { return secrets } +// createAccessTokens creates a list of access tokens in kubernetes fake client func createAccessTokens(t *testing.T, fakeClient client.Client, accessTokens ...cerberusv1alpha1.AccessToken) { ctx := context.Background() for _, token := range accessTokens { @@ -584,6 +542,7 @@ func createAccessTokens(t *testing.T, fakeClient client.Client, accessTokens ... } } +// createBindings creates a list of bindings in kubernetes fake client func createBindings(t *testing.T, fakeClient client.Client, bindings ...cerberusv1alpha1.WebserviceAccessBinding) { ctx := context.Background() for _, binding := range bindings { @@ -591,6 +550,7 @@ func createBindings(t *testing.T, fakeClient client.Client, bindings ...cerberus } } +// createWebservices creates a list of Webservices in kubernetes fake client func createWebservices(t *testing.T, fakeClient client.Client, webservices ...cerberusv1alpha1.WebService) { ctx := context.Background() for _, service := range webservices { @@ -598,6 +558,7 @@ func createWebservices(t *testing.T, fakeClient client.Client, webservices ...ce } } +// createSecrets creates a list of secrets in kubernetes fake client func createSecrets(t *testing.T, fakeClient client.Client, secrets ...corev1.Secret) { ctx := context.Background() for _, secret := range secrets { @@ -605,6 +566,8 @@ func createSecrets(t *testing.T, fakeClient client.Client, secrets ...corev1.Sec } } +// assertCachesPopulated asserts that webservicesCache is populated. +// authenticator.accessTokensCache is not supported yet func assertCachesPopulated(t *testing.T, authenticator *Authenticator) { authenticator.cacheLock.RLock() defer authenticator.cacheLock.RUnlock() @@ -613,3 +576,401 @@ func assertCachesPopulated(t *testing.T, authenticator *Authenticator) { //assert.NotEmpty(t, authenticator.accessTokensCache) assert.NotEmpty(t, authenticator.webservicesCache) } + +func TestToExtraHeaders(t *testing.T) { + // Test case 1: Empty input + emptyInput := CerberusExtraHeaders{} + result := toExtraHeaders(emptyInput) + assert.Empty(t, result, "Result should be empty for empty input") + + // Test case 2: Input with multiple headers + input := CerberusExtraHeaders{ + "Header1": "Value1", + "Header2": "Value2", + "Header3": "Value3", + } + expected := ExtraHeaders{ + "Header1": "Value1", + "Header2": "Value2", + "Header3": "Value3", + } + result = toExtraHeaders(input) + assert.Equal(t, expected, result, "Result should match expected extra headers") + + // Test case 3: Input with a single header + singleHeaderInput := CerberusExtraHeaders{ + "Header": "Value", + } + singleExpected := ExtraHeaders{ + "Header": "Value", + } + singleResult := toExtraHeaders(singleHeaderInput) + assert.Equal(t, singleExpected, singleResult, "Result should match expected extra headers") +} + +func TestCheckDomain(t *testing.T) { + // Test case 1: Domain matches one of the GLOB patterns + domain := "example.com" + domainAllowedList := []string{"*.com", "*.org", "example.*"} + matched, err := CheckDomain(domain, domainAllowedList) + assert.NoError(t, err, "No error should occur") + assert.True(t, matched, "Domain should match one of the GLOB patterns") + + // Test case 2: Domain matches one of the GLOB patterns and it + // does not care about the case + domain = "ExampLe.Com" + domainAllowedList = []string{"*.CoM", "*.org", "eXample.*"} + matched, err = CheckDomain(domain, domainAllowedList) + assert.NoError(t, err, "No error should occur") + assert.True(t, matched, "Domain should match one of the GLOB patterns") + + // Test case 3: Domain does not match any of the GLOB patterns + domain = "example.net" + domainAllowedList = []string{"*.com", "*.org", "google.*"} + matched, err = CheckDomain(domain, domainAllowedList) + assert.NoError(t, err, "No error should occur") + assert.False(t, matched, "Domain should not match any of the GLOB patterns") + + // Test case 4: Error while matching the domain with GLOB patterns + domain = "example.com" + domainAllowedList = []string{"[invalid pattern"} + matched, err = CheckDomain(domain, domainAllowedList) + assert.Error(t, err, "Error should occur") + assert.False(t, matched, "Domain should not match due to error") + assert.EqualError(t, err, "syntax error in pattern", "Error message should indicate syntax error") +} + +func Test_hasUpstreamAuth(t *testing.T) { + wsce := WebservicesCacheEntry{ + WebService: cerberusv1alpha1.WebService{ + Spec: cerberusv1alpha1.WebServiceSpec{ + UpstreamHttpAuth: cerberusv1alpha1.UpstreamHttpAuthService{ + Address: "", + }, + }, + }, + } + assert.False(t, hasUpstreamAuth(wsce)) + wsce.Spec.UpstreamHttpAuth.Address = "anything" + + assert.True(t, hasUpstreamAuth(wsce)) + +} + +func TestCerberusExtraHeaders_merge_set(t *testing.T) { + h := CerberusExtraHeaders{} + + h.merge(CerberusExtraHeaders{"a": "b"}) + + assert.Len(t, h, 1) + assert.Equal(t, h["a"], "b") + + h.merge(CerberusExtraHeaders{"a": "c", "x": "y", "z": "w"}) + assert.Len(t, h, 3) + assert.Equal(t, h["a"], "c") //Overwritten + assert.Equal(t, h["x"], "y") // Newly added + assert.Equal(t, h["z"], "w") // Newly added + + h.set(CerberusHeaderAccessToken, "test") + + assert.Equal(t, h[CerberusHeaderAccessToken], "test") +} + +func Test_readRequestContext(t *testing.T) { + // Test case 1: Valid context keys + request := &Request{ + Context: map[string]string{ + "webservice": "example-service", + "namespace": "example-namespace", + }, + } + wsvc, ns, reason := readRequestContext(request) + assert.Equal(t, "example-service", wsvc, "Webservice should match expected value") + assert.Equal(t, "example-namespace", ns, "Namespace should match expected value") + assert.Empty(t, reason, "Reason should be empty") + + // Test case 2: Webservice key missing + request = &Request{ + Context: map[string]string{ + "namespace": "example-namespace", + }, + } + wsvc, ns, reason = readRequestContext(request) + assert.Empty(t, wsvc, "Webservice should be empty") + assert.Empty(t, ns, "Namespace should be empty") + assert.Equal(t, CerberusReasonWebserviceEmpty, reason, "Reason should indicate webservice key missing") + + // Test case 3: Namespace key missing + request = &Request{ + Context: map[string]string{ + "webservice": "example-service", + }, + } + wsvc, ns, reason = readRequestContext(request) + assert.Empty(t, wsvc, "Webservice should be empty") + assert.Empty(t, ns, "Namespace should be empty") + assert.Equal(t, CerberusReasonWebserviceNamespaceEmpty, reason, "Reason should indicate namespace key missing") + + // Test case 4: Both keys missing + request = &Request{ + Context: map[string]string{}, + } + wsvc, ns, reason = readRequestContext(request) + assert.Empty(t, wsvc, "Webservice should be empty") + assert.Empty(t, ns, "Namespace should be empty") + assert.Equal(t, CerberusReasonWebserviceEmpty, reason, "Reason should indicate webservice key missing") +} + +func Test_generateResponse(t *testing.T) { + // Test case 1: Response is allowed with no extra headers + expectedResponse := &Response{ + Allow: true, + Response: http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + ExternalAuthHandlerHeader: {"cerberus"}, + CerberusHeaderReasonHeader: {"reason"}, + }, + }, + } + actualResponse := generateResponse(true, "reason", nil) + assert.Equal(t, expectedResponse.Allow, actualResponse.Allow, "Response should be allowed") + assert.Equal(t, expectedResponse.Response.StatusCode, actualResponse.Response.StatusCode, "HTTP status code should match") + assert.Equal(t, expectedResponse.Response.Header, actualResponse.Response.Header, "Response headers should match") + + // Test case 2: Response is not allowed with extra headers + extraHeaders := ExtraHeaders{"Extra-Header": "value"} + expectedResponse = &Response{ + Allow: false, + Response: http.Response{ + StatusCode: http.StatusUnauthorized, + Header: http.Header{ + ExternalAuthHandlerHeader: {"cerberus"}, + CerberusHeaderReasonHeader: {"reason"}, + "Extra-Header": {"value"}, + }, + }, + } + actualResponse = generateResponse(false, "reason", extraHeaders) + assert.Equal(t, expectedResponse.Allow, actualResponse.Allow, "Response should not be allowed") + assert.Equal(t, expectedResponse.Response.StatusCode, actualResponse.Response.StatusCode, "HTTP status code should match") + assert.Equal(t, expectedResponse.Response.Header, actualResponse.Response.Header, "Response headers should match") +} + +func TestValidateUpstreamAuthRequest(t *testing.T) { + // Test case 1: ReadTokenFrom and WriteTokenTo are empty + service := WebservicesCacheEntry{} + service.Spec.UpstreamHttpAuth.ReadTokenFrom = "" + service.Spec.UpstreamHttpAuth.WriteTokenTo = "" + reason := validateUpstreamAuthRequest(service) + assert.Equal(t, CerberusReasonTargetAuthTokenEmpty, reason, "Expected target auth token empty") + + // Test case 2: WriteTokenTo is empty + service = WebservicesCacheEntry{} + service.Spec.UpstreamHttpAuth.ReadTokenFrom = "token" + service.Spec.UpstreamHttpAuth.WriteTokenTo = "" + reason = validateUpstreamAuthRequest(service) + assert.Equal(t, CerberusReasonTargetAuthTokenEmpty, reason, "Expected target auth token empty") + + // Test case 3: ReadTokenFrom is empty + service = WebservicesCacheEntry{} + service.Spec.UpstreamHttpAuth.ReadTokenFrom = "" + service.Spec.UpstreamHttpAuth.WriteTokenTo = "token" + reason = validateUpstreamAuthRequest(service) + assert.Equal(t, CerberusReasonTargetAuthTokenEmpty, reason, "Expected target auth token empty") + + // Test case 4: Address is invalid + service = WebservicesCacheEntry{} + service.Spec.UpstreamHttpAuth.ReadTokenFrom = "token" + service.Spec.UpstreamHttpAuth.WriteTokenTo = "token" + service.Spec.UpstreamHttpAuth.Address = "not a valid URL" + reason = validateUpstreamAuthRequest(service) + assert.Equal(t, CerberusReasonInvalidUpstreamAddress, reason, "Expected invalid upstream address") + + // Test case 5: Everything is valid + service = WebservicesCacheEntry{} + service.Spec.UpstreamHttpAuth.ReadTokenFrom = "token" + service.Spec.UpstreamHttpAuth.WriteTokenTo = "token" + service.Spec.UpstreamHttpAuth.Address = "http://example.com" + reason = validateUpstreamAuthRequest(service) + assert.Empty(t, reason, "Expected no reason") +} + +// MockHTTPClient is a mock implementation of http.Client for testing purposes. +type MockTransport struct { + DoFunc func(req *http.Request) (*http.Response, error) +} + +// Do executes the provided HTTP request and returns the response. +func (c *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return c.DoFunc(req) +} + +func TestAdjustTimeoutWithHTTPClientMock(t *testing.T) { + // Test case 1: No downstream deadline + transport := &MockTransport{ + DoFunc: func(req *http.Request) (*http.Response, error) { + // Mock response + return &http.Response{ + StatusCode: http.StatusOK, + Body: http.NoBody, + }, nil + }, + } + + authenticator := Authenticator{ + httpClient: &http.Client{Transport: transport}, + } + + // TestCase 1: No deadline + timeout := 1000 + downstreamDeadline := time.Now() + hasDownstreamDeadline := false + + expectedTimeout := time.Duration(1000) * time.Millisecond + authenticator.adjustTimeout(timeout, downstreamDeadline, hasDownstreamDeadline) + assert.Equal(t, expectedTimeout, authenticator.httpClient.Timeout, "Timeout should match expected value") + + // TestCase 2, With deadline but it is a bad test because it is not deterministic + timeout = 1000 + downstreamDeadline = time.Now().Add(time.Duration(1000) * time.Millisecond) + hasDownstreamDeadline = true + expectedTimeout = time.Duration(1000)*time.Millisecond - downstreamDeadlineOffset + authenticator.adjustTimeout(timeout, downstreamDeadline, hasDownstreamDeadline) + assert.NotEqual(t, expectedTimeout, authenticator.httpClient.Timeout) + assert.LessOrEqual(t, expectedTimeout-authenticator.httpClient.Timeout, time.Duration(50)*time.Microsecond, "Timeout should match expected value") + + // TestCase 3, With deadline but it is more than httpTimeout + timeout = 10 + downstreamDeadline = time.Now().Add(time.Duration(100) * time.Millisecond) + hasDownstreamDeadline = true + expectedTimeout = time.Duration(10) * time.Millisecond + authenticator.adjustTimeout(timeout, downstreamDeadline, hasDownstreamDeadline) + assert.Equal(t, expectedTimeout, authenticator.httpClient.Timeout, "Timeout should match expected value") + +} + +func TestCopyUpstreamHeaders(t *testing.T) { + // Test case 1: Header is copied to extraHeaders + resp := &http.Response{ + Header: http.Header{ + "Header1": {"Value1"}, + "Header2": {"Value2"}, + }, + } + extraHeaders := make(ExtraHeaders) + careHeaders := []string{"Header1"} + + copyUpstreamHeaders(resp, &extraHeaders, careHeaders) + assert.Equal(t, "Value1", extraHeaders["Header1"], "Header1 should be copied to extraHeaders") + assert.Empty(t, extraHeaders["Header2"], "Header2 should not be copied to extraHeaders") + + // Test case 2: No headers are copied + resp = &http.Response{ + Header: http.Header{ + "Header1": {"Value1"}, + "Header2": {"Value2"}, + }, + } + extraHeaders = make(ExtraHeaders) + careHeaders = []string{} + + copyUpstreamHeaders(resp, &extraHeaders, careHeaders) + assert.Empty(t, extraHeaders, "No headers should be copied to extraHeaders") + + // Test case 3: Multiple headers are copied + resp = &http.Response{ + Header: http.Header{ + "Header1": {"Value1"}, + "Header2": {"Value2"}, + "Header3": {"Value3"}, + }, + } + extraHeaders = make(ExtraHeaders) + careHeaders = []string{"Header1", "Header3"} + + copyUpstreamHeaders(resp, &extraHeaders, careHeaders) + assert.Equal(t, "Value1", extraHeaders["Header1"], "Header1 should be copied to extraHeaders") + assert.Empty(t, extraHeaders["Header2"], "Header2 should not be copied to extraHeaders") + assert.Equal(t, "Value3", extraHeaders["Header3"], "Header3 should be copied to extraHeaders") +} + +// Mock error interface with timeout interface implementation for +// testing timeout errors in tests +type innerError struct { + timeout bool +} + +func (inner innerError) Timeout() bool { + return true +} + +func (inner innerError) Error() string { + panic("should not be used") +} +func TestProcessResponseError(t *testing.T) { + // Test case 1: No error + reason := processResponseError(nil) + assert.Equal(t, CerberusReasonNotSet, reason, "No error should return an empty string") + + // Test case 2: Timeout error + urlErr := &url.Error{ + Op: "Get", + URL: "http://example.com", + Err: &innerError{timeout: true}, + } + reason = processResponseError(urlErr) + assert.Equal(t, CerberusReasonUpstreamAuthTimeout, reason, "Timeout error should return upstream auth timeout") + + // Test case 2: Timeout error + urlErr = &url.Error{ + Op: "Get", + URL: "http://example.com", + Err: errors.New("no timeout implemeted"), + } + reason = processResponseError(urlErr) + assert.Equal(t, CerberusReasonUpstreamAuthFailed, reason, "Timeout error should return upstream auth timeout") + + // Test case 3: Other error + reason = processResponseError(errors.New("connection refused")) + assert.Equal(t, CerberusReasonUpstreamAuthFailed, reason, "Other errors should return upstream auth failed") +} + +func TestSetupUpstreamAuthRequest(t *testing.T) { + // Test case 1: Successful setup + upstreamAuth := &cerberusv1alpha1.UpstreamHttpAuthService{ + ReadTokenFrom: "X-Token-Read", + WriteTokenTo: "X-Token-Write", + Address: "http://example.com", + Timeout: 1000, + } + + request := &Request{ + Request: http.Request{ + Header: http.Header{ + "X-Token-Read": {"value"}, + }, + }, + } + + expectedReq, _ := http.NewRequest("GET", "http://example.com", nil) + expectedReq.Header = http.Header{ + "X-Token-Write": {"value"}, + "Content-Type": {"application/json"}, + } + + actualReq, actualErr := setupUpstreamAuthRequest(upstreamAuth, request) + assert.NoError(t, actualErr, "No error should occur") + assert.Equal(t, expectedReq.URL.String(), actualReq.URL.String(), "Request URL should match") + assert.Equal(t, expectedReq.Header, actualReq.Header, "Request headers should match") + + // Test case 2: Error from http.NewRequest + upstreamAuth = &cerberusv1alpha1.UpstreamHttpAuthService{ + Address: ":", + } // Empty service + request = &Request{} + + actualReq, actualErr = setupUpstreamAuthRequest(upstreamAuth, request) + assert.Nil(t, actualReq, "Request should be nil when there is an error") + assert.Error(t, actualErr, "Error should occur when service is empty") +} diff --git a/pkg/auth/cerberus_reasons.go b/pkg/auth/cerberus_reasons.go index aba4ad6..fd67e41 100644 --- a/pkg/auth/cerberus_reasons.go +++ b/pkg/auth/cerberus_reasons.go @@ -10,6 +10,10 @@ const ( // OTHER THAN CerberusReasonOK means that the request is NOT authenticated CerberusReasonOK CerberusReason = "ok" + // CerberusReasonNotSet means no reason is set during checks and it means + // the process should continue to find the reason + CerberusReasonNotSet CerberusReason = "" + // CerberusReasonUnauthorized means that given AccessToken is found but // it does NOT have access to requested Webservice CerberusReasonUnauthorized CerberusReason = "unauthorized" @@ -61,6 +65,10 @@ const ( // the request context is empty or it's not given at all CerberusReasonWebserviceEmpty CerberusReason = "webservice-empty" + // CerberusReasonWebserviceNotAllowed means that given webservice in + // the request context is empty or it's not given at all + CerberusReasonWebserviceNotAllowed CerberusReason = "webservice-not-allowed" + // CerberusReasonWebserviceNamespaceEmpty means that given namespace of webservice in // the request context is empty or it's not given at all CerberusReasonWebserviceNamespaceEmpty CerberusReason = "webservice-namespace-empty" diff --git a/pkg/auth/metrics_test.go b/pkg/auth/metrics_test.go new file mode 100644 index 0000000..525736b --- /dev/null +++ b/pkg/auth/metrics_test.go @@ -0,0 +1,84 @@ +package auth + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" +) + +func TestAddReasonLabel(t *testing.T) { + // Test case 1: Adding reason label + labels := AddReasonLabel(nil, "random-reason") + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "random-reason", labels[CerberusReasonLabel]) + + // Test case 2: Existing labels + existingLabels := prometheus.Labels{"existing": "label"} + labels = AddReasonLabel(existingLabels, "random-reason") + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "random-reason", labels[CerberusReasonLabel]) + assert.Equal(t, "label", labels["existing"], "Existing label should remain unchanged") +} + +func TestAddKindLabel(t *testing.T) { + // Test case 1: Adding kind label + labels := AddKindLabel(nil, "some_kind") + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "some_kind", labels[ObjectKindLabel], "Kind label should be added") + + existingLabels := prometheus.Labels{"existing": "label"} + labels = AddKindLabel(existingLabels, "some_kind") + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "some_kind", labels[ObjectKindLabel]) + assert.Equal(t, "label", labels["existing"]) + +} +func TestAddStatusLabel(t *testing.T) { + // Test case 1: Adding status label + labels := AddStatusLabel(nil, 200) + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "200", labels[StatusCode], "Status label should be added") + + existingLabels := prometheus.Labels{"existing": "label"} + labels = AddStatusLabel(existingLabels, 200) + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "200", labels[StatusCode], "Status label should be '200'") + assert.Equal(t, "label", labels["existing"], "Existing label should remain unchanged") +} +func TestAddUpstreamAuthLabel(t *testing.T) { + // Test case 1: With upstream auth + labels := AddUpstreamAuthLabel(nil, "true") + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "true", labels[HasUpstreamAuth], "HasUpstreamAuth label should be true") + + // Test case 2: Without upstream auth + labels = AddUpstreamAuthLabel(nil, "false") + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "false", labels[HasUpstreamAuth], "HasUpstreamAuth label should be false") + + // Test case 3: Existing labels + existingLabels := prometheus.Labels{"existing": "label"} + labels = AddUpstreamAuthLabel(existingLabels, "true") + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "true", labels[HasUpstreamAuth], "HasUpstreamAuth label should be true") + assert.Equal(t, "label", labels["existing"], "Existing label should remain unchanged") +} +func TestAddWithDownstreamDeadline(t *testing.T) { + // Test case 1: With downstream deadline + labels := AddWithDownstreamDeadline(nil, true) + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "true", labels[WithDownstreamDeadlineLabel], "WithDownstreamDeadlineLabel should be true") + + // Test case 2: Without downstream deadline + labels = AddWithDownstreamDeadline(nil, false) + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "false", labels[WithDownstreamDeadlineLabel], "WithDownstreamDeadlineLabel should be false") + + // Test case 3: Existing labels + existingLabels := prometheus.Labels{"existing": "label"} + labels = AddWithDownstreamDeadline(existingLabels, true) + assert.NotNil(t, labels, "Labels should not be nil") + assert.Equal(t, "true", labels[WithDownstreamDeadlineLabel], "WithDownstreamDeadlineLabel should be true") + assert.Equal(t, "label", labels["existing"], "Existing label should remain unchanged") +} diff --git a/pkg/testutils/logging.go b/pkg/testutils/logging.go new file mode 100644 index 0000000..1a99305 --- /dev/null +++ b/pkg/testutils/logging.go @@ -0,0 +1,96 @@ +package testutils + +import ( + "github.com/go-logr/logr" +) + +// Log is a structure of logr logs targets easy assertion in tests +type Log struct { + Message string + Type string + Name string + KeyValues map[interface{}]interface{} +} + +// Logs is array of logs to define references +type Logs []Log + +// TestLogSink is a simple implementation of the LogSink interface for testing purposes. +type TestLogSink struct { + Logs *Logs + currentValues []interface{} + currentName string +} + +func (sink TestLogSink) GetLog(n int) Log { + return (*sink.Logs)[n] +} + +func NewTestLogSink() *TestLogSink { + return &TestLogSink{ + Logs: &Logs{}, + } +} + +// keysAndValuesToMap converts key-value pairs to a map. +func keysAndValuesToMap(keysAndValues []interface{}) map[interface{}]interface{} { + // Initialize an empty map + result := make(map[interface{}]interface{}) + // Ensure that there are an even number of arguments + if len(keysAndValues)%2 != 0 { + // If the number of arguments is odd, return an empty map + return result + } + + // Iterate over the key-value pairs + for i := 0; i < len(keysAndValues); i += 2 { + // Assign the key-value pairs to the map + key := keysAndValues[i] + value := keysAndValues[i+1] + result[key] = value + } + + return result +} + +// Init initializes the TestLogSink. +func (t *TestLogSink) Init(info logr.RuntimeInfo) { + // For simplicity, we don't use any information about the logr library here. +} + +// Enabled always returns true, assuming all logs should be captured in tests. +func (t *TestLogSink) Enabled(level int) bool { + return true +} + +// Info captures log messages. +func (t *TestLogSink) Info(level int, msg string, keysAndValues ...interface{}) { + if len(t.currentValues) > 0 { + keysAndValues = append(keysAndValues, t.currentValues...) + } + *t.Logs = append(*t.Logs, Log{Name: t.currentName, Type: "info", Message: msg, KeyValues: keysAndValuesToMap(keysAndValues)}) +} + +// Error captures error messages. +func (t *TestLogSink) Error(err error, msg string, keysAndValues ...interface{}) { + if len(t.currentValues) > 0 { + keysAndValues = append(keysAndValues, t.currentValues...) + } + *t.Logs = append(*t.Logs, Log{Name: t.currentName, Type: "error", Message: err.Error(), KeyValues: keysAndValuesToMap(keysAndValues)}) +} + +// WithValues is not used in this simple implementation. +func (t *TestLogSink) WithValues(keysAndValues ...interface{}) logr.LogSink { + sink := &TestLogSink{Logs: t.Logs} + sink.currentValues = append(t.currentValues, keysAndValues...) + sink.currentName = t.currentName + return sink +} + +// WithName is not used in this simple implementation. +func (t *TestLogSink) WithName(name string) logr.LogSink { + sink := &TestLogSink{Logs: t.Logs} + sink.currentValues = t.currentValues + sink.currentName = t.currentName + return sink +}