Skip to content

Commit

Permalink
Fix checks, add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michel-laterman committed Sep 20, 2023
1 parent 19dc846 commit db5294f
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 11 deletions.
29 changes: 19 additions & 10 deletions internal/pkg/api/handlePGPRequest.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package api

import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
Expand All @@ -16,11 +21,15 @@ import (
"go.elastic.co/apm/v2"
)

const defaultKeyName = "default.pgp"
const (
defaultKeyName = "default.pgp"
defaultKeyPermissions = 0o0600
)

var (
ErrTLSRequired = errors.New("api call requires a TLS connection")
ErrPGPPermissions = errors.New("pgp key permissions are not 0700")
ErrPGPPermissions = fmt.Errorf("pgp key permissions are not %#o", defaultKeyPermissions)
ErrUpstreamStatus = errors.New("upstream http server status error")
)

type PGPRetrieverT struct {
Expand All @@ -47,7 +56,6 @@ func (pt *PGPRetrieverT) handlePGPKey(zlog zerolog.Logger, w http.ResponseWriter
}
zlog = zlog.With().Str(LogEnrollAPIKeyID, key.ID).Logger()
ctx := zlog.WithContext(r.Context())
r = r.WithContext(ctx)

p, err := pt.getPGPKey(ctx, zlog)
if err != nil {
Expand All @@ -61,10 +69,8 @@ func (pt *PGPRetrieverT) handlePGPKey(zlog zerolog.Logger, w http.ResponseWriter
// getPGPKey will return the PGP key bytes
//
// First the local cache will be checked
//
// If it's not found in the cache, we attempt to read from disk
// If it's found we set the cache and return the bytes
//
// If it's not found on disk we attempt to retrieve the upstream key
// If that succeeds we set the cache then write to disk (with best effort).
func (pt *PGPRetrieverT) getPGPKey(ctx context.Context, zlog zerolog.Logger) ([]byte, error) {
Expand All @@ -81,7 +87,7 @@ func (pt *PGPRetrieverT) getPGPKey(ctx context.Context, zlog zerolog.Logger) ([]
}
p, err := pt.getPGPFromDir(ctx, key)

// sucessfully retrieved from disk
// successfully retrieved from disk
if err == nil {
pt.cache.SetPGPKey(key, p)
return p, nil
Expand All @@ -106,14 +112,14 @@ func (pt *PGPRetrieverT) getPGPKey(ctx context.Context, zlog zerolog.Logger) ([]
//
// Key contents are only returned if the key has valid permission bits.
func (pt *PGPRetrieverT) getPGPFromDir(ctx context.Context, key string) ([]byte, error) {
span, ctx := apm.StartSpan(ctx, "getPGPFromDir", "process")
span, _ := apm.StartSpan(ctx, "getPGPFromDir", "process")
defer span.End()

stat, err := os.Stat(filepath.Join(pt.cfg.Dir, key))
if err != nil {
return nil, err
}
if stat.Mode().Perm() != 0700 { // TODO determine what permission bits we want to check
if stat.Mode().Perm() != defaultKeyPermissions { // TODO determine what permission bits we want to check
return nil, ErrPGPPermissions
}
return os.ReadFile(filepath.Join(pt.cfg.Dir, key))
Expand All @@ -134,6 +140,9 @@ func (pt *PGPRetrieverT) getPGPFromUpstream(ctx context.Context) ([]byte, error)
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %d", ErrUpstreamStatus, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}

Expand All @@ -142,7 +151,7 @@ func (pt *PGPRetrieverT) getPGPFromUpstream(ctx context.Context) ([]byte, error)
// If the directory does not exist it will create it
// Otherwise it is treated as a best-effort attempt
func (pt *PGPRetrieverT) writeKeyToDir(ctx context.Context, zlog zerolog.Logger, key string, p []byte) {
span, ctx := apm.StartSpan(ctx, "writeKeyToDir", "process")
span, _ := apm.StartSpan(ctx, "writeKeyToDir", "process")
defer span.End()

_, err := os.Stat(pt.cfg.Dir)
Expand All @@ -158,7 +167,7 @@ func (pt *PGPRetrieverT) writeKeyToDir(ctx context.Context, zlog zerolog.Logger,
}
}

err = os.WriteFile(filepath.Join(pt.cfg.Dir, key), p, 0700)
err = os.WriteFile(filepath.Join(pt.cfg.Dir, key), p, defaultKeyPermissions)
if err != nil {
zlog.Error().Err(err).Str("path", filepath.Join(pt.cfg.Dir, key)).Msg("Unable to write file.")
return
Expand Down
127 changes: 127 additions & 0 deletions internal/pkg/api/handlePGPRequest_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package api

import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"

"github.com/elastic/fleet-server/v7/internal/pkg/config"
"github.com/elastic/fleet-server/v7/internal/pkg/testing/cache"
testlog "github.com/elastic/fleet-server/v7/internal/pkg/testing/log"
"github.com/stretchr/testify/require"
)

func Test_PGPRetrieverT_getPGPKey(t *testing.T) {
tests := []struct {
name string
cache func() *cache.MockCache
dirSetup func(t *testing.T) string
upstreamStatus int
content []byte
err error
}{{
name: "found in cache",
cache: func() *cache.MockCache {
m := cache.NewMockCache()
m.On("GetPGPKey", defaultKeyName).Return([]byte("test"), true).Once()
return m
},
dirSetup: func(t *testing.T) string {
return ""
},
content: []byte("test"),
err: nil,
}, {
name: "found in dir",
cache: func() *cache.MockCache {
m := cache.NewMockCache()
m.On("GetPGPKey", defaultKeyName).Return([]byte{}, false).Once()
m.On("SetPGPKey", defaultKeyName, []byte("test")).Once()
return m
},
dirSetup: func(t *testing.T) string {
dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, defaultKeyName), []byte("test"), defaultKeyPermissions)
require.NoError(t, err)
return dir
},
content: []byte("test"),
err: nil,
}, {
name: "found in dir with incorrect permissions",
cache: func() *cache.MockCache {
m := cache.NewMockCache()
m.On("GetPGPKey", defaultKeyName).Return([]byte{}, false).Once()
return m
},
dirSetup: func(t *testing.T) string {
dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, defaultKeyName), []byte("test"), 0o0660)

Check failure on line 66 in internal/pkg/api/handlePGPRequest_test.go

View workflow job for this annotation

GitHub Actions / lint (linux)

G306: Expect WriteFile permissions to be 0600 or less (gosec)
require.NoError(t, err)
return dir
},
content: nil,
err: ErrPGPPermissions,
}, {
name: "failed upstream request",
cache: func() *cache.MockCache {
m := cache.NewMockCache()
m.On("GetPGPKey", defaultKeyName).Return([]byte{}, false).Once()
return m
},
dirSetup: func(t *testing.T) string {
dir := t.TempDir()
return dir
},
upstreamStatus: 400,
content: nil,
err: ErrUpstreamStatus,
}, {
name: "upstream request succeeded",
cache: func() *cache.MockCache {
m := cache.NewMockCache()
m.On("GetPGPKey", defaultKeyName).Return([]byte{}, false).Once()
m.On("SetPGPKey", defaultKeyName, []byte("test")).Once()
return m
},
dirSetup: func(t *testing.T) string {
dir := t.TempDir()
return dir
},
upstreamStatus: 200,
content: []byte(`test`),
err: nil,
}}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockCache := tc.cache()
dir := tc.dirSetup(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.upstreamStatus)
w.Write([]byte(`test`))

Check failure on line 109 in internal/pkg/api/handlePGPRequest_test.go

View workflow job for this annotation

GitHub Actions / lint (linux)

Error return value of `w.Write` is not checked (errcheck)
}))
defer server.Close()

pt := &PGPRetrieverT{
cache: mockCache,
cfg: config.PGP{
UpstreamURL: server.URL,
Dir: dir,
},
}

content, err := pt.getPGPKey(context.Background(), testlog.SetLogger(t))
require.ErrorIs(t, err, tc.err)
require.Equal(t, tc.content, content)
mockCache.AssertExpectations(t)
})
}
}
2 changes: 1 addition & 1 deletion internal/pkg/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func Test_server_Run(t *testing.T) {
et, err := NewEnrollerT(verCon, cfg, nil, c)
require.NoError(t, err)

srv := NewServer(addr, cfg, ct, et, nil, nil, nil, nil, fbuild.Info{}, nil, nil, nil, nil)
srv := NewServer(addr, cfg, ct, et, nil, nil, nil, nil, fbuild.Info{}, nil, nil, nil, nil, nil)
errCh := make(chan error)

var wg sync.WaitGroup
Expand Down
4 changes: 4 additions & 0 deletions internal/pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ func TestConfig(t *testing.T) {
Limits: generateServerLimits(12500),
Bulk: defaultServerBulk(),
GC: defaultServerGC(),
PGP: PGP{
UpstreamURL: defaultPGPUpstreamURL,
Dir: filepath.Join(retrieveExecutableDir(), defaultPGPDirectoryName),
},
},
Cache: generateCache(12500),
Monitor: Monitor{
Expand Down
4 changes: 4 additions & 0 deletions internal/pkg/config/pgp.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.

package config

import (
Expand Down

0 comments on commit db5294f

Please sign in to comment.