Skip to content

Commit

Permalink
feat: Add file.Untar (#10)
Browse files Browse the repository at this point in the history
* feat: Add file.Untar

* feat: Create dirs if no dir entry

* feat: Keep track of made directories and only create if it does not exist

* feat: Untar supports symlinks
  • Loading branch information
cszatmary authored Dec 18, 2021
1 parent 55c8a7f commit c575e2e
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 59 deletions.
126 changes: 111 additions & 15 deletions file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
package file

import (
"archive/tar"
"bytes"
"compress/gzip"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
)
Expand All @@ -27,23 +29,23 @@ func Exists(path string) bool {
return true
}

// Download creates or replaces a file at downloadPath by reading from r.
func Download(downloadPath string, r io.Reader) (int64, error) {
// Download creates or replaces a file at dst by reading from r.
func Download(dst string, r io.Reader) (int64, error) {
// Check if file exists
downloadDir := filepath.Dir(downloadPath)
if err := os.MkdirAll(downloadDir, mkdirDefaultPerms); err != nil {
return 0, fmt.Errorf("failed to create directory %q: %w", downloadDir, err)
dstDir := filepath.Dir(dst)
if err := os.MkdirAll(dstDir, mkdirDefaultPerms); err != nil {
return 0, fmt.Errorf("failed to create directory %q: %w", dstDir, err)
}

// Write payload to target dir
f, err := os.Create(downloadPath)
f, err := os.Create(dst)
if err != nil {
return 0, fmt.Errorf("failed to create file %q: %w", downloadPath, err)
return 0, fmt.Errorf("failed to create file %q: %w", dst, err)
}
defer f.Close()
n, err := io.Copy(f, r)
if err != nil {
return 0, fmt.Errorf("failed writing data to file %q: %w", downloadPath, err)
return 0, fmt.Errorf("failed writing data to file %q: %w", dst, err)
}
return n, nil
}
Expand Down Expand Up @@ -109,29 +111,31 @@ func copyDirContents(src, dst string, info os.FileInfo) error {
return fmt.Errorf("failed to create directory %q: %w", dst, err)
}

contents, err := ioutil.ReadDir(src)
contents, err := os.ReadDir(src)
if err != nil {
return fmt.Errorf("failed to read contents of directory %q: %w", src, err)
}

for _, item := range contents {
srcItemPath := filepath.Join(src, item.Name())
dstItemPath := filepath.Join(dst, item.Name())
fi, err := item.Info()
if err != nil {
return fmt.Errorf("failed to get info of %q: %w", srcItemPath, err)
}

if item.IsDir() {
err := copyDirContents(srcItemPath, dstItemPath, item)
err := copyDirContents(srcItemPath, dstItemPath, fi)
if err != nil {
return fmt.Errorf("failed to copy directory %q: %w", srcItemPath, err)
}
continue
}
if !item.Mode().IsRegular() {
if !fi.Mode().IsRegular() {
// Unsupported file type, ignore
continue
}

err := copyFile(srcItemPath, dstItemPath, item)
if err != nil {
if err := copyFile(srcItemPath, dstItemPath, fi); err != nil {
return fmt.Errorf("failed to copy file %q: %w", srcItemPath, err)
}
}
Expand Down Expand Up @@ -170,3 +174,95 @@ func DirLen(path string) (int, error) {
list, err := dir.Readdirnames(0)
return len(list), err
}

// Untar reads the tar file from r and writes it to dir.
// It can handle gzip-compressed tar files.
//
// Note that Untar will overwrite any existing files with the same path
// as files in the archive.
func Untar(dir string, r io.Reader) error {
// Determine if we are dealing with a gzip-compressed tar file.
// gzip files are identified by the first 3 bytes.
// See section 2.3.1. of RFC 1952: https://www.ietf.org/rfc/rfc1952.txt
buf := make([]byte, 3)
if _, err := io.ReadFull(r, buf); err != nil {
return fmt.Errorf("unable to check if tar file is gzip-compressed: %w", err)
}

// Need to create a new reader with the 3 bytes added back to move back to the
// start of the file. Can do this by concatenating buf with r.
rr := io.MultiReader(bytes.NewReader(buf), r)
if buf[0] == 0x1f && buf[1] == 0x8b && buf[2] == 8 {
gzr, err := gzip.NewReader(rr)
if err != nil {
return fmt.Errorf("unable to read gzip-compressed tar file: %w", err)
}
defer gzr.Close()
rr = gzr
}
tr := tar.NewReader(rr)

// Now we get to the fun part, the actual tar extraction.
// Loop through each entry in the archive and extract it.
// Keep track of a list of dirs created so we don't waste time creating the same dir multiple times.
madeDirs := make(map[string]struct{})
for {
header, err := tr.Next()
if err == io.EOF {
// End of the archive, we are done.
return nil
} else if err != nil {
return fmt.Errorf("untar: read error: %w", err)
}

dst := filepath.Join(dir, header.Name)
// Ensure the parent directory exists. Usually this shouldn't be required since there
// should be a directory entry in the tar file that created the directory beforehand.
// However, testing has revealed that this is not always the case and there can be
// tar files without directory entries so we should handle those cases.
parentDir := filepath.Dir(dst)
if _, ok := madeDirs[parentDir]; !ok {
if err := os.MkdirAll(parentDir, mkdirDefaultPerms); err != nil {
return fmt.Errorf("untar: create directory error: %w", err)
}
madeDirs[parentDir] = struct{}{}
}

mode := header.FileInfo().Mode()
switch {
case mode.IsDir():
if err := os.MkdirAll(dst, mkdirDefaultPerms); err != nil {
return fmt.Errorf("untar: create directory error: %w", err)
}
// Mark the dir as created so files in this dir don't need to create it again.
madeDirs[dst] = struct{}{}
case mode.IsRegular():
// Now we can create the actual file. Untar will overwrite any existing files.
f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode.Perm())
if err != nil {
return fmt.Errorf("untar: create file error: %w", err)
}
n, err := io.Copy(f, tr)

// We need to manually close the file here instead of using defer since defer runs when
// the function exits and would cause all files to remain open until this loop is finished.
if closeErr := f.Close(); closeErr != nil && err == nil {
err = closeErr
}
if err != nil {
return fmt.Errorf("untar: error writing file to %s: %w", dst, err)
}
// Make sure the right amount of bytes were written just to be safe.
if n != header.Size {
return fmt.Errorf("untar: only wrote %d bytes to %s; expected %d", n, dst, header.Size)
}
case mode&os.ModeSymlink != 0:
// Entry is a symlink, need to create a symlink to the target
if err := os.Symlink(header.Linkname, dst); err != nil {
return fmt.Errorf("untar: symlink error: %w", err)
}
default:
return fmt.Errorf("tar file entry %s has unsupported file type %v", header.Name, mode)
}
}
}
131 changes: 87 additions & 44 deletions file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package file_test

import (
"errors"
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -48,22 +47,15 @@ func TestDownload(t *testing.T) {
}

// Make sure file was actually written
data, err := ioutil.ReadFile(downloadPath)
if err != nil {
t.Fatalf("failed to read file %v", err)
}
gotContent := string(data)
if gotContent != content {
t.Errorf("got %q, want %q", gotContent, content)
}
assertFile(t, downloadPath, content)
}

func TestCopyFile(t *testing.T) {
tmpdir := t.TempDir()
src := filepath.Join(tmpdir, "src")
dst := filepath.Join(tmpdir, "dst")
const content = `this is some file content`
err := ioutil.WriteFile(src, []byte(content), 0o644)
err := os.WriteFile(src, []byte(content), 0o644)
if err != nil {
t.Fatalf("failed to write file %v", err)
}
Expand All @@ -72,14 +64,7 @@ func TestCopyFile(t *testing.T) {
if err != nil {
t.Errorf("want nil error, got %v", err)
}
data, err := ioutil.ReadFile(dst)
if err != nil {
t.Fatalf("failed to read file %v", err)
}
gotContent := string(data)
if gotContent != content {
t.Errorf("got %q, want %q", gotContent, content)
}
assertFile(t, dst, content)
}

func TestCopyFileNotRegularFile(t *testing.T) {
Expand All @@ -106,12 +91,12 @@ func TestCopyDirContents(t *testing.T) {
t.Fatalf("failed to create dir: %s", err)
}
const barfileContent = "bar"
err = ioutil.WriteFile(filepath.Join(src, "barfile"), []byte(barfileContent), 0o644)
err = os.WriteFile(filepath.Join(src, "barfile"), []byte(barfileContent), 0o644)
if err != nil {
t.Fatalf("failed to create file: %s", err)
}
const bazfileContent = "baz"
err = ioutil.WriteFile(filepath.Join(src, "foodir", "bazfile"), []byte(bazfileContent), 0o644)
err = os.WriteFile(filepath.Join(src, "foodir", "bazfile"), []byte(bazfileContent), 0o644)
if err != nil {
t.Fatalf("failed to create file: %s", err)
}
Expand All @@ -120,32 +105,16 @@ func TestCopyDirContents(t *testing.T) {
if err != nil {
t.Errorf("want nil error, got %v", err)
}

data, err := ioutil.ReadFile(filepath.Join(dst, "barfile"))
if err != nil {
t.Fatalf("failed to read file %v", err)
}
gotContent := string(data)
if gotContent != barfileContent {
t.Errorf("got %q, want %q", gotContent, barfileContent)
}

data, err = ioutil.ReadFile(filepath.Join(dst, "foodir", "bazfile"))
if err != nil {
t.Fatalf("failed to read file %v", err)
}
gotContent = string(data)
if gotContent != bazfileContent {
t.Errorf("got %q, want %q", gotContent, bazfileContent)
}
assertFile(t, filepath.Join(dst, "barfile"), barfileContent)
assertFile(t, filepath.Join(dst, "foodir", "bazfile"), bazfileContent)
}

func TestCopyDirContentsNotDir(t *testing.T) {
tmpdir := t.TempDir()
src := filepath.Join(tmpdir, "src")
dst := filepath.Join(tmpdir, "dst")
const content = `this is some file content`
err := ioutil.WriteFile(src, []byte(content), 0o644)
err := os.WriteFile(src, []byte(content), 0o644)
if err != nil {
t.Fatalf("failed to write file %v", err)
}
Expand All @@ -163,12 +132,12 @@ func TestDirSize(t *testing.T) {
t.Fatalf("failed to create dir: %s", err)
}
const barfileContent = "bar"
err = ioutil.WriteFile(filepath.Join(tmpdir, "barfile"), []byte(barfileContent), 0o644)
err = os.WriteFile(filepath.Join(tmpdir, "barfile"), []byte(barfileContent), 0o644)
if err != nil {
t.Fatalf("failed to create file: %s", err)
}
const bazfileContent = "baz"
err = ioutil.WriteFile(filepath.Join(tmpdir, "foodir", "bazfile"), []byte(bazfileContent), 0o644)
err = os.WriteFile(filepath.Join(tmpdir, "foodir", "bazfile"), []byte(bazfileContent), 0o644)
if err != nil {
t.Fatalf("failed to create file: %s", err)
}
Expand All @@ -189,7 +158,7 @@ func TestDirSizeNotDir(t *testing.T) {
tmpdir := t.TempDir()
const barfileContent = "bar"
barfilePath := filepath.Join(tmpdir, "barfile")
err := ioutil.WriteFile(barfilePath, []byte(barfileContent), 0o644)
err := os.WriteFile(barfilePath, []byte(barfileContent), 0o644)
if err != nil {
t.Fatalf("failed to create file: %s", err)
}
Expand All @@ -209,11 +178,11 @@ func TestDirLen(t *testing.T) {
if err != nil {
t.Fatalf("failed to create dir: %s", err)
}
err = ioutil.WriteFile(filepath.Join(tmpdir, "barfile"), []byte("bar"), 0o644)
err = os.WriteFile(filepath.Join(tmpdir, "barfile"), []byte("bar"), 0o644)
if err != nil {
t.Fatalf("failed to create file: %s", err)
}
err = ioutil.WriteFile(filepath.Join(tmpdir, "bazfile"), []byte("baz"), 0o644)
err = os.WriteFile(filepath.Join(tmpdir, "bazfile"), []byte("baz"), 0o644)
if err != nil {
t.Fatalf("failed to create file: %s", err)
}
Expand All @@ -227,3 +196,77 @@ func TestDirLen(t *testing.T) {
t.Errorf("got dir len %d, want %d", n, want)
}
}

func TestUntar(t *testing.T) {
tests := []struct {
name string
path string
}{
{"normal tar file", "testdata/basic.tar"},
{"gzip-compressed tar file", "testdata/basic.tgz"},
{"tar file without directories", "testdata/basic_nodirs.tgz"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f, err := os.Open(tt.path)
if err != nil {
t.Fatalf("failed to open %s: %v", tt.path, err)
}
t.Cleanup(func() {
f.Close()
})

tmpdir := t.TempDir()
err = file.Untar(tmpdir, f)
if err != nil {
t.Fatalf("want nil error, got %v", err)
}

assertFile(t, filepath.Join(tmpdir, "a.txt"), "This is a file\n")
// This means the b dir exists by definition
assertFile(t, filepath.Join(tmpdir, "b/c.txt"), "This is another file inside a directory\n")
})
}
}

func TestUntarSymlink(t *testing.T) {
const path = "testdata/basic_symlink.tgz"
f, err := os.Open(path)
if err != nil {
t.Fatalf("failed to open %s: %v", path, err)
}
t.Cleanup(func() {
f.Close()
})

tmpdir := t.TempDir()
err = file.Untar(tmpdir, f)
if err != nil {
t.Fatalf("want nil error, got %v", err)
}

assertFile(t, filepath.Join(tmpdir, "a.txt"), "This is a file\n")
// Check that symlink was created with the right path
cPath := filepath.Join(tmpdir, "b/c.txt")
link, err := os.Readlink(cPath)
if err != nil {
t.Fatalf("failed to read link %s: %v", cPath, err)
}
const wantLink = "../a.txt"
if link != wantLink {
t.Errorf("got symlink %q, want %q", link, wantLink)
}
assertFile(t, cPath, "This is a file\n")
}

func assertFile(t *testing.T, path, want string) {
t.Helper()
b, err := os.ReadFile(path)
if err != nil {
t.Fatalf("failed to read file %s: %v", path, err)
}
got := string(b)
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
Binary file added file/testdata/basic.tar
Binary file not shown.
Binary file added file/testdata/basic.tgz
Binary file not shown.
Binary file added file/testdata/basic_nodirs.tgz
Binary file not shown.
Binary file added file/testdata/basic_symlink.tgz
Binary file not shown.

0 comments on commit c575e2e

Please sign in to comment.