-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
22 changed files
with
323 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ such as output naming, auxiliary outputs, and wrapper models. | |
|
||
output-naming | ||
auxiliary-outputs | ||
multi-gpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
Multi-GPU training | ||
================== | ||
|
||
Some of the architectures in metatensor-models support multi-GPU training. | ||
In multi-GPU training, every batch of samples is split into smaller | ||
mini-batches and the computation is run for each of the smaller mini-batches | ||
in parallel on different GPUs. The different gradients obtained on each | ||
device are then summed. This approach allows the user to reduce the time | ||
it takes to train models. | ||
|
||
Here is a list of architectures supporting multi-GPU training: | ||
|
||
|
||
SOAP-BPNN | ||
--------- | ||
|
||
SOAP-BPNN supports distributed multi-GPU training on SLURM environments. | ||
The options file to run distributed training with the SOAP-BPNN model looks | ||
like this: | ||
|
||
.. literalinclude:: ../../../examples/multi-gpu/soap-bpnn/options-distributed.yaml | ||
:language: yaml | ||
|
||
and the slurm submission script would look like this: | ||
|
||
.. literalinclude:: ../../../examples/multi-gpu/soap-bpnn/submit-distributed.sh | ||
:language: shell |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../tests/distributed/options-distributed.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../tests/distributed/submit-distributed.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 17 additions & 0 deletions
17
src/metatrain/utils/distributed/distributed_data_parallel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch | ||
|
||
|
||
class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): | ||
""" | ||
DistributedDataParallel wrapper that inherits from | ||
:py:class`torch.nn.parallel.DistributedDataParallel` | ||
and adds the capabilities attribute to it. | ||
:param module: The module to be parallelized. | ||
:param args: Arguments to be passed to the parent class. | ||
:param kwargs: Keyword arguments to be passed to the parent class | ||
""" | ||
|
||
def __init__(self, module: torch.nn.Module, *args, **kwargs): | ||
super(DistributedDataParallel, self).__init__(module, *args, **kwargs) | ||
self.outputs = module.outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .slurm import is_slurm, is_slurm_main_process | ||
|
||
|
||
def is_main_process(): | ||
if is_slurm(): | ||
return is_slurm_main_process() | ||
else: | ||
return True |
Oops, something went wrong.