diff --git a/pseudolabeling/train_val_test_split.py b/pseudolabeling/train_val_test_split.py index a434c47..e22953a 100644 --- a/pseudolabeling/train_val_test_split.py +++ b/pseudolabeling/train_val_test_split.py @@ -13,6 +13,7 @@ def main(args): for d in os.listdir(args.source_dir) if os.path.isdir(os.path.join(args.source_dir, d)) ] + print(f"Subdirectories of source dir {args.source_dir}: {dset_dirs}") for dset_dir in dset_dirs: dset_path = os.path.join(args.source_dir, dset_dir) @@ -44,16 +45,46 @@ def main(args): for dataset, files in zip( ["train", "val", "test"], [train_files, val_files, test_files] ): - split_path = os.path.join(args.output_dir, dset_dir, dataset) + split_path = os.path.join(args.output_dir, dataset, dset_dir) print(f"Move {dset_path} -----------> {split_path}") os.makedirs(split_path, exist_ok=True) # Create class directory in split for file in files: - shutil.move( - os.path.join(dset_path, file), os.path.join(split_path, file) - ) + if args.copy: + shutil.copy(os.path.join(dset_path, file), os.path.join(split_path, file)) + else: + shutil.move( + os.path.join(dset_path, file), os.path.join(split_path, file) + ) if __name__ == "__main__": + """ + Given a source directory containing the data for multiple modalities, e.g., + + ``` + |--source_dir/ + | |--modality_a/ + | |--modality_b/ + | |--modality_c/ + ``` + + move the files into a specified output_dir/ with the structure: + ``` + |--source_dir/ + | |--train/ + | | |--modality_a/ + | | |--modality_b/ + | | |--modality_c/ + | |--val/ + | | |--modality_a/ + | | |--modality_b/ + | | |--modality_c/ + | |--test/ + | | |--modality_a/ + | | |--modality_b/ + | | |--modality_c/ + ``` + """ parser = argparse.ArgumentParser( description="Partition datasets into train, val, and test splits." ) @@ -87,6 +118,12 @@ def main(args): default=False, help="Whether to shuffle shards befores splitting. Otherwise, train is 0, 1, 2, etc.", ) + parser.add_argument( + "--copy", + type=bool, + default=False, + help="Whether to copy the files instead of move. Defaults to False.", + ) args = parser.parse_args() main(args)