Skip to content

Commit

Permalink
Fixed viewer issues, fixed and added tests to cmd and util packages
Browse files Browse the repository at this point in the history
Co-Authored-By: Naomi Kramer <[email protected]>
  • Loading branch information
lisaSW and caffeinatedpixel committed Jul 17, 2024
1 parent 9faeff5 commit c450224
Show file tree
Hide file tree
Showing 15 changed files with 1,355 additions and 277 deletions.
18 changes: 10 additions & 8 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ import (
var ErrMissingDatabaseName = errors.New("database name is required")
var ErrMissingConfigPath = errors.New("config path parameter is required")
var ErrTooManyArguments = errors.New("too many arguments provided")
var ErrInvalidConfigObject = errors.New("config was nil or invalid")
var ErrCurrentVersionEmpty = errors.New("current version unset")
var ErrCheckingForUpdate = errors.New("error checking for newer version of RITA")

func Commands() []*cli.Command {
return []*cli.Command{
Expand All @@ -39,21 +42,20 @@ func ConfigFlag(required bool) *cli.StringFlag {
}
}

func CheckForUpdate(cCtx *cli.Context, afs afero.Fs) error {
func CheckForUpdate(cfg *config.Config) error {
// make sure config is not nil
if cfg == nil {
return ErrInvalidConfigObject
}

// get the current version
currentVersion := config.Version

// load config file
cfg, err := config.ReadFileConfig(afs, cCtx.String("config"))
if err != nil {
return fmt.Errorf("error loading config file: %w", err)
}

// check for update if version is set
if cfg.UpdateCheckEnabled && currentVersion != "" && currentVersion != "dev" {
newer, latestVersion, err := util.CheckForNewerVersion(github.NewClient(nil), currentVersion)
if err != nil {
return fmt.Errorf("error checking for newer version of RITA: %w", err)
return fmt.Errorf("%w: %w", ErrCheckingForUpdate, err)
}
if newer {
fmt.Printf("\n\t✨ A newer version (%s) of RITA is available! https://github.com/activecm/rita/releases ✨\n\n", latestVersion)
Expand Down
131 changes: 116 additions & 15 deletions cmd/cmd_test.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
package cmd_test

import (
"bytes"
"context"
"fmt"
"log"
"os"
"path/filepath"
"testing"

"github.com/activecm/rita/v5/cmd"
"github.com/activecm/rita/v5/config"
"github.com/activecm/rita/v5/database"
"github.com/activecm/rita/v5/util"
"github.com/google/go-github/github"

"github.com/joho/godotenv"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/testcontainers/testcontainers-go"
Expand Down Expand Up @@ -91,20 +96,7 @@ func (c *CmdTestSuite) TearDownSuite() {
// func (d *DatabaseTestSuite) TearDownTest() {}

// SetupSubTest is run before each subtest
func (c *CmdTestSuite) SetupSubTest() {
t := c.T()
fmt.Println("Running setup subtest...")

// drop all databases that may have been created during subtest
if c.server != nil && c.server.Conn != nil {
dbs, err := c.server.ListImportDatabases()
require.NoError(t, err, "listing databases should not produce an error")
for _, db := range dbs {
err := c.server.DeleteSensorDB(db.Name)
require.NoError(t, err, "dropping database should not produce an error")
}
}
}
// func (c *CmdTestSuite) SetupSubTest() {}

// TearDownSubTest is run after each subtest
// func (c *CmdTestSuite) TearDownSubTest() {}
Expand Down Expand Up @@ -148,8 +140,117 @@ func setupTestApp(commands []*cli.Command, flags []cli.Flag) (*cli.App, context.
// this prevents the test from exiting when testing for errors
app.ExitErrHandler = func(_ *cli.Context, _ error) {
// add any custom test logic, or assertions or leave it blank

}

return app, ctx
}

func TestCheckForUpdate(t *testing.T) {
// set up file system interface
afs := afero.NewOsFs()

// load the config file
cfg, err := config.ReadFileConfig(afs, ConfigPath)
require.NoError(t, err, "config should load without error")

// get latest release version
latestVersion, err := util.GetLatestReleaseVersion(github.NewClient(nil), "activecm", "rita")
require.NoError(t, err, "latest release version should be retrieved without error")

tests := []struct {
name string
cfg *config.Config
updateCheckEnabled bool
currentVersion string
expectedErr error
expectedOutput string
}{
{
name: "New version available",
updateCheckEnabled: true,
cfg: cfg,
currentVersion: "v0.0.0",
expectedOutput: fmt.Sprintf("\n\t✨ A newer version (%s) of RITA is available! https://github.com/activecm/rita/releases ✨\n\n", latestVersion),
},
{
name: "Error checking for newer version",
updateCheckEnabled: true,
cfg: cfg,
currentVersion: "notaversion",
expectedErr: cmd.ErrCheckingForUpdate,
},
{
name: "Update check disabled",
updateCheckEnabled: false,
cfg: cfg,
currentVersion: "1.0.0",
},
{
name: "Current version is dev",
updateCheckEnabled: true,
cfg: cfg,
currentVersion: "dev",
},
{
name: "Current version is empty",
updateCheckEnabled: true,
cfg: cfg,
currentVersion: "",
},
{
name: "Nil config",
cfg: nil,
expectedErr: cmd.ErrInvalidConfigObject,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// set update check enabled in config
if test.cfg != nil {
test.cfg.UpdateCheckEnabled = test.updateCheckEnabled
}

// override global variables and functions
config.Version = test.currentVersion

// capture stdout
output := captureOutput(t, func() {
err := cmd.CheckForUpdate(test.cfg)
// check error
if test.expectedErr != nil {
require.Contains(t, err.Error(), test.expectedErr.Error(), "error should contain expected value")
} else {
assert.NoError(t, err)
}
})

// Assert output
if test.expectedOutput != "" {
assert.Equal(t, test.expectedOutput, output)
}
})
}
}

// captureOutput captures stdout from a function
func captureOutput(t *testing.T, f func()) string {
t.Helper()

// capture stdout
old := os.Stdout
r, w, err := os.Pipe()
require.NoError(t, err)
os.Stdout = w

// run the function
f()

// close and restore stdout
w.Close()
os.Stdout = old
var buf bytes.Buffer
_, err = buf.ReadFrom(r)
require.NoError(t, err)
return buf.String()
}
2 changes: 1 addition & 1 deletion cmd/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ var DeleteCommand = &cli.Command{
}

// check for updates after running the command
if err := CheckForUpdate(cCtx, afero.NewOsFs()); err != nil {
if err := CheckForUpdate(cfg); err != nil {
return err
}

Expand Down
8 changes: 6 additions & 2 deletions cmd/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,12 @@ func (c *CmdTestSuite) TestRunDeleteCmd() {
}

// validate that the expected databases remain
for _, db := range test.expectedRemainingDbs {
require.Contains(t, dbString, db, "database %s should not have been deleted", db)
require.ElementsMatch(t, test.expectedRemainingDbs, dbString, "remaining databases should match expected value")

// cleanup
for _, db := range test.dbs {
err := c.server.DeleteSensorDB(db.name)
require.NoError(t, err, "dropping database should not produce an error")
}

})
Expand Down
14 changes: 10 additions & 4 deletions cmd/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ var ErrInvalidLogHourRange = errors.New("could not parse hour from log file name
var ErrInvalidLogType = errors.New("incompatible log type")
var ErrIncompatibleFileExtension = errors.New("incompatible file extension")
var ErrSkippedDuplicateLog = errors.New("encountered file with same name but different extension, skipping file due to older last modified time")
var ErrMissingLogDirectory = errors.New("log directory flag is required")

type WalkError struct {
Path string
Expand Down Expand Up @@ -110,7 +111,7 @@ var ImportCommand = &cli.Command{
}

// check for updates after running the command
if err := CheckForUpdate(cCtx, afero.NewOsFs()); err != nil {
if err := CheckForUpdate(cfg); err != nil {
return err
}

Expand Down Expand Up @@ -303,7 +304,7 @@ func RunImportCmd(startTime time.Time, cfg *config.Config, afs afero.Fs, logDir

func ValidateLogDirectory(afs afero.Fs, logDir string) error {
if logDir == "" {
return fmt.Errorf("log directory flag is required")
return ErrMissingLogDirectory
}

dir, err := util.ParseRelativePath(logDir)
Expand Down Expand Up @@ -343,7 +344,12 @@ func ValidateDatabaseName(name string) error {
return nil
}

func parseFolderDate(folder string) (time.Time, error) {
// ParseFolderDate extracts the date from a given folder name
func ParseFolderDate(folder string) (time.Time, error) {
if folder == "" {
return time.Unix(0, 0), errors.New("folder name cannot be empty")
}

// check if the path is a directory
folderDate, err := time.Parse(time.DateOnly, folder)
if err != nil {
Expand Down Expand Up @@ -489,7 +495,7 @@ func WalkFiles(afs afero.Fs, root string) ([]HourlyZeekLogs, []WalkError, error)
}

parentDir := filepath.Base(filepath.Dir(file.path))
folderDate, err := parseFolderDate(parentDir)
folderDate, err := ParseFolderDate(parentDir)
if err != nil {
walkErrors = append(walkErrors, WalkError{Path: path, Error: err})
}
Expand Down
Loading

0 comments on commit c450224

Please sign in to comment.