diff --git a/internal/gitrepo/gitrepo.go b/internal/gitrepo/gitrepo.go index 3ae05a9..2e2e560 100644 --- a/internal/gitrepo/gitrepo.go +++ b/internal/gitrepo/gitrepo.go @@ -85,7 +85,7 @@ func (fc fallbackCatter) Cat(ctx context.Context, dst io.Writer, wantType object for nextID := id; ; { got, r, err := fc.OpenObject(ctx, nextID) if err != nil { - return fmt.Errorf("cat %v %v: %v", wantType, id, err) + return fmt.Errorf("cat %v %v: %w", wantType, id, err) } if got.Type == wantType { _, err := io.Copy(dst, r) diff --git a/internal/repocache/repocache.go b/internal/repocache/repocache.go index 1f0b7b7..83d4e9e 100644 --- a/internal/repocache/repocache.go +++ b/internal/repocache/repocache.go @@ -25,6 +25,7 @@ import ( "embed" "errors" "fmt" + "hash" "io" "gg-scm.io/pkg/git/githash" @@ -107,73 +108,123 @@ func migrate(conn *sqlite.Conn) (err error) { return nil } -// Cat copies the content of the given object from the cache into dst. +// OpenObject returns a reader for the given object in the cache. // If the object is not present in the cache, -// then Cat will return an error that wraps [ErrObjectNotFound]. -// If Cat does not return an error, -// it guarantees that the bytes written to dst match the hash. -func (c *Cache) Cat(ctx context.Context, dst io.Writer, id githash.SHA1) (_ object.Type, err error) { +// then OpenObject will return an error that wraps [ErrObjectNotFound]. +// If the returned reader returns an EOF, +// it guarantees that the bytes read from it match the hash. +func (c *Cache) OpenObject(ctx context.Context, id githash.SHA1) (_ object.Prefix, _ io.ReadCloser, err error) { c.conn.SetInterrupt(ctx.Done()) defer c.conn.SetInterrupt(nil) defer sqlitex.Transaction(c.conn)(&err) - _, tp, err := cat(c.conn, dst, id) - return tp, err + _, prefix, rc, err := openObject(c.conn, id) + return prefix, rc, err } -func stat(conn *sqlite.Conn, id githash.SHA1) (oid int64, tp object.Type, uncompressedSize int64, err error) { - uncompressedSize = -1 +// Stat returns the prefix of the given object. +// If the object is not present in the cache, +// then Stat will return an error that wraps [ErrObjectNotFound]. +func (c *Cache) Stat(ctx context.Context, id githash.SHA1) (_ object.Prefix, err error) { + c.conn.SetInterrupt(ctx.Done()) + defer c.conn.SetInterrupt(nil) + defer sqlitex.Transaction(c.conn)(&err) + _, prefix, err := stat(c.conn, id) + return prefix, err +} + +func stat(conn *sqlite.Conn, id githash.SHA1) (oid int64, prefix object.Prefix, err error) { + prefix.Size = -1 err = sqlitex.ExecuteTransientFS(conn, sqlFiles, "objects/find.sql", &sqlitex.ExecOptions{ Named: map[string]any{ ":sha1": id[:], }, ResultFunc: func(stmt *sqlite.Stmt) error { oid = stmt.GetInt64("oid") - tp = object.Type(stmt.GetText("type")) - uncompressedSize = stmt.GetInt64("uncompressed_size") + prefix.Type = object.Type(stmt.GetText("type")) + prefix.Size = stmt.GetInt64("uncompressed_size") return nil }, }) if err != nil { - return 0, "", 0, fmt.Errorf("read git object %v: %v", id, err) + return 0, object.Prefix{}, fmt.Errorf("read git object %v: %v", id, err) } - if uncompressedSize < 0 { - return 0, "", 0, fmt.Errorf("read git object %v: %w", id, ErrObjectNotFound) + if prefix.Size < 0 { + return 0, object.Prefix{}, fmt.Errorf("read git object %v: %w", id, ErrObjectNotFound) } - return oid, tp, uncompressedSize, nil + return oid, prefix, nil } -func cat(conn *sqlite.Conn, dst io.Writer, id githash.SHA1) (oid int64, _ object.Type, err error) { - defer sqlitex.Save(conn)(&err) +// objectReader is an open handle to a Git object. +// It verifies the read content on EOF. +type objectReader struct { + blob *sqlite.Blob + zr io.ReadCloser + err error + + id githash.SHA1 + hash hash.Hash + remaining int64 +} - oid, tp, uncompressedSize, err := stat(conn, id) +func openObject(conn *sqlite.Conn, id githash.SHA1) (oid int64, _ object.Prefix, _ io.ReadCloser, err error) { + oid, prefix, err := stat(conn, id) if err != nil { - return 0, "", err + return 0, object.Prefix{}, nil, err } + + // Intentionally not holding onto a transaction or savepoint here. compressedContent, err := conn.OpenBlob("", objectsTable, contentColumn, oid, false) if err != nil { - return 0, "", fmt.Errorf("read git object %v: %v", id, err) + return oid, prefix, nil, fmt.Errorf("read git object %v: %v", id, err) } - defer compressedContent.Close() - h := sha1.New() - h.Write(object.AppendPrefix(nil, tp, uncompressedSize)) - uncompressedContent, err := zlib.NewReader(compressedContent) - if err != nil { - return 0, "", fmt.Errorf("read git object %v: %v", id, err) + r := &objectReader{ + blob: compressedContent, + id: id, + remaining: prefix.Size, + hash: sha1.New(), } - gotSize, err := io.Copy(io.MultiWriter(h, dst), uncompressedContent) - uncompressedContent.Close() + r.zr, err = zlib.NewReader(compressedContent) if err != nil { - return 0, "", fmt.Errorf("read git object %v: %v", id, err) + compressedContent.Close() + return oid, prefix, nil, fmt.Errorf("read git object %v: %v", id, err) + } + r.hash.Write(object.AppendPrefix(nil, prefix.Type, prefix.Size)) + return oid, prefix, r, nil +} + +func (r *objectReader) Read(p []byte) (n int, err error) { + if r.err != nil { + return 0, r.err } - if gotSize != uncompressedSize { - return 0, "", fmt.Errorf("read git object %v: corrupted content (advertised size was %d bytes; found %d bytes)", id, uncompressedSize, gotSize) + if int64(len(p)) > r.remaining { + p = p[:r.remaining] } - var gotHash githash.SHA1 - h.Sum(gotHash[:0]) - if gotHash != id { - return 0, "", fmt.Errorf("read git object %v: corrupted content (hash = %v)", id, gotHash) + if len(p) > 0 { + n, err = r.zr.Read(p) + r.hash.Write(p[:n]) + r.remaining -= int64(n) } - return oid, tp, nil + switch { + case err != nil && err != io.EOF: + return n, fmt.Errorf("read git object %v: %w", r.id, err) + case err == io.EOF && r.remaining > 0: + return n, fmt.Errorf("read git object %v: %w", r.id, io.ErrUnexpectedEOF) + case r.remaining == 0: + var gotHash githash.SHA1 + r.hash.Sum(gotHash[:0]) + if gotHash != r.id { + return n, fmt.Errorf("read git object %v: corrupted content (hash = %v)", r.id, gotHash) + } + r.err = io.EOF + return n, io.EOF + } + return n, nil +} + +func (r *objectReader) Close() error { + r.err = fmt.Errorf("read git object %v: closed", r.id) + r.zr.Close() + return r.blob.Close() } // Close releases all resources associated with the cache connection. diff --git a/internal/repocache/repocache_test.go b/internal/repocache/repocache_test.go index 54b6177..24cba19 100644 --- a/internal/repocache/repocache_test.go +++ b/internal/repocache/repocache_test.go @@ -29,6 +29,7 @@ import ( "gg-scm.io/pkg/git/object" "gg-scm.io/pkg/git/packfile/client" "gg-scm.io/tool/internal/filesystem" + "gg-scm.io/tool/internal/gitrepo" "github.com/google/go-cmp/cmp" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" @@ -185,13 +186,10 @@ func TestCopyFrom(t *testing.T) { } got := new(bytes.Buffer) - gotType, err := cache.Cat(ctx, got, commitObjectName) + err = gitrepo.NewRepository(cache).Cat(ctx, got, object.TypeCommit, commitObjectName) if err != nil { t.Fatal("Cat:", err) } - if wantType := object.TypeCommit; gotType != wantType { - t.Errorf("type = %q; want %q", gotType, wantType) - } if diff := cmp.Diff(want, got.Bytes()); diff != "" { t.Errorf("content (-want +got):\n%s", diff) } diff --git a/internal/repocache/sync.go b/internal/repocache/sync.go index 750b13f..903a0cb 100644 --- a/internal/repocache/sync.go +++ b/internal/repocache/sync.go @@ -18,7 +18,6 @@ package repocache import ( "bufio" - "bytes" "compress/zlib" "context" "crypto/sha1" @@ -189,15 +188,21 @@ func insertObject(conn *sqlite.Conn, insertStmt *sqlite.Stmt, name githash.SHA1, func indexCommit(conn *sqlite.Conn, id githash.SHA1) (err error) { defer sqlitex.Save(conn)(&err) - buf := new(bytes.Buffer) - oid, tp, err := cat(conn, buf, id) + oid, prefix, rc, err := openObject(conn, id) if err != nil { return fmt.Errorf("index commit %v: %v", id, err) } - if tp != object.TypeCommit { - return fmt.Errorf("index commit %v: not a commit (found %v)", id, tp) + if prefix.Type != object.TypeCommit { + rc.Close() + return fmt.Errorf("index commit %v: not a commit (found %v)", id, prefix.Type) } - parsed, err := object.ParseCommit(buf.Bytes()) + buf := make([]byte, prefix.Size) + _, err = io.ReadFull(rc, buf) + rc.Close() + if err != nil { + return fmt.Errorf("index commit %v: %v", id, err) + } + parsed, err := object.ParseCommit(buf) if err != nil { return fmt.Errorf("index commit %v: %v", id, err) }