From 7df8aa400b9b19e555dd05f5fd3a049743d36cee Mon Sep 17 00:00:00 2001 From: Mario Vasilev Date: Wed, 20 Mar 2024 18:12:12 +0000 Subject: [PATCH] [fix]: account for y labels being offset by NUM_SPECIAL_TOKENS when calling np.bincount in emnist balance subsampling np.bincount will prepend zeros for elements that were not found starting from 0 to y_min_element-1; this will bias the mean to be lower if not controlled and will result in fewer samples in the balanced dataset --- text_recognizer/data/emnist.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index cb4252a..adf4f22 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -130,10 +130,12 @@ def _process_raw_dataset(filename: str, dirname: Path): shutil.rmtree("matlab") -def _sample_to_balance(x, y): +def _sample_to_balance(x, y, y_min_element=NUM_SPECIAL_TOKENS): """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(42) - num_to_sample = int(np.bincount(y.flatten()).mean()) + # np.bincount always starts counting from 0, so only take + # result for elements that actually occur in y; + num_to_sample = int(np.bincount(y.flatten())[y_min_element:].mean()) all_sampled_inds = [] for label in np.unique(y.flatten()): inds = np.where(y == label)[0]