-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
611 changed files
with
108,041 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#Run manually to reformat a file: | ||
#clang-format -i --style=file <file> | ||
BasedOnStyle: Google | ||
DerivePointerAlignment: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[flake8] | ||
ignore = E203, E266, W503, E741 | ||
max-line-length = 120 | ||
per-file-ignores = __init__.py:F401 atorch/distributed/distributed.py:F401 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[settings] | ||
multi_line_output=3 | ||
line_length=120 | ||
known_third_party = GPy,accelerate,agd,apex,data,datasets,deepspeed,distutils,dlrover,einops,evaluate,example_utils,fairscale,flash_attn,google,grpc,instruction_dataset_utils,matplotlib,model,modeling,networkx,numpy,packaging,pandas,peft,psutil,pymoo,pyomo,pytest,redis,safetensors,scipy,sklearn,tiktoken,torch,torch_npu,torchvision,tqdm,transformers,triton,typing_extensions,utils,yaml | ||
include_trailing_comma=True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
--- | ||
repos: | ||
- repo: https://github.com/pre-commit/mirrors-isort | ||
rev: v5.10.1 | ||
hooks: | ||
- id: isort | ||
exclude: _pb2.py|_pb2_grpc.py | ||
args: [--settings-path, atorch, "--profile", "black"] | ||
- repo: https://github.com/psf/black | ||
rev: 22.6.0 | ||
hooks: | ||
- id: black | ||
exclude: _pb2.py|_pb2_grpc.py | ||
args: [--line-length=120] | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v2.4.0 | ||
hooks: | ||
- id: flake8 | ||
exclude: __init__.py|_pb2.py|_pb2_grpc.py | ||
args: [ | ||
"--max-line-length=120", | ||
"--ignore=E721,W503,E203,E266,E741", | ||
] | ||
- repo: https://github.com/pre-commit/mirrors-mypy | ||
rev: v0.981 | ||
hooks: | ||
- id: mypy | ||
exclude: _pb2.py|_pb2_grpc.py|auto/engine/servicer.py | ||
args: [--ignore-missing-imports, --follow-imports=skip, --namespace-packages, --no-strict-optional, --show-error-codes] | ||
additional_dependencies: ["types_requests", "types-PyYAML"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Legal Disclaimer | ||
|
||
Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. | ||
|
||
法律免责声明 | ||
|
||
关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# ATorch | ||
<div id="top" align="center"> | ||
|
||
<img src="docs/img/atorch.png" alt="Editor" width="500"> | ||
|
||
ATorch: Make large model training more efficient and reproducible for everyone. | ||
|
||
|
||
|
||
[![GitHub Repo stars](https://img.shields.io/github/stars/intelligent-machine-learning/dlrover?style=social)](https://github.com/intelligent-machine-learning/dlrover/stargazers) | ||
[![Build](https://github.com/intelligent-machine-learning/dlrover/actions/workflows/main.yml/badge.svg)](https://github.com/intelligent-machine-learning/dlrover/actions/workflows/main.yml) | ||
[![PyPI Status Badge](https://badge.fury.io/py/atorch.svg)](https://pypi.org/project/atorch/) | ||
|
||
</div> | ||
|
||
|
||
## Table of Contents | ||
<ul> | ||
<li><a href="#Features">Features</a> </li> | ||
<li><a href="#Installation">Installation</a></li> | ||
<li><a href="#Getting-Started">Getting Started</a></li> | ||
<li><a href="#Contributing">Contributing</a></li> | ||
|
||
</ul> | ||
|
||
|
||
ATorch is an extension library of PyTorch developed by Ant Group's AI Infrastructure team. By decoupling model definition from training optimization strategy, ATorch supports efficient and easy-to-use model training experience. The design principle is to minimally disrupt the native PyTorch programming style. Through its API, ATorch provides performance optimizations in aspects such as I/O, preprocessing, computation, and communication (including automatic optimization). ATorch has supported large-scale pretraining of LLMs with over 100 billion parameters and thousands of A100/H100 GPUs. | ||
|
||
## Features | ||
|
||
![atorch_diagram](docs/img/atorch_fig.png) | ||
* Easy-to-use interface | ||
* [auto_accelerate](docs/auto_accelerate_api.md) API | ||
* ATorchTrainer (ongoing work) | ||
* Solutions for large-scale model training | ||
* support efficient large model initialization, checkpoint save/load, and restart with elastic resources. | ||
* Automatic/semi-automatic optimization | ||
* Acceleration Engine for automatic optimization | ||
* Semi-automatic optimization supports custom optimization | ||
* Hybrid parallelism support (arbitrary combination of fsdp/zero/ddp/tp/sp/pp) | ||
* High performance operators | ||
* Flash attention 2 with custom mask support | ||
* Transformer ops | ||
* High-performance MOE | ||
* sub-graph compilation | ||
* Checkpointing | ||
* Mixed precision | ||
* Communication optimization | ||
* Cached sharding | ||
* Effective optimizers for fast training convergence | ||
* [AGD optimizer](docs/README-AGD.md) | ||
* [WSAM optimizer](docs/README-WSAM.md) | ||
* IO/Preprocessing | ||
* CPU/GPU coworker to speedup data preprocessing | ||
* IO optimization for different dataset | ||
* Elastic and fault tolerance | ||
* Hardware error detection and migration (with dlrover) | ||
* GPU elastic training support | ||
* HangDetector (detecting and automatically restarting distributed training if it hangs) | ||
|
||
## Installation | ||
|
||
ATorch supports PyTorch with version >= 1.12, and version 2.1 or above is preferred. | ||
For example, you can use docker image <code>registry.cn-hangzhou.aliyuncs.com/atorch/atorch-open-20240430:pt210</code>) which has PyTorch 2.1 installed. | ||
|
||
### Install From PyPI | ||
Install atorch in any PyTorch-preinstalled environment (such as a container created with the docker image above) with <code>pip</code>: | ||
|
||
``` | ||
pip install atorch | ||
``` | ||
|
||
### Install From Source Files | ||
|
||
``` | ||
# clone repository | ||
git clone https://github.com/intelligent-machine-learning/dlrover.git | ||
cd dlrover/atorch | ||
# build package, optional set version. | ||
bash dev/scripts/build.sh [version] | ||
# install the created package in dist directory. Note that if version is set, file name is different. | ||
pip install dist/atorch-0.1.0.dev0-py3-none-any.whl | ||
``` | ||
|
||
|
||
## Getting Started | ||
|
||
### Run Examples | ||
|
||
|
||
- To run [auto_accelerate examples](examples/auto_accelerate): | ||
``` | ||
cd dlrover/atorch/examples/auto_accelerate | ||
# Single process train | ||
python train.py --model_type toy | ||
# Distributed train | ||
python -m atorch.distributed.run --nproc_per_node 2 train.py --model_type llama --distributed --load_strategy --use_fsdp --use_amp --use_module_replace --use_checkpointing | ||
``` | ||
|
||
- [Llama2 pretrain/finetune examples](examples/llama2) | ||
|
||
- [Optimizer (AGD, WSAM) Examples](examples/optimizer) | ||
|
||
### Documentations | ||
|
||
[auto_accelerate](docs/auto_accelerate_api.md) | ||
|
||
[AGD optimizer](docs/README-AGD.md) | ||
|
||
[WSAM optimizer](docs/README-WSAM.md) | ||
|
||
|
||
|
||
|
||
## Contributing | ||
Contributions are welcome! If you have any suggestions, ideas, or bug reports, please open an issue or submit a pull request. | ||
|
||
## CI/CD | ||
|
||
We leverage the power of [GitHub Actions](https://github.com/features/actions) to automate our development, release and deployment workflows. Please check out this [documentation](.github/workflows/README.md) on how the automated workflows are operated. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import logging | ||
import os | ||
from importlib.metadata import version | ||
|
||
from .distributed.distributed import coworker_size, init_distributed, local_rank, rank, reset_distributed, world_size | ||
|
||
try: | ||
__version__ = version("atorch") | ||
except ImportError: | ||
__version__ = "0.0.1dev" | ||
|
||
os.environ["PIPPY_PIN_DEVICE"] = "0" | ||
|
||
# patch with atorch addon if exists and not disabled by ATORCH_DISABLE_ADDON env. | ||
disable_addon = False | ||
disable_addon_env = os.getenv("ATORCH_DISABLE_ADDON") | ||
if disable_addon_env is not None and disable_addon_env.lower() in ["true", "t", "1", "y", "yes"]: | ||
disable_addon = True | ||
|
||
if disable_addon: | ||
logging.warning("atorch_addon disabled by env ATORCH_DISABLE_ADDON.") | ||
|
||
addon_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "atorch_addon.py") | ||
|
||
if not disable_addon and os.path.exists(addon_file): | ||
try: | ||
import atorch.atorch_addon | ||
except ImportError: | ||
logging.warning("Failed to import atorch_addon!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from atorch.common.log_utils import default_logger as logger | ||
|
||
try: | ||
from .amp import initialize, load_state_dict, master_params, scale_loss, state_dict | ||
from .hook import sample_list_to_type | ||
except ImportError: | ||
logger.info("Apex not available") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from apex import amp | ||
|
||
|
||
def initialize( | ||
model, | ||
optimizers=None, | ||
enabled=True, | ||
opt_level="O1", | ||
keep_batchnorm_fp32=None, | ||
loss_scale=None, | ||
min_loss_scale=None, | ||
max_loss_scale=16777216.0, | ||
**kwargs, | ||
): | ||
""" | ||
Wrap `apex.amp.initialize`. | ||
Initialize your models, optimizers, and the Torch tensor and functional | ||
namespace according to the chosen opt_level and overridden properties, | ||
if any. | ||
Args: | ||
models: Models to modify/cast. | ||
optimizers: Optimizers to modify/cast. REQUIRED for training, optional | ||
for inference. | ||
enabled: If False, renders all Amp calls no-ops, so your script should | ||
run as if Amp were not present. | ||
opt_level: Pure or mixed precision optimization level. Accepted values | ||
are “O0”, “O1”, “O2”, and “O3”, explained in detail above. | ||
keep_batchnorm_fp32: Optional property override. If passed as a | ||
string, must be the string “True” or “False”. | ||
loss_scale: Optional property override. If passed as a string, must be | ||
a string representing a number, e.g., “128.0”, or the string | ||
“dynamic”. | ||
min_loss_scale: Sets a floor for the loss scale values that can be | ||
chosen by dynamic loss scaling. The default value of None means | ||
that no floor is imposed. If dynamic loss scaling is not used, | ||
min_loss_scale is ignored. | ||
max_loss_scale: Sets a ceiling for the loss scale values that can be | ||
chosen by dynamic loss scaling. If dynamic loss scaling is not | ||
used, max_loss_scale is ignored. | ||
Returns: | ||
Model(s) and optimizer(s) modified according to the opt_level. If | ||
either the models or optimizers args were lists, the corresponding | ||
return value will also be a list. | ||
""" | ||
return amp.initialize( | ||
model, | ||
optimizers=optimizers, | ||
enabled=enabled, | ||
opt_level=opt_level, | ||
keep_batchnorm_fp32=keep_batchnorm_fp32, | ||
loss_scale=loss_scale, | ||
min_loss_scale=min_loss_scale, | ||
max_loss_scale=max_loss_scale, | ||
**kwargs, | ||
) | ||
|
||
|
||
def scale_loss(loss, optimizers, **kwargs): | ||
""" | ||
Wrap `apex.amp.scale_loss`. | ||
On context manager entrance, creates scaled_loss = (loss.float())*current | ||
loss scale. scaled_loss is yielded so that the user can call | ||
scaled_loss.backward(): | ||
with amp.scale_loss(loss, optimizer) as scaled_loss: | ||
scaled_loss.backward() | ||
On context manager exit (if delay_unscale=False), the gradients are | ||
checked for infs/NaNs and unscaled, so that optimizer.step() can be | ||
called. | ||
""" | ||
return amp.scale_loss(loss, optimizers, **kwargs) | ||
|
||
|
||
def master_params(optimizer): | ||
""" | ||
Wrap `apex.amp.master_params`. | ||
Generator expression that iterates over the params owned by optimizer. | ||
Returns: | ||
optimizer: An optimizer previously returned from amp.initialize. | ||
""" | ||
yield from amp.master_params(optimizer) | ||
|
||
|
||
def state_dict(destination=None): | ||
""" | ||
Wrap `apex.amp.state_dict`. | ||
To properly save and load your amp training, amp.state_dict() contains | ||
all loss_scalers. | ||
""" | ||
return amp.state_dict(destination=destination) | ||
|
||
|
||
def load_state_dict(state_dict): | ||
""" | ||
Wrap `apex.amp.load_state_dict`. | ||
""" | ||
return amp.load_state_dict(state_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Dict, List | ||
|
||
from apex.amp import _initialize | ||
from torch import Tensor | ||
|
||
to_type_original = _initialize.to_type | ||
|
||
|
||
def sample_list_to_type(dtype, t): | ||
""" | ||
Hook `_initialize.to_type`. Original `to_type` only handle the case | ||
that `t` is a torch.Tensor. `sample_list_to_type` can also handle | ||
the case that t is a list or a dict. | ||
""" | ||
if isinstance(t, Dict): | ||
for k, v in t.items(): | ||
if isinstance(v, Tensor): | ||
if v.is_floating_point(): | ||
t[k] = v.to(dtype) | ||
return t | ||
elif isinstance(t, List): | ||
for i, elem in enumerate(t): | ||
if isinstance(elem, Tensor): | ||
if elem.is_floating_point(): | ||
t[i] = elem.to(dtype) | ||
return t | ||
else: | ||
return to_type_original(dtype, t) | ||
|
||
|
||
_initialize.to_type = sample_list_to_type |
Oops, something went wrong.