Skip to content

Commit

Permalink
Merge branch 'master' into zhejiang/fix_runtime_dataloader_shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Sep 27, 2024
2 parents d40ba2e + d4e1895 commit 1787888
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 3 deletions.
59 changes: 59 additions & 0 deletions .github/workflows/xpu-compile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
name: xpu-compile

on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
pull_request:
paths:
- ".github/workflows/xpu-compile.yml"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: read
issues: write

jobs:
compile-tests:
runs-on: [self-hosted, intel, xpu]
container:
image: intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04
ports:
- 80
options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL

steps:
- uses: actions/checkout@v4
- name: Install prerequisite
run: |
apt-get update
apt-get install clinfo libaio-dev python3-pip -y
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torch/
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/intel-extension-for-pytorch/
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/oneccl-bind-pt/
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torchvision/
pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v3.0.0b2/triton_xpu-3.0.0b2-cp310-cp310-linux_x86_64.whl
pip install py-cpuinfo numpy
pip install .[dev,autotuning]
- name: Check container state
run: |
ldd --version
ds_report
python3 -c "import torch; print('torch:', torch.__version__, torch)"
python3 -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())"
python3 -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)"
pip list
- name: Compile Status
shell: bash
run: |
export FI_HMEM=system
ulimit -n 1048575
cd tests/torch_compile
export ZE_AFFINITY_MASK=0,1
deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt
cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY
7 changes: 4 additions & 3 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def __init__(self, model, config):
if hasattr(self.module, "config"):
TransformerPolicy.hf_model_config = self.module.config

if config.dtype == torch.half and not get_accelerator().is_fp16_supported():
raise ValueError("Type fp16 is not supported.")
if config.dtype not in get_accelerator().supported_dtypes():
raise ValueError(
f"Data type {config.dtype} is not supported by {get_accelerator().device_name()} accelerator")

# todo: keep this self.injection_dict because we don't use to change config.injection_policy API
# todo: this will get changed when Molly's PR on auto injection dict is merged
Expand Down Expand Up @@ -324,7 +325,7 @@ def _validate_args(self, mpu, replace_with_kernel_inject):
if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")

supported_dtypes = [None, torch.half, torch.int8, torch.float]
supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16]
if self._config.dtype not in supported_dtypes:
raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")

Expand Down
41 changes: 41 additions & 0 deletions tests/torch_compile/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"train_batch_size": 8,
"steps_per_print": 2000,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
},
"gradient_clipping": 1.0,
"prescale_gradients": false,
"bf16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 15
},
"wall_clock_breakdown": false,
"zero_optimization": {
"stage": 3,
"reduce_scatter": true,
"overlap_comm": false,
"contiguous_gradients": false
}
}
99 changes: 99 additions & 0 deletions tests/torch_compile/test_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import argparse
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed import comm

import torch
import intel_extension_for_pytorch # noqa: F401 # type: ignore
from torch.utils.data import Dataset, DataLoader

torch._dynamo.config.cache_size_limit = 100

import collections


def get_dynamo_stats():
# TODO: consider deepcopy'ing the entire counters struct and
# adding a helper to do subtraction on it
return collections.Counter({
"calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"],
"unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"],
"graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()),
# NB: The plus removes zero counts
"unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]),
"autograd_captures": torch._dynamo.utils.counters["compiled_autograd"]["captures"],
"autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"]["compiles"],
"cudagraph_skips": torch._dynamo.utils.counters["inductor"]["cudagraph_skips"],
})


class RandomDataset(Dataset):

def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size).to(torch.bfloat16)

def __getitem__(self, index):
return self.data[index]

def __len__(self):
return self.len


data_size = 1024
data_length = 100
rand_loader = DataLoader(dataset=RandomDataset(data_size, data_length), batch_size=1, shuffle=False)


class MyModule(torch.nn.Module):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fc0 = torch.nn.Linear(1024, 256, bias=False)
self.fc1 = torch.nn.Linear(256, 256, bias=False)
self.dropout = torch.nn.Dropout(0.5)

def forward(self, data, residual):
output = residual + self.fc1(self.fc0(self.dropout(data))) * 0.5
return output


model = MyModule()
params = model.parameters()

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
parser.add_argument('--deepspeed_config',
type=str,
default='ds_config.json',
help='path to DeepSpeed configuration file')
cmd_args = parser.parse_args()

# initialize the DeepSpeed engine
model_engine, optimizer, _, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=params)
model_engine.compile()

residual = torch.rand(256, 256, dtype=torch.float).to(get_accelerator().current_device_name())

start_stats = get_dynamo_stats()

for step, batch in enumerate(rand_loader):
if step % 10 == 0 and comm.get_rank() == 0:
print(f'step={step}')
# forward() method
loss = model_engine(batch.to(get_accelerator().current_device_name()), residual).sum()
# runs backpropagation
model_engine.backward(loss)
# weight update
model_engine.step()

dynamo_stats = get_dynamo_stats()
dynamo_stats.subtract(start_stats)

if comm.get_rank() == 0:
print(dynamo_stats)

0 comments on commit 1787888

Please sign in to comment.