diff --git a/frontend/api.go b/frontend/api.go index 747ce61b64..f12060ff33 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -119,8 +119,10 @@ type API interface { // AssertIsDifferent fails if i1 == i2 AssertIsDifferent(i1, i2 Variable) - // AssertIsBoolean fails if v != 0 ∥ v != 1 + // AssertIsBoolean fails if v != 0 and v != 1 AssertIsBoolean(i1 Variable) + // AssertIsCrumb fails if v ∉ {0,1,2,3} (crumb is a 2-bit variable; see https://en.wikipedia.org/wiki/Units_of_information) + AssertIsCrumb(i1 Variable) // AssertIsLessOrEqual fails if v > bound. // diff --git a/frontend/cs/r1cs/api_assertions.go b/frontend/cs/r1cs/api_assertions.go index 2ea2ec155d..b75f200f7a 100644 --- a/frontend/cs/r1cs/api_assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -81,6 +81,12 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { } } +func (builder *builder) AssertIsCrumb(i1 frontend.Variable) { + i1 = builder.MulAcc(builder.Mul(-3, i1), i1, i1) + i1 = builder.MulAcc(builder.Mul(2, i1), i1, i1) + builder.AssertIsEqual(i1, 0) +} + // AssertIsLessOrEqual adds assertion in constraint builder (v ⩽ bound) // // bound can be a constant or a Variable diff --git a/frontend/cs/scs/api_assertions.go b/frontend/cs/scs/api_assertions.go index 3b0d489ee6..3fe9ef1d9a 100644 --- a/frontend/cs/scs/api_assertions.go +++ b/frontend/cs/scs/api_assertions.go @@ -128,6 +128,30 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { } +func (builder *builder) AssertIsCrumb(i1 frontend.Variable) { + const errorMsg = "AssertIsCrumb: input is not a crumb" + if c, ok := builder.constantValue(i1); ok { + if i, ok := builder.cs.Uint64(c); ok && i < 4 { + return + } + panic(errorMsg) + } + + // i1 (i1-1) (i1-2) (i1-3) = (i1² - 3i1) (i1² - 3i1 + 2) + // take X := i1² - 3i1 and we get X (X+2) = 0 + + x := builder.MulAcc(builder.Mul(-3, i1), i1, i1).(expr.Term) + + // TODO @Tabaie Ideally this entire function would live in std/math/bits as it is quite specialized; + // however using two generic MulAccs and an AssertIsEqual results in three constraints rather than two. + builder.addPlonkConstraint(sparseR1C{ + xa: x.VID, + xb: x.VID, + qL: builder.cs.FromInterface(2), + qM: builder.tOne, + }) +} + // AssertIsLessOrEqual fails if v > bound func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { cv, vConst := builder.constantValue(v) diff --git a/std/compress/lzss/backref.go b/std/compress/lzss/backref.go index fc21ff69c8..587dcd3a50 100644 --- a/std/compress/lzss/backref.go +++ b/std/compress/lzss/backref.go @@ -50,19 +50,21 @@ type backref struct { func (b *backref) writeTo(w *bitio.Writer, i int) { w.TryWriteByte(b.bType.delimiter) w.TryWriteBits(uint64(b.length-1), b.bType.nbBitsLength) - if b.bType.dictOnly { - w.TryWriteBits(uint64(b.address), b.bType.nbBitsAddress) - } else { - w.TryWriteBits(uint64(i-b.address-1), b.bType.nbBitsAddress) + addrToWrite := b.address + if !b.bType.dictOnly { + addrToWrite = i - b.address - 1 } + w.TryWriteBits(uint64(addrToWrite), b.bType.nbBitsAddress) } func (b *backref) readFrom(r *bitio.Reader) { - b.length = int(r.TryReadBits(b.bType.nbBitsLength)) + 1 - if b.bType.dictOnly { - b.address = int(r.TryReadBits(b.bType.nbBitsAddress)) - } else { - b.address = int(r.TryReadBits(b.bType.nbBitsAddress)) + 1 + n := r.TryReadBits(b.bType.nbBitsLength) + b.length = int(n) + 1 + + n = r.TryReadBits(b.bType.nbBitsAddress) + b.address = int(n) + if !b.bType.dictOnly { + b.address++ } } diff --git a/std/compress/lzss/compress.go b/std/compress/lzss/compress.go index eb7c3fa4cb..a60dc981d5 100644 --- a/std/compress/lzss/compress.go +++ b/std/compress/lzss/compress.go @@ -3,6 +3,7 @@ package lzss import ( "bytes" "fmt" + "io" "math/bits" "github.com/consensys/gnark/std/compress/lzss/internal/suffixarray" @@ -103,7 +104,10 @@ func (compressor *Compressor) Compress(d []byte) (c []byte, err error) { // reset output buffer compressor.buf.Reset() - compressor.buf.WriteByte(byte(compressor.level)) + settings := settings{version: 0, level: compressor.level} + if err = settings.writeTo(&compressor.buf); err != nil { + return + } if compressor.level == NoCompression { compressor.buf.Write(d) return compressor.buf.Bytes(), nil @@ -206,11 +210,21 @@ func (compressor *Compressor) Compress(d []byte) (c []byte, err error) { if compressor.bw.TryError != nil { return nil, compressor.bw.TryError } - if err := compressor.bw.Close(); err != nil { + if err = compressor.bw.Close(); err != nil { return nil, err } - return compressor.buf.Bytes(), nil + if compressor.buf.Len() >= len(d)+settings.bitLen()/8 { + // compression was not worth it + compressor.buf.Reset() + settings.level = NoCompression + if err = settings.writeTo(&compressor.buf); err != nil { + return + } + _, err = compressor.buf.Write(d) + } + + return compressor.buf.Bytes(), err } // canEncodeSymbol returns true if the symbol can be encoded directly @@ -261,3 +275,29 @@ func max(a, b int) int { } return b } + +type settings struct { + version byte + level Level +} + +func (s *settings) writeTo(w io.Writer) error { + _, err := w.Write([]byte{s.version, byte(s.level)}) // 0 -> compressor release version + return err +} + +func (s *settings) readFrom(r io.ByteReader) (err error) { + if s.version, err = r.ReadByte(); err != nil { + return + } + if level, err := r.ReadByte(); err != nil { + return err + } else { + s.level = Level(level) + } + return +} + +func (s *settings) bitLen() int { + return 16 +} diff --git a/std/compress/lzss/decompress.go b/std/compress/lzss/decompress.go index 84b11db236..a7a35e7794 100644 --- a/std/compress/lzss/decompress.go +++ b/std/compress/lzss/decompress.go @@ -2,10 +2,10 @@ package lzss import ( "bytes" - "io" - + "errors" "github.com/consensys/gnark/std/compress" "github.com/icza/bitio" + "io" ) func DecompressGo(data, dict []byte) (d []byte, err error) { @@ -14,13 +14,19 @@ func DecompressGo(data, dict []byte) (d []byte, err error) { out.Grow(len(data)*6 + len(dict)) in := bitio.NewReader(bytes.NewReader(data)) - level := Level(in.TryReadByte()) - if level == NoCompression { - return data[1:], nil + var settings settings + if err = settings.readFrom(in); err != nil { + return + } + if settings.version != 0 { + return nil, errors.New("unsupported compressor version") + } + if settings.level == NoCompression { + return data[2:], nil } dict = augmentDict(dict) - shortBackRefType, longBackRefType, dictBackRefType := initBackRefTypes(len(dict), level) + shortBackRefType, longBackRefType, dictBackRefType := initBackRefTypes(len(dict), settings.level) bDict := backref{bType: dictBackRefType} bShort := backref{bType: shortBackRefType} @@ -56,60 +62,60 @@ func DecompressGo(data, dict []byte) (d []byte, err error) { return out.Bytes(), nil } -func ReadIntoStream(data, dict []byte, level Level) compress.Stream { - in := bitio.NewReader(bytes.NewReader(data)) +// ReadIntoStream reads the compressed data into a stream +// the stream is not padded with zeros as one obtained by a naive call to compress.NewStream may be +func ReadIntoStream(data, dict []byte, level Level) (compress.Stream, error) { - wordLen := int(level) + out, err := compress.NewStream(data, uint8(level)) + if err != nil { + return out, err + } + // now find out how much of the stream is padded zeros and remove them + byteReader := bytes.NewReader(data) + in := bitio.NewReader(byteReader) dict = augmentDict(dict) + var settings settings + if err := settings.readFrom(byteReader); err != nil { + return out, err + } shortBackRefType, longBackRefType, dictBackRefType := initBackRefTypes(len(dict), level) - bDict := backref{bType: dictBackRefType} - bShort := backref{bType: shortBackRefType} - bLong := backref{bType: longBackRefType} - - levelFromData := Level(in.TryReadByte()) - if levelFromData != NoCompression && levelFromData != level { - panic("compression mode mismatch") + // the main job of this function is to compute the right value for outLenBits + // so we can remove the extra zeros at the end of out + outLenBits := settings.bitLen() + if settings.level == NoCompression { + return out, nil } - - out := compress.Stream{ - NbSymbs: 1 << wordLen, + if settings.level != level { + return out, errors.New("compression mode mismatch") } - out.WriteNum(int(levelFromData), 8/wordLen) - s := in.TryReadByte() - for in.TryError == nil { - out.WriteNum(int(s), 8/wordLen) - - var b *backref + var b *backrefType switch s { case symbolShort: - // short back ref - b = &bShort + b = &shortBackRefType case symbolLong: - // long back ref - b = &bLong + b = &longBackRefType case symbolDict: - // dict back ref - b = &bDict + b = &dictBackRefType } - if b != nil && levelFromData != NoCompression { - b.readFrom(in) - address := b.address - if b != &bDict { - address-- - } - out.WriteNum(b.length-1, int(b.bType.nbBitsLength)/wordLen) - out.WriteNum(address, int(b.bType.nbBitsAddress)/wordLen) + if b == nil { + outLenBits += 8 + } else { + in.TryReadBits(b.nbBitsBackRef - 8) + outLenBits += int(b.nbBitsBackRef) } - s = in.TryReadByte() } if in.TryError != io.EOF { - panic(in.TryError) + return out, in.TryError } - return out + + return compress.Stream{ + D: out.D[:outLenBits/int(level)], + NbSymbs: out.NbSymbs, + }, nil } diff --git a/std/compress/lzss/e2e_test.go b/std/compress/lzss/e2e_test.go index b83e14b996..70053c5d77 100644 --- a/std/compress/lzss/e2e_test.go +++ b/std/compress/lzss/e2e_test.go @@ -37,12 +37,16 @@ func testCompressionE2E(t *testing.T, d, dict []byte, name string) { c, err := compressor.Compress(d) assert.NoError(t, err) - cStream := ReadIntoStream(c, dict, BestCompression) + cStream, err := compress.NewStream(c, uint8(compressor.level)) + assert.NoError(t, err) cSum, err := check(cStream, cStream.Len()) assert.NoError(t, err) - dSum, err := check(compress.NewStreamFromBytes(d), len(d)) + dStream, err := compress.NewStream(d, 8) + assert.NoError(t, err) + + dSum, err := check(dStream, len(d)) assert.NoError(t, err) circuit := compressionCircuit{ diff --git a/std/compress/lzss/snark.go b/std/compress/lzss/snark.go index 3bf764fff7..e12e6b2954 100644 --- a/std/compress/lzss/snark.go +++ b/std/compress/lzss/snark.go @@ -2,6 +2,7 @@ package lzss import ( "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/compress" "github.com/consensys/gnark/std/lookup/logderivlookup" ) @@ -9,39 +10,35 @@ import ( // d consists of bytes func Decompress(api frontend.API, c []frontend.Variable, cLength frontend.Variable, d []frontend.Variable, dict []byte, level Level) (dLength frontend.Variable, err error) { - wordLen := int(level) + wordNbBits := int(level) + + checkInputRange(api, c, wordNbBits) dict = augmentDict(dict) shortBackRefType, longBackRefType, dictBackRefType := initBackRefTypes(len(dict), level) - shortBrNbWords := int(shortBackRefType.nbBitsBackRef) / wordLen - longBrNbWords := int(longBackRefType.nbBitsBackRef) / wordLen - dictBrNbWords := int(dictBackRefType.nbBitsBackRef) / wordLen - byteNbWords := 8 / wordLen + shortBrNbWords := int(shortBackRefType.nbBitsBackRef) / wordNbBits + longBrNbWords := int(longBackRefType.nbBitsBackRef) / wordNbBits + dictBrNbWords := int(dictBackRefType.nbBitsBackRef) / wordNbBits + byteNbWords := 8 / wordNbBits - fileCompressionMode := readNum(api, c, byteNbWords, wordLen) - c = c[byteNbWords:] - cLength = api.Sub(cLength, byteNbWords) - api.AssertIsEqual(api.Mul(fileCompressionMode, fileCompressionMode), api.Mul(fileCompressionMode, wordLen)) // if fcm!=0, then fcm=wordLen + api.AssertIsEqual(compress.ReadNum(api, c, byteNbWords, wordNbBits), 0) // compressor version TODO @tabaie @gbotrel Handle this outside the circuit instead? + fileCompressionMode := compress.ReadNum(api, c[byteNbWords:], byteNbWords, wordNbBits) + c = c[2*byteNbWords:] + cLength = api.Sub(cLength, 2*byteNbWords) + api.AssertIsEqual(api.Mul(fileCompressionMode, fileCompressionMode), api.Mul(fileCompressionMode, wordNbBits)) // if fcm!=0, then fcm=wordNbBits decompressionNotBypassed := api.Sub(1, api.IsZero(fileCompressionMode)) - // assert that c are within range - cRangeTable := logderivlookup.New(api) - for i := 0; i < 1< 2 { + cRangeTable := logderivlookup.New(api) + for i := 0; i < 1< 0 { - lastSummand = nr.c[nr.nbWords] - } - for i := 1; i < nr.nbWords; i++ { // TODO Cache stepCoeff^nbWords - lastSummand = nr.api.Mul(lastSummand, nr.stepCoeff) - } - - nr.nxt = nr.api.Add(nr.api.DivUnchecked(nr.api.Sub(res, nr.c[0]), nr.stepCoeff), lastSummand) - - nr.c = nr.c[1:] - return res -} - func evaluatePlonkExpression(api frontend.API, a, b frontend.Variable, aCoeff, bCoeff, mCoeff, constant int) frontend.Variable { if plonkAPI, ok := api.(frontend.PlonkAPI); ok { return plonkAPI.EvaluatePlonkExpression(a, b, aCoeff, bCoeff, mCoeff, constant) diff --git a/std/compress/lzss/snark_test.go b/std/compress/lzss/snark_test.go index 726f1e8c99..82f4520313 100644 --- a/std/compress/lzss/snark_test.go +++ b/std/compress/lzss/snark_test.go @@ -19,7 +19,11 @@ func Test1ZeroSnark(t *testing.T) { testCompressionRoundTripSnark(t, []byte{0}, nil) } -func Test0To10Explicit(t *testing.T) { +func TestGoodCompressionSnark(t *testing.T) { + testCompressionRoundTripSnark(t, []byte{1, 2}, nil, withLevel(GoodCompression)) +} + +func Test0To10ExplicitSnark(t *testing.T) { testCompressionRoundTripSnark(t, []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, nil) } @@ -35,24 +39,27 @@ func TestNoCompressionSnark(t *testing.T) { c, err := compressor.Compress(d) require.NoError(t, err) - cStream := ReadIntoStream(c, dict, BestCompression) + decompressorLevel := BestCompression + + cStream, err := compress.NewStream(c, uint8(decompressorLevel)) + require.NoError(t, err) circuit := &DecompressionTestCircuit{ C: make([]frontend.Variable, cStream.Len()), D: d, Dict: dict, CheckCorrectness: true, - Level: BestCompression, + Level: decompressorLevel, } assignment := &DecompressionTestCircuit{ C: test_vector_utils.ToVariableSlice(cStream.D), CLength: cStream.Len(), } - test.NewAssert(t).SolvingSucceeded(circuit, assignment, test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) } -func Test4ZerosBackref(t *testing.T) { +func Test4ZerosBackrefSnark(t *testing.T) { shortBackRefType, longBackRefType, _ := initBackRefTypes(0, BestCompression) @@ -68,7 +75,7 @@ func Test4ZerosBackref(t *testing.T) { ) } -func Test255_254_253(t *testing.T) { +func Test255_254_253Snark(t *testing.T) { testCompressionRoundTripSnark(t, []byte{255, 254, 253}, nil) } @@ -82,7 +89,7 @@ func Test3c2943Snark(t *testing.T) { } // Fuzz test the decompression -func FuzzSnark(f *testing.F) { +func FuzzSnark(f *testing.F) { // TODO This is always skipped f.Fuzz(func(t *testing.T, input, dict []byte) { if len(input) > maxInputSize { t.Skip("input too large") @@ -97,11 +104,20 @@ func FuzzSnark(f *testing.F) { }) } -func testCompressionRoundTripSnark(t *testing.T, d, dict []byte) { +type testCompressionRoundTripOption func(*Level) + +func withLevel(level Level) testCompressionRoundTripOption { + return func(l *Level) { + *l = level + } +} + +func testCompressionRoundTripSnark(t *testing.T, d, dict []byte, options ...testCompressionRoundTripOption) { level := BestCompression - if len(d) > 1000 { - level = GoodCompression + + for _, option := range options { + option(&level) } compressor, err := NewCompressor(dict, level) @@ -109,7 +125,8 @@ func testCompressionRoundTripSnark(t *testing.T, d, dict []byte) { c, err := compressor.Compress(d) require.NoError(t, err) - cStream := ReadIntoStream(c, dict, level) + cStream, err := ReadIntoStream(c, dict, level) + require.NoError(t, err) circuit := &DecompressionTestCircuit{ C: make([]frontend.Variable, cStream.Len()), @@ -129,7 +146,7 @@ func testCompressionRoundTripSnark(t *testing.T, d, dict []byte) { func testDecompressionSnark(t *testing.T, dict []byte, level Level, compressedStream ...interface{}) { var bb bytes.Buffer w := bitio.NewWriter(&bb) - bb.WriteByte(byte(level)) + bb.Write([]byte{0, byte(level)}) i := 0 for _, c := range compressedStream { switch v := c.(type) { @@ -156,36 +173,39 @@ func testDecompressionSnark(t *testing.T, dict []byte, level Level, compressedSt c := bb.Bytes() d, err := DecompressGo(c, dict) require.NoError(t, err) - cStream := ReadIntoStream(c, dict, BestCompression) + + cStream, err := ReadIntoStream(c, dict, level) + require.NoError(t, err) circuit := &DecompressionTestCircuit{ C: make([]frontend.Variable, cStream.Len()), D: d, Dict: dict, CheckCorrectness: true, - Level: BestCompression, + Level: level, } assignment := &DecompressionTestCircuit{ C: test_vector_utils.ToVariableSlice(cStream.D), CLength: cStream.Len(), } - test.NewAssert(t).SolvingSucceeded(circuit, assignment, test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) } func TestReadBytes(t *testing.T) { - expected := []byte{0, 254, 0, 0} + expected := []byte{254, 0, 0, 0} circuit := &readBytesCircuit{ Words: make([]frontend.Variable, 8*len(expected)), WordNbBits: 1, Expected: expected, } - words := compress.NewStreamFromBytes(expected) + words, err := compress.NewStream(expected, 8) + assert.NoError(t, err) words = words.BreakUp(2) assignment := &readBytesCircuit{ Words: test_vector_utils.ToVariableSlice(words.D), } - test.NewAssert(t).SolvingSucceeded(circuit, assignment, test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithBackends(backend.PLONK), test.WithCurves(ecc.BN254)) } type readBytesCircuit struct { diff --git a/std/compress/lzss/snark_testing.go b/std/compress/lzss/snark_testing.go index dd75dd5ec8..3c8f5e5da9 100644 --- a/std/compress/lzss/snark_testing.go +++ b/std/compress/lzss/snark_testing.go @@ -27,16 +27,13 @@ type DecompressionTestCircuit struct { func (c *DecompressionTestCircuit) Define(api frontend.API) error { dBack := make([]frontend.Variable, len(c.D)) // TODO Try smaller constants - api.Println("maxLen(dBack)", len(dBack)) dLen, err := Decompress(api, c.C, c.CLength, dBack, c.Dict, c.Level) if err != nil { return err } if c.CheckCorrectness { - api.Println("got len", dLen, "expected", len(c.D)) api.AssertIsEqual(len(c.D), dLen) for i := range c.D { - api.Println("decompressed at", i, "->", dBack[i], "expected", c.D[i], "dBack", dBack[i]) api.AssertIsEqual(c.D[i], dBack[i]) } } @@ -61,13 +58,16 @@ func BenchCompressionE2ECompilation(dict []byte, name string) (constraint.Constr return nil, err } - cStream := ReadIntoStream(c, dict, GoodCompression) + cStream, err := compress.NewStream(c, uint8(compressor.level)) + if err != nil { + return nil, err + } circuit := compressionCircuit{ C: make([]frontend.Variable, cStream.Len()), D: make([]frontend.Variable, len(d)), Dict: make([]byte, len(dict)), - Level: GoodCompression, + Level: compressor.level, } var start int64 diff --git a/std/compress/snark_io.go b/std/compress/snark_io.go index 4598870202..c72cf0585c 100644 --- a/std/compress/snark_io.go +++ b/std/compress/snark_io.go @@ -19,3 +19,51 @@ func Pack(api frontend.API, words []frontend.Variable, wordLen int) []frontend.V } return res } + +type NumReader struct { + api frontend.API + c []frontend.Variable + stepCoeff int + maxCoeff int + nbWords int + nxt frontend.Variable +} + +func NewNumReader(api frontend.API, c []frontend.Variable, numNbBits, wordNbBits int) *NumReader { + nbWords := numNbBits / wordNbBits + stepCoeff := 1 << wordNbBits + nxt := ReadNum(api, c, nbWords, stepCoeff) + return &NumReader{ + api: api, + c: c, + stepCoeff: stepCoeff, + maxCoeff: 1 << numNbBits, + nxt: nxt, + nbWords: nbWords, + } +} + +func ReadNum(api frontend.API, c []frontend.Variable, nbWords, stepCoeff int) frontend.Variable { + res := frontend.Variable(0) + for i := 0; i < nbWords && i < len(c); i++ { + res = api.Add(c[i], api.Mul(res, stepCoeff)) + } + return res +} + +// Next returns the next number in the sequence. assumes bits past the end of the slice are 0 +func (nr *NumReader) Next() frontend.Variable { + res := nr.nxt + + if len(nr.c) != 0 { + nr.nxt = nr.api.Sub(nr.api.Mul(nr.nxt, nr.stepCoeff), nr.api.Mul(nr.c[0], nr.maxCoeff)) + + if nr.nbWords < len(nr.c) { + nr.nxt = nr.api.Add(nr.nxt, nr.c[nr.nbWords]) + } + + nr.c = nr.c[1:] + } + + return res +} diff --git a/std/compress/stream.go b/std/compress/stream.go index 4bc141047a..c0cdbf5b9c 100644 --- a/std/compress/stream.go +++ b/std/compress/stream.go @@ -31,12 +31,17 @@ func (s *Stream) At(i int) int { return s.D[i] } -func NewStreamFromBytes(in []byte) Stream { - d := make([]int, len(in)) - for i := range in { - d[i] = int(in[i]) +func NewStream(in []byte, bitsPerSymbol uint8) (Stream, error) { + d := make([]int, len(in)*8/int(bitsPerSymbol)) + r := bitio.NewReader(bytes.NewReader(in)) + for i := range d { + if n, err := r.ReadBits(bitsPerSymbol); err != nil { + return Stream{}, err + } else { + d[i] = int(n) + } } - return Stream{d, 256} + return Stream{d, 1 << int(bitsPerSymbol)}, nil } func (s *Stream) BreakUp(nbSymbs int) Stream { @@ -46,7 +51,7 @@ func (s *Stream) BreakUp(nbSymbs int) Stream { for i := range s.D { v := s.D[i] for j := 0; j < newPerOld; j++ { - d[i*newPerOld+j] = v % nbSymbs + d[(i+1)*newPerOld-j-1] = v % nbSymbs v /= nbSymbs } } diff --git a/test/api_assertions_test.go b/test/api_assertions_test.go new file mode 100644 index 0000000000..f53e2e2f25 --- /dev/null +++ b/test/api_assertions_test.go @@ -0,0 +1,28 @@ +package test + +import ( + "github.com/consensys/gnark/frontend" + "math/rand" + "testing" +) + +func TestIsCrumb(t *testing.T) { + c := []frontend.Variable{0, 1, 2, 3} + assert := NewAssert(t) + assert.SolvingSucceeded(&isCrumbCircuit{C: make([]frontend.Variable, len(c))}, &isCrumbCircuit{C: c}) + for n := 0; n < 20; n++ { + x := rand.Intn(65531) + 4 //#nosec G404 weak rng OK for test + assert.SolvingFailed(&isCrumbCircuit{C: []frontend.Variable{nil}}, &isCrumbCircuit{C: []frontend.Variable{x}}) + } +} + +type isCrumbCircuit struct { + C []frontend.Variable +} + +func (circuit *isCrumbCircuit) Define(api frontend.API) error { + for _, x := range circuit.C { + api.AssertIsCrumb(x) + } + return nil +} diff --git a/test/engine.go b/test/engine.go index 5702832bdb..26f51ae76c 100644 --- a/test/engine.go +++ b/test/engine.go @@ -464,6 +464,12 @@ func (e *engine) AssertIsBoolean(i1 frontend.Variable) { e.mustBeBoolean(b1) } +func (e *engine) AssertIsCrumb(i1 frontend.Variable) { + i1 = e.MulAcc(e.Mul(-3, i1), i1, i1) + i1 = e.MulAcc(e.Mul(2, i1), i1, i1) + e.AssertIsEqual(i1, 0) +} + func (e *engine) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { bValue := e.toBigInt(bound)