Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fetch SAs from apiserver #242

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ func main() {
saInformer,
cmInformer,
composeRoleArnCache,
clientset.CoreV1(),
)
stop := make(chan struct{})
informerFactory.Start(stop)
Expand Down
76 changes: 56 additions & 20 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,26 @@
package cache

import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/aws/amazon-eks-pod-identity-webhook/pkg"
"github.com/prometheus/client_golang/prometheus"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
coreinformers "k8s.io/client-go/informers/core/v1"
"k8s.io/client-go/kubernetes"
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/client-go/tools/cache"
"k8s.io/client-go/util/retry"
"k8s.io/klog/v2"
)

Expand Down Expand Up @@ -80,8 +86,7 @@ type serviceAccountCache struct {
composeRoleArn ComposeRoleArn
defaultTokenExpiration int64
webhookUsage prometheus.Gauge
notificationHandlers map[string]chan struct{}
handlerMu sync.Mutex
notifications *notifications
}

type ComposeRoleArn struct {
Expand Down Expand Up @@ -156,20 +161,13 @@ func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (u
return false, pkg.DefaultTokenExpiration
}

func (c *serviceAccountCache) getSA(req Request) (*Entry, chan struct{}) {
func (c *serviceAccountCache) getSA(req Request) (*Entry, <-chan struct{}) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, ok := c.saCache[req.CacheKey()]
if !ok && req.RequestNotification {
klog.V(5).Infof("Service Account %s not found in cache, adding notification handler", req.CacheKey())
c.handlerMu.Lock()
defer c.handlerMu.Unlock()
notifier, found := c.notificationHandlers[req.CacheKey()]
if !found {
notifier = make(chan struct{})
c.notificationHandlers[req.CacheKey()] = notifier
}
return nil, notifier
return nil, c.notifications.create(req)
}
return entry, nil
}
Expand Down Expand Up @@ -264,13 +262,7 @@ func (c *serviceAccountCache) setSA(name, namespace string, entry *Entry) {
klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, entry)
c.saCache[key] = entry

c.handlerMu.Lock()
defer c.handlerMu.Unlock()
if handler, found := c.notificationHandlers[key]; found {
klog.V(5).Infof("Notifying handlers for %q", key)
close(handler)
delete(c.notificationHandlers, key)
}
c.notifications.broadcast(key)
}

func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) {
Expand All @@ -280,7 +272,15 @@ func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) {
c.cmCache[namespace+"/"+name] = entry
}

func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenExpiration int64, saInformer coreinformers.ServiceAccountInformer, cmInformer coreinformers.ConfigMapInformer, composeRoleArn ComposeRoleArn) ServiceAccountCache {
func New(defaultAudience,
prefix string,
defaultRegionalSTS bool,
defaultTokenExpiration int64,
saInformer coreinformers.ServiceAccountInformer,
cmInformer coreinformers.ConfigMapInformer,
composeRoleArn ComposeRoleArn,
SAGetter corev1.ServiceAccountsGetter,
) ServiceAccountCache {
hasSynced := func() bool {
if cmInformer != nil {
return saInformer.Informer().HasSynced() && cmInformer.Informer().HasSynced()
Expand All @@ -289,6 +289,8 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
}
}

// Rate limit to 10 concurrent requests against the API server.
saFetchRequests := make(chan *Request, 10)
c := &serviceAccountCache{
saCache: map[string]*Entry{},
cmCache: map[string]*Entry{},
Expand All @@ -299,9 +301,22 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
defaultTokenExpiration: defaultTokenExpiration,
hasSynced: hasSynced,
webhookUsage: webhookUsage,
notificationHandlers: map[string]chan struct{}{},
notifications: newNotifications(saFetchRequests),
}

go func() {
for req := range saFetchRequests {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would mean we are making only one request at a time to apiserver right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapped this in a goroutine - thanks for the catch 💯

go func() {
sa, err := fetchFromAPI(SAGetter, req)
if err != nil {
klog.Errorf("fetching SA: %s, but got error from API: %v", req.CacheKey(), err)
return
}
c.addSA(sa)
}()
}
}()

saInformer.Informer().AddEventHandler(
cache.ResourceEventHandlerFuncs{
AddFunc: func(obj interface{}) {
Expand Down Expand Up @@ -351,6 +366,27 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
return c
}

func fetchFromAPI(getter corev1.ServiceAccountsGetter, req *Request) (*v1.ServiceAccount, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
defer cancel()

klog.V(5).Infof("fetching SA: %s", req.CacheKey())

var sa *v1.ServiceAccount
err := retry.OnError(retry.DefaultBackoff, func(err error) bool {
return errors.IsServerTimeout(err)
}, func() error {
res, err := getter.ServiceAccounts(req.Namespace).Get(ctx, req.Name, metav1.GetOptions{})
if err != nil {
return err
}
sa = res
return nil
})

return sa, err
}

func (c *serviceAccountCache) populateCacheFromCM(oldCM, newCM *v1.ConfigMap) error {
if newCM.Name != "pod-identity-webhook" {
return nil
Expand Down
97 changes: 88 additions & 9 deletions pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func TestSaCache(t *testing.T) {
defaultAudience: "sts.amazonaws.com",
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

resp := cache.Get(Request{Name: "default", Namespace: "default"})
Expand Down Expand Up @@ -69,9 +70,9 @@ func TestNotification(t *testing.T) {

t.Run("with one notification handler", func(t *testing.T) {
cache := &serviceAccountCache{
saCache: map[string]*Entry{},
notificationHandlers: map[string]chan struct{}{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
saCache: map[string]*Entry{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

// test that the requested SA is not in the cache
Expand Down Expand Up @@ -106,9 +107,9 @@ func TestNotification(t *testing.T) {

t.Run("with 10 notification handlers", func(t *testing.T) {
cache := &serviceAccountCache{
saCache: map[string]*Entry{},
notificationHandlers: map[string]chan struct{}{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
saCache: map[string]*Entry{},
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 5)),
}

// test that the requested SA is not in the cache
Expand Down Expand Up @@ -153,6 +154,63 @@ func TestNotification(t *testing.T) {
})
}

func TestFetchFromAPIServer(t *testing.T) {
testSA := &v1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Name: "my-sa",
Namespace: "default",
Annotations: map[string]string{
"eks.amazonaws.com/role-arn": "arn:aws:iam::111122223333:role/s3-reader",
"eks.amazonaws.com/token-expiration": "3600",
},
},
}
fakeSAClient := fake.NewSimpleClientset(testSA)

// use an empty informer to simulate the need to fetch SA from api server:
fakeEmptyClient := fake.NewSimpleClientset()
emptyInformerFactory := informers.NewSharedInformerFactory(fakeEmptyClient, 0)
emptyInformer := emptyInformerFactory.Core().V1().ServiceAccounts()

cache := New(
"sts.amazonaws.com",
"eks.amazonaws.com",
true,
86400,
emptyInformer,
nil,
ComposeRoleArn{},
fakeSAClient.CoreV1(),
)

stop := make(chan struct{})
emptyInformerFactory.Start(stop)
emptyInformerFactory.WaitForCacheSync(stop)
cache.Start(stop)
defer close(stop)

err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) {
return len(fakeEmptyClient.Actions()) != 0, nil
})
if err != nil {
t.Fatalf("informer never called client: %v", err)
}

resp := cache.Get(Request{Name: "my-sa", Namespace: "default", RequestNotification: true})
assert.False(t, resp.FoundInCache, "Expected cache entry to not be found")

// wait for the notification while we fetch the SA from the API server:
select {
case <-resp.Notifier:
// expected
// test that the requested SA is now in the cache
resp := cache.Get(Request{Name: "my-sa", Namespace: "default", RequestNotification: false})
assert.True(t, resp.FoundInCache, "Expected cache entry to be found in cache")
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for notification")
}
}

func TestNonRegionalSTS(t *testing.T) {
trueStr := "true"
falseStr := "false"
Expand Down Expand Up @@ -237,7 +295,16 @@ func TestNonRegionalSTS(t *testing.T) {

testComposeRoleArn := ComposeRoleArn{}

cache := New(audience, "eks.amazonaws.com", tc.defaultRegionalSTS, 86400, informer, nil, testComposeRoleArn)
cache := New(
audience,
"eks.amazonaws.com",
tc.defaultRegionalSTS,
86400,
informer,
nil,
testComposeRoleArn,
fakeClient.CoreV1(),
)
stop := make(chan struct{})
informerFactory.Start(stop)
informerFactory.WaitForCacheSync(stop)
Expand Down Expand Up @@ -295,7 +362,8 @@ func TestPopulateCacheFromCM(t *testing.T) {
}

c := serviceAccountCache{
cmCache: make(map[string]*Entry),
cmCache: make(map[string]*Entry),
notifications: newNotifications(make(chan *Request, 10)),
}

{
Expand Down Expand Up @@ -353,6 +421,7 @@ func TestSAAnnotationRemoval(t *testing.T) {
saCache: make(map[string]*Entry),
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

c.addSA(oldSA)
Expand Down Expand Up @@ -416,6 +485,7 @@ func TestCachePrecedence(t *testing.T) {
defaultTokenExpiration: pkg.DefaultTokenExpiration,
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

{
Expand Down Expand Up @@ -514,7 +584,15 @@ func TestRoleArnComposition(t *testing.T) {
informerFactory := informers.NewSharedInformerFactory(fakeClient, 0)
informer := informerFactory.Core().V1().ServiceAccounts()

cache := New(audience, "eks.amazonaws.com", true, 86400, informer, nil, testComposeRoleArn)
cache := New(audience,
"eks.amazonaws.com",
true,
86400,
informer,
nil,
testComposeRoleArn,
fakeClient.CoreV1(),
)
stop := make(chan struct{})
informerFactory.Start(stop)
informerFactory.WaitForCacheSync(stop)
Expand Down Expand Up @@ -613,6 +691,7 @@ func TestGetCommonConfigurations(t *testing.T) {
defaultAudience: "sts.amazonaws.com",
annotationPrefix: "eks.amazonaws.com",
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
notifications: newNotifications(make(chan *Request, 10)),
}

if tc.serviceAccount != nil {
Expand Down
43 changes: 43 additions & 0 deletions pkg/cache/notifications.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package cache

import (
"sync"

"k8s.io/klog/v2"
)

type notifications struct {
handlers map[string]chan struct{}
mu sync.Mutex
fetchRequests chan<- *Request
}

func newNotifications(saFetchRequests chan<- *Request) *notifications {
return &notifications{
handlers: map[string]chan struct{}{},
fetchRequests: saFetchRequests,
}
}

func (n *notifications) create(req Request) <-chan struct{} {
n.mu.Lock()
defer n.mu.Unlock()

notifier, found := n.handlers[req.CacheKey()]
if !found {
notifier = make(chan struct{})
n.handlers[req.CacheKey()] = notifier
n.fetchRequests <- &req
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We control the APIServer request rate through the size of the channel but it has two downsides:

  1. The APIServer request is actually not rate limited. There is no request handling rate limiter in the channel consumption and it is possible that there could be > 100 requests (for different namespace/name) sent in the same second given the channel consumer initiates a new go routine to submit a request to APIServer
  2. If due to some reasons the channel consumer dies or the queue is full, the create function will hang until the channel has some capacity. It could unexpectedly delay pod creation for arbitrary time.

A better choice could be use a larger channel size to minimize the chance of channel write blocking, and implement a more robust channel consumer which limit the consumption rate. In case of extremely high volumes of requests queued in the channel and the API requests could not be sent in time, the result would be either be the cache is synced before grace period and pod is mutated, or cache is not synced and the pod is not mutated. But no prolonged delay to pod creation or excessive requests to the APIServer

}
return notifier
}

func (n *notifications) broadcast(key string) {
n.mu.Lock()
defer n.mu.Unlock()
if handler, found := n.handlers[key]; found {
klog.V(5).Infof("Notifying handlers for %q", key)
close(handler)
delete(n.handlers, key)
}
}
5 changes: 3 additions & 2 deletions pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,10 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig {
}

// Use the STS WebIdentity method if set
request := cache.Request{Namespace: pod.Namespace, Name: pod.Spec.ServiceAccountName, RequestNotification: true}
gracePeriodEnabled := m.saLookupGraceTime > 0
request := cache.Request{Namespace: pod.Namespace, Name: pod.Spec.ServiceAccountName, RequestNotification: gracePeriodEnabled}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this change to toggle RequestNotification only when the grace period is enabled.

Previously it was basically a no-op when the feature is disabled and RequestNotification is true, but now we're fetching from the API server when it's true. So we only want it set when the feature is enabled.

response := m.Cache.Get(request)
if !response.FoundInCache && m.saLookupGraceTime > 0 {
if !response.FoundInCache && gracePeriodEnabled {
klog.Warningf("Service account %s not found in the cache. Waiting up to %s to be notified", request.CacheKey(), m.saLookupGraceTime)
select {
case <-response.Notifier:
Expand Down
1 change: 1 addition & 0 deletions vendor/modules.txt
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ k8s.io/client-go/util/consistencydetector
k8s.io/client-go/util/flowcontrol
k8s.io/client-go/util/homedir
k8s.io/client-go/util/keyutil
k8s.io/client-go/util/retry
k8s.io/client-go/util/watchlist
k8s.io/client-go/util/workqueue
# k8s.io/klog/v2 v2.130.1
Expand Down