Skip to content

Commit

Permalink
decompressionobj: implement reading across frames
Browse files Browse the repository at this point in the history
All APIs need to gain the ability to transparently read across
frames. This commit teaches the stdlib decompressionobj interface to
transparently read across frames.

Behavior is controlled via a boolean named argument. It defaults to
False to preserve existing behavior.

Related to #196.

I haven't tested this very thoroughly. But the added fuzzing test
not failing is usually a good indicator that this works as intended.
  • Loading branch information
indygreg committed May 15, 2023
1 parent f781a5f commit 685dc6f
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 22 deletions.
9 changes: 8 additions & 1 deletion c-ext/decompressobj.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ static PyObject *DecompressionObj_decompress(ZstdDecompressionObj *self,
}
}

if (0 == zresult) {
if (0 == zresult && !self->readAcrossFrames) {
self->finished = 1;

/* We should only get here at most once. */
Expand All @@ -98,6 +98,13 @@ static PyObject *DecompressionObj_decompress(ZstdDecompressionObj *self,

break;
}
else if (0 == zresult && self->readAcrossFrames) {
if (input.pos == input.size) {
break;
} else {
output.pos = 0;
}
}
else if (input.pos == input.size && output.pos == 0) {
break;
}
Expand Down
9 changes: 6 additions & 3 deletions c-ext/decompressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,14 @@ PyObject *Decompressor_decompress(ZstdDecompressor *self, PyObject *args,
static ZstdDecompressionObj *Decompressor_decompressobj(ZstdDecompressor *self,
PyObject *args,
PyObject *kwargs) {
static char *kwlist[] = {"write_size", NULL};
static char *kwlist[] = {"write_size", "read_across_frames", NULL};

ZstdDecompressionObj *result = NULL;
size_t outSize = ZSTD_DStreamOutSize();
PyObject *readAcrossFrames = NULL;

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|k:decompressobj", kwlist,
&outSize)) {
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|kO:decompressobj", kwlist,
&outSize, &readAcrossFrames)) {
return NULL;
}

Expand All @@ -426,6 +427,8 @@ static ZstdDecompressionObj *Decompressor_decompressobj(ZstdDecompressor *self,
result->decompressor = self;
Py_INCREF(result->decompressor);
result->outSize = outSize;
result->readAcrossFrames =
readAcrossFrames ? PyObject_IsTrue(readAcrossFrames) : 0;

return result;
}
Expand Down
1 change: 1 addition & 0 deletions c-ext/python-zstandard.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ typedef struct {

ZstdDecompressor *decompressor;
size_t outSize;
int readAcrossFrames;
int finished;
PyObject *unused_data;
} ZstdDecompressionObj;
Expand Down
16 changes: 15 additions & 1 deletion docs/news.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,21 @@ Other Actions Not Blocking Release
0.22.0 (not yet released)
=========================

None yet.
Backwards Compatibility Notes
-----------------------------

* ``ZstdDecompressor.decompressobj()`` will change ``read_across_frames`` to
default to ``True`` in a future release. If you depend on the current
functionality of stopping at frame boundaries, start explicitly passing
``read_across_frames=False`` to preserve the current behavior.

Changes
-------

* ``ZstdDecompressor.decompressobj()`` now accepts a ``read_across_frames``
boolean named argument to control whether to transparently read across
multiple zstd frames. It defaults to ``False`` to preserve existing
behavior.

0.21.0 (released 2023-04-16)
============================
Expand Down
16 changes: 14 additions & 2 deletions rust-ext/src/decompressionobj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@ use {
pub struct ZstdDecompressionObj {
dctx: Arc<DCtx<'static>>,
write_size: usize,
read_across_frames: bool,
finished: bool,
unused_data: Vec<u8>,
}

impl ZstdDecompressionObj {
pub fn new(dctx: Arc<DCtx<'static>>, write_size: usize) -> PyResult<Self> {
pub fn new(
dctx: Arc<DCtx<'static>>,
write_size: usize,
read_across_frames: bool,
) -> PyResult<Self> {
Ok(ZstdDecompressionObj {
dctx,
write_size,
read_across_frames,
finished: false,
unused_data: vec![],
})
Expand Down Expand Up @@ -68,7 +74,7 @@ impl ZstdDecompressionObj {
chunks.append(chunk)?;
}

if zresult == 0 {
if zresult == 0 && !self.read_across_frames {
self.finished = true;
// TODO clear out decompressor?

Expand All @@ -78,6 +84,12 @@ impl ZstdDecompressionObj {
}

break;
} else if zresult == 0 && self.read_across_frames {
if in_buffer.pos == in_buffer.size {
break;
} else {
dest_buffer.clear();
}
} else if in_buffer.pos == in_buffer.size && dest_buffer.len() < dest_buffer.capacity()
{
break;
Expand Down
5 changes: 3 additions & 2 deletions rust-ext/src/decompressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,12 @@ impl ZstdDecompressor {
Ok(PyBytes::new(py, &last_buffer))
}

#[pyo3(signature = (write_size=None))]
#[pyo3(signature = (write_size=None, read_across_frames=false))]
fn decompressobj(
&self,
py: Python,
write_size: Option<usize>,
read_across_frames: bool,
) -> PyResult<ZstdDecompressionObj> {
if let Some(write_size) = write_size {
if write_size < 1 {
Expand All @@ -387,7 +388,7 @@ impl ZstdDecompressor {

self.setup_dctx(py, true)?;

ZstdDecompressionObj::new(self.dctx.clone(), write_size)
ZstdDecompressionObj::new(self.dctx.clone(), write_size, read_across_frames)
}

fn memory_size(&self) -> usize {
Expand Down
26 changes: 25 additions & 1 deletion tests/test_decompressor_decompressobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_write_size(self):
dobj = dctx.decompressobj(write_size=i + 1)
self.assertEqual(dobj.decompress(data), source)

def test_multiple_frames(self):
def test_multiple_frames_default(self):
cctx = zstd.ZstdCompressor()
foo = cctx.compress(b"foo")
bar = cctx.compress(b"bar")
Expand All @@ -121,3 +121,27 @@ def test_multiple_frames(self):
self.assertEqual(dobj.decompress(foo + bar), b"foo")
self.assertEqual(dobj.unused_data, bar)
self.assertEqual(dobj.unconsumed_tail, b"")

def test_read_across_frames_false(self):
cctx = zstd.ZstdCompressor()
foo = cctx.compress(b"foo")
bar = cctx.compress(b"bar")

dctx = zstd.ZstdDecompressor()
dobj = dctx.decompressobj(read_across_frames=False)

self.assertEqual(dobj.decompress(foo + bar), b"foo")
self.assertEqual(dobj.unused_data, bar)
self.assertEqual(dobj.unconsumed_tail, b"")

def test_read_across_frames_true(self):
cctx = zstd.ZstdCompressor()
foo = cctx.compress(b"foo")
bar = cctx.compress(b"bar")

dctx = zstd.ZstdDecompressor()
dobj = dctx.decompressobj(read_across_frames=True)

self.assertEqual(dobj.decompress(foo + bar), b"foobar")
self.assertEqual(dobj.unused_data, b"")
self.assertEqual(dobj.unconsumed_tail, b"")
54 changes: 52 additions & 2 deletions tests/test_decompressor_fuzzing.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,9 @@ def test_random_output_sizes(
),
read_sizes=strategies.data(),
)
def test_multiple_frames(self, chunks, level, write_size, read_sizes):
def test_read_across_frames_false(
self, chunks, level, write_size, read_sizes
):
cctx = zstd.ZstdCompressor(level=level)

source = io.BytesIO()
Expand All @@ -545,7 +547,9 @@ def test_multiple_frames(self, chunks, level, write_size, read_sizes):
compressed.seek(0)

dctx = zstd.ZstdDecompressor()
dobj = dctx.decompressobj(write_size=write_size)
dobj = dctx.decompressobj(
write_size=write_size, read_across_frames=False
)

decompressed = io.BytesIO()

Expand All @@ -565,6 +569,52 @@ def test_multiple_frames(self, chunks, level, write_size, read_sizes):

self.assertEqual(decompressed.getvalue(), source_chunks[0])

@hypothesis.given(
chunks=strategies.lists(
strategies.sampled_from(random_input_data()),
min_size=2,
max_size=10,
),
level=strategies.integers(min_value=1, max_value=5),
write_size=strategies.integers(
min_value=1,
max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
),
read_sizes=strategies.data(),
)
def test_read_across_frames_true(
self, chunks, level, write_size, read_sizes
):
cctx = zstd.ZstdCompressor(level=level)

source = io.BytesIO()
source_chunks = []
compressed = io.BytesIO()

for chunk in chunks:
source.write(chunk)
source_chunks.append(chunk)
compressed.write(cctx.compress(chunk))

compressed.seek(0)

dctx = zstd.ZstdDecompressor()
dobj = dctx.decompressobj(
write_size=write_size, read_across_frames=True
)

decompressed = io.BytesIO()

while True:
read_size = read_sizes.draw(strategies.integers(1, 4096))
chunk = compressed.read(read_size)
if not chunk:
break

decompressed.write(dobj.decompress(chunk))

self.assertEqual(decompressed.getvalue(), source.getvalue())


@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")
class TestDecompressor_read_to_iter_fuzzing(unittest.TestCase):
Expand Down
4 changes: 3 additions & 1 deletion zstandard/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,9 @@ class ZstdDecompressor(object):
*,
closefd=False,
) -> ZstdDecompressionReader: ...
def decompressobj(self, write_size: int = ...) -> ZstdDecompressionObj: ...
def decompressobj(
self, write_size: int = ..., read_across_frames: bool = False
) -> ZstdDecompressionObj: ...
def read_to_iter(
self,
reader: Union[IO[bytes], ByteString],
Expand Down
39 changes: 30 additions & 9 deletions zstandard/backend_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2918,8 +2918,9 @@ class ZstdDecompressionObj(object):
subsequent calls needs to be concatenated to reassemble the full
decompressed byte sequence.
Each instance is single use: once an input frame is decoded,
``decompress()`` can no longer be called.
If ``read_across_frames=False``, each instance is single use: once an
input frame is decoded, ``decompress()`` can no longer be called. If
``read_across_frames=True``, instances can decode multiple frames.
>>> dctx = zstandard.ZstdDecompressor()
>>> dobj = dctx.decompressobj()
Expand All @@ -2941,10 +2942,11 @@ class ZstdDecompressionObj(object):
efficient as other APIs.
"""

def __init__(self, decompressor, write_size):
def __init__(self, decompressor, write_size, read_across_frames):
self._decompressor = decompressor
self._write_size = write_size
self._finished = False
self._read_across_frames = read_across_frames
self._unused_input = b""

def decompress(self, data):
Expand Down Expand Up @@ -2991,13 +2993,22 @@ def decompress(self, data):
chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])

# 0 is only seen when a frame is fully decoded *and* fully flushed.
# But there may be extra input data: make that available to
# `unused_input`.
if zresult == 0:
# Behavior depends on whether we're in single or multiple frame
# mode.
if zresult == 0 and not self._read_across_frames:
# Mark the instance as done and make any unconsumed input available
# for retrieval.
self._finished = True
self._decompressor = None
self._unused_input = data[in_buffer.pos : in_buffer.size]
break
elif zresult == 0 and self._read_across_frames:
# We're at the end of a fully flushed frame and we can read more.
# Try to read more if there's any more input.
if in_buffer.pos == in_buffer.size:
break
else:
out_buffer.pos = 0

# We're not at the end of the frame *or* we're not fully flushed.

Expand Down Expand Up @@ -3899,21 +3910,31 @@ def stream_reader(
self, source, read_size, read_across_frames, closefd=closefd
)

def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
def decompressobj(
self,
write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
read_across_frames=False,
):
"""Obtain a standard library compatible incremental decompressor.
See :py:class:`ZstdDecompressionObj` for more documentation
and usage examples.
:param write_size:
:param write_size: size of internal output buffer to collect decompressed
chunks in.
:param read_across_frames: whether to read across multiple zstd frames.
If False, reading stops after 1 frame and subsequent decompress
attempts will raise an exception.
:return:
:py:class:`zstandard.ZstdDecompressionObj`
"""
if write_size < 1:
raise ValueError("write_size must be positive")

self._ensure_dctx()
return ZstdDecompressionObj(self, write_size=write_size)
return ZstdDecompressionObj(
self, write_size=write_size, read_across_frames=read_across_frames
)

def read_to_iter(
self,
Expand Down

0 comments on commit 685dc6f

Please sign in to comment.