Skip to content

Commit

Permalink
Merge pull request #7 from superfly/ptr
Browse files Browse the repository at this point in the history
Don't mutate url on provider struct
  • Loading branch information
btoews authored Oct 27, 2023
2 parents 68dc8a2 + 1433780 commit 5436dd8
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 24 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Push

on:
push:
branches:
- main
pull_request:
types:
- opened
- synchronize
- reopened

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
- run: go test -v ./...
36 changes: 36 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package ssokenizer
import (
"context"
"net/http"

"github.com/sirupsen/logrus"
)

type contextKey string

const (
contextKeyTransaction contextKey = "transaction"
contextKeyProvider contextKey = "provider"
contextKeyLog contextKey = "log"
)

func withTransaction(r *http.Request, t *Transaction) *http.Request {
Expand All @@ -27,3 +30,36 @@ func withProvider(r *http.Request, p *provider) *http.Request {
func getProvider(r *http.Request) *provider {
return r.Context().Value(contextKeyProvider).(*provider)
}

// Updates the logrus.FieldLogger in the context with added data. Requests are
// logged by Transaction.ReturnData/ReturnError.
func WithLog(r *http.Request, l logrus.FieldLogger) *http.Request {
return r.WithContext(context.WithValue(r.Context(), contextKeyLog, l))
}

// Updates the logrus.FieldLogger in the context with "error" field. Requests
// are logged by Transaction.ReturnData/ReturnError.
func WithError(r *http.Request, err error) *http.Request {
return WithLog(r, GetLog(r).WithError(err))
}

// Updates the logrus.FieldLogger in the context with added field. Requests
// are logged by Transaction.ReturnData/ReturnError.
func WithField(r *http.Request, key string, value any) *http.Request {
return WithLog(r, GetLog(r).WithField(key, value))
}

// Updates the logrus.FieldLogger in the context with added fields. Requests
// are logged by Transaction.ReturnData/ReturnError.
func WithFields(r *http.Request, fields logrus.Fields) *http.Request {
return WithLog(r, GetLog(r).WithFields(fields))
}

// Gets the logrus.FieldLogger from the context. Requests are logged by
// Transaction.ReturnData/ReturnError.
func GetLog(r *http.Request) logrus.FieldLogger {
if l, ok := r.Context().Value(contextKeyLog).(logrus.FieldLogger); ok {
return l
}
return logrus.StandardLogger()
}
103 changes: 92 additions & 11 deletions oauth2/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package oauth2

import (
"crypto/subtle"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -73,6 +75,8 @@ func (p *provider) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

func (p *provider) handleStart(w http.ResponseWriter, r *http.Request) {
defer getLog(r).WithField("status", http.StatusFound).Info()

tr := ssokenizer.GetTransaction(r)
http.Redirect(w, r, p.config(r).AuthCodeURL(tr.Nonce, oauth2.AccessTypeOffline), http.StatusFound)
}
Expand All @@ -82,39 +86,44 @@ func (p *provider) handleCallback(w http.ResponseWriter, r *http.Request) {
params := r.URL.Query()

if errParam := params.Get("error"); errParam != "" {
r = withError(r, fmt.Errorf("error param: %s", errParam))
tr.ReturnError(w, r, errParam)
return
}

state := params.Get("state")
if state == "" {
logrus.Warn("missing state")
r = withError(r, errors.New("missing state"))
tr.ReturnError(w, r, "bad response")
return
}

if subtle.ConstantTimeCompare([]byte(tr.Nonce), []byte(state)) != 1 {
logrus.WithFields(logrus.Fields{"have": state, "want": tr.Nonce}).Warn("bad state")
r = withError(r, errors.New("bad state"))
r = withFields(r, logrus.Fields{"have": state, "want": tr.Nonce})
tr.ReturnError(w, r, "bad response")
return
}

code := params.Get("code")
if code == "" {
logrus.Warn("missing code")
r = withError(r, errors.New("missing code"))
tr.ReturnError(w, r, "bad response")
return
}

tok, err := p.config(r).Exchange(r.Context(), code, oauth2.AccessTypeOffline)
if err != nil {
logrus.WithError(err).Warn("failed exchange")
r = withError(r, fmt.Errorf("failed exchange: %w", err))
tr.ReturnError(w, r, "bad response")
return
}

r = withIdToken(r, tok)

if t := tok.Type(); t != "Bearer" {
logrus.WithField("type", t).Warn("unrecognized token type")
r = withField(r, "type", t)
r = withError(r, errors.New("unrecognized token type"))
tr.ReturnError(w, r, "bad response")
return
}
Expand All @@ -131,7 +140,7 @@ func (p *provider) handleCallback(w http.ResponseWriter, r *http.Request) {

sealed, err := secret.Seal(p.sealKey)
if err != nil {
logrus.WithError(err).Warn("failed seal")
r = withError(r, fmt.Errorf("failed seal: %w", err))
tr.ReturnError(w, r, "seal error")
return
}
Expand All @@ -145,19 +154,33 @@ func (p *provider) handleCallback(w http.ResponseWriter, r *http.Request) {
func (p *provider) handleRefresh(w http.ResponseWriter, r *http.Request) {
refreshToken, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ")
if !ok {
logrus.Warn("refresh: missing token")
getLog(r).
WithField("status", http.StatusUnauthorized).
Info("refresh: missing token")

w.WriteHeader(http.StatusUnauthorized)
return
}

tok, err := p.config(r).TokenSource(r.Context(), &oauth2.Token{RefreshToken: refreshToken}).Token()
if err != nil {
logrus.WithError(err).Warn("refresh")
getLog(r).
WithField("status", http.StatusBadGateway).
WithError(err).
Info("refresh")

w.WriteHeader(http.StatusBadGateway)
return
}

r = withIdToken(r, tok)

if t := tok.Type(); t != "Bearer" {
logrus.WithField("type", t).Warn("unrecognized token type")
getLog(r).
WithField("status", http.StatusInternalServerError).
WithField("type", t).
Info("unrecognized token type")

w.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -175,7 +198,11 @@ func (p *provider) handleRefresh(w http.ResponseWriter, r *http.Request) {

sealed, err := secret.Seal(p.sealKey)
if err != nil {
logrus.WithError(err).Warn("refresh: failed seal")
getLog(r).
WithField("status", http.StatusInternalServerError).
WithError(err).
Info("refresh: failed seal")

w.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -185,9 +212,16 @@ func (p *provider) handleRefresh(w http.ResponseWriter, r *http.Request) {

if _, err := w.Write([]byte(sealed)); err != nil {
// status already written
logrus.WithError(err).Warn("refresh: write response")
getLog(r).
WithError(err).
Info("refresh: write response")

return
}

getLog(r).
WithField("status", http.StatusOK).
Info()
}

func (p *provider) config(r *http.Request) *Config {
Expand All @@ -210,3 +244,50 @@ func (p *provider) requestValidators(r *http.Request) []tokenizer.RequestValidat
re := regexp.MustCompile(fmt.Sprintf("^(%s|%s)$", regexp.QuoteMeta(r.Host), p.AllowedHostPattern))
return []tokenizer.RequestValidator{tokenizer.AllowHostPattern(re)}
}

// logging helpers. aliased for convenience
var (
getLog = ssokenizer.GetLog
withError = ssokenizer.WithError
withField = ssokenizer.WithField
withFields = ssokenizer.WithFields
)

// logging helper. Tries to find and parse user info from id token.
func withIdToken(r *http.Request, tok *oauth2.Token) *http.Request {
idToken, ok := tok.Extra("id_token").(string)
if !ok {
return r
}

parts := strings.Split(idToken, ".")
if len(parts) < 2 {
return r
}

jbody, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return r
}

var body struct {
Sub string `json:"sub"`
HD string `json:"hd"`
Email string `json:"email"`
}
if err := json.Unmarshal(jbody, &body); err != nil {
return r
}

if body.Sub != "" {
r = withField(r, "sub", body.Sub)
}
if body.HD != "" {
r = withField(r, "hd", body.HD)
}
if body.Email != "" {
r = withField(r, "email", body.Email)
}

return r
}
2 changes: 1 addition & 1 deletion provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
type provider struct {
name string
handler http.Handler
returnURL *url.URL
returnURL url.URL
}

// Arbitrary configuration type for providers to implement.
Expand Down
27 changes: 15 additions & 12 deletions ssokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ssokenizer

import (
"context"
"errors"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -30,10 +31,8 @@ type Server struct {
// provided to tokenizer by the relying party in order to use the sealed token.
func NewServer(sealKey string) *Server {
s := &Server{
sealKey: sealKey,
providers: map[string](*provider){
"health": &provider{handler: handleHealth},
},
sealKey: sealKey,
providers: make(map[string](*provider)),
}

s.http = &http.Server{Handler: s}
Expand All @@ -42,11 +41,17 @@ func NewServer(sealKey string) *Server {
}

func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
logrus.WithFields(logrus.Fields{"method": r.Method, "uri": r.URL.Path, "host": r.Host}).Info()

providerName, rest, _ := strings.Cut(strings.TrimPrefix(r.URL.Path, "/"), "/")
if providerName == "health" {
fmt.Fprintln(w, "ok")
return
}

r = WithFields(r, logrus.Fields{"method": r.Method, "uri": r.URL.Path, "host": r.Host})

provider, ok := s.providers[providerName]
if !ok {
GetLog(r).WithField("status", http.StatusNotFound).Info()
w.WriteHeader(http.StatusNotFound)
return
}
Expand All @@ -60,21 +65,21 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {

if tc, err := r.Cookie(transactionCookieName); err != http.ErrNoCookie && tc.Value != "" {
if err := unmarshalTransaction(t, tc.Value); err != nil {
logrus.WithError(err).Warn("bad transaction cookie")
r = WithError(r, fmt.Errorf("bad transaction cookie: %w", err))
t.ReturnError(w, r, "bad request")
return
}

if time.Now().After(t.Expiry) {
logrus.Warn("expired transaction")
r = WithError(r, errors.New("expired transaction"))
t.ReturnError(w, r, "expired")
return
}
}

ts, err := t.marshal()
if err != nil {
logrus.WithError(err).Warn("marshal transaction cookie")
r = WithError(r, fmt.Errorf("marshal transaction cookie: %w", err))
t.ReturnError(w, r, "unexpected error")
return
}
Expand Down Expand Up @@ -107,7 +112,7 @@ func (s *Server) AddProvider(name string, pc ProviderConfig, returnURL string, a
s.providers[name] = &provider{
name: name,
handler: p,
returnURL: ru,
returnURL: *ru,
}

return nil
Expand Down Expand Up @@ -138,5 +143,3 @@ func (s *Server) Start(address string) error {
func (s *Server) Shutdown(ctx context.Context) error {
return s.http.Shutdown(ctx)
}

var handleHealth http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) }
2 changes: 2 additions & 0 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ func (t *Transaction) ReturnError(w http.ResponseWriter, r *http.Request, msg st
}

func (t *Transaction) returnData(w http.ResponseWriter, r *http.Request, data map[string]string) {
defer GetLog(r).WithField("status", http.StatusFound).Info()

t.setCookie(w, r, "")

returnURL := getProvider(r).returnURL
Expand Down

0 comments on commit 5436dd8

Please sign in to comment.