Skip to content

Commit

Permalink
Fix frame count calculation on journal recovery (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
benbjohnson authored Nov 24, 2022
1 parent 77f1414 commit 9bb4a2d
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 61 deletions.
143 changes: 82 additions & 61 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func (db *DB) rollbackJournal(ctx context.Context) error {
defer func() { _ = dbFile.Close() }()

// Copy every journal page back into the main database file.
r := NewJournalReader(journalFile)
r := NewJournalReader(journalFile, db.pageSize)
for i := 0; ; i++ {
if err := r.Next(); err == io.EOF {
break
Expand Down Expand Up @@ -910,7 +910,7 @@ func (db *DB) CommitJournal(mode JournalMode) error {
return fmt.Errorf("cannot open journal file: %w", err)
}

journalPageMap, err := buildJournalPageMap(journalFile)
journalPageMap, err := buildJournalPageMap(journalFile, db.pageSize)
if err != nil {
return fmt.Errorf("cannot build journal page map: %w", err)
}
Expand Down Expand Up @@ -1532,8 +1532,8 @@ type dbVarJSON struct {
} `json:"locks"`
}

func buildJournalPageMap(f *os.File) (map[uint32]uint64, error) {
r := NewJournalReader(f)
func buildJournalPageMap(f *os.File, pageSize uint32) (map[uint32]uint64, error) {
r := NewJournalReader(f, pageSize)

// Generate a map of pages and their new checksums.
m := make(map[uint32]uint64)
Expand Down Expand Up @@ -1565,61 +1565,92 @@ func buildJournalPageMap(f *os.File) (map[uint32]uint64, error) {

// JouralReader represents a reader of the SQLite journal file format.
type JournalReader struct {
r io.Reader
n int64 // bytes read
frame []byte // frame buffer

frameN int32 // Number of pages in the segment, or -1 to mean all content to the end of the file
nonce uint32 // A random nonce for the checksum
initialSize uint32 // Initial size of the database in pages
sectorSize uint32 // Size of the disk sectors
pageSize uint32 // Size of each page, in bytes
f *os.File
fi os.FileInfo // cached file info
offset int64 // read offset
frame []byte // frame buffer

frameN int32 // Number of pages in the segment
nonce uint32 // A random nonce for the checksum
commit uint32 // Initial size of the database in pages
sectorSize uint32 // Size of the disk sectors
pageSize uint32 // Size of each page, in bytes
}

// JournalReader returns a new instance of JournalReader.
func NewJournalReader(r io.Reader) *JournalReader {
return &JournalReader{r: r}
func NewJournalReader(f *os.File, pageSize uint32) *JournalReader {
return &JournalReader{
f: f,
pageSize: pageSize,
}
}

// DatabaseSize returns the size of the database before the journal transaction, in bytes.
func (r *JournalReader) DatabaseSize() int64 {
return int64(r.initialSize) * int64(r.pageSize)
return int64(r.commit) * int64(r.pageSize)
}

// Next reads the next segment of the journal. Returns io.EOF if no more segments exist.
func (r *JournalReader) Next() error {
// Read journal header. Return EOF if header is invalid.
buf := make([]byte, len(SQLITE_JOURNAL_HEADER_STRING)+20)
n, err := io.ReadFull(r.r, buf)
r.n += int64(n)
if err != nil {
return io.EOF
} else if string(buf[:len(SQLITE_JOURNAL_HEADER_STRING)]) != SQLITE_JOURNAL_HEADER_STRING {
func (r *JournalReader) Next() (err error) {
// Determine journal size on initial call.
if r.fi == nil {
if r.fi, err = r.f.Stat(); err != nil {
return err
}
}

// Ensure offset is sector-aligned.
r.offset = journalHeaderOffset(r.offset, int64(r.sectorSize))

// Read full header.
hdr := make([]byte, SQLITE_JOURNAL_HEADER_SIZE)
if _, err := internal.ReadFullAt(r.f, hdr, r.offset); err == io.EOF || err == io.ErrUnexpectedEOF {
return io.EOF // no header or partial header
} else if err != nil {
return err
} else if isByteSliceZero(hdr) {
return io.EOF
}

// Read fields after header magic.
hdr := buf[len(SQLITE_JOURNAL_HEADER_STRING):]
r.frameN = int32(binary.BigEndian.Uint32(hdr[0:]))
r.nonce = binary.BigEndian.Uint32(hdr[4:])
r.initialSize = binary.BigEndian.Uint32(hdr[8:])
r.sectorSize = binary.BigEndian.Uint32(hdr[12:])
r.pageSize = binary.BigEndian.Uint32(hdr[16:])
if r.pageSize == 0 {
return fmt.Errorf("invalid page size in journal header: %d", r.pageSize)
// Read number of frames in journal segment. Set to -1 if no-sync was set
// and set to 0 if the journal was not sync'd. In these two cases we will
// calculate the frame count based on the journal size.
r.frameN = int32(binary.BigEndian.Uint32(hdr[8:]))
if r.frameN == -1 {
r.frameN = int32((r.fi.Size() - int64(r.sectorSize)) / int64(r.pageSize))
} else if r.frameN == 0 {
r.frameN = int32((r.fi.Size() - r.offset) / int64(r.pageSize))
}

// Create a buffer to read the segment frames.
r.frame = make([]byte, r.pageSize+4+4)
// Read remaining fields from header.
r.nonce = binary.BigEndian.Uint32(hdr[12:]) // cksumInit
r.commit = binary.BigEndian.Uint32(hdr[16:]) // dbSize

// Only read sector and page size from first journal header.
if r.offset == 0 {
r.sectorSize = binary.BigEndian.Uint32(hdr[20:])

// Use page size from journal reader, if set to 0.
pageSize := binary.BigEndian.Uint32(hdr[24:])
if pageSize == 0 {
pageSize = r.pageSize
}
if pageSize != r.pageSize {
return fmt.Errorf("journal header page size (%d) does not match database (%d)", pageSize, r.pageSize)
}
}

// Move to the end of the sector.
p := make([]byte, int64(r.sectorSize)-int64(len(buf)))
n, err = io.ReadFull(r.r, p)
r.n += int64(n)
if err != nil && err != io.EOF {
return fmt.Errorf("cannot seek to next sector: %w", err)
// Exit if file doesn't have more than the initial sector.
if r.offset+int64(r.sectorSize) > r.fi.Size() {
return io.EOF
}

// Move journal offset to first sector.
r.offset += int64(r.sectorSize)

// Create a buffer to read the segment frames.
r.frame = make([]byte, r.pageSize+4+4)

return nil
}

Expand All @@ -1632,34 +1663,24 @@ func (r *JournalReader) ReadFrame() (pgno uint32, data []byte, err error) {
}

// Read the next frame from the journal.
n, err := io.ReadFull(r.r, r.frame)
r.n += int64(n)
if err != nil {
n, err := internal.ReadFullAt(r.f, r.frame, r.offset)
if err == io.ErrUnexpectedEOF {
return 0, nil, io.EOF
} else if err != nil {
return 0, nil, err
}
r.frameN--

// At the end of the last frame, move to the next sector.
if r.frameN == 0 {
b := make([]byte, nextMultipleOf(r.n, int64(r.sectorSize))-r.n)
n, err := io.ReadFull(r.r, b)
r.n += int64(n)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return 0, nil, fmt.Errorf("seek to next journal segment: %w", err)
}
}
r.offset += int64(n)

return binary.BigEndian.Uint32(r.frame[0:]), r.frame[4 : len(r.frame)-4], nil
}

// nextMultipleOf returns the next multiple of denom based on v.
// Returns v if it is a multiple of denom.
func nextMultipleOf(v, denom int64) int64 {
mod := v % denom
if mod == 0 {
return v
// journalHeaderOffset returns a sector-aligned offset.
func journalHeaderOffset(offset, sectorSize int64) int64 {
if offset == 0 {
return 0
}
return v + (denom - mod)
return ((offset-1)/sectorSize + 1) * sectorSize
}

// readAndVerifyLTXFile reads an LTX file and verifies its integrity.
Expand Down
16 changes: 16 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package internal

import (
"io"
"os"
)

Expand All @@ -17,3 +18,18 @@ func Sync(path string) error {
}
return f.Close()
}

// ReadFullAt is an implementation of io.ReadFull() but for io.ReaderAt.
func ReadFullAt(r io.ReaderAt, buf []byte, off int64) (n int, err error) {
for n < len(buf) && err == nil {
var nn int
nn, err = r.ReadAt(buf[n:], off+int64(n))
n += nn
}
if n >= len(buf) {
return n, nil
} else if n > 0 && err == io.EOF {
return n, io.ErrUnexpectedEOF
}
return n, err
}
40 changes: 40 additions & 0 deletions internal/internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package internal_test

import (
"io"
"strings"
"testing"

"github.com/superfly/litefs/internal"
)

func TestReadFullAt(t *testing.T) {
t.Run("OK", func(t *testing.T) {
buf := make([]byte, 2)
if n, err := internal.ReadFullAt(strings.NewReader("abcde"), buf, 2); err != nil {
t.Fatal(err)
} else if got, want := n, 2; got != want {
t.Fatalf("n=%v, want %v", got, want)
} else if got, want := string(buf), "cd"; got != want {
t.Fatalf("buf=%q, want %q", got, want)
}
})

t.Run("ErrUnexpectedEOF", func(t *testing.T) {
buf := make([]byte, 4)
if n, err := internal.ReadFullAt(strings.NewReader("abcde"), buf, 2); err != io.ErrUnexpectedEOF {
t.Fatalf("unexpected error: %#v", err)
} else if got, want := n, 3; got != want {
t.Fatalf("n=%v, want %v", got, want)
} else if got, want := string(buf), "cde\x00"; got != want {
t.Fatalf("buf=%q, want %q", got, want)
}
})

t.Run("EOF", func(t *testing.T) {
buf := make([]byte, 2)
if _, err := internal.ReadFullAt(strings.NewReader(""), buf, 2); err != io.EOF {
t.Fatalf("unexpected error: %#v", err)
}
})
}

0 comments on commit 9bb4a2d

Please sign in to comment.