Skip to content

Commit

Permalink
Add support for OAuth2 in HTTP target
Browse files Browse the repository at this point in the history
Inspired by official OAuth2 [RFC](https://www.rfc-editor.org/rfc/rfc6749) and example API provided by [Google Ads](https://developers.google.com/google-ads/api/rest/auth)
  • Loading branch information
pondzix committed May 23, 2024
1 parent 9f307e5 commit 2bfa424
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 16 deletions.
12 changes: 12 additions & 0 deletions assets/docs/configuration/targets/http-full-example.hcl
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,17 @@ target {

# Whether to enable setting headers dynamically
dynamic_headers = true

# Optional. One of client credentials required when authorizing using OAuth2.
oauth2_client_id = env.CLIENT_ID

# Optional. One of client credentials required when authorizing using OAuth2.
oauth2_client_secret = env.CLIENT_SECRET

# Optional. Required when using OAuth2. Long-lived token used to generate new short-lived access token when previous one experies.
oauth2_refresh_token = env.REFRESH_TOKEN

# Optional. Required when using OAuth2. URL to authorization server providing access token. E.g. for Goggle API "https://oauth2.googleapis.com/token"
oauth2_token_url = "https://my.auth.server/token"
}
}
41 changes: 36 additions & 5 deletions pkg/target/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package target

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -25,6 +26,8 @@ import (

"github.com/snowplow/snowbridge/pkg/common"
"github.com/snowplow/snowbridge/pkg/models"

"golang.org/x/oauth2"
)

// HTTPTargetConfig configures the destination for records consumed
Expand All @@ -41,6 +44,11 @@ type HTTPTargetConfig struct {
CaFile string `hcl:"ca_file,optional" env:"TARGET_HTTP_TLS_CA_FILE"`
SkipVerifyTLS bool `hcl:"skip_verify_tls,optional" env:"TARGET_HTTP_TLS_SKIP_VERIFY_TLS"` // false
DynamicHeaders bool `hcl:"dynamic_headers,optional" env:"TARGET_HTTP_DYNAMIC_HEADERS"`

OAuth2ClientId string `hcl:"oauth2_client_id,optional" env:"TARGET_HTTP_OAUTH2_CLIENT_ID"`
OAuth2ClientSecret string `hcl:"oauth2_client_secret,optional" env:"TARGET_HTTP_OAUTH2_CLIENT_SECRET"`
OAuth2RefreshToken string `hcl:"oauth2_refresh_token,optional" env:"TARGET_HTTP_OAUTH2_REFRESH_TOKEN"`
OAuth2TokenUrl string `hcl:"oauth2_token_url,optional" env:"TARGET_HTTP_OAUTH2_TOKEN_URL"`
}

// HTTPTarget holds a new client for writing messages to HTTP endpoints
Expand Down Expand Up @@ -94,7 +102,7 @@ func addHeadersToRequest(request *http.Request, headers map[string]string, dynam

// newHTTPTarget creates a client for writing events to HTTP
func newHTTPTarget(httpURL string, requestTimeout int, byteLimit int, contentType string, headers string, basicAuthUsername string, basicAuthPassword string,
certFile string, keyFile string, caFile string, skipVerifyTLS bool, dynamicHeaders bool) (*HTTPTarget, error) {
certFile string, keyFile string, caFile string, skipVerifyTLS bool, dynamicHeaders bool, oAuth2ClientId string, oAuth2ClientSecret string, oAuth2RefreshToken string, oAuth2TokenUrl string) (*HTTPTarget, error) {
err := checkURL(httpURL)
if err != nil {
return nil, err
Expand All @@ -113,11 +121,11 @@ func newHTTPTarget(httpURL string, requestTimeout int, byteLimit int, contentTyp
transport.TLSClientConfig = tlsConfig
}

client := createHttpClient(oAuth2ClientId, oAuth2ClientSecret, oAuth2TokenUrl, oAuth2RefreshToken, transport)
client.Timeout = time.Duration(requestTimeout) * time.Second

return &HTTPTarget{
client: &http.Client{
Transport: transport,
Timeout: time.Duration(requestTimeout) * time.Second,
},
client: client,
httpURL: httpURL,
byteLimit: byteLimit,
contentType: contentType,
Expand All @@ -129,6 +137,25 @@ func newHTTPTarget(httpURL string, requestTimeout int, byteLimit int, contentTyp
}, nil
}

func createHttpClient(oAuth2ClientId string, oAuth2ClientSecret string, oAuth2TokenUrl string, oAuth2RefreshToken string, transport *http.Transport) *http.Client {
if oAuth2ClientId != "" {
oauth2Config := oauth2.Config{
ClientID: oAuth2ClientId,
ClientSecret: oAuth2ClientSecret,
Endpoint: oauth2.Endpoint{
TokenURL: oAuth2TokenUrl,
},
}

token := &oauth2.Token{RefreshToken: oAuth2RefreshToken}
return oauth2Config.Client(context.Background(), token)
} else {
return &http.Client{
Transport: transport,
}
}
}

// HTTPTargetConfigFunction creates HTTPTarget from HTTPTargetConfig
func HTTPTargetConfigFunction(c *HTTPTargetConfig) (*HTTPTarget, error) {
return newHTTPTarget(
Expand All @@ -144,6 +171,10 @@ func HTTPTargetConfigFunction(c *HTTPTargetConfig) (*HTTPTarget, error) {
c.CaFile,
c.SkipVerifyTLS,
c.DynamicHeaders,
c.OAuth2ClientId,
c.OAuth2ClientSecret,
c.OAuth2RefreshToken,
c.OAuth2TokenUrl,
)
}

Expand Down
34 changes: 23 additions & 11 deletions pkg/target/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,20 +305,20 @@ func TestAddHeadersToRequest_WithDynamicHeaders(t *testing.T) {
func TestNewHTTPTarget(t *testing.T) {
assert := assert.New(t)

httpTarget, err := newHTTPTarget("http://something", 5, 1048576, "application/json", "", "", "", "", "", "", true, false)
httpTarget, err := newHTTPTarget("http://something", 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "")

assert.Nil(err)
assert.NotNil(httpTarget)

failedHTTPTarget, err1 := newHTTPTarget("something", 5, 1048576, "application/json", "", "", "", "", "", "", true, false)
failedHTTPTarget, err1 := newHTTPTarget("something", 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "")

assert.NotNil(err1)
if err1 != nil {
assert.Equal("Invalid url for HTTP target: 'something'", err1.Error())
}
assert.Nil(failedHTTPTarget)

failedHTTPTarget2, err2 := newHTTPTarget("", 5, 1048576, "application/json", "", "", "", "", "", "", true, false)
failedHTTPTarget2, err2 := newHTTPTarget("", 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "")
assert.NotNil(err2)
if err2 != nil {
assert.Equal("Invalid url for HTTP target: ''", err2.Error())
Expand All @@ -345,7 +345,7 @@ func TestHttpWrite_Simple(t *testing.T) {
server := createTestServerWithResponseCode(&results, &wg, tt.ResponseCode)
defer server.Close()

target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false)
target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -381,7 +381,7 @@ func TestHttpWrite_Concurrent(t *testing.T) {
server := createTestServer(&results, &wg)
defer server.Close()

target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false)
target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -424,7 +424,7 @@ func TestHttpWrite_Failure(t *testing.T) {
server := createTestServer(&results, &wg)
defer server.Close()

target, err := newHTTPTarget("http://NonexistentEndpoint", 5, 1048576, "application/json", "", "", "", "", "", "", true, false)
target, err := newHTTPTarget("http://NonexistentEndpoint", 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -465,7 +465,7 @@ func TestHttpWrite_InvalidResponseCode(t *testing.T) {
wg := sync.WaitGroup{}
server := createTestServerWithResponseCode(&results, &wg, tt.ResponseCode)
defer server.Close()
target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false)
target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -499,7 +499,7 @@ func TestHttpWrite_Oversized(t *testing.T) {
server := createTestServer(&results, &wg)
defer server.Close()

target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false)
target, err := newHTTPTarget(server.URL, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, "", "", "", "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -555,7 +555,11 @@ func TestHttpWrite_TLS(t *testing.T) {
string(`../../integration/http/localhost.key`),
string(`../../integration/http/rootCA.crt`),
false,
false)
false,
"",
"",
"",
"")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -588,7 +592,11 @@ func TestHttpWrite_TLS(t *testing.T) {
string(`../../integration/http/localhost.key`),
string(`../../integration/http/rootCA.crt`),
false,
false)
false,
"",
"",
"",
"")
if err2 != nil {
t.Fatal(err2)
}
Expand All @@ -614,7 +622,11 @@ func TestHttpWrite_TLS(t *testing.T) {
"",
"",
false,
false)
false,
"",
"",
"",
"")
if err4 != nil {
t.Fatal(err4)
}
Expand Down
128 changes: 128 additions & 0 deletions pkg/target/oauth2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/**
* Copyright (c) 2020-present Snowplow Analytics Ltd.
* All rights reserved.
*
* This software is made available by Snowplow Analytics, Ltd.,
* under the terms of the Snowplow Limited Use License Agreement, Version 1.0
* located at https://docs.snowplow.io/limited-use-license-1.0
* BY INSTALLING, DOWNLOADING, ACCESSING, USING OR DISTRIBUTING ANY PORTION
* OF THE SOFTWARE, YOU AGREE TO THE TERMS OF SUCH LICENSE AGREEMENT.
*/

package target

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/snowplow/snowbridge/pkg/models"
"github.com/snowplow/snowbridge/pkg/testutil"
"github.com/stretchr/testify/assert"
)

// that's what we configure in our target
const validClientId = "CLIENT_ID_TEST"
const validClientSecret = "CLIENT_SECRET_TEST"
const validRefreshToken = "REFRESH_TOKEN_TEST"
const validGrantType = "refresh_token"

// that's what is returned by mock token server and used as bearer token to authorize request to target server
const validAccessToken = "super_secret_access_token"

// This is mock server providing us the bearer access token. If you provide invalid details/something is misconfigured you get 400 HTTP status
func tokenServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.ParseForm()
clientId, clientSecret, _ := req.BasicAuth()
refreshToken := req.Form.Get("refresh_token")
grantType := req.Form.Get("grant_type")

if clientId == validClientId && clientSecret == validClientSecret && refreshToken == validRefreshToken && grantType == validGrantType {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
fmt.Fprintf(w, `{"access_token":"%s", "expires_in":3600}`, validAccessToken)
} else {
w.WriteHeader(400)
fmt.Fprintf(w, `{"error":"invalid_client"}`)
}
}))
}

// This is mock target server which requires us to provide valid access token. Without valid token you set 403 HTTP status
func targetServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") == fmt.Sprintf("Bearer %s", validAccessToken) {
w.WriteHeader(200)
} else {
w.WriteHeader(403)
fmt.Fprintf(w, "Invalid access token")
}
}))
}

func TestHttpOAuth2_Success(t *testing.T) {
assert := assert.New(t)

writeResult, err := runTest(t, validClientId, validClientSecret, validRefreshToken)

assert.Nil(err)
assert.Equal(1, len(writeResult.Sent))
assert.Equal(0, len(writeResult.Failed))
}

func TestHttpOAuth2_CanNotFetchToken(t *testing.T) {
testCases := []struct {
Name string
InputClientId string
InputClientSecret string
InputRefreshToken string
}{
{Name: "Invalid client id", InputClientId: "INVALID", InputClientSecret: validClientSecret, InputRefreshToken: validRefreshToken},
{Name: "Invalid client secret", InputClientId: validClientId, InputClientSecret: "INVALID", InputRefreshToken: validRefreshToken},
{Name: "Invalid refresh token", InputClientId: validClientId, InputClientSecret: validClientSecret, InputRefreshToken: "INVALID"},
}

for _, tt := range testCases {
t.Run(tt.Name, func(t *testing.T) {
assert := assert.New(t)
writeResult, err := runTest(t, tt.InputClientId, tt.InputClientSecret, tt.InputRefreshToken)

assert.NotNil(err)
assert.Contains(err.Error(), `{"error":"invalid_client"}`)
assert.Equal(0, len(writeResult.Sent))
assert.Equal(1, len(writeResult.Failed))
})
}
}

func TestHttpOAuth2_CallTargetWithoutToken(t *testing.T) {
assert := assert.New(t)
writeResult, err := runTest(t, "", "", "")

assert.NotNil(err)
assert.Contains(err.Error(), `Got response status: 403 Forbidden`)
assert.Equal(0, len(writeResult.Sent))
assert.Equal(1, len(writeResult.Failed))
}

func runTest(t *testing.T, inputClientId string, inputClientSecret string, inputRefreshToken string) (*models.TargetWriteResult, error) {
tokenServer := tokenServer()
server := targetServer()
defer tokenServer.Close()
defer server.Close()

target := oauth2Target(t, server.URL, inputClientId, inputClientSecret, inputRefreshToken, tokenServer.URL)

message := testutil.GetTestMessages(1, "Hello Server!!", func() {})
return target.Write(message)
}

func oauth2Target(t *testing.T, targetUrl string, inputClientId string, inputClientSecret string, inputRefreshToken string, tokenServerUrl string) *HTTPTarget {
target, err := newHTTPTarget(targetUrl, 5, 1048576, "application/json", "", "", "", "", "", "", true, false, inputClientId, inputClientSecret, inputRefreshToken, tokenServerUrl)
if err != nil {
t.Fatal(err)
}
return target
}

0 comments on commit 2bfa424

Please sign in to comment.