diff --git a/yggdrasil_decision_forests/dataset/BUILD b/yggdrasil_decision_forests/dataset/BUILD index 3b7526a7..d42172b3 100644 --- a/yggdrasil_decision_forests/dataset/BUILD +++ b/yggdrasil_decision_forests/dataset/BUILD @@ -38,6 +38,7 @@ cc_library_ydf( deps = [ ":formats_cc_proto", "//yggdrasil_decision_forests/utils:logging", + "//yggdrasil_decision_forests/utils:status_macros", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -83,7 +84,9 @@ cc_library_ydf( ":example_reader_interface", ":formats", ":formats_cc_proto", + "//yggdrasil_decision_forests/utils:logging", "//yggdrasil_decision_forests/utils:status_macros", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/yggdrasil_decision_forests/dataset/example_reader.cc b/yggdrasil_decision_forests/dataset/example_reader.cc index 07a12dbd..8b73e31a 100644 --- a/yggdrasil_decision_forests/dataset/example_reader.cc +++ b/yggdrasil_decision_forests/dataset/example_reader.cc @@ -21,6 +21,7 @@ #include #include +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -28,6 +29,7 @@ #include "yggdrasil_decision_forests/dataset/example_reader_interface.h" #include "yggdrasil_decision_forests/dataset/formats.h" #include "yggdrasil_decision_forests/dataset/formats.pb.h" +#include "yggdrasil_decision_forests/utils/logging.h" #include "yggdrasil_decision_forests/utils/status_macros.h" namespace yggdrasil_decision_forests { @@ -56,6 +58,7 @@ absl::StatusOr> CreateExampleReader( absl::StatusOr IsFormatSupported(absl::string_view typed_path) { const auto path_format_or = GetDatasetPathAndTypeOrStatus(typed_path); if (!path_format_or.ok()) { + LOG(WARNING) << "Cannot parse typed path: " << path_format_or.status(); return false; } std::string sharded_path; diff --git a/yggdrasil_decision_forests/dataset/example_reader_test.cc b/yggdrasil_decision_forests/dataset/example_reader_test.cc index 7c1ec4be..806d4409 100644 --- a/yggdrasil_decision_forests/dataset/example_reader_test.cc +++ b/yggdrasil_decision_forests/dataset/example_reader_test.cc @@ -40,12 +40,19 @@ std::string DatasetDir() { "yggdrasil_decision_forests/" "test_data/dataset"); } - TEST(ExampleReader, CreateExampleReader) { - for (const auto& dataset_path : - {absl::StrCat("tfrecord+tfe:", - file::JoinPath(DatasetDir(), "toy.tfe-tfrecord@2")), - absl::StrCat("csv:", file::JoinPath(DatasetDir(), "toy.csv"))}) { + for (const auto& dataset_path : { + absl::StrCat("tfrecord+tfe:", + file::JoinPath(DatasetDir(), "toy.tfe-tfrecord@2")), + absl::StrCat("tfrecord:", + file::JoinPath(DatasetDir(), "toy.tfe-tfrecord@2")), + absl::StrCat("tfrecordv2+gz+tfe:", + file::JoinPath(DatasetDir(), "toy.tfe-tfrecord@2")), + absl::StrCat( + "tfrecordv2+tfe:", + file::JoinPath(DatasetDir(), "toy.nocompress-tfe-tfrecord@2")), + absl::StrCat("csv:", file::JoinPath(DatasetDir(), "toy.csv")), + }) { LOG(INFO) << "Create dataspec for " << dataset_path; proto::DataSpecificationGuide guide; proto::DataSpecification data_spec; diff --git a/yggdrasil_decision_forests/dataset/formats.cc b/yggdrasil_decision_forests/dataset/formats.cc index 39aa9790..ccc415c6 100644 --- a/yggdrasil_decision_forests/dataset/formats.cc +++ b/yggdrasil_decision_forests/dataset/formats.cc @@ -30,11 +30,66 @@ #include "absl/strings/substitute.h" #include "yggdrasil_decision_forests/dataset/formats.pb.h" #include "yggdrasil_decision_forests/utils/logging.h" +#include "yggdrasil_decision_forests/utils/status_macros.h" namespace yggdrasil_decision_forests { namespace dataset { -using proto::DatasetFormat; +struct Format { + absl::string_view extension; + absl::string_view prefix; + absl::string_view prefix_alias = ""; + proto::DatasetFormat proto_format; +}; + +const std::vector& GetFormats() { + static const std::vector* formats = []() { + auto* formats = new std::vector(); + + // RFC 4180-compliant CSV file. + formats->push_back({ + .extension = "csv", + .prefix = FORMAT_CSV, + .proto_format = proto::FORMAT_CSV, + }); + + // GZip compressed TF Record of binary serialized TensorFlow.Example proto. + // Use TensorFlow API to read the file (required TensorFlow C++ to be + // linked). Deprecated: Use FORMAT_TFE_TFRECORD_COMPRESSED_V2 instead. + formats->push_back({ + .extension = "tfrecord", + .prefix = FORMAT_TFE_TFRECORD, + .proto_format = proto::FORMAT_TFE_TFRECORD, + }); + + // Uncompressed TF Record of binary serialized TensorFlow.Example proto. + // Does not require TensorFlow C++ to be linked. + formats->push_back({ + .extension = "tfrecord", + .prefix = FORMAT_TFE_TFRECORDV2, + .proto_format = proto::FORMAT_TFE_TFRECORDV2, + }); + + // GZip compressed TF Record of binary serialized TensorFlow.Example proto. + // Does not require TensorFlow C++ to be linked. + formats->push_back({ + .extension = "tfrecord", + .prefix = "tfrecord", + .prefix_alias = "tfrecordv2+gz+tfe", + .proto_format = proto::FORMAT_TFE_TFRECORD_COMPRESSED_V2, + }); + + // Partially computed (e.g. non indexed) dataset cache. + formats->push_back({ + .extension = "partial_dataset_cache", + .prefix = FORMAT_PARTIAL_DATASET_CACHE, + .proto_format = proto::FORMAT_PARTIAL_DATASET_CACHE, + }); + + return formats; + }(); + return *formats; +} absl::StatusOr> SplitTypeAndPath( const absl::string_view typed_path) { @@ -72,70 +127,49 @@ std::pair GetDatasetPathAndType( return GetDatasetPathAndTypeOrStatus(typed_path).value(); } -absl::StatusOr> -GetDatasetPathAndTypeOrStatus(const absl::string_view typed_path) { - std::string path, prefix; - std::tie(prefix, path) = SplitTypeAndPath(typed_path).value(); - - static const google::protobuf::EnumDescriptor* enum_descriptor = - google::protobuf::GetEnumDescriptor(); - for (int format_idx = 0; format_idx < enum_descriptor->value_count(); - format_idx++) { - const auto format = static_cast( - enum_descriptor->value(format_idx)->number()); - if (format == proto::INVALID) { - continue; +std::string FormatToRecommendedExtension(proto::DatasetFormat proto_format) { + for (const auto& format : GetFormats()) { + if (format.proto_format == proto_format) { + return std::string(format.extension); } - if (DatasetFormatToPrefix(format) == prefix) { - return std::make_pair(std::string(path), format); + } + return ""; +} + +absl::StatusOr PrefixToFormat( + const absl::string_view prefix) { + for (const auto& format : GetFormats()) { + if (format.prefix == prefix || format.prefix_alias == prefix) { + return format.proto_format; } } return absl::InvalidArgumentError( - absl::StrCat("Unknown format \"", prefix, "\" in \"", typed_path, "\"")); + absl::StrCat("The format prefix \"", prefix, + "\" is unknown. Make sure the format reader is linked to " + "the binary.")); } -std::string FormatToRecommendedExtension(proto::DatasetFormat format) { - switch (format) { - case proto::INVALID: - LOG(FATAL) << "Invalid format"; - break; - case proto::FORMAT_CSV: - return "csv"; - case proto::FORMAT_TFE_TFRECORD: - return "tfrecord"; - case proto::FORMAT_TFE_TFRECORDV2: - return "tfrecordv2"; - case proto::FORMAT_PARTIAL_DATASET_CACHE: - return "partial_dataset_cache"; - } +absl::StatusOr> +GetDatasetPathAndTypeOrStatus(const absl::string_view typed_path) { + std::string path, prefix; + ASSIGN_OR_RETURN(std::tie(prefix, path), SplitTypeAndPath(typed_path)); + ASSIGN_OR_RETURN(const auto format, PrefixToFormat(prefix)); + return std::make_pair(std::string(path), format); } -std::string DatasetFormatToPrefix(proto::DatasetFormat format) { - switch (format) { - case proto::INVALID: - LOG(FATAL) << "Invalid format"; - break; - case proto::FORMAT_CSV: - return FORMAT_CSV; - case proto::FORMAT_TFE_TFRECORD: - return FORMAT_TFE_TFRECORD; - case proto::FORMAT_TFE_TFRECORDV2: - return FORMAT_TFE_TFRECORDV2; - case proto::FORMAT_PARTIAL_DATASET_CACHE: - return FORMAT_PARTIAL_DATASET_CACHE; +std::string DatasetFormatToPrefix(proto::DatasetFormat proto_format) { + for (const auto& format : GetFormats()) { + if (format.proto_format == proto_format) { + return std::string(format.prefix); + } } + return "unknown"; } std::string ListSupportedFormats() { std::vector supported_prefixes; - static const google::protobuf::EnumDescriptor* enum_descriptor = - google::protobuf::GetEnumDescriptor(); - for (int i = 0; i < enum_descriptor->value_count(); i++) { - const google::protobuf::EnumValueDescriptor* format_idx = enum_descriptor->value(i); - const auto format = static_cast(format_idx->number()); - if (format != proto::INVALID) { - supported_prefixes.push_back(DatasetFormatToPrefix(format)); - } + for (const auto& format : GetFormats()) { + supported_prefixes.push_back(std::string(format.prefix)); } return absl::StrJoin(supported_prefixes, ", "); } diff --git a/yggdrasil_decision_forests/dataset/formats.h b/yggdrasil_decision_forests/dataset/formats.h index ac0652cd..424ea870 100644 --- a/yggdrasil_decision_forests/dataset/formats.h +++ b/yggdrasil_decision_forests/dataset/formats.h @@ -32,16 +32,9 @@ namespace yggdrasil_decision_forests { namespace dataset { // File prefixes indicative of the file format. - -// RFC 4180-compliant CSV file. const char* const FORMAT_CSV = "csv"; - -// TFRecord of binary serialized TensorFlow.Example proto. Compressed with GZip. const char* const FORMAT_TFE_TFRECORD = "tfrecord+tfe"; -// TFRecord read by YDF directly. Currently, does not support GZip compression. const char* const FORMAT_TFE_TFRECORDV2 = "tfrecordv2+tfe"; - -// Partially computed (e.g. non indexed) dataset cache. const char* const FORMAT_PARTIAL_DATASET_CACHE = "partial_dataset_cache"; // Splits the format and path from a typed path. @@ -50,7 +43,7 @@ std::pair GetDatasetPathAndType( // Same as "GetDatasetPathAndType", but return a status in case of error. absl::StatusOr> -GetDatasetPathAndTypeOrStatus(const absl::string_view typed_path); +GetDatasetPathAndTypeOrStatus(absl::string_view typed_path); // Tests if a string is a typed path. bool IsTypedPath(absl::string_view maybe_typed_path); diff --git a/yggdrasil_decision_forests/dataset/formats.proto b/yggdrasil_decision_forests/dataset/formats.proto index b75ad879..2abfa595 100644 --- a/yggdrasil_decision_forests/dataset/formats.proto +++ b/yggdrasil_decision_forests/dataset/formats.proto @@ -17,6 +17,8 @@ syntax = "proto2"; package yggdrasil_decision_forests.dataset.proto; +// See dataset/formats.h for the definition of the formats. + // Supported dataset formats. enum DatasetFormat { INVALID = 0; @@ -24,5 +26,6 @@ enum DatasetFormat { reserved 2, 3, 4, 6; FORMAT_TFE_TFRECORD = 5; FORMAT_TFE_TFRECORDV2 = 8; + FORMAT_TFE_TFRECORD_COMPRESSED_V2 = 9; FORMAT_PARTIAL_DATASET_CACHE = 7; } diff --git a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/BUILD b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/BUILD index 5551b89a..3f2df690 100644 --- a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/BUILD +++ b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/BUILD @@ -13,11 +13,13 @@ cc_library_ydf( srcs = ["tf_record.cc"], hdrs = ["tf_record.h"], deps = [ + "//yggdrasil_decision_forests/utils:bytestream", "//yggdrasil_decision_forests/utils:filesystem", "//yggdrasil_decision_forests/utils:logging", "//yggdrasil_decision_forests/utils:protobuf", "//yggdrasil_decision_forests/utils:sharded_io", "//yggdrasil_decision_forests/utils:status_macros", + "//yggdrasil_decision_forests/utils:zlib", "@com_google_absl//absl/base:endian", "@com_google_absl//absl/crc:crc32c", "@com_google_absl//absl/log", @@ -33,7 +35,6 @@ cc_library_ydf( srcs = ["tf_record_tf_example.cc"], hdrs = ["tf_record_tf_example.h"], deps = [ - ":tf_example_cc_proto", ":tf_record", "//yggdrasil_decision_forests/dataset:data_spec_cc_proto", "//yggdrasil_decision_forests/dataset:data_spec_inference", diff --git a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.cc b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.cc index 15af97fc..c89de048 100644 --- a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.cc +++ b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.cc @@ -18,6 +18,7 @@ #include #include +#include #include "absl/base/internal/endian.h" #include "absl/crc/crc32c.h" @@ -26,15 +27,19 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "yggdrasil_decision_forests/utils/bytestream.h" #include "yggdrasil_decision_forests/utils/filesystem.h" #include "yggdrasil_decision_forests/utils/logging.h" #include "yggdrasil_decision_forests/utils/protobuf.h" +#include "yggdrasil_decision_forests/utils/zlib.h" namespace yggdrasil_decision_forests::dataset::tensorflow_no_dep { namespace { constexpr char kInvalidDataMessage[] = - "The file is not a non-compressed TFRecord or it is corrupted. If " - "you have a compressed TFRecord, decompress it first."; + "The data is not a valid non-compressed TF Record. The data is either " + "corrupted or (more likely) a gzip compressed TFRecord. In this later " + "case, fix the type prefix in the filepath. For example, replace " + "'tfrecordv2+tfe:' with 'tfrecord:' (recommended) or ('tfrecord+tfe')."; static const uint32_t kMaskDelta = 0xa282ead8ul; @@ -70,8 +75,13 @@ TFRecordReader::~TFRecordReader() { } absl::StatusOr> TFRecordReader::Create( - const absl::string_view path) { - ASSIGN_OR_RETURN(auto stream, file::OpenInputFile(path)); + const absl::string_view path, bool compressed) { + ASSIGN_OR_RETURN(std::unique_ptr stream, + file::OpenInputFile(path)); + if (compressed) { + ASSIGN_OR_RETURN(stream, + utils::GZipInputByteStream::Create(std::move(stream))); + } return absl::make_unique(std::move(stream)); } diff --git a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.h b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.h index 044014a6..373b3ce1 100644 --- a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.h +++ b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record.h @@ -16,7 +16,6 @@ #ifndef YGGDRASIL_DECISION_FORESTS_DATASET_TENSORFLOW_NO_DEP_TF_RECORD_H_ #define YGGDRASIL_DECISION_FORESTS_DATASET_TENSORFLOW_NO_DEP_TF_RECORD_H_ -#include #include #include @@ -26,6 +25,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "yggdrasil_decision_forests/utils/bytestream.h" #include "yggdrasil_decision_forests/utils/filesystem.h" #include "yggdrasil_decision_forests/utils/protobuf.h" #include "yggdrasil_decision_forests/utils/sharded_io.h" @@ -39,7 +39,7 @@ class TFRecordReader { public: // Opens a TFRecord for reading. static absl::StatusOr> Create( - absl::string_view path); + absl::string_view path, bool compressed = false); ~TFRecordReader(); @@ -53,7 +53,7 @@ class TFRecordReader { // Closes the stream. absl::Status Close(); - TFRecordReader(std::unique_ptr&& stream) + TFRecordReader(std::unique_ptr&& stream) : stream_(std::move(stream)) {} // Value of the last read record. Includes skipped messages. @@ -63,7 +63,7 @@ class TFRecordReader { // Reads a CRC. absl::StatusOr ReadCRC(); - std::unique_ptr stream_; + std::unique_ptr stream_; std::string buffer_; }; @@ -71,12 +71,13 @@ class TFRecordReader { template class ShardedTFRecordReader : public utils::ShardedReader { public: - ShardedTFRecordReader() = default; + ShardedTFRecordReader(bool compressed = false) : compressed_(compressed) {}; absl::Status OpenShard(absl::string_view path) override; absl::StatusOr NextInShard(T* example) override; private: std::unique_ptr reader_; + bool compressed_; DISALLOW_COPY_AND_ASSIGN(ShardedTFRecordReader); }; @@ -127,7 +128,7 @@ absl::Status ShardedTFRecordReader::OpenShard(const absl::string_view path) { RETURN_IF_ERROR(reader_->Close()); reader_.reset(); } - ASSIGN_OR_RETURN(reader_, TFRecordReader::Create(path)); + ASSIGN_OR_RETURN(reader_, TFRecordReader::Create(path, compressed_)); return absl::OkStatus(); } diff --git a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record_test.cc b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record_test.cc index 3309c62a..d19c7b4f 100644 --- a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record_test.cc +++ b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record_test.cc @@ -104,6 +104,30 @@ TEST(TFRecord, Reader) { EXPECT_EQ(message_idx, 3); } +TEST(TFRecord, ReaderCompressed) { + ASSERT_OK_AND_ASSIGN( + auto reader, + TFRecordReader::Create( + file::JoinPath(DatasetDir(), "toy.tfe-tfrecord-00000-of-00002"), + /*compressed=*/true)); + + int message_idx = 0; + while (true) { + tensorflow::Example message; + ASSERT_OK_AND_ASSIGN(const bool has_value, reader->Next(&message)); + if (!has_value) { + break; + } + LOG(INFO) << message.DebugString(); + if (message_idx == 3) { + EXPECT_THAT(message, EqualsProto(ThirdExample())); + } + message_idx++; + } + ASSERT_OK(reader->Close()); + EXPECT_EQ(message_idx, 3); +} + TEST(TFRecord, ShardedReader) { ShardedTFRecordReader reader; ASSERT_OK(reader.Open( diff --git a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record_tf_example.h b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record_tf_example.h index 476d4a53..786f6887 100644 --- a/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record_tf_example.h +++ b/yggdrasil_decision_forests/dataset/tensorflow_no_dep/tf_record_tf_example.h @@ -34,6 +34,8 @@ using TFRecordV2TFExampleReader = ShardedTFRecordReader; REGISTER_AbstractTFExampleReader(TFRecordV2TFExampleReader, "FORMAT_TFE_TFRECORDV2"); +// Non-compressed TFRecord. + class TFRecordV2TFEToExampleReaderInterface : public TFExampleReaderToExampleReader { public: @@ -59,6 +61,34 @@ class TFRecordV2TFExampleReaderToDataSpecCreator REGISTER_AbstractDataSpecCreator(TFRecordV2TFExampleReaderToDataSpecCreator, "FORMAT_TFE_TFRECORDV2"); +// Compressed TFRecord. + +class TFRecordCompressedV2TFEToExampleReaderInterface + : public TFExampleReaderToExampleReader { + public: + TFRecordCompressedV2TFEToExampleReaderInterface( + const proto::DataSpecification& data_spec, + absl::optional> ensure_non_missing) + : TFExampleReaderToExampleReader(data_spec, ensure_non_missing) {} + + std::unique_ptr CreateReader() override { + return absl::make_unique(true); + } +}; +REGISTER_ExampleReaderInterface(TFRecordCompressedV2TFEToExampleReaderInterface, + "FORMAT_TFE_TFRECORD_COMPRESSED_V2"); + +class TFRecordCompressedV2TFExampleReaderToDataSpecCreator + : public TFExampleReaderToDataSpecCreator { + std::unique_ptr CreateReader() override { + return absl::make_unique(true); + } +}; + +REGISTER_AbstractDataSpecCreator( + TFRecordCompressedV2TFExampleReaderToDataSpecCreator, + "FORMAT_TFE_TFRECORD_COMPRESSED_V2"); + // Write tf.Examples in TFRecords. class TFRecordV2TFExampleWriter : public ShardedTFRecordWriter {}; diff --git a/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py b/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py index 9dceba52..214c108d 100644 --- a/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py +++ b/yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py @@ -156,6 +156,28 @@ def test_adult_classification(self): logging.info("Evaluation: %s", evaluation) self.assertGreaterEqual(evaluation.accuracy, 0.864) + def test_adult_classification_on_tfrecord_dataset(self): + learner = specialized_learners.RandomForestLearner(label="income") + model = learner.train( + "tfrecord:" + + os.path.join( + test_utils.ydf_test_data_path(), + "dataset", + "adult_train.recordio.gz", + ) + ) + logging.info("Trained model: %s", model) + + # Evaluate the trained model. + evaluation = model.evaluate( + "tfrecord:" + + os.path.join( + test_utils.ydf_test_data_path(), "dataset", "adult_test.recordio.gz" + ) + ) + logging.info("Evaluation: %s", evaluation) + self.assertGreaterEqual(evaluation.accuracy, 0.864) + def test_two_center_regression(self): learner = specialized_learners.RandomForestLearner( label="target", task=generic_learner.Task.REGRESSION