diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml new file mode 100644 index 0000000..6726f01 --- /dev/null +++ b/.github/workflows/push.yml @@ -0,0 +1,19 @@ +name: Push + +on: + push: + branches: + - main + pull_request: + types: + - opened + - synchronize + - reopened + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v3 + - run: go test -v ./... diff --git a/context.go b/context.go index 123a2f8..0c93660 100644 --- a/context.go +++ b/context.go @@ -3,6 +3,8 @@ package ssokenizer import ( "context" "net/http" + + "github.com/sirupsen/logrus" ) type contextKey string @@ -10,6 +12,7 @@ type contextKey string const ( contextKeyTransaction contextKey = "transaction" contextKeyProvider contextKey = "provider" + contextKeyLog contextKey = "log" ) func withTransaction(r *http.Request, t *Transaction) *http.Request { @@ -27,3 +30,36 @@ func withProvider(r *http.Request, p *provider) *http.Request { func getProvider(r *http.Request) *provider { return r.Context().Value(contextKeyProvider).(*provider) } + +// Updates the logrus.FieldLogger in the context with added data. Requests are +// logged by Transaction.ReturnData/ReturnError. +func WithLog(r *http.Request, l logrus.FieldLogger) *http.Request { + return r.WithContext(context.WithValue(r.Context(), contextKeyLog, l)) +} + +// Updates the logrus.FieldLogger in the context with "error" field. Requests +// are logged by Transaction.ReturnData/ReturnError. +func WithError(r *http.Request, err error) *http.Request { + return WithLog(r, GetLog(r).WithError(err)) +} + +// Updates the logrus.FieldLogger in the context with added field. Requests +// are logged by Transaction.ReturnData/ReturnError. +func WithField(r *http.Request, key string, value any) *http.Request { + return WithLog(r, GetLog(r).WithField(key, value)) +} + +// Updates the logrus.FieldLogger in the context with added fields. Requests +// are logged by Transaction.ReturnData/ReturnError. +func WithFields(r *http.Request, fields logrus.Fields) *http.Request { + return WithLog(r, GetLog(r).WithFields(fields)) +} + +// Gets the logrus.FieldLogger from the context. Requests are logged by +// Transaction.ReturnData/ReturnError. +func GetLog(r *http.Request) logrus.FieldLogger { + if l, ok := r.Context().Value(contextKeyLog).(logrus.FieldLogger); ok { + return l + } + return logrus.StandardLogger() +} diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 9906fa6..0f7cc6e 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -2,6 +2,8 @@ package oauth2 import ( "crypto/subtle" + "encoding/base64" + "encoding/json" "errors" "fmt" "net/http" @@ -73,6 +75,8 @@ func (p *provider) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (p *provider) handleStart(w http.ResponseWriter, r *http.Request) { + defer getLog(r).WithField("status", http.StatusFound).Info() + tr := ssokenizer.GetTransaction(r) http.Redirect(w, r, p.config(r).AuthCodeURL(tr.Nonce, oauth2.AccessTypeOffline), http.StatusFound) } @@ -82,39 +86,44 @@ func (p *provider) handleCallback(w http.ResponseWriter, r *http.Request) { params := r.URL.Query() if errParam := params.Get("error"); errParam != "" { + r = withError(r, fmt.Errorf("error param: %s", errParam)) tr.ReturnError(w, r, errParam) return } state := params.Get("state") if state == "" { - logrus.Warn("missing state") + r = withError(r, errors.New("missing state")) tr.ReturnError(w, r, "bad response") return } if subtle.ConstantTimeCompare([]byte(tr.Nonce), []byte(state)) != 1 { - logrus.WithFields(logrus.Fields{"have": state, "want": tr.Nonce}).Warn("bad state") + r = withError(r, errors.New("bad state")) + r = withFields(r, logrus.Fields{"have": state, "want": tr.Nonce}) tr.ReturnError(w, r, "bad response") return } code := params.Get("code") if code == "" { - logrus.Warn("missing code") + r = withError(r, errors.New("missing code")) tr.ReturnError(w, r, "bad response") return } tok, err := p.config(r).Exchange(r.Context(), code, oauth2.AccessTypeOffline) if err != nil { - logrus.WithError(err).Warn("failed exchange") + r = withError(r, fmt.Errorf("failed exchange: %w", err)) tr.ReturnError(w, r, "bad response") return } + r = withIdToken(r, tok) + if t := tok.Type(); t != "Bearer" { - logrus.WithField("type", t).Warn("unrecognized token type") + r = withField(r, "type", t) + r = withError(r, errors.New("unrecognized token type")) tr.ReturnError(w, r, "bad response") return } @@ -131,7 +140,7 @@ func (p *provider) handleCallback(w http.ResponseWriter, r *http.Request) { sealed, err := secret.Seal(p.sealKey) if err != nil { - logrus.WithError(err).Warn("failed seal") + r = withError(r, fmt.Errorf("failed seal: %w", err)) tr.ReturnError(w, r, "seal error") return } @@ -145,19 +154,33 @@ func (p *provider) handleCallback(w http.ResponseWriter, r *http.Request) { func (p *provider) handleRefresh(w http.ResponseWriter, r *http.Request) { refreshToken, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ") if !ok { - logrus.Warn("refresh: missing token") + getLog(r). + WithField("status", http.StatusUnauthorized). + Info("refresh: missing token") + w.WriteHeader(http.StatusUnauthorized) return } tok, err := p.config(r).TokenSource(r.Context(), &oauth2.Token{RefreshToken: refreshToken}).Token() if err != nil { - logrus.WithError(err).Warn("refresh") + getLog(r). + WithField("status", http.StatusBadGateway). + WithError(err). + Info("refresh") + w.WriteHeader(http.StatusBadGateway) return } + + r = withIdToken(r, tok) + if t := tok.Type(); t != "Bearer" { - logrus.WithField("type", t).Warn("unrecognized token type") + getLog(r). + WithField("status", http.StatusInternalServerError). + WithField("type", t). + Info("unrecognized token type") + w.WriteHeader(http.StatusInternalServerError) return } @@ -175,7 +198,11 @@ func (p *provider) handleRefresh(w http.ResponseWriter, r *http.Request) { sealed, err := secret.Seal(p.sealKey) if err != nil { - logrus.WithError(err).Warn("refresh: failed seal") + getLog(r). + WithField("status", http.StatusInternalServerError). + WithError(err). + Info("refresh: failed seal") + w.WriteHeader(http.StatusInternalServerError) return } @@ -185,9 +212,16 @@ func (p *provider) handleRefresh(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte(sealed)); err != nil { // status already written - logrus.WithError(err).Warn("refresh: write response") + getLog(r). + WithError(err). + Info("refresh: write response") + return } + + getLog(r). + WithField("status", http.StatusOK). + Info() } func (p *provider) config(r *http.Request) *Config { @@ -210,3 +244,50 @@ func (p *provider) requestValidators(r *http.Request) []tokenizer.RequestValidat re := regexp.MustCompile(fmt.Sprintf("^(%s|%s)$", regexp.QuoteMeta(r.Host), p.AllowedHostPattern)) return []tokenizer.RequestValidator{tokenizer.AllowHostPattern(re)} } + +// logging helpers. aliased for convenience +var ( + getLog = ssokenizer.GetLog + withError = ssokenizer.WithError + withField = ssokenizer.WithField + withFields = ssokenizer.WithFields +) + +// logging helper. Tries to find and parse user info from id token. +func withIdToken(r *http.Request, tok *oauth2.Token) *http.Request { + idToken, ok := tok.Extra("id_token").(string) + if !ok { + return r + } + + parts := strings.Split(idToken, ".") + if len(parts) < 2 { + return r + } + + jbody, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return r + } + + var body struct { + Sub string `json:"sub"` + HD string `json:"hd"` + Email string `json:"email"` + } + if err := json.Unmarshal(jbody, &body); err != nil { + return r + } + + if body.Sub != "" { + r = withField(r, "sub", body.Sub) + } + if body.HD != "" { + r = withField(r, "hd", body.HD) + } + if body.Email != "" { + r = withField(r, "email", body.Email) + } + + return r +} diff --git a/provider.go b/provider.go index c3fd976..0692d9f 100644 --- a/provider.go +++ b/provider.go @@ -10,7 +10,7 @@ import ( type provider struct { name string handler http.Handler - returnURL *url.URL + returnURL url.URL } // Arbitrary configuration type for providers to implement. diff --git a/ssokenizer.go b/ssokenizer.go index 6c42a9c..dbc598b 100644 --- a/ssokenizer.go +++ b/ssokenizer.go @@ -2,6 +2,7 @@ package ssokenizer import ( "context" + "errors" "fmt" "net" "net/http" @@ -30,10 +31,8 @@ type Server struct { // provided to tokenizer by the relying party in order to use the sealed token. func NewServer(sealKey string) *Server { s := &Server{ - sealKey: sealKey, - providers: map[string](*provider){ - "health": &provider{handler: handleHealth}, - }, + sealKey: sealKey, + providers: make(map[string](*provider)), } s.http = &http.Server{Handler: s} @@ -42,11 +41,17 @@ func NewServer(sealKey string) *Server { } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - logrus.WithFields(logrus.Fields{"method": r.Method, "uri": r.URL.Path, "host": r.Host}).Info() - providerName, rest, _ := strings.Cut(strings.TrimPrefix(r.URL.Path, "/"), "/") + if providerName == "health" { + fmt.Fprintln(w, "ok") + return + } + + r = WithFields(r, logrus.Fields{"method": r.Method, "uri": r.URL.Path, "host": r.Host}) + provider, ok := s.providers[providerName] if !ok { + GetLog(r).WithField("status", http.StatusNotFound).Info() w.WriteHeader(http.StatusNotFound) return } @@ -60,13 +65,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if tc, err := r.Cookie(transactionCookieName); err != http.ErrNoCookie && tc.Value != "" { if err := unmarshalTransaction(t, tc.Value); err != nil { - logrus.WithError(err).Warn("bad transaction cookie") + r = WithError(r, fmt.Errorf("bad transaction cookie: %w", err)) t.ReturnError(w, r, "bad request") return } if time.Now().After(t.Expiry) { - logrus.Warn("expired transaction") + r = WithError(r, errors.New("expired transaction")) t.ReturnError(w, r, "expired") return } @@ -74,7 +79,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { ts, err := t.marshal() if err != nil { - logrus.WithError(err).Warn("marshal transaction cookie") + r = WithError(r, fmt.Errorf("marshal transaction cookie: %w", err)) t.ReturnError(w, r, "unexpected error") return } @@ -107,7 +112,7 @@ func (s *Server) AddProvider(name string, pc ProviderConfig, returnURL string, a s.providers[name] = &provider{ name: name, handler: p, - returnURL: ru, + returnURL: *ru, } return nil @@ -138,5 +143,3 @@ func (s *Server) Start(address string) error { func (s *Server) Shutdown(ctx context.Context) error { return s.http.Shutdown(ctx) } - -var handleHealth http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) } diff --git a/transaction.go b/transaction.go index 849b1b2..1abad4e 100644 --- a/transaction.go +++ b/transaction.go @@ -44,6 +44,8 @@ func (t *Transaction) ReturnError(w http.ResponseWriter, r *http.Request, msg st } func (t *Transaction) returnData(w http.ResponseWriter, r *http.Request, data map[string]string) { + defer GetLog(r).WithField("status", http.StatusFound).Info() + t.setCookie(w, r, "") returnURL := getProvider(r).returnURL