diff --git a/Makefile b/Makefile index 237c72a0c1..5ce7724da2 100644 --- a/Makefile +++ b/Makefile @@ -537,7 +537,7 @@ test: $(BIN_DIR)/$(EXE) $(LIB_DIR)/libvg.a test/build_graph $(BIN_DIR)/shuf $(BI # Somebody has been polluting the test directory with temporary files that are not deleted after the tests. # To make git status more useful, we delete everything that looks like a temporary file. -clean-test: +clean-tests: cd test && rm -rf tmp && mkdir tmp && mv 2_2.mat build_graph.cpp default.mat tmp && rm -f *.* && mv tmp/* . && rmdir tmp docs: $(SRC_DIR)/*.cpp $(SRC_DIR)/*.hpp $(ALGORITHMS_SRC_DIR)/*.cpp $(ALGORITHMS_SRC_DIR)/*.hpp $(SUBCOMMAND_SRC_DIR)/*.cpp $(SUBCOMMAND_SRC_DIR)/*.hpp $(UNITTEST_SRC_DIR)/*.cpp $(UNITTEST_SRC_DIR)/*.hpp $(UNITTEST_SUPPORT_SRC_DIR)/*.cpp diff --git a/deps/gbwt b/deps/gbwt index 2aab62f066..99daaccf4c 160000 --- a/deps/gbwt +++ b/deps/gbwt @@ -1 +1 @@ -Subproject commit 2aab62f0664b2ce8eb4cffd6b3872c89b85fdfa6 +Subproject commit 99daaccf4c8141e5e28680be031a481eb843e2c6 diff --git a/deps/gbwtgraph b/deps/gbwtgraph index 9fa7f6fa82..ebb77a18c1 160000 --- a/deps/gbwtgraph +++ b/deps/gbwtgraph @@ -1 +1 @@ -Subproject commit 9fa7f6fa82f329399a99c1e84573bdae96940d1b +Subproject commit ebb77a18c1fd5d13bb9ac3cc1f6fc75807b742f2 diff --git a/src/gaf_sorter.cpp b/src/gaf_sorter.cpp new file mode 100644 index 0000000000..12d158b0fd --- /dev/null +++ b/src/gaf_sorter.cpp @@ -0,0 +1,564 @@ +#include "gaf_sorter.hpp" + +#include +#include +#include +#include +#include +#include + +// Needed for the temporary file creation. +#include "utility.hpp" + +// For reading and writing compressed temporary files. +#include "zstdutil.hpp" + +// For reading compressed input. +#include +#include + +namespace vg { + +//------------------------------------------------------------------------------ + +// Public class constants. + +constexpr std::uint64_t GAFSorterRecord::MISSING_KEY; +const std::string GAFSorterRecord::GBWT_OFFSET_TAG = "GB:i:"; + +constexpr size_t GAFSorterParameters::THREADS; +constexpr size_t GAFSorterParameters::RECORDS_PER_FILE; +constexpr size_t GAFSorterParameters::FILES_PER_MERGE; +constexpr size_t GAFSorterParameters::BUFFER_SIZE; + +//------------------------------------------------------------------------------ + +void GAFSorterRecord::set_key(key_type type) { + if (type == key_node_interval) { + std::uint32_t min_id = std::numeric_limits::max(); + std::uint32_t max_id = 0; + str_view path = this->get_field(PATH_FIELD); + size_t start = 1; + while (start < path.size) { + std::uint32_t id = 0; + auto result = std::from_chars(path.data + start, path.data + path.size, id); + if (result.ec != std::errc()) { + this->key = MISSING_KEY; + return; + } + min_id = std::min(min_id, id); + max_id = std::max(max_id, id); + start = (result.ptr - path.data) + 1; + } + if (min_id == std::numeric_limits::max()) { + this->key = MISSING_KEY; + } else { + this->key = (static_cast(min_id) << 32) | max_id; + } + } else if (type == key_gbwt_pos) { + std::uint32_t node_id = std::numeric_limits::max(); + std::uint32_t offset = std::numeric_limits::max(); + this->for_each_field([&](size_t i, str_view value) -> bool { + if (i == PATH_FIELD && value.size > 1) { + auto result = std::from_chars(value.data + 1, value.data + value.size, node_id); + if (result.ec != std::errc()) { + return false; + } + } else if (i >= MANDATORY_FIELDS) { + size_t tag_size = GBWT_OFFSET_TAG.size(); + if (value.size > tag_size && value.substr(0, tag_size) == GBWT_OFFSET_TAG) { + auto result = std::from_chars(value.data + tag_size, value.data + value.size, offset); + return false; + } + } + return true; + }); + if (node_id == std::numeric_limits::max() || offset == std::numeric_limits::max()) { + // We either did not find both fields or failed to parse them. + this->key = MISSING_KEY; + } else { + this->key = (static_cast(node_id) << 32) | offset; + } + } else if (type == key_hash) { + this->key = hasher(this->value); + } else { + this->key = MISSING_KEY; + } +} + +bool GAFSorterRecord::serialize(std::ostream& out) const { + bool success = true; + out.write(reinterpret_cast(&this->key), sizeof(this->key)); + success &= out.good(); + std::uint64_t length = this->value.size(); + out.write(reinterpret_cast(&length), sizeof(length)); + success &= out.good(); + out.write(this->value.data(), length); + success &= out.good(); + return success; +} + +bool GAFSorterRecord::write_line(std::ostream& out) const { + out << this->value << '\n'; + return out.good(); +} + +bool GAFSorterRecord::deserialize(std::istream& in) { + bool success = true; + in.read(reinterpret_cast(&this->key), sizeof(this->key)); + success &= in.good(); + std::uint64_t length = 0; + in.read(reinterpret_cast(&length), sizeof(length)); + success &= in.good(); + this->value.resize(length); + in.read(&this->value[0], length); + success &= in.good(); + return success; +} + +bool GAFSorterRecord::read_line(std::istream& in, key_type type) { + std::getline(in, this->value); + if (in.eof()) { + return false; + } + this->set_key(type); + return true; +} + +str_view GAFSorterRecord::get_field(size_t field) const { + str_view result; + this->for_each_field([&](size_t i, str_view value) -> bool { + if (i == field) { + result = value; + return false; + } + return true; + }); + return result; +} + +void GAFSorterRecord::for_each_field(const std::function& lambda) const { + size_t start = 0, end = 0; + size_t i = 0; + while (end != std::string::npos) { + end = this->value.find('\t', start); + if (!lambda(i, str_view(this->value).substr(start, end - start))) { + break; + } + start = end + 1; + i++; + } +} + +//------------------------------------------------------------------------------ + +GAFSorterFile::GAFSorterFile() : + records(0), + temporary(true), compressed(true), raw_gaf(false), removed(false), ok(true) { + this->name = temp_file::create("gaf-sorter"); +} + +GAFSorterFile::GAFSorterFile(const std::string& name) : + name(name), records(0), + temporary(false), compressed(false), raw_gaf(true), removed(false), ok(true) { +} + +GAFSorterFile::~GAFSorterFile() { + this->remove_temporary(); +} + +std::pair> GAFSorterFile::open_output() { + std::pair> result; + if (this->is_std_in_out()) { + result.first = &std::cout; + } else if (this->compressed) { + result.second.reset(new zstd_ofstream(this->name)); + result.first = result.second.get(); + } else { + result.second.reset(new std::ofstream(this->name, std::ios::binary)); + result.first = result.second.get(); + } + if (!result.first->good()) { + this->ok = false; + std::cerr << "error: [gaf_sorter] could not open output file " << this->name << std::endl; + } + return result; +} + +std::pair> GAFSorterFile::open_input() { + std::pair> result; + if (this->is_std_in_out()) { + result.first = &std::cin; + } else if (this->compressed) { + result.second.reset(new zstd_ifstream(this->name)); + result.first = result.second.get(); + } else { + result.second.reset(new std::ifstream(this->name, std::ios::binary)); + result.first = result.second.get(); + } + if (!result.first->good()) { + this->ok = false; + std::cerr << "error: [gaf_sorter] could not open input file " << this->name << std::endl; + } + return result; +} + +void GAFSorterFile::remove_temporary() { + if (this->temporary && !this->removed) { + temp_file::remove(this->name); + this->removed = true; + this->ok = false; + } +} + +//------------------------------------------------------------------------------ + +bool sort_gaf(const std::string& input_file, const std::string& output_file, const GAFSorterParameters& params) { + // Timestamp for the start and the total number of records. + auto start_time = std::chrono::high_resolution_clock::now(); + size_t total_records = 0; + auto report_time = [&]() { + if (params.progress) { + auto end_time = std::chrono::high_resolution_clock::now(); + double seconds = std::chrono::duration(end_time - start_time).count(); + std::cerr << "Sorted " << total_records << " records in " << seconds << " seconds" << std::endl; + } + }; + + // Worker threads. + size_t num_threads = std::max(params.threads, size_t(1)); + if (params.progress) { + std::cerr << "Sorting GAF records with " << num_threads << " worker threads" << std::endl; + } + std::vector threads(num_threads); + auto join_all = [&]() { + for (std::thread& thread : threads) { + if (thread.joinable()) { + thread.join(); + } + } + }; + + // Temporary output files. + std::vector> files; + auto check_all_files = [&]() -> bool { + bool all_ok = true; + for (std::unique_ptr& file : files) { + all_ok &= file->ok; + } + return all_ok; + }; + + // Initial sort. If a worker thread fails, we break on join. + size_t batch = 0; + size_t initial_batch_size = std::max(params.records_per_file, size_t(1)); + if (params.progress) { + std::cerr << "Initial sort: " << initial_batch_size << " records per file" << std::endl; + } + std::string peek; + htsFile* input = hts_open(input_file.c_str(), "r"); + if (input == nullptr) { + std::cerr << "error: [gaf_sorter] could not open input file " << input_file << std::endl; + return false; + } + while (true) { + // Read the next batch. + std::unique_ptr> lines(new std::vector()); + lines->reserve(initial_batch_size); + if (!peek.empty()) { + lines->push_back(std::move(peek)); + peek.clear(); + } + kstring_t s_buffer = KS_INITIALIZE; + std::string line; + while (lines->size() < initial_batch_size && hts_getline(input, '\n', &s_buffer) >= 0) { + lines->push_back(std::string(ks_str(&s_buffer), ks_len(&s_buffer))); + } + total_records += lines->size(); + + // Peek at the first line of the next batch to determine if there is only one batch. + if (batch == 0) { + if (hts_getline(input, '\n', &s_buffer) < 0) { + if (params.progress) { + std::cerr << "Sorting directly to the final output" << std::endl; + } + std::unique_ptr out(new GAFSorterFile(output_file)); + sort_gaf_lines(std::move(lines), params.key_type, params.stable, std::ref(*out)); + ks_free(&s_buffer); + hts_close(input); + if (out->ok) { + report_time(); + return true; + } else { + return false; + } + } + peek = std::string(ks_str(&s_buffer), ks_len(&s_buffer)); + } + ks_free(&s_buffer); + if (lines->empty()) { + break; + } + + // Sort the batch to a temporary file. + std::unique_ptr out(new GAFSorterFile()); + size_t thread_id = batch % num_threads; + if (threads[thread_id].joinable()) { + threads[thread_id].join(); + if (!files[batch - num_threads]->ok) { + break; + } + } + threads[thread_id] = std::thread(sort_gaf_lines, std::move(lines), params.key_type, params.stable, std::ref(*out)); + files.push_back(std::move(out)); + batch++; + } + hts_close(input); + join_all(); + if (!check_all_files()) { + return false; + } + if (params.progress) { + std::cerr << "Initial sort finished with " << total_records << " records in " << files.size() << " files" << std::endl; + } + + // Intermediate merges. If a worker thread fails, we break on join. + size_t files_per_merge = std::max(params.files_per_merge, size_t(2)); + size_t round = 0; + while (files.size() > files_per_merge) { + if (params.progress) { + std::cerr << "Round " << round << ": " << files_per_merge << " files per batch" << std::endl; + } + std::vector> next_files; + batch = 0; + for (size_t i = 0; i < files.size(); i += files_per_merge) { + if (i + 1 == files.size()) { + // If we have a single file left, just move it to the next round. + next_files.push_back(std::move(files[i])); + continue; + } + std::unique_ptr> batch_files(new std::vector()); + for (size_t j = 0; j < files_per_merge && i + j < files.size(); j++) { + batch_files->push_back(std::move(*(files[i + j]))); + } + std::unique_ptr out(new GAFSorterFile()); + size_t thread_id = batch % num_threads; + if (threads[thread_id].joinable()) { + threads[thread_id].join(); + if (!next_files[batch - num_threads]->ok) { + break; + } + } + threads[thread_id] = std::thread(merge_gaf_records, std::move(batch_files), std::ref(*out), params.buffer_size); + next_files.push_back(std::move(out)); + batch++; + } + join_all(); + files = std::move(next_files); + if (!check_all_files()) { + return false; + } + if (params.progress) { + std::cerr << "Round " << round << " finished with " << files.size() << " files" << std::endl; + } + round++; + } + + // Final merge. + { + if (params.progress) { + std::cerr << "Starting the final merge" << std::endl; + } + GAFSorterFile out(output_file); + std::unique_ptr> inputs(new std::vector()); + for (std::unique_ptr& file : files) { + inputs->push_back(std::move(*file)); + } + merge_gaf_records(std::move(inputs), out, params.buffer_size); + if (out.ok) { + report_time(); + return true; + } else { + return false; + } + } +} + +//------------------------------------------------------------------------------ + +void sort_gaf_lines( + std::unique_ptr> lines, + GAFSorterRecord::key_type key_type, + bool stable, + GAFSorterFile& output +) { + if (lines == nullptr) { + output.ok = false; + std::cerr << "error: [gaf_sorter] sort_gaf_lines() called with null lines" << std::endl; + return; + } + if (!output.ok || output.records > 0) { + output.ok = false; + std::cerr << "error: [gaf_sorter] sort_gaf_lines() called with an invalid output file" << std::endl; + return; + } + + // Convert the lines into GAFSorterRecord objects. + std::vector records; + records.reserve(lines->size()); + for (std::string& line : *lines) { + records.emplace_back(std::move(line), key_type); + } + lines.reset(); + + // Sort the records. + if (stable) { + std::stable_sort(records.begin(), records.end()); + } else { + std::sort(records.begin(), records.end()); + } + + // Write the sorted records to the output file. + auto out = output.open_output(); + if (!output.ok) { + // open_output() already prints an error message. + return; + } + for (GAFSorterRecord& record : records) { + output.write(record, *out.first); + } + out.second.reset(); +} + +//------------------------------------------------------------------------------ + +void merge_gaf_records(std::unique_ptr> inputs, GAFSorterFile& output, size_t buffer_size) { + if (inputs == nullptr) { + output.ok = false; + std::cerr << "error: [gaf_sorter] merge_gaf_records() called with null inputs" << std::endl; + return; + } + if (buffer_size == 0) { + buffer_size = 1; + } + for (GAFSorterFile& input : *inputs) { + if (!input.ok || input.raw_gaf) { + output.ok = false; + std::cerr << "error: [gaf_sorter] merge_gaf_records() called an invalid input file" << std::endl; + return; + } + } + if (!output.ok || output.records > 0) { + output.ok = false; + std::cerr << "error: [gaf_sorter] merge_gaf_records called() with an invalid output file" << std::endl; + return; + } + + // Open the input files. + std::vector>> in; in.reserve(inputs->size()); + std::vector remaining; remaining.reserve(inputs->size()); + for (GAFSorterFile& input : *inputs) { + in.emplace_back(input.open_input()); + remaining.push_back(input.records); + if (!input.ok) { + // open_input() already prints an error message. + output.ok = false; + return; + } + } + + // Open the output file. + auto out = output.open_output(); + if (!output.ok) { + // open_output() already prints an error message. + return; + } + + // Input buffers. + std::vector> records; + records.resize(in.size()); + auto read_buffer = [&](size_t i) { + size_t count = std::min(buffer_size, remaining[i]); + if (count > 0) { + records[i].clear(); + for (size_t j = 0; j < count; j++) { + records[i].emplace_back(); + (*inputs)[i].read(records[i].back(), *(in[i].first)); + records[i].back().flip_key(); // Flip for the priority queue. + } + remaining[i] -= count; + if (!(*inputs)[i].ok) { + output.ok = false; + } + } + }; + for (size_t i = 0; i < in.size(); i++) { + read_buffer(i); + } + if (!output.ok) { + std::cerr << "error: [gaf_sorter] merge_gaf_records() failed to read the initial buffers" << std::endl; + return; + } + + // Output buffer. + std::vector buffer; + buffer.reserve(buffer_size); + auto write_buffer = [&]() { + for (GAFSorterRecord& record : buffer) { + output.write(record, *out.first); + } + buffer.clear(); + }; + + // Merge loop. + std::priority_queue> queue; + for (size_t i = 0; i < records.size(); i++) { + if (!records[i].empty()) { + queue.emplace(records[i].front(), i); + records[i].pop_front(); + } + } + while (!queue.empty()) { + GAFSorterRecord record = std::move(queue.top().first); + record.flip_key(); // Restore the original key. + size_t source = queue.top().second; + queue.pop(); + buffer.push_back(std::move(record)); + if (buffer.size() >= buffer_size) { + write_buffer(); + if (!output.ok) { + std::cerr << "error: [gaf_sorter] merge_gaf_records() failed to write to " << output.name << std::endl; + return; + } + } + if (records[source].empty()) { + read_buffer(source); + if (!output.ok) { + std::cerr << "error: [gaf_sorter] merge_gaf_records() failed to read from " << (*inputs)[source].name << std::endl; + return; + } + } + if (!records[source].empty()) { + queue.emplace(records[source].front(), source); + records[source].pop_front(); + } + } + if (!buffer.empty()) { + write_buffer(); + if (!output.ok) { + std::cerr << "error: [gaf_sorter] merge_gaf_records() failed to write to " << output.name << std::endl; + return; + } + } + + // Close the files. + for (size_t i = 0; i < in.size(); i++) { + in[i].first = nullptr; + in[i].second.reset(); + } + out.second.reset(); +} + +//------------------------------------------------------------------------------ + +} // namespace vg diff --git a/src/gaf_sorter.hpp b/src/gaf_sorter.hpp new file mode 100644 index 0000000000..e760eb8f9d --- /dev/null +++ b/src/gaf_sorter.hpp @@ -0,0 +1,295 @@ +#ifndef VG_GAF_SORTER_HPP_INCLUDED +#define VG_GAF_SORTER_HPP_INCLUDED + +/** \file + * Tools for sorting GAF records. + * + * TODO: This could be an independent utility. + * TODO: Asynchronous I/O. + * TODO: Option for automatic detection of merge width to guarantee <= 2 rounds. + * TODO: Switch to std::string_view when we can. + */ + +#include +#include +#include +#include +#include +#include +#include +// #include +#include + +namespace vg { + +//------------------------------------------------------------------------------ + +/** + * This should be std::string_view, but apparently we are still using C++14 in Linux. + */ +struct str_view { + const char* data; + size_t size; + + str_view() : data(nullptr), size(0) {} + str_view(const char* data, size_t size) : data(data), size(size) {} + str_view(const std::string& str) : data(str.data()), size(str.size()) {} + + bool empty() const { return (this->size == 0); } + + char operator[](size_t i) const { return this->data[i]; } + + str_view substr(size_t start, size_t length) const { + return str_view(this->data + start, length); + } + + bool operator==(const str_view& another) const { + return (this->size == another.size && std::equal(this->data, this->data + this->size, another.data)); + } + + bool operator==(const std::string& another) const { + return (this->size == another.size() && std::equal(this->data, this->data + this->size, another.begin())); + } + + std::string to_string() const { return std::string(this->data, this->size); } +}; + +/** + * A record corresponding to a single line (alignment) in a GAF file. + * The record contains an integer key and the original line. + * Various types of keys can be derived from the value, but the line is not + * parsed beyond that. + */ +struct GAFSorterRecord { + /// Integer key. + std::uint64_t key; + + /// GAF line. + std::string value; + + /// Hasher used for random shuffling. + static std::hash hasher; + + /// Missing key. Records without a key are sorted to the end. + constexpr static std::uint64_t MISSING_KEY = std::numeric_limits::max(); + + /// Node offset for the GBWT starting position of the forward orientation + /// may be stored in this tag. + const static std::string GBWT_OFFSET_TAG; // "GB:i:" + + /// Types of keys that can be derived from the value. + enum key_type { + /// (minimum node id, maximum node id) in the path. + key_node_interval, + /// GBWT starting position for the forward orientation. + /// Derived from the path and tag "GB:i:". + key_gbwt_pos, + /// Hash of the value for random shuffling. + key_hash, + }; + + /// Default constructor. + GAFSorterRecord() : key(MISSING_KEY) {} + + /// Constructor that consumes the given value and sets the key. + GAFSorterRecord(std::string&& value, key_type type) : key(MISSING_KEY), value(std::move(value)) { + this->set_key(type); + } + + /// Records are sorted by key in ascending order. + bool operator<(const GAFSorterRecord& another) const { + return (this->key < another.key); + } + + /// Flips they key to reverse the order. + /// Sorting is based on ascending order, while priority queues return the largest element first. + void flip_key() { + this->key = std::numeric_limits::max() - this->key; + } + + /// Sets a key of the given type, or MISSING_KEY if the key cannot be derived. + void set_key(key_type type); + + /// Serializes the record to a stream. Returns true on success. + bool serialize(std::ostream& out) const; + + /// Writes the underlying GAF line to a stream. Returns true on success. + bool write_line(std::ostream& out) const; + + /// Deserializes the record from a stream. Returns true on success. + bool deserialize(std::istream& in); + + /// Reads a GAF line from a stream and sets the key. Returns true on success. + bool read_line(std::istream& in, key_type type); + + /// Returns a view of the given 0-based field, or an empty string if the field is missing. + str_view get_field(size_t field) const; + + /// Calls the given function with a 0-based field index and the field value. + /// Stops if the function returns false. + void for_each_field(const std::function& lambda) const; + +private: + constexpr static size_t PATH_FIELD = 5; + constexpr static size_t MANDATORY_FIELDS = 12; +}; + +//------------------------------------------------------------------------------ + +/** + * A file of GAFSorterRecords or GAF lines. + * + * The records are sorted in increasing order by key. + * The object is movable but not copyable. + */ +struct alignas(128) GAFSorterFile { + /// File name. + std::string name; + + /// Number of records. + size_t records; + + /// Is this a temporary file created with temp_file::create()? + bool temporary; + + /// Is this file compressed? + bool compressed; + + /// Is this a raw GAF file? + bool raw_gaf; + + /// Has the file been removed? + bool removed; + + /// Success flag. + bool ok; + + /// Default constructor that creates a compressed temporary file. + GAFSorterFile(); + + /// Constructor that creates a raw GAF file with the given name. + explicit GAFSorterFile(const std::string& name); + + /// If the file is temporary, the destructor removes the file. + ~GAFSorterFile(); + + GAFSorterFile(const GAFSorterFile&) = delete; + GAFSorterFile& operator=(const GAFSorterFile&) = delete; + GAFSorterFile(GAFSorterFile&&) = default; + GAFSorterFile& operator=(GAFSorterFile&&) = default; + + /// Returns an output stream to the file. + /// The first return value is the actual stream. + /// The second return value is a unique pointer which may contain a newly created stream. + /// Sets the success flag. + std::pair> open_output(); + + /// Writes the record to the file. + /// Updates the number of records and sets the success flag. + void write(const GAFSorterRecord& record, std::ostream& out) { + this->ok &= (this->raw_gaf ? record.write_line(out) : record.serialize(out)); + this->records++; + } + + /// Returns an input stream to the file. + /// The first return value is the actual stream. + /// The second return value is a unique pointer which may contain a newly created stream. + /// Sets the success flag. + std::pair> open_input(); + + /// Reads the next record from the file, assuming that this is not a raw GAF file. + /// Sets the success flag. + void read(GAFSorterRecord& record, std::istream& in) { + this->ok &= (this->raw_gaf ? false : record.deserialize(in)); + } + + /// Returns true if the file is actually stdin/stdout. + /// In that case, open_input() should not be called. + bool is_std_in_out() const { + return (this->name == "-"); + } + + /// Removes the file if it is temporary. + void remove_temporary(); +}; + +//------------------------------------------------------------------------------ + +/** + * Parameters for the GAF sorter. + */ +struct GAFSorterParameters { + /// Default for threads. + constexpr static size_t THREADS = 1; + + /// Default for records_per_file. + constexpr static size_t RECORDS_PER_FILE = 1000000; + + /// Default for files_per_merge. + constexpr static size_t FILES_PER_MERGE = 32; + + /// Default for buffer_size. + constexpr static size_t BUFFER_SIZE = 1000; + + /// Key type used for sorting. + GAFSorterRecord::key_type key_type = GAFSorterRecord::key_node_interval; + + /// Number of parallel sort/merge jobs. + size_t threads = THREADS; + + /// Number of records per file in the initial sort. + size_t records_per_file = RECORDS_PER_FILE; + + /// Number of files to merge at once. + size_t files_per_merge = FILES_PER_MERGE; + + /// Buffer size for reading and writing records. + size_t buffer_size = BUFFER_SIZE; + + /// Use stable sorting. + bool stable = false; + + /// Print progress information to stderr. + bool progress = false; +}; + +/** + * Sorts the given GAF file into the given output file. + * + * The initial round sorts the records into temporary files with params.records_per_file records each. + * Each successive round merges the temporary files into larger files until there is only one file left. + * Each merge job merges params.files_per_merge files. + * Use "-" for reading stdin / writing to stdout. + * Returns false and prints an error message on failure. + */ +bool sort_gaf(const std::string& input_file, const std::string& output_file, const GAFSorterParameters& params); + +/** + * Sorts the given GAF lines into the given output file, with an option to use stable sorting. + * + * The lines are converted into GAFSorterRecord objects, with the given key type. + * The original lines are consumed. + * Sets the ok flag in the output and prints an error message on failure. + * + * This function is intended to be used with std::thread. + */ +void sort_gaf_lines(std::unique_ptr> lines, GAFSorterRecord::key_type key_type, bool stable, GAFSorterFile& output); + +/** + * Merges the given files into a single output file. + * + * The records in each input file are assumed to be sorted with sort_gaf_lines(). + * Records are read and written in blocks of the given size. + * If the input files are in the same order as the corresponding batches in the initial sort, this is a stable merge. + * Consumes the inputs and removes the files if they are temporary. + * Sets the ok flag in the output and prints an error message on failure. + * + * This function is intended to be used with std::thread. + */ +void merge_gaf_records(std::unique_ptr> inputs, GAFSorterFile& output, size_t buffer_size); + +//------------------------------------------------------------------------------ + +} // namespace vg + +#endif // VG_GAF_SORTER_HPP_INCLUDED diff --git a/src/subcommand/gamsort_main.cpp b/src/subcommand/gamsort_main.cpp index c4ec9fd1ad..e335ea9cc2 100644 --- a/src/subcommand/gamsort_main.cpp +++ b/src/subcommand/gamsort_main.cpp @@ -1,93 +1,92 @@ -#include "../stream_sorter.hpp" -#include -#include "../stream_index.hpp" -#include +/** \file gamsort_main.cpp + * + * Defines the "vg gamsort" subcommand for sorting and shuffling GAM and GAF files. + */ + #include "subcommand.hpp" -#include "vg/io/gafkluge.hpp" -#include "alignment.hpp" +#include "../alignment.hpp" +#include "../gaf_sorter.hpp" +#include "../stream_index.hpp" +#include "../stream_sorter.hpp" -/** -* GAM sort main -*/ +#include +#include -using namespace std; using namespace vg; using namespace vg::subcommand; + +//------------------------------------------------------------------------------ + +// We limit the max threads, and only allow thread count to be lowered, to +// prevent tcmalloc from giving each thread a very large heap for many +// threads. On my machine we can keep about 4 threads busy. +constexpr static size_t GAM_MAX_THREADS = 4; + +// gaf_sorter defaults to 1 thread. If we assume that the input and the output +// are bgzip-compressed, we should be able to saturate 5 worker threads. +// Because intermediate merges are independent, we can use more threads. +// The final single-threaded merge can use 5 bgzip threads. + void help_gamsort(char **argv) { - cerr << "gamsort: sort a GAM/GAF file, or index a sorted GAM file" << endl - << "Usage: " << argv[1] << " [Options] gamfile" << endl - << "Options:" << endl - << " -i / --index FILE produce an index of the sorted GAM file" << endl - << " -d / --dumb-sort use naive sorting algorithm (no tmp files, faster for small GAMs)" << endl - << " -s / --shuffle Shuffle reads by hash (GAM only)" << endl - << " -p / --progress Show progress." << endl - << " -G / --gaf-input Input is a GAF file." << endl - << " -c / --chunk-size Number of reads per chunk when sorting GAFs." << endl - << " -t / --threads Use the specified number of threads." << endl - << endl; + std::cerr << "usage: " << argv[0] << " " << argv[1] << " [options] input > output" << std::endl; + std::cerr << std::endl; + std::cerr << "Sort a GAM/GAF file, or index a sorted GAM file." << std::endl; + std::cerr << std::endl; + std::cerr << "General options:" << std::endl; + std::cerr << " -p, --progress show progress" << std::endl; + std::cerr << " -s, --shuffle shuffle reads by hash" << std::endl; + std::cerr << " -t, --threads N use N worker threads (default: " << GAM_MAX_THREADS << " for GAM, " << GAFSorterParameters::THREADS << " for GAF)" << std::endl; + std::cerr << std::endl; + std::cerr << "GAM sorting options:" << std::endl; + std::cerr << " -i, --index FILE produce an index of the sorted GAM file" << std::endl; + std::cerr << " -d, --dumb-sort use naive sorting algorithm (no tmp files, faster for small GAMs)" << std::endl; + std::cerr << std::endl; + std::cerr << "GAF sorting options:" << std::endl; + std::cerr << " -G, --gaf-input input is a GAF file" << std::endl; + std::cerr << " -c, --chunk-size N number of reads per chunk (default: " << GAFSorterParameters::RECORDS_PER_FILE << ")" << std::endl; + std::cerr << " -m, --merge-width N number of files to merge at once (default: " << GAFSorterParameters::FILES_PER_MERGE << ")" << std::endl; + std::cerr << " -S, --stable use stable sorting" << std::endl; + std::cerr << std::endl; } -// defines how to compare two GAF records -// first using 'rk1' tag (here, minimum node ID). If tied, use 'rk2' tag (here, maximum node ID) -struct compare_gaf { - bool operator()(const gafkluge::GafRecord& gaf1, const gafkluge::GafRecord& gaf2) { - // TODO find a way to not have to convert the node ids to string before and then back to int here? - long long rk11 = std::stoll(gaf1.opt_fields.find("rk1")->second.second); - long long rk12 = std::stoll(gaf2.opt_fields.find("rk1")->second.second); - long long rk21 = std::stoll(gaf1.opt_fields.find("rk2")->second.second); - long long rk22 = std::stoll(gaf2.opt_fields.find("rk2")->second.second); - return rk11 < rk12 || (rk11 == rk12 && rk21 < rk22); - } -}; -// defines a pair of a GAF record and the ID of the file it came from (used when merging sorted GAF files) -struct GafFile { - gafkluge::GafRecord gaf; - int file_i; -}; -// comparator used by the min-heap when merging sorted GAF files -struct greater_gaffile { - bool operator()(const GafFile& gf1, const GafFile& gf2) { - // TODO find a way to not have to convert the node ids to string before and then back to int here? - long long rk11 = std::stoll(gf1.gaf.opt_fields.find("rk1")->second.second); - long long rk12 = std::stoll(gf2.gaf.opt_fields.find("rk1")->second.second); - long long rk21 = std::stoll(gf1.gaf.opt_fields.find("rk2")->second.second); - long long rk22 = std::stoll(gf2.gaf.opt_fields.find("rk2")->second.second); - return rk11 > rk12 || (rk11 == rk12 && rk21 > rk22); - } -}; +//------------------------------------------------------------------------------ int main_gamsort(int argc, char **argv) { + // General options. + string input_format = "GAM"; + + // GAM sorting options. + size_t num_threads = GAM_MAX_THREADS; string index_filename; bool easy_sort = false; bool shuffle = false; bool show_progress = false; - string input_format = "GAM"; - int chunk_size = 1000000; // maximum number reads held in memory - // We limit the max threads, and only allow thread count to be lowered, to - // prevent tcmalloc from giving each thread a very large heap for many - // threads. - // On my machine we can keep about 4 threads busy. - size_t num_threads = 4; + + // GAF sorting options. + GAFSorterParameters gaf_params; + int c; optind = 2; // force optind past command positional argument while (true) { - static struct option long_options[] = - { - {"index", required_argument, 0, 'i'}, - {"dumb-sort", no_argument, 0, 'd'}, - {"shuffle", no_argument, 0, 's'}, - {"progress", no_argument, 0, 'p'}, - {"gaf-input", no_argument, 0, 'g'}, - {"chunk-size", required_argument, 0, 'c'}, - {"threads", required_argument, 0, 't'}, - {0, 0, 0, 0}}; + static struct option long_options[] = { + { "progress", no_argument, 0, 'p' }, + { "shuffle", no_argument, 0, 's' }, + { "threads", required_argument, 0, 't' }, + { "index", required_argument, 0, 'i' }, + { "dumb-sort", no_argument, 0, 'd' }, + { "gaf-input", no_argument, 0, 'G' }, + { "chunk-size", required_argument, 0, 'c' }, + { "merge-width", required_argument, 0, 'm' }, + { "stable", no_argument, 0, 'S' }, + { "help", no_argument, 0, 'h' }, + { 0, 0, 0, 0 } + }; int option_index = 0; - c = getopt_long(argc, argv, "i:dshpGt:c:", - long_options, &option_index); + c = getopt_long(argc, argv, "pst:i:dGc:m:Sh", long_options, &option_index); // Detect the end of the options. if (c == -1) @@ -95,27 +94,45 @@ int main_gamsort(int argc, char **argv) switch (c) { + // General options. + case 'p': + show_progress = true; + gaf_params.progress = true; + break; + case 's': + shuffle = true; + gaf_params.key_type = GAFSorterRecord::key_hash; + break; + case 't': + { + size_t parsed = std::max(parse(optarg), size_t(1)); + num_threads = std::min(parsed, num_threads); + gaf_params.threads = parsed; + } + break; + + // GAM sorting options. case 'i': index_filename = optarg; break; case 'd': easy_sort = true; break; - case 's': - shuffle = true; - break; - case 'p': - show_progress = true; - break; + + // GAF sorting options. case 'G': input_format = "GAF"; break; case 'c': - chunk_size = parse(optarg); + gaf_params.records_per_file = parse(optarg); break; - case 't': - num_threads = min(parse(optarg), num_threads); + case 'm': + gaf_params.files_per_merge = parse(optarg); break; + case 'S': + gaf_params.stable = true; + break; + case 'h': case '?': default: @@ -141,7 +158,7 @@ int main_gamsort(int argc, char **argv) GAMSorter gs(shuffle ? GAMSorter::Order::RANDOM : GAMSorter::Order::BY_GRAPH_POSITION, show_progress); // Do a normal GAMSorter sort - unique_ptr index; + std::unique_ptr index; if (!index_filename.empty()) { // Make an index @@ -150,167 +167,27 @@ int main_gamsort(int argc, char **argv) if (easy_sort) { // Sort in a single pass in memory - gs.easy_sort(gam_in, cout, index.get()); + gs.easy_sort(gam_in, std::cout, index.get()); } else { // Sort using fan-in-limited temp file merging - gs.stream_sort(gam_in, cout, index.get()); + gs.stream_sort(gam_in, std::cout, index.get()); } if (index.get() != nullptr) { // Save the index - ofstream index_out(index_filename); + std::ofstream index_out(index_filename); index->save(index_out); } }); } else if (input_format == "GAF") { - if (shuffle) { - // TODO: Implement shuffling for GAF files by making the - // comparators switch modes and hashing the record strings. - // TODO: Is there a way to be less duplicative with the - // StreamSorter? - cerr << "[vg gamsort] Shuffling is not implemented for GAF files." << endl; - exit(1); - } - - std::string input_gaf_filename = get_input_file_name(optind, argc, argv); - - // where to store the chunk of GAF records that will be sorted, then written to disk, - // (then later merged with the other sorted chunks) - std::vector current_gaf_chunk; - int count = 0; // read count - int chunk_id = 0; // ID of the current chunk - std::vector chunk_files; // names of the chunk files - - // read input GAF file - htsFile* in = hts_open(input_gaf_filename.c_str(), "r"); - if (in == NULL) { - cerr << "[vg gamsort] couldn't open " << input_gaf_filename << endl; exit(1); - } - kstring_t s_buffer = KS_INITIALIZE; - gafkluge::GafRecord gaf; - - string chunk_outf = temp_file::create(); - if(show_progress){ - cerr << "Preparing temporary chunk " << chunk_outf << "..." << endl; - } - - while (vg::io::get_next_record_from_gaf(nullptr, nullptr, in, s_buffer, gaf) == true) { - // find the minimum and maximum node IDs - nid_t min_node = std::numeric_limits::max(); - nid_t max_node = 0; - for (size_t i = 0; i < gaf.path.size(); ++i) { - const auto& gaf_step = gaf.path[i]; - assert(gaf_step.is_stable == false); - assert(gaf_step.is_interval == false); - nid_t nodeid = std::stol(gaf_step.name); - if (min_node > nodeid){ - min_node = nodeid; - } - if (max_node < nodeid){ - max_node = nodeid; - } - } - // write them as new GAF tags 'rk1' and 'rk2' - // they'll get written in the temporary chunks to avoid having - // to find them again when merging them - gaf.opt_fields["rk1"] = make_pair('i', std::to_string(min_node)); - gaf.opt_fields["rk2"] = make_pair('i', std::to_string(max_node)); - current_gaf_chunk.push_back(gaf); - count++; - - // if we've read enough reads, sort them and write to disk - if(count == chunk_size){ - // sort by minimum node id - if(show_progress){ - cerr << " Sorting chunk..." << endl; - } - std::stable_sort(current_gaf_chunk.begin(), current_gaf_chunk.end(), compare_gaf()); - // write to temporary file - if(show_progress){ - cerr << " Writing chunk..." << endl; - } - std::ofstream out_file(chunk_outf); - for (int ii=0; ii 0){ - // sort by minimum node id - if(show_progress){ - cerr << " Sorting chunk..." << endl; - } - std::stable_sort(current_gaf_chunk.begin(), current_gaf_chunk.end(), compare_gaf()); - // write to temporary file - if(show_progress){ - cerr << " Writing chunk..." << endl; - } - std::ofstream out_file(chunk_outf); - for (int ii=0; ii opened_files; - std::vector more_in_file; - std::vector opened_file_buffers; - // heap with the current GAF record of each file - std::priority_queue, greater_gaffile > opened_records; - - std::string line; - - // open the temp GAF files and read the first record - GafFile gf; - for(int ii=0; ii < chunk_files.size(); ii++){ - htsFile* in = hts_open(chunk_files[ii].c_str(), "r"); - if (in == NULL) { - cerr << "[vg::alignment.cpp] couldn't open " << input_gaf_filename << endl; exit(1); - } - opened_file_buffers.push_back(KS_INITIALIZE); - opened_files.push_back(in); - if(vg::io::get_next_record_from_gaf(nullptr, nullptr, opened_files.back(), opened_file_buffers.back(), gaf)){ - gf.gaf = gaf; - gf.file_i = ii; - opened_records.push(gf); - } - } - - while(opened_records.size() > 0){ - // which file will have the smallest record (i.e. to output first) - gf = opened_records.top(); - // remove the rk1/rk2 fields - gf.gaf.opt_fields.erase("rk1"); - gf.gaf.opt_fields.erase("rk2"); - // output smallest record - cout << gf.gaf << endl; - opened_records.pop(); - if(vg::io::get_next_record_from_gaf(nullptr, nullptr, opened_files[gf.file_i], opened_file_buffers[gf.file_i], gf.gaf)){ - opened_records.push(gf); - } - } - } + return 0; } diff --git a/src/unittest/gaf_sorter.cpp b/src/unittest/gaf_sorter.cpp new file mode 100644 index 0000000000..7122ba25e1 --- /dev/null +++ b/src/unittest/gaf_sorter.cpp @@ -0,0 +1,433 @@ +/** \file + * + * Unit tests for gaf_sorter.cpp, which provides tools for sorting GAF records. + */ + +#include "../gaf_sorter.hpp" +#include "../utility.hpp" + +#include +#include + +#include "catch.hpp" + +namespace vg { + +namespace unittest { + +//------------------------------------------------------------------------------ + +namespace { + +struct GAFInfo { + std::string line; + size_t id; + std::uint32_t min_node, max_node; + std::uint32_t first_node, gbwt_offset; + + GAFInfo(size_t id, size_t nodes, bool with_gbwt_offset) : + id(id), + min_node(std::numeric_limits::max()), max_node(0), + first_node(std::numeric_limits::max()), gbwt_offset(std::numeric_limits::max()) { + + // Name, query length, query start, query end, strand; + std::string nd = std::to_string(nodes); + this->line = "read" + std::to_string(id) + "\t" + nd + "\t0\t" + nd + "\t+\t"; + + // Path. + std::mt19937 rng(id); + for (size_t i = 0; i < nodes; i++) { + std::uint32_t node = rng() % 1000 + 1; + bool reverse = rng() % 2; + this->min_node = std::min(this->min_node, node); + this->max_node = std::max(this->max_node, node); + if (i == 0) { + this->first_node = node; + if (with_gbwt_offset) { + this->gbwt_offset = rng() % 1000; + } + } + this->line += (reverse ? "<" : ">") + std::to_string(node); + } + + // Path length, path start, path end, matches, alignment length, mapping quality. + this->line += "\t" + nd + "\t0\t" + nd + "\t" + nd + "\t" + nd + "\t60"; + + // Some arbitrary tags. + this->line += "\tab:Z:cd\tef:i:42"; + + // GBWT offset. + if (with_gbwt_offset) { + this->line.push_back('\t'); + this->line += GAFSorterRecord::GBWT_OFFSET_TAG; + this->line += std::to_string(this->gbwt_offset); + } + + // More tags. + this->line += "\tgh:Z:ij\tkl:i:42"; + } + + std::uint64_t key(GAFSorterRecord::key_type type) const { + if (type == GAFSorterRecord::key_node_interval) { + return (static_cast(this->min_node) << 32) | this->max_node; + } else if (type == GAFSorterRecord::key_gbwt_pos) { + if (this->gbwt_offset == std::numeric_limits::max()) { + return GAFSorterRecord::MISSING_KEY; + } else { + return (static_cast(this->first_node) << 32) | this->gbwt_offset; + } + } else if (type == GAFSorterRecord::key_hash) { + return GAFSorterRecord::hasher(this->line); + } else { + return GAFSorterRecord::MISSING_KEY; + } + } + + // Returns a copy of the line that can be consumed. + std::string value() const { + return this->line; + } + + // Returns the id encoded in read name. + static std::uint32_t decode_id(const std::string& line) { + size_t pos = line.find('\t'); + return std::stoul(line.substr(4, pos - 4)); + } +}; + +std::unique_ptr> generate_gaf(size_t count, size_t path_length, double unaligned_probability) { + std::unique_ptr> result(new std::vector()); + result->reserve(count); + std::mt19937 rng(count ^ path_length); + for (size_t id = 0; id < count; id++) { + double p = static_cast(rng()) / rng.max(); + if (p < unaligned_probability) { + GAFInfo info(id, 0, false); + result->push_back(info.value()); + } else { + GAFInfo info(id, path_length, true); + result->push_back(info.value()); + } + } + return result; +} + +std::vector generate_records(size_t count, size_t path_length, double unaligned_probability) { + auto lines = generate_gaf(count, path_length, unaligned_probability); + std::vector result; + for (std::string line : *lines) { + result.emplace_back(std::move(line), GAFSorterRecord::key_node_interval); + } + return result; +} + +GAFSorterFile generate_sorted(size_t count, size_t path_length, double unaligned_probability, bool stable, const std::string* filename = nullptr) { + auto lines = generate_gaf(count, path_length, unaligned_probability); + GAFSorterFile output = (filename == nullptr ? GAFSorterFile() : GAFSorterFile(*filename)); + sort_gaf_lines(std::move(lines), GAFSorterRecord::key_node_interval, stable, output); + return output; +} + +void check_sorted(GAFSorterFile& file, bool raw_gaf, size_t lines, GAFSorterRecord::key_type key_type, bool stable) { + REQUIRE(file.ok); + REQUIRE(file.records == lines); + + std::pair> in; + if (raw_gaf) { + in.second = std::unique_ptr(new std::ifstream(file.name, std::ios::binary)); + in.first = in.second.get(); + } else { + in = file.open_input(); + } + + size_t line_num = 0; + GAFSorterRecord previous; + while (line_num < lines) { + GAFSorterRecord record; + if (raw_gaf) { + REQUIRE(record.read_line(*in.first, key_type)); + } else { + REQUIRE(record.deserialize(*in.first)); + } + REQUIRE((line_num == 0 || previous.key <= record.key)); + if (stable && line_num > 0 && previous.key == record.key) { + std::uint32_t prev_id = GAFInfo::decode_id(previous.value); + std::uint32_t curr_id = GAFInfo::decode_id(record.value); + REQUIRE(prev_id < curr_id); + } + previous = record; + line_num++; + } + + // There should not be any additional data. + char c; + REQUIRE(!in.first->get(c)); +} + +void merge_and_check(std::unique_ptr> inputs, size_t buffer_size, size_t expected_records, GAFSorterRecord::key_type key_type) { + std::string filename = temp_file::create("gaf-sorter"); + GAFSorterFile output(filename); + merge_gaf_records(std::move(inputs), output, buffer_size); + check_sorted(output, true, expected_records, key_type, false); + temp_file::remove(filename); +} + +void integrated_test(size_t count, size_t path_length, double unaligned_probability, const GAFSorterParameters& params) { + // Generate the input. + std::string input_file = temp_file::create("gaf-sorter"); + std::ofstream out(input_file, std::ios::binary); + auto lines = generate_gaf(count, path_length, unaligned_probability); + for (const std::string& line : *lines) { + out << line << '\n'; + } + out.close(); + + // Sort the input. + std::string output_file = temp_file::create("gaf-sorter"); + REQUIRE(sort_gaf(input_file, output_file, params)); + temp_file::remove(input_file); + + // Check the output. + GAFSorterFile output(output_file); + output.records = count; // This is a new file object, so we need to set the record count. + check_sorted(output, true, count, params.key_type, false); + temp_file::remove(output_file); +} + +} // anonymous namespace + +//------------------------------------------------------------------------------ + +TEST_CASE("Records and keys", "[gaf_sorter]") { + SECTION("node interval keys") { + for (size_t id = 0; id < 10; id++) { + GAFInfo info(id, 10, false); + GAFSorterRecord record(info.value(), GAFSorterRecord::key_node_interval); + REQUIRE(record.key == info.key(GAFSorterRecord::key_node_interval)); + } + } + + SECTION("node interval key with an empty path") { + for (size_t id = 0; id < 10; id++) { + GAFInfo info(id, 0, false); + GAFSorterRecord record(info.value(), GAFSorterRecord::key_node_interval); + REQUIRE(record.key == GAFSorterRecord::MISSING_KEY); + } + } + + SECTION("GBWT position keys") { + for (size_t id = 0; id < 10; id++) { + GAFInfo info(id, 10, true); + GAFSorterRecord record(info.value(), GAFSorterRecord::key_gbwt_pos); + REQUIRE(record.key == info.key(GAFSorterRecord::key_gbwt_pos)); + } + } + + SECTION("GBWT position key with an empty path") { + for (size_t id = 0; id < 10; id++) { + GAFInfo info(id, 0, true); + GAFSorterRecord record(info.value(), GAFSorterRecord::key_gbwt_pos); + REQUIRE(record.key == GAFSorterRecord::MISSING_KEY); + } + } + + SECTION("GBWT position key without offset") { + for (size_t id = 0; id < 10; id++) { + GAFInfo info(id, 10, false); + GAFSorterRecord record(info.value(), GAFSorterRecord::key_gbwt_pos); + REQUIRE(record.key == GAFSorterRecord::MISSING_KEY); + } + } + + SECTION("hash keys") { + for (size_t id = 0; id < 10; id++) { + GAFInfo info(id, 10, false); + GAFSorterRecord record(info.value(), GAFSorterRecord::key_hash); + REQUIRE(record.key == info.key(GAFSorterRecord::key_hash)); + } + } +} + +TEST_CASE("Record serialization", "[gaf_sorter]") { + SECTION("records") { + auto records = generate_records(100, 10, 0.05); + std::string filename = temp_file::create("gaf-sorter"); + std::ofstream out(filename, std::ios::binary); + for (const GAFSorterRecord& record : records) { + record.serialize(out); + } + out.close(); + + std::ifstream in(filename, std::ios::binary); + for (size_t i = 0; i < records.size(); i++) { + GAFSorterRecord record; + REQUIRE(record.deserialize(in)); + REQUIRE(record.key == records[i].key); + REQUIRE(record.value == records[i].value); + } + char c; + REQUIRE(!in.get(c)); + } + + SECTION("lines") { + auto records = generate_records(120, 10, 0.05); + std::string filename = temp_file::create("gaf-sorter"); + std::ofstream out(filename, std::ios::binary); + for (const GAFSorterRecord& record : records) { + record.write_line(out); + } + out.close(); + + std::ifstream in(filename, std::ios::binary); + for (size_t i = 0; i < records.size(); i++) { + GAFSorterRecord record; + REQUIRE(record.read_line(in, GAFSorterRecord::key_node_interval)); + REQUIRE(record.key == records[i].key); + REQUIRE(record.value == records[i].value); + } + char c; + REQUIRE(!in.get(c)); + } +} + +//------------------------------------------------------------------------------ + +TEST_CASE("Sorting GAF records", "[gaf_sorter]") { + SECTION("raw GAF output") { + size_t n = 1000; + std::string filename = temp_file::create("gaf-sorter"); + GAFSorterFile output = generate_sorted(n, 10, 0.05, false, &filename); + check_sorted(output, true, n, GAFSorterRecord::key_node_interval, false); + temp_file::remove(filename); + } + + SECTION("empty GAF output") { + size_t n = 0; + std::string filename = temp_file::create("gaf-sorter"); + GAFSorterFile output = generate_sorted(n, 10, 0.05, false, &filename); + check_sorted(output, true, n, GAFSorterRecord::key_node_interval, false); + temp_file::remove(filename); + } + + SECTION("record output") { + size_t n = 1234; + GAFSorterFile output = generate_sorted(n, 10, 0.05, false); + check_sorted(output, false, n, GAFSorterRecord::key_node_interval, false); + } + + SECTION("empty record output") { + size_t n = 0; + GAFSorterFile output = generate_sorted(n, 10, 0.05, false); + check_sorted(output, false, n, GAFSorterRecord::key_node_interval, false); + } + + SECTION("stable sorting") { + size_t n = 1000; + GAFSorterFile output = generate_sorted(n, 10, 0.05, true); + check_sorted(output, false, n, GAFSorterRecord::key_node_interval, true); + } +} + +//------------------------------------------------------------------------------ + +TEST_CASE("Merging sorted files", "[gaf_sorter]") { + SECTION("three files") { + size_t n = 1000, expected_records = 0; + std::unique_ptr> inputs(new std::vector()); + for (size_t i = 0; i < 3; i++) { + inputs->push_back(generate_sorted(n + i, 10, 0.05, false)); + expected_records += inputs->back().records; + } + merge_and_check(std::move(inputs), 100, expected_records, GAFSorterRecord::key_node_interval); + } + + SECTION("one file is empty") { + size_t n = 1000, expected_records = 0; + std::unique_ptr> inputs(new std::vector()); + for (size_t i = 0; i < 3; i++) { + size_t count = (i == 1 ? 0 : n + i); + inputs->push_back(generate_sorted(count, 10, 0.05, false)); + expected_records += inputs->back().records; + } + merge_and_check(std::move(inputs), 100, expected_records, GAFSorterRecord::key_node_interval); + } + + SECTION("all files are empty") { + size_t expected_records = 0; + std::unique_ptr> inputs(new std::vector()); + for (size_t i = 0; i < 3; i++) { + inputs->push_back(generate_sorted(0, 10, 0.05, false)); + expected_records += inputs->back().records; + } + merge_and_check(std::move(inputs), 100, expected_records, GAFSorterRecord::key_node_interval); + } + + SECTION("no input files") { + size_t expected_records = 0; + std::unique_ptr> inputs(new std::vector()); + merge_and_check(std::move(inputs), 100, expected_records, GAFSorterRecord::key_node_interval); + } +} + +//------------------------------------------------------------------------------ + +TEST_CASE("GAF sorting", "[gaf_sorter]") { + SECTION("one full batch") { + size_t n = 1000; + GAFSorterParameters params; + params.records_per_file = 1000; + integrated_test(n, 10, 0.05, params); + } + + SECTION("one partial batch") { + size_t n = 500; + GAFSorterParameters params; + params.records_per_file = 1000; + integrated_test(n, 10, 0.05, params); + } + + SECTION("one merge") { + size_t n = 2000; + GAFSorterParameters params; + params.records_per_file = 1000; + params.files_per_merge = 2; + integrated_test(n, 10, 0.05, params); + } + + SECTION("one merge + one batch") { + size_t n = 3000; + GAFSorterParameters params; + params.records_per_file = 1000; + params.files_per_merge = 2; + integrated_test(n, 10, 0.05, params); + } + + SECTION("multiple levels of merges") { + size_t n = 10000; + GAFSorterParameters params; + params.records_per_file = 1000; + params.files_per_merge = 2; + integrated_test(n, 10, 0.05, params); + } + + SECTION("multithreaded") { + size_t n = 10000; + GAFSorterParameters params; + params.records_per_file = 1000; + params.files_per_merge = 2; + params.threads = 2; + integrated_test(n, 10, 0.05, params); + } + + SECTION("empty input") { + size_t n = 0; + GAFSorterParameters params; + integrated_test(n, 10, 0.05, params); + } +} + +//------------------------------------------------------------------------------ + +} // namespace unittest + +} // namespace vg diff --git a/src/unittest/zstdutil.cpp b/src/unittest/zstdutil.cpp new file mode 100644 index 0000000000..7441444f2c --- /dev/null +++ b/src/unittest/zstdutil.cpp @@ -0,0 +1,136 @@ +/** \file + * + * Unit tests for zstdutil.cpp, which implements wrappers for Zstandard compression. + */ + +#include "../zstdutil.hpp" +#include "../utility.hpp" + +#include +#include + +#include "catch.hpp" + +namespace vg { + +namespace unittest { + +//------------------------------------------------------------------------------ + +namespace { + +std::vector words { + "GATTACA", "CAT", "GATTA", "GAGA", "TAG", "GATT" +}; + +std::string generate_data(size_t num_words, size_t seed) { + std::string data; + std::mt19937 rng(num_words ^ seed); + for (size_t i = 0; i < num_words; i++) { + size_t index = rng() % words.size(); + data.insert(data.end(), words[index].begin(), words[index].end()); + } + return data; +} + +std::string compress_to_string(const std::string& data) { + std::stringstream compressed_stream; + zstd_compress_buf buffer(compressed_stream.rdbuf()); + buffer.sputn(const_cast(data.data()), data.size()); + buffer.pubsync(); + return compressed_stream.str(); +} + +void compress_to_file(const std::string& data, const std::string& filename) { + zstd_ofstream out(filename); + out.write(data.data(), data.size()); +} + +std::string decompress_string(const std::string& compressed, size_t expected_size) { + std::string decompressed; + decompressed.resize(expected_size); + std::stringstream compressed_stream(compressed); + zstd_decompress_buf buffer(compressed_stream.rdbuf()); + size_t n = buffer.sgetn(const_cast(decompressed.data()), decompressed.size()); + REQUIRE(n == expected_size); + REQUIRE(buffer.sgetc() == std::char_traits::eof()); + return decompressed; +} + +std::string decompress_file(const std::string& filename, size_t expected_size) { + std::string decompressed; + decompressed.resize(expected_size); + zstd_ifstream in(filename); + in.read(&decompressed[0], decompressed.size()); + REQUIRE(in); + REQUIRE(in.peek() == std::char_traits::eof()); + return decompressed; +} + +} // anonymous namespace + +//------------------------------------------------------------------------------ + +TEST_CASE("Compression with stream buffers", "[zstdutil]") { + SECTION("empty string") { + std::string compressed = compress_to_string(""); + std::string decompressed = decompress_string(compressed, 0); + REQUIRE(decompressed.empty()); + } + + SECTION("random words") { + for (size_t i = 0; i < 10; i++) { + std::string data = generate_data(1000, i); + std::string compressed = compress_to_string(data); + std::string decompressed = decompress_string(compressed, data.size()); + REQUIRE(decompressed == data); + } + } + + SECTION("large instance") { + for (size_t i = 0; i < 3; i++) { + std::string data = generate_data(1000000, i); + std::string compressed = compress_to_string(data); + std::string decompressed = decompress_string(compressed, data.size()); + REQUIRE(decompressed == data); + } + } +} + +TEST_CASE("Compression to files", "[zstdutil]") { + SECTION("empty string") { + std::string filename = temp_file::create("zstdutil"); + compress_to_file("", filename); + std::string decompressed = decompress_file(filename, 0); + REQUIRE(decompressed.empty()); + temp_file::remove(filename); + } + + SECTION("random words") { + for (size_t i = 0; i < 10; i++) { + std::string data = generate_data(1020, i); + std::string filename = temp_file::create("zstdutil"); + compress_to_file(data, filename); + std::string decompressed = decompress_file(filename, data.size()); + REQUIRE(decompressed == data); + temp_file::remove(filename); + } + } + + SECTION("large instance") { + for (size_t i = 0; i < 3; i++) { + std::string data = generate_data(1000020, i); + std::string filename = temp_file::create("zstdutil"); + compress_to_file(data, filename); + std::string decompressed = decompress_file(filename, data.size()); + REQUIRE(decompressed == data); + temp_file::remove(filename); + } + } +} + +//------------------------------------------------------------------------------ + +} // namespace unittest + +} // namespace vg diff --git a/src/zstdutil.cpp b/src/zstdutil.cpp index 05511802bc..d732c22b3e 100644 --- a/src/zstdutil.cpp +++ b/src/zstdutil.cpp @@ -1,3 +1,116 @@ +#include "zstdutil.hpp" + +namespace vg { + +//------------------------------------------------------------------------------ + +// Public class constants. + +constexpr int zstd_compress_buf::DEFAULT_COMPRESSION_LEVEL; + +//------------------------------------------------------------------------------ + +zstd_compress_buf::zstd_compress_buf(std::streambuf* inner, int compression_level) : + inner(inner), context(ZSTD_createCCtx()) +{ + this->in_buffer.resize(ZSTD_CStreamInSize()); + this->setp(this->in_buffer.data(), this->in_buffer.data() + this->in_buffer.size()); + this->out_buffer.resize(ZSTD_CStreamOutSize()); + ZSTD_CCtx_setParameter(this->context, ZSTD_c_compressionLevel, compression_level); +} + +zstd_compress_buf::~zstd_compress_buf() { + this->sync(); + ZSTD_freeCCtx(this->context); this->context = nullptr; +} + +zstd_compress_buf::int_type zstd_compress_buf::overflow(int_type ch) { + if (ch != traits_type::eof()) { + if (this->sync() == -1) { + return traits_type::eof(); + } + *this->pptr() = traits_type::to_char_type(ch); + this->pbump(1); + } + return ch; +} + +zstd_compress_buf::int_type zstd_compress_buf::sync() { + if (this->inner == nullptr) { + throw std::runtime_error("zstd_compress_buf: inner stream buffer is null"); + } + + ZSTD_inBuffer input = { this->pbase(), static_cast(this->pptr() - this->pbase()), 0 }; + ZSTD_EndDirective mode = (this->pptr() < this->epptr() ? ZSTD_e_end : ZSTD_e_continue); + bool finished = false; + while (!finished) { + ZSTD_outBuffer output = { this->out_buffer.data(), this->out_buffer.size(), 0 }; + size_t result = ZSTD_compressStream2(this->context, &output, &input, mode); + if (ZSTD_isError(result)) { + std::string msg = "zstd_compress_buf: compression failed: " + std::string(ZSTD_getErrorName(result)); + throw std::runtime_error(msg); + } + size_t n = this->inner->sputn(this->out_buffer.data(), output.pos); + if (n != output.pos) { + throw std::runtime_error("zstd_compress_buf: failed to write compressed data"); + } + finished = (input.pos >= input.size); + if (mode == ZSTD_e_end) { + finished &= (result == 0); + } + } + + this->setp(this->in_buffer.data(), this->in_buffer.data() + this->in_buffer.size()); + return 0; +} + +//------------------------------------------------------------------------------ + +zstd_decompress_buf::zstd_decompress_buf(std::streambuf* inner) : + inner(inner), context(ZSTD_createDCtx()) +{ + this->in_buffer.resize(ZSTD_DStreamInSize()); + this->in_offset = this->in_buffer.size(); + this->out_buffer.resize(ZSTD_DStreamOutSize()); +} + +zstd_decompress_buf::~zstd_decompress_buf() { + ZSTD_freeDCtx(this->context); this->context = nullptr; +} + +zstd_decompress_buf::int_type zstd_decompress_buf::underflow() { + if (this->gptr() < this->egptr()) { + return traits_type::to_int_type(*this->gptr()); + } + + // Fill the input buffer if necessary. + if (this->in_offset >= this->in_buffer.size()) { + size_t n = this->inner->sgetn(this->in_buffer.data(), this->in_buffer.size()); + this->in_offset = 0; + this->in_buffer.resize(n); + } + + // Decompress the data into the input buffer. + ZSTD_inBuffer input = { this->in_buffer.data(), this->in_buffer.size(), this->in_offset }; + ZSTD_outBuffer output = { this->out_buffer.data(), this->out_buffer.size(), 0 }; + size_t result = ZSTD_decompressStream(this->context, &output, &input); + if (ZSTD_isError(result)) { + std::string msg = "zstd_decompress_buf: decompression failed: " + std::string(ZSTD_getErrorName(result)); + throw std::runtime_error(msg); + } + this->in_offset = input.pos; + + // Tell the stream to use the output buffer. + this->setg(this->out_buffer.data(), this->out_buffer.data(), this->out_buffer.data() + output.pos); + return (output.pos > 0 ? traits_type::to_int_type(*this->gptr()) : traits_type::eof()); +} + +//------------------------------------------------------------------------------ + +} // namespace vg + +//------------------------------------------------------------------------------ + // // -*- coding: utf-8-unix; -*- // Copyright (c) 2020 Tencent, Inc. @@ -8,8 +121,6 @@ // Desc: // -#include "zstdutil.hpp" - namespace zstdutil { int CompressString(const std::string& src, std::string& dst, int compressionlevel) { @@ -136,4 +247,6 @@ int StreamDecompressString(const std::string& src, std::string& dst, int compres return 0; } -} // namespace util +} // namespace zstdutil + +//------------------------------------------------------------------------------ diff --git a/src/zstdutil.hpp b/src/zstdutil.hpp index e04381474f..382b39a7ae 100644 --- a/src/zstdutil.hpp +++ b/src/zstdutil.hpp @@ -1,3 +1,116 @@ + +#pragma once + +#include +#include +#include +#include + +#include + +/** \file + * Wrappers for Zstandard compression and decompression. + * + * TODO: Override xsputn, xsgetn for faster compression? + * TODO: Move constructors for streams? + * TODO: is_open(), close() for streams? + */ + +namespace vg { + +//------------------------------------------------------------------------------ + +/// Zstandard compression buffer that writes to another stream buffer. +class zstd_compress_buf : public std::streambuf { +public: + // We should be using ZSTD_defaultCLevel(), but it requires v1.5.0, which is not available everywhere. + constexpr static int DEFAULT_COMPRESSION_LEVEL = 3; + + explicit zstd_compress_buf(std::streambuf* inner, int compression_level = DEFAULT_COMPRESSION_LEVEL); + ~zstd_compress_buf(); + + zstd_compress_buf(const zstd_compress_buf&) = delete; + zstd_compress_buf& operator=(const zstd_compress_buf&) = delete; + zstd_compress_buf(zstd_compress_buf&&) = default; + zstd_compress_buf& operator=(zstd_compress_buf&&) = default; + +protected: + int_type overflow(int_type ch) override; + int sync() override; + + std::streambuf* inner; + ZSTD_CCtx* context; + std::vector in_buffer; + std::vector out_buffer; +}; + +/// Zstandard decompression buffer that reads from another stream buffer. +class zstd_decompress_buf : public std::streambuf { +public: + explicit zstd_decompress_buf(std::streambuf* inner); + ~zstd_decompress_buf(); + + zstd_decompress_buf(const zstd_decompress_buf&) = delete; + zstd_decompress_buf& operator=(const zstd_decompress_buf&) = delete; + zstd_decompress_buf(zstd_decompress_buf&&) = default; + zstd_decompress_buf& operator=(zstd_decompress_buf&&) = default; + +protected: + int_type underflow() override; + + std::streambuf* inner; + ZSTD_DCtx* context; + std::vector in_buffer; + std::vector out_buffer; + size_t in_offset; +}; + +//------------------------------------------------------------------------------ + +/// Zstandard output file stream. +/// The object cannot be copied or moved. +class zstd_ofstream : public std::ostream { +public: + explicit zstd_ofstream(const std::string& filename, int compression_level = zstd_compress_buf::DEFAULT_COMPRESSION_LEVEL) : + std::ostream(&buffer), + inner(filename, std::ios::binary), + buffer(inner.rdbuf(), compression_level) {} + + zstd_ofstream(const zstd_ofstream&) = delete; + zstd_ofstream& operator=(const zstd_ofstream&) = delete; + zstd_ofstream(zstd_ofstream&&) = delete; + zstd_ofstream& operator=(zstd_ofstream&&) = delete; + +protected: + std::ofstream inner; + zstd_compress_buf buffer; +}; + +/// Zstandard input file stream. +/// The object cannot be copied or moved. +class zstd_ifstream : public std::istream { +public: + explicit zstd_ifstream(const std::string& filename) : + std::istream(&buffer), + inner(filename, std::ios::binary), + buffer(inner.rdbuf()) {} + + zstd_ifstream(const zstd_ifstream&) = delete; + zstd_ifstream& operator=(const zstd_ifstream&) = delete; + zstd_ifstream(zstd_ifstream&&) = delete; + zstd_ifstream& operator=(zstd_ifstream&&) = delete; + +protected: + std::ifstream inner; + zstd_decompress_buf buffer; +}; + +//------------------------------------------------------------------------------ + +} // namespace vg + +// TODO: Get rid of these when we have something better. + // // -*- coding: utf-8-unix; -*- // Copyright (c) 2020 Tencent, Inc. @@ -8,11 +121,6 @@ // Desc: // -#pragma once - -#include -#include - namespace zstdutil { const int DEFAULTCOMPRESSLEVEL = 5; @@ -32,4 +140,6 @@ int StreamDecompressString(const std::string& src, std::string& dst, int StreamCompressString(const std::string& src, std::string& dst, int compressionlevel = DEFAULTCOMPRESSLEVEL); -} // namespace util +} // namespace zstdutil + +//------------------------------------------------------------------------------ diff --git a/test/t/42_vg_gamsort.t b/test/t/42_vg_gamsort.t index 02f09aec0f..62584e7249 100644 --- a/test/t/42_vg_gamsort.t +++ b/test/t/42_vg_gamsort.t @@ -6,12 +6,14 @@ BASH_TAP_ROOT=../deps/bash-tap PATH=../bin:$PATH # for vg -plan tests 4 +plan tests 8 vg construct -r small/x.fa -v small/x.vcf.gz >x.vg vg index -x x.xg x.vg vg sim -n 1000 -l 100 -e 0.01 -i 0.005 -x x.xg -a >x.gam + +# GAM vg gamsort x.gam >x.sorted.gam vg view -aj x.sorted.gam | jq -r '.path.mapping | ([.[] | .position.node_id | tonumber] | min)' >min_ids.gamsorted.txt @@ -26,4 +28,27 @@ vg gamsort --shuffle x.sorted.gam >x.shuffled.gam is "$?" "0" "GAMs can be shuffled" is "$(vg stats -a x.shuffled.gam)" "$(vg stats -a x.sorted.gam)" "Shuffling preserves read data" -rm -f x.vg x.xg x.gam x.sorted.gam x.sorted.2.gam x.shuffled.gam min_ids.gamsorted.txt min_ids.sorted.txt x.sorted.gam.gai x.sorted.2.gam.gai +rm -f x.sorted.gam x.sorted.2.gam x.shuffled.gam min_ids.gamsorted.txt min_ids.sorted.txt x.sorted.gam.gai x.sorted.2.gam.gai + + +# GAF. Correctness is tested in unit tests, so we just check that the commands work. +vg convert -G x.gam x.xg > x.gaf +sort x.gaf > x.gaf.lexicographic + +vg gamsort -G x.gaf > x.sorted.gaf +is "$?" "0" "GAFs can be sorted" +sort x.sorted.gaf > x.sorted.gaf.lexicographic +cmp x.gaf.lexicographic x.sorted.gaf.lexicographic +is "$?" "0" "Sorting a GAF preserves read data" + +vg gamsort -G --shuffle x.gaf > x.shuffled.gaf +is "$?" "0" "GAFs can be shuffled" +sort x.shuffled.gaf > x.shuffled.gaf.lexicographic +cmp x.gaf.lexicographic x.shuffled.gaf.lexicographic +is "$?" "0" "Shuffling a GAF preserves read data" + +rm -f x.gaf x.gaf.lexicographic x.sorted.gaf x.sorted.gaf.lexicographic x.shuffled.gaf x.shuffled.gaf.lexicographic + + +# Cleanup +rm -f x.vg x.xg x.gam