Skip to content

Commit

Permalink
[DSG-8949] support Schema Registry bearer auth with static token (#1103)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangpei1214 authored Nov 13, 2023
1 parent ba5aa40 commit 8d530db
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func main() {
panic(fmt.Sprintf("Failed to create producer: %s", err))
}

client, err := schemaregistry.NewClient(schemaregistry.NewConfigWithAuthentication(
client, err := schemaregistry.NewClient(schemaregistry.NewConfigWithBasicAuthentication(
schemaRegistryAPIEndpoint,
schemaRegistryAPIKey,
schemaRegistryAPISecret))
Expand Down
39 changes: 36 additions & 3 deletions schemaregistry/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ type Config struct {
// SaslUsername specifies the password for SASL.
SaslPassword string

// BearerAuthToken specifies the token for authentication.
BearerAuthToken string
// BearerAuthCredentialsSource specifies how to determine the credentials.
BearerAuthCredentialsSource string
// BearerAuthLogicalCluster specifies the target SR logical cluster id. It is required for Confluent Cloud Schema Registry
BearerAuthLogicalCluster string
// BearerAuthIdentityPoolID specifies the identity pool ID. It is required for Confluent Cloud Schema Registry
BearerAuthIdentityPoolID string

// SslCertificateLocation specifies the location of SSL certificates.
SslCertificateLocation string
// SslKeyLocation specifies the location of SSL keys.
Expand All @@ -64,9 +73,6 @@ func NewConfig(url string) *Config {

c.SchemaRegistryURL = url

c.BasicAuthUserInfo = ""
c.BasicAuthCredentialsSource = "URL"

c.SaslMechanism = "GSSAPI"
c.SaslUsername = ""
c.SaslPassword = ""
Expand All @@ -84,6 +90,7 @@ func NewConfig(url string) *Config {

// NewConfigWithAuthentication returns a new configuration instance using basic authentication.
// For Confluent Cloud, use the API key for the username and the API secret for the password.
// This method is deprecated.
func NewConfigWithAuthentication(url string, username string, password string) *Config {
c := NewConfig(url)

Expand All @@ -92,3 +99,29 @@ func NewConfigWithAuthentication(url string, username string, password string) *

return c
}

// NewConfigWithBasicAuthentication returns a new configuration instance using basic authentication.
// For Confluent Cloud, use the API key for the username and the API secret for the password.
func NewConfigWithBasicAuthentication(url string, username string, password string) *Config {
c := NewConfig(url)

c.BasicAuthUserInfo = fmt.Sprintf("%s:%s", username, password)
c.BasicAuthCredentialsSource = "USER_INFO"

return c
}

// NewConfigWithBearerAuthentication returns a new configuration instance using bearer authentication.
// For Confluent Cloud, targetSr(`bearer.auth.logical.cluster` and
// identityPoolID(`bearer.auth.identity.pool.id`) is required
func NewConfigWithBearerAuthentication(url, token, targetSr, identityPoolID string) *Config {

c := NewConfig(url)

c.BearerAuthToken = token
c.BearerAuthCredentialsSource = "STATIC_TOKEN"
c.BearerAuthLogicalCluster = targetSr
c.BearerAuthIdentityPoolID = identityPoolID

return c
}
10 changes: 9 additions & 1 deletion schemaregistry/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ import (
func TestConfigWithAuthentication(t *testing.T) {
maybeFail = initFailFunc(t)

c := NewConfigWithAuthentication("mock://", "username", "password")
c := NewConfigWithBasicAuthentication("mock://", "username", "password")

maybeFail("BasicAuthCredentialsSource", expect(c.BasicAuthCredentialsSource, "USER_INFO"))
maybeFail("BasicAuthUserInfo", expect(c.BasicAuthUserInfo, "username:password"))
}

func TestConfigWithBearerAuth(t *testing.T) {
maybeFail = initFailFunc(t)
c := NewConfigWithBearerAuthentication("mock://", "token", "lsrc-123", "poolID")
maybeFail("BearerAuthCredentialsSource", expect(c.BearerAuthCredentialsSource, "STATIC_TOKEN"))
maybeFail("BearerAuthLogicalCluster", expect(c.BearerAuthLogicalCluster, "lsrc-123"))
maybeFail("BearerAuthIdentityPoolID", expect(c.BearerAuthIdentityPoolID, "poolID"))
}
68 changes: 57 additions & 11 deletions schemaregistry/rest_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ const (
subjectConfig = config + "/%s"
mode = "/mode"
modeConfig = mode + "/%s"

targetSRClusterKey = "Target-Sr-Cluster"
targetIdentityPoolIDKey = "Confluent-Identity-Pool-Id"
)

// REST API request
Expand Down Expand Up @@ -251,28 +254,71 @@ func configureUSERINFOAuth(conf *Config, header http.Header) error {

}

func configureStaticTokenAuth(conf *Config, header http.Header) error {
bearerToken := conf.BearerAuthToken
if len(bearerToken) == 0 {
return fmt.Errorf("config bearer.auth.token must be specified when bearer.auth.credentials.source is" +
" specified with STATIC_TOKEN")
}
header.Add("Authorization", fmt.Sprintf("Bearer %s", bearerToken))
setBearerAuthExtraHeaders(conf, header)
return nil
}

func setBearerAuthExtraHeaders(conf *Config, header http.Header) {
targetIdentityPoolID := conf.BearerAuthIdentityPoolID
if len(targetIdentityPoolID) > 0 {
header.Add(targetIdentityPoolIDKey, targetIdentityPoolID)
}

targetSr := conf.BearerAuthLogicalCluster
if len(targetSr) > 0 {
header.Add(targetSRClusterKey, targetSr)
}
}

// newAuthHeader returns a base64 encoded userinfo string identified on the configured credentials source
func newAuthHeader(service *url.URL, conf *Config) (http.Header, error) {
// Remove userinfo from url regardless of source to avoid confusion/conflicts
defer func() {
service.User = nil
}()

source := conf.BasicAuthCredentialsSource

header := http.Header{}

basicSource := conf.BasicAuthCredentialsSource
bearerSource := conf.BearerAuthCredentialsSource

var err error
switch strings.ToUpper(source) {
case "URL":
err = configureURLAuth(service, header)
case "SASL_INHERIT":
err = configureSASLAuth(conf, header)
case "USER_INFO":
err = configureUSERINFOAuth(conf, header)
default:
err = fmt.Errorf("unrecognized value for basic.auth.credentials.source %s", source)
if len(basicSource) != 0 && len(bearerSource) != 0 {
return header, fmt.Errorf("only one of basic.auth.credentials.source or bearer.auth.credentials.source" +
" may be specified")
} else if len(basicSource) != 0 {
switch strings.ToUpper(basicSource) {
case "URL":
err = configureURLAuth(service, header)
case "SASL_INHERIT":
err = configureSASLAuth(conf, header)
case "USER_INFO":
err = configureUSERINFOAuth(conf, header)
default:
err = fmt.Errorf("unrecognized value for basic.auth.credentials.source %s", basicSource)
}
} else if len(bearerSource) != 0 {
switch strings.ToUpper(bearerSource) {
case "STATIC_TOKEN":
err = configureStaticTokenAuth(conf, header)
//case "OAUTHBEARER":
// err = configureOauthBearerAuth(conf, header)
//case "SASL_OAUTHBEARER_INHERIT":
// err = configureSASLOauth()
//case "CUSTOM":
// err = configureCustomOauth(conf, header)
default:
err = fmt.Errorf("unrecognized value for bearer.auth.credentials.source %s", bearerSource)
}
}

return header, err
}

Expand Down
85 changes: 85 additions & 0 deletions schemaregistry/rest_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package schemaregistry

import (
"crypto/tls"
"net/url"
"strings"
"testing"
)
Expand Down Expand Up @@ -83,3 +84,87 @@ func TestConfigureTLS(t *testing.T) {
t.Errorf("Should work with valid CA, certificate and key, got %s", err)
}
}

func TestNewAuthHeader(t *testing.T) {
url, err := url.Parse("mock://")
if err != nil {
t.Errorf("Should work with empty config, got %s", err)
}

config := &Config{}

config.BearerAuthCredentialsSource = "STATIC_TOKEN"
config.BasicAuthCredentialsSource = "URL"

_, err = newAuthHeader(url, config)
if err == nil {
t.Errorf("Should not work with both basic auth source and bearer auth source")
}

// testing bearer auth
config.BasicAuthCredentialsSource = ""
_, err = newAuthHeader(url, config)
if err == nil {
t.Errorf("Should not work if bearer auth token is empty")
}

config.BearerAuthToken = "token"
config.BearerAuthLogicalCluster = "lsrc-123"
config.BearerAuthIdentityPoolID = "poolID"
headers, err := newAuthHeader(url, config)
if err != nil {
t.Errorf("Should work with bearer auth token, got %s", err)
} else {
if val, exists := headers["Authorization"]; !exists || len(val) == 0 ||
!strings.EqualFold(val[0], "Bearer token") {
t.Errorf("Should have header with key Authorization")
}
if val, exists := headers[targetIdentityPoolIDKey]; !exists || len(val) == 0 ||
!strings.EqualFold(val[0], "poolID") {
t.Errorf("Should have header with key Confluent-Identity-Pool-Id")
}
if val, exists := headers[targetSRClusterKey]; !exists || len(val) == 0 ||
!strings.EqualFold(val[0], "lsrc-123") {
t.Errorf("Should have header with key Target-Sr-Cluster")
}
}

config.BearerAuthCredentialsSource = "other"
_, err = newAuthHeader(url, config)
if err == nil {
t.Errorf("Should not work if bearer auth source is invalid")
}

// testing basic auth
config.BearerAuthCredentialsSource = ""
config.BasicAuthCredentialsSource = "USER_INFO"
config.BasicAuthUserInfo = "username:password"
_, err = newAuthHeader(url, config)
if err != nil {
t.Errorf("Should work with basic auth token, got %s", err)
}

config.BasicAuthCredentialsSource = "URL"
_, err = newAuthHeader(url, config)
if err != nil {
t.Errorf("Should work with basic auth token, got %s", err)
} else if val, exists := headers["Authorization"]; !exists || len(val) == 0 {
t.Errorf("Should have header with key Authorization")
}

config.BasicAuthCredentialsSource = "SASL_INHERIT"
config.SaslUsername = "username"
config.SaslPassword = "password"
_, err = newAuthHeader(url, config)
if err != nil {
t.Errorf("Should work with basic auth token, got %s", err)
} else if val, exists := headers["Authorization"]; !exists || len(val) == 0 {
t.Errorf("Should have header with key Authorization")
}

config.BasicAuthCredentialsSource = "other"
_, err = newAuthHeader(url, config)
if err == nil {
t.Errorf("Should not work if basic auth source is invalid")
}
}
1 change: 1 addition & 0 deletions schemaregistry/schemaregistry_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ type Client interface {
func NewClient(conf *Config) (Client, error) {

urlConf := conf.SchemaRegistryURL
// for testing
if strings.HasPrefix(urlConf, "mock://") {
url, err := url.Parse(urlConf)
if err != nil {
Expand Down

0 comments on commit 8d530db

Please sign in to comment.