From df9a48d3118cc61b5f94df313c65c4a475d7ccce Mon Sep 17 00:00:00 2001 From: Jeff Haynie Date: Sun, 2 Feb 2025 15:26:03 -0600 Subject: [PATCH] first commit --- .github/workflows/go.yml | 26 +++ LICENSE | 19 ++ Makefile | 15 ++ README.md | 21 ++ cache/cache.go | 28 +++ cache/inmemory.go | 125 +++++++++++ cache/inmemory_test.go | 85 ++++++++ command/command.go | 279 ++++++++++++++++++++++++ command/nonwindows.go | 15 ++ command/windows.go | 15 ++ compress/gunzip.go | 87 ++++++++ compress/gunzip_test.go | 47 ++++ dns/dns.go | 197 +++++++++++++++++ dns/dns_test.go | 174 +++++++++++++++ go.mod | 29 +++ go.sum | 79 +++++++ logger/console.go | 210 ++++++++++++++++++ logger/console_test.go | 295 +++++++++++++++++++++++++ logger/gcloud.go | 11 + logger/init.go | 69 ++++++ logger/json.go | 189 +++++++++++++++++ logger/logger_test.go | 159 ++++++++++++++ logger/multi.go | 57 +++++ logger/test.go | 75 +++++++ logger/util.go | 5 + renovate.json | 17 ++ request/request.go | 410 +++++++++++++++++++++++++++++++++++ request/request_test.go | 261 +++++++++++++++++++++++ slice/slice.go | 57 +++++ slice/slice_test.go | 46 ++++ string/hash.go | 59 +++++ string/hash_test.go | 49 +++++ string/http.go | 20 ++ string/http_test.go | 45 ++++ string/interpolate.go | 55 +++++ string/interpolate_test.go | 39 ++++ string/json.go | 20 ++ string/mask.go | 93 ++++++++ string/mask_test.go | 97 +++++++++ string/random.go | 58 +++++ string/random_test.go | 27 +++ string/sha.go | 16 ++ string/sha_test.go | 47 ++++ string/string.go | 29 +++ string/string_test.go | 27 +++ string/util.go | 11 + string/util_test.go | 27 +++ sys/docker.go | 22 ++ sys/errors.go | 35 +++ sys/io.go | 425 +++++++++++++++++++++++++++++++++++++ sys/io_test.go | 45 ++++ sys/ip.go | 44 ++++ sys/net.go | 25 +++ sys/shutdown.go | 14 ++ 54 files changed, 4431 insertions(+) create mode 100644 .github/workflows/go.yml create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 cache/cache.go create mode 100644 cache/inmemory.go create mode 100644 cache/inmemory_test.go create mode 100644 command/command.go create mode 100644 command/nonwindows.go create mode 100644 command/windows.go create mode 100644 compress/gunzip.go create mode 100644 compress/gunzip_test.go create mode 100644 dns/dns.go create mode 100644 dns/dns_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 logger/console.go create mode 100644 logger/console_test.go create mode 100644 logger/gcloud.go create mode 100644 logger/init.go create mode 100644 logger/json.go create mode 100644 logger/logger_test.go create mode 100644 logger/multi.go create mode 100644 logger/test.go create mode 100644 logger/util.go create mode 100644 renovate.json create mode 100644 request/request.go create mode 100644 request/request_test.go create mode 100644 slice/slice.go create mode 100644 slice/slice_test.go create mode 100644 string/hash.go create mode 100644 string/hash_test.go create mode 100644 string/http.go create mode 100644 string/http_test.go create mode 100644 string/interpolate.go create mode 100644 string/interpolate_test.go create mode 100644 string/json.go create mode 100644 string/mask.go create mode 100644 string/mask_test.go create mode 100644 string/random.go create mode 100644 string/random_test.go create mode 100644 string/sha.go create mode 100644 string/sha_test.go create mode 100644 string/string.go create mode 100644 string/string_test.go create mode 100644 string/util.go create mode 100644 string/util_test.go create mode 100644 sys/docker.go create mode 100644 sys/errors.go create mode 100644 sys/io.go create mode 100644 sys/io_test.go create mode 100644 sys/ip.go create mode 100644 sys/net.go create mode 100644 sys/shutdown.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..0bd44af --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,26 @@ +name: Go +on: + push: + branches: [main] + pull_request: + branches: [main] +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "stable" + - name: Check vulnerabilities + run: | + go install golang.org/x/vuln/cmd/govulncheck@latest + govulncheck ./... + - name: Build + run: go build -v ./... + - name: Test + run: go test -v ./... diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..68d3e23 --- /dev/null +++ b/LICENSE @@ -0,0 +1,19 @@ +Copyright 2025 Agentuity, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- + +Some code was originally adapted from and relicensed under the same. https://github.com/shopmonkeyus/go-common + +Copyright 2023-2024 Shopmonkey, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a21bea0 --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +.PHONY: all lint test vet tidy + +all: test + +lint: + @go fmt ./... + +vet: + @go vet ./... + +tidy: + @go mod tidy + +test: tidy lint vet + @go test -v -count=1 ./... diff --git a/README.md b/README.md new file mode 100644 index 0000000..44144a1 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ + + +# Overview + +This repository contains the public shared utility code for Agenuity as a Golang module. + +## Requirements + +You will need [Golang](https://go.dev/dl/) version 1.23 or later to use this package. + +## Usage + +You should import these files using the Go package with the following: + +```go +import "github.com/agentuity/go-common" +``` + +## License + +All files in this repository are licensed under the [MIT license](https://opensource.org/licenses/MIT). See the [LICENSE](./LICENSE) file for details. diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..2669210 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,28 @@ +package cache + +import ( + "time" +) + +type Cache interface { + // Get a value from the cache and return true if found, any is the value if found and nil if no error. + Get(key string) (bool, any, error) + + // Set a value into the cache with a cache expiration. + Set(key string, val any, expires time.Duration) error + + // Hits returns the number of times a key has been accessed. + Hits(key string) (bool, int) + + // Expire will expire a key in the cache. + Expire(key string) (bool, error) + + // Close will shutdown the cache. + Close() error +} + +type value struct { + object any + expires time.Time + hits int +} diff --git a/cache/inmemory.go b/cache/inmemory.go new file mode 100644 index 0000000..0a3316a --- /dev/null +++ b/cache/inmemory.go @@ -0,0 +1,125 @@ +package cache + +import ( + "context" + "sync" + "time" +) + +type inMemoryCache struct { + ctx context.Context + cancel context.CancelFunc + cache map[string]*value + mutex sync.Mutex + waitGroup sync.WaitGroup + once sync.Once + expiryCheck time.Duration +} + +var _ Cache = (*inMemoryCache)(nil) + +func (c *inMemoryCache) Get(key string) (bool, any, error) { + c.mutex.Lock() + val, ok := c.cache[key] + if ok { + val.hits++ + } + c.mutex.Unlock() + if ok { + if val.expires.Before(time.Now()) { + c.mutex.Lock() + delete(c.cache, key) + c.mutex.Unlock() + return false, nil, nil + } + return true, val.object, nil + } + return false, nil, nil +} + +// Hits returns the number of times a key has been accessed. +func (c *inMemoryCache) Hits(key string) (bool, int) { + c.mutex.Lock() + var val int + var found bool + if v, ok := c.cache[key]; ok { + val = v.hits + found = true + } + c.mutex.Unlock() + return found, val +} + +func (c *inMemoryCache) Set(key string, val any, expires time.Duration) error { + c.mutex.Lock() + if v, ok := c.cache[key]; ok { + v.hits = 0 + v.expires = time.Now().Add(expires) + v.object = val + } else { + c.cache[key] = &value{val, time.Now().Add(expires), 0} + } + c.mutex.Unlock() + return nil +} + +func (c *inMemoryCache) Expire(key string) (bool, error) { + c.mutex.Lock() + _, ok := c.cache[key] + if ok { + delete(c.cache, key) + } + c.mutex.Unlock() + return ok, nil +} + +func (c *inMemoryCache) Close() error { + c.once.Do(func() { + c.cancel() + c.waitGroup.Wait() + }) + return nil +} + +func (c *inMemoryCache) run() { + c.waitGroup.Add(1) + timer := time.NewTicker(c.expiryCheck) + defer func() { + timer.Stop() + c.waitGroup.Done() + }() + for { + select { + case <-c.ctx.Done(): + return + case <-timer.C: + now := time.Now() + c.mutex.Lock() + var expired []string + for key, val := range c.cache { + if val.expires.Before(now) { + expired = append(expired, key) + } + } + if len(expired) > 0 { + for _, key := range expired { + delete(c.cache, key) + } + } + c.mutex.Unlock() + } + } +} + +// New returns a new Cache implementation +func NewInMemory(parent context.Context, expiryCheck time.Duration) Cache { + ctx, cancel := context.WithCancel(parent) + c := &inMemoryCache{ + ctx: ctx, + cancel: cancel, + cache: make(map[string]*value), + expiryCheck: expiryCheck, + } + go c.run() + return c +} diff --git a/cache/inmemory_test.go b/cache/inmemory_test.go new file mode 100644 index 0000000..3ce91bd --- /dev/null +++ b/cache/inmemory_test.go @@ -0,0 +1,85 @@ +package cache + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSimpleCache(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cache := NewInMemory(ctx, time.Second) + cache.Close() + cancel() +} + +func TestSetGetCache(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cache := NewInMemory(ctx, time.Minute) + found, val, err := cache.Get("test") + assert.NoError(t, err) + assert.False(t, found) + assert.Nil(t, val) + assert.NoError(t, cache.Set("test", "value", time.Millisecond*10)) + found, val, err = cache.Get("test") + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "value", val) + ok, hits := cache.Hits("test") + assert.True(t, ok) + assert.Equal(t, 1, hits) + time.Sleep(time.Millisecond * 11) + found, val, err = cache.Get("test") + assert.NoError(t, err) + assert.False(t, found) + assert.Nil(t, val) + ok, hits = cache.Hits("test") + assert.False(t, ok) + assert.Equal(t, 0, hits) + cache.Close() + cancel() +} + +func TestCacheBackgroundExpire(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cache := NewInMemory(ctx, time.Millisecond*100) + found, val, err := cache.Get("test") + assert.NoError(t, err) + assert.False(t, found) + assert.Nil(t, val) + assert.NoError(t, cache.Set("test", "value", 90*time.Millisecond)) + found, val, err = cache.Get("test") + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "value", val) + time.Sleep(time.Millisecond * 200) + c := cache.(*inMemoryCache) + c.mutex.Lock() + defer c.mutex.Unlock() + assert.Empty(t, c.cache) + cache.Close() + cancel() +} + +func TestCacheExpire(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cache := NewInMemory(ctx, time.Millisecond*100) + found, val, err := cache.Get("test") + assert.NoError(t, err) + assert.False(t, found) + assert.Nil(t, val) + assert.NoError(t, cache.Set("test", "value", 90*time.Millisecond)) + found, val, err = cache.Get("test") + assert.NoError(t, err) + assert.True(t, found) + assert.Equal(t, "value", val) + cache.Expire("test") + c := cache.(*inMemoryCache) + c.mutex.Lock() + defer c.mutex.Unlock() + assert.Empty(t, c.cache) + cache.Close() + cancel() +} diff --git a/command/command.go b/command/command.go new file mode 100644 index 0000000..eaa12cc --- /dev/null +++ b/command/command.go @@ -0,0 +1,279 @@ +package command + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "os/signal" + "path/filepath" + "regexp" + "strings" + "time" + + "github.com/agentuity/go-common/compress" + "github.com/agentuity/go-common/logger" +) + +func parseLastLines(fn string, n int) (string, error) { + file, err := os.Open(fn) + if err != nil { + return "", err + } + defer file.Close() + + stats, statsErr := file.Stat() + if statsErr != nil { + return "", statsErr + } + + buf := make([]byte, stats.Size()) + _, err = file.ReadAt(buf, 0) + if err != nil { + return "", err + } + + lines := strings.Split(string(buf), "\n") + totalLines := len(lines) + + start := totalLines - n + if start < 0 { + start = 0 + } + + lastLines := lines[start:] + + return strings.Join(lastLines, "\n"), nil +} + +type Uploader func(ctx context.Context, log logger.Logger, file string) (string, error) + +type ProcessCallback func(process *os.Process) + +type ForkArgs struct { + // required + Log logger.Logger + Command string + + // optional + Context context.Context + Args []string + Cwd string + BaseDir string // the base director to upload if different than dir + Dir string // the directory to store logs in + LogFilenameLabel string + SaveLogs bool + Env []string + SkipBundleOnSuccess bool + WriteToStd bool + ForwardInterrupt bool + LogFileSink bool + ProcessCallback ProcessCallback +} + +type ForkResult struct { + Duration time.Duration + LastErrorLines string + ProcessState *os.ProcessState + LogFileBundle string +} + +func (r *ForkResult) String() string { + pState := "" + if r.ProcessState != nil { + pState = r.ProcessState.String() + } + return fmt.Sprintf("ProcessState: %s, Duration: %s, LogFileBundle: %s", pState, r.Duration, r.LogFileBundle) +} + +var looksLikeJSONRegex = regexp.MustCompile(`^\s*[\[\{]`) + +func looksLikeJSON(s string) bool { + return looksLikeJSONRegex.MatchString(s) +} + +func formatCmd(cmdargs []string) string { + var args []string + for _, arg := range cmdargs { + if looksLikeJSON(arg) { + // quote json so i can paste it out of the logs and into my terminal 😤 + args = append(args, "'"+arg+"'") + } else { + args = append(args, arg) + } + } + return fmt.Sprintf("%s %s\n", os.Args[0], strings.Join(args, " ")) +} + +// GetExecutable returns the path to the current executable. +func GetExecutable() string { + ex, err := os.Executable() + if err != nil { + ex = os.Args[0] + } + return ex +} + +// Fork will run a command on the current executable. +func Fork(args ForkArgs) (*ForkResult, error) { + started := time.Now() + if args.Log == nil { + args.Log = logger.NewConsoleLogger(logger.LevelInfo) + } + dir := args.Dir + executable := GetExecutable() + if dir == "" { + tmp, err := os.MkdirTemp("", filepath.Base(executable)+"-") + if err != nil { + return nil, fmt.Errorf("error creating temp dir: %w", err) + } + defer os.RemoveAll(tmp) + dir = tmp + } else { + if _, err := os.Stat(dir); err != nil && os.IsNotExist(err) { + os.MkdirAll(dir, 0755) + } + } + cmdargs := append([]string{args.Command}, args.Args...) + + args.Log.Trace("executing: %s", formatCmd(cmdargs)) + + ctx := args.Context + if ctx == nil { + ctx = context.Background() + } + + label := args.LogFilenameLabel + if label == "" { + label = "job-" + time.Now().Format("20060102-150405") + } + stderrFn := filepath.Join(dir, label+"_stderr.txt") + stdoutFn := filepath.Join(dir, label+"_stdout.txt") + + cmd := exec.CommandContext(ctx, executable, cmdargs...) + if args.Cwd != "" { + cmd.Dir = args.Cwd + } else { + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("error getting current working directory: %w", err) + } + cmd.Dir = cwd + } + if len(args.Env) > 0 { + cmd.Env = args.Env + args.Log.Trace("using custom env") + } else { + cmd.Env = os.Environ() + args.Log.Trace("using default env") + } + + var err error + var stderr, stdout *os.File + + if args.SaveLogs { + stderr, err = os.Create(stderrFn) + if err != nil { + return nil, fmt.Errorf("error creating temporary stderr log file: %w", err) + } + defer stderr.Close() + + if !args.LogFileSink { + stdout, err = os.Create(stdoutFn) + if err != nil { + return nil, fmt.Errorf("error creating temporary stdout log file: %w", err) + } + defer stdout.Close() + } + if args.WriteToStd { + cmd.Stdout = os.Stdout + cmd.Stderr = io.MultiWriter(stderr, os.Stderr) + } else { + cmd.Stderr = stderr + cmd.Stdout = stdout + } + stdout.WriteString(fmt.Sprintf("executing: %s\n", formatCmd(cmdargs))) + } else { + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + } + + if args.WriteToStd { + cmd.Stdin = os.Stdin + } else { + cmd.Stdin = nil + } + + setCommandProcessGroup(cmd) + + var result ForkResult + var resultError error + sigch := make(chan os.Signal, 1) + cctx, cancel := context.WithCancel(ctx) + defer cancel() + if args.ForwardInterrupt { + signal.Notify(sigch, os.Interrupt) + go func() { + select { + case <-cctx.Done(): + return + case <-sigch: + args.Log.Trace("forwarding interrupt to child process") + cmd.Process.Signal(os.Interrupt) + } + }() + } + + // notify the callback with the process once its running + if args.ProcessCallback != nil { + go func() { + for { + select { + case <-cctx.Done(): + return + case <-time.After(time.Millisecond * 10): + if cmd.Process != nil && cmd.Process.Pid > 0 { + args.ProcessCallback(cmd.Process) + return + } + } + } + }() + } + + if err := cmd.Run(); err != nil { + if args.SaveLogs { + stderr.Close() + stdout.Close() + lines, _ := parseLastLines(stderrFn, 10) + if lines == "" { + lines, _ = parseLastLines(stdoutFn, 10) + } + result.LastErrorLines = lines + } + resultError = err + } else if args.SaveLogs { + stderr.Close() + stdout.Close() + } + + result.ProcessState = cmd.ProcessState + result.Duration = time.Since(started) + + if args.SaveLogs { + if !args.SkipBundleOnSuccess || resultError != nil { + baseDir := dir + if args.BaseDir != "" { + baseDir = args.BaseDir + } + targz, err := compress.TarGzipDir(baseDir) + if err != nil { + return nil, fmt.Errorf("error compressing logs: %w", err) + } + result.LogFileBundle = targz + } + } + + return &result, resultError +} diff --git a/command/nonwindows.go b/command/nonwindows.go new file mode 100644 index 0000000..b8365a0 --- /dev/null +++ b/command/nonwindows.go @@ -0,0 +1,15 @@ +//go:build !windows +// +build !windows + +package command + +import ( + "os/exec" + "syscall" +) + +func setCommandProcessGroup(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } +} diff --git a/command/windows.go b/command/windows.go new file mode 100644 index 0000000..ac85d82 --- /dev/null +++ b/command/windows.go @@ -0,0 +1,15 @@ +//go:build windows +// +build windows + +package command + +import ( + "os/exec" + "syscall" +) + +func setCommandProcessGroup(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, + } +} diff --git a/compress/gunzip.go b/compress/gunzip.go new file mode 100644 index 0000000..7053d2b --- /dev/null +++ b/compress/gunzip.go @@ -0,0 +1,87 @@ +package compress + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// Gunzip will unzip data and return buffer inline +func Gunzip(data []byte) ([]byte, error) { + r, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, err + } + defer r.Close() + + var resB bytes.Buffer + _, err = resB.ReadFrom(r) + if err != nil { + return nil, err + } + + return append([]byte(nil), resB.Bytes()...), nil +} + +func TarGz(srcDir string, outfile *os.File) error { + zr := gzip.NewWriter(outfile) + tw := tar.NewWriter(zr) + + baseDir := filepath.Base(srcDir) + // walk through every file in the folder + filepath.Walk(srcDir, func(file string, fi os.FileInfo, _ error) error { + // generate tar header + header, err := tar.FileInfoHeader(fi, file) + if err != nil { + return err + } + + header.Name = baseDir + strings.Replace(filepath.ToSlash(file), srcDir, "", -1) + + // write header + if err := tw.WriteHeader(header); err != nil { + return err + } + // if not a dir, write file content + if !fi.IsDir() { + data, err := os.Open(file) + if err != nil { + return err + } + if _, err := io.Copy(tw, data); err != nil { + return err + } + } + return nil + }) + + // produce tar + if err := tw.Close(); err != nil { + return err + } + // produce gzip + if err := zr.Close(); err != nil { + return err + } + + return nil +} + +// TarGzipDir will tar and gzip a directory and return the path to the file. You must delete the file when done. +func TarGzipDir(srcDir string) (string, error) { + tmpfn, err := os.CreateTemp("", "*.tar.gz") + if err != nil { + return "", fmt.Errorf("tmp: %w", err) + } + defer tmpfn.Close() + + if err := TarGz(srcDir, tmpfn); err != nil { + return "", err + } + return tmpfn.Name(), nil +} diff --git a/compress/gunzip_test.go b/compress/gunzip_test.go new file mode 100644 index 0000000..7fa93b6 --- /dev/null +++ b/compress/gunzip_test.go @@ -0,0 +1,47 @@ +package compress + +import ( + "bytes" + "testing" +) + +func TestGunzip(t *testing.T) { + tests := []struct { + name string + data []byte + want []byte + wantErr bool + }{ + { + name: "empty input", + data: []byte{}, + want: []byte{}, + wantErr: true, + }, + { + name: "valid gzip data", + data: []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 242, 72, 205, 201, 201, 87, 8, 207, 47, 202, 73, 1, 4, 0, 0, 255, 255, 86, 177, 23, 74, 11, 0, 0, 0}, + want: []byte("Hello World"), + wantErr: false, + }, + { + name: "invalid gzip data", + data: []byte{1, 2, 3, 4}, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Gunzip(tt.data) + if (err != nil) != tt.wantErr { + t.Errorf("Gunzip() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !bytes.Equal(got, tt.want) { + t.Errorf("Gunzip() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/dns/dns.go b/dns/dns.go new file mode 100644 index 0000000..9a90da1 --- /dev/null +++ b/dns/dns.go @@ -0,0 +1,197 @@ +package dns + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "net" + "net/http" + "net/url" + "regexp" + "time" + + "github.com/agentuity/go-common/cache" +) + +var ErrInvalidIP = fmt.Errorf("invalid ip address resolved for hostname") + +type DNS interface { + // Lookup performs a DNS lookup for the given hostname and returns a valid IP address for the A record. + Lookup(ctx context.Context, hostname string) (bool, *net.IP, error) +} + +type RecordType uint8 + +const ( + A RecordType = 1 + CNAME RecordType = 5 +) + +type StatusType uint8 + +const ( + NoError StatusType = 0 + FormErr StatusType = 1 + ServFail StatusType = 2 + NXDomain StatusType = 3 + Refused StatusType = 5 + NotAuth StatusType = 9 + NotZone StatusType = 10 +) + +func (s StatusType) String() string { + switch s { + case NoError: + return "Success" + case FormErr: + return "Format Error" + case ServFail: + return "Server Fail" + case NXDomain: + return "Non-Existent Domain" + case Refused: + return "Query Refused" + case NotAuth: + return "Server Not Authoritative for zone" + case NotZone: + return "Name not contained in zone" + default: + return "Unknown DNS error" + } +} + +type Result struct { + Status StatusType `json:"Status"` + Answer []Answer `json:"Answer"` +} + +type Answer struct { + Name string `json:"name"` + Type RecordType `json:"type"` + TTL uint `json:"ttl"` + Data string `json:"data"` +} + +type dnsConfig struct { + FailIfLocal bool +} + +type Dns struct { + cache cache.Cache + isLocal bool +} + +var _ DNS = (*Dns)(nil) + +var ipv4 = regexp.MustCompile(`^(((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|$)){4})`) +var magicIpAddress = "169.254.169.254" + +func isPrivateIP(ip string) bool { + ipAddress := net.ParseIP(ip) + return ipAddress.IsPrivate() || ipAddress.IsLoopback() +} + +// Lookup performs a DNS lookup for the given hostname and returns a valid IP address for the A record. +func (d *Dns) Lookup(ctx context.Context, hostname string) (bool, *net.IP, error) { + if (hostname == "localhost" || hostname == "127.0.0.1") && d.isLocal { + return true, &net.IP{127, 0, 0, 1}, nil + } + if isPrivateIP(hostname) && !d.isLocal { + return false, nil, ErrInvalidIP + } + if hostname == magicIpAddress { + return false, nil, ErrInvalidIP + } + if ipv4.MatchString(hostname) { + ip := net.ParseIP(hostname) + if ip == nil { + return false, nil, fmt.Errorf("failed to parse ip address: %s", hostname) + } + return true, &ip, nil + } + cacheKey := fmt.Sprintf("dns:%s", hostname) + ok, val, _ := d.cache.Get(cacheKey) + if ok { + ip, ok := val.([]net.IP) + if ok { + // only 1 ip, return it + if len(ip) == 1 { + return true, &ip[0], nil + } + // more than 1 ip, return a random one + i := rand.Int31n(int32(len(ip))) + return true, &ip[i], nil + } + } + c, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(c, "GET", "https://cloudflare-dns.com/dns-query?name="+url.QueryEscape(hostname), nil) + if err != nil { + return false, nil, err + } + req.Header.Set("accept", "application/dns-json") + req.Header.Set("user-agent", "Shopmonkey (+https://shopmonkey.io)") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false, nil, err + } + if resp.StatusCode != http.StatusOK { + return false, nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + var res Result + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return false, nil, fmt.Errorf("failed to decode dns json response: %w", err) + } + if res.Status != NoError { + return false, nil, fmt.Errorf("dns lookup failed: %s", res.Status) + } + var ips []net.IP + var minTTL uint + for _, a := range res.Answer { + if a.Type == A { + if minTTL == 0 || a.TTL < minTTL { + minTTL = a.TTL + } + ip := net.ParseIP(a.Data) + if ip == nil { + return false, nil, fmt.Errorf("failed to parse ip address: %s", a.Data) + } + ips = append(ips, ip) + } + } + if len(ips) == 0 { + return false, nil, fmt.Errorf("no A records found for %s", hostname) + } + if (ips[0].IsPrivate() || ips[0].IsLoopback()) && !d.isLocal { + return false, nil, ErrInvalidIP + } + expires := time.Duration(minTTL) * time.Second + if expires > time.Hour*24 { + expires = time.Hour * 24 + } + d.cache.Set(cacheKey, ips, expires) + return true, &ips[0], nil +} + +// New creates a new DNS caching resolver. +func New(cache cache.Cache, opts ...WithConfig) *Dns { + var config dnsConfig + for _, opt := range opts { + opt(&config) + } + val := &Dns{ + cache: cache, + isLocal: !config.FailIfLocal, + } + return val +} + +type WithConfig func(config *dnsConfig) + +// WithFailIfLocal will cause the DNS resolver to fail if the hostname is a local hostname. +func WithFailIfLocal() WithConfig { + return func(config *dnsConfig) { + config.FailIfLocal = true + } +} diff --git a/dns/dns_test.go b/dns/dns_test.go new file mode 100644 index 0000000..725e583 --- /dev/null +++ b/dns/dns_test.go @@ -0,0 +1,174 @@ +package dns + +import ( + "context" + "testing" + "time" + + "github.com/agentuity/go-common/cache" + "github.com/stretchr/testify/assert" +) + +func TestDNSIsValidAndCached(t *testing.T) { + c := cache.NewInMemory(context.Background(), time.Second) + defer c.Close() + d := New(c) + ok, ip, err := d.Lookup(context.Background(), "app.shopmonkey.cloud") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count := c.Hits("dns:app.shopmonkey.cloud") + assert.True(t, ok) + assert.Equal(t, 0, count) + + ok, ip, err = d.Lookup(context.Background(), "app.shopmonkey.cloud") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count = c.Hits("dns:app.shopmonkey.cloud") + assert.True(t, ok) + assert.Equal(t, 1, count) +} + +func TestDNSDomainIsInvalid(t *testing.T) { + c := cache.NewInMemory(context.Background(), time.Second) + defer c.Close() + d := New(c) + ok, ip, err := d.Lookup(context.Background(), "adasf123sdasdxc.dsadasdshopmonkey.cloud") + assert.Error(t, err, "dns lookup failed: Non-Existent Domain") + assert.False(t, ok) + assert.Nil(t, ip) + ok, count := c.Hits("dns:adasf123sdasdxc.dsadasdshopmonkey.cloud") + assert.False(t, ok) + assert.Equal(t, 0, count) +} + +func TestDNSLocalHost(t *testing.T) { + c := cache.NewInMemory(context.Background(), time.Second) + defer c.Close() + d := New(c) + ok, ip, err := d.Lookup(context.Background(), "localhost") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count := c.Hits("dns:localhost") + assert.False(t, ok) + assert.Equal(t, 0, count) + + ok, ip, err = d.Lookup(context.Background(), "localhost") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count = c.Hits("dns:localhost") + assert.False(t, ok) + assert.Equal(t, 0, count) +} + +func TestDNS127(t *testing.T) { + c := cache.NewInMemory(context.Background(), time.Second) + defer c.Close() + d := New(c) + ok, ip, err := d.Lookup(context.Background(), "127.0.0.1") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count := c.Hits("dns:127.0.0.1") + assert.False(t, ok) + assert.Equal(t, 0, count) + + ok, ip, err = d.Lookup(context.Background(), "127.0.0.1") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count = c.Hits("dns:127.0.0.1") + assert.False(t, ok) + assert.Equal(t, 0, count) +} + +func TestDNSIPAddressSkipped(t *testing.T) { + c := cache.NewInMemory(context.Background(), time.Second) + defer c.Close() + d := New(c) + ok, ip, err := d.Lookup(context.Background(), "81.0.0.1") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count := c.Hits("dns:81.0.0.1") + assert.False(t, ok) + assert.Equal(t, 0, count) + + ok, ip, err = d.Lookup(context.Background(), "81.0.0.1") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count = c.Hits("dns:81.0.0.1") + assert.False(t, ok) + assert.Equal(t, 0, count) +} + +func TestDNSPrivateIPSkipped(t *testing.T) { + c := cache.NewInMemory(context.Background(), time.Second) + defer c.Close() + d := New(c) + ok, ip, err := d.Lookup(context.Background(), "10.8.0.1") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count := c.Hits("dns:81.0.0.1") + assert.False(t, ok) + assert.Equal(t, 0, count) + + ok, ip, err = d.Lookup(context.Background(), "81.0.0.1") + assert.NoError(t, err) + assert.True(t, ok) + assert.NotNil(t, ip) + ok, count = c.Hits("dns:81.0.0.1") + assert.False(t, ok) + assert.Equal(t, 0, count) +} + +func TestDNSTest(t *testing.T) { + c := cache.NewInMemory(context.Background(), time.Second) + defer c.Close() + d := New(c, WithFailIfLocal()) + ok, ip, err := d.Lookup(context.Background(), "customer1.app.localhost.my.company.127.0.0.1.nip.io") + assert.Error(t, err, ErrInvalidIP) + assert.False(t, ok) + assert.Nil(t, ip) +} + +func TestInvalidDNSEntries(t *testing.T) { + tests := []struct { + name string + hostname string + }{ + {"EmptyHostname", ""}, + {"InvalidCharacters", "invalid!hostname"}, + {"TooLongHostname", "this.is.a.very.long.hostname.that.exceeds.the.maximum.length.allowed.by.the.dns.specification.and.should.therefore.fail.validation"}, + {"HostnameWithSpaces", "hostname with spaces"}, + {"HostnameWithUnderscore", "hostname_with_underscore"}, + {"Unresolved DNS", "bugbounty.dod.network"}, + {"Invalid Hostname", "0xA9.0xFE.0xA9.0xFE"}, + {"Invalid IP Address", "169.254.169.254"}, + {"Invalid IP Address From DNS", "169.254.169.254.nip.io"}, + {"Local IP v6", "[::1]"}, + {"Invalid IP v6", "[::ffff:7f00:1]"}, + {"Invalid Virtual DNS", "localtest.me"}, + {"Invalid Virtual DNS to Private", "spoofed.burpcollaborator.net"}, + {"Docker Host Internal", "host.docker.internal"}, + } + + c := cache.NewInMemory(context.Background(), time.Second) + defer c.Close() + d := New(c) + d.isLocal = false + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ok, ip, err := d.Lookup(context.Background(), tt.hostname) + assert.Error(t, err, ErrInvalidIP) + assert.False(t, ok) + assert.Nil(t, ip) + }) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..d3513a8 --- /dev/null +++ b/go.mod @@ -0,0 +1,29 @@ +module github.com/agentuity/go-common + +go 1.23 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 + github.com/cockroachdb/errors v1.11.3 + github.com/mattn/go-isatty v0.0.20 + github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 + github.com/stretchr/testify v1.9.0 + golang.org/x/sync v0.8.0 +) + +require ( + github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect + github.com/cockroachdb/redact v1.1.5 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/getsentry/sentry-go v0.27.0 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect + golang.org/x/sys v0.26.0 // indirect + golang.org/x/text v0.19.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..95448a2 --- /dev/null +++ b/go.sum @@ -0,0 +1,79 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cockroachdb/errors v1.11.3 h1:5bA+k2Y6r+oz/6Z/RFlNeVCesGARKuC6YymtcDrbC/I= +github.com/cockroachdb/errors v1.11.3/go.mod h1:m4UIW4CDjx+R5cybPsNrRbreomiFqt8o1h1wUVazSd8= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b h1:r6VH0faHjZeQy818SGhaone5OnYfxFR/+AzdY3sf5aE= +github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs= +github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwPJ30= +github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK3r3Ps= +github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= +github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= +github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= +github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc= +github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/logger/console.go b/logger/console.go new file mode 100644 index 0000000..2e6341a --- /dev/null +++ b/logger/console.go @@ -0,0 +1,210 @@ +package logger + +import ( + "encoding/json" + "fmt" + "log" + "os" + "runtime" + "strings" + "time" + + gstrings "github.com/agentuity/go-common/string" + "github.com/mattn/go-isatty" +) + +const isWindows = runtime.GOOS == "windows" + +var noColor = os.Getenv("TERM") == "dumb" || + (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) + +func color(val string) string { + if isWindows || noColor { + return "" + } + return val +} + +const ( + Reset = "\033[0m" + Gray = "\033[1;30m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Magenta = "\033[35m" + Cyan = "\033[36m" + White = "\033[37m" + BlueBold = "\033[34;1m" + MagentaBold = "\033[35;1m" + RedBold = "\033[31;1m" + YellowBold = "\033[33;1m" + WhiteBold = "\033[37;1m" + CyanBold = "\033[36;1m" + Purple = "\u001b[38;5;200m" +) + +type consoleLogger struct { + prefixes []string + metadata map[string]interface{} + traceLevelColor string + traceMessageColor string + debugLevelColor string + debugMessageColor string + infoLevelColor string + infoMessageColor string + warnLevelColor string + warnMessageColor string + errorLevelColor string + errorMessageColor string + sink Sink + logLevel LogLevel + sinkLogLevel LogLevel +} + +var _ Logger = (*consoleLogger)(nil) + +func (c *consoleLogger) Default(val string, def string) string { + if val == "" { + return def + } + return val +} + +// WithPrefix will return a new logger with a prefix prepended to the message +func (c *consoleLogger) WithPrefix(prefix string) Logger { + prefixes := make([]string, 0) + prefixes = append(prefixes, c.prefixes...) + if !gstrings.Contains(prefixes, prefix, false) { + prefixes = append(prefixes, prefix) + } + l := c.Clone(c.metadata, c.sink) + l.prefixes = prefixes + return l +} + +var isCI = os.Getenv("CI") != "" + +func (c *consoleLogger) Clone(kv map[string]interface{}, sink Sink) *consoleLogger { + prefixes := make([]string, 0) + prefixes = append(prefixes, c.prefixes...) + var tracecolor = Gray + if isCI { + tracecolor = Purple + } + return &consoleLogger{ + metadata: kv, + prefixes: prefixes, + traceLevelColor: c.Default(c.traceLevelColor, CyanBold), + traceMessageColor: c.Default(c.traceMessageColor, tracecolor), + debugLevelColor: c.Default(c.debugLevelColor, BlueBold), + debugMessageColor: c.Default(c.debugMessageColor, Green), + infoLevelColor: c.Default(c.infoLevelColor, YellowBold), + infoMessageColor: c.Default(c.infoMessageColor, WhiteBold), + warnLevelColor: c.Default(c.infoMessageColor, MagentaBold), + warnMessageColor: c.Default(c.warnMessageColor, Magenta), + errorLevelColor: c.Default(c.errorMessageColor, RedBold), + errorMessageColor: c.Default(c.errorMessageColor, Red), + sink: sink, + logLevel: c.logLevel, + sinkLogLevel: c.sinkLogLevel, + } +} + +func (c *consoleLogger) SetSink(sink Sink, level LogLevel) { + c.sink = sink + c.sinkLogLevel = level +} + +func (c *consoleLogger) With(metadata map[string]interface{}) Logger { + kv := metadata + if c.metadata != nil { + kv = make(map[string]interface{}) + for k, v := range c.metadata { + kv[k] = v + } + for k, v := range metadata { + kv[k] = v + } + } + if len(kv) == 0 { + kv = nil + } + return c.Clone(kv, c.sink) +} + +func (c *consoleLogger) Log(level LogLevel, levelColor string, messageColor string, levelString string, msg string, args ...interface{}) { + if level < c.logLevel && level < c.sinkLogLevel { + return + } + _msg := fmt.Sprintf(msg, args...) + var prefix string + var suffix string + if len(c.prefixes) > 0 { + prefix = color(Purple) + strings.Join(c.prefixes, " ") + color(Reset) + " " + } + if c.metadata != nil { + buf, _ := json.Marshal(c.metadata) + _buf := string(buf) + if _buf != "{}" { + if isCI { + suffix = " " + color(MagentaBold) + _buf + color(Reset) + } else { + suffix = " " + color(Gray) + _buf + color(Reset) + } + } + } + var levelSuffix string + if len(levelString) < 5 { + levelSuffix = strings.Repeat(" ", 5-len(levelString)) + } + levelText := color(levelColor) + fmt.Sprintf("[%s]%s", levelString, levelSuffix) + color(Reset) + message := color(messageColor) + _msg + color(Reset) + out := fmt.Sprintf("%s %s%s%s", levelText, prefix, message, suffix) + if level >= c.logLevel { + log.Printf("%s\n", out) + } + if c.sink != nil && level >= c.sinkLogLevel { + ts := time.Now().Format(time.RFC3339Nano) + c.sink.Write([]byte(ts + " " + ansiColorStripper.ReplaceAllString(out, "") + "\n")) + } +} + +func (c *consoleLogger) Trace(msg string, args ...interface{}) { + c.Log(LevelTrace, c.traceLevelColor, c.traceMessageColor, "TRACE", msg, args...) +} + +func (c *consoleLogger) Debug(msg string, args ...interface{}) { + c.Log(LevelDebug, c.debugLevelColor, c.debugMessageColor, "DEBUG", msg, args...) +} + +func (c *consoleLogger) Info(msg string, args ...interface{}) { + c.Log(LevelInfo, c.infoLevelColor, c.infoMessageColor, "INFO", msg, args...) +} + +func (c *consoleLogger) Warn(msg string, args ...interface{}) { + c.Log(LevelWarn, c.warnLevelColor, c.warnMessageColor, "WARN", msg, args...) +} + +func (c *consoleLogger) Error(msg string, args ...interface{}) { + c.Log(LevelError, c.errorLevelColor, c.errorMessageColor, "ERROR", msg, args...) +} + +func (c *consoleLogger) Fatal(msg string, args ...interface{}) { + c.Log(LevelError, c.errorLevelColor, c.errorMessageColor, "ERROR", msg, args...) + os.Exit(1) +} + +func (c *consoleLogger) SetLogLevel(level LogLevel) { + c.logLevel = level +} + +// NewConsoleLogger returns a new Logger instance which will log to the console +func NewConsoleLogger(levels ...LogLevel) SinkLogger { + if len(levels) > 0 { + return (&consoleLogger{logLevel: levels[0], sinkLogLevel: LevelNone}).Clone(nil, nil) + } + level := GetLevelFromEnv() + + return (&consoleLogger{logLevel: level, sinkLogLevel: LevelNone}).Clone(nil, nil) +} diff --git a/logger/console_test.go b/logger/console_test.go new file mode 100644 index 0000000..33fa8bf --- /dev/null +++ b/logger/console_test.go @@ -0,0 +1,295 @@ +package logger + +import ( + "bytes" + "log" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func captureOutput(f func()) string { + var buf bytes.Buffer + log.SetOutput(&buf) + f() + log.SetOutput(nil) + return buf.String() +} + +func TestConsoleLogger(t *testing.T) { + + logger := NewConsoleLogger().(*consoleLogger) + + tests := []struct { + level LogLevel + shouldContain []string + shouldNotContain []string + }{ + { + level: LevelTrace, + shouldContain: []string{"TRACE", "DEBUG", "INFO", "WARN", "ERROR"}, + shouldNotContain: []string{}, + }, + { + level: LevelDebug, + shouldContain: []string{"DEBUG", "INFO", "WARN", "ERROR"}, + shouldNotContain: []string{"TRACE"}, + }, + { + level: LevelInfo, + shouldContain: []string{"INFO", "WARN", "ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG"}, + }, + { + level: LevelWarn, + shouldContain: []string{"WARN", "ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG", "INFO"}, + }, + { + level: LevelError, + shouldContain: []string{"ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG", "INFO", "WARN"}, + }, + } + + for _, tt := range tests { + logger.SetLogLevel(tt.level) + output := captureOutput(func() { + logger.Trace("This is a trace message") + logger.Debug("This is a debug message") + logger.Info("This is an info message") + logger.Warn("This is a warn message") + logger.Error("This is an error message") + }) + for _, shouldContain := range tt.shouldContain { + assert.Contains(t, output, shouldContain) + } + for _, shouldNotContain := range tt.shouldNotContain { + assert.NotContains(t, output, shouldNotContain) + } + } +} + +func TestConsoleLoggerWithEnvLevel(t *testing.T) { + + tests := []struct { + level string + shouldContain []string + shouldNotContain []string + }{ + { + level: "TRACE", + shouldContain: []string{"TRACE", "DEBUG", "INFO", "WARN", "ERROR"}, + shouldNotContain: []string{}, + }, + { + level: "DEBUG", + shouldContain: []string{"DEBUG", "INFO", "WARN", "ERROR"}, + shouldNotContain: []string{"TRACE"}, + }, + { + level: "INFO", + shouldContain: []string{"INFO", "WARN", "ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG"}, + }, + { + level: "WARN", + shouldContain: []string{"WARN", "ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG", "INFO"}, + }, + { + level: "ERROR", + shouldContain: []string{"ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG", "INFO", "WARN"}, + }, + } + + for _, tt := range tests { + os.Setenv("SM_LOG_LEVEL", tt.level) + logger := NewConsoleLogger().(*consoleLogger) + + output := captureOutput(func() { + logger.Trace("This is a trace message") + logger.Debug("This is a debug message") + logger.Info("This is an info message") + logger.Warn("This is a warn message") + logger.Error("This is an error message") + }) + for _, shouldContain := range tt.shouldContain { + assert.Contains(t, output, shouldContain) + } + for _, shouldNotContain := range tt.shouldNotContain { + assert.NotContains(t, output, shouldNotContain) + } + os.Unsetenv("SM_LOG_LEVEL") + } +} + +func TestConsoleLoggerWithMetadata(t *testing.T) { + logger := NewConsoleLogger().(*consoleLogger) + + metadata := map[string]interface{}{ + "key1": "value1", + "key2": "value2", + } + logger = logger.With(metadata).(*consoleLogger) + + output := captureOutput(func() { + logger.Info("This is an info message with metadata") + }) + + assert.Contains(t, output, "This is an info message with metadata") + assert.Contains(t, output, `"key1":"value1"`) + assert.Contains(t, output, `"key2":"value2"`) +} + +func TestConsoleLoggerSinkTraceLevel(t *testing.T) { + logger := NewConsoleLogger().(*consoleLogger) + logger.SetLogLevel(LevelTrace) + + output := captureOutput(func() { + logger.Trace("This is a trace message") + }) + assert.Contains(t, output, "This is a trace message") + + output = captureOutput(func() { + logger.Debug("This is a debug message") + }) + assert.Contains(t, output, "This is a debug message") + + output = captureOutput(func() { + logger.Info("This is an info message") + }) + assert.Contains(t, output, "This is an info message") + + output = captureOutput(func() { + logger.Warn("This is a warn message") + }) + assert.Contains(t, output, "This is a warn message") + + output = captureOutput(func() { + logger.Error("This is an error message") + }) + assert.Contains(t, output, "This is an error message") +} + +func TestConsoleLoggerSinkDebugLevel(t *testing.T) { + logger := NewConsoleLogger().(*consoleLogger) + logger.SetLogLevel(LevelDebug) + + output := captureOutput(func() { + logger.Trace("This trace message should not be printed") + }) + assert.NotContains(t, output, "This trace message should not be printed") + + output = captureOutput(func() { + logger.Debug("This is a debug message") + }) + assert.Contains(t, output, "This is a debug message") + + output = captureOutput(func() { + logger.Info("This is an info message") + }) + assert.Contains(t, output, "This is an info message") + + output = captureOutput(func() { + logger.Warn("This is a warn message") + }) + assert.Contains(t, output, "This is a warn message") + + output = captureOutput(func() { + logger.Error("This is an error message") + }) + assert.Contains(t, output, "This is an error message") +} + +func TestConsoleLoggerSinkInfoLevel(t *testing.T) { + logger := NewConsoleLogger().(*consoleLogger) + logger.SetLogLevel(LevelInfo) + + output := captureOutput(func() { + logger.Trace("This trace message should not be printed") + }) + assert.NotContains(t, output, "This trace message should not be printed") + + output = captureOutput(func() { + logger.Debug("This debug message should not be printed") + }) + assert.NotContains(t, output, "This debug message should not be printed") + + output = captureOutput(func() { + logger.Info("This is an info message") + }) + assert.Contains(t, output, "This is an info message") + + output = captureOutput(func() { + logger.Warn("This is a warn message") + }) + assert.Contains(t, output, "This is a warn message") + + output = captureOutput(func() { + logger.Error("This is an error message") + }) + assert.Contains(t, output, "This is an error message") +} + +func TestConsoleLoggerSinkWarnLevel(t *testing.T) { + logger := NewConsoleLogger().(*consoleLogger) + logger.SetLogLevel(LevelWarn) + + output := captureOutput(func() { + logger.Trace("This trace message should not be printed") + }) + assert.NotContains(t, output, "This trace message should not be printed") + + output = captureOutput(func() { + logger.Debug("This debug message should not be printed") + }) + assert.NotContains(t, output, "This debug message should not be printed") + + output = captureOutput(func() { + logger.Info("This info message should not be printed") + }) + assert.NotContains(t, output, "This info message should not be printed") + + output = captureOutput(func() { + logger.Warn("This is a warn message") + }) + assert.Contains(t, output, "This is a warn message") + + output = captureOutput(func() { + logger.Error("This is an error message") + }) + assert.Contains(t, output, "This is an error message") +} + +func TestConsoleLoggerSinkErrorLevel(t *testing.T) { + logger := NewConsoleLogger().(*consoleLogger) + logger.SetLogLevel(LevelError) + + output := captureOutput(func() { + logger.Trace("This trace message should not be printed") + }) + assert.NotContains(t, output, "This trace message should not be printed") + + output = captureOutput(func() { + logger.Debug("This debug message should not be printed") + }) + assert.NotContains(t, output, "This debug message should not be printed") + + output = captureOutput(func() { + logger.Info("This info message should not be printed") + }) + assert.NotContains(t, output, "This info message should not be printed") + + output = captureOutput(func() { + logger.Warn("This warn message should not be printed") + }) + assert.NotContains(t, output, "This warn message should not be printed") + + output = captureOutput(func() { + logger.Error("This is an error message") + }) + assert.Contains(t, output, "This is an error message") +} diff --git a/logger/gcloud.go b/logger/gcloud.go new file mode 100644 index 0000000..9640c6e --- /dev/null +++ b/logger/gcloud.go @@ -0,0 +1,11 @@ +package logger + +// NewGCloudLogger returns a new Logger instance which can be used for structured google cloud logging +func NewGCloudLogger() Logger { + return NewJSONLogger() +} + +// NewGCloudLoggerWithSink returns a new Logger instance using a sink and suppressing the console logging +func NewGCloudLoggerWithSink(sink Sink, level LogLevel) Logger { + return NewJSONLoggerWithSink(sink, level) +} diff --git a/logger/init.go b/logger/init.go new file mode 100644 index 0000000..2c65cc2 --- /dev/null +++ b/logger/init.go @@ -0,0 +1,69 @@ +package logger + +import ( + "io" + "os" + "regexp" + "strings" +) + +// LogLevel defines the level of logging +type LogLevel int + +const ( + LevelTrace LogLevel = iota + LevelDebug + LevelInfo + LevelWarn + LevelError + LevelNone +) + +// GetLevelFrom env will look at the environment var `SM_LOG_LEVEL` and convert it into the appropriate LogLevel +func GetLevelFromEnv() LogLevel { + s := os.Getenv("SM_LOG_LEVEL") + switch strings.ToLower(s) { // Convert the string to lowercase to make it case-insensitive + case "trace": + return LevelTrace + case "debug": + return LevelDebug + case "info": + return LevelInfo + case "warn": + return LevelWarn + case "error": + return LevelError + default: + return LevelDebug // Return an unknown or default value for invalid strings + } +} + +type Sink io.Writer + +// Logger is an interface for logging +type Logger interface { + // With will return a new logger using metadata as the base context + With(metadata map[string]interface{}) Logger + // WithPrefix will return a new logger with a prefix prepended to the message + WithPrefix(prefix string) Logger + // Trace level logging + Trace(msg string, args ...interface{}) + // Debug level logging + Debug(msg string, args ...interface{}) + // Info level logging + Info(msg string, args ...interface{}) + // Warning level logging + Warn(msg string, args ...interface{}) + // Error level logging + Error(msg string, args ...interface{}) + // Fatal level logging and exit with code 1 + Fatal(msg string, args ...interface{}) +} + +type SinkLogger interface { + Logger + // SetSink will set the sink, and level to sink + SetSink(sink Sink, level LogLevel) +} + +var ansiColorStripper = regexp.MustCompile("\x1b\\[[0-9;]*[mK]") diff --git a/logger/json.go b/logger/json.go new file mode 100644 index 0000000..2a30959 --- /dev/null +++ b/logger/json.go @@ -0,0 +1,189 @@ +package logger + +import ( + "encoding/json" + "fmt" + "log" + "regexp" + "strings" + "time" +) + +// JSONLogEntry defines a log entry +// this is modeled after the JSON format expected by Cloud Logging +// https://github.com/GoogleCloudPlatform/golang-samples/blob/08bc985b4973901c09344eabbe9d7d5add7dc656/run/logging-manual/main.go +type JSONLogEntry struct { + Timestamp time.Time `json:"timestamp,omitempty"` + Message string `json:"message"` + Severity string `json:"severity,omitempty"` + Trace string `json:"logging.googleapis.com/trace,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Logs Explorer allows filtering and display of this as `jsonPayload.component`. + Component string `json:"component,omitempty"` + logLevel LogLevel +} + +// String renders an entry structure to the JSON format expected by Cloud Logging. +func (e JSONLogEntry) String() string { + if e.Severity == "" { + e.Severity = "INFO" + } + out, err := json.Marshal(e) + if err != nil { + log.Printf("json.Marshal: %v", err) + } + return string(out) +} + +type jsonLogger struct { + metadata map[string]interface{} + traceID string + component string + sink Sink + sinkLogLevel LogLevel + noConsole bool + ts *time.Time // for unit testing + logLevel LogLevel +} + +var _ Logger = (*jsonLogger)(nil) + +func (c *jsonLogger) SetSink(sink Sink, level LogLevel) { + c.sink = sink + c.sinkLogLevel = level +} + +// WithPrefix will return a new logger with a prefix prepended to the message +func (c *jsonLogger) WithPrefix(prefix string) Logger { + newlogger := c.With(nil).(*jsonLogger) + if c.component == "" { + newlogger.component = prefix + } else { + if !strings.Contains(c.component, prefix) { + newlogger.component = c.component + " " + prefix + } + } + return newlogger +} + +func (c *jsonLogger) With(metadata map[string]interface{}) Logger { + traceID := c.traceID + component := c.component + if trace, ok := metadata["trace"].(string); ok { + traceID = trace + delete(metadata, "trace") + } + if comp, ok := metadata["component"].(string); ok { + component = comp + delete(metadata, "component") + } + kv := metadata + if c.metadata != nil { + kv = make(map[string]interface{}) + for k, v := range c.metadata { + kv[k] = v + } + for k, v := range metadata { + kv[k] = v + } + } + if len(kv) == 0 { + kv = nil + } + return &jsonLogger{ + metadata: kv, + traceID: traceID, + component: component, + noConsole: c.noConsole, + sink: c.sink, + sinkLogLevel: c.sinkLogLevel, + logLevel: c.logLevel, + } +} + +var bracketRegex = regexp.MustCompile(`\[(.*?)\]`) + +func (c *jsonLogger) tokenize(val string) string { + if bracketRegex.MatchString(val) { + vals := make([]string, 0) + for _, token := range bracketRegex.FindAllString(val, -1) { + vals = append(vals, bracketRegex.ReplaceAllString(token, "$1")) + } + return strings.Join(vals, ", ") + } + return val +} + +func (c *jsonLogger) Log(level LogLevel, severity string, msg string, args ...interface{}) { + if level < c.logLevel && level < c.sinkLogLevel { + return + } + _msg := msg + if len(args) > 0 { + _msg = fmt.Sprintf(msg, args...) + } + entry := JSONLogEntry{ + Severity: severity, + Message: _msg, + Trace: c.traceID, + Metadata: c.metadata, + Component: c.tokenize(c.component), + Timestamp: time.Now(), + } + if !c.noConsole && level >= c.logLevel { + log.Println(entry) + } + if c.sink != nil && level >= c.sinkLogLevel { + entry.Message = ansiColorStripper.ReplaceAllString(entry.Message, "") + if c.ts != nil { + entry.Timestamp = *c.ts // for testing + } + buf, _ := json.Marshal(entry) + if _, err := c.sink.Write(buf); err != nil { + log.Printf("sink.Write: %v", err) + } + } +} + +func (c *jsonLogger) Trace(msg string, args ...interface{}) { + c.Log(LevelTrace, "TRACE", msg, args...) +} + +func (c *jsonLogger) Debug(msg string, args ...interface{}) { + c.Log(LevelDebug, "DEBUG", msg, args...) +} + +func (c *jsonLogger) Info(msg string, args ...interface{}) { + c.Log(LevelInfo, "INFO", msg, args...) +} + +func (c *jsonLogger) Warn(msg string, args ...interface{}) { + c.Log(LevelWarn, "WARNING", msg, args...) +} + +func (c *jsonLogger) Error(msg string, args ...interface{}) { + c.Log(LevelError, "ERROR", msg, args...) +} + +func (c *jsonLogger) Fatal(msg string, args ...interface{}) { + c.Log(LevelError, "ERROR", msg, args...) +} + +func (c *jsonLogger) SetLogLevel(level LogLevel) { + c.logLevel = level +} + +// NewJSONLogger returns a new Logger instance which can be used for structured logging +func NewJSONLogger(levels ...LogLevel) Logger { + if len(levels) > 0 { + return &jsonLogger{logLevel: levels[0]} + } + level := GetLevelFromEnv() + return &jsonLogger{logLevel: level} + +} + +// NewJSONLoggerWithSink returns a new Logger instance using a sink and suppressing the console logging +func NewJSONLoggerWithSink(sink Sink, level LogLevel) SinkLogger { + return &jsonLogger{noConsole: true, sink: sink, sinkLogLevel: level} +} diff --git a/logger/logger_test.go b/logger/logger_test.go new file mode 100644 index 0000000..6dfca36 --- /dev/null +++ b/logger/logger_test.go @@ -0,0 +1,159 @@ +package logger + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type testSink struct { + buf []byte +} + +func (s *testSink) Write(buf []byte) (int, error) { + s.buf = buf + return len(buf), nil +} + +func TestGCloudLogger(t *testing.T) { + sink := &testSink{} + log := NewGCloudLoggerWithSink(sink, LevelTrace) + jlog := log.(*jsonLogger) + tv := time.Date(2023, 10, 22, 12, 30, 0, 0, time.UTC) + jlog.ts = &tv + log.Trace("Hi") + assert.Equal(t, `{"timestamp":"2023-10-22T12:30:00Z","message":"Hi","severity":"TRACE"}`, string(sink.buf)) + wlog := log.WithPrefix("[hi]") + jlog = wlog.(*jsonLogger) + jlog.ts = &tv + wlog.Debug("hi") + assert.Equal(t, `{"timestamp":"2023-10-22T12:30:00Z","message":"hi","severity":"DEBUG","component":"hi"}`, string(sink.buf)) + w2log := wlog.WithPrefix("[bye]") + jlog = w2log.(*jsonLogger) + jlog.ts = &tv + w2log.Debug("hi") + assert.Equal(t, `{"timestamp":"2023-10-22T12:30:00Z","message":"hi","severity":"DEBUG","component":"hi, bye"}`, string(sink.buf)) +} + +func TestCombinedLogger(t *testing.T) { + sink := &testSink{} + log := NewTestLogger() + jsonLog := NewJSONLoggerWithSink(sink, LevelTrace) + tv := time.Date(2023, 10, 22, 12, 30, 0, 0, time.UTC) + jsonLog.(*jsonLogger).ts = &tv + combined := NewMultiLogger(log, jsonLog) + combined.Info("Ayyyyyy") + assert.Len(t, log.Logs, 1) + assert.Equal(t, `{"timestamp":"2023-10-22T12:30:00Z","message":"Ayyyyyy","severity":"INFO"}`, string(sink.buf)) +} + +func TestJSONLogger(t *testing.T) { + + logger := NewJSONLogger().(*jsonLogger) + + tests := []struct { + level LogLevel + shouldContain []string + shouldNotContain []string + }{ + { + level: LevelTrace, + shouldContain: []string{"TRACE", "DEBUG", "INFO", "WARN", "ERROR"}, + shouldNotContain: []string{}, + }, + { + level: LevelDebug, + shouldContain: []string{"DEBUG", "INFO", "WARN", "ERROR"}, + shouldNotContain: []string{"TRACE"}, + }, + { + level: LevelInfo, + shouldContain: []string{"INFO", "WARN", "ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG"}, + }, + { + level: LevelWarn, + shouldContain: []string{"WARN", "ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG", "INFO"}, + }, + { + level: LevelError, + shouldContain: []string{"ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG", "INFO", "WARN"}, + }, + } + + for _, tt := range tests { + logger.SetLogLevel(tt.level) + output := captureOutput(func() { + logger.Trace("This is a trace message") + logger.Debug("This is a debug message") + logger.Info("This is an info message") + logger.Warn("This is a warn message") + logger.Error("This is an error message") + }) + for _, shouldContain := range tt.shouldContain { + assert.Contains(t, output, shouldContain) + } + for _, shouldNotContain := range tt.shouldNotContain { + assert.NotContains(t, output, shouldNotContain) + } + } +} + +func TestJSONLoggerWithEnvLevel(t *testing.T) { + + tests := []struct { + level string + shouldContain []string + shouldNotContain []string + }{ + { + level: "TRACE", + shouldContain: []string{"TRACE", "DEBUG", "INFO", "WARN", "ERROR"}, + shouldNotContain: []string{}, + }, + { + level: "DEBUG", + shouldContain: []string{"DEBUG", "INFO", "WARN", "ERROR"}, + shouldNotContain: []string{"TRACE"}, + }, + { + level: "INFO", + shouldContain: []string{"INFO", "WARN", "ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG"}, + }, + { + level: "WARN", + shouldContain: []string{"WARN", "ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG", "INFO"}, + }, + { + level: "ERROR", + shouldContain: []string{"ERROR"}, + shouldNotContain: []string{"TRACE", "DEBUG", "INFO", "WARN"}, + }, + } + + for _, tt := range tests { + os.Setenv("SM_LOG_LEVEL", tt.level) + logger := NewJSONLogger().(*jsonLogger) + + output := captureOutput(func() { + logger.Trace("This is a trace message") + logger.Debug("This is a debug message") + logger.Info("This is an info message") + logger.Warn("This is a warn message") + logger.Error("This is an error message") + }) + for _, shouldContain := range tt.shouldContain { + assert.Contains(t, output, shouldContain) + } + for _, shouldNotContain := range tt.shouldNotContain { + assert.NotContains(t, output, shouldNotContain) + } + os.Unsetenv("SM_LOG_LEVEL") + } +} diff --git a/logger/multi.go b/logger/multi.go new file mode 100644 index 0000000..125186c --- /dev/null +++ b/logger/multi.go @@ -0,0 +1,57 @@ +package logger + +type muxLogger struct { + loggers []Logger +} + +func (m *muxLogger) With(metadata map[string]interface{}) Logger { + var newLoggers []Logger + for _, l := range m.loggers { + newLoggers = append(newLoggers, l.With(metadata)) + } + return NewMultiLogger(newLoggers...) +} + +func (m *muxLogger) WithPrefix(prefix string) Logger { + var newLoggers []Logger + for _, l := range m.loggers { + newLoggers = append(newLoggers, l.WithPrefix(prefix)) + } + return NewMultiLogger(newLoggers...) +} + +func (m *muxLogger) Trace(msg string, args ...interface{}) { + m.each(func(l Logger) { l.Trace(msg, args...) }) +} + +func (m *muxLogger) Debug(msg string, args ...interface{}) { + m.each(func(l Logger) { l.Debug(msg, args...) }) +} + +func (m *muxLogger) Info(msg string, args ...interface{}) { + m.each(func(l Logger) { l.Info(msg, args...) }) +} + +func (m *muxLogger) Warn(msg string, args ...interface{}) { + m.each(func(l Logger) { l.Warn(msg, args...) }) +} + +func (m *muxLogger) Error(msg string, args ...interface{}) { + m.each(func(l Logger) { l.Error(msg, args...) }) +} + +func (m *muxLogger) Fatal(msg string, args ...interface{}) { + m.each(func(l Logger) { l.Fatal(msg, args...) }) +} + +func (m *muxLogger) each(f func(Logger)) { + for _, l := range m.loggers { + f(l) + } +} + +func NewMultiLogger(loggers ...Logger) Logger { + return &muxLogger{ + loggers: loggers, + } +} diff --git a/logger/test.go b/logger/test.go new file mode 100644 index 0000000..fd5be46 --- /dev/null +++ b/logger/test.go @@ -0,0 +1,75 @@ +package logger + +import "os" + +type TestLogEntry struct { + Severity string + Message string + Arguments []interface{} +} + +type TestLogger struct { + metadata map[string]interface{} + Logs []TestLogEntry +} + +var _ Logger = (*TestLogger)(nil) + +func (c *TestLogger) WithSink(sink Sink, level LogLevel) Logger { + return c +} + +// WithPrefix will return a new logger with a prefix prepended to the message +func (c *TestLogger) WithPrefix(prefix string) Logger { + return c +} + +func (c *TestLogger) With(metadata map[string]interface{}) Logger { + kv := metadata + if c.metadata != nil { + kv = make(map[string]interface{}) + for k, v := range c.metadata { + kv[k] = v + } + for k, v := range metadata { + kv[k] = v + } + } + return &TestLogger{kv, c.Logs} +} + +func (c *TestLogger) Log(level string, msg string, args ...interface{}) { + c.Logs = append(c.Logs, TestLogEntry{level, msg, args}) +} + +func (c *TestLogger) Trace(msg string, args ...interface{}) { + c.Log("TRACE", msg, args...) +} + +func (c *TestLogger) Debug(msg string, args ...interface{}) { + c.Log("DEBUG", msg, args...) +} + +func (c *TestLogger) Info(msg string, args ...interface{}) { + c.Log("INFO", msg, args...) +} + +func (c *TestLogger) Warn(msg string, args ...interface{}) { + c.Log("WARNING", msg, args...) +} + +func (c *TestLogger) Error(msg string, args ...interface{}) { + c.Log("ERROR", msg, args...) +} + +func (c *TestLogger) Fatal(msg string, args ...interface{}) { + c.Log("FATAL", msg, args...) + os.Exit(1) +} + +// NewTestLogger returns a new Logger instance useful for testing +func NewTestLogger() *TestLogger { + return &TestLogger{ + Logs: make([]TestLogEntry, 0), + } +} diff --git a/logger/util.go b/logger/util.go new file mode 100644 index 0000000..0ddadec --- /dev/null +++ b/logger/util.go @@ -0,0 +1,5 @@ +package logger + +func WithKV(logger Logger, key string, value any) Logger { + return logger.With(map[string]interface{}{key: value}) +} diff --git a/renovate.json b/renovate.json new file mode 100644 index 0000000..5249302 --- /dev/null +++ b/renovate.json @@ -0,0 +1,17 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": ["config:base"], + "packageRules": [ + { + "description": "Automerge non-major updates", + "matchUpdateTypes": ["minor", "patch"], + "automerge": true + } + ], + "lockFileMaintenance": { + "enabled": true, + "automerge": true, + "automergeType": "pr", + "platformAutomerge": true + } +} diff --git a/request/request.go b/request/request.go new file mode 100644 index 0000000..9d6f959 --- /dev/null +++ b/request/request.go @@ -0,0 +1,410 @@ +package request + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "math" + "net" + ghttp "net/http" + "regexp" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/agentuity/go-common/dns" + cstr "github.com/agentuity/go-common/string" + "golang.org/x/sync/semaphore" +) + +var ErrTooManyAttempts = errors.New("too many attempts") + +const ( + userAgentHeaderValue = "Agentuity (+https://agentuity.com)" +) + +// Request is an interface for an HTTP request. +type Request interface { + // Method returns the HTTP method. + Method() string + // URL returns the URL. + URL() string + // Headers returns the headers. + Headers() map[string]string + // Payload returns the payload. + Payload() []byte +} + +type HTTPRequest struct { + method string + url string + headers map[string]string + payload []byte +} + +func (r *HTTPRequest) Method() string { + return r.method +} + +func (r *HTTPRequest) URL() string { + return r.url +} + +func (r *HTTPRequest) Headers() map[string]string { + return r.headers +} + +func (r *HTTPRequest) Payload() []byte { + return r.payload +} + +// NewHTTPRequest creates a new HTTPRequest that implements the Request interface. +func NewHTTPRequest(method string, url string, headers map[string]string, payload []byte) Request { + return &HTTPRequest{method, url, headers, payload} +} + +// NewHTTPGetRequest creates a new HTTPRequest that implements the Request interface for GET requests. +func NewHTTPGetRequest(url string, headers map[string]string) Request { + return &HTTPRequest{ghttp.MethodGet, url, headers, nil} +} + +// NewHTTPPostRequest creates a new HTTPRequest that implements the Request interface for POST requests. +func NewHTTPPostRequest(url string, headers map[string]string, payload []byte) Request { + return &HTTPRequest{ghttp.MethodPost, url, headers, payload} +} + +// Response is the response from an HTTP request. +type Response struct { + StatusCode int `json:"statusCode"` + Body []byte `json:"body,omitempty"` + Headers map[string]string `json:"headers"` + Attempts uint `json:"attempts"` + Latency time.Duration `json:"latency"` +} + +// Recorder is an interface for recording request / responses. +type Recorder interface { + OnResponse(ctx context.Context, req Request, resp *Response) +} + +// Http is an interface for making HTTP requests. +type Http interface { + // Deliver sends a request and returns a response. + Deliver(ctx context.Context, request Request) (*Response, error) +} + +// RetryBackoff is an interface for retrying a request with a backoff. +type RetryBackoff interface { + // BackOff returns the duration to wait before retrying. + BackOff(attempt uint) time.Duration +} + +type powerOfTwoBackoff struct { + min time.Duration + max time.Duration +} + +func (p *powerOfTwoBackoff) BackOff(attempt uint) time.Duration { + if attempt == 0 { + return p.min + } + ms := time.Duration(math.Pow(2, float64(attempt))) * p.min + if ms > p.max { + return p.max + } + return ms +} + +// NewMinMaxBackoff creates a new RetryBackoff that retries with a backoff of min * 2^attempt. +func NewMinMaxBackoff(min time.Duration, max time.Duration) RetryBackoff { + return &powerOfTwoBackoff{min, max} +} + +type http struct { + transport *ghttp.Transport + timeout time.Duration // set for testing but defaults to 55 seconds otherwise + dur time.Duration // set for testing but defaults to 1 second otherwise + recorder Recorder + count uint64 + semaphore *semaphore.Weighted + maxAttempts uint + backoff RetryBackoff +} + +var _ Http = (*http)(nil) + +func (h *http) shouldRetry(resp *ghttp.Response, err error) bool { + if err != nil { + msg := err.Error() + if strings.Contains(msg, "connection reset") || strings.Contains(msg, "connection refused") || strings.Contains(msg, "EOF") { + return true + } + } + if resp != nil { + switch resp.StatusCode { + case ghttp.StatusRequestTimeout, ghttp.StatusBadGateway, ghttp.StatusServiceUnavailable, ghttp.StatusGatewayTimeout, ghttp.StatusTooManyRequests: + return true + } + } + return false +} + +func (h *http) toResponse(resp *ghttp.Response, attempt uint, latency time.Duration) (*Response, error) { + if resp == nil { + return nil, nil + } + var body []byte + if resp.Body != nil { + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + body = b + } + headers := make(map[string]string) + for k, v := range resp.Header { + headers[k] = strings.Join(v, ", ") + } + return &Response{ + StatusCode: resp.StatusCode, + Body: body, + Headers: headers, + Attempts: attempt, + Latency: latency, + }, nil +} + +var isNumber = regexp.MustCompile("^[0-9]+$") + +func (h *http) generateRequestId(req Request) string { + count := atomic.AddUint64(&h.count, 1) + return fmt.Sprintf("%d/%s", count, cstr.NewHash(req.URL(), req.Payload(), time.Now().UnixNano())) +} + +func (h *http) Deliver(ctx context.Context, req Request) (*Response, error) { + started := time.Now() + if err := h.semaphore.Acquire(ctx, 1); err != nil { + return nil, fmt.Errorf("error acquiring semaphore: %w", err) + } + defer h.semaphore.Release(1) + var attempt uint + var resp *ghttp.Response + var response *Response + var c context.Context + var cancel context.CancelFunc + maxAttempts := h.maxAttempts + if maxAttempts == 1 { + // for testing we want to make sure we don't wait too long + c, cancel = context.WithTimeout(ctx, time.Second*3) + // we only want to try once for tests + } else { + c, cancel = context.WithTimeout(ctx, h.timeout) + } + defer cancel() + headers := req.Headers() + if headers == nil { + headers = make(map[string]string) + } + for attempt < maxAttempts { + attempt++ + reqId := h.generateRequestId(req) + var body io.Reader + payload := req.Payload() + if len(payload) > 0 { + body = bytes.NewBuffer(payload) + } + hreq, err := ghttp.NewRequestWithContext(c, req.Method(), req.URL(), body) + if err != nil { + return nil, err + } + for k, v := range headers { + hreq.Header.Set(k, v) + } + hreq.Header.Set("User-Agent", userAgentHeaderValue) + hreq.Header.Set("X-Request-Id", reqId) + hreq.Header.Set("X-Attempt", strconv.Itoa(int(attempt))) + resp, err = h.transport.RoundTrip(hreq) + if err != nil && (errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)) { + return nil, err + } + if h.recorder != nil && resp != nil /*&& !req.TestOnly*/ { + r, err := h.toResponse(resp, attempt, time.Since(started)) + if err != nil { + return nil, fmt.Errorf("error converting response: %w", err) + } + // update our headers + headers["User-Agent"] = userAgentHeaderValue + headers["X-Request-Id"] = reqId + headers["X-Attempt"] = strconv.Itoa(int(attempt)) + h.recorder.OnResponse(ctx, req, r) + response = r // set it so we don't try and re-read the body again + } + if h.shouldRetry(resp, err) /*&& !req.TestOnly*/ { + select { + case <-ctx.Done(): + return nil, context.Canceled + default: + } + if attempt == maxAttempts { + // don't worry about sleeping and reading the body if we're not going to retry + break + } + ms := h.dur * time.Duration(attempt) + if h.backoff != nil { + ms = h.backoff.BackOff(attempt) + } + if resp != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + ra := resp.Header.Get("Retry-After") + if ra != "" { + if isNumber.MatchString(ra) { + afterSeconds, _ := strconv.Atoi(ra) + if afterSeconds > 0 { + ms = time.Second * time.Duration(afterSeconds) + } + } else { + if tv, err := time.Parse(ghttp.TimeFormat, ra); err == nil { + ms = time.Until(tv) + } + } + } + } + if ms > 0 { + time.Sleep(ms) + } + continue + } + if err != nil { + return nil, err + } + if response == nil { + return h.toResponse(resp, attempt, time.Since(started)) + } + return response, nil + } + if response == nil && resp != nil { + r, err := h.toResponse(resp, attempt, time.Since(started)) + if err != nil { + return nil, err + } + return r, ErrTooManyAttempts + } + return response, ErrTooManyAttempts +} + +type configOpts struct { + recorder Recorder + max uint64 + dns dns.DNS + timeout time.Duration + dur time.Duration + maxAttempts uint + backoff RetryBackoff +} + +type ConfigOpt func(opts *configOpts) + +// New returns a new HTTP implementation. +func New(opts ...ConfigOpt) Http { + var c configOpts + c.timeout = time.Second * 55 + c.dur = time.Second + c.max = 100 + c.maxAttempts = 4 + c.backoff = NewMinMaxBackoff(time.Millisecond*50, time.Second*10) + for _, opt := range opts { + opt(&c) + } + if c.max <= 0 { + panic("max was nil") + } + if c.maxAttempts <= 0 { + panic("maxAttempts was nil") + } + var transport *ghttp.Transport + if c.dns != nil { + transport = &ghttp.Transport{ + DialContext: func(ctx context.Context, network string, addr string) (conn net.Conn, err error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ok, ip, err := c.dns.Lookup(ctx, host) + if err != nil { + return nil, err + } + if !ok { + return nil, fmt.Errorf("dns lookup failed: couldn't find ip for %s", host) + } + var dialer net.Dialer + conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + return + }, + } + } else { + transport = ghttp.DefaultTransport.(*ghttp.Transport) + } + return &http{ + transport: transport, + timeout: c.timeout, + dur: c.dur, + recorder: c.recorder, + semaphore: semaphore.NewWeighted(int64(c.max)), + maxAttempts: c.maxAttempts, + backoff: c.backoff, + } +} + +// WithDNS sets the dns resolver for the http client. +func WithDNS(dns dns.DNS) ConfigOpt { + return func(opts *configOpts) { + opts.dns = dns + } +} + +// WithRecorder sets the recorder for the http client. +func WithRecorder(recorder Recorder) ConfigOpt { + return func(opts *configOpts) { + opts.recorder = recorder + } +} + +// WithMaxConcurrency sets the max number of concurrent requests. +func WithMaxConcurrency(max uint64) ConfigOpt { + return func(opts *configOpts) { + opts.max = max + } +} + +// WithTimeout sets the timeout for the http client. +func WithTimeout(timeout time.Duration) ConfigOpt { + return func(opts *configOpts) { + opts.timeout = timeout + } +} + +// WithBackoffDuration sets the backoff duration for the http client. +func WithBackoffDuration(dur time.Duration) ConfigOpt { + return func(opts *configOpts) { + opts.dur = dur + } +} + +// WithMaxAttempts sets the max number of attempts for the http client. +func WithMaxAttempts(max uint) ConfigOpt { + return func(opts *configOpts) { + opts.maxAttempts = max + } +} + +// WithBackoff sets the backoff strategy for the http client. +func WithBackoff(backoff RetryBackoff) ConfigOpt { + return func(opts *configOpts) { + opts.backoff = backoff + } +} diff --git a/request/request_test.go b/request/request_test.go new file mode 100644 index 0000000..149b405 --- /dev/null +++ b/request/request_test.go @@ -0,0 +1,261 @@ +package request + +import ( + "context" + "fmt" + ghttp "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/sync/semaphore" +) + +func TestHTTPOK(t *testing.T) { + h := New() + srv := httptest.NewServer(ghttp.HandlerFunc(func(w ghttp.ResponseWriter, r *ghttp.Request) { + assert.Contains(t, r.Header, "User-Agent") + assert.Contains(t, r.Header, "X-Request-Id") + assert.Equal(t, "1", r.Header.Get("X-Attempt")) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(ghttp.StatusOK) + w.Write([]byte(fmt.Sprintf(`{"message":"%s"}`, r.Method))) + })) + defer srv.Close() + resp, err := h.Deliver(context.Background(), NewHTTPPostRequest(srv.URL, map[string]string{}, []byte("hello"))) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, ghttp.StatusOK, resp.StatusCode) + assert.Equal(t, `{"message":"POST"}`, string(resp.Body)) + assert.Equal(t, "application/json", resp.Headers["Content-Type"]) + assert.Equal(t, uint(1), resp.Attempts) + resp, err = h.Deliver(context.Background(), NewHTTPGetRequest(srv.URL, map[string]string{})) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, ghttp.StatusOK, resp.StatusCode) + assert.Equal(t, `{"message":"GET"}`, string(resp.Body)) + assert.Equal(t, "application/json", resp.Headers["Content-Type"]) + assert.Equal(t, uint(1), resp.Attempts) +} + +func TestHTTPRetry(t *testing.T) { + h := New() + var count int + srv := httptest.NewServer(ghttp.HandlerFunc(func(w ghttp.ResponseWriter, r *ghttp.Request) { + count++ + assert.Contains(t, r.Header, "User-Agent") + assert.Contains(t, r.Header, "X-Request-Id") + assert.Equal(t, strconv.Itoa(count), r.Header.Get("X-Attempt")) + if count < 3 { + w.WriteHeader(ghttp.StatusBadGateway) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(ghttp.StatusOK) + w.Write([]byte(`{"message":"hello"}`)) + })) + defer srv.Close() + resp, err := h.Deliver(context.Background(), NewHTTPRequest(ghttp.MethodPost, srv.URL, map[string]string{}, nil)) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, ghttp.StatusOK, resp.StatusCode) + assert.Equal(t, `{"message":"hello"}`, string(resp.Body)) + assert.Equal(t, "application/json", resp.Headers["Content-Type"]) + assert.Equal(t, 3, count) + assert.Equal(t, uint(3), resp.Attempts) +} + +func TestHTTPRetryWithRetryAfterHeader(t *testing.T) { + h := New(WithMaxAttempts(3)) + var count int + ts := time.Now() + srv := httptest.NewServer(ghttp.HandlerFunc(func(w ghttp.ResponseWriter, r *ghttp.Request) { + count++ + if count < 2 { + w.Header().Set("Retry-After", "2") + w.WriteHeader(ghttp.StatusTooManyRequests) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(ghttp.StatusOK) + w.Write([]byte(`{"message":"hello"}`)) + })) + defer srv.Close() + resp, err := h.Deliver(context.Background(), NewHTTPRequest(ghttp.MethodPost, srv.URL, map[string]string{}, nil)) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, ghttp.StatusOK, resp.StatusCode) + assert.Equal(t, `{"message":"hello"}`, string(resp.Body)) + assert.Equal(t, "application/json", resp.Headers["Content-Type"]) + assert.Equal(t, 2, count) + assert.True(t, time.Since(ts) > 2*time.Second) + assert.Equal(t, uint(2), resp.Attempts) + assert.True(t, resp.Latency > 2*time.Second) +} + +func TestHTTPRetryWithRetryAfterHeaderAsTime(t *testing.T) { + h := New() + var count int + srv := httptest.NewServer(ghttp.HandlerFunc(func(w ghttp.ResponseWriter, r *ghttp.Request) { + count++ + if count < 2 { + w.Header().Set("Retry-After", time.Now().Add(3*time.Second).UTC().Format(ghttp.TimeFormat)) + w.WriteHeader(ghttp.StatusTooManyRequests) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(ghttp.StatusOK) + w.Write([]byte(`{"message":"hello"}`)) + })) + defer srv.Close() + resp, err := h.Deliver(context.Background(), NewHTTPRequest(ghttp.MethodPost, srv.URL, map[string]string{}, nil)) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, ghttp.StatusOK, resp.StatusCode) + assert.Equal(t, `{"message":"hello"}`, string(resp.Body)) + assert.Equal(t, "application/json", resp.Headers["Content-Type"]) + assert.Equal(t, 2, count) + assert.Equal(t, uint(2), resp.Attempts) + assert.True(t, resp.Latency > 2*time.Second) +} + +func TestHTTPRetryTimeout(t *testing.T) { + var h http + h.dur = 1 * time.Millisecond + h.timeout = time.Second + h.semaphore = semaphore.NewWeighted(1) + h.transport = &ghttp.Transport{} + h.maxAttempts = 3 + var count int + srv := httptest.NewServer(ghttp.HandlerFunc(func(w ghttp.ResponseWriter, r *ghttp.Request) { + count++ + w.WriteHeader(ghttp.StatusBadGateway) + })) + defer srv.Close() + resp, err := h.Deliver(context.Background(), NewHTTPRequest(ghttp.MethodPost, srv.URL, map[string]string{}, nil)) + assert.Error(t, err, ErrTooManyAttempts) + assert.NotNil(t, resp) + assert.Equal(t, ghttp.StatusBadGateway, resp.StatusCode) + assert.Equal(t, uint(3), resp.Attempts) + assert.True(t, resp.Latency > 0) +} + +type testRecord struct { + req Request + resp *Response +} + +var _ Recorder = (*testRecord)(nil) + +func (r *testRecord) OnResponse(ctx context.Context, req Request, resp *Response) { + r.req = req + r.resp = resp +} + +func TestHTTPRecorder(t *testing.T) { + var h http + var tr testRecord + h.recorder = &tr + h.dur = 1 * time.Millisecond + h.timeout = time.Second + h.semaphore = semaphore.NewWeighted(1) + h.transport = &ghttp.Transport{} + h.maxAttempts = 3 + var count int + srv := httptest.NewServer(ghttp.HandlerFunc(func(w ghttp.ResponseWriter, r *ghttp.Request) { + count++ + w.WriteHeader(ghttp.StatusBadGateway) + })) + defer srv.Close() + resp, err := h.Deliver(context.Background(), NewHTTPRequest(ghttp.MethodPost, srv.URL, map[string]string{}, nil)) + assert.Error(t, err, ErrTooManyAttempts) + assert.NotNil(t, resp) + assert.Equal(t, ghttp.StatusBadGateway, resp.StatusCode) + assert.NotNil(t, tr.req) + assert.NotNil(t, tr.resp) + assert.Equal(t, srv.URL, tr.req.URL()) + assert.Equal(t, ghttp.StatusBadGateway, tr.resp.StatusCode) +} + +func TestHTTPTimeout(t *testing.T) { + var h http + h.timeout = time.Millisecond * 500 + h.dur = 1 * time.Millisecond + h.semaphore = semaphore.NewWeighted(1) + h.transport = &ghttp.Transport{} + h.maxAttempts = 3 + srv := httptest.NewServer(ghttp.HandlerFunc(func(w ghttp.ResponseWriter, r *ghttp.Request) { + time.Sleep(time.Second) + w.WriteHeader(ghttp.StatusOK) + })) + defer srv.Close() + resp, err := h.Deliver(context.Background(), NewHTTPRequest(ghttp.MethodPost, srv.URL, map[string]string{}, nil)) + assert.Error(t, err, context.DeadlineExceeded) + assert.Nil(t, resp) +} + +func TestHTTPMaxAttempts(t *testing.T) { + var h http + var tr testRecord + h.recorder = &tr + h.dur = 1 * time.Millisecond + h.timeout = time.Second + h.semaphore = semaphore.NewWeighted(1) + h.transport = &ghttp.Transport{} + h.maxAttempts = 1 + var count int + srv := httptest.NewServer(ghttp.HandlerFunc(func(w ghttp.ResponseWriter, r *ghttp.Request) { + count++ + w.WriteHeader(ghttp.StatusBadGateway) + })) + defer srv.Close() + resp, err := h.Deliver(context.Background(), NewHTTPGetRequest(srv.URL, nil)) + assert.Error(t, err, ErrTooManyAttempts) + assert.NotNil(t, resp) + assert.Equal(t, ghttp.StatusBadGateway, resp.StatusCode) + assert.NotNil(t, tr.req) + assert.NotNil(t, tr.resp) + assert.Equal(t, srv.URL, tr.req.URL()) + assert.Equal(t, ghttp.StatusBadGateway, tr.resp.StatusCode) + assert.Equal(t, uint(1), tr.resp.Attempts) +} + +type testBackoff struct { + count uint +} + +func (t *testBackoff) BackOff(attempt uint) time.Duration { + t.count++ + return time.Millisecond +} + +func TestHTTPBackoff(t *testing.T) { + var h http + var tr testRecord + var tb testBackoff + h.recorder = &tr + h.dur = 1 * time.Millisecond + h.timeout = time.Second + h.semaphore = semaphore.NewWeighted(1) + h.transport = &ghttp.Transport{} + h.maxAttempts = 3 + h.backoff = &tb + var count int + srv := httptest.NewServer(ghttp.HandlerFunc(func(w ghttp.ResponseWriter, r *ghttp.Request) { + count++ + w.WriteHeader(ghttp.StatusBadGateway) + })) + defer srv.Close() + resp, err := h.Deliver(context.Background(), NewHTTPGetRequest(srv.URL, nil)) + assert.Error(t, err, ErrTooManyAttempts) + assert.NotNil(t, resp) + assert.Equal(t, ghttp.StatusBadGateway, resp.StatusCode) + assert.NotNil(t, tr.req) + assert.NotNil(t, tr.resp) + assert.Equal(t, srv.URL, tr.req.URL()) + assert.Equal(t, ghttp.StatusBadGateway, tr.resp.StatusCode) + assert.Equal(t, uint(3), tr.resp.Attempts) + assert.Equal(t, uint(2), tb.count) +} diff --git a/slice/slice.go b/slice/slice.go new file mode 100644 index 0000000..e90bab2 --- /dev/null +++ b/slice/slice.go @@ -0,0 +1,57 @@ +package slice + +import "strings" + +type withOpts struct { + caseInsensitive bool +} + +type withOptsFunc func(opts *withOpts) + +// WithCaseInsensitive will make the contains functions case insensitive. +func WithCaseInsensitive() withOptsFunc { + return func(opts *withOpts) { + opts.caseInsensitive = true + } +} + +// Contains returns true if the slice contains the value. +func Contains(slice []string, val string, opts ...withOptsFunc) bool { + var withOpts withOpts + for _, opt := range opts { + opt(&withOpts) + } + for _, s := range slice { + if withOpts.caseInsensitive { + if strings.EqualFold(s, val) { + return true + } + } else { + if s == val { + return true + } + } + } + return false +} + +// ContainsAny returns true if the slice contain any of the values [val]. +func ContainsAny(slice []string, val ...string) bool { + for _, s := range slice { + if Contains(val, s) { + return true + } + } + return false +} + +// Omit returns a new slice with the values [val] omitted. +func Omit(slice []string, val ...string) []string { + var result []string + for _, s := range slice { + if !Contains(val, s) { + result = append(result, s) + } + } + return result +} diff --git a/slice/slice_test.go b/slice/slice_test.go new file mode 100644 index 0000000..fce2ffb --- /dev/null +++ b/slice/slice_test.go @@ -0,0 +1,46 @@ +package slice + +import ( + "testing" +) + +func TestContains(t *testing.T) { + slice := []string{"a", "b", "c"} + if !Contains(slice, "a") { + t.Errorf("expected true, got false") + } + if Contains(slice, "d") { + t.Errorf("expected false, got true") + } +} + +func TestContainsCaseInsensitive(t *testing.T) { + slice := []string{"a", "b", "c"} + if !Contains(slice, "a", WithCaseInsensitive()) { + t.Errorf("expected true, got false") + } + if !Contains(slice, "A", WithCaseInsensitive()) { + t.Errorf("expected true, got false") + } + if Contains(slice, "d") { + t.Errorf("expected false, got true") + } +} + +func TestContainsAny(t *testing.T) { + slice := []string{"a", "b", "c"} + if !ContainsAny(slice, "a", "d") { + t.Errorf("expected true, got false") + } + if ContainsAny(slice, "d", "e") { + t.Errorf("expected false, got true") + } +} + +func TestOmit(t *testing.T) { + slice := []string{"a", "b", "c"} + result := Omit(slice, "a", "c") + if len(result) != 1 || result[0] != "b" { + t.Errorf(`expected ["b"], got %v`, result) + } +} diff --git a/string/hash.go b/string/hash.go new file mode 100644 index 0000000..4fd358a --- /dev/null +++ b/string/hash.go @@ -0,0 +1,59 @@ +package string + +import ( + "encoding/json" + "fmt" + "hash/fnv" + "strconv" + "strings" + + xxhash "github.com/cespare/xxhash/v2" + gstr "github.com/savsgio/gotils/strconv" +) + +func NewHash64(val ...interface{}) uint64 { + sha := xxhash.New() + for _, v := range val { + switch r := v.(type) { + case string: + sha.WriteString(r) + case int, int8, int16, int32, int64: + sha.WriteString(fmt.Sprintf("%d", r)) + case float32, float64: + sha.WriteString(fmt.Sprintf("%f", r)) + case bool: + sha.WriteString(strconv.FormatBool(r)) + default: + buf, _ := json.Marshal(r) + sha.Write(buf) + } + } + return sha.Sum64() +} + +// NewHash returns a hash of one or more input variables using xxhash algorithm +func NewHash(val ...interface{}) string { + v := fmt.Sprintf("%x", NewHash64(val...)) + if len(v) == 16 { + return v + } + return strings.Repeat("0", 16-len(v)) + v +} + +// FNV1Hash will take a string and return a FNV-1 hash value as a uint32 +func FNV1Hash(val string) uint32 { + h := fnv.New32() + h.Write([]byte(val)) + return h.Sum32() +} + +// Modulo will take the value and return a modulo with the num length +func Modulo(value string, num int) int { + hasher := fnv.New32a() + hasher.Write(gstr.S2B(value)) + partition := int(hasher.Sum32()) % num + if partition < 0 { + partition = -partition + } + return partition +} diff --git a/string/hash_test.go b/string/hash_test.go new file mode 100644 index 0000000..bce37b6 --- /dev/null +++ b/string/hash_test.go @@ -0,0 +1,49 @@ +package string + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewHash(t *testing.T) { + assert.Equal(t, "b7b41276360564d4", NewHash("1")) + assert.Equal(t, "d7c9b97948142e4a", NewHash(true)) + assert.Equal(t, "ea8842e9ea2638fa", NewHash("hi")) + assert.Equal(t, "2e1472b57af294d1", NewHash(map[string]any{})) + assert.Equal(t, "02817afd559a4122", NewHash(1, true, "hi", map[string]any{})) + assert.Equal(t, "3ec9e10063179f3a", NewHash(nil)) + assert.Equal(t, "985b4bad3b2d15ee", NewHash("order_number", "6520692e92cb3f002353975c")) + + assert.Equal(t, "54789c6c18ea9933", NewHash("order_number", "627c06f38433f70025baba37")) + assert.Equal(t, "cf1ed9ce4383e878", NewHash("order_number", "627c08068433f70025baba84")) + assert.Equal(t, "4e0cb433f75e37a4", NewHash("order_number", "637e6de64cf72300244423d3")) + assert.Equal(t, "398e48e3cea51ab3", NewHash("order_number", "63f79ea5b233d000252414a5")) + assert.Equal(t, "680eea2ab6c7f3b1", NewHash("order_number", "6435790607b87d002407783e")) + assert.Equal(t, "0ee9fc3c7d732f71", NewHash("order_number", "64399fcc07b87d0024079ff9")) + assert.Equal(t, "8daa95f03762b484", NewHash("order_number", "64b82027ee46e20024b0050b")) + assert.Equal(t, "e725148afb79ec95", NewHash("order_number", "64b84f58ee46e20024b00b54")) + assert.Equal(t, "582060db3bdc366a", NewHash("order_number", "63ec0cefb233d0002523b2b9")) +} + +func TestNewHash64(t *testing.T) { + assert.Equal(t, uint64(0xb7b41276360564d4), NewHash64("1")) + assert.Equal(t, uint64(0xd7c9b97948142e4a), NewHash64(true)) + assert.Equal(t, uint64(0xea8842e9ea2638fa), NewHash64("hi")) + assert.Equal(t, uint64(0x2e1472b57af294d1), NewHash64(map[string]any{})) + assert.Equal(t, uint64(0x2817afd559a4122), NewHash64(1, true, "hi", map[string]any{})) + assert.Equal(t, uint64(0x3ec9e10063179f3a), NewHash64(nil)) + assert.Equal(t, uint64(0x985b4bad3b2d15ee), NewHash64("order_number", "6520692e92cb3f002353975c")) + assert.Equal(t, uint64(0x54789c6c18ea9933), NewHash64("order_number", "627c06f38433f70025baba37")) + assert.Equal(t, uint64(0xcf1ed9ce4383e878), NewHash64("order_number", "627c08068433f70025baba84")) +} + +func TestFNV1Hash(t *testing.T) { + assert.Equal(t, uint32(0x42f53a8d), FNV1Hash("order_number")) +} + +func TestModulo(t *testing.T) { + assert.Equal(t, 7, Modulo("order_number", 10)) + assert.Equal(t, 5, Modulo("order_number", 11)) + assert.Equal(t, 0, Modulo("order_number", 1)) +} diff --git a/string/http.go b/string/http.go new file mode 100644 index 0000000..fb357e5 --- /dev/null +++ b/string/http.go @@ -0,0 +1,20 @@ +package string + +import ( + "net/http" + "strings" +) + +// MaskHeaders will return a stringified version of headers +// masking the headers passed in by name +func MaskHeaders(h http.Header, maskHeaders []string) map[string]string { + hh := make(map[string]string, len(h)) + for k, v := range h { + if Contains(maskHeaders, k, true) { + hh[k] = Mask(strings.Join(v, ", ")) + continue + } + hh[k] = strings.Join(v, ", ") + } + return hh +} diff --git a/string/http_test.go b/string/http_test.go new file mode 100644 index 0000000..71d013d --- /dev/null +++ b/string/http_test.go @@ -0,0 +1,45 @@ +package string + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMaskHeaders(t *testing.T) { + testCases := []struct { + name string + headers http.Header + maskHeaders []string + expected map[string]string + }{ + { + name: "mask one header", + headers: http.Header{"a": []string{"b"}}, + maskHeaders: []string{"a"}, + expected: map[string]string{"a": "*"}, + }, + { + name: "do not mask any headers", + headers: http.Header{"a": []string{"b"}}, + maskHeaders: []string{"c"}, + expected: map[string]string{"a": "b"}, + }, + { + name: "mask multiple headers", + headers: http.Header{"a": []string{"b"}, "b": []string{"c"}}, + maskHeaders: []string{"a", "b"}, + expected: map[string]string{"a": "*", "b": "*"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := MaskHeaders(tc.headers, tc.maskHeaders) + assert.Equal(t, tc.expected, result) + fmt.Println(tc.name) + }) + } +} diff --git a/string/interpolate.go b/string/interpolate.go new file mode 100644 index 0000000..01f1fbd --- /dev/null +++ b/string/interpolate.go @@ -0,0 +1,55 @@ +package string + +import ( + "fmt" + "regexp" + "strings" +) + +var re = regexp.MustCompile(`(\$?{(.*?)})`) + +// InterpolateString replaces { } in string with values from environment maps. +func InterpolateString(val string, env ...map[string]interface{}) (string, error) { + if val == "" { + return val, nil + } + var err error + val = re.ReplaceAllStringFunc(val, func(s string) string { + tok := re.FindStringSubmatch(s) + key := tok[2] + def := s + var required bool + if strings.HasPrefix(key, "!") { + key = key[1:] + required = true + } + if idx := strings.Index(key, ":-"); idx != -1 { + def = key[idx+2:] + key = key[:idx] + } + var v interface{} + for _, e := range env { + if nv, ok := e[key]; ok { + v = nv + break + } + } + if v == nil { + if required { + err = fmt.Errorf("required value not found for key '%s'", key) + } + return def + } + if v == "" { + if required { + err = fmt.Errorf("required value not found for key '%s'", key) + } + return def + } + return fmt.Sprint(v) + }) + if err != nil { + return "", err + } + return val, nil +} diff --git a/string/interpolate_test.go b/string/interpolate_test.go new file mode 100644 index 0000000..e1e0484 --- /dev/null +++ b/string/interpolate_test.go @@ -0,0 +1,39 @@ +package string + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInterpolateStrings(t *testing.T) { + type testCase struct { + input string + env []map[string]interface{} + expectedVal string + expectedErr error + } + + testCases := []testCase{ + {input: "abc", env: nil, expectedVal: "abc", expectedErr: nil}, + {input: "", env: nil, expectedVal: "", expectedErr: nil}, + {input: "this is a {test}", env: []map[string]interface{}{{"test": "TEST"}}, expectedVal: "this is a TEST", expectedErr: nil}, + {input: "this is a {test} {notfound}", env: []map[string]interface{}{{"test": "TEST"}}, expectedVal: "this is a TEST {notfound}", expectedErr: nil}, + {input: "this is a {test:-notfound}", env: []map[string]interface{}{{"foo": "TEST"}}, expectedVal: "this is a notfound", expectedErr: nil}, + {input: "this is a {test:-fail}", env: []map[string]interface{}{{"test": "TEST"}}, expectedVal: "this is a TEST", expectedErr: nil}, + {input: "this is a {test}", env: []map[string]interface{}{{"test": 123}}, expectedVal: "this is a 123", expectedErr: nil}, + {input: "this is a {test}", env: []map[string]interface{}{{"test": nil}}, expectedVal: "this is a {test}", expectedErr: nil}, + {input: "this is a {test}", env: []map[string]interface{}{{"test": ""}}, expectedVal: "this is a {test}", expectedErr: nil}, + {input: "this is a {!test}", env: []map[string]interface{}{{"test": ""}}, expectedVal: "", expectedErr: fmt.Errorf("required value not found for key 'test'")}, + {input: "this is a {!test}", env: []map[string]interface{}{{"test": nil}}, expectedVal: "", expectedErr: fmt.Errorf("required value not found for key 'test'")}, + {input: "this is a ${test}", env: []map[string]interface{}{{"test": nil}}, expectedVal: "this is a ${test}", expectedErr: nil}, + {input: "this is a ${test:-foo}", env: []map[string]interface{}{{"test": "foo"}}, expectedVal: "this is a foo", expectedErr: nil}, + } + + for _, tc := range testCases { + actualVal, actualErr := InterpolateString(tc.input, tc.env...) + assert.Equal(t, tc.expectedVal, actualVal) + assert.Equal(t, tc.expectedErr, actualErr) + } +} diff --git a/string/json.go b/string/json.go new file mode 100644 index 0000000..de36573 --- /dev/null +++ b/string/json.go @@ -0,0 +1,20 @@ +package string + +import ( + "encoding/json" +) + +// JSONStringify converts any value to a JSON string. +func JSONStringify(val any, pretty ...bool) string { + var buf []byte + var err error + if len(pretty) > 0 && pretty[0] { + buf, err = json.MarshalIndent(val, "", " ") + } else { + buf, err = json.Marshal(val) + } + if err != nil { + panic(err) + } + return string(buf) +} diff --git a/string/mask.go b/string/mask.go new file mode 100644 index 0000000..641399a --- /dev/null +++ b/string/mask.go @@ -0,0 +1,93 @@ +package string + +import ( + "fmt" + "net/url" + "regexp" + "sort" + "strings" +) + +// Mask will mask a string by replacing the middle with asterisks. +func Mask(s string) string { + l := len(s) + if l == 0 { + return s + } + if l == 1 { + return "*" + } + h := int(l / 2) + return s[0:h] + strings.Repeat("*", l-h) +} + +// MaskURL returns a masked version of the URL string attempting to hide sensitive information. +func MaskURL(urlString string) (string, error) { + u, err := url.Parse(urlString) + if err != nil { + return "", fmt.Errorf("failed to parse URL: %w", err) + } + var str strings.Builder + str.WriteString(u.Scheme) + str.WriteString("://") + if u.User != nil { + str.WriteString(Mask(u.User.Username())) + pass, ok := u.User.Password() + if ok { + str.WriteString(":") + str.WriteString(Mask(pass)) + } + str.WriteString("@") + } + str.WriteString(u.Host) + p := u.Path + if p != "/" && p != "" { + str.WriteString("/") + if len(p) > 1 && p[0] == '/' { + str.WriteString(Mask(p[1:])) + } + } + var qs []string + for k, v := range u.Query() { + qs = append(qs, fmt.Sprintf("%s=%s", k, Mask(strings.Join(v, ",")))) + } + sort.Strings(qs) + if len(qs) > 0 { + str.WriteString("?") + str.WriteString(strings.Join(qs, "&")) + } + return str.String(), nil +} + +// MaskEmail masks the email address attempting to hide sensitive information. +func MaskEmail(val string) string { + tok := strings.Split(val, "@") + dot := strings.Split(tok[1], ".") + return Mask(tok[0]) + "@" + Mask(dot[0]) + "." + strings.Join(dot[1:], ".") +} + +var isURL = regexp.MustCompile(`^(\w+)://`) +var isEmail = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) +var isJWT = regexp.MustCompile(`^[a-zA-Z0-9-_]+\.[a-zA-Z0-9-_]+\.[a-zA-Z0-9-_]+$`) + +// MaskArguments masks sensitive information in the given arguments. +func MaskArguments(args []string) []string { + masked := make([]string, len(args)) + for i, arg := range args { + if isURL.MatchString(arg) { + u, err := MaskURL(arg) + if err == nil { + masked[i] = u + } else { + masked[i] = Mask(arg) + } + } else if isEmail.MatchString(arg) { + masked[i] = MaskEmail(arg) + } else if isJWT.MatchString(arg) { + masked[i] = Mask(arg) + } else { + masked[i] = arg + } + } + return masked +} diff --git a/string/mask_test.go b/string/mask_test.go new file mode 100644 index 0000000..33a68ef --- /dev/null +++ b/string/mask_test.go @@ -0,0 +1,97 @@ +package string + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMasking(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + {"foobar", "foo***"}, + {"foo", "f**"}, + {"f", "*"}, + {"", ""}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("Mask(%q)", tc.input), func(t *testing.T) { + output := Mask(tc.input) + assert.Equal(t, tc.expected, output) + }) + } +} + +func TestMaskUrl(t *testing.T) { + u, err := MaskURL("http://user:password@localhost:8080/path?query=1") + assert.NoError(t, err) + assert.Equal(t, "http://us**:pass****@localhost:8080/pa**?query=*", u) + + u, err = MaskURL("snowflake://FOO:thisisapassword@TFLXCJY-LU41011/TEST/PUBLIC") + assert.NoError(t, err) + assert.Equal(t, "snowflake://F**:thisisa********@TFLXCJY-LU41011/TEST/******", u) + + u, err = MaskURL("s3://bucket/folder?region=us-west-2&access-key-id=AKIAIOSFODNN7EXAMPLE&secret-access-key=wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY") + assert.NoError(t, err) + assert.Equal(t, "s3://bucket/fol***?access-key-id=AKIAIOSFOD**********®ion=us-w*****&secret-access-key=wJalrXUtnFEMI/K7MDEN********************", u) +} + +func TestMaskArguments(t *testing.T) { + tests := []struct { + name string + args []string + want []string + }{ + { + name: "Mask URL", + args: []string{"https://alice:bob@example.com/a/b?foo=bar", "http://user:password@localhost:8080/path?query=1", "s3://bucket/folder?region=us-west-2&access-key-id=AKIAIOSFODNN7EXAMPLE&secret-access", "mysql://user:password@localhost:3306/db?query=1"}, + want: []string{"https://al***:b**@example.com/a**?foo=b**", "http://us**:pass****@localhost:8080/pa**?query=*", "s3://bucket/fol***?access-key-id=AKIAIOSFOD**********®ion=us-w*****&secret-access=", "mysql://us**:pass****@localhost:3306/d*?query=*"}, + }, + { + name: "Mask Email", + args: []string{"user@example.com", "another.user@example.com"}, + want: []string{"us**@exa****.com", "anothe******@exa****.com"}, + }, + { + name: "Mask JWT", + args: []string{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"}, + want: []string{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpv******************************************************************************"}, + }, + { + name: "No Masking Needed", + args: []string{"hello", "world"}, + want: []string{"hello", "world"}, + }, + { + name: "Mixed Arguments", + args: []string{"http://example.com", "user@example.com", "hello"}, + want: []string{"http://example.com", "us**@exa****.com", "hello"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MaskArguments(tt.args) + assert.Equal(t, tt.want, got) + }) + } + +} + +func TestMaskedEmail(t *testing.T) { + tests := []struct { + email string + expected string + }{ + {"test@example.com", "te**@exa****.com"}, + {"user@example.co.uk", "us**@exa****.co.uk"}, + } + for _, test := range tests { + result := MaskEmail(test.email) + assert.Equal(t, test.expected, result) + } +} diff --git a/string/random.go b/string/random.go new file mode 100644 index 0000000..cb6b427 --- /dev/null +++ b/string/random.go @@ -0,0 +1,58 @@ +package string + +import ( + "crypto/rand" + "fmt" + "io" + "math/big" +) + +// Adapted from https://elithrar.github.io/article/generating-secure-random-numbers-crypto-rand/ + +func init() { + assertAvailablePRNG() +} + +func assertAvailablePRNG() { + // Assert that a cryptographically secure PRNG is available. + // Panic otherwise. + buf := make([]byte, 1) + + _, err := io.ReadFull(rand.Reader, buf) + if err != nil { + panic(fmt.Sprintf("crypto/rand is unavailable: Read() failed with %#v", err)) + } +} + +// GenerateRandomBytes returns securely generated random bytes. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + // Note that err == nil only if we read len(b) bytes. + if err != nil { + return nil, err + } + + return b, nil +} + +// GenerateRandomString returns a securely generated random string. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomString(n int) (string, error) { + const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-" + ret := make([]byte, n) + for i := 0; i < n; i++ { + num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + if err != nil { + return "", err + } + ret[i] = letters[num.Int64()] + } + + return string(ret), nil +} diff --git a/string/random_test.go b/string/random_test.go new file mode 100644 index 0000000..c25d7d1 --- /dev/null +++ b/string/random_test.go @@ -0,0 +1,27 @@ +package string + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGenerateRandomString(t *testing.T) { + str, err := GenerateRandomString(1) + assert.NoError(t, err) + assert.Len(t, str, 1) + t.Log(str) + str, err = GenerateRandomString(10) + assert.NoError(t, err) + assert.Len(t, str, 10) + t.Log(str) +} + +func TestGenerateRandomBytes(t *testing.T) { + str, err := GenerateRandomBytes(1) + assert.NoError(t, err) + assert.Len(t, str, 1) + str, err = GenerateRandomBytes(10) + assert.NoError(t, err) + assert.Len(t, str, 10) +} diff --git a/string/sha.go b/string/sha.go new file mode 100644 index 0000000..358935f --- /dev/null +++ b/string/sha.go @@ -0,0 +1,16 @@ +package string + +import ( + "crypto/sha256" + "encoding/hex" +) + +// SHA256 will return a sha 256 hash of the data in hex format +func SHA256(data []byte, extra ...[]byte) string { + h := sha256.New() + h.Write(data) + for _, d := range extra { + h.Write(d) + } + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/string/sha_test.go b/string/sha_test.go new file mode 100644 index 0000000..c41965f --- /dev/null +++ b/string/sha_test.go @@ -0,0 +1,47 @@ +package string + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSHA256(t *testing.T) { + testCases := []struct { + name string + data []byte + extras [][]byte + want string + }{ + { + name: "single data", + data: []byte("foobar"), + want: "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2", + }, + { + name: "data with extra", + data: []byte("foo"), + extras: [][]byte{[]byte("bar")}, + want: "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2", + }, + { + name: "data with multiple extras", + data: []byte("f"), + extras: [][]byte{ + []byte("o"), + []byte("o"), + []byte("b"), + []byte("a"), + []byte("r"), + }, + want: "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := SHA256(tc.data, tc.extras...) + assert.Equal(t, tc.want, got, "unexpected hash result") + }) + } +} diff --git a/string/string.go b/string/string.go new file mode 100644 index 0000000..581ac4f --- /dev/null +++ b/string/string.go @@ -0,0 +1,29 @@ +package string + +import "strings" + +// StringPointer will set the pointer to nil if the string is not nil but an empty string +func StringPointer(v string) *string { + if v != "" { + nv := strings.TrimSpace(v) + if nv == "" { + return nil + } else { + return &nv + } + } + return nil +} + +// ClearEmptyStringPointer will set the pointer to nil if the string is not nil but an empty string +func ClearEmptyStringPointer(v *string) *string { + if v != nil { + nv := strings.TrimSpace(*v) + if nv == "" { + return nil + } else { + return &nv + } + } + return nil +} diff --git a/string/string_test.go b/string/string_test.go new file mode 100644 index 0000000..d832ad0 --- /dev/null +++ b/string/string_test.go @@ -0,0 +1,27 @@ +package string + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStringPointer(t *testing.T) { + assert.Equal(t, "hi", *StringPointer("hi")) + assert.Equal(t, "hi", *StringPointer(" hi")) + assert.Equal(t, "hi", *StringPointer("hi ")) + assert.Equal(t, "hi", *StringPointer(" hi ")) + assert.Equal(t, "hi", *StringPointer("hi")) + assert.Equal(t, "hi", *StringPointer("hi")) + assert.Equal(t, "hi", *StringPointer("hi")) +} + +func TestClearEmptyStringPointer(t *testing.T) { + assert.Equal(t, "hi", *StringPointer("hi")) + assert.Equal(t, "hi", *StringPointer(" hi")) + assert.Equal(t, "hi", *StringPointer("hi ")) + assert.Equal(t, "hi", *StringPointer(" hi ")) + assert.Equal(t, "hi", *StringPointer("hi")) + assert.Equal(t, "hi", *StringPointer("hi")) + assert.Equal(t, "hi", *StringPointer("hi")) +} diff --git a/string/util.go b/string/util.go new file mode 100644 index 0000000..4f5f877 --- /dev/null +++ b/string/util.go @@ -0,0 +1,11 @@ +package string + +import "github.com/agentuity/go-common/slice" + +// Contains returns true if the search string is found in the slice +func Contains(needle []string, haystack string, caseInsensitive bool) bool { + if caseInsensitive { + return slice.Contains(needle, haystack, slice.WithCaseInsensitive()) + } + return slice.Contains(needle, haystack) +} diff --git a/string/util_test.go b/string/util_test.go new file mode 100644 index 0000000..fbd10c7 --- /dev/null +++ b/string/util_test.go @@ -0,0 +1,27 @@ +package string + +import ( + "testing" +) + +func TestContains(t *testing.T) { + tests := []struct { + haystack []string + needle string + caseInsensitive bool + expected bool + }{ + {[]string{"A", "B", "C"}, "A", false, true}, + {[]string{"A", "B", "C"}, "D", false, false}, + {[]string{"A", "B", "C"}, "a", true, true}, + {[]string{"A", "B", "C"}, "f", true, false}, + } + + for _, test := range tests { + result := Contains(test.haystack, test.needle, test.caseInsensitive) + if result != test.expected { + t.Errorf("Contains(%v, %v, %v) returned %v, expected %v", + test.haystack, test.needle, test.caseInsensitive, result, test.expected) + } + } +} diff --git a/sys/docker.go b/sys/docker.go new file mode 100644 index 0000000..b37dc06 --- /dev/null +++ b/sys/docker.go @@ -0,0 +1,22 @@ +package sys + +import ( + "os" + "strings" +) + +// IsRunningInsideDocker returns true if the process is running inside a docker container. +func IsRunningInsideDocker() bool { + if Exists("/.dockerenv") { + return true + } + + if Exists("/proc/1/cgroup") { + buf, _ := os.ReadFile("/proc/1/cgroup") + if len(buf) > 0 { + contents := strings.TrimSpace(string(buf)) + return strings.Contains(contents, "docker") || strings.Contains(contents, "lxc") || strings.Contains(contents, "rt") + } + } + return false +} diff --git a/sys/errors.go b/sys/errors.go new file mode 100644 index 0000000..ba33408 --- /dev/null +++ b/sys/errors.go @@ -0,0 +1,35 @@ +package sys + +import ( + "os" + "runtime/pprof" + "strings" + + "github.com/agentuity/go-common/logger" + "github.com/cockroachdb/errors" +) + +// The call stack here is usually: +// - panicError +// - RecoverPanic +// - panic() +// so RecoverPanic should pop three frames. +var depth = 3 + +// RecoverPanic recovers from a panic and logs the error along with the current goroutines. +func RecoverPanic(logger logger.Logger) { + if r := recover(); r != nil { + v := panicError(depth, r) + var str strings.Builder + pprof.Lookup("goroutine").WriteTo(&str, 2) + logger.Error("a panic has occurred: %s\ncurrent goroutines:\n\n%s", v, str.String()) + os.Exit(2) // same exit code as panic + } +} + +func panicError(depth int, r interface{}) error { + if err, ok := r.(error); ok { + return errors.WithStackDepth(err, depth+1) + } + return errors.NewWithDepthf(depth+1, "panic: %v", r) +} diff --git a/sys/io.go b/sys/io.go new file mode 100644 index 0000000..afda704 --- /dev/null +++ b/sys/io.go @@ -0,0 +1,425 @@ +package sys + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "encoding/json" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strings" +) + +// CopyFile will copy src to dst +func CopyFile(src, dst string) (int64, error) { + sourceFileStat, err := os.Stat(src) + if err != nil { + return 0, err + } + + if !sourceFileStat.Mode().IsRegular() { + return 0, fmt.Errorf("%s is not a regular file", src) + } + + source, err := os.Open(src) + if err != nil { + return 0, err + } + defer source.Close() + + destination, err := os.Create(dst) + if err != nil { + return 0, err + } + defer destination.Close() + nBytes, err := io.Copy(destination, source) + return nBytes, err +} + +// CopyDir will copy all files recursively from src to dst +func CopyDir(src string, dst string) error { + var err error + var fds []os.DirEntry + var srcinfo os.FileInfo + + if srcinfo, err = os.Stat(src); err != nil { + return fmt.Errorf("error reading %s: %w", src, err) + } + + if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil { + return fmt.Errorf("error mkdir %s: %w", dst, err) + } + + if fds, err = os.ReadDir(src); err != nil { + return fmt.Errorf("error readdir %s: %w", src, err) + } + for _, fd := range fds { + srcfp := path.Join(src, fd.Name()) + dstfp := path.Join(dst, fd.Name()) + + if fd.IsDir() { + if err = CopyDir(srcfp, dstfp); err != nil { + return fmt.Errorf("error copying directory from %s to %s: %w", srcfp, dstfp, err) + } + } else { + if _, err = CopyFile(srcfp, dstfp); err != nil { + return fmt.Errorf("error copying file from %s to %s: %w", srcfp, dstfp, err) + } + } + } + return nil +} + +// Exists returns true if the filename or directory specified by fn exists. +func Exists(fn string) bool { + if _, err := os.Stat(fn); os.IsNotExist(err) { + return false + } + return true +} + +// ListDir will return an array of files recursively walking into sub directories +func ListDir(dir string) ([]string, error) { + files, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + res := make([]string, 0) + for _, file := range files { + if file.IsDir() { + newres, err := ListDir(filepath.Join(dir, file.Name())) + if err != nil { + return nil, err + } + res = append(res, newres...) + } else { + if file.Name() == ".DS_Store" { + continue + } + res = append(res, filepath.Join(dir, file.Name())) + } + } + return res, nil +} + +// GzipFile compresses a file using gzip. +func GzipFile(filepath string) error { + infile, err := os.Open(filepath) + if err != nil { + return fmt.Errorf("open: %w", err) + } + defer infile.Close() + + outfile, err := os.Create(filepath + ".gz") + if err != nil { + return fmt.Errorf("create: %w", err) + } + defer outfile.Close() + + zr := gzip.NewWriter(outfile) + defer zr.Close() + _, err = io.Copy(zr, infile) + if err != nil { + return fmt.Errorf("copy: %w", err) + } + + return nil +} + +func TarGz(srcDir string, outfile *os.File) error { + zr := gzip.NewWriter(outfile) + tw := tar.NewWriter(zr) + + baseDir := filepath.Base(srcDir) + // walk through every file in the folder + filepath.Walk(srcDir, func(file string, fi os.FileInfo, _ error) error { + // generate tar header + header, err := tar.FileInfoHeader(fi, file) + if err != nil { + return err + } + + header.Name = baseDir + strings.Replace(filepath.ToSlash(file), srcDir, "", -1) + + // write header + if err := tw.WriteHeader(header); err != nil { + return err + } + // if not a dir, write file content + if !fi.IsDir() { + data, err := os.Open(file) + if err != nil { + return err + } + if _, err := io.Copy(tw, data); err != nil { + return err + } + } + return nil + }) + + // produce tar + if err := tw.Close(); err != nil { + return err + } + + // produce gzip + if err := zr.Close(); err != nil { + return err + } + + return nil +} + +// JSONEncoder is an encoder that will allow you to encode one or more objects as JSON newline delimited output +type JSONEncoder interface { + // Encode will encode v as a new line delimited JSON encoded string + Encode(v any) error + // Close a stream + Close() error + // Count returns the number of records written + Count() int +} + +type JSONDecoder interface { + Decode(v any) error + // More returns true if there are more items in the stream + More() bool + // Count returns the number of records read + Count() int + // Close a stream + Close() error +} + +type ndjsonWriter struct { + out *os.File + gz *gzip.Writer + enc *json.Encoder + count int +} + +var _ JSONEncoder = (*ndjsonWriter)(nil) + +func (n *ndjsonWriter) Close() error { + if n.gz != nil { + n.gz.Close() + n.gz = nil + } + if n.out != nil { + n.out.Close() + n.out = nil + } + return nil +} + +func (n *ndjsonWriter) Encode(v any) error { + n.count++ + return n.enc.Encode(v) +} + +func (n *ndjsonWriter) Count() int { + return n.count +} + +type ndjsonReader struct { + in *os.File + gr *gzip.Reader + dec *json.Decoder + count int +} + +var _ JSONDecoder = (*ndjsonReader)(nil) + +func (n *ndjsonReader) Count() int { + return n.count +} + +func (n *ndjsonReader) Close() error { + if n.gr != nil { + n.gr.Close() + n.gr = nil + } + if n.in != nil { + n.in.Close() + n.in = nil + } + return nil +} + +func (n *ndjsonReader) More() bool { + return n.dec.More() +} + +func (n *ndjsonReader) Decode(v any) error { + if err := n.dec.Decode(v); err != nil { + return err + } + n.count++ + return nil +} + +// NewNDJSONEncoder will return a JSONEncoder which allows you to stream json as new line delimited JSON +func NewNDJSONEncoder(fn string) (JSONEncoder, error) { + out, err := os.Create(fn) + if err != nil { + return nil, fmt.Errorf("error opening: %s. %w", fn, err) + } + var o io.Writer = out + var gw *gzip.Writer + if filepath.Ext(fn) == ".gz" { + gw = gzip.NewWriter(out) + o = gw + } + jw := json.NewEncoder(o) + return &ndjsonWriter{ + out: out, + gz: gw, + enc: jw, + }, nil +} + +func NewNDJSONEncoderAppend(fn string) (JSONEncoder, error) { + out, err := os.OpenFile(fn, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + + if err != nil { + return nil, fmt.Errorf("error opening: %s. %w", fn, err) + } + var o io.Writer = out + var gw *gzip.Writer + if filepath.Ext(fn) == ".gz" { + gw = gzip.NewWriter(out) + o = gw + } + jw := json.NewEncoder(o) + return &ndjsonWriter{ + out: out, + gz: gw, + enc: jw, + }, nil +} + +// NewNDJSONDecoder returns a decoder which can be used to read JSON new line delimited files +func NewNDJSONDecoder(fn string) (JSONDecoder, error) { + in, err := os.Open(fn) + if err != nil { + return nil, fmt.Errorf("error opening: %s. %w", fn, err) + } + var i io.Reader = in + var gr *gzip.Reader + if filepath.Ext(fn) == ".gz" { + var err error + gr, err = gzip.NewReader(in) + if err != nil { + return nil, fmt.Errorf("gzip: error opening: %s. %w", fn, err) + } + i = gr + } + je := json.NewDecoder(i) + return &ndjsonReader{ + in: in, + gr: gr, + dec: je, + }, nil +} + +func WriteJSON(filename string, v any) error { + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + if err := json.NewEncoder(f).Encode(v); err != nil { + return err + } + return nil +} + +func WriteJSONLAppend(filename string, v []interface{}) error { + enc, err := NewNDJSONEncoderAppend(filename) + if err != nil { + return err + } + defer enc.Close() + for _, item := range v { + if err := enc.Encode(item); err != nil { + return err + } + } + return nil +} + +type ProcessDirWithDecoderCallback func(decoder JSONDecoder, filename string) error + +// ProcessDirWithDecoder will process all the JSON files in directory and call the callback +func ProcessDirWithDecoder(dir string, callback ProcessDirWithDecoderCallback) error { + files, err := ListDir(dir) + if err != nil { + return err + } + for _, file := range files { + filename := file + if filepath.Ext(filename) == ".json" || filepath.Ext(filename) == ".gz" { + dec, err := NewNDJSONDecoder(filename) + if err != nil { + return err + } + defer dec.Close() + if err := callback(dec, filename); err != nil { + return err + } + dec.Close() + } + } + return nil +} + +// Unzip a file to a directory +func Unzip(src, dest string, flatten bool) error { + r, err := zip.OpenReader(src) + if err != nil { + return err + } + defer r.Close() + + for _, f := range r.File { + rc, err := f.Open() + if err != nil { + return err + } + defer rc.Close() + + if flatten { + f.Name = filepath.Base(f.Name) + } + + fpath := filepath.Join(dest, f.Name) + if f.FileInfo().IsDir() && !flatten { + os.MkdirAll(fpath, os.ModePerm) + } else { + var fdir string + if lastIndex := strings.LastIndex(fpath, string(os.PathSeparator)); lastIndex > -1 { + fdir = fpath[:lastIndex] + } + + err = os.MkdirAll(fdir, os.ModePerm) + if err != nil { + return err + } + f, err := os.OpenFile( + fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) + if err != nil { + return err + } + defer f.Close() + + _, err = io.Copy(f, rc) + if err != nil { + return err + } + } + } + return nil +} diff --git a/sys/io_test.go b/sys/io_test.go new file mode 100644 index 0000000..7a77a4b --- /dev/null +++ b/sys/io_test.go @@ -0,0 +1,45 @@ +package sys + +import ( + "archive/zip" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTarGz(t *testing.T) { + + baseDir := t.TempDir() + + dir := filepath.Join(baseDir, "test") + + os.WriteFile(filepath.Join(dir, "foo"), []byte("bar"), 0644) + os.Mkdir(filepath.Join(dir, "nested"), 0755) + os.WriteFile(filepath.Join(dir, "nested", "foo2"), []byte("bar2"), 0644) + + tarball, err := os.Create(filepath.Join(baseDir, "test.tar.gz")) + assert.NoError(t, err) + assert.NoError(t, TarGz(dir, tarball)) + tarball.Close() + +} + +func TestUnzip(t *testing.T) { + baseDir := t.TempDir() + zf, err := os.Create(filepath.Join(baseDir, "foobar.zip")) + assert.NoError(t, err) + zw := zip.NewWriter(zf) + w, err := zw.Create("foo/foo.txt") + assert.NoError(t, err) + w.Write([]byte("bar")) + zw.Close() + zf.Close() + assert.NoError(t, Unzip(filepath.Join(baseDir, "foobar.zip"), baseDir, true)) + assert.True(t, Exists(filepath.Join(baseDir, "foo.txt"))) + os.Remove(filepath.Join(baseDir, "foo.txt")) + assert.NoError(t, Unzip(filepath.Join(baseDir, "foobar.zip"), baseDir, false)) + assert.False(t, Exists(filepath.Join(baseDir, "foo.txt"))) + assert.True(t, Exists(filepath.Join(baseDir, "foo", "foo.txt"))) +} diff --git a/sys/ip.go b/sys/ip.go new file mode 100644 index 0000000..30c3572 --- /dev/null +++ b/sys/ip.go @@ -0,0 +1,44 @@ +package sys + +import ( + "errors" + "net" +) + +// LocalIP will return the local ipaddress for the machine +func LocalIP() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", err + } + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + if iface.Flags&net.FlagLoopback != 0 { + continue // loopback interface + } + addrs, err := iface.Addrs() + if err != nil { + return "", err + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip == nil || ip.IsLoopback() { + continue + } + ip = ip.To4() + if ip == nil { + continue // not an ipv4 address + } + return ip.String(), nil + } + } + return "", errors.New("are you connected to the network?") +} diff --git a/sys/net.go b/sys/net.go new file mode 100644 index 0000000..3df151c --- /dev/null +++ b/sys/net.go @@ -0,0 +1,25 @@ +package sys + +import ( + "net" + "strings" +) + +// GetFreePort asks the kernel for a free open port that is ready to use. +func GetFreePort() (port int, err error) { + var a *net.TCPAddr + if a, err = net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil + } + } + return +} + +// IsLocalhost returns true if the URL is localhost or 127.0.0.1 or 0.0.0.0. +func IsLocalhost(url string) bool { + // technically 127.0.0.0 – 127.255.255.255 is the loopback range but most people use 127.0.0.1 + return strings.Contains(url, "localhost") || strings.Contains(url, "127.0.0.1") || strings.Contains(url, "0.0.0.0") +} diff --git a/sys/shutdown.go b/sys/shutdown.go new file mode 100644 index 0000000..114b2d4 --- /dev/null +++ b/sys/shutdown.go @@ -0,0 +1,14 @@ +package sys + +import ( + "os" + "os/signal" + "syscall" +) + +// CreateShutdownChannel returns a channel which can be used to block for a termination signal (SIGTERM, SIGINT, etc) +func CreateShutdownChannel() chan os.Signal { + done := make(chan os.Signal, 1) + signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + return done +}