Skip to content

Commit

Permalink
Add support for compressed tfrecords without tf dependencies.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673388515
  • Loading branch information
achoum authored and copybara-github committed Sep 11, 2024
1 parent b7f5472 commit 0b87800
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 76 deletions.
3 changes: 3 additions & 0 deletions yggdrasil_decision_forests/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ cc_library_ydf(
deps = [
":formats_cc_proto",
"//yggdrasil_decision_forests/utils:logging",
"//yggdrasil_decision_forests/utils:status_macros",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -83,7 +84,9 @@ cc_library_ydf(
":example_reader_interface",
":formats",
":formats_cc_proto",
"//yggdrasil_decision_forests/utils:logging",
"//yggdrasil_decision_forests/utils:status_macros",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
Expand Down
3 changes: 3 additions & 0 deletions yggdrasil_decision_forests/dataset/example_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
#include "yggdrasil_decision_forests/dataset/example_reader_interface.h"
#include "yggdrasil_decision_forests/dataset/formats.h"
#include "yggdrasil_decision_forests/dataset/formats.pb.h"
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"

namespace yggdrasil_decision_forests {
Expand Down Expand Up @@ -56,6 +58,7 @@ absl::StatusOr<std::unique_ptr<ExampleReaderInterface>> CreateExampleReader(
absl::StatusOr<bool> IsFormatSupported(absl::string_view typed_path) {
const auto path_format_or = GetDatasetPathAndTypeOrStatus(typed_path);
if (!path_format_or.ok()) {
LOG(WARNING) << "Cannot parse typed path: " << path_format_or.status();
return false;
}
std::string sharded_path;
Expand Down
17 changes: 12 additions & 5 deletions yggdrasil_decision_forests/dataset/example_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,19 @@ std::string DatasetDir() {
"yggdrasil_decision_forests/"
"test_data/dataset");
}

TEST(ExampleReader, CreateExampleReader) {
for (const auto& dataset_path :
{absl::StrCat("tfrecord+tfe:",
file::JoinPath(DatasetDir(), "toy.tfe-tfrecord@2")),
absl::StrCat("csv:", file::JoinPath(DatasetDir(), "toy.csv"))}) {
for (const auto& dataset_path : {
absl::StrCat("tfrecord+tfe:",
file::JoinPath(DatasetDir(), "toy.tfe-tfrecord@2")),
absl::StrCat("tfrecord:",
file::JoinPath(DatasetDir(), "toy.tfe-tfrecord@2")),
absl::StrCat("tfrecordv2+gz+tfe:",
file::JoinPath(DatasetDir(), "toy.tfe-tfrecord@2")),
absl::StrCat(
"tfrecordv2+tfe:",
file::JoinPath(DatasetDir(), "toy.nocompress-tfe-tfrecord@2")),
absl::StrCat("csv:", file::JoinPath(DatasetDir(), "toy.csv")),
}) {
LOG(INFO) << "Create dataspec for " << dataset_path;
proto::DataSpecificationGuide guide;
proto::DataSpecification data_spec;
Expand Down
138 changes: 86 additions & 52 deletions yggdrasil_decision_forests/dataset/formats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,66 @@
#include "absl/strings/substitute.h"
#include "yggdrasil_decision_forests/dataset/formats.pb.h"
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"

namespace yggdrasil_decision_forests {
namespace dataset {

using proto::DatasetFormat;
struct Format {
absl::string_view extension;
absl::string_view prefix;
absl::string_view prefix_alias = "";
proto::DatasetFormat proto_format;
};

const std::vector<Format>& GetFormats() {
static const std::vector<Format>* formats = []() {
auto* formats = new std::vector<Format>();

// RFC 4180-compliant CSV file.
formats->push_back({
.extension = "csv",
.prefix = FORMAT_CSV,
.proto_format = proto::FORMAT_CSV,
});

// GZip compressed TF Record of binary serialized TensorFlow.Example proto.
// Use TensorFlow API to read the file (required TensorFlow C++ to be
// linked). Deprecated: Use FORMAT_TFE_TFRECORD_COMPRESSED_V2 instead.
formats->push_back({
.extension = "tfrecord",
.prefix = FORMAT_TFE_TFRECORD,
.proto_format = proto::FORMAT_TFE_TFRECORD,
});

// Uncompressed TF Record of binary serialized TensorFlow.Example proto.
// Does not require TensorFlow C++ to be linked.
formats->push_back({
.extension = "tfrecord",
.prefix = FORMAT_TFE_TFRECORDV2,
.proto_format = proto::FORMAT_TFE_TFRECORDV2,
});

// GZip compressed TF Record of binary serialized TensorFlow.Example proto.
// Does not require TensorFlow C++ to be linked.
formats->push_back({
.extension = "tfrecord",
.prefix = "tfrecord",
.prefix_alias = "tfrecordv2+gz+tfe",
.proto_format = proto::FORMAT_TFE_TFRECORD_COMPRESSED_V2,
});

// Partially computed (e.g. non indexed) dataset cache.
formats->push_back({
.extension = "partial_dataset_cache",
.prefix = FORMAT_PARTIAL_DATASET_CACHE,
.proto_format = proto::FORMAT_PARTIAL_DATASET_CACHE,
});

return formats;
}();
return *formats;
}

absl::StatusOr<std::pair<std::string, std::string>> SplitTypeAndPath(
const absl::string_view typed_path) {
Expand Down Expand Up @@ -72,70 +127,49 @@ std::pair<std::string, proto::DatasetFormat> GetDatasetPathAndType(
return GetDatasetPathAndTypeOrStatus(typed_path).value();
}

absl::StatusOr<std::pair<std::string, proto::DatasetFormat>>
GetDatasetPathAndTypeOrStatus(const absl::string_view typed_path) {
std::string path, prefix;
std::tie(prefix, path) = SplitTypeAndPath(typed_path).value();

static const google::protobuf::EnumDescriptor* enum_descriptor =
google::protobuf::GetEnumDescriptor<DatasetFormat>();
for (int format_idx = 0; format_idx < enum_descriptor->value_count();
format_idx++) {
const auto format = static_cast<DatasetFormat>(
enum_descriptor->value(format_idx)->number());
if (format == proto::INVALID) {
continue;
std::string FormatToRecommendedExtension(proto::DatasetFormat proto_format) {
for (const auto& format : GetFormats()) {
if (format.proto_format == proto_format) {
return std::string(format.extension);
}
if (DatasetFormatToPrefix(format) == prefix) {
return std::make_pair(std::string(path), format);
}
return "";
}

absl::StatusOr<proto::DatasetFormat> PrefixToFormat(
const absl::string_view prefix) {
for (const auto& format : GetFormats()) {
if (format.prefix == prefix || format.prefix_alias == prefix) {
return format.proto_format;
}
}
return absl::InvalidArgumentError(
absl::StrCat("Unknown format \"", prefix, "\" in \"", typed_path, "\""));
absl::StrCat("The format prefix \"", prefix,
"\" is unknown. Make sure the format reader is linked to "
"the binary."));
}

std::string FormatToRecommendedExtension(proto::DatasetFormat format) {
switch (format) {
case proto::INVALID:
LOG(FATAL) << "Invalid format";
break;
case proto::FORMAT_CSV:
return "csv";
case proto::FORMAT_TFE_TFRECORD:
return "tfrecord";
case proto::FORMAT_TFE_TFRECORDV2:
return "tfrecordv2";
case proto::FORMAT_PARTIAL_DATASET_CACHE:
return "partial_dataset_cache";
}
absl::StatusOr<std::pair<std::string, proto::DatasetFormat>>
GetDatasetPathAndTypeOrStatus(const absl::string_view typed_path) {
std::string path, prefix;
ASSIGN_OR_RETURN(std::tie(prefix, path), SplitTypeAndPath(typed_path));
ASSIGN_OR_RETURN(const auto format, PrefixToFormat(prefix));
return std::make_pair(std::string(path), format);
}

std::string DatasetFormatToPrefix(proto::DatasetFormat format) {
switch (format) {
case proto::INVALID:
LOG(FATAL) << "Invalid format";
break;
case proto::FORMAT_CSV:
return FORMAT_CSV;
case proto::FORMAT_TFE_TFRECORD:
return FORMAT_TFE_TFRECORD;
case proto::FORMAT_TFE_TFRECORDV2:
return FORMAT_TFE_TFRECORDV2;
case proto::FORMAT_PARTIAL_DATASET_CACHE:
return FORMAT_PARTIAL_DATASET_CACHE;
std::string DatasetFormatToPrefix(proto::DatasetFormat proto_format) {
for (const auto& format : GetFormats()) {
if (format.proto_format == proto_format) {
return std::string(format.prefix);
}
}
return "unknown";
}

std::string ListSupportedFormats() {
std::vector<std::string> supported_prefixes;
static const google::protobuf::EnumDescriptor* enum_descriptor =
google::protobuf::GetEnumDescriptor<DatasetFormat>();
for (int i = 0; i < enum_descriptor->value_count(); i++) {
const google::protobuf::EnumValueDescriptor* format_idx = enum_descriptor->value(i);
const auto format = static_cast<DatasetFormat>(format_idx->number());
if (format != proto::INVALID) {
supported_prefixes.push_back(DatasetFormatToPrefix(format));
}
for (const auto& format : GetFormats()) {
supported_prefixes.push_back(std::string(format.prefix));
}
return absl::StrJoin(supported_prefixes, ", ");
}
Expand Down
9 changes: 1 addition & 8 deletions yggdrasil_decision_forests/dataset/formats.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,9 @@ namespace yggdrasil_decision_forests {
namespace dataset {

// File prefixes indicative of the file format.

// RFC 4180-compliant CSV file.
const char* const FORMAT_CSV = "csv";

// TFRecord of binary serialized TensorFlow.Example proto. Compressed with GZip.
const char* const FORMAT_TFE_TFRECORD = "tfrecord+tfe";
// TFRecord read by YDF directly. Currently, does not support GZip compression.
const char* const FORMAT_TFE_TFRECORDV2 = "tfrecordv2+tfe";

// Partially computed (e.g. non indexed) dataset cache.
const char* const FORMAT_PARTIAL_DATASET_CACHE = "partial_dataset_cache";

// Splits the format and path from a typed path.
Expand All @@ -50,7 +43,7 @@ std::pair<std::string, proto::DatasetFormat> GetDatasetPathAndType(

// Same as "GetDatasetPathAndType", but return a status in case of error.
absl::StatusOr<std::pair<std::string, proto::DatasetFormat>>
GetDatasetPathAndTypeOrStatus(const absl::string_view typed_path);
GetDatasetPathAndTypeOrStatus(absl::string_view typed_path);

// Tests if a string is a typed path.
bool IsTypedPath(absl::string_view maybe_typed_path);
Expand Down
3 changes: 3 additions & 0 deletions yggdrasil_decision_forests/dataset/formats.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ syntax = "proto2";

package yggdrasil_decision_forests.dataset.proto;

// See dataset/formats.h for the definition of the formats.

// Supported dataset formats.
enum DatasetFormat {
INVALID = 0;
FORMAT_CSV = 1;
reserved 2, 3, 4, 6;
FORMAT_TFE_TFRECORD = 5;
FORMAT_TFE_TFRECORDV2 = 8;
FORMAT_TFE_TFRECORD_COMPRESSED_V2 = 9;
FORMAT_PARTIAL_DATASET_CACHE = 7;
}
3 changes: 2 additions & 1 deletion yggdrasil_decision_forests/dataset/tensorflow_no_dep/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ cc_library_ydf(
srcs = ["tf_record.cc"],
hdrs = ["tf_record.h"],
deps = [
"//yggdrasil_decision_forests/utils:bytestream",
"//yggdrasil_decision_forests/utils:filesystem",
"//yggdrasil_decision_forests/utils:logging",
"//yggdrasil_decision_forests/utils:protobuf",
"//yggdrasil_decision_forests/utils:sharded_io",
"//yggdrasil_decision_forests/utils:status_macros",
"//yggdrasil_decision_forests/utils:zlib",
"@com_google_absl//absl/base:endian",
"@com_google_absl//absl/crc:crc32c",
"@com_google_absl//absl/log",
Expand All @@ -33,7 +35,6 @@ cc_library_ydf(
srcs = ["tf_record_tf_example.cc"],
hdrs = ["tf_record_tf_example.h"],
deps = [
":tf_example_cc_proto",
":tf_record",
"//yggdrasil_decision_forests/dataset:data_spec_cc_proto",
"//yggdrasil_decision_forests/dataset:data_spec_inference",
Expand Down
18 changes: 14 additions & 4 deletions yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <stdint.h>

#include <memory>
#include <utility>

#include "absl/base/internal/endian.h"
#include "absl/crc/crc32c.h"
Expand All @@ -26,15 +27,19 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "yggdrasil_decision_forests/utils/bytestream.h"
#include "yggdrasil_decision_forests/utils/filesystem.h"
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/protobuf.h"
#include "yggdrasil_decision_forests/utils/zlib.h"

namespace yggdrasil_decision_forests::dataset::tensorflow_no_dep {
namespace {
constexpr char kInvalidDataMessage[] =
"The file is not a non-compressed TFRecord or it is corrupted. If "
"you have a compressed TFRecord, decompress it first.";
"The data is not a valid non-compressed TF Record. The data is either "
"corrupted or (more likely) a gzip compressed TFRecord. In this later "
"case, fix the type prefix in the filepath. For example, replace "
"'tfrecordv2+tfe:' with 'tfrecord:' (recommended) or ('tfrecord+tfe').";

static const uint32_t kMaskDelta = 0xa282ead8ul;

Expand Down Expand Up @@ -70,8 +75,13 @@ TFRecordReader::~TFRecordReader() {
}

absl::StatusOr<std::unique_ptr<TFRecordReader>> TFRecordReader::Create(
const absl::string_view path) {
ASSIGN_OR_RETURN(auto stream, file::OpenInputFile(path));
const absl::string_view path, bool compressed) {
ASSIGN_OR_RETURN(std::unique_ptr<utils::InputByteStream> stream,
file::OpenInputFile(path));
if (compressed) {
ASSIGN_OR_RETURN(stream,
utils::GZipInputByteStream::Create(std::move(stream)));
}
return absl::make_unique<TFRecordReader>(std::move(stream));
}

Expand Down
Loading

0 comments on commit 0b87800

Please sign in to comment.