From 0305a0ea94bdac1cbd8b37f7e64607a2264d9738 Mon Sep 17 00:00:00 2001 From: Pier Fiedorowicz <117680821+fiedorowicz1@users.noreply.github.com> Date: Wed, 12 Jun 2024 17:50:55 -0700 Subject: [PATCH] Fix Hang in Python Dataset Reader with DistConv (#2457) * Internally track mini batch index * Remove redundant minibatch index --- .../readers/data_reader_python_dataset.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/data_ingestion/readers/data_reader_python_dataset.cpp b/src/data_ingestion/readers/data_reader_python_dataset.cpp index fe0b4758306..07b7ea404d5 100644 --- a/src/data_ingestion/readers/data_reader_python_dataset.cpp +++ b/src/data_ingestion/readers/data_reader_python_dataset.cpp @@ -186,7 +186,14 @@ void python_dataset_reader::shuffle_responses(DataType* responses_ptr) execution_mode mode = exec_mode_from_string(get_role()); dataset& ds = get_trainer().get_data_coordinator().get_dataset(mode); - uint64_t global_mb_size = ds.get_current_mini_batch_size(); + uint64_t global_mb_size{}; + if (m_dataset_minibatch_offset < (ds.get_num_iterations_per_epoch() - 1)) { + global_mb_size = ds.get_mini_batch_size(); + } + else if (m_dataset_minibatch_offset == + (ds.get_num_iterations_per_epoch() - 1)) { + global_mb_size = ds.get_last_mini_batch_size(); + } uint64_t local_mb_size = global_mb_size / nprocs; uint64_t extra_samples = global_mb_size % nprocs;