Skip to content

Commit

Permalink
GZip file writer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675508141
  • Loading branch information
achoum authored and copybara-github committed Sep 17, 2024
1 parent bccdcb7 commit fa0bd2f
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 34 deletions.
2 changes: 2 additions & 0 deletions yggdrasil_decision_forests/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,7 @@ cc_library_ydf(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@zlib",
],
)
Expand Down Expand Up @@ -1363,6 +1364,7 @@ cc_test(
":test",
":zlib",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
6 changes: 6 additions & 0 deletions yggdrasil_decision_forests/utils/bytestream.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ class OutputByteStream {

// Writes a chunk of bytes.
virtual absl::Status Write(absl::string_view chunk) = 0;

// Interrupts the stream. If closed, the stream cannot be used anymore.
// Streams are not required to be closed (i.e., they are closed automatically
// at destruction). However, closing a stream explicitly allows to capture the
// closing status.
virtual absl::Status Close() { return absl::OkStatus(); }
};

class StringOutputByteStream : public OutputByteStream {
Expand Down
4 changes: 2 additions & 2 deletions yggdrasil_decision_forests/utils/filesystem_default.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class FileInputByteStream
absl::Status Open(absl::string_view path);
absl::StatusOr<int> ReadUpTo(char* buffer, int max_read) override;
absl::StatusOr<bool> ReadExactly(char* buffer, int num_read) override;
absl::Status Close();
absl::Status Close() override;

private:
std::ifstream file_stream_;
Expand All @@ -122,7 +122,7 @@ class FileOutputByteStream
public:
absl::Status Open(absl::string_view path);
absl::Status Write(absl::string_view chunk) override;
absl::Status Close();
absl::Status Close() override;

private:
std::ofstream file_stream_;
Expand Down
157 changes: 129 additions & 28 deletions yggdrasil_decision_forests/utils/zlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,44 @@
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "yggdrasil_decision_forests/utils/bytestream.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"
#include <zconf.h>

#define ZLIB_CONST
#include <zlib.h>

namespace yggdrasil_decision_forests::utils {

absl::StatusOr<std::unique_ptr<GZipInputByteStream>>
GZipInputByteStream::Create(std::unique_ptr<utils::InputByteStream>&& stream,
size_t buffer_size) {
return std::make_unique<GZipInputByteStream>(std::move(stream), buffer_size);
auto gz_stream =
std::make_unique<GZipInputByteStream>(std::move(stream), buffer_size);

gz_stream->deflate_stream_.zalloc = Z_NULL;
gz_stream->deflate_stream_.zfree = Z_NULL;
gz_stream->deflate_stream_.opaque = Z_NULL;
gz_stream->deflate_stream_.avail_in = 0;
gz_stream->deflate_stream_.next_in = Z_NULL;
if (inflateInit2(&gz_stream->deflate_stream_, 16 + MAX_WBITS) != Z_OK) {
return absl::InternalError("Cannot initialize gzip stream");
}
// gz_stream->deflate_stream_.next_in = gz_stream->input_buffer_.data();
// gz_stream->deflate_stream_.avail_in = 0;
// gz_stream->deflate_stream_.next_out = gz_stream->output_buffer_.data();
// gz_stream->deflate_stream_.avail_out = 0;
gz_stream->deflate_stream_is_allocated_ = true;
return gz_stream;
}

GZipInputByteStream::GZipInputByteStream(
std::unique_ptr<utils::InputByteStream>&& stream, size_t buffer_size)
: buffer_size_(buffer_size), stream_(std::move(stream)) {
deflate_stream_.zalloc = Z_NULL;
deflate_stream_.zfree = Z_NULL;
deflate_stream_.opaque = Z_NULL;
deflate_stream_.avail_in = 0;
deflate_stream_.next_in = Z_NULL;
if (inflateInit2(&deflate_stream_, 16 + MAX_WBITS) != Z_OK) {
CHECK(false);
}

input_buffer_.resize(buffer_size_);
output_buffer_.resize(buffer_size_);
deflate_stream_.next_in = input_buffer_.data();
deflate_stream_.next_out = output_buffer_.data();

deflate_stream_.avail_in = 0;
deflate_stream_.avail_out = 0;
}

absl::StatusOr<int> GZipInputByteStream::ReadUpTo(char* buffer, int max_read) {
Expand All @@ -84,21 +90,15 @@ absl::StatusOr<int> GZipInputByteStream::ReadUpTo(char* buffer, int max_read) {
// 2. Continue the decompression of the data in the input buffer.
deflate_stream_.avail_out = buffer_size_;
deflate_stream_.next_out = output_buffer_.data();
int error_status = inflate(&deflate_stream_, Z_NO_FLUSH);
switch (error_status) {
case Z_NEED_DICT:
case Z_DATA_ERROR:
case Z_MEM_ERROR:
inflateEnd(&deflate_stream_);
return absl::InternalError("Internal error");
const auto zlib_error = inflate(&deflate_stream_, Z_NO_FLUSH);
if (zlib_error != Z_OK && zlib_error != Z_STREAM_END) {
inflateEnd(&deflate_stream_);
return absl::InternalError(absl::StrCat("Internal error", zlib_error));
}
const int num_decompressed = buffer_size_ - deflate_stream_.avail_out;
output_buffer_begin_ = 0;
output_buffer_end_ = num_decompressed;

if (num_decompressed == 0) {
break;
}
const int produced_bytes = buffer_size_ - deflate_stream_.avail_out;
output_buffer_begin_ = 0;
output_buffer_end_ = produced_bytes;
}

// 3. Read non-compressed data from the underlying stream to the input
Expand Down Expand Up @@ -154,6 +154,107 @@ absl::Status GZipInputByteStream::CloseDeflateStream() {
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<GZipOutputByteStream>>
GZipOutputByteStream::Create(std::unique_ptr<utils::OutputByteStream>&& stream,
int compression_level, size_t buffer_size) {
if (compression_level != Z_DEFAULT_COMPRESSION) {
STATUS_CHECK_GT(compression_level, Z_NO_COMPRESSION);
STATUS_CHECK_LT(compression_level, Z_BEST_COMPRESSION);
}
auto gz_stream =
std::make_unique<GZipOutputByteStream>(std::move(stream), buffer_size);

gz_stream->deflate_stream_.zalloc = Z_NULL;
gz_stream->deflate_stream_.zfree = Z_NULL;
gz_stream->deflate_stream_.opaque = Z_NULL;
gz_stream->deflate_stream_.avail_in = 0;
gz_stream->deflate_stream_.next_in = Z_NULL;
if (deflateInit2(&gz_stream->deflate_stream_, compression_level, Z_DEFLATED,
MAX_WBITS + 16,
/*memLevel=*/8, // 8 is the recommended default
Z_DEFAULT_STRATEGY) != Z_OK) {
return absl::InternalError("Cannot initialize gzip stream");
}
gz_stream->deflate_stream_is_allocated_ = true;
return gz_stream;
}

GZipOutputByteStream::GZipOutputByteStream(
std::unique_ptr<utils::OutputByteStream>&& stream, size_t buffer_size)
: buffer_size_(buffer_size), stream_(std::move(stream)) {
output_buffer_.resize(buffer_size_);
}

GZipOutputByteStream::~GZipOutputByteStream() {
CloseInflateStream().IgnoreError();
}

absl::Status GZipOutputByteStream::Write(absl::string_view chunk) {
return WriteImpl(chunk, false);
}

absl::Status GZipOutputByteStream::WriteImpl(absl::string_view chunk,
bool flush) {
if (chunk.empty() && !flush) {
return absl::OkStatus();
}
deflate_stream_.next_in = reinterpret_cast<const Bytef*>(chunk.data());
deflate_stream_.avail_in = chunk.size();

while (true) {
deflate_stream_.next_out = output_buffer_.data();
deflate_stream_.avail_out = buffer_size_;

const auto zlib_error =
deflate(&deflate_stream_, flush ? Z_FINISH : Z_NO_FLUSH);

if (flush) {
if (zlib_error != Z_STREAM_END && !chunk.empty()) {
deflateEnd(&deflate_stream_);
return absl::InternalError(absl::StrCat("Internal error ", zlib_error,
". Output buffer too small"));
}
} else {
if (zlib_error != Z_OK) {
deflateEnd(&deflate_stream_);
return absl::InternalError(absl::StrCat("Internal error ", zlib_error));
}
}

const size_t compressed_bytes = buffer_size_ - deflate_stream_.avail_out;

if (compressed_bytes > 0) {
RETURN_IF_ERROR(stream_->Write(absl::string_view{
reinterpret_cast<char*>(output_buffer_.data()), compressed_bytes}));
}

if (deflate_stream_.avail_out != 0) {
break;
}
}

return absl::OkStatus();
}

absl::Status GZipOutputByteStream::Close() {
RETURN_IF_ERROR(CloseInflateStream());
if (stream_) {
return stream_->Close();
}
return absl::OkStatus();
}

absl::Status GZipOutputByteStream::CloseInflateStream() {
if (deflate_stream_is_allocated_) {
deflate_stream_is_allocated_ = false;
RETURN_IF_ERROR(WriteImpl("", true));
if (deflateEnd(&deflate_stream_) != Z_OK) {
return absl::InternalError("Cannot close deflate");
}
}
return absl::OkStatus();
}

} // namespace yggdrasil_decision_forests::utils

#endif // THIRD_PARTY_YGGDRASIL_DECISION_FORESTS_UTILS_GZIP_H_
40 changes: 36 additions & 4 deletions yggdrasil_decision_forests/utils/zlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

#include <cstddef>
#include <memory>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "yggdrasil_decision_forests/utils/bytestream.h"

#define ZLIB_CONST
#include <zlib.h>

namespace yggdrasil_decision_forests::utils {
Expand All @@ -31,15 +33,15 @@ class GZipInputByteStream : public utils::InputByteStream {
public:
static absl::StatusOr<std::unique_ptr<GZipInputByteStream>> Create(
std::unique_ptr<utils::InputByteStream>&& stream,
size_t buffer_size = 3 /*1024 * 1024*/);
size_t buffer_size = 1024 * 1024);

GZipInputByteStream(std::unique_ptr<utils::InputByteStream>&& stream,
size_t buffer_size);
~GZipInputByteStream() override;

absl::StatusOr<int> ReadUpTo(char* buffer, int max_read) override;
absl::StatusOr<bool> ReadExactly(char* buffer, int num_read) override;
absl::Status Close();
absl::Status Close() override;

private:
absl::Status CloseDeflateStream();
Expand All @@ -59,7 +61,37 @@ class GZipInputByteStream : public utils::InputByteStream {
// zlib decompression state machine.
z_stream deflate_stream_;
// Was "deflate_stream_" allocated?
bool deflate_stream_is_allocated_ = true;
bool deflate_stream_is_allocated_ = false;
};

class GZipOutputByteStream : public utils::OutputByteStream {
public:
static absl::StatusOr<std::unique_ptr<GZipOutputByteStream>> Create(
std::unique_ptr<utils::OutputByteStream>&& stream,
int compression_level = Z_DEFAULT_COMPRESSION,
size_t buffer_size = 1024 * 1024);

GZipOutputByteStream(std::unique_ptr<utils::OutputByteStream>&& stream,
size_t buffer_size);
~GZipOutputByteStream() override;

absl::Status Write(absl::string_view chunk) override;
absl::Status Close() override;

private:
absl::Status CloseInflateStream();
absl::Status WriteImpl(absl::string_view chunk, bool flush);

// Size of the compressed and uncompressed buffers.
size_t buffer_size_;
// Underlying stream of compressed data.
std::unique_ptr<utils::OutputByteStream> stream_;
// Buffer of compressed data.
std::vector<Bytef> output_buffer_;
// zlib decompression state machine.
z_stream deflate_stream_;
// Was "deflate_stream_" allocated?
bool deflate_stream_is_allocated_ = false;
};

} // namespace yggdrasil_decision_forests::utils
Expand Down
Loading

0 comments on commit fa0bd2f

Please sign in to comment.