diff --git a/assets/docs/configuration/targets/http-full-example.hcl b/assets/docs/configuration/targets/http-full-example.hcl index 99ff1a05..46ed9a54 100644 --- a/assets/docs/configuration/targets/http-full-example.hcl +++ b/assets/docs/configuration/targets/http-full-example.hcl @@ -41,5 +41,17 @@ target { # Whether to enable setting headers dynamically dynamic_headers = true + + # Optional. One of client credentials required when authorizing using OAuth2. + oauth2_client_id = env.CLIENT_ID + + # Optional. One of client credentials required when authorizing using OAuth2. + oauth2_client_secret = env.CLIENT_SECRET + + # Optional. Required when using OAuth2. Long-lived token used to generate new short-lived access token when previous one experies. + oauth2_refresh_token = env.REFRESH_TOKEN + + # Optional. Required when using OAuth2. URL to authorization server providing access token. E.g. for Goggle API "https://oauth2.googleapis.com/token" + oauth2_token_url = "https://my.auth.server/token" } } diff --git a/pkg/target/http.go b/pkg/target/http.go index 75f7dc76..276d13a0 100644 --- a/pkg/target/http.go +++ b/pkg/target/http.go @@ -13,6 +13,7 @@ package target import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -25,6 +26,8 @@ import ( "github.com/snowplow/snowbridge/pkg/common" "github.com/snowplow/snowbridge/pkg/models" + + "golang.org/x/oauth2" ) // HTTPTargetConfig configures the destination for records consumed @@ -41,6 +44,11 @@ type HTTPTargetConfig struct { CaFile string `hcl:"ca_file,optional" env:"TARGET_HTTP_TLS_CA_FILE"` SkipVerifyTLS bool `hcl:"skip_verify_tls,optional" env:"TARGET_HTTP_TLS_SKIP_VERIFY_TLS"` // false DynamicHeaders bool `hcl:"dynamic_headers,optional" env:"TARGET_HTTP_DYNAMIC_HEADERS"` + + OAuth2ClientId string `hcl:"oauth2_client_id,optional" env:"TARGET_HTTP_OAUTH2_CLIENT_ID"` + OAuth2ClientSecret string `hcl:"oauth2_client_secret,optional" env:"TARGET_HTTP_OAUTH2_CLIENT_SECRET"` + OAuth2RefreshToken string `hcl:"oauth2_refresh_token,optional" env:"TARGET_HTTP_OAUTH2_REFRESH_TOKEN"` + OAuth2TokenUrl string `hcl:"oauth2_token_url,optional" env:"TARGET_HTTP_OAUTH2_TOKEN_URL"` } // HTTPTarget holds a new client for writing messages to HTTP endpoints @@ -94,7 +102,7 @@ func addHeadersToRequest(request *http.Request, headers map[string]string, dynam // newHTTPTarget creates a client for writing events to HTTP func newHTTPTarget(httpURL string, requestTimeout int, byteLimit int, contentType string, headers string, basicAuthUsername string, basicAuthPassword string, - certFile string, keyFile string, caFile string, skipVerifyTLS bool, dynamicHeaders bool) (*HTTPTarget, error) { + certFile string, keyFile string, caFile string, skipVerifyTLS bool, dynamicHeaders bool, oAuth2ClientId string, oAuth2ClientSecret string, oAuth2RefreshToken string, oAuth2TokenUrl string) (*HTTPTarget, error) { err := checkURL(httpURL) if err != nil { return nil, err @@ -113,11 +121,11 @@ func newHTTPTarget(httpURL string, requestTimeout int, byteLimit int, contentTyp transport.TLSClientConfig = tlsConfig } + client := createHttpClient(oAuth2ClientId, oAuth2ClientSecret, oAuth2TokenUrl, oAuth2RefreshToken, transport) + client.Timeout = time.Duration(requestTimeout) * time.Second + return &HTTPTarget{ - client: &http.Client{ - Transport: transport, - Timeout: time.Duration(requestTimeout) * time.Second, - }, + client: client, httpURL: httpURL, byteLimit: byteLimit, contentType: contentType, @@ -129,6 +137,25 @@ func newHTTPTarget(httpURL string, requestTimeout int, byteLimit int, contentTyp }, nil } +func createHttpClient(oAuth2ClientId string, oAuth2ClientSecret string, oAuth2TokenUrl string, oAuth2RefreshToken string, transport *http.Transport) *http.Client { + if oAuth2ClientId != "" { + oauth2Config := oauth2.Config{ + ClientID: oAuth2ClientId, + ClientSecret: oAuth2ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: oAuth2TokenUrl, + }, + } + + token := &oauth2.Token{RefreshToken: oAuth2RefreshToken} + return oauth2Config.Client(context.Background(), token) + } else { + return &http.Client{ + Transport: transport, + } + } +} + // HTTPTargetConfigFunction creates HTTPTarget from HTTPTargetConfig func HTTPTargetConfigFunction(c *HTTPTargetConfig) (*HTTPTarget, error) { return newHTTPTarget( @@ -144,6 +171,10 @@ func HTTPTargetConfigFunction(c *HTTPTargetConfig) (*HTTPTarget, error) { c.CaFile, c.SkipVerifyTLS, c.DynamicHeaders, + c.OAuth2ClientId, + c.OAuth2ClientSecret, + c.OAuth2RefreshToken, + c.OAuth2TokenUrl, ) } diff --git a/pkg/target/http_test.go b/pkg/target/http_test.go index 456c6024..c4a83b87 100644 --- a/pkg/target/http_test.go +++ b/pkg/target/http_test.go @@ -305,12 +305,12 @@ func TestAddHeadersToRequest_WithDynamicHeaders(t *testing.T) { func TestNewHTTPTarget(t *testing.T) { assert := assert.New(t) - httpTarget, err := newHTTPTarget("http://something", 5, 1048576, "application/json", "", "", "", "", "", "", true, false) + httpTarget, err := newHTTPTarget("http://something", 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "") assert.Nil(err) assert.NotNil(httpTarget) - failedHTTPTarget, err1 := newHTTPTarget("something", 5, 1048576, "application/json", "", "", "", "", "", "", true, false) + failedHTTPTarget, err1 := newHTTPTarget("something", 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "") assert.NotNil(err1) if err1 != nil { @@ -318,7 +318,7 @@ func TestNewHTTPTarget(t *testing.T) { } assert.Nil(failedHTTPTarget) - failedHTTPTarget2, err2 := newHTTPTarget("", 5, 1048576, "application/json", "", "", "", "", "", "", true, false) + failedHTTPTarget2, err2 := newHTTPTarget("", 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "") assert.NotNil(err2) if err2 != nil { assert.Equal("Invalid url for HTTP target: ''", err2.Error()) @@ -345,7 +345,7 @@ func TestHttpWrite_Simple(t *testing.T) { server := createTestServerWithResponseCode(&results, &wg, tt.ResponseCode) defer server.Close() - target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false) + target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "") if err != nil { t.Fatal(err) } @@ -381,7 +381,7 @@ func TestHttpWrite_Concurrent(t *testing.T) { server := createTestServer(&results, &wg) defer server.Close() - target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false) + target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "") if err != nil { t.Fatal(err) } @@ -424,7 +424,7 @@ func TestHttpWrite_Failure(t *testing.T) { server := createTestServer(&results, &wg) defer server.Close() - target, err := newHTTPTarget("http://NonexistentEndpoint", 5, 1048576, "application/json", "", "", "", "", "", "", true, false) + target, err := newHTTPTarget("http://NonexistentEndpoint", 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "") if err != nil { t.Fatal(err) } @@ -465,7 +465,7 @@ func TestHttpWrite_InvalidResponseCode(t *testing.T) { wg := sync.WaitGroup{} server := createTestServerWithResponseCode(&results, &wg, tt.ResponseCode) defer server.Close() - target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false) + target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "") if err != nil { t.Fatal(err) } @@ -499,7 +499,7 @@ func TestHttpWrite_Oversized(t *testing.T) { server := createTestServer(&results, &wg) defer server.Close() - target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false) + target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "") if err != nil { t.Fatal(err) } @@ -555,7 +555,11 @@ func TestHttpWrite_TLS(t *testing.T) { string(`../../integration/http/localhost.key`), string(`../../integration/http/rootCA.crt`), false, - false) + false, + "", + "", + "", + "") if err != nil { t.Fatal(err) } @@ -588,7 +592,11 @@ func TestHttpWrite_TLS(t *testing.T) { string(`../../integration/http/localhost.key`), string(`../../integration/http/rootCA.crt`), false, - false) + false, + "", + "", + "", + "") if err2 != nil { t.Fatal(err2) } @@ -614,7 +622,11 @@ func TestHttpWrite_TLS(t *testing.T) { "", "", false, - false) + false, + "", + "", + "", + "") if err4 != nil { t.Fatal(err4) } diff --git a/pkg/target/oauth2_test.go b/pkg/target/oauth2_test.go new file mode 100644 index 00000000..58bdad15 --- /dev/null +++ b/pkg/target/oauth2_test.go @@ -0,0 +1,128 @@ +/** + * Copyright (c) 2020-present Snowplow Analytics Ltd. + * All rights reserved. + * + * This software is made available by Snowplow Analytics, Ltd., + * under the terms of the Snowplow Limited Use License Agreement, Version 1.0 + * located at https://docs.snowplow.io/limited-use-license-1.0 + * BY INSTALLING, DOWNLOADING, ACCESSING, USING OR DISTRIBUTING ANY PORTION + * OF THE SOFTWARE, YOU AGREE TO THE TERMS OF SUCH LICENSE AGREEMENT. + */ + +package target + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/snowplow/snowbridge/pkg/models" + "github.com/snowplow/snowbridge/pkg/testutil" + "github.com/stretchr/testify/assert" +) + +// that's what we configure in our target +const validClientId = "CLIENT_ID_TEST" +const validClientSecret = "CLIENT_SECRET_TEST" +const validRefreshToken = "REFRESH_TOKEN_TEST" +const validGrantType = "refresh_token" + +// that's what is returned by mock token server and used as bearer token to authorize request to target server +const validAccessToken = "super_secret_access_token" + +// This is mock server providing us the bearer access token. If you provide invalid details/something is misconfigured you get 400 HTTP status +func tokenServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req.ParseForm() + clientId, clientSecret, _ := req.BasicAuth() + refreshToken := req.Form.Get("refresh_token") + grantType := req.Form.Get("grant_type") + + if clientId == validClientId && clientSecret == validClientSecret && refreshToken == validRefreshToken && grantType == validGrantType { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + fmt.Fprintf(w, `{"access_token":"%s", "expires_in":3600}`, validAccessToken) + } else { + w.WriteHeader(400) + fmt.Fprintf(w, `{"error":"invalid_client"}`) + } + })) +} + +// This is mock target server which requires us to provide valid access token. Without valid token you set 403 HTTP status +func targetServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.Header.Get("Authorization") == fmt.Sprintf("Bearer %s", validAccessToken) { + w.WriteHeader(200) + } else { + w.WriteHeader(403) + fmt.Fprintf(w, "Invalid access token") + } + })) +} + +func TestHttpOAuth2_Success(t *testing.T) { + assert := assert.New(t) + + writeResult, err := runTest(t, validClientId, validClientSecret, validRefreshToken) + + assert.Nil(err) + assert.Equal(1, len(writeResult.Sent)) + assert.Equal(0, len(writeResult.Failed)) +} + +func TestHttpOAuth2_CanNotFetchToken(t *testing.T) { + testCases := []struct { + Name string + InputClientId string + InputClientSecret string + InputRefreshToken string + }{ + {Name: "Invalid client id", InputClientId: "INVALID", InputClientSecret: validClientSecret, InputRefreshToken: validRefreshToken}, + {Name: "Invalid client secret", InputClientId: validClientId, InputClientSecret: "INVALID", InputRefreshToken: validRefreshToken}, + {Name: "Invalid refresh token", InputClientId: validClientId, InputClientSecret: validClientSecret, InputRefreshToken: "INVALID"}, + } + + for _, tt := range testCases { + t.Run(tt.Name, func(t *testing.T) { + assert := assert.New(t) + writeResult, err := runTest(t, tt.InputClientId, tt.InputClientSecret, tt.InputRefreshToken) + + assert.NotNil(err) + assert.Contains(err.Error(), `{"error":"invalid_client"}`) + assert.Equal(0, len(writeResult.Sent)) + assert.Equal(1, len(writeResult.Failed)) + }) + } +} + +func TestHttpOAuth2_CallTargetWithoutToken(t *testing.T) { + assert := assert.New(t) + writeResult, err := runTest(t, "", "", "") + + assert.NotNil(err) + assert.Contains(err.Error(), `Got response status: 403 Forbidden`) + assert.Equal(0, len(writeResult.Sent)) + assert.Equal(1, len(writeResult.Failed)) +} + +func runTest(t *testing.T, inputClientId string, inputClientSecret string, inputRefreshToken string) (*models.TargetWriteResult, error) { + tokenServer := tokenServer() + server := targetServer() + defer tokenServer.Close() + defer server.Close() + + target := oauth2Target(t, server.URL, inputClientId, inputClientSecret, inputRefreshToken, tokenServer.URL) + + message := testutil.GetTestMessages(1, "Hello Server!!", func() {}) + return target.Write(message) +} + +func oauth2Target(t *testing.T, targetUrl string, inputClientId string, inputClientSecret string, inputRefreshToken string, tokenServerUrl string) *HTTPTarget { + target, err := newHTTPTarget(targetUrl, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, inputClientId, inputClientSecret, inputRefreshToken, tokenServerUrl) + if err != nil { + t.Fatal(err) + } + return target +}