Skip to content

Commit

Permalink
initramfs test: convert to errors.Is
Browse files Browse the repository at this point in the history
Signed-off-by: Chris Koch <[email protected]>
  • Loading branch information
hugelgupf committed Feb 8, 2024
1 parent 5309bc4 commit dd94bfe
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 50 deletions.
5 changes: 4 additions & 1 deletion uroot/initramfs/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package initramfs

import (
"errors"
"fmt"
"log"
"os"
Expand Down Expand Up @@ -137,11 +138,13 @@ func (af *Files) AddFileNoFollow(src string, dest string) error {
return af.addFile(src, dest, false)
}

var errAbsoluteName = errors.New("record name must not be absolute")

// AddRecord adds a cpio.Record into the archive at `r.Name`.
func (af *Files) AddRecord(r cpio.Record) error {
r.Name = path.Clean(r.Name)
if filepath.IsAbs(r.Name) {
return fmt.Errorf("record name %q must not be absolute", r.Name)
return fmt.Errorf("%w: %q", errAbsoluteName, r.Name)
}

if src, ok := af.Files[r.Name]; ok {
Expand Down
86 changes: 37 additions & 49 deletions uroot/initramfs/files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
package initramfs

import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"reflect"
"strings"
"testing"

"github.com/u-root/mkuimage/cpio"
Expand All @@ -34,12 +34,12 @@ func TestFilesAddFileNoFollow(t *testing.T) {
}

for i, tt := range []struct {
name string
af *Files
src string
dest string
result *Files
errContains string
name string
af *Files
src string
dest string
result *Files
err error
}{
{
name: "just add a file",
Expand Down Expand Up @@ -70,11 +70,8 @@ func TestFilesAddFileNoFollow(t *testing.T) {
} {
t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
err := tt.af.AddFileNoFollow(tt.src, tt.dest)
if err != nil && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
}
if err == nil && len(tt.errContains) > 0 {
t.Errorf("Got no error, want %v", tt.errContains)
if !errors.Is(err, tt.err) {
t.Errorf("AddFileNoFollow = %v, want %v", err, tt.err)
}

if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
Expand All @@ -97,25 +94,22 @@ func TestFilesAddFile(t *testing.T) {

symlinkToDir3 := filepath.Join(dir3, "fooSymDir/")
fooDir := filepath.Join(dir3, "fooDir")
//nolint:errcheck
{
os.Create(filepath.Join(dir, "foo"))
os.Create(filepath.Join(dir, "foo2"))
os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))
_ = os.WriteFile(filepath.Join(dir, "foo"), nil, 0o777)
_ = os.WriteFile(filepath.Join(dir, "foo2"), nil, 0o777)
_ = os.Symlink(filepath.Join(dir, "foo2"), filepath.Join(dir2, "foo3"))

os.Mkdir(fooDir, os.ModePerm)
os.Symlink(fooDir, symlinkToDir3)
os.Create(filepath.Join(fooDir, "foo"))
os.Create(filepath.Join(fooDir, "bar"))
}
_ = os.Mkdir(fooDir, os.ModePerm)
_ = os.Symlink(fooDir, symlinkToDir3)
_ = os.WriteFile(filepath.Join(fooDir, "foo"), nil, 0o777)
_ = os.WriteFile(filepath.Join(fooDir, "bar"), nil, 0o777)

for i, tt := range []struct {
name string
af *Files
src string
dest string
result *Files
errContains string
name string
af *Files
src string
dest string
result *Files
err error
}{
{
name: "just add a file",
Expand Down Expand Up @@ -171,7 +165,7 @@ func TestFilesAddFile(t *testing.T) {
"bar/foo": "/some/other/place",
},
},
errContains: "already exists in archive",
err: os.ErrExist,
},
{
name: "add a file that exists in Records",
Expand All @@ -187,7 +181,7 @@ func TestFilesAddFile(t *testing.T) {
"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
},
},
errContains: "already exists in archive",
err: os.ErrExist,
},
{
name: "add a file that already exists in Files, but is the same one",
Expand Down Expand Up @@ -261,18 +255,15 @@ func TestFilesAddFile(t *testing.T) {
"bar/foo/foo2": "/some/place/real/zed",
},
},
src: dir,
dest: "bar/foo",
errContains: "already exists in archive",
src: dir,
dest: "bar/foo",
err: os.ErrExist,
},
} {
t.Run(fmt.Sprintf("Test %02d: %s", i, tt.name), func(t *testing.T) {
err := tt.af.AddFile(tt.src, tt.dest)
if err != nil && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
}
if err == nil && len(tt.errContains) > 0 {
t.Errorf("Got no error, want %v", tt.errContains)
if !errors.Is(err, tt.err) {
t.Errorf("AddFile = %v, want %v", err, tt.err)
}

if tt.result != nil && !reflect.DeepEqual(tt.af, tt.result) {
Expand All @@ -287,8 +278,8 @@ func TestFilesAddRecord(t *testing.T) {
af *Files
record cpio.Record

result *Files
errContains string
result *Files
err error
}{
{
af: NewFiles(),
Expand All @@ -312,7 +303,7 @@ func TestFilesAddRecord(t *testing.T) {
"bar/foo": "/some/other/place",
},
},
errContains: "already exists in archive",
err: os.ErrExist,
},
{
af: &Files{
Expand All @@ -326,7 +317,7 @@ func TestFilesAddRecord(t *testing.T) {
"bar/foo": cpio.Symlink("bar/foo", "/some/other/place"),
},
},
errContains: "already exists in archive",
err: os.ErrExist,
},
{
af: &Files{
Expand All @@ -342,17 +333,14 @@ func TestFilesAddRecord(t *testing.T) {
},
},
{
record: cpio.Symlink("/bar/foo", ""),
errContains: "must not be absolute",
record: cpio.Symlink("/bar/foo", ""),
err: errAbsoluteName,
},
} {
t.Run(fmt.Sprintf("Test %02d", i), func(t *testing.T) {
err := tt.af.AddRecord(tt.record)
if err != nil && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Error is %v, does not contain %v", err, tt.errContains)
}
if err == nil && len(tt.errContains) > 0 {
t.Errorf("Got no error, want %v", tt.errContains)
if !errors.Is(err, tt.err) {
t.Errorf("AddRecord = %v, want %v", err, tt.err)
}

if !reflect.DeepEqual(tt.af, tt.result) {
Expand Down

0 comments on commit dd94bfe

Please sign in to comment.