Skip to content

Commit

Permalink
Merge branch 'main' into units
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Jun 13, 2024
2 parents 07fab42 + 5eae102 commit 27974b0
Show file tree
Hide file tree
Showing 22 changed files with 323 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ coverage:
- "tests/.*"
- "examples/.*"
- "src/metatrain/experimental/.*"
- "src/metatrain/utils/distributed/.*"


comment: false
1 change: 1 addition & 0 deletions docs/src/advanced-concepts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ such as output naming, auxiliary outputs, and wrapper models.

output-naming
auxiliary-outputs
multi-gpu
27 changes: 27 additions & 0 deletions docs/src/advanced-concepts/multi-gpu.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Multi-GPU training
==================

Some of the architectures in metatensor-models support multi-GPU training.
In multi-GPU training, every batch of samples is split into smaller
mini-batches and the computation is run for each of the smaller mini-batches
in parallel on different GPUs. The different gradients obtained on each
device are then summed. This approach allows the user to reduce the time
it takes to train models.

Here is a list of architectures supporting multi-GPU training:


SOAP-BPNN
---------

SOAP-BPNN supports distributed multi-GPU training on SLURM environments.
The options file to run distributed training with the SOAP-BPNN model looks
like this:

.. literalinclude:: ../../../examples/multi-gpu/soap-bpnn/options-distributed.yaml
:language: yaml

and the slurm submission script would look like this:

.. literalinclude:: ../../../examples/multi-gpu/soap-bpnn/submit-distributed.sh
:language: shell
1 change: 1 addition & 0 deletions examples/multi-gpu/soap-bpnn/options-distributed.yaml
1 change: 1 addition & 0 deletions examples/multi-gpu/soap-bpnn/submit-distributed.sh
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"metatensor-operations==0.2.1",
"metatensor-torch==0.5.1",
"omegaconf",
"python-hostlist",
"torch",
]

Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main():
if callable == "train_model":
# define and create `checkpoint_dir` based on current directory, date and time
checkpoint_dir = _datetime_output_path(now=datetime.now())
os.makedirs(checkpoint_dir)
os.makedirs(checkpoint_dir, exist_ok=True) # exist_ok=True for distributed
args.checkpoint_dir = checkpoint_dir

log_file = checkpoint_dir / "train.log"
Expand Down
4 changes: 4 additions & 0 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from ..utils.data.dataset import _train_test_random_split
from ..utils.devices import pick_devices
from ..utils.distributed.logging import is_main_process
from ..utils.errors import ArchitectureError
from ..utils.io import check_suffix
from ..utils.omegaconf import (
Expand Down Expand Up @@ -394,6 +395,9 @@ def train_model(
except Exception as e:
raise ArchitectureError(e)

if not is_main_process():
return # only save and evaluate on the main process

###########################
# SAVE FINAL MODEL ########
###########################
Expand Down
2 changes: 2 additions & 0 deletions src/metatrain/experimental/soap_bpnn/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ model:
num_neurons_per_layer: 32

training:
distributed: False
distributed_port: 39591
batch_size: 8
num_epochs: 100
learning_rate: 0.001
Expand Down
16 changes: 11 additions & 5 deletions src/metatrain/experimental/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def test_regression_init():
)

# if you need to change the hardcoded values:
torch.set_printoptions(precision=5)
print(output["mtt::U0"].block().values)
# torch.set_printoptions(precision=12)
# print(output["mtm::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5
Expand Down Expand Up @@ -90,12 +90,18 @@ def test_regression_train():
)

expected_output = torch.tensor(
[[-40.56458], [-56.51794], [-76.49743], [-77.32737], [-93.40791]]
[
[-40.592571258545],
[-56.522350311279],
[-76.571365356445],
[-77.384849548340],
[-93.445365905762],
]
)

# if you need to change the hardcoded values:
# torch.set_printoptions(precision=5)
# print(output["mtt::U0"].block().values)
# torch.set_printoptions(precision=12)
# print(output["mtm::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values, expected_output, rtol=1e-5, atol=1e-5
Expand Down
121 changes: 103 additions & 18 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import List, Union

import torch
from metatensor.learn.data import DataLoader
import torch.distributed
from torch.utils.data import DataLoader, DistributedSampler

from ...utils.composition import calculate_composition_weights
from ...utils.data import (
Expand All @@ -15,6 +16,8 @@
get_all_targets,
)
from ...utils.data.extract_targets import get_targets_dict
from ...utils.distributed.distributed_data_parallel import DistributedDataParallel
from ...utils.distributed.slurm import DistributedEnvironment
from ...utils.evaluate_model import evaluate_model
from ...utils.external_naming import to_external_name
from ...utils.logging import MetricLogger
Expand Down Expand Up @@ -46,18 +49,45 @@ def train(
val_datasets: List[Union[Dataset, torch.utils.data.Subset]],
checkpoint_dir: str,
):
dtype = train_datasets[0][0]["system"].positions.dtype

# only one device, as we don't support multi-gpu for now
assert len(devices) == 1
device = devices[0]
is_distributed = self.hypers["distributed"]

if is_distributed:
distr_env = DistributedEnvironment(self.hypers["distributed_port"])
torch.distributed.init_process_group(backend="nccl")
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
else:
rank = 0

if is_distributed:
if len(devices) > 1:
raise ValueError(
"Requested distributed training with the `multi-gpu` device. "
" If you want to run distributed training with SOAP-BPNN, please "
"set `device` to cuda."
)
# the calculation of the device number works both when GPUs on different
# processes are not visible to each other and when they are
device_number = distr_env.local_rank % torch.cuda.device_count()
device = torch.device("cuda", device_number)
else:
device = devices[
0
] # only one device, as we don't support multi-gpu for now
dtype = train_datasets[0][0]["system"].positions.dtype

logger.info(f"Training on device {device} with dtype {dtype}")
if is_distributed:
logger.info(f"Training on {world_size} devices with dtype {dtype}")
else:
logger.info(f"Training on device {device} with dtype {dtype}")
model.to(device=device, dtype=dtype)
if is_distributed:
model = DistributedDataParallel(model, device_ids=[device])

# Calculate and set the composition weights for all targets:
logger.info("Calculating composition weights")
for target_name in model.new_outputs:
for target_name in (model.module if is_distributed else model).new_outputs:
if "mtt::aux::" in target_name:
continue
# TODO: document transfer learning and say that outputs that are already
Expand All @@ -80,8 +110,8 @@ def train(
raise ValueError(
"Supplied atomic types are not present in the dataset."
)
model.set_composition_weights(
target_name, fixed_weights, list(atomic_types)
(model.module if is_distributed else model).set_composition_weights(
target_name, fixed_weights, atomic_types
)

else:
Expand All @@ -97,40 +127,75 @@ def train(
composition_weights, composition_types = calculate_composition_weights(
train_datasets_with_target, target_name
)
model.set_composition_weights(
(model.module if is_distributed else model).set_composition_weights(
target_name, composition_weights, composition_types
)

logger.info("Setting up data loaders")

if is_distributed:
train_samplers = [
DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
drop_last=True,
)
for train_dataset in train_datasets
]
val_samplers = [
DistributedSampler(
val_dataset,
num_replicas=world_size,
rank=rank,
shuffle=False,
drop_last=False,
)
for val_dataset in val_datasets
]
else:
train_samplers = [None] * len(train_datasets)
val_samplers = [None] * len(val_datasets)

# Create dataloader for the training datasets:
train_dataloaders = []
for dataset in train_datasets:
for dataset, sampler in zip(train_datasets, train_samplers):
train_dataloaders.append(
DataLoader(
dataset=dataset,
batch_size=self.hypers["batch_size"],
shuffle=True,
sampler=sampler,
shuffle=(
sampler is None
), # the sampler takes care of this (if present)
drop_last=(
sampler is None
), # the sampler takes care of this (if present)
collate_fn=collate_fn,
)
)
train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True)

# Create dataloader for the validation datasets:
val_dataloaders = []
for dataset in val_datasets:
for dataset, sampler in zip(val_datasets, val_samplers):
val_dataloaders.append(
DataLoader(
dataset=dataset,
batch_size=self.hypers["batch_size"],
sampler=sampler,
shuffle=False,
drop_last=False,
collate_fn=collate_fn,
)
)
val_dataloader = CombinedDataLoader(val_dataloaders, shuffle=False)

# Extract all the possible outputs and their gradients:
train_targets = get_targets_dict(train_datasets, model.dataset_info)
train_targets = get_targets_dict(
train_datasets, (model.module if is_distributed else model).dataset_info
)
outputs_list = []
for target_name, target_info in train_targets.items():
outputs_list.append(target_name)
Expand Down Expand Up @@ -179,6 +244,9 @@ def train(
# Train the model:
logger.info("Starting training")
for epoch in range(self.hypers["num_epochs"]):
if is_distributed:
sampler.set_epoch(epoch)

train_rmse_calculator = RMSEAccumulator()
val_rmse_calculator = RMSEAccumulator()

Expand Down Expand Up @@ -207,12 +275,18 @@ def train(
targets = average_by_num_atoms(targets, systems, per_structure_targets)

train_loss_batch = loss_fn(predictions, targets)
train_loss += train_loss_batch.item()
train_loss_batch.backward()
optimizer.step()

if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(train_loss_batch)
train_loss += train_loss_batch.item()
train_rmse_calculator.update(predictions, targets)
finalized_train_info = train_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
device=device,
)

val_loss = 0.0
Expand All @@ -238,10 +312,16 @@ def train(
targets = average_by_num_atoms(targets, systems, per_structure_targets)

val_loss_batch = loss_fn(predictions, targets)

if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()
val_rmse_calculator.update(predictions, targets)
finalized_val_info = val_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
device=device,
)

lr_scheduler.step(val_loss)
Expand All @@ -264,10 +344,15 @@ def train(
metric_logger.log(
metrics=[finalized_train_info, finalized_val_info],
epoch=epoch,
rank=rank,
)

if epoch % self.hypers["checkpoint_interval"] == 0:
model.save_checkpoint(Path(checkpoint_dir) / f"model_{epoch}.ckpt")
if is_distributed:
torch.distributed.barrier()
(model.module if is_distributed else model).save_checkpoint(
Path(checkpoint_dir) / f"model_{epoch}.ckpt"
)

# early stopping criterion:
if val_loss < best_val_loss:
Expand Down
17 changes: 17 additions & 0 deletions src/metatrain/utils/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch


class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):
"""
DistributedDataParallel wrapper that inherits from
:py:class`torch.nn.parallel.DistributedDataParallel`
and adds the capabilities attribute to it.
:param module: The module to be parallelized.
:param args: Arguments to be passed to the parent class.
:param kwargs: Keyword arguments to be passed to the parent class
"""

def __init__(self, module: torch.nn.Module, *args, **kwargs):
super(DistributedDataParallel, self).__init__(module, *args, **kwargs)
self.outputs = module.outputs
8 changes: 8 additions & 0 deletions src/metatrain/utils/distributed/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .slurm import is_slurm, is_slurm_main_process


def is_main_process():
if is_slurm():
return is_slurm_main_process()
else:
return True
Loading

0 comments on commit 27974b0

Please sign in to comment.