Skip to content

Commit

Permalink
s.cfg.Input unchanged throughout the program
Browse files Browse the repository at this point in the history
  • Loading branch information
BrendanCoughlan5 committed Feb 12, 2025
1 parent 93f51a8 commit 0ed2257
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 154 deletions.
158 changes: 66 additions & 92 deletions pkg/snapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ func (s *SnapshotService) CreateSnapshot() error {
if err != nil {
return errors.Wrap(err, fmt.Sprintf("failed to resolve output file path '%s'", s.cfg.OutputFile))
}
s.cfg.OutputFile = resolvedFilePath

s.l.Sugar().Debugw("Resolved file paths", zap.String("Output", s.cfg.OutputFile))
s.l.Sugar().Debugw("Resolved file paths", zap.String("Output", resolvedFilePath))

if err := s.validateCreateSnapshotConfig(); err != nil {
return err
Expand All @@ -108,8 +107,8 @@ func (s *SnapshotService) CreateSnapshot() error {

s.l.Sugar().Infow("Successfully created snapshot")

outputHashFile := getHashName(s.cfg.OutputFile)
if err := saveOutputFileHash(s.cfg.OutputFile, outputHashFile); err != nil {
outputHashFile := getHashName(resolvedFilePath)
if err := saveOutputFileHash(resolvedFilePath, outputHashFile); err != nil {
return errors.Wrap(err, fmt.Sprintf("failed to save output file hash '%s'", outputHashFile))
}

Expand All @@ -123,95 +122,91 @@ func (s *SnapshotService) RestoreSnapshot() error {
s.tempFiles = s.tempFiles[:0] // Clear the tempFiles slice after cleanup
}()

if err := s.resolveAndDownloadRestoreInput(); err != nil {
return err
}

if err := s.validateRestoreConfig(); err != nil {
return err
}

restore, err := s.setupRestore()
if err != nil {
return err
}

restoreExec := restore.Exec(s.cfg.Input, pgcommands.ExecOptions{StreamPrint: false})
if restoreExec.Error != nil {
s.l.Sugar().Errorw("Failed to restore from snapshot",
zap.Error(restoreExec.Error.Err),
zap.String("output", restoreExec.Output),
)
return restoreExec.Error.Err
}

s.l.Sugar().Infow("Successfully restored from snapshot")
return nil
}

// resolveAndDownloadRestoreInput prepares the SnapshotService struct into a suitable format for restoring a snapshot
// and downloads the necessary files if necessary
func (s *SnapshotService) resolveAndDownloadRestoreInput() error {
var resolvedFilePath string
if isHttpURL(s.cfg.Input) {
inputUrl := s.cfg.Input

// Check if the input URL exists
snapshotExists, err := urlExists(inputUrl)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("error checking existence of snapshot URL '%s'", inputUrl))
}
if !snapshotExists {
return fmt.Errorf("snapshot file not found at '%s'. Ensure the file exists", inputUrl)
return errors.Wrap(fmt.Errorf("snapshot file not found at '%s'. Ensure the file exists", inputUrl), "snapshot file not found")
}

// Check if the hash URL exists
fileName, err := getFileNameFromURL(inputUrl)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("failed to extract file name from URL '%s'", inputUrl))
}

inputFilePath := filepath.Join(os.TempDir(), fileName)

// Download the snapshot file hash
if s.cfg.VerifyInput {
hashFilePath := getHashName(inputFilePath)
hashUrl := getHashName(inputUrl)

hashFileExists, err := urlExists(hashUrl)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("error checking existence of snapshot hash URL '%s'", hashUrl))
}
if !hashFileExists {
return fmt.Errorf("snapshot hash file not found at '%s'. Ensure the file exists or set --verify-input=false to skip verification", hashUrl)
return errors.Wrap(fmt.Errorf("snapshot hash file not found at '%s'. Ensure the file exists or set --verify-input=false to skip verification", hashUrl), "snapshot hash file not found")
}

err = downloadFile(hashUrl, hashFilePath)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("error downloading snapshot hash from '%s'", hashUrl))
}
s.tempFiles = append(s.tempFiles, hashFilePath)
}

// Download the snapshot file and assign to s.cfg.Input
fileName, err := getFileNameFromURL(inputUrl)
// Download the snapshot file
err = downloadFile(inputUrl, inputFilePath)
if err != nil {
return fmt.Errorf("failed to extract file name from URL: %w", err)
return errors.Wrap(err, fmt.Sprintf("error downloading snapshot from '%s'", inputUrl))
}
s.tempFiles = append(s.tempFiles, inputFilePath)

inputFilePath := filepath.Join(os.TempDir(), fileName)

s.cfg.Input, err = downloadFile(inputUrl, inputFilePath)
resolvedFilePath, err = resolveFilePath(inputFilePath)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("error downloading snapshot from '%s'", inputUrl))
return errors.Wrap(err, fmt.Sprintf("failed to resolve input file path '%s'", s.cfg.Input))
}
s.tempFiles = append(s.tempFiles, s.cfg.Input)
} else {
var err error
resolvedFilePath, err = resolveFilePath(s.cfg.Input)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("failed to resolve input file path '%s'", s.cfg.Input))
}
}

// Download the snapshot file hash
if s.cfg.VerifyInput {
hashFilePath := getHashName(inputFilePath)
hashUrl := getHashName(inputUrl)
hashFile, err := downloadFile(hashUrl, hashFilePath)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("error downloading snapshot hash from '%s'", hashUrl))
}
s.tempFiles = append(s.tempFiles, hashFile)
s.l.Sugar().Debugw("Resolved file path", zap.String("resolvedFilePath", resolvedFilePath))

// validate snapshot against the hash file
if s.cfg.VerifyInput {
if err := validateInputFileHash(resolvedFilePath, getHashName(resolvedFilePath)); err != nil {
return errors.Wrap(err, fmt.Sprintf("input file hash validation failed for '%s'", resolvedFilePath))
}
s.l.Sugar().Debugw("Input file hash validated successfully",
zap.String("input", resolvedFilePath),
zap.String("inputHashFile", getHashName(resolvedFilePath)),
)
}

resolvedFilePath, err := resolveFilePath(s.cfg.Input)
restore, err := s.setupRestore()
if err != nil {
return errors.Wrap(err, fmt.Sprintf("failed to resolve input file path '%s'", s.cfg.Input))
return err
}
s.cfg.Input = resolvedFilePath

s.l.Sugar().Debugw("Resolved file paths", zap.String("Input", s.cfg.Input))
restoreExec := restore.Exec(resolvedFilePath, pgcommands.ExecOptions{StreamPrint: false})
if restoreExec.Error != nil {
s.l.Sugar().Errorw("Failed to restore from snapshot",
zap.Error(restoreExec.Error.Err),
zap.String("output", restoreExec.Output),
)
return restoreExec.Error.Err
}

s.l.Sugar().Infow("Successfully restored from snapshot")
return nil
}

Expand Down Expand Up @@ -243,34 +238,13 @@ func (s *SnapshotService) setupSnapshotDump() (*pgcommands.Dump, error) {
if s.cfg.SchemaName != "" {
dump.Options = append(dump.Options, fmt.Sprintf("--schema=%s", s.cfg.SchemaName))
}

dump.SetFileName(s.cfg.OutputFile)

return dump, nil
}

func (s *SnapshotService) validateRestoreConfig() error {
if s.cfg.Input == "" {
return fmt.Errorf("restore snapshot file path i.e. `input` must be specified")
}

// s.cfg.input is resolved from a url in resolveAndDownloadRestoreInput()
info, err := os.Stat(s.cfg.Input)
if err != nil || info.IsDir() {
return fmt.Errorf("snapshot file does not exist: %s", s.cfg.Input)
}

if s.cfg.VerifyInput {
if err := validateInputFileHash(s.cfg.Input, getHashName(s.cfg.Input)); err != nil {
return errors.Wrap(err, fmt.Sprintf("input file hash validation failed for '%s'", s.cfg.Input))
}
s.l.Sugar().Debugw("Input file hash validated successfully",
zap.String("input", s.cfg.Input),
zap.String("inputHashFile", getHashName(s.cfg.Input)),
)
resolvedFilePath, err := resolveFilePath(s.cfg.OutputFile)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("failed to resolve output file path '%s'", s.cfg.OutputFile))
}
dump.SetFileName(resolvedFilePath)

return nil
return dump, nil
}

func (s *SnapshotService) setupRestore() (*pgcommands.Restore, error) {
Expand Down Expand Up @@ -362,22 +336,22 @@ func getFileHash(filePath string) ([]byte, error) {
return hasher.Sum(nil), nil
}

func downloadFile(url, downloadDestFilePath string) (string, error) {
func downloadFile(url, downloadDestFilePath string) error {
resp, err := http.Get(url)
if err != nil {
return "", errors.Wrap(err, "failed to initiate download")
return errors.Wrap(err, "failed to initiate download")
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("downloading error, received status code %d", resp.StatusCode)
return fmt.Errorf("downloading error, received status code %d", resp.StatusCode)
}

// Ensure the directory for the file path exists
dir := filepath.Dir(downloadDestFilePath)
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return "", errors.Wrap(err, "failed to create directories to save the download file to")
return errors.Wrap(err, "failed to create directories to save the download file to")
}
}

Expand All @@ -390,16 +364,16 @@ func downloadFile(url, downloadDestFilePath string) (string, error) {

tmpFile, err := os.Create(downloadDestFilePath)
if err != nil {
return "", errors.Wrap(err, "failed to create local file")
return errors.Wrap(err, "failed to create local file")
}
defer tmpFile.Close()

_, err = io.Copy(io.MultiWriter(tmpFile, bar), resp.Body)
if err != nil {
return "", errors.Wrap(err, "failed to write to local file")
return errors.Wrap(err, "failed to write to local file")
}

return downloadDestFilePath, nil
return nil
}

// getHashName returns the hash name for a given file path or URL.
Expand Down
68 changes: 6 additions & 62 deletions pkg/snapshot/snapshot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,65 +84,6 @@ func TestSetupSnapshotDump(t *testing.T) {
assert.NotNil(t, dump, "Dump should not be nil")
}

func TestValidateRestoreConfig(t *testing.T) {
tempDir := t.TempDir()
snapshotFile := filepath.Join(tempDir, "TestValidateRestoreConfig.sql")
snapshotHashFile := filepath.Join(tempDir, "TestValidateRestoreConfig.sql.sha256sum")
content := []byte("test content")
err := os.WriteFile(snapshotFile, content, 0644)
assert.NoError(t, err, "Writing to snapshot file should not fail")

// Generate the hash using the CLI
cmd := exec.Command("shasum", "-a", "256", snapshotFile)
output, err := cmd.Output()
assert.NoError(t, err, "Generating hash with CLI should not fail")

// Extract the hash from the CLI output
hashParts := strings.Fields(string(output))
if len(hashParts) < 1 {
t.Fatal("Failed to parse hash from CLI output")
}
hashFileContent := fmt.Sprintf("%s %s\n", hashParts[0], filepath.Base(snapshotFile))
err = os.WriteFile(snapshotHashFile, []byte(hashFileContent), 0644)
assert.NoError(t, err, "Writing to hash file should not fail")

cfg := &SnapshotConfig{
Host: "localhost",
Port: 5432,
DbName: "testdb",
User: "testuser",
Password: "testpassword",
SchemaName: "public",
Input: snapshotFile,
VerifyInput: true,
}
l, _ := zap.NewDevelopment()
svc, err := NewSnapshotService(cfg, l)
assert.NoError(t, err, "NewSnapshotService should not return an error")
err = svc.validateRestoreConfig()
assert.NoError(t, err, "Restore config should be valid")
os.Remove(snapshotFile)
os.Remove(snapshotHashFile)
}

func TestValidateRestoreConfigMissingInputFile(t *testing.T) {
cfg := &SnapshotConfig{
Host: "localhost",
Port: 5432,
DbName: "testdb",
User: "testuser",
Password: "testpassword",
SchemaName: "public",
Input: "",
VerifyInput: true,
}
l, _ := zap.NewDevelopment()
svc, err := NewSnapshotService(cfg, l)
assert.NoError(t, err, "NewSnapshotService should not return an error")
err = svc.validateRestoreConfig()
assert.Error(t, err, "Restore config should be invalid if input file is missing")
}

func TestSetupRestore(t *testing.T) {
cfg := &SnapshotConfig{
Host: "localhost",
Expand Down Expand Up @@ -295,19 +236,22 @@ func TestDownloadFile(t *testing.T) {
uniqueID := fmt.Sprintf("%d", time.Now().UnixNano())
inputFileName := uniqueID + "_downloaded_snapshot.dump"

tempDir := t.TempDir()
inputFilePath := filepath.Join(tempDir, inputFileName)

// Use the test server's URL to test the downloadFile function
filePath, err := downloadFile(testServer.URL, inputFileName)
err := downloadFile(testServer.URL, inputFilePath)
assert.NoError(t, err, "downloadFile should not return an error")

// Schedule the file for removal after the test completes
defer func() {
if err := os.Remove(filePath); err != nil {
if err := os.Remove(inputFilePath); err != nil {
t.Logf("Failed to remove downloaded file: %v", err)
}
}()

// Verify the file was downloaded correctly
content, err := os.ReadFile(filePath)
content, err := os.ReadFile(inputFilePath)
assert.NoError(t, err, "Reading downloaded file should not fail")
assert.Contains(t, string(content), "This is a test file.", "Downloaded file content should match expected content")
}
Expand Down

0 comments on commit 0ed2257

Please sign in to comment.