Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Azure Workload Identity to authorize against RabbitMQ manageme… #4657

Merged
merged 4 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio
- **Pulsar Scaler**: Improve error messages for unsuccessful connections ([#4563](https://github.com/kedacore/keda/issues/4563))
- **Security**: Enable secret scanning in GitHub repo
- **RabbitMQ Scaler**: Add support for `unsafeSsl` in trigger metadata ([#4448](https://github.com/kedacore/keda/issues/4448))
- **RabbitMQ Scaler**: Add support for `workloadIdentityResource` and utilize AzureAD Workload Identity for HTTP authorization ([#4716](https://github.com/kedacore/keda/issues/4716))
- **PostgreSQL Scaler**: Replace `lib/pq` with `pgx` ([#4704](https://github.com/kedacore/keda/issues/4704))
- **Prometheus Metrics**: Add new metric with KEDA build info ([#4647](https://github.com/kedacore/keda/issues/4647))
- **Prometheus Scaler**: Add support for Google Managed Prometheus ([#4675](https://github.com/kedacore/keda/pull/4675))
Expand Down
54 changes: 46 additions & 8 deletions pkg/scalers/rabbitmq_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
v2 "k8s.io/api/autoscaling/v2"
"k8s.io/metrics/pkg/apis/external_metrics"

"github.com/kedacore/keda/v2/apis/keda/v1alpha1"
"github.com/kedacore/keda/v2/pkg/scalers/azure"
kedautil "github.com/kedacore/keda/v2/pkg/util"
)

Expand Down Expand Up @@ -59,6 +61,7 @@ type rabbitMQScaler struct {
connection *amqp.Connection
channel *amqp.Channel
httpClient *http.Client
azureOAuth *azure.ADWorkloadIdentityTokenProvider
logger logr.Logger
}

Expand All @@ -85,6 +88,10 @@ type rabbitMQMetadata struct {
keyPassword string
enableTLS bool
unsafeSsl bool

// token provider for azure AD
workloadIdentityClientID string
workloadIdentityResource string
}

type queueInfo struct {
Expand Down Expand Up @@ -233,6 +240,13 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {

meta.keyPassword = config.AuthParams["keyPassword"]

if config.PodIdentity.Provider == v1alpha1.PodIdentityProviderAzureWorkload {
if config.AuthParams["workloadIdentityResource"] != "" {
meta.workloadIdentityClientID = config.PodIdentity.IdentityID
meta.workloadIdentityResource = config.AuthParams["workloadIdentityResource"]
}
}

certGiven := meta.cert != ""
keyGiven := meta.key != ""
if certGiven != keyGiven {
Expand Down Expand Up @@ -264,6 +278,10 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
}
}

if meta.protocol == amqpProtocol && config.AuthParams["workloadIdentityResource"] != "" {
return nil, fmt.Errorf("workload identity is not supported for amqp protocol currently")
}

// Resolve queueName
if val, ok := config.TriggerMetadata["queueName"]; ok {
meta.queueName = val
Expand Down Expand Up @@ -464,9 +482,9 @@ func (s *rabbitMQScaler) Close(context.Context) error {
return nil
}

func (s *rabbitMQScaler) getQueueStatus() (int64, float64, error) {
func (s *rabbitMQScaler) getQueueStatus(ctx context.Context) (int64, float64, error) {
if s.metadata.protocol == httpProtocol {
info, err := s.getQueueInfoViaHTTP()
info, err := s.getQueueInfoViaHTTP(ctx)
if err != nil {
return -1, -1, err
}
Expand All @@ -488,12 +506,32 @@ func (s *rabbitMQScaler) getQueueStatus() (int64, float64, error) {
return int64(items.Messages), 0, nil
}

func getJSON(s *rabbitMQScaler, url string) (queueInfo, error) {
func getJSON(ctx context.Context, s *rabbitMQScaler, url string) (queueInfo, error) {
var result queueInfo
r, err := s.httpClient.Get(url)

request, err := http.NewRequest("GET", url, nil)
if err != nil {
return result, err
}

if s.metadata.workloadIdentityResource != "" {
if s.azureOAuth == nil {
s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.workloadIdentityClientID, s.metadata.workloadIdentityResource)
}

err = s.azureOAuth.Refresh()
if err != nil {
return result, err
}

request.Header.Set("Authorization", "Bearer "+s.azureOAuth.OAuthToken())
}

r, err := s.httpClient.Do(request)
if err != nil {
return result, err
}

defer r.Body.Close()

if r.StatusCode == 200 {
Expand All @@ -518,7 +556,7 @@ func getJSON(s *rabbitMQScaler, url string) (queueInfo, error) {
return result, fmt.Errorf("error requesting rabbitMQ API status: %s, response: %s, from: %s", r.Status, body, url)
}

func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) {
func (s *rabbitMQScaler) getQueueInfoViaHTTP(ctx context.Context) (*queueInfo, error) {
parsedURL, err := url.Parse(s.metadata.host)

if err != nil {
Expand Down Expand Up @@ -547,7 +585,7 @@ func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) {
}

var info queueInfo
info, err = getJSON(s, getQueueInfoManagementURI)
info, err = getJSON(ctx, s, getQueueInfoManagementURI)

if err != nil {
return nil, err
Expand All @@ -572,8 +610,8 @@ func (s *rabbitMQScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpe
}

// GetMetricsAndActivity returns value for a supported metric and an error if there is a problem getting the metric
func (s *rabbitMQScaler) GetMetricsAndActivity(_ context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
messages, publishRate, err := s.getQueueStatus()
func (s *rabbitMQScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
messages, publishRate, err := s.getQueueStatus(ctx)
if err != nil {
return []external_metrics.ExternalMetricValue{}, false, s.anonymizeRabbitMQError(err)
}
Expand Down
38 changes: 26 additions & 12 deletions pkg/scalers/rabbitmq_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"time"

"github.com/stretchr/testify/assert"

"github.com/kedacore/keda/v2/apis/keda/v1alpha1"
)

const (
Expand All @@ -24,10 +26,12 @@ type parseRabbitMQMetadataTestData struct {
}

type parseRabbitMQAuthParamTestData struct {
metadata map[string]string
authParams map[string]string
isError bool
enableTLS bool
metadata map[string]string
podIdentity v1alpha1.AuthPodIdentity
authParams map[string]string
isError bool
enableTLS bool
workloadIdentity bool
}

type rabbitMQMetricIdentifier struct {
Expand Down Expand Up @@ -134,19 +138,23 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{
}

var testRabbitMQAuthParamData = []parseRabbitMQAuthParamTestData{
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true, false},
// success, TLS cert/key and assumed public CA
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true, false},
// success, TLS cert/key + key password and assumed public CA
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true, false},
// success, TLS CA only
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true, false},
// failure, TLS missing cert
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true, false},
// failure, TLS missing key
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true, false},
// failure, TLS invalid
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true, false},
// success, WorkloadIdentity
{map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "http"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: "client-id"}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, false, false, true},
// failure, WoekloadIdentity not supported for amqp
{map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "amqp"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: "client-id"}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, true, false, false},
}
var rabbitMQMetricIdentifiers = []rabbitMQMetricIdentifier{
{&testRabbitMQMetadata[1], 0, "s0-rabbitmq-sample"},
Expand Down Expand Up @@ -177,7 +185,7 @@ func TestRabbitMQParseMetadata(t *testing.T) {

func TestRabbitMQParseAuthParamData(t *testing.T) {
for _, testData := range testRabbitMQAuthParamData {
metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams, PodIdentity: testData.podIdentity})
if err != nil && !testData.isError {
t.Error("Expected success but got error", err)
}
Expand All @@ -201,6 +209,12 @@ func TestRabbitMQParseAuthParamData(t *testing.T) {
t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["keyPassword"], metadata.key)
}
}
if metadata != nil && metadata.workloadIdentityClientID != "" && !testData.workloadIdentity {
t.Errorf("Expected workloadIdentity to be disabled but got %v as client ID and %v as resource\n", metadata.workloadIdentityClientID, metadata.workloadIdentityResource)
}
if metadata != nil && metadata.workloadIdentityClientID == "" && testData.workloadIdentity {
t.Error("Expected workloadIdentity to be enabled but was not\n")
}
}
}

Expand Down
63 changes: 53 additions & 10 deletions tests/scalers/rabbitmq/rabbitmq_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ data:
default_vhost = {{.VHostName}}
management.tcp.port = 15672
management.tcp.ip = 0.0.0.0
{{if .EnableOAuth}}
auth_backends.1 = rabbit_auth_backend_internal
auth_backends.2 = rabbit_auth_backend_oauth2
auth_backends.3 = rabbit_auth_backend_amqp
auth_oauth2.resource_server_id = {{.OAuthClientID}}
auth_oauth2.scope_prefix = rabbitmq.
auth_oauth2.additional_scopes_key = {{.OAuthScopesKey}}
auth_oauth2.jwks_url = {{.OAuthJwksURI}}
{{end}}
enabled_plugins: |
[rabbitmq_management].
---
Expand Down Expand Up @@ -158,33 +167,67 @@ spec:
`
)

type RabbitOAuthConfig struct {
Enable bool
ClientID string
ScopesKey string
JwksURI string
}

func WithoutOAuth() RabbitOAuthConfig {
return RabbitOAuthConfig{
Enable: false,
}
}

func WithAzureADOAuth(tenantID string, clientID string) RabbitOAuthConfig {
return RabbitOAuthConfig{
Enable: true,
ClientID: clientID,
ScopesKey: "roles",
JwksURI: fmt.Sprintf("https://login.microsoftonline.com/%s/discovery/keys", tenantID),
}
}

type templateData struct {
Namespace string
Connection string
QueueName string
HostName, VHostName string
Username, Password string
MessageCount int
EnableOAuth bool
OAuthClientID string
OAuthScopesKey string
OAuthJwksURI string
}

func RMQInstall(t *testing.T, kc *kubernetes.Clientset, namespace, user, password, vhost string) {
func RMQInstall(t *testing.T, kc *kubernetes.Clientset, namespace, user, password, vhost string, oauth RabbitOAuthConfig) {
helper.CreateNamespace(t, kc, namespace)
data := templateData{
Namespace: namespace,
VHostName: vhost,
Username: user,
Password: password,
Namespace: namespace,
VHostName: vhost,
Username: user,
Password: password,
EnableOAuth: oauth.Enable,
OAuthClientID: oauth.ClientID,
OAuthScopesKey: oauth.ScopesKey,
OAuthJwksURI: oauth.JwksURI,
}

helper.KubectlApplyWithTemplate(t, data, "rmqDeploymentTemplate", deploymentTemplate)
}

func RMQUninstall(t *testing.T, namespace, user, password, vhost string) {
func RMQUninstall(t *testing.T, namespace, user, password, vhost string, oauth RabbitOAuthConfig) {
data := templateData{
Namespace: namespace,
VHostName: vhost,
Username: user,
Password: password,
Namespace: namespace,
VHostName: vhost,
Username: user,
Password: password,
EnableOAuth: oauth.Enable,
OAuthClientID: oauth.ClientID,
OAuthScopesKey: oauth.ScopesKey,
OAuthJwksURI: oauth.JwksURI,
}

helper.KubectlDeleteWithTemplate(t, data, "rmqDeploymentTemplate", deploymentTemplate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

// Load environment variables from .env file
var _ = godotenv.Load("../../.env")
var _ = godotenv.Load("../../../.env")

const (
testName = "rmq-queue-amqp-test"
Expand Down Expand Up @@ -79,7 +79,7 @@ func TestScaler(t *testing.T) {
kc := GetKubernetesClient(t)
data, templates := getTemplateData()

RMQInstall(t, kc, rmqNamespace, user, password, vhost)
RMQInstall(t, kc, rmqNamespace, user, password, vhost, WithoutOAuth())
CreateKubernetesResources(t, kc, testNamespace, data, templates)

assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1),
Expand All @@ -92,7 +92,7 @@ func TestScaler(t *testing.T) {
// cleanup
t.Log("--- cleaning up ---")
DeleteKubernetesResources(t, testNamespace, data, templates)
RMQUninstall(t, rmqNamespace, user, password, vhost)
RMQUninstall(t, rmqNamespace, user, password, vhost, WithoutOAuth())
}

func getTemplateData() (templateData, []Template) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

// Load environment variables from .env file
var _ = godotenv.Load("../../.env")
var _ = godotenv.Load("../../../.env")

const (
testName = "rmq-queue-amqp-vhost-test"
Expand Down Expand Up @@ -79,7 +79,7 @@ func TestScaler(t *testing.T) {
kc := GetKubernetesClient(t)
data, templates := getTemplateData()

RMQInstall(t, kc, rmqNamespace, user, password, vhost)
RMQInstall(t, kc, rmqNamespace, user, password, vhost, WithoutOAuth())
CreateKubernetesResources(t, kc, testNamespace, data, templates)

assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1),
Expand All @@ -92,7 +92,7 @@ func TestScaler(t *testing.T) {
// cleanup
t.Log("--- cleaning up ---")
DeleteKubernetesResources(t, testNamespace, data, templates)
RMQUninstall(t, rmqNamespace, user, password, vhost)
RMQUninstall(t, rmqNamespace, user, password, vhost, WithoutOAuth())
}

func getTemplateData() (templateData, []Template) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

// Load environment variables from .env file
var _ = godotenv.Load("../../.env")
var _ = godotenv.Load("../../../.env")

const (
testName = "rmq-queue-http-test"
Expand Down Expand Up @@ -80,7 +80,7 @@ func TestScaler(t *testing.T) {
kc := GetKubernetesClient(t)
data, templates := getTemplateData()

RMQInstall(t, kc, rmqNamespace, user, password, vhost)
RMQInstall(t, kc, rmqNamespace, user, password, vhost, WithoutOAuth())
CreateKubernetesResources(t, kc, testNamespace, data, templates)

assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1),
Expand All @@ -91,7 +91,7 @@ func TestScaler(t *testing.T) {
// cleanup
t.Log("--- cleaning up ---")
DeleteKubernetesResources(t, testNamespace, data, templates)
RMQUninstall(t, rmqNamespace, user, password, vhost)
RMQUninstall(t, rmqNamespace, user, password, vhost, WithoutOAuth())
}

func getTemplateData() (templateData, []Template) {
Expand Down
Loading