Skip to content

Commit

Permalink
Chore/compression v1 (#940)
Browse files Browse the repository at this point in the history
* feat: compressor version and decision to bypass compression

* feat: compressor versioning in decompressor

* feat: compressor version in snark decompressor

* fix: decompressor no compression bug

* perf: isBoolean / isCrumb

* fix: ReadIntoStream

* fix nosec rng

* feat: make snark decompressor big endian

* revert bring back ReadIntoStream

* fix all but the "real data" test

* test discovered error in readNum for non-bit granularity

* fix numReader bug

* style more readable AssertIsCrumb const path

* docs: explain AssertIsCrumb

* refactor: concentrate settings purego io, ReadIntoStream to return errors

* doc explain ReadIntoStream

* refactor NewStream to return errors

* docs mathfmt

* revert: Try[Read|Write]Bits

* fix no compression not writing anything
  • Loading branch information
Tabaie authored Dec 1, 2023
1 parent acb6e50 commit 98a1c52
Show file tree
Hide file tree
Showing 14 changed files with 336 additions and 179 deletions.
4 changes: 3 additions & 1 deletion frontend/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ type API interface {
// AssertIsDifferent fails if i1 == i2
AssertIsDifferent(i1, i2 Variable)

// AssertIsBoolean fails if v != 0 v != 1
// AssertIsBoolean fails if v != 0 and v != 1
AssertIsBoolean(i1 Variable)
// AssertIsCrumb fails if v ∉ {0,1,2,3} (crumb is a 2-bit variable; see https://en.wikipedia.org/wiki/Units_of_information)
AssertIsCrumb(i1 Variable)

// AssertIsLessOrEqual fails if v > bound.
//
Expand Down
6 changes: 6 additions & 0 deletions frontend/cs/r1cs/api_assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) {
}
}

func (builder *builder) AssertIsCrumb(i1 frontend.Variable) {
i1 = builder.MulAcc(builder.Mul(-3, i1), i1, i1)
i1 = builder.MulAcc(builder.Mul(2, i1), i1, i1)
builder.AssertIsEqual(i1, 0)
}

// AssertIsLessOrEqual adds assertion in constraint builder (v ⩽ bound)
//
// bound can be a constant or a Variable
Expand Down
24 changes: 24 additions & 0 deletions frontend/cs/scs/api_assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,30 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) {

}

func (builder *builder) AssertIsCrumb(i1 frontend.Variable) {
const errorMsg = "AssertIsCrumb: input is not a crumb"
if c, ok := builder.constantValue(i1); ok {
if i, ok := builder.cs.Uint64(c); ok && i < 4 {
return
}
panic(errorMsg)
}

// i1 (i1-1) (i1-2) (i1-3) = (i1² - 3i1) (i1² - 3i1 + 2)
// take X := i1² - 3i1 and we get X (X+2) = 0

x := builder.MulAcc(builder.Mul(-3, i1), i1, i1).(expr.Term)

// TODO @Tabaie Ideally this entire function would live in std/math/bits as it is quite specialized;
// however using two generic MulAccs and an AssertIsEqual results in three constraints rather than two.
builder.addPlonkConstraint(sparseR1C{
xa: x.VID,
xb: x.VID,
qL: builder.cs.FromInterface(2),
qM: builder.tOne,
})
}

// AssertIsLessOrEqual fails if v > bound
func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) {
cv, vConst := builder.constantValue(v)
Expand Down
20 changes: 11 additions & 9 deletions std/compress/lzss/backref.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,21 @@ type backref struct {
func (b *backref) writeTo(w *bitio.Writer, i int) {
w.TryWriteByte(b.bType.delimiter)
w.TryWriteBits(uint64(b.length-1), b.bType.nbBitsLength)
if b.bType.dictOnly {
w.TryWriteBits(uint64(b.address), b.bType.nbBitsAddress)
} else {
w.TryWriteBits(uint64(i-b.address-1), b.bType.nbBitsAddress)
addrToWrite := b.address
if !b.bType.dictOnly {
addrToWrite = i - b.address - 1
}
w.TryWriteBits(uint64(addrToWrite), b.bType.nbBitsAddress)
}

func (b *backref) readFrom(r *bitio.Reader) {
b.length = int(r.TryReadBits(b.bType.nbBitsLength)) + 1
if b.bType.dictOnly {
b.address = int(r.TryReadBits(b.bType.nbBitsAddress))
} else {
b.address = int(r.TryReadBits(b.bType.nbBitsAddress)) + 1
n := r.TryReadBits(b.bType.nbBitsLength)
b.length = int(n) + 1

n = r.TryReadBits(b.bType.nbBitsAddress)
b.address = int(n)
if !b.bType.dictOnly {
b.address++
}
}

Expand Down
46 changes: 43 additions & 3 deletions std/compress/lzss/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package lzss
import (
"bytes"
"fmt"
"io"
"math/bits"

"github.com/consensys/gnark/std/compress/lzss/internal/suffixarray"
Expand Down Expand Up @@ -103,7 +104,10 @@ func (compressor *Compressor) Compress(d []byte) (c []byte, err error) {

// reset output buffer
compressor.buf.Reset()
compressor.buf.WriteByte(byte(compressor.level))
settings := settings{version: 0, level: compressor.level}
if err = settings.writeTo(&compressor.buf); err != nil {
return
}
if compressor.level == NoCompression {
compressor.buf.Write(d)
return compressor.buf.Bytes(), nil
Expand Down Expand Up @@ -206,11 +210,21 @@ func (compressor *Compressor) Compress(d []byte) (c []byte, err error) {
if compressor.bw.TryError != nil {
return nil, compressor.bw.TryError
}
if err := compressor.bw.Close(); err != nil {
if err = compressor.bw.Close(); err != nil {
return nil, err
}

return compressor.buf.Bytes(), nil
if compressor.buf.Len() >= len(d)+settings.bitLen()/8 {
// compression was not worth it
compressor.buf.Reset()
settings.level = NoCompression
if err = settings.writeTo(&compressor.buf); err != nil {
return
}
_, err = compressor.buf.Write(d)
}

return compressor.buf.Bytes(), err
}

// canEncodeSymbol returns true if the symbol can be encoded directly
Expand Down Expand Up @@ -261,3 +275,29 @@ func max(a, b int) int {
}
return b
}

type settings struct {
version byte
level Level
}

func (s *settings) writeTo(w io.Writer) error {
_, err := w.Write([]byte{s.version, byte(s.level)}) // 0 -> compressor release version
return err
}

func (s *settings) readFrom(r io.ByteReader) (err error) {
if s.version, err = r.ReadByte(); err != nil {
return
}
if level, err := r.ReadByte(); err != nil {
return err
} else {
s.level = Level(level)
}
return
}

func (s *settings) bitLen() int {
return 16
}
90 changes: 48 additions & 42 deletions std/compress/lzss/decompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package lzss

import (
"bytes"
"io"

"errors"
"github.com/consensys/gnark/std/compress"
"github.com/icza/bitio"
"io"
)

func DecompressGo(data, dict []byte) (d []byte, err error) {
Expand All @@ -14,13 +14,19 @@ func DecompressGo(data, dict []byte) (d []byte, err error) {
out.Grow(len(data)*6 + len(dict))
in := bitio.NewReader(bytes.NewReader(data))

level := Level(in.TryReadByte())
if level == NoCompression {
return data[1:], nil
var settings settings
if err = settings.readFrom(in); err != nil {
return
}
if settings.version != 0 {
return nil, errors.New("unsupported compressor version")
}
if settings.level == NoCompression {
return data[2:], nil
}

dict = augmentDict(dict)
shortBackRefType, longBackRefType, dictBackRefType := initBackRefTypes(len(dict), level)
shortBackRefType, longBackRefType, dictBackRefType := initBackRefTypes(len(dict), settings.level)

bDict := backref{bType: dictBackRefType}
bShort := backref{bType: shortBackRefType}
Expand Down Expand Up @@ -56,60 +62,60 @@ func DecompressGo(data, dict []byte) (d []byte, err error) {
return out.Bytes(), nil
}

func ReadIntoStream(data, dict []byte, level Level) compress.Stream {
in := bitio.NewReader(bytes.NewReader(data))
// ReadIntoStream reads the compressed data into a stream
// the stream is not padded with zeros as one obtained by a naive call to compress.NewStream may be
func ReadIntoStream(data, dict []byte, level Level) (compress.Stream, error) {

wordLen := int(level)
out, err := compress.NewStream(data, uint8(level))
if err != nil {
return out, err
}

// now find out how much of the stream is padded zeros and remove them
byteReader := bytes.NewReader(data)
in := bitio.NewReader(byteReader)
dict = augmentDict(dict)
var settings settings
if err := settings.readFrom(byteReader); err != nil {
return out, err
}
shortBackRefType, longBackRefType, dictBackRefType := initBackRefTypes(len(dict), level)

bDict := backref{bType: dictBackRefType}
bShort := backref{bType: shortBackRefType}
bLong := backref{bType: longBackRefType}

levelFromData := Level(in.TryReadByte())
if levelFromData != NoCompression && levelFromData != level {
panic("compression mode mismatch")
// the main job of this function is to compute the right value for outLenBits
// so we can remove the extra zeros at the end of out
outLenBits := settings.bitLen()
if settings.level == NoCompression {
return out, nil
}

out := compress.Stream{
NbSymbs: 1 << wordLen,
if settings.level != level {
return out, errors.New("compression mode mismatch")
}

out.WriteNum(int(levelFromData), 8/wordLen)

s := in.TryReadByte()

for in.TryError == nil {
out.WriteNum(int(s), 8/wordLen)

var b *backref
var b *backrefType
switch s {
case symbolShort:
// short back ref
b = &bShort
b = &shortBackRefType
case symbolLong:
// long back ref
b = &bLong
b = &longBackRefType
case symbolDict:
// dict back ref
b = &bDict
b = &dictBackRefType
}
if b != nil && levelFromData != NoCompression {
b.readFrom(in)
address := b.address
if b != &bDict {
address--
}
out.WriteNum(b.length-1, int(b.bType.nbBitsLength)/wordLen)
out.WriteNum(address, int(b.bType.nbBitsAddress)/wordLen)
if b == nil {
outLenBits += 8
} else {
in.TryReadBits(b.nbBitsBackRef - 8)
outLenBits += int(b.nbBitsBackRef)
}

s = in.TryReadByte()
}
if in.TryError != io.EOF {
panic(in.TryError)
return out, in.TryError
}
return out

return compress.Stream{
D: out.D[:outLenBits/int(level)],
NbSymbs: out.NbSymbs,
}, nil
}
8 changes: 6 additions & 2 deletions std/compress/lzss/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,16 @@ func testCompressionE2E(t *testing.T, d, dict []byte, name string) {
c, err := compressor.Compress(d)
assert.NoError(t, err)

cStream := ReadIntoStream(c, dict, BestCompression)
cStream, err := compress.NewStream(c, uint8(compressor.level))
assert.NoError(t, err)

cSum, err := check(cStream, cStream.Len())
assert.NoError(t, err)

dSum, err := check(compress.NewStreamFromBytes(d), len(d))
dStream, err := compress.NewStream(d, 8)
assert.NoError(t, err)

dSum, err := check(dStream, len(d))
assert.NoError(t, err)

circuit := compressionCircuit{
Expand Down
Loading

0 comments on commit 98a1c52

Please sign in to comment.