diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b79baa9e..4668d652 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,8 +11,8 @@ jobs: name: Run golangci-lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: golangci-lint - uses: golangci/golangci-lint-action@v2 + uses: golangci/golangci-lint-action@v3 with: version: v1.46.2 diff --git a/samlsp/basic_assertion_handler.go b/samlsp/basic_assertion_handler.go new file mode 100644 index 00000000..095f1063 --- /dev/null +++ b/samlsp/basic_assertion_handler.go @@ -0,0 +1,17 @@ +package samlsp + +import ( + "github.com/crewjam/saml" +) + +var _ SamlAssertionHandler = BasicSamlAssertionHandler{} + +// BasicSamlAssertionHandler is an implementation of SamlAssertionHandler that has +// an empty HandleAssertion function to retain useability. +type BasicSamlAssertionHandler struct{} + +// HandleAssertion is called and passed saml assertion +// this can add extra functionality and should return any error that occurs. +func (as BasicSamlAssertionHandler) HandleAssertion(assertion *saml.Assertion) error { + return nil +} diff --git a/samlsp/middleware.go b/samlsp/middleware.go index 834a79c1..759043f8 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -39,12 +39,13 @@ import ( // SAML service provider already has a private key, we borrow that key // to sign the JWTs as well. type Middleware struct { - ServiceProvider saml.ServiceProvider - OnError func(w http.ResponseWriter, r *http.Request, err error) - Binding string // either saml.HTTPPostBinding or saml.HTTPRedirectBinding - ResponseBinding string // either saml.HTTPPostBinding or saml.HTTPArtifactBinding - RequestTracker RequestTracker - Session SessionProvider + ServiceProvider saml.ServiceProvider + OnError func(w http.ResponseWriter, r *http.Request, err error) + Binding string // either saml.HTTPPostBinding or saml.HTTPRedirectBinding + ResponseBinding string // either saml.HTTPPostBinding or saml.HTTPArtifactBinding + RequestTracker RequestTracker + Session SessionProvider + AssertionHandler SamlAssertionHandler } // ServeHTTP implements http.Handler and serves the SAML-specific HTTP endpoints @@ -92,6 +93,12 @@ func (m *Middleware) ServeACS(w http.ResponseWriter, r *http.Request) { return } + assertionErr := m.AssertionHandler.HandleAssertion(assertion) + if assertionErr != nil { + m.OnError(w, r, assertionErr) + return + } + m.CreateSessionFromAssertion(w, r, assertion, m.ServiceProvider.DefaultRedirectURI) return } diff --git a/samlsp/new.go b/samlsp/new.go index 81fa75f6..e9ed8554 100644 --- a/samlsp/new.go +++ b/samlsp/new.go @@ -131,6 +131,12 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { } } +// DefaultSamlAssertionHandler returns the default SamlAssertionHandler for the provided options, +// a BasicSamlAssertionHandler configured to do nothing. +func DefaultSamlAssertionHandler(opts Options) BasicSamlAssertionHandler { + return BasicSamlAssertionHandler{} +} + // New creates a new Middleware with the default providers for the // given options. // @@ -139,11 +145,12 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { // in the returned Middleware. func New(opts Options) (*Middleware, error) { m := &Middleware{ - ServiceProvider: DefaultServiceProvider(opts), - Binding: "", - ResponseBinding: saml.HTTPPostBinding, - OnError: DefaultOnError, - Session: DefaultSessionProvider(opts), + ServiceProvider: DefaultServiceProvider(opts), + Binding: "", + ResponseBinding: saml.HTTPPostBinding, + OnError: DefaultOnError, + Session: DefaultSessionProvider(opts), + AssertionHandler: DefaultSamlAssertionHandler(opts), } m.RequestTracker = DefaultRequestTracker(opts, &m.ServiceProvider) if opts.UseArtifactResponse { diff --git a/samlsp/saml_assertion_handler.go b/samlsp/saml_assertion_handler.go new file mode 100644 index 00000000..1230aaf7 --- /dev/null +++ b/samlsp/saml_assertion_handler.go @@ -0,0 +1,9 @@ +package samlsp + +import "github.com/crewjam/saml" + +// SamlAssertionHandler is an interface implemented by types that can handle +// assertions and add extra functionality +type SamlAssertionHandler interface { + HandleAssertion(assertion *saml.Assertion) error +}