Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiprocessing when preparing ICL dataset #1276

Open
sanjari-orb opened this issue Jun 13, 2024 · 8 comments
Open

Allow multiprocessing when preparing ICL dataset #1276

sanjari-orb opened this issue Jun 13, 2024 · 8 comments
Labels
enhancement New feature or request

Comments

@sanjari-orb
Copy link
Contributor

🚀 Feature Request

Allow passing num_proc/num_workers parameter inInContextLearningDataset so that preparation of dataset can use more than one processes.

Motivation

When loading bigger ICL eval datasets, it is desirable to pass num_procs>1 in the following map function, which preps each example in the dataset:

self.dataset: HFDataset = self.dataset.map(
self._prep_example,
with_indices=True,
fn_kwargs={
'num_fewshot': num_fewshot,
'prompt_string': prompt_string,
'fewshot_rng': fewshot_rng,
},
)

Can we introduce a num_proc parameter in the InContextLearningDataset constructors so that the example preparation can instead be done like this:

        self.dataset: HFDataset = self.dataset.map(
            self._prep_example,
            with_indices=True,
            num_proc=num_proc,
            fn_kwargs={
                'num_fewshot': num_fewshot,
                'prompt_string': prompt_string,
                'fewshot_rng': fewshot_rng,
            },
        )

This greatly increases the speed of loading larger datasets.

@sanjari-orb sanjari-orb added the enhancement New feature or request label Jun 13, 2024
@dakinggg
Copy link
Collaborator

@sanjari-orb sure! My only hesitation in doing this is that we've observed occasional hangs when using hf datasets and multiprocessing (huggingface/datasets#6393), but should be fine, especially if we keep it single process by default. Would be happy to accept a PR adding the arg.

@sanjari-orb
Copy link
Contributor Author

Actually we ended up seeing the same problem of the map() hanging while loading ICL evaluations with num_proc>1, and unluckily this happens frequently enough.
Do you have any insights on how this problem was solved in mosaicml?

@dakinggg
Copy link
Collaborator

Unfortunately I have never managed to fully root cause this issue (feel free to comment on the datasets issue, as I don't think they have been able to fix it either). However, I believe it has something to do with multiple processes processing the same data at the same time. As a result, in the main dataloader we have local rank 0 go first, so that all the other ranks are just reading data cached on disk. We could probably apply the same logic in the ICL classes.

@sanjari-orb
Copy link
Contributor Author

Could you give me a pointer to where this is being handled?

@dakinggg
Copy link
Collaborator

Ah yeah sorry, meant to include the link.

# Non local rank 0 ranks will wait here for local rank 0 to finish the data processing.
# Once local rank 0 is done, the datasets are all cached on disk, and all other ranks
# can just read them.
if dist.get_local_rank() != 0:
log.debug('Waiting for local_rank 0 to finish data prep')
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass
for the wait, and
# Now local rank 0 indicates to the other ranks that it is done
if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_data_prep')
# All ranks sync up at this barrier, having completed data processing
dist.barrier()
# Last, local rank 0 cleans up the signal file
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
for the cleanup. I added some nicer utils for this to composer also (mosaicml/composer#3396) but haven't updated foundry yet to use them.

@sanjari-orb
Copy link
Contributor Author

We are already doing that here though right?

with dist.local_rank_zero_download_and_wait(destination_path):
if dist.get_local_rank() == 0:
get_file(dataset_uri, destination_path, overwrite=True)
dataset = load_dataset(

@dakinggg
Copy link
Collaborator

dakinggg commented Jun 21, 2024

not quite. in the code I linked we have rank 0 go first for the dataset load. In the code you linked, we have only rank 0 download the file, but then all ranks would call load_dataset at the same time

@sanjari-orb
Copy link
Contributor Author

Ah gotcha. Okay let me try this. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants