From db5294f51b581f10354a086a2cd029c9a1c92842 Mon Sep 17 00:00:00 2001 From: michel-laterman Date: Wed, 20 Sep 2023 09:33:34 -0700 Subject: [PATCH] Fix checks, add unit tests --- internal/pkg/api/handlePGPRequest.go | 29 +++-- internal/pkg/api/handlePGPRequest_test.go | 127 ++++++++++++++++++++++ internal/pkg/api/server_test.go | 2 +- internal/pkg/config/config_test.go | 4 + internal/pkg/config/pgp.go | 4 + 5 files changed, 155 insertions(+), 11 deletions(-) create mode 100644 internal/pkg/api/handlePGPRequest_test.go diff --git a/internal/pkg/api/handlePGPRequest.go b/internal/pkg/api/handlePGPRequest.go index 1bff06d3f..cb6ff5d6a 100644 --- a/internal/pkg/api/handlePGPRequest.go +++ b/internal/pkg/api/handlePGPRequest.go @@ -1,8 +1,13 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + package api import ( "context" "errors" + "fmt" "io" "io/fs" "net/http" @@ -16,11 +21,15 @@ import ( "go.elastic.co/apm/v2" ) -const defaultKeyName = "default.pgp" +const ( + defaultKeyName = "default.pgp" + defaultKeyPermissions = 0o0600 +) var ( ErrTLSRequired = errors.New("api call requires a TLS connection") - ErrPGPPermissions = errors.New("pgp key permissions are not 0700") + ErrPGPPermissions = fmt.Errorf("pgp key permissions are not %#o", defaultKeyPermissions) + ErrUpstreamStatus = errors.New("upstream http server status error") ) type PGPRetrieverT struct { @@ -47,7 +56,6 @@ func (pt *PGPRetrieverT) handlePGPKey(zlog zerolog.Logger, w http.ResponseWriter } zlog = zlog.With().Str(LogEnrollAPIKeyID, key.ID).Logger() ctx := zlog.WithContext(r.Context()) - r = r.WithContext(ctx) p, err := pt.getPGPKey(ctx, zlog) if err != nil { @@ -61,10 +69,8 @@ func (pt *PGPRetrieverT) handlePGPKey(zlog zerolog.Logger, w http.ResponseWriter // getPGPKey will return the PGP key bytes // // First the local cache will be checked -// // If it's not found in the cache, we attempt to read from disk // If it's found we set the cache and return the bytes -// // If it's not found on disk we attempt to retrieve the upstream key // If that succeeds we set the cache then write to disk (with best effort). func (pt *PGPRetrieverT) getPGPKey(ctx context.Context, zlog zerolog.Logger) ([]byte, error) { @@ -81,7 +87,7 @@ func (pt *PGPRetrieverT) getPGPKey(ctx context.Context, zlog zerolog.Logger) ([] } p, err := pt.getPGPFromDir(ctx, key) - // sucessfully retrieved from disk + // successfully retrieved from disk if err == nil { pt.cache.SetPGPKey(key, p) return p, nil @@ -106,14 +112,14 @@ func (pt *PGPRetrieverT) getPGPKey(ctx context.Context, zlog zerolog.Logger) ([] // // Key contents are only returned if the key has valid permission bits. func (pt *PGPRetrieverT) getPGPFromDir(ctx context.Context, key string) ([]byte, error) { - span, ctx := apm.StartSpan(ctx, "getPGPFromDir", "process") + span, _ := apm.StartSpan(ctx, "getPGPFromDir", "process") defer span.End() stat, err := os.Stat(filepath.Join(pt.cfg.Dir, key)) if err != nil { return nil, err } - if stat.Mode().Perm() != 0700 { // TODO determine what permission bits we want to check + if stat.Mode().Perm() != defaultKeyPermissions { // TODO determine what permission bits we want to check return nil, ErrPGPPermissions } return os.ReadFile(filepath.Join(pt.cfg.Dir, key)) @@ -134,6 +140,9 @@ func (pt *PGPRetrieverT) getPGPFromUpstream(ctx context.Context) ([]byte, error) return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w: %d", ErrUpstreamStatus, resp.StatusCode) + } return io.ReadAll(resp.Body) } @@ -142,7 +151,7 @@ func (pt *PGPRetrieverT) getPGPFromUpstream(ctx context.Context) ([]byte, error) // If the directory does not exist it will create it // Otherwise it is treated as a best-effort attempt func (pt *PGPRetrieverT) writeKeyToDir(ctx context.Context, zlog zerolog.Logger, key string, p []byte) { - span, ctx := apm.StartSpan(ctx, "writeKeyToDir", "process") + span, _ := apm.StartSpan(ctx, "writeKeyToDir", "process") defer span.End() _, err := os.Stat(pt.cfg.Dir) @@ -158,7 +167,7 @@ func (pt *PGPRetrieverT) writeKeyToDir(ctx context.Context, zlog zerolog.Logger, } } - err = os.WriteFile(filepath.Join(pt.cfg.Dir, key), p, 0700) + err = os.WriteFile(filepath.Join(pt.cfg.Dir, key), p, defaultKeyPermissions) if err != nil { zlog.Error().Err(err).Str("path", filepath.Join(pt.cfg.Dir, key)).Msg("Unable to write file.") return diff --git a/internal/pkg/api/handlePGPRequest_test.go b/internal/pkg/api/handlePGPRequest_test.go new file mode 100644 index 000000000..2dacb5806 --- /dev/null +++ b/internal/pkg/api/handlePGPRequest_test.go @@ -0,0 +1,127 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/elastic/fleet-server/v7/internal/pkg/config" + "github.com/elastic/fleet-server/v7/internal/pkg/testing/cache" + testlog "github.com/elastic/fleet-server/v7/internal/pkg/testing/log" + "github.com/stretchr/testify/require" +) + +func Test_PGPRetrieverT_getPGPKey(t *testing.T) { + tests := []struct { + name string + cache func() *cache.MockCache + dirSetup func(t *testing.T) string + upstreamStatus int + content []byte + err error + }{{ + name: "found in cache", + cache: func() *cache.MockCache { + m := cache.NewMockCache() + m.On("GetPGPKey", defaultKeyName).Return([]byte("test"), true).Once() + return m + }, + dirSetup: func(t *testing.T) string { + return "" + }, + content: []byte("test"), + err: nil, + }, { + name: "found in dir", + cache: func() *cache.MockCache { + m := cache.NewMockCache() + m.On("GetPGPKey", defaultKeyName).Return([]byte{}, false).Once() + m.On("SetPGPKey", defaultKeyName, []byte("test")).Once() + return m + }, + dirSetup: func(t *testing.T) string { + dir := t.TempDir() + err := os.WriteFile(filepath.Join(dir, defaultKeyName), []byte("test"), defaultKeyPermissions) + require.NoError(t, err) + return dir + }, + content: []byte("test"), + err: nil, + }, { + name: "found in dir with incorrect permissions", + cache: func() *cache.MockCache { + m := cache.NewMockCache() + m.On("GetPGPKey", defaultKeyName).Return([]byte{}, false).Once() + return m + }, + dirSetup: func(t *testing.T) string { + dir := t.TempDir() + err := os.WriteFile(filepath.Join(dir, defaultKeyName), []byte("test"), 0o0660) + require.NoError(t, err) + return dir + }, + content: nil, + err: ErrPGPPermissions, + }, { + name: "failed upstream request", + cache: func() *cache.MockCache { + m := cache.NewMockCache() + m.On("GetPGPKey", defaultKeyName).Return([]byte{}, false).Once() + return m + }, + dirSetup: func(t *testing.T) string { + dir := t.TempDir() + return dir + }, + upstreamStatus: 400, + content: nil, + err: ErrUpstreamStatus, + }, { + name: "upstream request succeeded", + cache: func() *cache.MockCache { + m := cache.NewMockCache() + m.On("GetPGPKey", defaultKeyName).Return([]byte{}, false).Once() + m.On("SetPGPKey", defaultKeyName, []byte("test")).Once() + return m + }, + dirSetup: func(t *testing.T) string { + dir := t.TempDir() + return dir + }, + upstreamStatus: 200, + content: []byte(`test`), + err: nil, + }} + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockCache := tc.cache() + dir := tc.dirSetup(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.upstreamStatus) + w.Write([]byte(`test`)) + })) + defer server.Close() + + pt := &PGPRetrieverT{ + cache: mockCache, + cfg: config.PGP{ + UpstreamURL: server.URL, + Dir: dir, + }, + } + + content, err := pt.getPGPKey(context.Background(), testlog.SetLogger(t)) + require.ErrorIs(t, err, tc.err) + require.Equal(t, tc.content, content) + mockCache.AssertExpectations(t) + }) + } +} diff --git a/internal/pkg/api/server_test.go b/internal/pkg/api/server_test.go index 0de84dd42..3e265755e 100644 --- a/internal/pkg/api/server_test.go +++ b/internal/pkg/api/server_test.go @@ -48,7 +48,7 @@ func Test_server_Run(t *testing.T) { et, err := NewEnrollerT(verCon, cfg, nil, c) require.NoError(t, err) - srv := NewServer(addr, cfg, ct, et, nil, nil, nil, nil, fbuild.Info{}, nil, nil, nil, nil) + srv := NewServer(addr, cfg, ct, et, nil, nil, nil, nil, fbuild.Info{}, nil, nil, nil, nil, nil) errCh := make(chan error) var wg sync.WaitGroup diff --git a/internal/pkg/config/config_test.go b/internal/pkg/config/config_test.go index 5232724b5..44b13c816 100644 --- a/internal/pkg/config/config_test.go +++ b/internal/pkg/config/config_test.go @@ -128,6 +128,10 @@ func TestConfig(t *testing.T) { Limits: generateServerLimits(12500), Bulk: defaultServerBulk(), GC: defaultServerGC(), + PGP: PGP{ + UpstreamURL: defaultPGPUpstreamURL, + Dir: filepath.Join(retrieveExecutableDir(), defaultPGPDirectoryName), + }, }, Cache: generateCache(12500), Monitor: Monitor{ diff --git a/internal/pkg/config/pgp.go b/internal/pkg/config/pgp.go index f8cbc6d31..f9e8f98f4 100644 --- a/internal/pkg/config/pgp.go +++ b/internal/pkg/config/pgp.go @@ -1,3 +1,7 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + package config import (