From 34afa95dbad0b4e458efc427f72971ed2e6fdf4f Mon Sep 17 00:00:00 2001 From: Michal Hruby Date: Wed, 5 May 2021 20:06:06 +0100 Subject: [PATCH 1/3] Allow direct usage of go slices --- bufferpool.go | 30 ++++++++ gozstd.go | 76 +++++++++++--------- reader.go | 182 +++++++++++++++++++++++++++-------------------- writer.go | 190 ++++++++++++++++++++++++++++++-------------------- 4 files changed, 291 insertions(+), 187 deletions(-) create mode 100644 bufferpool.go diff --git a/bufferpool.go b/bufferpool.go new file mode 100644 index 0000000..b3d7000 --- /dev/null +++ b/bufferpool.go @@ -0,0 +1,30 @@ +package gozstd + +import ( + "bytes" + "sync" +) + +var compInBufPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, cstreamInBufSize)) + }, +} + +var compOutBufPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, cstreamOutBufSize)) + }, +} + +var decInBufPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, dstreamInBufSize)) + }, +} + +var decOutBufPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, dstreamOutBufSize)) + }, +} diff --git a/gozstd.go b/gozstd.go index cc9ffa4..85f07d0 100644 --- a/gozstd.go +++ b/gozstd.go @@ -7,29 +7,27 @@ package gozstd #include "zstd.h" #include "zstd_errors.h" -#include // for uintptr_t - // The following *_wrapper functions allow avoiding memory allocations // durting calls from Go. // See https://github.com/golang/go/issues/24450 . -static size_t ZSTD_compressCCtx_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize, int compressionLevel) { - return ZSTD_compressCCtx((ZSTD_CCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize, compressionLevel); +static size_t ZSTD_compressCCtx_wrapper(void *ctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, int compressionLevel) { + return ZSTD_compressCCtx((ZSTD_CCtx*)ctx, dst, dstCapacity, src, srcSize, compressionLevel); } -static size_t ZSTD_compress_usingCDict_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize, uintptr_t cdict) { +static size_t ZSTD_compress_usingCDict_wrapper(void *ctx, void *dst, size_t dstCapacity, void *src, size_t srcSize, void *cdict) { return ZSTD_compress_usingCDict((ZSTD_CCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize, (const ZSTD_CDict*)cdict); } -static size_t ZSTD_decompressDCtx_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize) { +static size_t ZSTD_decompressDCtx_wrapper(void *ctx, void *dst, size_t dstCapacity, void *src, size_t srcSize) { return ZSTD_decompressDCtx((ZSTD_DCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize); } -static size_t ZSTD_decompress_usingDDict_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize, uintptr_t ddict) { +static size_t ZSTD_decompress_usingDDict_wrapper(void *ctx, void *dst, size_t dstCapacity, void *src, size_t srcSize, void *ddict) { return ZSTD_decompress_usingDDict((ZSTD_DCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize, (const ZSTD_DDict*)ddict); } -static unsigned long long ZSTD_getFrameContentSize_wrapper(uintptr_t src, size_t srcSize) { +static unsigned long long ZSTD_getFrameContentSize_wrapper(void *src, size_t srcSize) { return ZSTD_getFrameContentSize((const void*)src, srcSize); } */ @@ -38,6 +36,7 @@ import "C" import ( "fmt" "io" + "reflect" "runtime" "sync" "unsafe" @@ -147,34 +146,37 @@ func compress(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressi } func compressInternal(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressionLevel int, mustSucceed bool) C.size_t { + dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + if cd != nil { result := C.ZSTD_compress_usingCDict_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cctxDict.cctx))), - C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), + unsafe.Pointer(cctxDict.cctx), + unsafe.Pointer(dstHdr.Data), C.size_t(cap(dst)), - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), + unsafe.Pointer(srcHdr.Data), C.size_t(len(src)), - C.uintptr_t(uintptr(unsafe.Pointer(cd.p)))) + unsafe.Pointer(cd.p)) // Prevent from GC'ing of dst and src during CGO call above. runtime.KeepAlive(dst) runtime.KeepAlive(src) if mustSucceed { - ensureNoError("ZSTD_compress_usingCDict_wrapper", result) + ensureNoError("ZSTD_compress_usingCDict", result) } return result } result := C.ZSTD_compressCCtx_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cctx.cctx))), - C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), + unsafe.Pointer(cctx.cctx), + unsafe.Pointer(dstHdr.Data), C.size_t(cap(dst)), - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), + unsafe.Pointer(srcHdr.Data), C.size_t(len(src)), C.int(compressionLevel)) // Prevent from GC'ing of dst and src during CGO call above. runtime.KeepAlive(dst) runtime.KeepAlive(src) if mustSucceed { - ensureNoError("ZSTD_compressCCtx_wrapper", result) + ensureNoError("ZSTD_compressCCtx", result) } return result } @@ -254,10 +256,8 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte } // Slow path - resize dst to fit decompressed data. - decompressBound := int(C.ZSTD_getFrameContentSize_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), C.size_t(len(src)))) - // Prevent from GC'ing of src during CGO call above. - runtime.KeepAlive(src) + srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + decompressBound := int(C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src)))) switch uint64(decompressBound) { case uint64(C.ZSTD_CONTENTSIZE_UNKNOWN): return streamDecompress(dst, src, dd) @@ -287,24 +287,28 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte } func decompressInternal(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) C.size_t { - var n C.size_t + var ( + dstHdr = (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + srcHdr = (*reflect.SliceHeader)(unsafe.Pointer(&src)) + n C.size_t + ) if dd != nil { n = C.ZSTD_decompress_usingDDict_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(dctxDict.dctx))), - C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), + unsafe.Pointer(dctxDict.dctx), + unsafe.Pointer(dstHdr.Data), C.size_t(cap(dst)), - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), + unsafe.Pointer(srcHdr.Data), C.size_t(len(src)), - C.uintptr_t(uintptr(unsafe.Pointer(dd.p)))) + unsafe.Pointer(dd.p)) } else { n = C.ZSTD_decompressDCtx_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(dctx.dctx))), - C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), + unsafe.Pointer(dctx.dctx), + unsafe.Pointer(dstHdr.Data), C.size_t(cap(dst)), - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), + unsafe.Pointer(srcHdr.Data), C.size_t(len(src))) } - // Prevent from GC'ing of dst and src during CGO calls above. + // Prevent from GC'ing of dst and src during CGO call above. runtime.KeepAlive(dst) runtime.KeepAlive(src) return n @@ -317,13 +321,17 @@ func errStr(result C.size_t) string { } func ensureNoError(funcName string, result C.size_t) { + if zstdIsError(result) { + panic(fmt.Errorf("BUG: unexpected error in %s: %s", funcName, errStr(result))) + } +} + +func zstdIsError(result C.size_t) bool { if int(result) >= 0 { // Fast path - avoid calling C function. - return - } - if C.ZSTD_getErrorCode(result) != 0 { - panic(fmt.Errorf("BUG: unexpected error in %s: %s", funcName, errStr(result))) + return false } + return C.ZSTD_isError(result) != 0 } func streamDecompress(dst, src []byte, dd *DDict) ([]byte, error) { diff --git a/reader.go b/reader.go index 6766b71..3923f4f 100644 --- a/reader.go +++ b/reader.go @@ -7,14 +7,18 @@ package gozstd #include "zstd.h" #include "zstd_errors.h" -#include // for malloc/free -#include // for uintptr_t +typedef struct { + size_t dstSize; + size_t srcSize; + size_t dstPos; + size_t srcPos; +} ZSTD_EXT_BufferSizes; // The following *_wrapper functions allow avoiding memory allocations // durting calls from Go. // See https://github.com/golang/go/issues/24450 . -static size_t ZSTD_initDStream_usingDDict_wrapper(uintptr_t ds, uintptr_t dict) { +static size_t ZSTD_initDStream_usingDDict_wrapper(void *ds, void *dict) { ZSTD_DStream *zds = (ZSTD_DStream *)ds; size_t rv = ZSTD_DCtx_reset(zds, ZSTD_reset_session_only); if (rv != 0) { @@ -23,23 +27,27 @@ static size_t ZSTD_initDStream_usingDDict_wrapper(uintptr_t ds, uintptr_t dict) return ZSTD_DCtx_refDDict(zds, (ZSTD_DDict *)dict); } -static size_t ZSTD_freeDStream_wrapper(uintptr_t ds) { +static size_t ZSTD_freeDStream_wrapper(void *ds) { return ZSTD_freeDStream((ZSTD_DStream*)ds); } -static size_t ZSTD_decompressStream_wrapper(uintptr_t ds, uintptr_t output, uintptr_t input) { - return ZSTD_decompressStream((ZSTD_DStream*)ds, (ZSTD_outBuffer*)output, (ZSTD_inBuffer*)input); +static size_t ZSTD_decompressStream_wrapper(void *ds, void* dst, const void* src, ZSTD_EXT_BufferSizes* sizes) { + return ZSTD_decompressStream_simpleArgs((ZSTD_DStream*)ds, dst, sizes->dstSize, &sizes->dstPos, src, sizes->srcSize, &sizes->srcPos); } */ import "C" import ( + "bytes" "fmt" "io" + "reflect" "runtime" "unsafe" ) +const minDirectWriteBufferSize = 16 * 1024 + var ( dstreamInBufSize = C.ZSTD_DStreamInSize() dstreamOutBufSize = C.ZSTD_DStreamOutSize() @@ -51,11 +59,17 @@ type Reader struct { ds *C.ZSTD_DStream dd *DDict - inBuf *C.ZSTD_inBuffer - outBuf *C.ZSTD_outBuffer + inBufWrapper *bytes.Buffer + outBufWrapper *bytes.Buffer + + skipNextRead bool - inBufGo cMemPtr - outBufGo cMemPtr + readerPos int + inBuf []byte + outBuf []byte + // go doesn't allow passing pointers to structs with pointers to Go memory + // so we can't use ZSTD_inBuffer and ZSTD_outBuffer directly + sizes C.ZSTD_EXT_BufferSizes } // NewReader returns new zstd reader reading compressed data from r. @@ -73,37 +87,29 @@ func NewReaderDict(r io.Reader, dd *DDict) *Reader { ds := C.ZSTD_createDStream() initDStream(ds, dd) - inBuf := (*C.ZSTD_inBuffer)(C.malloc(C.sizeof_ZSTD_inBuffer)) - inBuf.src = C.malloc(dstreamInBufSize) - inBuf.size = 0 - inBuf.pos = 0 - - outBuf := (*C.ZSTD_outBuffer)(C.malloc(C.sizeof_ZSTD_outBuffer)) - outBuf.dst = C.malloc(dstreamOutBufSize) - outBuf.size = 0 - outBuf.pos = 0 + inBufWrapper := decInBufPool.Get().(*bytes.Buffer) + outBufWrapper := decOutBufPool.Get().(*bytes.Buffer) zr := &Reader{ - r: r, - ds: ds, - dd: dd, - inBuf: inBuf, - outBuf: outBuf, + r: r, + ds: ds, + dd: dd, + inBufWrapper: inBufWrapper, + outBufWrapper: outBufWrapper, + inBuf: inBufWrapper.Bytes(), + outBuf: outBufWrapper.Bytes(), } - zr.inBufGo = cMemPtr(zr.inBuf.src) - zr.outBufGo = cMemPtr(zr.outBuf.dst) - runtime.SetFinalizer(zr, freeDStream) return zr } // Reset resets zr to read from r using the given dictionary dd. func (zr *Reader) Reset(r io.Reader, dd *DDict) { - zr.inBuf.size = 0 - zr.inBuf.pos = 0 - zr.outBuf.size = 0 - zr.outBuf.pos = 0 + zr.readerPos = 0 + zr.sizes = C.ZSTD_EXT_BufferSizes{} + zr.inBuf = zr.inBuf[:0] + zr.outBuf = zr.outBuf[:0] zr.dd = dd initDStream(zr.ds, zr.dd) @@ -116,9 +122,7 @@ func initDStream(ds *C.ZSTD_DStream, dd *DDict) { if dd != nil { ddict = dd.p } - result := C.ZSTD_initDStream_usingDDict_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(ds))), - C.uintptr_t(uintptr(unsafe.Pointer(ddict)))) + result := C.ZSTD_initDStream_usingDDict_wrapper(unsafe.Pointer(ds), unsafe.Pointer(ddict)) ensureNoError("ZSTD_initDStream_usingDDict", result) } @@ -134,21 +138,23 @@ func (zr *Reader) Release() { return } - result := C.ZSTD_freeDStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zr.ds)))) + result := C.ZSTD_freeDStream_wrapper(unsafe.Pointer(zr.ds)) ensureNoError("ZSTD_freeDStream", result) zr.ds = nil - C.free(zr.inBuf.src) - C.free(unsafe.Pointer(zr.inBuf)) - zr.inBuf = nil - - C.free(zr.outBuf.dst) - C.free(unsafe.Pointer(zr.outBuf)) - zr.outBuf = nil - zr.r = nil zr.dd = nil + + if zr.inBuf != nil { + zr.inBuf = nil + decInBufPool.Put(zr.inBufWrapper) + zr.inBufWrapper = nil + } + if zr.outBuf != nil { + zr.outBuf = nil + decOutBufPool.Put(zr.outBufWrapper) + zr.outBufWrapper = nil + } } // WriteTo writes all the data from zr to w. @@ -157,16 +163,17 @@ func (zr *Reader) Release() { func (zr *Reader) WriteTo(w io.Writer) (int64, error) { nn := int64(0) for { - if zr.outBuf.pos == zr.outBuf.size { - if err := zr.fillOutBuf(); err != nil { + if zr.readerPos >= len(zr.outBuf) { + if _, err := zr.fillOutBuf(nil); err != nil { if err == io.EOF { return nn, nil } return nn, err } + zr.readerPos = 0 } - n, err := w.Write(zr.outBufGo[zr.outBuf.pos:zr.outBuf.size]) - zr.outBuf.pos += C.size_t(n) + n, err := w.Write(zr.outBuf[zr.readerPos:]) + zr.readerPos += n nn += int64(n) if err != nil { return nn, err @@ -180,51 +187,68 @@ func (zr *Reader) Read(p []byte) (int, error) { return 0, nil } - if zr.outBuf.pos == zr.outBuf.size { - if err := zr.fillOutBuf(); err != nil { + if zr.readerPos >= len(zr.outBuf) { + if len(p) >= minDirectWriteBufferSize { + // write directly into the target buffer + // but make sure to override its capacity + return zr.fillOutBuf(p[:len(p):len(p)]) + } + if _, err := zr.fillOutBuf(nil); err != nil { return 0, err } + zr.readerPos = 0 } - n := copy(p, zr.outBufGo[zr.outBuf.pos:zr.outBuf.size]) - zr.outBuf.pos += C.size_t(n) + n := copy(p, zr.outBuf[zr.readerPos:]) + zr.readerPos += n return n, nil } -func (zr *Reader) fillOutBuf() error { - if zr.inBuf.pos == zr.inBuf.size && zr.outBuf.size < dstreamOutBufSize { +func (zr *Reader) fillOutBuf(target []byte) (int, error) { + dst := target + if dst == nil { + dst = zr.outBuf + } + + if int(zr.sizes.srcPos) == len(zr.inBuf) && !zr.skipNextRead { // inBuf is empty and the previously decompressed data size - // is smaller than the maximum possible zr.outBuf.size. + // is smaller than the maximum possible dst.size. // This means that the internal buffer in zr.ds doesn't contain // more data to decompress, so read new data into inBuf. if err := zr.fillInBuf(); err != nil { - return err + return 0, err } } + zr.sizes.dstSize = C.size_t(cap(dst)) + zr.sizes.dstPos = 0 + + inHdr := (*reflect.SliceHeader)(unsafe.Pointer(&zr.inBuf)) + outHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) tryDecompressAgain: + zr.sizes.srcSize = C.size_t(len(zr.inBuf)) + prevInBufPos := zr.sizes.srcPos + // Try decompressing inBuf into outBuf. - zr.outBuf.size = dstreamOutBufSize - zr.outBuf.pos = 0 - prevInBufPos := zr.inBuf.pos result := C.ZSTD_decompressStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zr.ds))), - C.uintptr_t(uintptr(unsafe.Pointer(zr.outBuf))), - C.uintptr_t(uintptr(unsafe.Pointer(zr.inBuf)))) - zr.outBuf.size = zr.outBuf.pos - zr.outBuf.pos = 0 - - if C.ZSTD_getErrorCode(result) != 0 { - return fmt.Errorf("cannot decompress data: %s", errStr(result)) + unsafe.Pointer(zr.ds), unsafe.Pointer(outHdr.Data), unsafe.Pointer(inHdr.Data), &zr.sizes) + + zr.skipNextRead = int(zr.sizes.dstPos) == cap(dst) + if target == nil { + zr.outBuf = zr.outBuf[:zr.sizes.dstPos] + } + + if zstdIsError(result) { + return int(zr.sizes.dstPos), fmt.Errorf("cannot decompress data: %s", errStr(result)) } - if zr.outBuf.size > 0 { + if zr.sizes.dstPos > 0 { // Something has been decompressed to outBuf. Return it. - return nil + return int(zr.sizes.dstPos), nil } // Nothing has been decompressed from inBuf. - if zr.inBuf.pos != prevInBufPos && zr.inBuf.pos < zr.inBuf.size { + if zr.sizes.srcPos != prevInBufPos && int(zr.sizes.srcPos) < len(zr.inBuf) { // Data has been consumed from inBuf, but decompressed // into nothing. There is more data in inBuf, so try // decompressing it again. @@ -235,21 +259,25 @@ tryDecompressAgain: // decompressed into nothing and inBuf became empty. // Read more data into inBuf and try decompressing again. if err := zr.fillInBuf(); err != nil { - return err + return 0, err } + goto tryDecompressAgain } func (zr *Reader) fillInBuf() error { // Copy the remaining data to the start of inBuf. - copy(zr.inBufGo[:dstreamInBufSize], zr.inBufGo[zr.inBuf.pos:zr.inBuf.size]) - zr.inBuf.size -= zr.inBuf.pos - zr.inBuf.pos = 0 + if zr.sizes.srcPos > 0 && int(zr.sizes.srcPos) > cap(zr.inBuf)/2 { + copy(zr.inBuf[:cap(zr.inBuf)], zr.inBuf[zr.sizes.srcPos:]) + zr.inBuf = zr.inBuf[:len(zr.inBuf)-int(zr.sizes.srcPos)] + zr.sizes.srcPos = 0 + } readAgain: // Read more data into inBuf. - n, err := zr.r.Read(zr.inBufGo[zr.inBuf.size:dstreamInBufSize]) - zr.inBuf.size += C.size_t(n) + n, err := zr.r.Read(zr.inBuf[len(zr.inBuf):cap(zr.inBuf)]) + zr.inBuf = zr.inBuf[:len(zr.inBuf)+n] + if err == nil { if n == 0 { // Nothing has been read. Try reading data again. @@ -265,5 +293,5 @@ readAgain: // Do not wrap io.EOF, so the caller may notify the end of stream. return err } - return fmt.Errorf("cannot read data from the underlying reader: %s", err) + return fmt.Errorf("cannot read data from the underlying reader: %w", err) } diff --git a/writer.go b/writer.go index ddfb85a..87a3be1 100644 --- a/writer.go +++ b/writer.go @@ -7,47 +7,70 @@ package gozstd #include "zstd.h" #include "zstd_errors.h" -#include // for malloc/free -#include // for uintptr_t +typedef struct { + size_t dstSize; + size_t srcSize; + size_t dstPos; + size_t srcPos; +} ZSTD_EXT_BufferSizes; // The following *_wrapper functions allow avoiding memory allocations // durting calls from Go. // See https://github.com/golang/go/issues/24450 . - -static size_t ZSTD_CCtx_setParameter_wrapper(uintptr_t cs, ZSTD_cParameter param, int value) { +static size_t ZSTD_CCtx_setParameter_wrapper(void *cs, ZSTD_cParameter param, int value) { return ZSTD_CCtx_setParameter((ZSTD_CStream*)cs, param, value); } -static size_t ZSTD_initCStream_wrapper(uintptr_t cs, int compressionLevel) { +static size_t ZSTD_initCStream_wrapper(void *cs, int compressionLevel) { return ZSTD_initCStream((ZSTD_CStream*)cs, compressionLevel); } -static size_t ZSTD_CCtx_refCDict_wrapper(uintptr_t cc, uintptr_t dict) { +static size_t ZSTD_CCtx_refCDict_wrapper(void *cc, void *dict) { return ZSTD_CCtx_refCDict((ZSTD_CCtx*)cc, (ZSTD_CDict*)dict); } -static size_t ZSTD_freeCStream_wrapper(uintptr_t cs) { +static size_t ZSTD_freeCStream_wrapper(void *cs) { return ZSTD_freeCStream((ZSTD_CStream*)cs); } -static size_t ZSTD_compressStream_wrapper(uintptr_t cs, uintptr_t output, uintptr_t input) { - return ZSTD_compressStream((ZSTD_CStream*)cs, (ZSTD_outBuffer*)output, (ZSTD_inBuffer*)input); +static size_t ZSTD_compressStream_wrapper(void *cs, void* dst, const void* src, ZSTD_EXT_BufferSizes* sizes, ZSTD_EndDirective endOp) { + return ZSTD_compressStream2_simpleArgs((ZSTD_CStream*)cs, dst, sizes->dstSize, &sizes->dstPos, src, sizes->srcSize, &sizes->srcPos, endOp); } -static size_t ZSTD_flushStream_wrapper(uintptr_t cs, uintptr_t output) { - return ZSTD_flushStream((ZSTD_CStream*)cs, (ZSTD_outBuffer*)output); +static size_t ZSTD_flushStream_wrapper(void *cs, void *dst, ZSTD_EXT_BufferSizes* sizes) { + size_t res; + ZSTD_outBuffer outBuf; + + outBuf.dst = dst; + outBuf.size = sizes->dstSize; + outBuf.pos = sizes->dstPos; + + res = ZSTD_flushStream((ZSTD_CStream*)cs, &outBuf); + sizes->dstPos = outBuf.pos; + return res; } -static size_t ZSTD_endStream_wrapper(uintptr_t cs, uintptr_t output) { - return ZSTD_endStream((ZSTD_CStream*)cs, (ZSTD_outBuffer*)output); +static size_t ZSTD_endStream_wrapper(void *cs, void *dst, ZSTD_EXT_BufferSizes* sizes) { + size_t res; + ZSTD_outBuffer outBuf; + + outBuf.dst = dst; + outBuf.size = sizes->dstSize; + outBuf.pos = sizes->dstPos; + + res = ZSTD_endStream((ZSTD_CStream*)cs, &outBuf); + sizes->dstPos = outBuf.pos; + return res; } */ import "C" import ( + "bytes" "fmt" "io" + "reflect" "runtime" "unsafe" ) @@ -57,8 +80,6 @@ var ( cstreamOutBufSize = C.ZSTD_CStreamOutSize() ) -type cMemPtr *[1 << 30]byte - // Writer implements zstd writer. type Writer struct { w io.Writer @@ -67,11 +88,12 @@ type Writer struct { cs *C.ZSTD_CStream cd *CDict - inBuf *C.ZSTD_inBuffer - outBuf *C.ZSTD_outBuffer + inBufWrapper *bytes.Buffer + outBufWrapper *bytes.Buffer - inBufGo cMemPtr - outBufGo cMemPtr + inBuf []byte + outBuf []byte + sizes C.ZSTD_EXT_BufferSizes } // NewWriter returns new zstd writer writing compressed data to w. @@ -160,15 +182,8 @@ func NewWriterParams(w io.Writer, params *WriterParams) *Writer { cs := C.ZSTD_createCStream() initCStream(cs, *params) - inBuf := (*C.ZSTD_inBuffer)(C.malloc(C.sizeof_ZSTD_inBuffer)) - inBuf.src = C.malloc(cstreamInBufSize) - inBuf.size = 0 - inBuf.pos = 0 - - outBuf := (*C.ZSTD_outBuffer)(C.malloc(C.sizeof_ZSTD_outBuffer)) - outBuf.dst = C.malloc(cstreamOutBufSize) - outBuf.size = cstreamOutBufSize - outBuf.pos = 0 + inBufWrapper := compInBufPool.Get().(*bytes.Buffer) + outBufWrapper := compOutBufPool.Get().(*bytes.Buffer) zw := &Writer{ w: w, @@ -176,13 +191,12 @@ func NewWriterParams(w io.Writer, params *WriterParams) *Writer { wlog: params.WindowLog, cs: cs, cd: params.Dict, - inBuf: inBuf, - outBuf: outBuf, + inBufWrapper: inBufWrapper, + outBufWrapper: outBufWrapper, + inBuf: inBufWrapper.Bytes(), + outBuf: outBufWrapper.Bytes(), } - zw.inBufGo = cMemPtr(zw.inBuf.src) - zw.outBufGo = cMemPtr(zw.outBuf.dst) - runtime.SetFinalizer(zw, freeCStream) return zw } @@ -201,10 +215,9 @@ func (zw *Writer) Reset(w io.Writer, cd *CDict, compressionLevel int) { // ResetWriterParams resets zw to write to w using the given set of parameters. func (zw *Writer) ResetWriterParams(w io.Writer, params *WriterParams) { - zw.inBuf.size = 0 - zw.inBuf.pos = 0 - zw.outBuf.size = cstreamOutBufSize - zw.outBuf.pos = 0 + zw.inBuf = zw.inBuf[:0] + zw.outBuf = zw.outBuf[:0] + zw.sizes = C.ZSTD_EXT_BufferSizes{} zw.cd = params.Dict initCStream(zw.cs, *params) @@ -215,18 +228,18 @@ func (zw *Writer) ResetWriterParams(w io.Writer, params *WriterParams) { func initCStream(cs *C.ZSTD_CStream, params WriterParams) { if params.Dict != nil { result := C.ZSTD_CCtx_refCDict_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cs))), - C.uintptr_t(uintptr(unsafe.Pointer(params.Dict.p)))) + unsafe.Pointer(cs), + unsafe.Pointer(params.Dict.p)) ensureNoError("ZSTD_CCtx_refCDict", result) } else { result := C.ZSTD_initCStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cs))), + unsafe.Pointer(cs), C.int(params.CompressionLevel)) ensureNoError("ZSTD_initCStream", result) } result := C.ZSTD_CCtx_setParameter_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cs))), + unsafe.Pointer(cs), C.ZSTD_cParameter(C.ZSTD_c_windowLog), C.int(params.WindowLog)) ensureNoError("ZSTD_CCtx_setParameter", result) @@ -244,21 +257,24 @@ func (zw *Writer) Release() { return } - result := C.ZSTD_freeCStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zw.cs)))) + result := C.ZSTD_freeCStream_wrapper(unsafe.Pointer(zw.cs)) ensureNoError("ZSTD_freeCStream", result) zw.cs = nil - C.free(unsafe.Pointer(zw.inBuf.src)) - C.free(unsafe.Pointer(zw.inBuf)) - zw.inBuf = nil - - C.free(unsafe.Pointer(zw.outBuf.dst)) - C.free(unsafe.Pointer(zw.outBuf)) - zw.outBuf = nil - zw.w = nil zw.cd = nil + + if zw.inBufWrapper != nil { + zw.inBuf = nil + compInBufPool.Put(zw.inBufWrapper) + zw.inBufWrapper = nil + } + + if zw.outBufWrapper != nil { + zw.outBuf = nil + compOutBufPool.Put(zw.outBufWrapper) + zw.outBufWrapper = nil + } } // ReadFrom reads all the data from r and writes it to zw. @@ -272,13 +288,15 @@ func (zw *Writer) Release() { func (zw *Writer) ReadFrom(r io.Reader) (int64, error) { nn := int64(0) for { + inBuf := zw.inBuf[len(zw.inBuf):cap(zw.inBuf)] // Fill the inBuf. - for zw.inBuf.size < cstreamInBufSize { - n, err := r.Read(zw.inBufGo[zw.inBuf.size:cstreamInBufSize]) + for len(inBuf) > 0 { + n, err := r.Read(inBuf) // Sometimes n > 0 even when Read() returns an error. // This is true especially if the error is io.EOF. - zw.inBuf.size += C.size_t(n) + inBuf = inBuf[n:] + zw.inBuf = zw.inBuf[:len(zw.inBuf)+n] nn += int64(n) if err != nil { @@ -309,8 +327,8 @@ func (zw *Writer) Write(p []byte) (int, error) { } for { - n := copy(zw.inBufGo[zw.inBuf.size:cstreamInBufSize], p) - zw.inBuf.size += C.size_t(n) + n := copy(zw.inBuf[len(zw.inBuf):cap(zw.inBuf)], p) + zw.inBuf = zw.inBuf[:len(zw.inBuf)+n] p = p[n:] if len(p) == 0 { // Fast path - just copy the data to input buffer. @@ -323,19 +341,30 @@ func (zw *Writer) Write(p []byte) (int, error) { } func (zw *Writer) flushInBuf() error { - prevInBufPos := zw.inBuf.pos + zw.sizes.dstSize = C.size_t(cap(zw.outBuf)) + zw.sizes.dstPos = C.size_t(len(zw.outBuf)) + zw.sizes.srcSize = C.size_t(len(zw.inBuf)) + zw.sizes.srcPos = 0 + + outHdr := (*reflect.SliceHeader)(unsafe.Pointer(&zw.outBuf)) + inHdr := (*reflect.SliceHeader)(unsafe.Pointer(&zw.inBuf)) + result := C.ZSTD_compressStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zw.cs))), - C.uintptr_t(uintptr(unsafe.Pointer(zw.outBuf))), - C.uintptr_t(uintptr(unsafe.Pointer(zw.inBuf)))) - ensureNoError("ZSTD_compressStream", result) + unsafe.Pointer(zw.cs), unsafe.Pointer(outHdr.Data), unsafe.Pointer(inHdr.Data), + &zw.sizes, C.ZSTD_e_continue) + ensureNoError("ZSTD_compressStream_wrapper", result) + + zw.outBuf = zw.outBuf[:zw.sizes.dstPos] // Move the remaining data to the start of inBuf. - copy(zw.inBufGo[:cstreamInBufSize], zw.inBufGo[zw.inBuf.pos:zw.inBuf.size]) - zw.inBuf.size -= zw.inBuf.pos - zw.inBuf.pos = 0 + if int(zw.sizes.srcPos) < len(zw.inBuf) { + copy(zw.inBuf[:cap(zw.inBuf)], zw.inBuf[zw.sizes.srcPos:len(zw.inBuf)]) + zw.inBuf = zw.inBuf[:len(zw.inBuf)-int(zw.sizes.srcPos)] + } else { + zw.inBuf = zw.inBuf[:0] + } - if zw.outBuf.size-zw.outBuf.pos > zw.outBuf.pos && prevInBufPos != zw.inBuf.pos { + if cap(zw.outBuf)-int(zw.sizes.dstPos) > int(zw.sizes.dstPos) && zw.sizes.srcPos > 0 { // There is enough space in outBuf and the last compression // succeeded, so don't flush outBuf yet. return nil @@ -347,20 +376,20 @@ func (zw *Writer) flushInBuf() error { } func (zw *Writer) flushOutBuf() error { - if zw.outBuf.pos == 0 { + if len(zw.outBuf) == 0 { // Nothing to flush. return nil } - outBuf := zw.outBufGo[:zw.outBuf.pos] - n, err := zw.w.Write(outBuf) - zw.outBuf.pos = 0 + bufLen := len(zw.outBuf) + n, err := zw.w.Write(zw.outBuf) + zw.outBuf = zw.outBuf[:0] if err != nil { return fmt.Errorf("cannot flush internal buffer to the underlying writer: %s", err) } - if n != len(outBuf) { + if n != bufLen { panic(fmt.Errorf("BUG: the underlying writer violated io.Writer contract and didn't return error after writing incomplete data; written %d bytes; want %d bytes", - n, len(outBuf))) + n, bufLen)) } return nil } @@ -368,7 +397,7 @@ func (zw *Writer) flushOutBuf() error { // Flush flushes the remaining data from zw to the underlying writer. func (zw *Writer) Flush() error { // Flush inBuf. - for zw.inBuf.size > 0 { + for len(zw.inBuf) > 0 { if err := zw.flushInBuf(); err != nil { return err } @@ -376,10 +405,14 @@ func (zw *Writer) Flush() error { // Flush the internal buffer to outBuf. for { + outHdr := (*reflect.SliceHeader)(unsafe.Pointer(&zw.outBuf)) + zw.sizes.dstSize = C.size_t(cap(zw.outBuf)) + zw.sizes.dstPos = C.size_t(len(zw.outBuf)) + result := C.ZSTD_flushStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zw.cs))), - C.uintptr_t(uintptr(unsafe.Pointer(zw.outBuf)))) + unsafe.Pointer(zw.cs), unsafe.Pointer(outHdr.Data), &zw.sizes) ensureNoError("ZSTD_flushStream", result) + zw.outBuf = zw.outBuf[:zw.sizes.dstPos] if err := zw.flushOutBuf(); err != nil { return err } @@ -400,10 +433,15 @@ func (zw *Writer) Close() error { } for { + outHdr := (*reflect.SliceHeader)(unsafe.Pointer(&zw.outBuf)) + zw.sizes.dstSize = C.size_t(cap(zw.outBuf)) + zw.sizes.dstPos = C.size_t(len(zw.outBuf)) + result := C.ZSTD_endStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zw.cs))), - C.uintptr_t(uintptr(unsafe.Pointer(zw.outBuf)))) + unsafe.Pointer(zw.cs), + unsafe.Pointer(outHdr.Data), &zw.sizes) ensureNoError("ZSTD_endStream", result) + zw.outBuf = zw.outBuf[:zw.sizes.dstPos] if err := zw.flushOutBuf(); err != nil { return err } From b90fa650e3fa777692f4a685deec05e43215add1 Mon Sep 17 00:00:00 2001 From: Michal Hruby Date: Sun, 25 Dec 2022 23:50:59 +0100 Subject: [PATCH 2/3] protect against malformed frames --- gozstd.go | 12 +++++++----- gozstd_test.go | 8 ++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/gozstd.go b/gozstd.go index 85f07d0..e90e0a2 100644 --- a/gozstd.go +++ b/gozstd.go @@ -45,6 +45,8 @@ import ( // DefaultCompressionLevel is the default compression level. const DefaultCompressionLevel = 3 // Obtained from ZSTD_CLEVEL_DEFAULT. +const maxFrameContentSize = 256 << 20 // 256 MB + // Compress appends compressed src to dst and returns the result. func Compress(dst, src []byte) []byte { return compressDictLevel(dst, src, nil, DefaultCompressionLevel) @@ -257,14 +259,14 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte // Slow path - resize dst to fit decompressed data. srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src)) - decompressBound := int(C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src)))) - switch uint64(decompressBound) { - case uint64(C.ZSTD_CONTENTSIZE_UNKNOWN): + contentSize := C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src))) + switch { + case contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN || contentSize > maxFrameContentSize: return streamDecompress(dst, src, dd) - case uint64(C.ZSTD_CONTENTSIZE_ERROR): + case contentSize == C.ZSTD_CONTENTSIZE_ERROR: return dst, fmt.Errorf("cannot decompress invalid src") } - decompressBound++ + decompressBound := int(contentSize) + 1 if n := dstLen + decompressBound - cap(dst); n > 0 { // This should be optimized since go 1.11 - see https://golang.org/doc/go1.11#performance-compiler. diff --git a/gozstd_test.go b/gozstd_test.go index 3a465b8..e870505 100644 --- a/gozstd_test.go +++ b/gozstd_test.go @@ -54,6 +54,14 @@ func TestDecompressSmallBlockWithoutSingleSegmentFlag(t *testing.T) { }) } +func TestDecompressTooLarge(t *testing.T) { + src := []byte{40, 181, 47, 253, 228, 122, 118, 105, 67, 140, 234, 85, 20, 159, 67} + _, err := Decompress(nil, src) + if err == nil { + t.Fatalf("expecting error when decompressing malformed frame") + } +} + func mustUnhex(dataHex string) []byte { data, err := hex.DecodeString(dataHex) if err != nil { From 4edc66abd4340f5df453a81be52e6f5b5e93080b Mon Sep 17 00:00:00 2001 From: Michal Hruby Date: Mon, 26 Dec 2022 17:16:53 +0100 Subject: [PATCH 3/3] add noescape annotations --- gozstd.go | 24 +++++++++++++++++++----- gozstd_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/gozstd.go b/gozstd.go index e90e0a2..c01e2b5 100644 --- a/gozstd.go +++ b/gozstd.go @@ -147,9 +147,22 @@ func compress(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressi return dst } +// noescape hides a pointer from escape analysis. It is the identity function +// but escape analysis doesn't think the output depends on the input. +// noescape is inlined and currently compiles down to zero instructions. +// This is copied from go's strings.Builder. Allows us to use stack-allocated +// slices. +//go:nosplit +//go:nocheckptr +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + func compressInternal(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressionLevel int, mustSucceed bool) C.size_t { - dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) - srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + // using noescape will allow this to work with stack-allocated slices + dstHdr := (*reflect.SliceHeader)(noescape(unsafe.Pointer(&dst))) + srcHdr := (*reflect.SliceHeader)(noescape(unsafe.Pointer(&src))) if cd != nil { result := C.ZSTD_compress_usingCDict_wrapper( @@ -180,6 +193,7 @@ func compressInternal(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, c if mustSucceed { ensureNoError("ZSTD_compressCCtx", result) } + return result } @@ -258,7 +272,7 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte } // Slow path - resize dst to fit decompressed data. - srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + srcHdr := (*reflect.SliceHeader)(noescape(unsafe.Pointer(&src))) contentSize := C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src))) switch { case contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN || contentSize > maxFrameContentSize: @@ -290,8 +304,8 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte func decompressInternal(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) C.size_t { var ( - dstHdr = (*reflect.SliceHeader)(unsafe.Pointer(&dst)) - srcHdr = (*reflect.SliceHeader)(unsafe.Pointer(&src)) + dstHdr = (*reflect.SliceHeader)(noescape(unsafe.Pointer(&dst))) + srcHdr = (*reflect.SliceHeader)(noescape(unsafe.Pointer(&src))) n C.size_t ) if dd != nil { diff --git a/gozstd_test.go b/gozstd_test.go index e870505..7729b66 100644 --- a/gozstd_test.go +++ b/gozstd_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/hex" "fmt" + "io" "math/rand" "runtime" "strings" @@ -54,6 +55,14 @@ func TestDecompressSmallBlockWithoutSingleSegmentFlag(t *testing.T) { }) } +func TestCompressEmpty(t *testing.T) { + var dst [64]byte + res := Compress(dst[:0], nil) + if len(res) > 0 { + t.Fatalf("unexpected non-empty compressed frame: %X", res) + } +} + func TestDecompressTooLarge(t *testing.T) { src := []byte{40, 181, 47, 253, 228, 122, 118, 105, 67, 140, 234, 85, 20, 159, 67} _, err := Decompress(nil, src) @@ -70,6 +79,48 @@ func mustUnhex(dataHex string) []byte { return data } +func TestCompressWithStackMove(t *testing.T) { + var srcBuf [96]byte + + n, err := io.ReadFull(rand.New(rand.NewSource(time.Now().Unix())), srcBuf[:]) + if err != nil { + t.Fatalf("cannot fill srcBuf with random data: %s", err) + } + + // We're running this twice, because the first run will allocate + // objects in sync.Pool, calls to which extend the stack, and the second + // run can skip those allocations and extend the stack right before + // the CGO call. + // Note that this test might require some go:nosplit annotations + // to force the stack move to happen exactly before the CGO call. + for i := 0; i < 2; i++ { + ch := make(chan struct{}) + go func() { + defer close(ch) + + var dstBuf [1416]byte + + res := Compress(dstBuf[:0], srcBuf[:n]) + + // make a copy of the result, so the original can remain on the stack + compressedCpy := make([]byte, len(res)) + copy(compressedCpy, res) + + orig, err := Decompress(nil, compressedCpy) + if err != nil { + panic(fmt.Errorf("cannot decompress: %s", err)) + } + if !bytes.Equal(orig, srcBuf[:n]) { + panic(fmt.Errorf("unexpected decompressed data; got %q; want %q", orig, srcBuf[:n])) + } + }() + // wait for the goroutine to finish + <-ch + } + + runtime.GC() +} + func TestCompressDecompressDistinctConcurrentDicts(t *testing.T) { // Build multiple distinct dicts. var cdicts []*CDict