Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP)Add torch example. #415

Draft
wants to merge 19 commits into
base: develop/v2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions aiaccel/torch/lightning/datamodules/single_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Any

from collections.abc import Callable, Sized

from torch.utils.data import DataLoader, Dataset, Subset

import lightning as lt

from aiaccel.torch.datasets import scatter_dataset


class SingleDataModule(lt.LightningDataModule):
def __init__(
self,
train_dataset_fn: Callable[..., Dataset[str]],
val_dataset_fn: Callable[..., Dataset[str]],
batch_size: int,
num_workers: int = 10,
wrap_scatter_dataset: bool = True,
):
super().__init__()

self.train_dataset_fn = train_dataset_fn
self.val_dataset_fn = val_dataset_fn

self.default_dataloader_kwargs = dict[str, Any](
batch_size=batch_size,
num_workers=num_workers,
persistent_workers=True,
shuffle=True,
)

self.wrap_scatter_dataset = wrap_scatter_dataset

def setup(self, stage: str | None) -> None:
self.train_dataset: Dataset[str] | Subset[str]
self.val_dataset: Dataset[str] | Subset[str]
if stage == "fit":
if self.wrap_scatter_dataset:
self.train_dataset = scatter_dataset(self.train_dataset_fn())
self.val_dataset = scatter_dataset(self.val_dataset_fn())
else:
self.train_dataset = self.train_dataset_fn()
self.val_dataset = self.val_dataset_fn()

if isinstance(self.train_dataset, Sized) and isinstance(self.val_dataset, Sized):
print(f"Dataset size: {len(self.train_dataset)=}, {len(self.val_dataset)=}")
else:
raise ValueError("`stage` is not 'fit'.")

def train_dataloader(self) -> DataLoader[Any]:
return DataLoader(
self.train_dataset,
drop_last=True,
**self.default_dataloader_kwargs,
)

def val_dataloader(self) -> DataLoader[Any]:
return DataLoader(
self.val_dataset,
drop_last=False,
**self.default_dataloader_kwargs,
)
39 changes: 39 additions & 0 deletions docs/source/user_guide/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,45 @@ Writing a simple training script
Running inference
-----------------

To run aiaccel on ABCI3.0, you need an environment with Python 3.10. This guide explains
how to set up the environment using Singularity.

.. note::

For details on how to use Singularity, please refer to the following documentation:
https://docs.abci.ai/v3/en/containers/

Create the following Singularity definition file:

.. code-block:: bash
:caption: aiaccel_env.def

BootStrap: docker

From: python:3.10

%post

pip install --upgrade pip

# aiaccel env
pip install aiaccel[torch]@git+https://github.com/aistairc/aiaccel.git@develop/v2

# torch/MNIST example env
pip install torchvision

Use the Singularity definition file to build a Singularity image file:

.. code-block:: bash

singularity build aiaccel.sif aiaccel_env.def

Use the Singularity image file to execute aiaccel:

.. code-block:: bash

singularity exec --nv aiaccel.sif python -m aiaccel.torch.apps.train $wd/config.yaml --working_directory $wd

Writing a DDP training script
-----------------------------

Expand Down
70 changes: 70 additions & 0 deletions examples/torch/MNIST/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
trainer:
logger: True
max_epochs: 10
callbacks:
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
filename: "{epoch:04d}"
save_last: True
save_top_k: -1

task:
_target_: torchvision_task.Resnet50Task
model:
_target_: torchvision.models.resnet50
weights:
_target_: hydra.utils.get_object
path: torchvision.models.ResNet50_Weights.DEFAULT
optimizer_config:
_target_: aiaccel.torch.lightning.OptimizerConfig
optimizer_generator:
_partial_: True
_target_: torch.optim.AdamW
lr: 1.e-4
num_classes: 10

datamodule:
_target_: aiaccel.torch.lightning.datamodules.single_datamodule.SingleDataModule
train_dataset_fn:
_partial_: true
_target_: torchvision.datasets.MNIST
root: "./dataset"
train: True
download: True
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.Resize
size:
- 256
- 256
- _target_: torchvision.transforms.Grayscale
num_output_channels: 3
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean:
- 0.5
std:
- 0.5
val_dataset_fn:
_partial_: true
_target_: torchvision.datasets.MNIST
root: "./dataset"
train: False
download: True
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.Resize
size:
- 256
- 256
- _target_: torchvision.transforms.Grayscale
num_output_channels: 3
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean:
- 0.5
std:
- 0.5
batch_size: 128
wrap_scatter_dataset: False
73 changes: 73 additions & 0 deletions examples/torch/MNIST/config_ddp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
trainer:
accelerator: "gpu"
devices: 8
strategy: "ddp"
logger: True
max_epochs: 10
callbacks:
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
filename: "{epoch:04d}"
save_last: True
save_top_k: -1

task:
_target_: torchvision_task.Resnet50Task
model:
_target_: torchvision.models.resnet50
weights:
_target_: hydra.utils.get_object
path: torchvision.models.ResNet50_Weights.DEFAULT
optimizer_config:
_target_: aiaccel.torch.lightning.OptimizerConfig
optimizer_generator:
_partial_: True
_target_: torch.optim.AdamW
lr: 1.e-4
num_classes: 10

datamodule:
_target_: aiaccel.torch.lightning.datamodules.single_datamodule.SingleDataModule
train_dataset_fn:
_partial_: true
_target_: torchvision.datasets.MNIST
root: "./dataset"
train: True
download: True
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.Resize
size:
- 256
- 256
- _target_: torchvision.transforms.Grayscale
num_output_channels: 3
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean:
- 0.5
std:
- 0.5
val_dataset_fn:
_partial_: true
_target_: torchvision.datasets.MNIST
root: "./dataset"
train: False
download: True
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.Resize
size:
- 256
- 256
- _target_: torchvision.transforms.Grayscale
num_output_channels: 3
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean:
- 0.5
std:
- 0.5
batch_size: 128
wrap_scatter_dataset: False
42 changes: 42 additions & 0 deletions examples/torch/MNIST/torchvision_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any

from torch import Tensor, nn
from torch.nn import functional as func
from torch.utils.data import DataLoader

import torchmetrics

from aiaccel.torch.lightning import OptimizerConfig, OptimizerLightningModule


class Resnet50Task(OptimizerLightningModule):
def __init__(self, model: nn.Module, optimizer_config: OptimizerConfig, num_classes: int = 10):
super().__init__(optimizer_config)
self.model = model
if hasattr(self.model.fc, "in_features") and isinstance(self.model.fc.in_features, int):
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

self.train_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)
self.val_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10)

def forward(self, x: Any) -> Any:
return self.model(x)

def training_step(self, batch: DataLoader[Any], batch_idx: int) -> Tensor:
x, y = batch
logits = self(x)
loss = func.cross_entropy(logits, y)

acc = self.train_accuracy(logits, y)
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", acc, prog_bar=True)
return loss

def validation_step(self, batch: DataLoader[Any], batch_idx: int) -> None:
x, y = batch
logits = self(x)
loss = func.cross_entropy(logits, y)

acc = self.val_accuracy(logits, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
18 changes: 18 additions & 0 deletions examples/torch/MNIST/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#! /bin/bash

#PBS -q rt_HG
#PBS -l select=1
#PBS -l walltime=1:00:00
#PBS -P grpname
#PBS -j oe

cd ${PBS_O_WORKDIR}

source /etc/profile.d/modules.sh
module load cuda/12.6/12.6.1

source path_to_aiaccel_env/bin/activate

wd=path_to_working_directory

singularity exec --nv path_to_python.sif python -m aiaccel.torch.apps.train $wd/config.yaml --working_directory $wd
18 changes: 18 additions & 0 deletions examples/torch/MNIST/train_ddp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#! /bin/bash

#PBS -q rt_HF
#PBS -l select=1
#PBS -l walltime=1:00:00
#PBS -P grpname
#PBS -j oe

cd ${PBS_O_WORKDIR}

source /etc/profile.d/modules.sh
module load cuda/12.6/12.6.1

source path_to_aiaccel_env/bin/activate

wd=path_to_working_directory

singularity exec --nv path_to_python.sif python -m aiaccel.torch.apps.train $wd/config_ddp.yaml --working_directory $wd
Loading