From 817675a979e8777cbbb914fd3f40ca6df381cfd4 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Thu, 23 May 2024 10:42:17 -0400 Subject: [PATCH 1/4] Remake format guess functions --- src/snappy/snappy.py | 87 ++++++++++++++++++++++++------------ src/snappy/snappy_formats.py | 79 +++++++++++++++++++++++--------- 2 files changed, 115 insertions(+), 51 deletions(-) diff --git a/src/snappy/snappy.py b/src/snappy/snappy.py index aa1a22e..8e1e93b 100644 --- a/src/snappy/snappy.py +++ b/src/snappy/snappy.py @@ -149,23 +149,16 @@ def __init__(self): self.remains = None @staticmethod - def check_format(data): + def check_format(fin): """Checks that the given data starts with snappy framing format stream identifier. Raises UncompressError if it doesn't start with the identifier. :return: None """ - if len(data) < 6: - raise UncompressError("Too short data length") - chunk_type = struct.unpack("> 8) - chunk_type &= 0xff - if (chunk_type != _IDENTIFIER_CHUNK or - size != len(_STREAM_IDENTIFIER)): - raise UncompressError("stream missing snappy identifier") - chunk = data[4:4 + size] - if chunk != _STREAM_IDENTIFIER: - raise UncompressError("stream has invalid snappy identifier") + try: + return fin.read(len(_STREAM_HEADER_BLOCK)) == _STREAM_HEADER_BLOCK + except: + return False def decompress(self, data: bytes): """Decompress 'data', returning a string containing the uncompressed @@ -233,14 +226,23 @@ def __init__(self): self.remains = b"" @staticmethod - def check_format(data): + def check_format(fin): """Checks that there are enough bytes for a hadoop header We cannot actually determine if the data is really hadoop-snappy """ - if len(data) < 8: - raise UncompressError("Too short data length") - chunk_length = int.from_bytes(data[4:8], "big") + try: + from snappy.snappy_formats import check_unframed_format + size = fin.seek(0, 2) + fin.seek(0) + assert size >= 8 + + chunk_length = int.from_bytes(fin.read(4), "big") + assert chunk_length < size + fin.read(4) + return check_unframed_format(fin) + except: + return False def decompress(self, data: bytes): """Decompress 'data', returning a string containing the uncompressed @@ -319,16 +321,43 @@ def stream_decompress(src, decompressor.flush() # makes sure the stream ended well -def check_format(fin=None, chunk=None, - blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, - decompressor_cls=StreamDecompressor): - ok = True - if chunk is None: - chunk = fin.read(blocksize) - if not chunk: - raise UncompressError("Empty input stream") - try: - decompressor_cls.check_format(chunk) - except UncompressError as err: - ok = False - return ok, chunk +def hadoop_stream_decompress( + src, + dst, + blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, +): + c = HadoopStreamDecompressor() + while True: + data = src.read(blocksize) + if not data: + break + buf = c.decompress(data) + if buf: + dst.write() + dst.flush() + + +def hadoop_stream_compress( + src, + dst, + blocksize=_STREAM_TO_STREAM_BLOCK_SIZE, +): + c = HadoopStreamCompressor() + while True: + data = src.read(blocksize) + if not data: + break + buf = c.compress(data) + if buf: + dst.write() + dst.flush() + + +def raw_stream_decompress(src, dst): + data = src.read() + dst.write(decompress(data)) + + +def raw_stream_compress(src, dst): + data = src.read() + dst.write(compress(data)) diff --git a/src/snappy/snappy_formats.py b/src/snappy/snappy_formats.py index 51a54dd..1439418 100644 --- a/src/snappy/snappy_formats.py +++ b/src/snappy/snappy_formats.py @@ -8,40 +8,73 @@ from __future__ import absolute_import from .snappy import ( - stream_compress, stream_decompress, check_format, UncompressError) - + HadoopStreamDecompressor, StreamDecompressor, + hadoop_stream_compress, hadoop_stream_decompress, raw_stream_compress, + raw_stream_decompress, stream_compress, stream_decompress, + UncompressError +) -FRAMING_FORMAT = 'framing' # Means format auto detection. # For compression will be used framing format. # In case of decompression will try to detect a format from the input stream # header. -FORMAT_AUTO = 'auto' - -DEFAULT_FORMAT = FORMAT_AUTO +DEFAULT_FORMAT = "auto" -ALL_SUPPORTED_FORMATS = [FRAMING_FORMAT, FORMAT_AUTO] +ALL_SUPPORTED_FORMATS = ["framing", "auto"] _COMPRESS_METHODS = { - FRAMING_FORMAT: stream_compress, + "framing": stream_compress, + "hadoop": hadoop_stream_compress, + "raw": raw_stream_compress } _DECOMPRESS_METHODS = { - FRAMING_FORMAT: stream_decompress, + "framing": stream_decompress, + "hadoop": hadoop_stream_decompress, + "raw": raw_stream_decompress } # We will use framing format as the default to compression. # And for decompression, if it's not defined explicitly, we will try to # guess the format from the file header. -_DEFAULT_COMPRESS_FORMAT = FRAMING_FORMAT +_DEFAULT_COMPRESS_FORMAT = "framing" + + +def uvarint(fin): + result = 0 + shift = 0 + while True: + byte = fin.read(1)[0] + result |= (byte & 0x7F) << shift + if (byte & 0x80) == 0: + break + shift += 7 + return result + + +def check_unframed_format(fin): + fin.seek(0) + try: + size = uvarint(fin) + assert size < 2**32 - 1 + next_byte = fin.read(1)[0] + end = fin.seek(0, 2) + assert size < end + assert next_byte & 0b11 == 0 # must start with literal block + return True + except: + return False + # The tuple contains an ordered sequence of a format checking function and # a format-specific decompression function. # Framing format has it's header, that may be recognized. -_DECOMPRESS_FORMAT_FUNCS = ( - (check_format, stream_decompress), -) +_DECOMPRESS_FORMAT_FUNCS = { + "framed": stream_decompress, + "hadoop": hadoop_stream_decompress, + "raw": raw_stream_decompress +} def guess_format_by_header(fin): @@ -50,23 +83,25 @@ def guess_format_by_header(fin): :return: tuple of decompression method and a chunk that was taken from the input for format detection. """ - chunk = None - for check_method, decompress_func in _DECOMPRESS_FORMAT_FUNCS: - ok, chunk = check_method(fin=fin, chunk=chunk) - if not ok: - continue - return decompress_func, chunk - raise UncompressError("Can't detect archive format") + if StreamDecompressor.check_format(fin): + form = "framed" + elif HadoopStreamDecompressor.check_format(fin): + form = "hadoop" + elif check_unframed_format(fin): + form = "raw" + else: + raise UncompressError("Can't detect format") + return form, _DECOMPRESS_FORMAT_FUNCS[form] def get_decompress_function(specified_format, fin): - if specified_format == FORMAT_AUTO: + if specified_format == "auto": decompress_func, read_chunk = guess_format_by_header(fin) return decompress_func, read_chunk return _DECOMPRESS_METHODS[specified_format], None def get_compress_function(specified_format): - if specified_format == FORMAT_AUTO: + if specified_format == "auto": return _COMPRESS_METHODS[_DEFAULT_COMPRESS_FORMAT] return _COMPRESS_METHODS[specified_format] From d882ccecc32f6315b361c1d7d9f2a3cd51d8120d Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Thu, 23 May 2024 10:47:24 -0400 Subject: [PATCH 2/4] fix --- src/snappy/snappy_formats.py | 6 +++--- test_formats.py | 20 +++++++++----------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/snappy/snappy_formats.py b/src/snappy/snappy_formats.py index 1439418..2b61496 100644 --- a/src/snappy/snappy_formats.py +++ b/src/snappy/snappy_formats.py @@ -96,9 +96,9 @@ def guess_format_by_header(fin): def get_decompress_function(specified_format, fin): if specified_format == "auto": - decompress_func, read_chunk = guess_format_by_header(fin) - return decompress_func, read_chunk - return _DECOMPRESS_METHODS[specified_format], None + format, decompress_func = guess_format_by_header(fin) + return decompress_func + return _DECOMPRESS_METHODS[specified_format] def get_compress_function(specified_format): diff --git a/test_formats.py b/test_formats.py index 43afb91..4e499d7 100644 --- a/test_formats.py +++ b/test_formats.py @@ -3,12 +3,11 @@ from unittest import TestCase from snappy import snappy_formats as formats -from snappy.snappy import _CHUNK_MAX, UncompressError class TestFormatBase(TestCase): - compress_format = formats.FORMAT_AUTO - decompress_format = formats.FORMAT_AUTO + compress_format = "auto" + decompress_format = "auto" success = True def runTest(self): @@ -18,34 +17,33 @@ def runTest(self): compressed_stream = io.BytesIO() compress_func(instream, compressed_stream) compressed_stream.seek(0) - decompress_func, read_chunk = formats.get_decompress_function( + decompress_func = formats.get_decompress_function( self.decompress_format, compressed_stream ) decompressed_stream = io.BytesIO() decompress_func( compressed_stream, decompressed_stream, - start_chunk=read_chunk ) decompressed_stream.seek(0) self.assertEqual(data, decompressed_stream.read()) class TestFormatFramingFraming(TestFormatBase): - compress_format = formats.FRAMING_FORMAT - decompress_format = formats.FRAMING_FORMAT + compress_format = "framing" + decompress_format = "framing" success = True class TestFormatFramingAuto(TestFormatBase): - compress_format = formats.FRAMING_FORMAT - decompress_format = formats.FORMAT_AUTO + compress_format = "framing" + decompress_format = "auto" success = True class TestFormatAutoFraming(TestFormatBase): - compress_format = formats.FORMAT_AUTO - decompress_format = formats.FRAMING_FORMAT + compress_format = "auto" + decompress_format = "framing" success = True From a7e94c6f5a6e7e90afbbd5a926c878735721e11c Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Thu, 23 May 2024 10:57:04 -0400 Subject: [PATCH 3/4] add tests --- src/snappy/snappy.py | 4 ++-- src/snappy/snappy_formats.py | 7 ++++--- test_formats.py | 25 +++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/snappy/snappy.py b/src/snappy/snappy.py index 8e1e93b..f8e5fe2 100644 --- a/src/snappy/snappy.py +++ b/src/snappy/snappy.py @@ -333,7 +333,7 @@ def hadoop_stream_decompress( break buf = c.decompress(data) if buf: - dst.write() + dst.write(buf) dst.flush() @@ -349,7 +349,7 @@ def hadoop_stream_compress( break buf = c.compress(data) if buf: - dst.write() + dst.write(buf) dst.flush() diff --git a/src/snappy/snappy_formats.py b/src/snappy/snappy_formats.py index 2b61496..ead6dbb 100644 --- a/src/snappy/snappy_formats.py +++ b/src/snappy/snappy_formats.py @@ -53,8 +53,9 @@ def uvarint(fin): return result -def check_unframed_format(fin): - fin.seek(0) +def check_unframed_format(fin, reset=False): + if reset: + fin.seek(0) try: size = uvarint(fin) assert size < 2**32 - 1 @@ -87,7 +88,7 @@ def guess_format_by_header(fin): form = "framed" elif HadoopStreamDecompressor.check_format(fin): form = "hadoop" - elif check_unframed_format(fin): + elif check_unframed_format(fin, reset=True): form = "raw" else: raise UncompressError("Can't detect format") diff --git a/test_formats.py b/test_formats.py index 4e499d7..6453b1e 100644 --- a/test_formats.py +++ b/test_formats.py @@ -20,6 +20,7 @@ def runTest(self): decompress_func = formats.get_decompress_function( self.decompress_format, compressed_stream ) + compressed_stream.seek(0) decompressed_stream = io.BytesIO() decompress_func( compressed_stream, @@ -47,6 +48,30 @@ class TestFormatAutoFraming(TestFormatBase): success = True +class TestFormatHadoop(TestFormatBase): + compress_format = "hadoop" + decompress_format = "hadoop" + success = True + + +class TestFormatRaw(TestFormatBase): + compress_format = "raw" + decompress_format = "raw" + success = True + + +class TestFormatHadoopAuto(TestFormatBase): + compress_format = "hadoop" + decompress_format = "auto" + success = True + + +class TestFormatRawAuto(TestFormatBase): + compress_format = "raw" + decompress_format = "auto" + success = True + + if __name__ == "__main__": import unittest unittest.main() From 0906e06d4504d3192006cb42f0b9c1c724edc68a Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Thu, 23 May 2024 11:01:16 -0400 Subject: [PATCH 4/4] docstrings --- src/snappy/snappy.py | 11 ++++------- src/snappy/snappy_formats.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/snappy/snappy.py b/src/snappy/snappy.py index f8e5fe2..6bd2b8b 100644 --- a/src/snappy/snappy.py +++ b/src/snappy/snappy.py @@ -150,10 +150,9 @@ def __init__(self): @staticmethod def check_format(fin): - """Checks that the given data starts with snappy framing format - stream identifier. - Raises UncompressError if it doesn't start with the identifier. - :return: None + """Does this stream start with a stream header block? + + True indicates that the stream can likely be decoded using this class. """ try: return fin.read(len(_STREAM_HEADER_BLOCK)) == _STREAM_HEADER_BLOCK @@ -227,9 +226,7 @@ def __init__(self): @staticmethod def check_format(fin): - """Checks that there are enough bytes for a hadoop header - - We cannot actually determine if the data is really hadoop-snappy + """Does this look like a hadoop snappy stream? """ try: from snappy.snappy_formats import check_unframed_format diff --git a/src/snappy/snappy_formats.py b/src/snappy/snappy_formats.py index ead6dbb..e230e0b 100644 --- a/src/snappy/snappy_formats.py +++ b/src/snappy/snappy_formats.py @@ -42,6 +42,7 @@ def uvarint(fin): + """Read uint64 nbumber from varint encoding in a stream""" result = 0 shift = 0 while True: @@ -54,6 +55,11 @@ def uvarint(fin): def check_unframed_format(fin, reset=False): + """Can this be read using the raw codec + + This function wil return True for all snappy raw streams, but + True does not mean that we can necessarily decode the stream. + """ if reset: fin.seek(0) try: @@ -81,8 +87,8 @@ def check_unframed_format(fin, reset=False): def guess_format_by_header(fin): """Tries to guess a compression format for the given input file by it's header. - :return: tuple of decompression method and a chunk that was taken from the - input for format detection. + + :return: format name (str), stream decompress function (callable) """ if StreamDecompressor.check_format(fin): form = "framed"