Skip to content

Commit

Permalink
backend/ssm: minor cleanups (#76)
Browse files Browse the repository at this point in the history
Avoid the opaque-looking generated mock code and apply some
other minor cleanups.
  • Loading branch information
rogpeppe authored Oct 16, 2019
1 parent 1833a32 commit 86d03df
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 151 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ matrix:
before_script:
- docker run -d -p 2379:2379 quay.io/coreos/etcd /usr/local/bin/etcd -advertise-client-urls http://0.0.0.0:2379 -listen-client-urls http://0.0.0.0:2379
- docker run -d -p 8500:8500 --name consul consul
- docker run -d -p 8200:8200 --cap-add=IPC_LOCK -e 'VAULT_DEV_ROOT_TOKEN_ID=root' -e 'VAULT_DEV_LISTEN_ADDRESS=0.0.0.0:8200' vault:0.9.6
- docker run -d -p 8200:8200 --cap-add=IPC_LOCK -e 'VAULT_DEV_ROOT_TOKEN_ID=root' -e 'VAULT_DEV_LISTEN_ADDRESS=0.0.0.0:8200' vault:0.9.6
44 changes: 21 additions & 23 deletions backend/ssm/ssm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@ import (

"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"

"github.com/heetch/confita/backend"
)

type Backend struct {
type ssmBackend struct {
client ssmiface.SSMAPI
ssmPath string
cache map[string][]byte
}

func NewBackend(ssm ssmiface.SSMAPI, path string) *Backend {
return &Backend{client: ssm, ssmPath: path}
// NewBackend returns a backend instance that uses the given SSMAPI implementation
// to retrieve keys from the parameter store at the given path.
func NewBackend(ssm ssmiface.SSMAPI, path string) backend.Backend {
return &ssmBackend{client: ssm, ssmPath: path}
}

func (b *Backend) Get(ctx context.Context, key string) ([]byte, error) {
// Get implements backend.Backend.Get by fetching the key from SSM params.
func (b *ssmBackend) Get(ctx context.Context, key string) ([]byte, error) {
if b.cache == nil {
err := b.fetchParams(ctx)
if err != nil {
Expand All @@ -30,11 +34,12 @@ func (b *Backend) Get(ctx context.Context, key string) ([]byte, error) {
return b.fromCache(ctx, key)
}

func (b *Backend) Name() string {
// Name implements backend.Backend.Name.
func (b *ssmBackend) Name() string {
return "ssm"
}

func (b *Backend) fetchParams(ctx context.Context) error {
func (b *ssmBackend) fetchParams(ctx context.Context) error {
b.cache = make(map[string][]byte)

ssmInput := &ssm.GetParametersByPathInput{
Expand All @@ -43,40 +48,33 @@ func (b *Backend) fetchParams(ctx context.Context) error {
WithDecryption: newBool(true),
MaxResults: newInt64(10),
}

for {
res, err := b.client.GetParametersByPathWithContext(ctx, ssmInput)
if err != nil {
return err
}

for _, p := range res.Parameters {
if p.Name != nil && p.Value != nil {
path := strings.Split(*p.Name, "/")
key := path[len(path)-1]
if key != "" {
b.cache[key] = []byte(*p.Value)
}
if p.Name == nil || p.Value == nil {
continue
}
path := strings.Split(*p.Name, "/")
if key := path[len(path)-1]; key != "" {
b.cache[key] = []byte(*p.Value)
}
}

if res.NextToken == nil {
break
}

ssmInput.NextToken = res.NextToken
}

return nil
}

func (b *Backend) fromCache(ctx context.Context, key string) ([]byte, error) {
v, ok := b.cache[key]
if !ok {
return nil, backend.ErrNotFound
func (b *ssmBackend) fromCache(ctx context.Context, key string) ([]byte, error) {
if v, ok := b.cache[key]; ok {
return v, nil
}

return v, nil
return nil, backend.ErrNotFound
}

func newBool(b bool) *bool {
Expand Down
230 changes: 105 additions & 125 deletions backend/ssm/ssm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,162 +5,111 @@ import (
"fmt"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/heetch/confita/backend"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

type mockSSM struct {
mock.Mock
ssmiface.SSMAPI
}

func (_m *mockSSM) GetParametersByPathWithContext(_a0 context.Context, _a1 *ssm.GetParametersByPathInput, _a2 ...request.Option) (*ssm.GetParametersByPathOutput, error) {
_va := make([]interface{}, len(_a2))
for _i := range _a2 {
_va[_i] = _a2[_i]
}
var _ca []interface{}
_ca = append(_ca, _a0, _a1)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)

var r0 *ssm.GetParametersByPathOutput
if rf, ok := ret.Get(0).(func(context.Context, *ssm.GetParametersByPathInput, ...request.Option) *ssm.GetParametersByPathOutput); ok {
r0 = rf(_a0, _a1, _a2...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*ssm.GetParametersByPathOutput)
}
}

var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *ssm.GetParametersByPathInput, ...request.Option) error); ok {
r1 = rf(_a0, _a1, _a2...)
} else {
r1 = ret.Error(1)
}

return r0, r1
}

func TestAWSError(t *testing.T) {
client := new(mockSSM)
ssmOpts := getSSMOpts("/borked/")
ctx := context.Background()
expected := fmt.Errorf("aws down")
client.On("GetParametersByPathWithContext", ctx, ssmOpts).Return(
nil, expected)

client := newFakeSSM(t, "/borked/", []getParamsRequest{{
resultErr: fmt.Errorf("aws down"),
}})
b := NewBackend(client, "/borked/")
_, actual := b.Get(context.Background(), "some_key")
require.Equal(t, expected, actual)
_, err := b.Get(context.Background(), "some_key")
require.Error(t, err)
require.Contains(t, err.Error(), "aws down")
}

func TestNilNameAndValue(t *testing.T) {
client := new(mockSSM)
ssmOpts := getSSMOpts("/sup/")
ctx := context.Background()

client.On("GetParametersByPathWithContext", ctx, ssmOpts).Return(&ssm.GetParametersByPathOutput{
Parameters: []*ssm.Parameter{
{
// nil names and values in the response are ignored.
client := newFakeSSM(t, "/borked/", []getParamsRequest{{
result: &ssm.GetParametersByPathOutput{
Parameters: []*ssm.Parameter{{
Name: nil,
Value: newString("ignorevalue"),
}, {
Name: newString("ignorename"),
Value: nil,
},
{
Name: ptrString("/sup/key"),
Value: nil,
},
}, {
Name: newString("ignoreboth"),
Value: newString("ignoreboth"),
}, {
Name: newString("/sup/key"),
Value: newString("hello"),
}},
},
}, nil)

b := NewBackend(client, "/sup/")

_, actual := b.Get(context.Background(), "key")
require.Equal(t, backend.ErrNotFound, actual)
}})
b := NewBackend(client, "/borked/")
val, err := b.Get(context.Background(), "key")
require.NoError(t, err)
require.Equal(t, "hello", string(val))
}

func TestEmptyKey(t *testing.T) {
client := new(mockSSM)
ssmOpts := getSSMOpts("/sup/")
ctx := context.Background()

client.On("GetParametersByPathWithContext", ctx, ssmOpts).Return(&ssm.GetParametersByPathOutput{
Parameters: []*ssm.Parameter{
{
Name: ptrString("/sup/"),
Value: ptrString("a value"),
},
client := newFakeSSM(t, "/borked/", []getParamsRequest{{
result: &ssm.GetParametersByPathOutput{
Parameters: []*ssm.Parameter{{
Name: newString("/sup/key"),
Value: newString("hello"),
}},
},
}, nil)

b := NewBackend(client, "/sup/")

_, actual := b.Get(context.Background(), "")
require.Equal(t, backend.ErrNotFound, actual)
}})
b := NewBackend(client, "/borked/")
val, err := b.Get(context.Background(), "")
require.Equal(t, backend.ErrNotFound, err)
require.Equal(t, "", string(val))
}

func TestKeyNotFound(t *testing.T) {
client := new(mockSSM)
ssmOpts := getSSMOpts("/whatevs/")
ctx := context.Background()
client.On("GetParametersByPathWithContext", ctx, ssmOpts).Return(
&ssm.GetParametersByPathOutput{}, nil)

client := newFakeSSM(t, "/whatevs/", []getParamsRequest{{
result: &ssm.GetParametersByPathOutput{},
}})
b := NewBackend(client, "/whatevs/")
_, actual := b.Get(context.Background(), "some_key")
require.Equal(t, backend.ErrNotFound, actual)
}

func ptrString(str string) *string {
return &str
val, err := b.Get(context.Background(), "some_key")
require.Equal(t, backend.ErrNotFound, err)
require.Equal(t, "", string(val))
}

func TestKeysFound(t *testing.T) {
client := new(mockSSM)
ctx := context.Background()
ssmOpts := getSSMOpts("/yo/whatup/")
client.On("GetParametersByPathWithContext", ctx, ssmOpts).Return(
&ssm.GetParametersByPathOutput{
client := newFakeSSM(t, "/yo/whatup/", []getParamsRequest{{
result: &ssm.GetParametersByPathOutput{
Parameters: []*ssm.Parameter{
{Name: ptrString("/yo/whatup/a_key"), Value: ptrString("wow")},
{Name: ptrString("/yo/whatup/some_key"), Value: ptrString("wondrous")},
{Name: newString("/yo/whatup/a_key"), Value: newString("wow")},
{Name: newString("/yo/whatup/some_key"), Value: newString("wondrous")},
},
}, nil)

},
}})
b := NewBackend(client, "/yo/whatup/")
actual, err := b.Get(context.Background(), "a_key")
val, err := b.Get(context.Background(), "a_key")
require.Nil(t, err)
require.Equal(t, "wow", string(actual))
actual, err = b.Get(context.Background(), "some_key")
require.Equal(t, "wow", string(val))
val, err = b.Get(context.Background(), "some_key")
require.Nil(t, err)
require.Equal(t, "wondrous", string(actual))
require.Equal(t, "wondrous", string(val))
}

func TestSSMPagedCall(t *testing.T) {
client := new(mockSSM)
ctx := context.Background()
firstOpts := getSSMOpts("/a/path/")
client.On("GetParametersByPathWithContext", ctx, firstOpts).Return(
&ssm.GetParametersByPathOutput{
Parameters: []*ssm.Parameter{},
NextToken: ptrString("/a/path/your_key"),
}, nil)

secondOpts := getSSMOpts("/a/path/")
secondOpts.NextToken = ptrString("/a/path/your_key")
client.On("GetParametersByPathWithContext", ctx, secondOpts).Return(
&ssm.GetParametersByPathOutput{
client := newFakeSSM(t, "/a/path/", []getParamsRequest{{
result: &ssm.GetParametersByPathOutput{
Parameters: []*ssm.Parameter{
{Name: newString("/yo/whatup/a_key"), Value: newString("wow")},
{Name: newString("/yo/whatup/some_key"), Value: newString("wondrous")},
},
NextToken: newString("/a/path/your_key"),
},
}, {
expectToken: newString("/a/path/your_key"),
result: &ssm.GetParametersByPathOutput{
Parameters: []*ssm.Parameter{
{Name: ptrString("/a/path/your_key"), Value: ptrString("shazam")},
{Name: ptrString("/a/path/another_key"), Value: ptrString("kazam")},
{Name: newString("/a/path/your_key"), Value: newString("shazam")},
{Name: newString("/a/path/another_key"), Value: newString("kazam")},
},
NextToken: nil,
}, nil)
},
}})

b := NewBackend(client, "/a/path/")
actual, err := b.Get(context.Background(), "your_key")
Expand All @@ -171,11 +120,42 @@ func TestSSMPagedCall(t *testing.T) {
require.Equal(t, "kazam", string(actual))
}

func getSSMOpts(path string) *ssm.GetParametersByPathInput {
return &ssm.GetParametersByPathInput{
Path: &path,
Recursive: newBool(true),
WithDecryption: newBool(true),
MaxResults: newInt64(10),
type getParamsRequest struct {
expectToken *string
result *ssm.GetParametersByPathOutput
resultErr error
}

type fakeSSM struct {
ssmiface.SSMAPI
t *testing.T
// We always expect the backend to use the same path.
expectPath string
// The sequence of calls we're expecting.
calls []getParamsRequest
// The index of the next expected call.
call int
}

func newFakeSSM(t *testing.T, path string, calls []getParamsRequest) *fakeSSM {
return &fakeSSM{
t: t,
expectPath: path,
calls: calls,
}
}

func (f *fakeSSM) GetParametersByPathWithContext(ctx aws.Context, arg *ssm.GetParametersByPathInput, opts ...request.Option) (*ssm.GetParametersByPathOutput, error) {
if f.call >= len(f.calls) {
f.t.Errorf("too many calls to SSM (expected max of %d)", len(f.calls))
}
call := f.calls[f.call]
f.call++
require.Equal(f.t, newString(f.expectPath), arg.Path)
require.Equal(f.t, call.expectToken, arg.NextToken)
return call.result, call.resultErr
}

func newString(str string) *string {
return &str
}
Loading

0 comments on commit 86d03df

Please sign in to comment.