Skip to content

Commit

Permalink
Merge pull request #121 from azaurus1/119-implement-auth-type-functio…
Browse files Browse the repository at this point in the history
…nality-for-client

added bearer functionality + implemented checking for multiple auth types
  • Loading branch information
azaurus1 authored Mar 26, 2024
2 parents cc9e530 + 0a26eab commit b4bc29f
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 6 deletions.
105 changes: 99 additions & 6 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,144 @@ import (
"os"
)

const controllerUrl = "controllerUrl"
const authToken = "authToken"
const authType = "authType"
const logger = "logger"

type Opt interface {
apply(*cfg)
Type() string
}

type cfg struct {
controllerUrl string
authToken string
authType string
httpAuthWriter httpAuthWriter
logger *slog.Logger
}

type clientOpt struct{ fn func(*cfg) }

// controllerUrlOpt is an option to set the controller url for the client
type controllerUrlOpt struct {
controllerUrl string
}

func (o *controllerUrlOpt) apply(c *cfg) {
c.controllerUrl = o.controllerUrl
}

func (o *controllerUrlOpt) Type() string {
return controllerUrl
}

// authTokenOpt is an option to set the auth token for the client
type authTokenOpt struct {
authToken string
}

func (o *authTokenOpt) apply(c *cfg) {
c.authToken = o.authToken
}

func (o *authTokenOpt) Type() string {
return authToken
}

// authTypeOpt is an option to set the auth type for the client
type authTypeOpt struct {
authType string
}

func (o *authTypeOpt) apply(c *cfg) {
c.authType = o.authType
}

func (o *authTypeOpt) Type() string {
return authType
}

// loggerOpt is an option to set the logger for the client
type loggerOpt struct {
logger *slog.Logger
}

func (o *loggerOpt) apply(c *cfg) {
c.logger = o.logger
}

func (o *loggerOpt) Type() string {
return logger
}

func (opt clientOpt) apply(cfg *cfg) { opt.fn(cfg) }

func ControllerUrl(pinotControllerUrl string) Opt {
return clientOpt{fn: func(cfg *cfg) { cfg.controllerUrl = pinotControllerUrl }}
return &controllerUrlOpt{controllerUrl: pinotControllerUrl}
}

func AuthToken(token string) Opt {
return clientOpt{fn: func(cfg *cfg) { cfg.authToken = token }}
return &authTokenOpt{authToken: token}
}

func Logger(logger *slog.Logger) Opt {
return clientOpt{fn: func(cfg *cfg) { cfg.logger = logger }}
return &loggerOpt{logger: logger}
}

func AuthType(authType string) Opt {
return &authTypeOpt{authType: authType}
}

func validateOpts(opts ...Opt) (*cfg, *url.URL, error) {

// with default auth writer that does nothing
optCfg := defaultCfg()
optCounts := make(map[string]int)
for _, opt := range opts {

switch opt.(type) {
case *authTypeOpt:
optCounts[authType]++
case *authTokenOpt:
optCounts[authToken]++
case *controllerUrlOpt:
optCounts[controllerUrl]++
case *loggerOpt:
optCounts[logger]++
default:
optCounts[opt.Type()]++
}

opt.apply(optCfg)

if optCounts[authType] > 1 {
return nil, nil, fmt.Errorf("multiple auth types provided")
}
}

// validate controller url
pinotControllerUrl, err := url.Parse(optCfg.controllerUrl)
if err != nil {
return nil, nil, fmt.Errorf("controller url is invalid: %w", err)
}

// TODO: remove the redundant check
// Currently this is designed to avoid a breaking change
if optCfg.authType != "" && optCfg.authToken == "" {
return nil, nil, fmt.Errorf("auth token is required when auth type is set")
}
// if auth token passed, handle authenticated requests
if optCfg.authToken != "" {
optCfg.httpAuthWriter = func(req *http.Request) {
req.Header.Set("Authorization", fmt.Sprintf("Basic %s", optCfg.authToken))
switch optCfg.authType {
case "Bearer":
optCfg.httpAuthWriter = func(req *http.Request) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", optCfg.authToken))
}
default:
optCfg.httpAuthWriter = func(req *http.Request) {
req.Header.Set("Authorization", fmt.Sprintf("Basic %s", optCfg.authToken))
}
}
}

Expand Down
135 changes: 135 additions & 0 deletions go-pinot-api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1548,3 +1548,138 @@ func TestGetNonEmptyTable(t *testing.T) {
assert.Equal(t, false, table.IsEmpty(), "Expected table to not be empty")

}

func TestEmptyAuthType(t *testing.T) {
mux := http.NewServeMux()

mux.HandleFunc(RouteClusterInfo, func(w http.ResponseWriter, r *http.Request) {
// Check the Authorization header
authHeader := r.Header.Get("Authorization")
assert.Equal(t, authHeader, "Basic your_token", "Expected Authorization header to be 'Basic your_token'")

w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"clusterName": "PinotCluster"}`)
})
server := httptest.NewServer(mux)
defer server.Close()

// Create a new client
client := goPinotAPI.NewPinotAPIClient(
goPinotAPI.AuthType(""),
goPinotAPI.ControllerUrl(server.URL),
goPinotAPI.AuthToken("your_token"),
)

// Make a request (replace this with an actual API call)
_, err := client.GetClusterInfo()
assert.NoError(t, err, "Expected no error from client.Get")
}

func TestBasicAuthType(t *testing.T) {
mux := http.NewServeMux()

mux.HandleFunc(RouteClusterInfo, func(w http.ResponseWriter, r *http.Request) {
// Check the Authorization header
authHeader := r.Header.Get("Authorization")
assert.Equal(t, authHeader, "Basic your_token", "Expected Authorization header to be 'Basic your_token'")

w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"clusterName": "PinotCluster"}`)
})
server := httptest.NewServer(mux)
defer server.Close()

// Create a new client
client := goPinotAPI.NewPinotAPIClient(
goPinotAPI.AuthType("Basic"),
goPinotAPI.ControllerUrl(server.URL),
goPinotAPI.AuthToken("your_token"),
)

// Make a request (replace this with an actual API call)
_, err := client.GetClusterInfo()
assert.NoError(t, err, "Expected no error from client.Get")
}

func TestBearerAuthType(t *testing.T) {
mux := http.NewServeMux()

mux.HandleFunc(RouteClusterInfo, func(w http.ResponseWriter, r *http.Request) {
// Check the Authorization header
authHeader := r.Header.Get("Authorization")
assert.Equal(t, authHeader, "Bearer your_token", "Expected Authorization header to be 'Bearer your_token'")

w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"clusterName": "PinotCluster"}`)
})
server := httptest.NewServer(mux)
defer server.Close()

// Create a new client
client := goPinotAPI.NewPinotAPIClient(
goPinotAPI.AuthType("Bearer"),
goPinotAPI.ControllerUrl(server.URL),
goPinotAPI.AuthToken("your_token"),
)

// Make a request (replace this with an actual API call)
_, err := client.GetClusterInfo()
assert.NoError(t, err, "Expected no error from client.Get")
}

func TestNoAuthType(t *testing.T) {
mux := http.NewServeMux()

mux.HandleFunc(RouteClusterInfo, func(w http.ResponseWriter, r *http.Request) {
// Check the Authorization header
authHeader := r.Header.Get("Authorization")
assert.Equal(t, authHeader, "Basic your_token", "Expected Authorization header to be 'Basic your_token'")

w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"clusterName": "PinotCluster"}`)
})
server := httptest.NewServer(mux)
defer server.Close()

// Create a new client
client := goPinotAPI.NewPinotAPIClient(
goPinotAPI.ControllerUrl(server.URL),
goPinotAPI.AuthToken("your_token"),
)

// Make a request (replace this with an actual API call)
_, err := client.GetClusterInfo()
assert.NoError(t, err, "Expected no error from client.Get")
}

func TestMultipleAuthTypes(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Logf("The code panicked with %v", r)
}
}()
mux := http.NewServeMux()

mux.HandleFunc(RouteClusterInfo, func(w http.ResponseWriter, r *http.Request) {
// Check the Authorization header
authHeader := r.Header.Get("Authorization")
assert.Equal(t, authHeader, "Basic your_token", "Expected Authorization header to be 'Basic your_token'")

w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"clusterName": "PinotCluster"}`)
})
server := httptest.NewServer(mux)
defer server.Close()

// Create a new client
client := goPinotAPI.NewPinotAPIClient(
goPinotAPI.AuthType("Bearer"),
goPinotAPI.AuthType("Basic"),
goPinotAPI.ControllerUrl(server.URL),
goPinotAPI.AuthToken("your_token"),
)

// Make a request (replace this with an actual API call)
_, err := client.GetClusterInfo()
assert.Error(t, err, "Expected error from client.Get")
}

0 comments on commit b4bc29f

Please sign in to comment.