From fa0bd2f0e47c77231206c7d37d447f2686583fef Mon Sep 17 00:00:00 2001 From: Mathieu Guillame-Bert Date: Tue, 17 Sep 2024 04:12:23 -0700 Subject: [PATCH] GZip file writer PiperOrigin-RevId: 675508141 --- yggdrasil_decision_forests/utils/BUILD | 2 + yggdrasil_decision_forests/utils/bytestream.h | 6 + .../utils/filesystem_default.h | 4 +- yggdrasil_decision_forests/utils/zlib.cc | 157 ++++++++++++++---- yggdrasil_decision_forests/utils/zlib.h | 40 ++++- yggdrasil_decision_forests/utils/zlib_test.cc | 67 ++++++++ 6 files changed, 242 insertions(+), 34 deletions(-) diff --git a/yggdrasil_decision_forests/utils/BUILD b/yggdrasil_decision_forests/utils/BUILD index c8b3b02a..d0a790f3 100644 --- a/yggdrasil_decision_forests/utils/BUILD +++ b/yggdrasil_decision_forests/utils/BUILD @@ -895,6 +895,7 @@ cc_library_ydf( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@zlib", ], ) @@ -1363,6 +1364,7 @@ cc_test( ":test", ":zlib", "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) diff --git a/yggdrasil_decision_forests/utils/bytestream.h b/yggdrasil_decision_forests/utils/bytestream.h index 39964aad..42d1d392 100644 --- a/yggdrasil_decision_forests/utils/bytestream.h +++ b/yggdrasil_decision_forests/utils/bytestream.h @@ -92,6 +92,12 @@ class OutputByteStream { // Writes a chunk of bytes. virtual absl::Status Write(absl::string_view chunk) = 0; + + // Interrupts the stream. If closed, the stream cannot be used anymore. + // Streams are not required to be closed (i.e., they are closed automatically + // at destruction). However, closing a stream explicitly allows to capture the + // closing status. + virtual absl::Status Close() { return absl::OkStatus(); } }; class StringOutputByteStream : public OutputByteStream { diff --git a/yggdrasil_decision_forests/utils/filesystem_default.h b/yggdrasil_decision_forests/utils/filesystem_default.h index 431b2651..43d90df6 100644 --- a/yggdrasil_decision_forests/utils/filesystem_default.h +++ b/yggdrasil_decision_forests/utils/filesystem_default.h @@ -111,7 +111,7 @@ class FileInputByteStream absl::Status Open(absl::string_view path); absl::StatusOr ReadUpTo(char* buffer, int max_read) override; absl::StatusOr ReadExactly(char* buffer, int num_read) override; - absl::Status Close(); + absl::Status Close() override; private: std::ifstream file_stream_; @@ -122,7 +122,7 @@ class FileOutputByteStream public: absl::Status Open(absl::string_view path); absl::Status Write(absl::string_view chunk) override; - absl::Status Close(); + absl::Status Close() override; private: std::ofstream file_stream_; diff --git a/yggdrasil_decision_forests/utils/zlib.cc b/yggdrasil_decision_forests/utils/zlib.cc index 46fe3654..ecb3a6bb 100644 --- a/yggdrasil_decision_forests/utils/zlib.cc +++ b/yggdrasil_decision_forests/utils/zlib.cc @@ -28,9 +28,13 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "yggdrasil_decision_forests/utils/bytestream.h" #include "yggdrasil_decision_forests/utils/status_macros.h" #include + +#define ZLIB_CONST #include namespace yggdrasil_decision_forests::utils { @@ -38,28 +42,30 @@ namespace yggdrasil_decision_forests::utils { absl::StatusOr> GZipInputByteStream::Create(std::unique_ptr&& stream, size_t buffer_size) { - return std::make_unique(std::move(stream), buffer_size); + auto gz_stream = + std::make_unique(std::move(stream), buffer_size); + + gz_stream->deflate_stream_.zalloc = Z_NULL; + gz_stream->deflate_stream_.zfree = Z_NULL; + gz_stream->deflate_stream_.opaque = Z_NULL; + gz_stream->deflate_stream_.avail_in = 0; + gz_stream->deflate_stream_.next_in = Z_NULL; + if (inflateInit2(&gz_stream->deflate_stream_, 16 + MAX_WBITS) != Z_OK) { + return absl::InternalError("Cannot initialize gzip stream"); + } + // gz_stream->deflate_stream_.next_in = gz_stream->input_buffer_.data(); + // gz_stream->deflate_stream_.avail_in = 0; + // gz_stream->deflate_stream_.next_out = gz_stream->output_buffer_.data(); + // gz_stream->deflate_stream_.avail_out = 0; + gz_stream->deflate_stream_is_allocated_ = true; + return gz_stream; } GZipInputByteStream::GZipInputByteStream( std::unique_ptr&& stream, size_t buffer_size) : buffer_size_(buffer_size), stream_(std::move(stream)) { - deflate_stream_.zalloc = Z_NULL; - deflate_stream_.zfree = Z_NULL; - deflate_stream_.opaque = Z_NULL; - deflate_stream_.avail_in = 0; - deflate_stream_.next_in = Z_NULL; - if (inflateInit2(&deflate_stream_, 16 + MAX_WBITS) != Z_OK) { - CHECK(false); - } - input_buffer_.resize(buffer_size_); output_buffer_.resize(buffer_size_); - deflate_stream_.next_in = input_buffer_.data(); - deflate_stream_.next_out = output_buffer_.data(); - - deflate_stream_.avail_in = 0; - deflate_stream_.avail_out = 0; } absl::StatusOr GZipInputByteStream::ReadUpTo(char* buffer, int max_read) { @@ -84,21 +90,15 @@ absl::StatusOr GZipInputByteStream::ReadUpTo(char* buffer, int max_read) { // 2. Continue the decompression of the data in the input buffer. deflate_stream_.avail_out = buffer_size_; deflate_stream_.next_out = output_buffer_.data(); - int error_status = inflate(&deflate_stream_, Z_NO_FLUSH); - switch (error_status) { - case Z_NEED_DICT: - case Z_DATA_ERROR: - case Z_MEM_ERROR: - inflateEnd(&deflate_stream_); - return absl::InternalError("Internal error"); + const auto zlib_error = inflate(&deflate_stream_, Z_NO_FLUSH); + if (zlib_error != Z_OK && zlib_error != Z_STREAM_END) { + inflateEnd(&deflate_stream_); + return absl::InternalError(absl::StrCat("Internal error", zlib_error)); } - const int num_decompressed = buffer_size_ - deflate_stream_.avail_out; - output_buffer_begin_ = 0; - output_buffer_end_ = num_decompressed; - if (num_decompressed == 0) { - break; - } + const int produced_bytes = buffer_size_ - deflate_stream_.avail_out; + output_buffer_begin_ = 0; + output_buffer_end_ = produced_bytes; } // 3. Read non-compressed data from the underlying stream to the input @@ -154,6 +154,107 @@ absl::Status GZipInputByteStream::CloseDeflateStream() { return absl::OkStatus(); } +absl::StatusOr> +GZipOutputByteStream::Create(std::unique_ptr&& stream, + int compression_level, size_t buffer_size) { + if (compression_level != Z_DEFAULT_COMPRESSION) { + STATUS_CHECK_GT(compression_level, Z_NO_COMPRESSION); + STATUS_CHECK_LT(compression_level, Z_BEST_COMPRESSION); + } + auto gz_stream = + std::make_unique(std::move(stream), buffer_size); + + gz_stream->deflate_stream_.zalloc = Z_NULL; + gz_stream->deflate_stream_.zfree = Z_NULL; + gz_stream->deflate_stream_.opaque = Z_NULL; + gz_stream->deflate_stream_.avail_in = 0; + gz_stream->deflate_stream_.next_in = Z_NULL; + if (deflateInit2(&gz_stream->deflate_stream_, compression_level, Z_DEFLATED, + MAX_WBITS + 16, + /*memLevel=*/8, // 8 is the recommended default + Z_DEFAULT_STRATEGY) != Z_OK) { + return absl::InternalError("Cannot initialize gzip stream"); + } + gz_stream->deflate_stream_is_allocated_ = true; + return gz_stream; +} + +GZipOutputByteStream::GZipOutputByteStream( + std::unique_ptr&& stream, size_t buffer_size) + : buffer_size_(buffer_size), stream_(std::move(stream)) { + output_buffer_.resize(buffer_size_); +} + +GZipOutputByteStream::~GZipOutputByteStream() { + CloseInflateStream().IgnoreError(); +} + +absl::Status GZipOutputByteStream::Write(absl::string_view chunk) { + return WriteImpl(chunk, false); +} + +absl::Status GZipOutputByteStream::WriteImpl(absl::string_view chunk, + bool flush) { + if (chunk.empty() && !flush) { + return absl::OkStatus(); + } + deflate_stream_.next_in = reinterpret_cast(chunk.data()); + deflate_stream_.avail_in = chunk.size(); + + while (true) { + deflate_stream_.next_out = output_buffer_.data(); + deflate_stream_.avail_out = buffer_size_; + + const auto zlib_error = + deflate(&deflate_stream_, flush ? Z_FINISH : Z_NO_FLUSH); + + if (flush) { + if (zlib_error != Z_STREAM_END && !chunk.empty()) { + deflateEnd(&deflate_stream_); + return absl::InternalError(absl::StrCat("Internal error ", zlib_error, + ". Output buffer too small")); + } + } else { + if (zlib_error != Z_OK) { + deflateEnd(&deflate_stream_); + return absl::InternalError(absl::StrCat("Internal error ", zlib_error)); + } + } + + const size_t compressed_bytes = buffer_size_ - deflate_stream_.avail_out; + + if (compressed_bytes > 0) { + RETURN_IF_ERROR(stream_->Write(absl::string_view{ + reinterpret_cast(output_buffer_.data()), compressed_bytes})); + } + + if (deflate_stream_.avail_out != 0) { + break; + } + } + + return absl::OkStatus(); +} + +absl::Status GZipOutputByteStream::Close() { + RETURN_IF_ERROR(CloseInflateStream()); + if (stream_) { + return stream_->Close(); + } + return absl::OkStatus(); +} + +absl::Status GZipOutputByteStream::CloseInflateStream() { + if (deflate_stream_is_allocated_) { + deflate_stream_is_allocated_ = false; + RETURN_IF_ERROR(WriteImpl("", true)); + if (deflateEnd(&deflate_stream_) != Z_OK) { + return absl::InternalError("Cannot close deflate"); + } + } + return absl::OkStatus(); +} + } // namespace yggdrasil_decision_forests::utils #endif // THIRD_PARTY_YGGDRASIL_DECISION_FORESTS_UTILS_GZIP_H_ diff --git a/yggdrasil_decision_forests/utils/zlib.h b/yggdrasil_decision_forests/utils/zlib.h index 21adddb1..bc26ee2f 100644 --- a/yggdrasil_decision_forests/utils/zlib.h +++ b/yggdrasil_decision_forests/utils/zlib.h @@ -18,11 +18,13 @@ #include #include -#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "yggdrasil_decision_forests/utils/bytestream.h" + +#define ZLIB_CONST #include namespace yggdrasil_decision_forests::utils { @@ -31,7 +33,7 @@ class GZipInputByteStream : public utils::InputByteStream { public: static absl::StatusOr> Create( std::unique_ptr&& stream, - size_t buffer_size = 3 /*1024 * 1024*/); + size_t buffer_size = 1024 * 1024); GZipInputByteStream(std::unique_ptr&& stream, size_t buffer_size); @@ -39,7 +41,7 @@ class GZipInputByteStream : public utils::InputByteStream { absl::StatusOr ReadUpTo(char* buffer, int max_read) override; absl::StatusOr ReadExactly(char* buffer, int num_read) override; - absl::Status Close(); + absl::Status Close() override; private: absl::Status CloseDeflateStream(); @@ -59,7 +61,37 @@ class GZipInputByteStream : public utils::InputByteStream { // zlib decompression state machine. z_stream deflate_stream_; // Was "deflate_stream_" allocated? - bool deflate_stream_is_allocated_ = true; + bool deflate_stream_is_allocated_ = false; +}; + +class GZipOutputByteStream : public utils::OutputByteStream { + public: + static absl::StatusOr> Create( + std::unique_ptr&& stream, + int compression_level = Z_DEFAULT_COMPRESSION, + size_t buffer_size = 1024 * 1024); + + GZipOutputByteStream(std::unique_ptr&& stream, + size_t buffer_size); + ~GZipOutputByteStream() override; + + absl::Status Write(absl::string_view chunk) override; + absl::Status Close() override; + + private: + absl::Status CloseInflateStream(); + absl::Status WriteImpl(absl::string_view chunk, bool flush); + + // Size of the compressed and uncompressed buffers. + size_t buffer_size_; + // Underlying stream of compressed data. + std::unique_ptr stream_; + // Buffer of compressed data. + std::vector output_buffer_; + // zlib decompression state machine. + z_stream deflate_stream_; + // Was "deflate_stream_" allocated? + bool deflate_stream_is_allocated_ = false; }; } // namespace yggdrasil_decision_forests::utils diff --git a/yggdrasil_decision_forests/utils/zlib_test.cc b/yggdrasil_decision_forests/utils/zlib_test.cc index 200bbc77..512e07d5 100644 --- a/yggdrasil_decision_forests/utils/zlib_test.cc +++ b/yggdrasil_decision_forests/utils/zlib_test.cc @@ -15,13 +15,17 @@ #include "yggdrasil_decision_forests/utils/zlib.h" +#include + #include +#include #include #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "yggdrasil_decision_forests/utils/filesystem.h" #include "yggdrasil_decision_forests/utils/logging.h" #include "yggdrasil_decision_forests/utils/test.h" @@ -29,6 +33,7 @@ namespace yggdrasil_decision_forests::utils { namespace { +using ::testing::TestWithParam; using yggdrasil_decision_forests::test::DataRootDirectory; std::string HelloPath() { @@ -71,5 +76,67 @@ TEST(GZip, ReadUnit) { EXPECT_OK(stream->Close()); } +struct GZipTestCase { + size_t content_size; + size_t buffer_size; + int compression; + bool random = true; +}; + +using GZipTestCaseTest = TestWithParam; + +INSTANTIATE_TEST_SUITE_P( + GZipTestCaseTestSuiteInstantiation, GZipTestCaseTest, + testing::ValuesIn({{1024 * 2, 1024, 8}, + {10, 1024, 8}, + {10, 256, 8}, + {0, 1024 * 1024, 8}, + {10, 1024 * 1024, 8}, + {10 * 1024 * 1024, 1024 * 1024, 8}, + {10 * 1024 * 1024, 1024 * 1024, 8, false}, + {10 * 1024 * 1024, 1024 * 1024, 1}, + {10 * 1024 * 1024, 1024 * 1024, -1}})); + +TEST_P(GZipTestCaseTest, WriteAndRead) { + const GZipTestCase& test_case = GetParam(); + std::uniform_int_distribution dist(0, 10); + std::mt19937_64 rng(1); + + auto tmp_dir = test::TmpDirectory(); + auto file_path = file::JoinPath(tmp_dir, "my_file.txt.gz"); + std::string content; + content.reserve(test_case.content_size); + if (test_case.random) { + for (int i = 0; i < test_case.content_size; i++) { + absl::StrAppend(&content, dist(rng)); + } + } else { + for (int i = 0; i < test_case.content_size / 2; i++) { + absl::StrAppend(&content, "AB"); + } + } + + { + auto file_stream = file::OpenOutputFile(file_path).value(); + auto stream = GZipOutputByteStream::Create(std::move(file_stream), + test_case.compression, + test_case.buffer_size) + .value(); + EXPECT_OK(stream->Write(content)); + EXPECT_OK(stream->Write(content)); + EXPECT_OK(stream->Close()); + } + + { + auto file_stream = file::OpenInputFile(file_path).value(); + auto stream = GZipInputByteStream::Create(std::move(file_stream), + test_case.buffer_size) + .value(); + auto read_content = stream->ReadAll().value(); + EXPECT_OK(stream->Close()); + EXPECT_EQ(content + content, read_content); + } +} + } // namespace } // namespace yggdrasil_decision_forests::utils