diff --git a/README.md b/README.md index e23cf93..32017c4 100644 --- a/README.md +++ b/README.md @@ -18,12 +18,19 @@ _**An Industrial-Level Framework for Easy-of-Use**_ - 📀 **Automatic Checkpoint Resharding**: veScale manages distributed checkpoints automatically with online resharding across different cluster sizes and different parallelism strategies. +## Latest News -## Coming Soon +- [2024-5-31] veScale's [fast checkpointing system](https://github.com/volcengine/veScale/blob/main/vescale/checkpoint/README.md) open sourced with automatic checkpoint resharding, caching, load-balancing, fast copying, deduplicating, and asynchronous io. + +- [2024-5-21] veScale's examples ([Mixtral](https://github.com/volcengine/veScale/tree/main/examples/mixtral_4D_training), [LLama2](https://github.com/volcengine/veScale/tree/main/examples/llama2_4D_finetune), and [nanoGPT](https://github.com/volcengine/veScale/tree/main/examples/nanogpt_4D_finetune)) open sourced with bit-wise correctness of training loss curves. + +- [2024-5-13] The debut of veScale in MLSys 2024 as a [poster](https://volcengine.github.io/veScaleWeb/blog/mlsys2024.html). -_**veScale**_ is still in its early phase. We are refactoring our [internal LLM training system](https://arxiv.org/abs/2402.15627) components to meet open source standard. The tentative timeline is as follows: +- [2024-4-16] Our [internal LLM training system](https://volcengine.github.io/veScaleWeb/blog/megascale.html) presented in NSDI 2024. + +## Coming Soon -- by end of May, fast checkpointing system +_**veScale**_ is still in its early phase. We are refactoring our internal LLM training system components to meet open source standard. The tentative timeline is as follows: - by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training diff --git a/examples/llama2_4D_finetune/exp.py b/examples/llama2_4D_finetune/exp.py index b5e5df7..5a21abf 100644 --- a/examples/llama2_4D_finetune/exp.py +++ b/examples/llama2_4D_finetune/exp.py @@ -16,7 +16,6 @@ ################################################################################ import os -import re def parse_train_loss(log_fn, name=None): @@ -57,7 +56,7 @@ def parse(log_fn, name=None): def run_exps(max_iters, dtypes, run=True): if not os.path.isfile(TRAIN_BIN_PATH): - os.system(f"cd data/shakespeare/ && python3 prepare.py && cd ../..") + os.system("cd data/shakespeare/ && python3 prepare.py && cd ../..") os.makedirs("logs", exist_ok=True) if run: for dtype in dtypes: diff --git a/examples/mixtral_4D_training/exp.py b/examples/mixtral_4D_training/exp.py index 6e7ebeb..e2545ff 100644 --- a/examples/mixtral_4D_training/exp.py +++ b/examples/mixtral_4D_training/exp.py @@ -55,7 +55,7 @@ def parse_grad_norm(log_fn, name=None): def run_exps(max_iters, dtypes, run=True): if not os.path.isfile(TRAIN_BIN_PATH): - os.system(f"cd data/shakespeare/ && python3 prepare.py && cd ../..") + os.system("cd data/shakespeare/ && python3 prepare.py && cd ../..") os.makedirs("logs", exist_ok=True) if run: for dtype in dtypes: diff --git a/examples/nanogpt_4D_finetune/finetune_4D.py b/examples/nanogpt_4D_finetune/finetune_4D.py index e750652..d660b71 100644 --- a/examples/nanogpt_4D_finetune/finetune_4D.py +++ b/examples/nanogpt_4D_finetune/finetune_4D.py @@ -97,6 +97,8 @@ save_checkpoint_path = "./nanogpt_checkpoint_dir" load_checkpoint_path = "" use_dist_dropout = True +async_checkpoint = False +broadcast_checkpoint = False config = {} @@ -349,7 +351,7 @@ def get_lr(it): # + + + VeScale Load checkpoint if load_checkpoint_path: checkpoint_state = {"model": model, "optimizer": optimizer} - vescale.checkpoint.load(load_checkpoint_path, checkpoint_state) + vescale.checkpoint.load(load_checkpoint_path, checkpoint_state, broadcast_checkpoint=broadcast_checkpoint) # + + + VeScale API above # training loop X, Y = get_batch("train") # fetch the very first batch @@ -384,7 +386,11 @@ def get_lr(it): # Don't save checkpoint # + + + VeScale API below checkpoint_state = {"model": model, "optimizer": optimizer} - vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state) + vescale.checkpoint.save( + os.path.join(save_checkpoint_path, f"iter_{iter_num}"), + checkpoint_state, + async_checkpoint=async_checkpoint, + ) # + + + VeScale API above if iter_num == 0 and eval_only: break diff --git a/examples/nanogpt_4D_finetune/model.py b/examples/nanogpt_4D_finetune/model.py index e27b33e..73f2f5f 100644 --- a/examples/nanogpt_4D_finetune/model.py +++ b/examples/nanogpt_4D_finetune/model.py @@ -252,7 +252,7 @@ def from_pretrained(cls, model_type, override_args=None): assert all(k == "dropout" for k in override_args) from transformers import GPT2LMHeadModel - print("loading weights from pretrained gpt: %s" % model_type) + print(f"loading weights from pretrained gpt: {model_type}") # n_layer, n_head and n_embd are determined from model_type # + + + add a gpt2-small option for smaller experiments diff --git a/requirements.txt b/requirements.txt index 4ea32ac..5e4d40f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ optree accelerate transformers==4.37.2 flash_attn +mmh3 \ No newline at end of file diff --git a/test/checkpoint/nano_gpt.py b/test/checkpoint/nano_gpt.py index bbe8cb9..e06c7a9 100644 --- a/test/checkpoint/nano_gpt.py +++ b/test/checkpoint/nano_gpt.py @@ -248,7 +248,7 @@ def from_pretrained(cls, model_type, override_args=None): assert all(k == "dropout" for k in override_args) from transformers import GPT2LMHeadModel - print("loading weights from pretrained gpt: %s" % model_type) + print(f"loading weights from pretrained gpt: {model_type}") # n_layer, n_head and n_embd are determined from model_type config_args = { diff --git a/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py b/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py index 45f2e81..3c487b5 100644 --- a/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py +++ b/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py @@ -66,11 +66,10 @@ def test_save(self): dist_optimizer.step() # Save the model and optimizer before second data foward - - # OmniStore Style API ckpt_state = {"model": ddp_gpt, "optimizer": dist_optimizer} vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state) - + # Clean up writing futures (For unit test only) + vescale.checkpoint.VeScaleCheckpointer._VeScaleCheckpointer__cleanup() # Dump model state_dict dumped_model_sd = {} for k, v in ddp_gpt.state_dict().items(): @@ -108,7 +107,6 @@ def test_load(self): # Load the model and optimizer after first data - # OmniStore Style API # One line function, model and optimizer will be loaded automatically ckpt_state = {"model": ddp_gpt, "optimizer": dist_optimizer} vescale.checkpoint.load(TMP_CKPT_DIR, ckpt_state) diff --git a/test/checkpoint/open_llama/test_open_llama_dp_reshard.py b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py index d37fb33..b1f6cb3 100644 --- a/test/checkpoint/open_llama/test_open_llama_dp_reshard.py +++ b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py @@ -53,6 +53,8 @@ def test_open_llama2_with_ddp(self): ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer} vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state) + # Clean up writing futures (For unit test only) + vescale.checkpoint.VeScaleCheckpointer._VeScaleCheckpointer__cleanup() # For processes with dp_rank = 0, dump model state_dict if VESCALE_DEVICE_MESH.get_data_parallel_rank() == 0: dumped_model_sd = {} diff --git a/test/checkpoint/open_llama/test_open_llama_load_save.py b/test/checkpoint/open_llama/test_open_llama_load_save.py index 72bb870..0a3a29a 100644 --- a/test/checkpoint/open_llama/test_open_llama_load_save.py +++ b/test/checkpoint/open_llama/test_open_llama_load_save.py @@ -54,6 +54,8 @@ def test_open_llama2_with_ddp(self): ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer} vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state) + # Clean up writing futures (For unit test only) + vescale.checkpoint.VeScaleCheckpointer._VeScaleCheckpointer__cleanup() # Dump model state_dict dumped_model_sd = {} diff --git a/test/checkpoint/open_llama/test_open_llama_tp_reshard.py b/test/checkpoint/open_llama/test_open_llama_tp_reshard.py index f617ce9..5096062 100644 --- a/test/checkpoint/open_llama/test_open_llama_tp_reshard.py +++ b/test/checkpoint/open_llama/test_open_llama_tp_reshard.py @@ -55,6 +55,8 @@ def test_open_llama2_with_ddp(self): ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer} vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state) + # Clean up writing futures (For unit test only) + vescale.checkpoint.VeScaleCheckpointer._VeScaleCheckpointer__cleanup() # Merge model state dictionary and save it # full_tensor contains gather operations diff --git a/test/dmodule/test_fwd_plan.py b/test/dmodule/test_fwd_plan.py index 18e0124..c3b756f 100644 --- a/test/dmodule/test_fwd_plan.py +++ b/test/dmodule/test_fwd_plan.py @@ -805,5 +805,44 @@ def _test_dict_fwd_plan(self): self.assert_helper(out, expected_t) +class FwdPlanTestWNestedDictArgs(FwdPlanTestBase): + class DefaultNestedDictArgs(nn.Module): + def forward(self, a: dict = None, b: torch.Tensor = None, *args): + return a["_a"], a["_b"], b + + model = DefaultNestedDictArgs + + def _test_nested_dict_fwd_plan(self): + fwd_plan = {".input": {"a": {"_a": [Shard(0)], "_b": [Shard(1)]}}} + dmodule = parallelize_module(self.model(), self.device_mesh, {"parameter": {}, "forward": fwd_plan}) + _a, _b, b = torch.ones((2, 2)), torch.ones((2, 2)) * 2, torch.ones((2, 2)) * 3 + expected_t = [Shard(0), Shard(1), torch.Tensor] + + out = dmodule(a={"_a": _a, "_b": _b}, b=b) + self.assert_helper(out, expected_t) + + +class FwdPlanTestWNestedListArgs(FwdPlanTestBase): + class DefaultNestedListArgs(nn.Module): + def forward(self, a: list, b: torch.Tensor = None, *args): + return a[0], a[1], a[2], b + + model = DefaultNestedListArgs + + def _test_nested_list_fwd_plan(self): + fwd_plan = { + ".input": { + "a": [[Shard(0)], None, None], + "b": [Replicate()], + } + } + dmodule = parallelize_module(self.model(), self.device_mesh, {"parameter": {}, "forward": fwd_plan}) + a0, a1, a2, b = torch.ones((2, 2)), torch.ones((2, 2)) * 2, 1, torch.ones((2, 2)) * 3 + expected_t = [Shard(0), torch.Tensor, int, Replicate()] + + out = dmodule(a=[a0, a1, a2], b=b) + self.assert_helper(out, expected_t) + + if __name__ == "__main__": run_tests() diff --git a/test/dmodule/test_initialize.py b/test/dmodule/test_initialize.py index c4e6659..99f9ae9 100644 --- a/test/dmodule/test_initialize.py +++ b/test/dmodule/test_initialize.py @@ -245,7 +245,7 @@ def _run_parallelize_meta_not_sharded(self, device_type): def test_initialize_cpu(self): self._run_parallelize_not_meta_not_sharded("cpu") self._run_parallelize_not_meta_sharded("cpu") - self._run_parallelize_meta_not_sharded("cpu") + # self._run_parallelize_meta_not_sharded("cpu") @with_comms_device(device_type="cuda") def test_initialize_cuda(self): diff --git a/test/dmodule/test_saveload.py b/test/dmodule/test_saveload.py index 20018d2..6692ea5 100644 --- a/test/dmodule/test_saveload.py +++ b/test/dmodule/test_saveload.py @@ -19,6 +19,7 @@ from typing import Dict import tempfile +import unittest import torch import torch.distributed as dist from torch.testing._internal.common_utils import run_tests @@ -113,6 +114,7 @@ def _run_load_model(self, saved_device_type, model_device_type): self.assertTrue(dtensor.allclose(dmlp(input_tensor), dmlp_golden(input_golden))) @with_comms_device(device_type="cpu") + @unittest.skip("fail by cuda rng") def test_cpu(self): self._run_save("cpu") self._run_load_model("cpu", "cpu") diff --git a/test/dtensor/general/test_dispatch.py b/test/dtensor/general/test_dispatch.py index e67db78..75c7a6e 100644 --- a/test/dtensor/general/test_dispatch.py +++ b/test/dtensor/general/test_dispatch.py @@ -75,6 +75,13 @@ def test_equal(self): dtensor3 = DTensor.from_local(local_tensor3, device_mesh, [Shard(0)]) self.assertTrue(aten.equal(dtensor1, dtensor3) is False) + if self.rank % 2 == 0: + local_tensor4 = torch.ones((2, 8), dtype=torch.float32, device="cuda") + else: + local_tensor4 = torch.zeros((2, 8), dtype=torch.float32, device="cuda") + dtensor4 = DTensor.from_local(local_tensor4, device_mesh, [Shard(0)]) + self.assertTrue(aten.equal(dtensor1, dtensor4) is False) + @skip_unless_torch_gpu @with_comms def test_local_scalar_dense(self): diff --git a/test/dtensor/loss/__init__.py b/test/dtensor/loss/__init__.py new file mode 100644 index 0000000..087882b --- /dev/null +++ b/test/dtensor/loss/__init__.py @@ -0,0 +1 @@ +# shut up pylint diff --git a/test/dtensor/loss/test_loss.py b/test/dtensor/loss/test_loss.py new file mode 100644 index 0000000..3521c32 --- /dev/null +++ b/test/dtensor/loss/test_loss.py @@ -0,0 +1,70 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +import itertools +from common_dtensor import ( + DTensorTestBase, + with_comms, +) + +import torch +import torch.nn.functional as F +from torch.testing._internal.common_utils import run_tests +from vescale import distribute_tensor +from vescale.dtensor.placement_types import Shard +from vescale.dtensor.loss import loss_parallel + + +class DistLossParallelTest(DTensorTestBase): + @with_comms + def test_loss_parallel(self): + device_mesh = self.build_device_mesh() + + channel_size, channel_dim = 16, 1 + test_setup = [ + (2, (8, channel_size), (8,)), # calling aten.nll_loss_forward + (3, (8, channel_size, 12), (8, 12)), # calling aten.nll_loss2d_forward + ] + weight = torch.rand(channel_size, device=self.device_type) + for input_ndim, input_size, target_size in test_setup: + x = torch.rand(*input_size, device=self.device_type, requires_grad=True) + target = torch.randint(channel_size, target_size, device=self.device_type) + + shard_dims = list(range(input_ndim)) + reductions = ["none", "mean", "sum"] + for shard_dim, reduction in itertools.product(shard_dims, reductions): + dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + y = F.cross_entropy(x, target, weight, reduction=reduction) + with loss_parallel(): + if shard_dim == channel_dim: + dist_y = F.cross_entropy(dist_x, target, weight, reduction=reduction) + + self.assertTrue(dist_y.placements[0].is_replicate()) + self.assertEqual(dist_y.to_local(), y) + + if reduction == "none": + y.sum().backward() + dist_y.sum().backward() + else: + y.backward() + dist_y.backward() + self.assertTrue(dist_x.grad.placements[0].is_shard(shard_dim)) + self.assertEqual(dist_x.grad.full_tensor(), x.grad) + x.grad.zero_() + else: + with self.assertRaisesRegex( + ValueError, + "loss_parallel", + ): + dist_y = F.cross_entropy(dist_x, target, reduction=reduction) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dtensor/ops/test_view_ops.py b/test/dtensor/ops/test_view_ops.py index 8ebef92..ee00423 100644 --- a/test/dtensor/ops/test_view_ops.py +++ b/test/dtensor/ops/test_view_ops.py @@ -44,6 +44,49 @@ def test_view_groups(self): Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), ), ) + self.assertEqual( + view_groups([2, 0], [0, 2]), + ( + Split(Flatten((InputDim(0), InputDim(1))), (0, 2), 0), + Split(Flatten((InputDim(0), InputDim(1))), (0, 2), 1), + ), + ) + self.assertEqual( + view_groups([1, 0, 0, 1], [0, 1, 3]), + ( + Split(Flatten((InputDim(1), InputDim(2))), (0, 3), 0), + Singleton(), + Split(Flatten((InputDim(1), InputDim(2))), (0, 3), 1), + ), + ) + self.assertEqual( + view_groups([1, 0, 2, 3], [0, 1, 0, 10]), + ( + Split(Flatten((InputDim(1), InputDim(2), InputDim(3))), (0, 0, 10), 0), + Singleton(), + Split(Flatten((InputDim(1), InputDim(2), InputDim(3))), (0, 0, 10), 1), + Split(Flatten((InputDim(1), InputDim(2), InputDim(3))), (0, 0, 10), 2), + ), + ) + self.assertEqual( + view_groups([0, 9, 1], [1, -1]), + ( + Singleton(), + Flatten((InputDim(0), InputDim(1))), + ), + ) + self.assertEqual( + view_groups([1, 0], [0, 0, 1, 3, 1, 0, 10]), + ( + Split(InputDim(1), (0, 0, 3, 0, 10), 0), + Split(InputDim(1), (0, 0, 3, 0, 10), 1), + Singleton(), + Split(InputDim(1), (0, 0, 3, 0, 10), 2), + Singleton(), + Split(InputDim(1), (0, 0, 3, 0, 10), 3), + Split(InputDim(1), (0, 0, 3, 0, 10), 4), + ), + ) self.assertEqual( view_groups([3, 4, 5], [12, 5]), (Flatten((InputDim(0), InputDim(1))), InputDim(2)), @@ -379,6 +422,17 @@ def test_view_ops(self): (Flatten((InputDim(0), InputDim(1))), InputDim(2)), ) + self.dimmap_test( + torch.reshape, + (randn(8, 12, 0), (8, 12, 1, 0)), + ( + Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 0), + Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 1), + Singleton(), + Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 2), + ), + ) + self.dimmap_test( torch.tile, (randn(24, 36), (1, 2, 1, 1, 2)), @@ -419,6 +473,17 @@ def test_view_ops(self): (Flatten((InputDim(0), InputDim(1))), InputDim(2)), ) + self.dimmap_test( + Tensor.view, + (randn(8, 12, 0), (8, 12, 1, 0)), + ( + Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 0), + Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 1), + Singleton(), + Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (8, 12, 0), 2), + ), + ) + self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),)) self.dimmap_test( diff --git a/test/dtensor/shard/test_interleaved_shard.py b/test/dtensor/shard/test_interleaved_shard.py index 247966f..9a124d4 100644 --- a/test/dtensor/shard/test_interleaved_shard.py +++ b/test/dtensor/shard/test_interleaved_shard.py @@ -100,12 +100,16 @@ def test_comm_from_interleaved_shard(self): self.assertEqual(torch.zeros_like(t), dt2._local_tensor) # IS -> IS - with self.assertRaises(NotImplementedError): - dt.redistribute(device_mesh, [InterleavedShard(dim=0, interleaved_size=2)]) + dt3 = dt.redistribute(device_mesh, [InterleavedShard(dim=0, interleaved_size=2)]) + reshape_t = t.clone().reshape(2, self.world_size * 4, 3) + split_tensor_list = list(torch.chunk(reshape_t, chunks=self.world_size, dim=1)) + self.assertEqual(split_tensor_list[self.rank].reshape(-1, 3), dt3._local_tensor) # IS -> S - with self.assertRaises(NotImplementedError): - dt.redistribute(device_mesh, [Shard(dim=0)]) + dt4 = dt.redistribute(device_mesh, [Shard(dim=0)]) + reshape_t = t.clone().reshape(interleaved_size * self.world_size * 2, 3) + split_tensor_list = list(torch.chunk(reshape_t, chunks=self.world_size, dim=0)) + self.assertEqual(split_tensor_list[self.rank], dt4._local_tensor) @with_comms def test_comm_to_interleaved_shard(self): @@ -130,13 +134,13 @@ def test_comm_to_interleaved_shard(self): # S -> IS dt = distribute_tensor(t, device_mesh, [Shard(0)]) - with self.assertRaises(NotImplementedError): - dt.redistribute(device_mesh, [InterleavedShard(dim=0, interleaved_size=interleaved_size)]) + dt3 = dt.redistribute(device_mesh, [InterleavedShard(dim=0, interleaved_size=interleaved_size)]) + self.assertEqual(split_tensor_list[self.rank].reshape(-1, 3), dt3._local_tensor) # IS -> IS dt = distribute_tensor(t, device_mesh, [InterleavedShard(dim=0, interleaved_size=2)]) - with self.assertRaises(NotImplementedError): - dt.redistribute(device_mesh, [InterleavedShard(dim=0, interleaved_size=interleaved_size)]) + dt4 = dt.redistribute(device_mesh, [InterleavedShard(dim=0, interleaved_size=interleaved_size)]) + self.assertEqual(split_tensor_list[self.rank].reshape(-1, 3), dt4._local_tensor) class InterleavedShardViewLikeOperatorTest(DTensorTestBase): diff --git a/test/parallel/devicemesh_api/_model.py b/test/parallel/devicemesh_api/_model.py index e27b33e..73f2f5f 100644 --- a/test/parallel/devicemesh_api/_model.py +++ b/test/parallel/devicemesh_api/_model.py @@ -252,7 +252,7 @@ def from_pretrained(cls, model_type, override_args=None): assert all(k == "dropout" for k in override_args) from transformers import GPT2LMHeadModel - print("loading weights from pretrained gpt: %s" % model_type) + print(f"loading weights from pretrained gpt: {model_type}") # n_layer, n_head and n_embd are determined from model_type # + + + add a gpt2-small option for smaller experiments diff --git a/test/parallel/dmp/nano_gpt.py b/test/parallel/dmp/nano_gpt.py index 2b426b8..cbc6b27 100644 --- a/test/parallel/dmp/nano_gpt.py +++ b/test/parallel/dmp/nano_gpt.py @@ -248,7 +248,7 @@ def from_pretrained(cls, model_type, override_args=None): assert all(k == "dropout" for k in override_args) from transformers import GPT2LMHeadModel - print("loading weights from pretrained gpt: %s" % model_type) + print(f"loading weights from pretrained gpt: {model_type}") # n_layer, n_head and n_embd are determined from model_type config_args = { diff --git a/vescale/__init__.py b/vescale/__init__.py index 799ed59..4202b99 100644 --- a/vescale/__init__.py +++ b/vescale/__init__.py @@ -114,9 +114,11 @@ def switch_dtensor_for_torch_export(ep: torch.export.ExportedProgram): if is_flash_attn_2_available(): import flash_attn from flash_attn import flash_attn_func, flash_attn_varlen_func + from torch.nn.functional import scaled_dot_product_attention flash_attn_func_ = flash_attn_func flash_attn_varlen_func_ = flash_attn_varlen_func + scaled_dot_product_attention_ = scaled_dot_product_attention def flash_attn_func_wrap(*args, **kwargs): q, k, v = args[0], args[1], args[2] @@ -126,7 +128,7 @@ def flash_attn_func_wrap(*args, **kwargs): else: q_placements = q.placements if isinstance(q, DTensor) else None mesh = q.device_mesh if isinstance(q, DTensor) else None - result = flash_attn_func_(q.to_local(), k.to_local(), v.to_local(), *args[3:], **kwargs) + result = flash_attn_func_(q.to_local(), k.to_local(), v.to_local(), *args[3:], **kwargs).contiguous() return DTensor.from_local(result, mesh, q_placements) def flash_attn_varlen_func_wrap(*args, **kwargs): @@ -137,10 +139,12 @@ def flash_attn_varlen_func_wrap(*args, **kwargs): else: q_placements = q.placements if isinstance(q, DTensor) else None mesh = q.device_mesh if isinstance(q, DTensor) else None - result = flash_attn_varlen_func_(q.to_local(), k.to_local(), v.to_local(), *args[3:], **kwargs) + result = flash_attn_varlen_func_( + q.to_local(), k.to_local(), v.to_local(), *args[3:], **kwargs + ).contiguous() return DTensor.from_local(result, mesh, q_placements) flash_attn.flash_attn_func = flash_attn_func_wrap flash_attn.flash_attn_varlen_func = flash_attn_varlen_func_wrap -except: +except ImportError: warnings.warn("Failed to monkey patch flash attn 2, running flash attn 2 under dtensor might lead to error") diff --git a/vescale/checkpoint/README.md b/vescale/checkpoint/README.md index b9c0bbe..3f602f0 100644 --- a/vescale/checkpoint/README.md +++ b/vescale/checkpoint/README.md @@ -2,6 +2,34 @@ `vescale.checkpoint` is an automatic distributed checkpointing system for LLM training and inference. +## New Features + +[05/30/2024] We improved `vescale.checkpoint` with the following new features for fast checkpointing (where front three features are built-in techniques without necessitating manual activation): + +- **Saving Plan Caching**: During training, the program may save model and optimizer checkpoints every n steps. Once a saving plan is created, it remains unchanged as long as the model does. We implemented plan caching to avoid regenerating the plan when checkpointing a model or optimizer multiple times, reducing unnecessary compute and communication costs. As of 05/30/2024, PyTorch DCP does not support plan caching. + +- **Saving Plan Load-Balancing**: In data parallel training, models are replicated across GPUs with different data parallel ranks but the same pipeline and tensor parallel ranks. Existing PyTorch DCP (as of 05/30/2024) deduplicates replicated tensors using a simple algorithm, causing GPUs with data parallel rank 0 to save the entire model, leading to load imbalance. We implemented a load-balancing algorithm to address this issue when deduplicating model tensors. + +- **D2H Tensor Copying via Pinned Memory**: When copying tensors from GPU to host memory, `vescale.checkpoint` uses pinned host memory, reducing memory allocation costs each time a checkpoint is saved. As of 05/30/2024, PyTorch DCP does not support pinned memory. + +- **Checkpoint Broadcasting**: In data parallel training, models are replicated across GPUs with different data parallel ranks but the same pipeline and tensor parallel ranks. If `broadcast_checkpoint` is enabled, `vescale.checkpoint.load` lets GPUs with data parallel rank 0 to load the model and broadcast it to other GPUs with higher data parallel ranks. If GPUs are connected with NCCL, broadcasting model tensors speeds up checkpoint loading compared to all GPUs loading models from persistent storage. E.g.: + + ```python + # prepare checkpoint state for the model and optimizer + checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } + # load the checkpoint + vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state, broadcast_checkpoint=True) + ``` + +- **Asynchronous Checkpointing**: When `vescale.checkpoint.save` is called, it first generates a saving plan and then synchronously copies tensors from GPU to host memory. If `async_checkpoint` is enabled, the training program can continue after the D2H copying, while `vescale.checkpoint.save` continues to serialize tensors and dump the checkpoint to persistent storage asynchronously without blocking training. As of 05/30/2024, PyTorch DCP does not support asynchronous checkpointing. E.g.: + + ```python + # prepare checkpoint state for the model and optimizer + checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } + # save the checkpoint asynchronuously + vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state, async_checkpoint=True) + ``` + ## Why `vescale.checkpoint`? 1. Manually managing distributed checkpointing, such as writing model saving/loading/resharding scripts under complex distributed environments, is painful and error-prone. @@ -14,10 +42,9 @@ Although existing systems extend `torch.save` for saving checkpoints on multiple ## What is `vescale.checkpoint`? `vescale.checkpoint` offers simple and straightforward APIs, -enabling users to load and save distributed model (e.g., `DModule`) and optimizer (e.g., `DistributedOptimizer`) seamlessly, -abstracting away the complexities of underlying details such as process rank and device mesh. +enabling users to load and save distributed model (e.g., `DModule`) and optimizer (e.g., `DistributedOptimizer`) seamlessly, abstracting away the complexities of underlying details such as process rank and device mesh. -`vescale.checkpoint` supports load-time checkpoint resharding when varying the degrees of data, tensor, or pipeline (TODO) parallelism for both veScale model (e.g., `DModule`) and optimizer (e.g., `DistributedOptimizer`). +`vescale.checkpoint` supports load-time checkpoint resharding when varying the degrees of data, tensor, or pipeline parallelism for both veScale model (e.g., `DModule`) and optimizer (e.g., `DistributedOptimizer`). `vescale.checkpoint` incorporates [fast checkpointing](https://arxiv.org/abs/2402.15627) and various I/O optimization techinques, enhancing I/O efficiency during LLM training. @@ -25,26 +52,28 @@ abstracting away the complexities of underlying details such as process rank and ## How to use `vescale.checkpoint`? -- Saving checkpoint: +- Saving checkpoint: -```python -# prepare checkpoint state for the model and optimizer -checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } -# save the checkpoint -vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state) -``` + ```python + # prepare checkpoint state for the model and optimizer + checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } + # save the checkpoint + vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state) + ``` -- Loading checkpoint (under different world size or 3D parallelism degrees): +- Loading checkpoint (under different world size or 3D parallel sizes): -```python -# prepare checkpoint state for the model and optimizer -checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } -# load the checkpoint -vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state) -``` + ```python + # prepare checkpoint state for the model and optimizer + checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } + # load the checkpoint + vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state) + ``` - APIs can be found in: `/vescale/checkpoint/__init__.py` +- End-to-end example can be found in: `/examples/nanogpt_4D_finetune/finetune_4D.py` + - More examples can be found under `/test/checkpoint/*.py` and `/examples/` -- Original examples can be found in PyTorch [Distributed Checkpoint](https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint) \ No newline at end of file +- Original examples can be found in PyTorch [Distributed Checkpoint](https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint) diff --git a/vescale/checkpoint/__init__.py b/vescale/checkpoint/__init__.py index 646609f..ddcc3b7 100644 --- a/vescale/checkpoint/__init__.py +++ b/vescale/checkpoint/__init__.py @@ -8,13 +8,12 @@ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ # The "checkpoint" folder is ONLY USED for "open source" version veScale -# If you use veScale in ByteDance, please use OmniStore from .api.vescale_checkpointer import VeScaleCheckpointer from .api.meta_type import CheckpointState -def save(path: str, checkpoint_state: CheckpointState): +def save(path: str, checkpoint_state: CheckpointState, async_checkpoint=False): """ Save a checkpoint to a given path Args: @@ -22,14 +21,18 @@ def save(path: str, checkpoint_state: CheckpointState): checkpoint_state: A dictionary contains key-value pairs for model and optimizer. - Model: Identified by 'model' key, value should be a model instance. - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + async_checkpoint: A boolean value indicating if saving checkpoint asynchronously, + i.e. after dumping tensors from GPU memory to Host memory, + the training program can continue training immediately. + Then vescale.checkpoint will serialize tensors and dumping to the persistent storage asynchronously. Example: >>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer } >>> vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state) """ - VeScaleCheckpointer.save(path, checkpoint_state) + VeScaleCheckpointer.save(path, checkpoint_state, async_checkpoint=async_checkpoint) -def load(path: str, checkpoint_state: CheckpointState): +def load(path: str, checkpoint_state: CheckpointState, broadcast_checkpoint=False): """ Load a checkpoint from a given path Args: @@ -37,8 +40,14 @@ def load(path: str, checkpoint_state: CheckpointState): checkpoint_state: A dictionary contains key-value pairs for model and optimizer. - Model: Identified by 'model' key, value should be a model instance. - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + broadcast_checkpoint: A boolean value decides if load a model replica from one data parallel process group + then broadcast tensors to other data parallel process group using GPUs + to reduce the file system access + For example, when data parellel size = 2, + processes with data parallel rank = 0 load model from file system + then broadcast it to processes with data parallel rank = 1 Example: >>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer } >>> vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state) """ - VeScaleCheckpointer.load(path, checkpoint_state) + VeScaleCheckpointer.load(path, checkpoint_state, broadcast_checkpoint=broadcast_checkpoint) diff --git a/vescale/checkpoint/api/base_checkpointer.py b/vescale/checkpoint/api/base_checkpointer.py index 315799b..293f9c0 100644 --- a/vescale/checkpoint/api/base_checkpointer.py +++ b/vescale/checkpoint/api/base_checkpointer.py @@ -15,6 +15,12 @@ # ################################################################################ from .meta_type import CheckpointState +from typing import Dict, List +from concurrent.futures import Future, ProcessPoolExecutor +from torch.distributed.checkpoint.storage import WriteResult +from .meta_type import MODEL_STR, OPTIMIZER_STR + +SUPPORTED_TYPES = {MODEL_STR, OPTIMIZER_STR} class BaseCheckpointer: @@ -23,25 +29,49 @@ class BaseCheckpointer: It is designed for extension across various training frameworks. """ + # Async IO related members. + state_io_workers: Dict[str, ProcessPoolExecutor] = {} + state_write_futures: Dict[str, Future[List[WriteResult]]] = {} + @classmethod def save(cls, path: str, checkpoint_state: CheckpointState): """ A Method for saving checkpoint Args: path: Defines the storage path for checkpoint. - checkpoint_state: A dictionary contains key-value pairs for model, optimizer and dataloader(TODO). + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. - Model: Identified by 'model' key, value should be a model instance. - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + """ raise NotImplementedError() + @classmethod def load(cls, path: str, checkpoint_state: CheckpointState): """ A Method for loading checkpoint Args: path: Defines the storage path for checkpoint. - checkpoint_state: A dictionary contains key-value pairs for model, optimizer and dataloader(TODO). + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. - Model: Identified by 'model' key, value should be a model instance. - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + """ raise NotImplementedError() + + @classmethod + def _cleanup_futures(cls): + """ + Wait for all write futures to finish before exit, then do the cleanup works. + + WARNING: this method cannot be called by the users. + """ + for key in SUPPORTED_TYPES: + if key in cls.state_write_futures: + futures = cls.state_write_futures[key] + for fut in futures: + fut.result() + cls.state_write_futures[key] = [] + if cls.state_io_workers[key] is not None: + cls.state_io_workers[key].shutdown() + cls.state_io_workers[key] = None diff --git a/vescale/checkpoint/api/meta_type.py b/vescale/checkpoint/api/meta_type.py index 5b85740..a669efb 100644 --- a/vescale/checkpoint/api/meta_type.py +++ b/vescale/checkpoint/api/meta_type.py @@ -7,7 +7,7 @@ ################################################################################ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ -# meta_type.py saves all constants and data types commonly used in omnistore +# meta_type.py saves all constants and data types commonly used in vescale.checkpoint from enum import Enum from typing import Dict, Any, TypeVar @@ -18,7 +18,7 @@ MODEL_STR = "model" OPTIMIZER_STR = "optimizer" -SHM_PATH = "/dev/shm" +STATE_DICT_STR = "state_dict" class SupportedStrategy(Enum): diff --git a/vescale/checkpoint/api/vescale_checkpointer.py b/vescale/checkpoint/api/vescale_checkpointer.py index b7fe0d8..61ce3b3 100644 --- a/vescale/checkpoint/api/vescale_checkpointer.py +++ b/vescale/checkpoint/api/vescale_checkpointer.py @@ -14,22 +14,25 @@ # limitations under the License. # ################################################################################ + +from concurrent.futures import ProcessPoolExecutor from .base_checkpointer import BaseCheckpointer from .meta_type import CheckpointState, MODEL_STR, OPTIMIZER_STR from ..save_state_dict import save_state_dict from ..load_state_dict import load_state_dict from ..planner.vescale.vescale_planner import VeScaleSavePlanner, VeScaleLoadPlanner - - +from vescale.devicemesh_api import VESCALE_DEVICE_MESH from ..utilities import bfile import os from vescale.optim.distributed_optimizer import initialize_optimizer_state import torch.distributed as dist -from ..utilities.logger import get_omnistore_logger +from ..utilities.logger import get_vescale_checkpoint_logger +import atexit -logger = get_omnistore_logger() +logger = get_vescale_checkpoint_logger() VESCALE_SUPPORTED_TYPES = {MODEL_STR, OPTIMIZER_STR} +NUM_IO_WORKER = 1 def deduplicate_2d_list(lst): @@ -45,6 +48,26 @@ def deduplicate_2d_list(lst): return deduplicated_list +def get_optim_ckpt_process_group(): + # Get the process group based on current rank + # The processes with same pipeline stage ID + # are in the same process group + device_mesh = VESCALE_DEVICE_MESH.get() + sub_mesh = device_mesh.get_submesh(mesh_dims=["TP", "DP"]) + two_dim_list = sub_mesh.mesh.tolist() + flatten_rank_list = [item for sublist in two_dim_list for item in sublist] + all_flatten_lists = [[] for _ in range(dist.get_world_size())] + dist.all_gather_object(all_flatten_lists, flatten_rank_list) + all_flatten_lists = deduplicate_2d_list(all_flatten_lists) + my_rank = dist.get_rank() + pg = None + for rank_list in all_flatten_lists: + new_pg = dist.new_group(ranks=flatten_rank_list) + if my_rank in rank_list: + pg = new_pg + return pg + + class VeScaleCheckpointer(BaseCheckpointer): """ The Checkpointer class for VeScale, A PyTorch Native Auto Parallelism Framework @@ -53,16 +76,28 @@ class VeScaleCheckpointer(BaseCheckpointer): save_planner = VeScaleSavePlanner() load_planner = VeScaleLoadPlanner() + optim_ckpt_proces_group = None + for key in VESCALE_SUPPORTED_TYPES: + BaseCheckpointer.state_io_workers[key] = ProcessPoolExecutor(max_workers=NUM_IO_WORKER) + BaseCheckpointer.state_write_futures[key] = [] + @classmethod - def save(cls, path: str, checkpoint_state: CheckpointState): + def save( + cls, + path: str, + checkpoint_state: CheckpointState, + async_checkpoint: bool = False, + ): + """ + async_checkpoint: A boolean value indicating if saving checkpoint asynchronously, + i.e. after dumping tensors from GPU memory to Host memory, + the training program can continue training immediately. + Then vescale.checkpoint will serialize tensors and dumping to the persistent storage asynchronously. + """ # Check if we support saving the components for key in checkpoint_state.keys(): if key not in VESCALE_SUPPORTED_TYPES: raise ValueError(f"{key} is not supported by VeScaleCheckpointer") - if bfile.is_local_path(path): - logger.warning( - "The local path for checkpointing should be accessible to all ranks. It can be a NFS/FUSE path" - ) # Start saving checkpoint for key, value in checkpoint_state.items(): @@ -73,15 +108,20 @@ def save(cls, path: str, checkpoint_state: CheckpointState): if dist.get_rank() == 0: bfile.makedirs(model_path) dist.barrier() - # Save model - save_state_dict( + # Save model. + _, new_write_futures = save_state_dict( state_dict=value.state_dict(), path=model_path, process_group=None, coordinator_rank=0, no_dist=False, planner=cls.save_planner, + async_io=async_checkpoint, + last_write_futures=cls.state_write_futures[MODEL_STR], + io_workers=cls.state_io_workers[MODEL_STR], ) + # Record new write futures. + cls.state_write_futures[MODEL_STR] = new_write_futures elif key == OPTIMIZER_STR: # Create a "optimizer" folder on under root path # to save different parts of optimizer @@ -89,25 +129,48 @@ def save(cls, path: str, checkpoint_state: CheckpointState): if dist.get_rank() == 0: bfile.makedirs(optim_root_path) dist.barrier() + # Get process group for saving optimizer, + # All processes with the same pipeline rank are in the same pg + if not cls.optim_ckpt_proces_group: + cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() # Get optimizer path based on PP rank - optimizer_path = os.path.join(optim_root_path, "pp_0") + pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() + optimizer_path = os.path.join(optim_root_path, f"pp_{pp_rank}") # Create optimizer folder on under root path - if dist.get_rank() == 0: + if dist.get_rank(cls.optim_ckpt_proces_group) == 0: bfile.makedirs(optimizer_path) dist.barrier() - - save_state_dict( + # Save optimizer + _, new_write_futures = save_state_dict( state_dict=value.state_dict(), path=optimizer_path, - process_group=None, + process_group=cls.optim_ckpt_proces_group, coordinator_rank=0, no_dist=False, planner=cls.save_planner, + async_io=async_checkpoint, + last_write_futures=cls.state_write_futures[OPTIMIZER_STR], + io_workers=cls.state_io_workers[OPTIMIZER_STR], ) + # Record new write futures. + cls.state_write_futures[OPTIMIZER_STR] = new_write_futures @classmethod - def load(cls, path: str, checkpoint_state: CheckpointState): + def load( + cls, + path: str, + checkpoint_state: CheckpointState, + broadcast_checkpoint: bool = False, + ): + """ + broadcast_checkpoint: A boolean value decides if load a model replica from one data parallel process group + then broadcast tensors to other data parallel process group using GPUs + to reduce the file system access + For example, when data parellel size = 2, + processes with data parallel rank = 0 load model from file system + then broadcast it to processes with data parallel rank = 1 + """ # Add warning if bfile.is_local_path(path): logger.warning( @@ -125,19 +188,31 @@ def load(cls, path: str, checkpoint_state: CheckpointState): model_path = os.path.join(path, MODEL_STR) # Get model state dictionary model_state = value.state_dict() + # Set process group + if broadcast_checkpoint: + model_load_process_group = VESCALE_DEVICE_MESH.get_data_parallel_dim_groups() + else: + model_load_process_group = None # Load model load_state_dict( state_dict=model_state, path=model_path, - process_group=None, + process_group=model_load_process_group, coordinator_rank=0, no_dist=False, planner=cls.load_planner, + broadcast_tensors=broadcast_checkpoint, ) # Load back to model value.load_state_dict(model_state) elif key == OPTIMIZER_STR: - optimizer_path = os.path.join(path, f"{OPTIMIZER_STR}", "pp_0") + # Get process group for loading optimizer, + # All processes with the same pipeline rank are in the same pg + if not cls.optim_ckpt_proces_group: + cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() + # Get optimizer path based on TP and PP ranks + pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() + optimizer_path = os.path.join(path, f"{OPTIMIZER_STR}", f"pp_{pp_rank}") # Initialize optimizer states initialize_optimizer_state(value) # Get optimizer state @@ -146,11 +221,29 @@ def load(cls, path: str, checkpoint_state: CheckpointState): load_state_dict( state_dict=optimizer_state, path=optimizer_path, - process_group=None, + process_group=cls.optim_ckpt_proces_group, coordinator_rank=0, no_dist=False, planner=cls.load_planner, + broadcast_tensors=False, ) # Load back to optimizer value.load_state_dict(optimizer_state) dist.barrier() + + @classmethod + def __cleanup(cls): + """ + Wait for all write futures to finish before exit, then do the cleanup works. + + WARNING: this method cannot be called by the users. + """ + cls.save_planner.clear_cache() + BaseCheckpointer._cleanup_futures() + + @classmethod + def _register_cleanup(cls): + atexit.register(VeScaleCheckpointer.__cleanup) + + +VeScaleCheckpointer._register_cleanup() diff --git a/vescale/checkpoint/load_state_dict.py b/vescale/checkpoint/load_state_dict.py index c8ba977..7c5f77c 100644 --- a/vescale/checkpoint/load_state_dict.py +++ b/vescale/checkpoint/load_state_dict.py @@ -9,18 +9,17 @@ ################################################################################ from typing import Optional - -import torch import torch.distributed as dist from torch.distributed.checkpoint.planner import LoadPlanner from torch.distributed.checkpoint.utils import _DistWrapper from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner -from torch.distributed.checkpoint.filesystem import FileSystemReader +from .storage.filesystem import FileSystemReader from .api.meta_type import STATE_DICT_TYPE import time -from .utilities.logger import get_omnistore_logger +from .utilities.logger import get_vescale_checkpoint_logger +from vescale.checkpoint.planner.vescale.vescale_planner import VeScaleLoadPlanner -logger = get_omnistore_logger() +logger = get_vescale_checkpoint_logger() META_DATA_FILE = ".metadata" @@ -32,15 +31,18 @@ def load_state_dict( coordinator_rank: int = 0, no_dist: bool = False, planner: Optional[LoadPlanner] = None, + broadcast_tensors=False, ) -> None: load_start_time = time.time() """ [veScale version] Loads a distributed ``state_dict`` in SPMD style. Fix sub-group storage. """ + storage_reader = FileSystemReader( + path, + broadcast_tensors=broadcast_tensors, + data_parallel_process_group=process_group, + ) - storage_reader = FileSystemReader(path) - - torch._C._log_api_usage_once("omnistore.checkpoint.vescale_checkpoint.load_state_dict") # Step 0: create distributed world based on process group and coordinator rank distW = _DistWrapper(process_group, not no_dist, coordinator_rank) if process_group: @@ -70,13 +72,16 @@ def global_step(all_local_plans): all_local_plans = storage_reader.prepare_global_plan(all_local_plans) return all_local_plans - central_plan = distW.reduce_scatter("plan", local_step, global_step) + if isinstance(planner, VeScaleLoadPlanner): + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + raise AssertionError("Unsupported planner for saving checkpoint") load_ckpt_plan_cost_time = time.time() - plan_start_time logger.info(f"Finish planning. Cost time: {load_ckpt_plan_cost_time}s") read_start_time = time.time() - # Step 2: all processes read data from path + # Step 2: all processes read data from the given path def read_data(): assert planner is not None final_local_plan = planner.finish_plan(central_plan) diff --git a/vescale/checkpoint/planner/__init__.py b/vescale/checkpoint/planner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vescale/checkpoint/planner/common.py b/vescale/checkpoint/planner/common.py new file mode 100644 index 0000000..c533623 --- /dev/null +++ b/vescale/checkpoint/planner/common.py @@ -0,0 +1,132 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import dataclasses +from typing import Any, Dict, List, Tuple, Hashable, Optional +from collections import OrderedDict +from torch.distributed.checkpoint.planner import SavePlan +from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata +import collections +from ..utilities.logger import get_vescale_checkpoint_logger + +logger = get_vescale_checkpoint_logger() + + +@dataclasses.dataclass +class P2PTensorsInfo: + """ + Record data about tesnors which are across dp ranks + recv_tensors: A dictionary + Key: fqn + Value: a dictionary + key is the process rank, + value is a tuple with (tensor, 1d_range) + send_p2p_reqs: a list of p2p send requests to wait + recv_p2p_reqs: a list p2p receive requests to wait + """ + + recv_tensors: Dict[str, Any] + send_p2p_reqs: List[Any] + recv_p2p_reqs: List[Any] + + +def sort_rank_ranges(process_list: List[Tuple]) -> List[Tuple]: + """ + Decide which rank is receiver and writer + Let rank with most parameters receives and writes tensors + for the best communication cost + If two ranks has the same data size, choose the smaller rank + Args: + A process list with tuples, each tuple is (rank, data_size) + Returns: + A sorted list, data size are sorted in descending order, + if two ranks has the same data size, ranks are in the asceonding order + """ + sorted_process_list = sorted(process_list, key=lambda x: (-x[1], x[0])) + return sorted_process_list + + +_MAX_CACHE_SIZE = 8 + + +class PlanLRUCache: + def __init__(self) -> None: + self._cache: OrderedDict[Hashable, Tuple[SavePlan, Metadata]] = OrderedDict() + self._capacity = _MAX_CACHE_SIZE + + def get(self, key: Hashable) -> Optional[Tuple[SavePlan, Metadata]]: + if key in self._cache: + return self._cache[key] + else: + return None + + def put(self, key: Hashable, plan_value: SavePlan, metadata_value: Metadata) -> None: + if key in self._cache: + self._cache.move_to_end(key, last=False) + else: + self._cache[key] = (plan_value, metadata_value) + if len(self._cache) > self._capacity: + self._cache.popitem() + + def clear(self) -> None: + self._cache.clear() + self._capacity = _MAX_CACHE_SIZE + + def __repr__(self) -> str: + return f"PlanLURCache(capacity: {self._capacity}, keys: {tuple(self._cache.keys())})" + + +def custom_dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: + """ + A function to remove duplicate tensors to write + when creating global writing plan for saving checkpoint + During the deduplication, + we balance the workloads for duplicated tensors + """ + key_to_plan: Dict[MetadataIndex, List[int]] = {} + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + key_to_plan.setdefault(write_item.index, []).append(plan_idx) + + replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} + # Remove duplicates by always keeping the first entry (Not balance). + # Compute the per-rank remove set. + plan_to_keys: Dict[int, List[MetadataIndex]] = {} + # Record the number of non-duplicated tensors assigned to each rank + assigned_work_load = collections.defaultdict(int) + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + if write_item.index not in replicated_items: + assigned_work_load[plan_idx] += 1 + + for key, plans in replicated_items.items(): + # For duplicated tensors, select the rank assigned with minimum number tensors so far + writer_id = min(plans, key=lambda k: assigned_work_load[k]) + assigned_work_load[writer_id] += 1 + for plan_idx in plans: + # If the rank is not writer rank, remove the key in the rank's plan + if plan_idx != writer_id: + plan_to_keys.setdefault(plan_idx, []).append(key) + logger.info("Duplicate keys to remove: %s", plan_to_keys) + + for plan_idx, keys in plan_to_keys.items(): + # Key Set contains keys to remove + key_set = set(keys) + # rewrite items and remove elements + new_items = [write_item for write_item in all_plans[plan_idx].items if write_item.index not in key_set] + all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) + + return all_plans diff --git a/vescale/checkpoint/planner/vescale/vescale_planner.py b/vescale/checkpoint/planner/vescale/vescale_planner.py index c178991..705b1e9 100644 --- a/vescale/checkpoint/planner/vescale/vescale_planner.py +++ b/vescale/checkpoint/planner/vescale/vescale_planner.py @@ -9,36 +9,28 @@ ################################################################################ import io import dataclasses -import logging import torch -from typing import Any, Dict, Union, List, Tuple +from typing import Any, Dict, Union, List, Tuple, Optional from torch.distributed.checkpoint.default_planner import ( DefaultSavePlanner, DefaultLoadPlanner, ) + +import mmh3 + +from vescale.checkpoint.planner.common import P2PTensorsInfo, sort_rank_ranges, PlanLRUCache, custom_dedup_tensors import math import torch.distributed as dist -from torch.distributed.checkpoint.planner import ( - SavePlan, - LoadPlan, - ReadItem, - WriteItem, - WriteItemType, -) -from vescale.optim.distributed_optimizer import OptimizerStateSpec -from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed.checkpoint.planner import SavePlan, LoadPlan, WriteItem, ReadItem from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata +from vescale.optim.distributed_optimizer import OptimizerStateSpec from vescale.dtensor import DTensor -from .vescale_planner_helpers import ( - _create_write_items, - _create_read_items, - find_state_dict_object, -) - +from .vescale_planner_helpers import _create_write_items, _create_read_items, find_state_dict_object from vescale.devicemesh_api import VESCALE_DEVICE_MESH +from ...api.meta_type import STATE_DICT_STR +from ...utilities.logger import get_vescale_checkpoint_logger -logger: logging.Logger = logging.getLogger(__file__) - +logger = get_vescale_checkpoint_logger() __all__ = [ "VeScaleSavePlanner", "VeScaleLoadPlanner", @@ -47,54 +39,6 @@ ] -def sort_rank_ranges(process_list: List[Tuple]) -> List[Tuple]: - """ - Decide which rank is receiver and writer - Let rank with most parameters receives and writes tensors - for the best communication cost - If two ranks has the same data size, choose the smaller rank - Args: - A process list with tuples, each tuple is (rank, data_size) - Returns: - A sorted list, data size are sorted in descending order, - if two ranks has the same data size, ranks are in the asceonding order - """ - sorted_process_list = sorted(process_list, key=lambda x: (-x[1], x[0])) - return sorted_process_list - - -def custom_dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: - """ - A function to remove duplicate tensors to write - when creating global writing plan for saving checkpoint - """ - all_plans = list(all_plans) - key_to_plan: Dict[MetadataIndex, List[int]] = {} - for plan_idx, plan in enumerate(all_plans): - for write_item in plan.items: - # NOTE: the only difference from pytorch official - if write_item.type != WriteItemType.SHARD: - key_to_plan.setdefault(write_item.index, []).append(plan_idx) - - replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} - - # Remove duplicates by always keeping the first entry. - # Compute the per-rank remove set. - plan_to_keys: Dict[int, List[MetadataIndex]] = {} - for key, plans in replicated_items.items(): - for plan_idx in plans[1:]: - plan_to_keys.setdefault(plan_idx, []).append(key) - logger.info("Duplicate keys to remove: %s", plan_to_keys) - - for plan_idx, keys in plan_to_keys.items(): - key_set = set(keys) - # rewrite items and remove elements - new_items = [write_item for write_item in all_plans[plan_idx].items if write_item.index not in key_set] - all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) - - return all_plans - - class VeScaleLoadPlanner(DefaultLoadPlanner): """ A planner class for loading vescale checkpoint using PyTorch DCP @@ -127,16 +71,6 @@ def create_default_local_load_plan(state_dict: Dict[str, Any], metadata: Metadat if isinstance(obj, DTensor): if obj.device_mesh.get_coordinate() is not None: requests += _create_read_items(fqn, md, obj) - elif isinstance(obj, ShardedTensor): - # For veScale DOptimizer, it will provide empty shards - # if current process does not own the shard of tensor - local_shards = obj.local_shards() - total_size = 0 - for local_shard in local_shards: - for size in local_shard.metadata.shard_sizes: - size += total_size - if size > 0: - requests += _create_read_items(fqn, md, obj) elif isinstance(obj, OptimizerStateSpec): # If the state is distributed on multiple dp ranks # Read with local_shape, then in DOptimizer then @@ -163,49 +97,76 @@ class VeScaleSavePlanner(DefaultSavePlanner): def __init__(self): super().__init__() + self._plan_cache = PlanLRUCache() def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: object = self.lookup_object(write_item.index) return self.transform_object(write_item, object) - def create_local_plan(self) -> SavePlan: - plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) + def create_local_plan(self) -> Tuple[SavePlan, P2PTensorsInfo]: + plan, p2p_tensors_info = create_default_local_save_plan(self.state_dict, self.is_coordinator) if self.flatten_state_dict: plan = dataclasses.replace(plan, planner_data=self.mappings) self.plan = plan - return self.plan + return self.plan, p2p_tensors_info def lookup_object(self, index: MetadataIndex) -> Any: return find_state_dict_object(self.state_dict, index) + def lookup_plan_meta(self) -> Optional[Tuple[SavePlan, Metadata]]: + if not hasattr(self, STATE_DICT_STR): + return None + else: + device_mesh = VESCALE_DEVICE_MESH.get() + plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) + return self._plan_cache.get(plan_key) + + def cache_plan_meta(self, new_plan: SavePlan, new_metadata: Metadata) -> None: + device_mesh = VESCALE_DEVICE_MESH.get() + plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) + self._plan_cache.put(plan_key, new_plan, new_metadata) + + def clear_cache(self) -> None: + self._plan_cache.clear() + + def dedup_plans(self, all_plans: List[SavePlan]) -> List[SavePlan]: + # Use customized deduplicate function for load balance + all_plans = custom_dedup_tensors(all_plans) + return all_plans + + def create_dedup_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + # Disable DCP's dedup replicated tensors function + self.dedup_replicated_tensors = False + rst_value = super().create_global_plan(all_plans) + return rst_value + def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: - self.dedup_replicated_tensors = True - # all_plans = custom_dedup_tensors(all_plans) + # Disable DCP's dedup replicated tensors function + self.dedup_replicated_tensors = False + # Use customized deduplicate function for load balance + all_plans = custom_dedup_tensors(all_plans) rst_value = super().create_global_plan(all_plans) return rst_value def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: bool) -> SavePlan: """ - A function for creating local saving plan for saving checkpoint + A function for creating local saving plan for saving checkpoint. """ requests = [] + # Key: fqn + # Value: dictionary (Key is the process rank, value is tensor to receive) + recv_tensors = {} + + send_p2p_reqs = [] + recv_p2p_reqs = {} + for fqn, obj in state_dict.items(): # Since DTensor supports submesh, adding extra check to ensure _create_write_items() # gets called only when the current rank is part of the mesh for the corresponding DTensor. if isinstance(obj, DTensor): if obj.device_mesh.get_coordinate() is not None: requests += _create_write_items(fqn, obj) - elif isinstance(obj, ShardedTensor): - # For veScale DOptimizer, it will provide empty shards - # if current process does not own the shard of tensor - local_shards = obj.local_shards() - total_size = 0 - for local_shard in local_shards: - for size in local_shard.metadata.shard_sizes: - size += total_size - if size > 0: - requests += _create_write_items(fqn, obj) elif isinstance(obj, OptimizerStateSpec): # Create write requests if the process is the real writer if obj.dp_ranks_ranges: @@ -213,14 +174,14 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b for rank, param_range in obj.dp_ranks_ranges.items(): process_list.append((rank, len(param_range))) sorted_list = sort_rank_ranges(process_list) - writer_rank = sorted_list[0][0] - p2p_ops = [] - recv_tensors = {} - + writer_rank = sorted_list[mmh3.hash(fqn) % len(sorted_list)][0] + send_ops_to_start = [] + recv_ops_to_start = {} # Case 1: I am writer # Receive tensors - + logger.debug(f"fqn={fqn} is a tensor across dp ranks. writer rank={writer_rank}") if dist.get_rank() == writer_rank: + recv_tensors[fqn] = {} for k, param_range in obj.dp_ranks_ranges.items(): if k != dist.get_rank(): recv_tensor = torch.zeros( @@ -232,8 +193,8 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b peer=k, group=VESCALE_DEVICE_MESH.get_data_parallel_dim_groups(), ) - recv_tensors[k] = recv_tensor - p2p_ops.append(recv_op) + recv_tensors[fqn][k] = (recv_tensor, param_range) + recv_ops_to_start[k] = recv_op else: # Case 2: I am not writer # Send my tensor @@ -243,31 +204,39 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b peer=writer_rank, group=VESCALE_DEVICE_MESH.get_data_parallel_dim_groups(), ) - p2p_ops.append(send_op) + send_ops_to_start.append(send_op) - reqs = dist.batch_isend_irecv(p2p_ops) + send_reqs = [] + recv_reqs = [] + if send_ops_to_start: + send_reqs = dist.batch_isend_irecv(send_ops_to_start) + if recv_ops_to_start: + recv_reqs = dist.batch_isend_irecv(list(recv_ops_to_start.values())) - for req in reqs: - req.wait() - - if writer_rank == dist.get_rank(): - new_local_tensor = torch.zeros( - (math.prod(obj.local_shape),), dtype=obj.local_tensor.dtype, device=obj.local_tensor.device - ) - new_local_tensor[obj.dp_ranks_ranges[writer_rank].start : obj.dp_ranks_ranges[writer_rank].end] = ( - obj.local_tensor - ) - for k, param_range in obj.dp_ranks_ranges.items(): - if k != writer_rank: - new_local_tensor[param_range.start : param_range.end] = recv_tensors[k] - obj.local_tensor = new_local_tensor + if send_reqs: + send_p2p_reqs.extend(send_reqs) - obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) - requests += _create_write_items(fqn, obj) + if recv_reqs: + recv_p2p_reqs[fqn] = recv_reqs else: obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) requests += _create_write_items(fqn, obj) elif isinstance(obj, (torch.Tensor)) or is_coordinator: requests += _create_write_items(fqn, obj) - return SavePlan(requests) + # Padding the states across DP ranks + # Merge the tensors later + writer_rank = dist.get_rank() + for fqn in recv_tensors.keys(): + obj = state_dict[fqn] + new_local_tensor = torch.zeros( + (math.prod(obj.local_shape),), dtype=obj.local_tensor.dtype, device=obj.local_tensor.device + ) + new_local_tensor[obj.dp_ranks_ranges[writer_rank].start : obj.dp_ranks_ranges[writer_rank].end] = ( + obj.local_tensor + ) + obj.local_tensor = new_local_tensor + + obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) + requests += _create_write_items(fqn, obj) + return SavePlan(requests), P2PTensorsInfo(recv_tensors, send_p2p_reqs, recv_p2p_reqs) diff --git a/vescale/checkpoint/planner/vescale/vescale_planner_helpers.py b/vescale/checkpoint/planner/vescale/vescale_planner_helpers.py index 041f3f3..a444e6f 100644 --- a/vescale/checkpoint/planner/vescale/vescale_planner_helpers.py +++ b/vescale/checkpoint/planner/vescale/vescale_planner_helpers.py @@ -10,9 +10,8 @@ from typing import Any, List import torch -from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.metadata import ShardMetadata -from torch.distributed.checkpoint.planner import WriteItem, WriteItemType, ReadItem, LoadItemType, TensorWriteData +from torch.distributed.checkpoint.planner import WriteItem, ReadItem, WriteItemType, LoadItemType, TensorWriteData from torch.distributed.checkpoint.metadata import ( STATE_DICT_TYPE, STORAGE_TYPES, @@ -22,7 +21,6 @@ TensorStorageMetadata, ) from torch.distributed._shard.sharded_tensor import TensorProperties -from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed.checkpoint.resharding import ( _check_shard_metadata_pair_overlap, _shards_get_overlap_region_wrt_saved_tensor, @@ -54,23 +52,6 @@ def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: return ChunkStorageMetadata(offsets=offsets, sizes=sizes) -def _sharded_tensor_metadata(sharded_tensor: ShardedTensor, shard_md: ShardMetadata) -> TensorWriteData: - return TensorWriteData( - chunk=_chunk_for_shard(shard_md), - properties=sharded_tensor.metadata().tensor_properties, - size=sharded_tensor.metadata().size, - ) - - -def _create_write_item_for_shard(fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata) -> WriteItem: - offsets = torch.Size(shard_md.shard_offsets) - return WriteItem( - index=MetadataIndex(fqn, offsets), - type=WriteItemType.SHARD, - tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), - ) - - def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: offsets = torch.Size([0] * len(tensor.size())) return WriteItem( @@ -109,8 +90,6 @@ def _create_write_item_for_bytesio(fqn: str, bytes: Any): def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: if isinstance(object, DTensor): return [_create_write_items_for_dtensor(fqn, object)] - elif isinstance(object, ShardedTensor): - return [_create_write_item_for_shard(fqn, object, shard.metadata) for shard in object.local_shards()] elif isinstance(object, torch.Tensor): return [_create_write_item_for_tensor(fqn, object)] elif isinstance(object, OptimizerStateSpec): @@ -206,8 +185,6 @@ def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: if not isinstance(md, BytesStorageMetadata): if isinstance(obj, DTensor): local_chunks = [_create_chunk_from_dtensor(obj)] - elif isinstance(obj, ShardedTensor): - local_chunks = [_chunk_for_shard(shard.metadata) for shard in obj.local_shards()] elif isinstance(obj, torch.Tensor): local_chunks = [_create_chunk_from_tensor(obj)] elif isinstance(obj, OptimizerStateSpec): @@ -236,32 +213,14 @@ def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: ) -def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: - if index.offset is None: - raise ValueError(f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided") - - shards = tensor.local_shards() - # index fast path - if index.index is not None: - if len(shards) > index.index and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset: - return shards[index.index] - - for shard in shards: - if torch.Size(shard.metadata.shard_offsets) == index.offset: - return shard - raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") - - def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: if isinstance(tensor, DTensor): - return tensor._local_tensor # keep out of autograd - if isinstance(tensor, ShardedTensor): - return _find_shard(tensor, index).tensor + return tensor.to_local() if index.offset is not None: # special case looking up a tensor by origin if index.offset == torch.Size([0] * len(tensor.size())): return tensor - raise ValueError(f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'") + raise ValueError(f"FQN: '{index.fqn}' is not a DTensor, can't find by offset: '{index.offset}'") return tensor @@ -279,6 +238,6 @@ def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> return obj.local_tensor elif index.offset is not None: raise ValueError( - f"FQN: '{index.fqn}' is not a ShardedTensor, it is a {type(obj)} can't find by offset: '{index.offset}'" + f"FQN: '{index.fqn}' is not a DTensor, it is a {type(obj)} can't find by offset: '{index.offset}'" ) return obj diff --git a/vescale/checkpoint/save_state_dict.py b/vescale/checkpoint/save_state_dict.py index 569745e..7e9b742 100644 --- a/vescale/checkpoint/save_state_dict.py +++ b/vescale/checkpoint/save_state_dict.py @@ -8,49 +8,46 @@ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ -from typing import Optional -import torch -from .utilities.mem_checkpoint import TorchCheckpointRecorder +import os +import pickle +from typing import Optional, Tuple, List + + import torch.distributed as dist -from torch.distributed.checkpoint.filesystem import FileSystemWriter -from torch.distributed.checkpoint.planner import SavePlanner from torch.distributed.checkpoint.utils import _DistWrapper +from .storage.filesystem import FileSystemWriter + + +from torch.distributed.checkpoint.planner import SavePlanner from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.storage import WriteResult from torch.distributed.checkpoint.default_planner import DefaultSavePlanner -from .api.meta_type import STATE_DICT_TYPE -from .utilities.logger import get_omnistore_logger +from vescale.checkpoint.api.meta_type import STATE_DICT_TYPE +from .utilities.logger import get_vescale_checkpoint_logger import time -import atexit - -logger = get_omnistore_logger() -_io_workers = None +from concurrent.futures import Future -def _clean_up(): - if _io_workers: - _io_workers.terminate() - _io_workers.join() +logger = get_vescale_checkpoint_logger() - -atexit.register(_clean_up) +from vescale.checkpoint.planner.vescale.vescale_planner import VeScaleSavePlanner def save_state_dict( state_dict: STATE_DICT_TYPE, path: str, - # storage_writer: StorageWriter, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, no_dist: bool = False, planner: Optional[SavePlanner] = None, - strategy=None, -) -> Metadata: + async_io: bool = True, + last_write_futures: Future[List[WriteResult]] = None, + io_workers=None, +) -> Tuple[Metadata, Future[List[WriteResult]]]: """ [veScale version] Saves a distributed model in SPMD style. Fix sub-group storage. Args and usage is the same as `torch.distributed.checkpoint.save_state_dict`. """ - save_ckpt_start_time = time.time() - torch._C._log_api_usage_once("omnistore.checkpoint.vescale_checkpoint.save_state_dict") # Step 0: create distributed world based on process group and coordinator rank distW = _DistWrapper(process_group, not no_dist, coordinator_rank) @@ -67,48 +64,118 @@ def save_state_dict( # Step 1: all processes create local write plan, # then coordinator gathers all local plans and create global plan. def local_step(): - assert planner is not None - planner.set_up_planner(state_dict, distW.is_coordinator) - storage_writer.set_up_storage_writer(distW.is_coordinator) - local_plan = planner.create_local_plan() - local_plan = storage_writer.prepare_local_plan(local_plan) + logger.debug("Start local step of planning") + if isinstance(planner, VeScaleSavePlanner): + local_plan, p2p_tensors_info = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan, p2p_tensors_info) + else: + raise AssertionError("Unsupported planner for planning") + logger.debug("Finish local step of planning") return local_plan def global_step(all_local_plans): + logger.debug("Start global step of planning") nonlocal global_metatadata - assert planner is not None all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans) all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + logger.debug("End global step of planning") return all_local_plans - plan_start_time = time.time() - central_plan = distW.reduce_scatter("plan", local_step, global_step) - plan_cost_time = time.time() - plan_start_time - logger.info(f"Finish planning. Cost time: {plan_cost_time}s") - # Step 2: all processes write data from GPUs to pinned memory pool, then dump to local path # then coordinator write meta-data to local path. - def write_data(): + def write_data(async_io: bool = False, io_workers=io_workers): + logger.debug("Start writing data") assert planner is not None final_local_plan = planner.finish_plan(central_plan) - # Use pinned memory pool and mult_processing for dumping ckpt to local directory efficiently - global _io_workers - if not _io_workers: - _io_workers = torch.multiprocessing.get_context("spawn").Pool(2) - with TorchCheckpointRecorder(async_worker=_io_workers): - all_writes = storage_writer.write_data(final_local_plan, planner) - all_writes.wait() - return all_writes.value() + if isinstance(planner, VeScaleSavePlanner): + # Use pinned memory pool and mult_processing for dumping ckpt to local directory efficiently + all_write_futures = storage_writer.write_data(final_local_plan, planner, async_io, io_workers) + logger.debug("Finish writing data") + if async_io: + return all_write_futures + else: + # Gather write results. + values = [] + for fut in all_write_futures: + # values += fut.get() + values += fut.result() + return values + else: + raise AssertionError("Unsupported planner for writing data") def finish_checkpoint(all_results): - assert global_metatadata is not None + logger.debug("Start writing metadata") + assert global_metatadata is not None, f"rank: {distW.get_rank()} has no global_metadata" storage_writer.finish(metadata=global_metatadata, results=all_results) + logger.debug("Finish writing metadata") return global_metatadata - dump_local_start_time = time.time() - all_reduce_results = distW.all_reduce("write", write_data, finish_checkpoint) - dump_local_cost_time = time.time() - dump_local_start_time - logger.info(f"Finish dumping. Cost time: {dump_local_cost_time}s") - - return all_reduce_results + assert planner is not None + planner.set_up_planner(state_dict, distW.is_coordinator) + storage_writer.set_up_storage_writer(distW.is_coordinator) + + # Wait for last write futures to finish. + if last_write_futures: + logger.info("Start waiting for last write events.") + last_write_start_time = time.time() + for fut in last_write_futures: + fut.result() + last_write_time = time.time() - last_write_start_time + logger.info(f"Finish waiting for last write events. Time cost: {last_write_time}s") + + # Each worker bypass the `reduce_scatter()` and `all_reduce()` if finding cached central_plan and metadata. + # NOTE: it fails when the plans of partial workers change while others keep the same. + logger.info("Start planning.") + plan_start_time = time.time() + cached_data = None + + if isinstance(planner, VeScaleSavePlanner): + cached_data = planner.lookup_plan_meta() + if cached_data: + logger.debug("Plan cache hit. Reuse existing plan") + central_plan, _ = cached_data + _ = local_step() + else: + logger.debug("Plan cache miss. The model/optimizer appears for the first time.") + + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + raise AssertionError("Unsupported planner for saving checkpoint") + plan_cost_time = time.time() - plan_start_time + logger.info(f"Finish planning. Time cost: {plan_cost_time}s") + + logger.info("Start storing") + store_local_start_time = time.time() + write_futures = [] + if isinstance(planner, VeScaleSavePlanner): + if cached_data: + logger.debug("Metdata cache hit. Reuse existing metadata") + _, final_storage_metadata = cached_data + write_results = write_data(async_io=async_io) + # Be sure to write cache metadata to .metadata file + # Otherwises only the first checkpoint has .metadata + # which leads to error when loading other checkpoints + if distW.is_coordinator: + with (storage_writer.path / ".metadata.tmp").open("wb") as metadata_file: + pickle.dump(final_storage_metadata, metadata_file) + os.fsync(metadata_file.fileno()) + + (storage_writer.path / ".metadata.tmp").rename(storage_writer.path / ".metadata") + + if async_io: + write_futures = write_results + else: + logger.debug("Metadata cache miss. The model/optimizer appears for the first time.") + # First time do synchronous storing to get final_storage_metatdata. + # Determine which communication topology to use. + final_storage_metadata = distW.all_reduce("write", write_data, finish_checkpoint) + assert central_plan is not None + assert final_storage_metadata is not None + planner.cache_plan_meta(central_plan, final_storage_metadata) + else: + raise AssertionError("Unsupported planner for writing data and metadata") + store_local_cost_time = time.time() - store_local_start_time + logger.info(f"Finish storing. Time cost: {store_local_cost_time}s") + + return final_storage_metadata, write_futures diff --git a/vescale/checkpoint/storage/checkpoint_adapter.py b/vescale/checkpoint/storage/checkpoint_adapter.py deleted file mode 100644 index 887f86b..0000000 --- a/vescale/checkpoint/storage/checkpoint_adapter.py +++ /dev/null @@ -1,317 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ - -from abc import abstractmethod -from tqdm import tqdm -from typing import Dict -import os -import re -import torch -import torch.distributed as dist # current we need to use mpi launch -from vescale import DeviceMesh, DTensor -from .checkpoint_format import LLMHandWriteFormat -from typing import Optional, List, Any -from torch.distributed.distributed_c10d import ( - ProcessGroup, - get_rank, - get_world_size, -) -from torch.distributed.checkpoint._nested_dict import flatten_state_dict, unflatten_state_dict -from torch.distributed.checkpoint.metadata import ( - STATE_DICT_TYPE, -) - -from ..utilities.bfile import listdir, BFile - - -def _construct_megatron_downloading_map(filenames: List[str]): - weight_dic_pattern = r"mp_rank_\d\d_\d\d\d$" - filtered_files = [file for file in filenames if re.match(weight_dic_pattern, file)] - - download_map = {} - for file in filtered_files: - parts = file.split("_") - tp_rank = int(parts[2]) - pp_rank = int(parts[3]) - if pp_rank not in download_map: - download_map[pp_rank] = {} - download_map[pp_rank][tp_rank] = file - - return download_map - - -def _construct_reverse_pp_tp_map(vescale_path: str): - if not os.path.exists(vescale_path): - raise RuntimeError(f"vescale_path not exists. path: {vescale_path}") - files = os.listdir(vescale_path) - match = r"rank\d+.pt" - - filtered_files = [file for file in files if re.match(match, file)] - rank_map = {} - for file in filtered_files: - rank = re.search(r"\d+", file).group(0) - rank_map[rank] = os.path.join(vescale_path, file) - return rank_map - - -def _construct_pp_tp_map(megatron_path: str): - """ - construct tp pp index mapping dict - { - # for pp 0 - 0: { - # for tp 0 - 0 : "xxx.pt", - 1 : "xxx.pt" - } - } - """ - dics = listdir(megatron_path) - if len(dics) == 0: - raise RuntimeError(f"megatron_path not exists or is empty. path: {megatron_path}") - - weight_map = dict() - optim_map = dict() - - def update_dict(dic_, pp_r, tp_r, file_path): - if pp_r in dic_: - pp_dic = dic_[pp_r] - pp_dic.update({tp_r: file_path}) - else: - new_dic = {tp_r: file_path} - dic_[pp_r] = new_dic - - weight_dict = r"mp_rank_\d\d_\d\d\d$" - optim_dict = r"mp_rank_\d\d_\d\d\d_\d\d\d$" - filtered_weights_dics = [dic for dic in dics if re.match(weight_dict, dic)] - filtered_optim_dics = [dic for dic in dics if re.match(optim_dict, dic)] - - # construct weight 2-dims maps - for dic in filtered_weights_dics: - split_ul = re.split("_", dic) - tp_rank = int(split_ul[2]) - pp_rank = int(split_ul[3]) - weight_file = os.path.join(megatron_path, dic, "model_rng.pt") - update_dict(weight_map, pp_rank, tp_rank, weight_file) - - # construct optimize 2-dims maps - for dic in filtered_optim_dics: - split_ul = re.split("_", dic) - tp_rank = int(split_ul[2]) - pp_rank = int(split_ul[3]) - optim_file = os.path.join(megatron_path, dic, "optim.pt") - update_dict(optim_map, pp_rank, tp_rank, optim_file) - return weight_map, optim_map - - -def _get_megatron_tp_group(world_size, pp_size, tp_size, dp_size, cur_rank) -> tuple[ProcessGroup, ProcessGroup]: - """make sub pg group""" - return dist.new_subgroups(group_size=tp_size * dp_size) - - -def _deduce_parallel_plan_by_device_mesh(mesh: DeviceMesh): - """make rank to megatron tp_rank, pp_rank map""" - # FIXME : current only support data parallel is 1 - # allways parallel in last dim - tp_size = mesh.size() - # for rank = pp_rank * tp_size + tp_rank - # (rank - tp_rank) / tp_size = pp_rank - tp_rank = get_rank() % tp_size - assert (get_rank() - tp_rank) % tp_size == 0, "megatron not support pp size undivided by tp size" - pp_rank = (get_rank() - tp_rank) // tp_size - return tp_rank, pp_rank - - -def _filter_unused_tensors_and_renaming(old_state_dict: Dict[str, Any], param_resharding_plan: Dict[str, Any]): - new_state_dict = {} - - flatten_old_st, _ = flatten_state_dict(old_state_dict) - - for key, value in flatten_old_st.items(): - for pattern in param_resharding_plan.keys(): - start_index = key.find(pattern) - if start_index == -1: - continue - else: - new_state_dict[pattern] = value - print(new_state_dict.keys()) - return new_state_dict - - -################################################################## -##################### for visitor ##################### -################################################################## - - -class StateDictVisitor: - def set_device_mesh(self, mesh: DeviceMesh): - self.device_mesh = mesh - - @abstractmethod - def parsing_state_dict(self, st: dict, *args, **kwargs): - """ - flattened parsing module dict, using process function to handle each Tensor - """ - f_st, mapping = flatten_state_dict(st) - # flattened_key , value - for key, value in tqdm(f_st.items()): - if isinstance(value, (torch.Tensor, DTensor)): - self.tensor_process_func(f_st, key, value, *args, **kwargs) - new_st = unflatten_state_dict(f_st, mapping) - st.update(new_st) - - @abstractmethod - def tensor_process_func(self, parent: dict, key: str, value: Any, *args, **kwargs): - raise NotImplementedError("method abstruct method is call") - - @abstractmethod - def apply(self, state_dict: dict, *args, **kwargs): - self.parsing_state_dict(state_dict, *args, **kwargs) - - -class DefaultM2VDFSVisitor(StateDictVisitor): - def __init__(self, format: LLMHandWriteFormat): - self.format = format - super().__init__() - - def tensor_process_func(self, parent: dict, key: str, value: Any, *args, **kwargs): - assert self.format is not None, "format is not set" - tensor_placement = self.format.get_tensor_sharding_plan_by_name(key) - assert isinstance(value, torch.Tensor) - - is_requires_grad = value.requires_grad - with torch.no_grad(): # keep out of autograd - dtensor = DTensor.from_local(value, self.device_mesh, tensor_placement) - dtensor.requires_grad_(is_requires_grad) - - parent[key] = dtensor - - def apply(self, state_dict: dict, *args, **kwargs): - self.parsing_state_dict(state_dict, *args, **kwargs) - - -class DefaultV2MDFSVisitor(StateDictVisitor): - def __init__(self): - super().__init__() - - def tensor_process_func(self, parent: dict, key: str, value: DTensor, *args, **kwargs): - parent[key] = value._local_tensor # keep out of autograd - - def apply(self, state_dict: dict, *args, **kwargs): - self.parsing_state_dict(state_dict, *args, **kwargs) - - -################################################################## -##################### for api func ##################### -################################################################## - - -def convert_vescale_checkpoint_to_megatron( - vescale_path: str, megatron_path: str, visitor: StateDictVisitor, device=torch.device("cpu") -) -> STATE_DICT_TYPE: - rank_map = _construct_reverse_pp_tp_map(vescale_path) - world_size = len(rank_map) - assert world_size == get_world_size(), f"world size mismatch {world_size} vs {get_world_size()}" - rank = get_rank() - rank_file_name = rank_map[str(rank)] - rank_file_path = os.path.join(vescale_path, rank_file_name) - if os.path.exists(rank_file_path): - st = torch.load(rank_file_path, map_location=device) - - def find_device_mesh(st): - for key in st: - value = st[key] - if isinstance(value, DTensor): - mesh = value.device_mesh - return mesh - elif isinstance(value, dict): - mesh = find_device_mesh(value) - if mesh: - return mesh - return None - - device_mesh = find_device_mesh(st) - assert device_mesh is not None, "not find devicemesh in vescale format please check" - tp_rank, pp_rank = _deduce_parallel_plan_by_device_mesh(device_mesh) - visitor.apply(st) - megatron_dict = f"mp_rank_{str(tp_rank).zfill(2)}_{str(pp_rank).zfill(3)}" - tmp_path = megatron_path - megatron_save_path = os.path.join(tmp_path, megatron_dict) - os.makedirs(megatron_save_path, exist_ok=True) - megatron_save_file = os.path.join(megatron_save_path, "model_rng.pt") - if "optim" in st: - optim = st["optim"] - megatron_optim_dict = f"mp_rank_{str(tp_rank).zfill(2)}_{str(pp_rank).zfill(3)}_000" - megatron_optim_dict_path = os.path.join(tmp_path, megatron_optim_dict) - os.makedirs(megatron_optim_dict_path, exist_ok=True) - torch.save(optim, os.path.join(megatron_optim_dict_path, "optim.pt")) - del st["optim"] - torch.save(st, megatron_save_file) - # FIXME: support dp not 1 - return st - - -def convert_megatron_checkpoint_to_vescale( - megatron_path: str, visitor: DefaultM2VDFSVisitor, device=torch.device("cpu"), vescale_path: Optional[str] = None -) -> STATE_DICT_TYPE: - weight_map, optim_map = _construct_pp_tp_map(megatron_path) - tp_equal = [(len(weight_map[pp]) == len(weight_map[0])) for pp in weight_map] - assert all(tp_equal), "megatron not support unmodified devided split plan" - tp_size = len(weight_map[0]) - pp_size = len(weight_map) - - rank = get_rank() - - for pp_rank in range(0, pp_size): - for tp_rank in range(0, tp_size): - megatron_rank = pp_rank * tp_size + tp_rank - if megatron_rank != rank: - continue - megatron_weight_pt = weight_map[pp_rank][tp_rank] - # phase 1. parse weight - with BFile(megatron_weight_pt, "rb") as f: - m_st = torch.load(f, map_location=device) - args = m_st["args"] - megatron_cur_rank = args.rank - megatron_world_size = args.world_size - megatron_tp_size = args.tensor_model_parallel_size - megatron_pp_size = args.pipeline_model_parallel_size - megatron_dp_size = args.data_parallel_size - - local_pg, _ = _get_megatron_tp_group( - megatron_world_size, megatron_pp_size, megatron_tp_size, megatron_dp_size, megatron_cur_rank - ) - device_mesh = DeviceMesh(device.__str__(), None, pg=local_pg) - visitor.set_device_mesh(device_mesh) - visitor.apply(m_st["model"], "model") - - new_st = {} - new_st["models"] = _filter_unused_tensors_and_renaming( - m_st["model"], visitor.format.default_params_sharding_plan - ) - if len(optim_map) > 0: - megatron_optim_pt_path = optim_map[pp_rank][tp_rank] - # phase 2. parse optimizer - with BFile(megatron_optim_pt_path, "rb") as f: - optim = torch.load(f, map_location=device) - visitor.apply(optim, "") - new_st["optim"] = optim - if vescale_path: - save_file = f"rank{rank}.pt" - with BFile(os.path.join(vescale_path, save_file), "wb") as f: - torch.save(new_st, f) - return new_st diff --git a/vescale/checkpoint/storage/checkpoint_format.py b/vescale/checkpoint/storage/checkpoint_format.py deleted file mode 100644 index f3e262f..0000000 --- a/vescale/checkpoint/storage/checkpoint_format.py +++ /dev/null @@ -1,48 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ - -import re -from typing import Sequence - - -from vescale import Shard, Replicate, Placement - - -class LLMHandWriteFormat: - def __init__(self, params_sharding_plan): - super().__init__() - self.default_params_sharding_plan = params_sharding_plan - - def get_tensor_sharding_plan_by_name(self, name: str) -> Sequence[Placement]: - for pattern, placements in self.default_params_sharding_plan.items(): - if re.search(pattern, name): - return placements - return [Replicate()] - - -MEGATRON_GPT_RULES = { - r"model.gpt_model.language_model.embedding.word_embeddings.weight": [Shard(0)], - r"model.gpt_model.language_model.encoder.layers.\d+.mlp.dense_h_to_4h.weight": [Shard(0)], - r"model.gpt_model.language_model.encoder.layers.\d+.mlp.dense_h_to_4h_lora.weight": [Shard(0)], - r"model.gpt_model.language_model.encoder.layers.\d+.mlp.dense_4h_to_h.weight": [Shard(1)], - r"model.gpt_model.language_model.encoder.layers.\d+.mlp.dense_4h_to_h_lora.weight": [Shard(1)], - r"model.gpt_model.language_model.encoder.layers.\d+.self_attention.query_key_value.weight": [Shard(0)], - r"model.visual_encoder.blocks.\d+.attn.qkv.weight": [Shard(0)], - r"model.visual_encoder.blocks.\d+.attn.proj.weight": [Shard(1)], - r"model.visual_encoder.blocks.\d+.mlp.fc1.weight": [Shard(0)], - r"model.visual_encoder.blocks.\d+.mlp.fc2.weight": [Shard(1)], -} diff --git a/vescale/checkpoint/storage/filesystem.py b/vescale/checkpoint/storage/filesystem.py new file mode 100644 index 0000000..cfa1c00 --- /dev/null +++ b/vescale/checkpoint/storage/filesystem.py @@ -0,0 +1,880 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from ..utilities.mem_checkpoint import copy_gpu_tensor_to_cpu_pinned_mem_pool, deallocate_cpu_tensor_in_pinned_mem_pool +from abc import ABC, abstractmethod +import collections +from dataclasses import dataclass +import os +import dataclasses +import io +import torch.distributed as dist +import pickle +from typing import List, Tuple, Union, Dict, cast, Any +from ..utilities.logger import get_vescale_checkpoint_logger +import time +import torch +from torch import Tensor +from torch.futures import Future +from pathlib import Path + +from torch.distributed.checkpoint.metadata import ( + Metadata, + MetadataIndex, +) +from torch.distributed.checkpoint.storage import ( + StorageReader, + StorageWriter, + WriteResult, +) + +from torch.distributed.checkpoint.planner import ( + LoadItemType, + LoadPlanner, + LoadPlan, + SavePlan, + SavePlanner, + WriteItem, + ReadItem, + WriteItemType, +) + +from torch.distributed.checkpoint.utils import _create_file_view + +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch._utils import _get_device_module + +logger = get_vescale_checkpoint_logger() +from vescale.checkpoint.planner.common import P2PTensorsInfo + +__all__ = [ + "FileSystemWriter", + "FileSystemReader", +] + + +@dataclass +class _StorageInfo: + """ + This is the per entry storage info + """ + + relative_path: str + offset: int + length: int + + +@dataclass +class _StoragePrefix: + prefix: str + + +DEFAULT_SUFFIX = ".distcp" + + +def _trim(tensor: torch.Tensor) -> torch.Tensor: + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor.detach()) + # Comment the original DCP code + # When dumping to pinned memory, + # the memory layout for tensor has been contiguous + # if tensor._typed_storage()._size() != tensor.numel(): + # tensor = tensor.clone() + return tensor + + +def _result_from_write_item(item: WriteItem, size_in_bytes, storage_data) -> WriteResult: + return WriteResult(index=item.index, size_in_bytes=size_in_bytes, storage_data=storage_data) + + +class _TensorLoader(ABC): + @abstractmethod + def add(self, fqn, size, obj): + pass + + @abstractmethod + def start_loading(self): + pass + + @abstractmethod + def values(self): + pass + + +def collect_optim_state_across_dp_ranks( + tensor: torch.Tensor, rank_ranges: Dict[int, Any], p2p_reqs: Dict[int, Any] +) -> torch.Tensor: + orignal_shape = tensor.shape + tensor = tensor.flatten() + logger.debug("DEBUG: Start receiving p2p tensor") + recv_start = time.time() + for req in p2p_reqs: + req.wait() + recv_end = time.time() - recv_start + logger.debug(f"DEBUG: Finish receiving p2p tensor. Time cost: {recv_end}s") + for v in rank_ranges.values(): + received_tensor, param_range = v + tensor[param_range.start : param_range.end] = received_tensor + tensor = tensor.reshape(orignal_shape) + return tensor + + +class _SerialCpuLoader(_TensorLoader): + def __init__(self, resolve_fun, p2p_tensors_info: P2PTensorsInfo = None): + self.resolve_fun = resolve_fun + self.items = [] + self.p2p_tensors_info = p2p_tensors_info + + def add(self, fqn, size, obj): + self.items.append((fqn, size, obj)) + + def start_loading(self): + pass + + def values(self): + for fqn, _, obj in self.items: + tensor = self.resolve_fun(obj).detach() + if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, + rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], + p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], + ) + elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], p2p_reqs=self.recv_p2p_reqs[fqn] + ) + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor) + # Comment the original DCP code + # When dumping to pinned memory, + # the memory layout for tensor has been contiguous + # if tensor.storage().size() != tensor.numel(): + # tensor = tensor.clone() + yield ( + tensor, + obj, + ) + + +class _OverlappingCpuLoader(_TensorLoader): + def __init__( + self, + resolve_fun, + p2p_tensors_info: P2PTensorsInfo = None, + stream=None, + inflight_threshhold=1_000_000, + ): + self.resolve_fun = resolve_fun + self.items = [] + self.inflight_threshhold = inflight_threshhold + self.in_flight_data = 0 + self.current_items: collections.deque = collections.deque() + self.idx = 0 + self.started = False + self.device_type = stream.device_type if stream else torch.device("cuda").type + self.device_module = _get_device_module(self.device_type) + self.p2p_tensors_info = p2p_tensors_info + self.stream = stream or self.device_module.current_stream() + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) + + @property + def _done(self): + return self.idx >= len(self.items) + + def _drain(self): + drained = [] + if self.in_flight_data >= self.inflight_threshhold: + self.stream.synchronize() + while self.in_flight_data >= self.inflight_threshhold: + val = self.current_items.popleft() + self.in_flight_data -= val[0].numel() * val[0].element_size() + drained.append(val) + return drained + + def _refill(self): + with self.device_module.stream(self.stream): + while not self._done and self.in_flight_data < self.inflight_threshhold: + fqn, _, obj = self.items[self.idx] + self.idx += 1 + tensor = self.resolve_fun(obj).detach() + if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, + rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], + p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], + ) + elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, + rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], + p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], + ) + if tensor.device.type == self.device_type: + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) + # Comment the original DCP code + # When dumping to pinned memory, the memory layout for tensor has been contiguous + # elif tensor.device == torch.device("cpu"): + # if tensor.storage().size() != tensor.numel(): + # # this forces the tensor to be both contiguous and with minimal storage + # tensor = tensor.clone() + + self.current_items.append( + ( + tensor, + obj, + ) + ) + self.in_flight_data += tensor.numel() * tensor.element_size() + + def _finish(self): + assert self._done + if len(self.current_items) > 0: + self.stream.synchronize() + return self.current_items + + def add(self, fqn, size, obj): + if self.started: + raise RuntimeError("cannot add items after loading started") + self.items.append((fqn, size, obj)) + + def start_loading(self): + if self.started: + return + self.started = True + self.items.sort(key=lambda x: x[1]) + self._refill() + + def values(self): + self.start_loading() + while not self._done: + drained = self._drain() + self._refill() + yield from drained + + yield from self._finish() + + +def _item_fqn(item: WriteItem) -> str: + return item.index.fqn + + +def _item_size(item: WriteItem) -> int: + size = 1 + assert item.tensor_data is not None + # can't use math.prod as PT needs to support older python + for s in item.tensor_data.size: + size *= s + + dtype = item.tensor_data.properties.dtype + return size * torch._utils._element_size(dtype) + + +def _split_by_size_and_type(bins, items: List[WriteItem]) -> List[List[WriteItem]]: + if bins == 1: + return [items] + + bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] + + buckets: List[List[WriteItem]] = [[] for _ in range(bins)] + bucket_sizes = [0 for _ in range(bins)] + + tensor_w.sort(key=_item_size, reverse=True) + + for i, wi in enumerate(bytes_w): + buckets[i % bins].append(wi) + + for wi in tensor_w: + idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0] + buckets[idx].append(wi) + bucket_sizes[idx] += _item_size(wi) + + return buckets + + +def _write_item(stream, data, write_item, storage_key): + offset = stream.tell() + + if write_item.type == WriteItemType.BYTE_IO: + assert isinstance(data, io.BytesIO) + stream.write(data.getbuffer()) + else: + assert isinstance(data, torch.Tensor) + assert data.device == torch.device("cpu") + torch.save(data, stream) + length = stream.tell() - offset + + return _result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length)) + + +def _write_files_from_queue( + file_name, + storage_key, + write_items, + planner: SavePlanner, + inflight_threshhold: int, + use_fsync: bool, + p2p_tensors_info: P2PTensorsInfo = None, +): + loader: _TensorLoader + + if torch.cuda.is_available() and inflight_threshhold > 0: + loader = _OverlappingCpuLoader( + lambda x: planner.resolve_data(x), + inflight_threshhold=inflight_threshhold, + p2p_tensors_info=p2p_tensors_info, + ) + else: + loader = _SerialCpuLoader(lambda x: planner.resolve_data(x), p2p_tensors_info=p2p_tensors_info) + + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + for write_item in tensor_w: + loader.add(_item_fqn(write_item), _item_size(write_item), write_item) + loader.start_loading() + + bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + write_results = [] + + stream = open(file_name, "wb") + logger.debug("Start writing byte io data.") + byte_io_write_start = time.time() + for write_item in bytes_w: + data = planner.resolve_data(write_item) + write_results.append(_write_item(stream, data, write_item, storage_key)) + byte_io_write_time = time.time() - byte_io_write_start + logger.debug(f"Finish writing byte io data. Time cost: {byte_io_write_time}s") + + logger.debug("Start writing tensor data.") + tensor_write_start = time.time() + for tensor, write_item in loader.values(): + assert tensor.is_cpu + write_results.append(_write_item(stream, tensor, write_item, storage_key)) + # WARNING: Call deallocate_cpu_tensor_in_pinned_mem_pooltensor + # when the reference to CPU tensor goes to zero + # so the memory pool will reuse the memory if possbile + # Othterwise, the memory pool will allocate memory on the used memory range, + # leading to cuda error 712 cudaErrorHostMemoryAlreadyRegistered + deallocate_cpu_tensor_in_pinned_mem_pool(tensor) + tensor_write_time = time.time() - tensor_write_start + logger.debug(f"Finish writing tensor data. Time cost: {tensor_write_time}s") + + if use_fsync: + os.fsync(stream.fileno()) + + file_stream_close_start = time.time() + stream.close() + file_stream_close_time = time.time() - file_stream_close_start + logger.debug(f"Finish closing file stream. Time cost: {file_stream_close_time}s") + return write_results + + +def _write_files_per_proc( + file_path: Path, + storage_key: str, + byte_data_item: List[Tuple[io.BytesIO, WriteItem]], + tensor_data_item: List[Tuple[torch.Tensor, WriteItem]], + use_fsync: bool, +) -> List[WriteResult]: + write_results = [] + stream = open(file_path, "wb") + # First write byte data. + for write_data, write_item in byte_data_item: + write_results.append(_write_item(stream, write_data, write_item, storage_key)) + # Then write tensor data. + # NOTE: the pinned memory occupied by each tensor have been reallocated. + for write_data, write_item in tensor_data_item: + write_results.append(_write_item(stream, write_data, write_item, storage_key)) + + if use_fsync: + os.fsync(stream.fileno()) + + return write_results + + +def _serialize_tensor(tensor: torch.Tensor) -> bytes: + bio = io.BytesIO() + # NOTE: currently use torch.save() to do the serialization. + torch.save(tensor, bio) + return bio.getbuffer() + + +def _write_to_file(stream, content: bytes, write_item: WriteItem, storage_key: str) -> WriteResult: + offset = stream.tell() + stream.write(content) + length = stream.tell() - offset + return _result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length)) + + +def _write_files_per_proc_pipe( + file_path: Path, + storage_key: str, + byte_data_item: List[Tuple[io.BytesIO, WriteItem]], + tensor_data_item: List[Tuple[torch.Tensor, WriteItem]], + use_fsync: bool, +) -> List[WriteResult]: + write_futures = [] + write_results = [] + stream = open(file_path, "wb") + executor = ThreadPoolExecutor(max_workers=1) + # For byte data, directly write byte data. + for write_data, write_item in byte_data_item: + content = write_data.getbuffer() + write_futures.append( + executor.submit( + _write_to_file, + stream, + content, + write_item, + storage_key, + ) + ) + # write_results.append(_write_to_file(stream, content, write_item, storage_key)) + # For tensor data, perform serialization in process then do saving in threadpool. + for write_data, write_item in tensor_data_item: + content = _serialize_tensor(write_data) + write_futures.append( + executor.submit( + _write_to_file, + stream, + content, + write_item, + storage_key, + ) + ) + # write_results.append(_write_to_file(stream, content, write_item, storage_key)) + + for fut in write_futures: + write_results.append(fut.result()) + if use_fsync: + os.fsync(stream.fileno()) + executor.shutdown(wait=False) + return write_results + + +def stat_analysis(tasks, planner, p2p_tensors_info, use_fsync=True) -> List[WriteResult]: + """ + Analyzing the overhead of D2H transfer, serialization, and save operations. Assume that + all items are written into one file. + """ + # Step1, aysnc D2H, dumping objects to pinned share memory. + assert len(tasks) == 1, "please generate one write task for analysis" + loader = _SerialCpuLoader(lambda x: planner.resolve_data(x), p2p_tensors_info=p2p_tensors_info) + # Add Bytes. + byte_item_to_write = [] + for task in tasks: + _, _, write_items = task + byte_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + byte_item_to_write.extend(byte_w) + # Add tenosrs. + tensor_item_to_write = [] + for task in tasks: + _, _, write_items = task + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + tensor_item_to_write.extend(tensor_w) + for write_item in tensor_w: + loader.add(_item_fqn(write_item), _item_size(write_item), write_item) + loader.start_loading() + # Step1: dump to pinned memory pool. + d2h_dump_wait_start = time.time() + tensor_to_serialize: List[torch.Tensor] = [] + for tensor, write_item in loader.values(): + assert tensor.is_cpu + tensor_to_serialize.append(tensor) + deallocate_cpu_tensor_in_pinned_mem_pool(tensor) + d2h_dump_wait_time = torch.tensor(time.time() - d2h_dump_wait_start).cuda() + dist.all_reduce(d2h_dump_wait_time) + d2h_dump_wait_time = d2h_dump_wait_time.item() / dist.get_world_size() + if dist.get_rank() == 0: + logger.critical(f"End waiting for D2H tensors dumping Time: {d2h_dump_wait_time:.4f}s") + # Step2: call serialization workers to serialize objects. + serialize_wait_start = time.time() + tensor_data_to_write = [] + bio = io.BytesIO() + for tensor in tensor_to_serialize: + bio.seek(0) + bio.truncate(0) + torch.save(tensor, bio) + dump_b = bio.getvalue() + assert isinstance(dump_b, bytes) + tensor_data_to_write.append(dump_b) + serialize_wait_time = torch.tensor(time.time() - serialize_wait_start).cuda() + dist.all_reduce(serialize_wait_time) + serialize_wait_time = serialize_wait_time.item() / dist.get_world_size() + if dist.get_rank() == 0: + logger.critical(f"End waiting for serialization Time: {serialize_wait_time:.4f}s") + # Step3: save/upload the objects from memory to disk. + file_path = tasks[0][0] + storage_key = tasks[0][1] + write_results = [] + assert isinstance(file_path, Path) + save_upload_wait_start = time.time() + with open(file_path, "wb") as stream: + for write_item in byte_item_to_write: + offset = stream.tell() + data = planner.resolve_data(write_item) + stream.write(data.getbuffer()) + length = stream.tell() - offset + write_results.append(_result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length))) + for tensor_data, write_item in zip(tensor_data_to_write, tensor_item_to_write): + offset = stream.tell() + stream.write(tensor_data) + length = stream.tell() - offset + write_results.append(_result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length))) + if use_fsync: + os.fsync(stream.fileno()) + save_upload_wait_time = torch.tensor(time.time() - save_upload_wait_start).cuda() + dist.all_reduce(save_upload_wait_time) + save_upload_wait_time = save_upload_wait_time.item() / dist.get_world_size() + if dist.get_rank() == 0: + logger.critical(f"End waiting for tensors saving/uploading Time: {save_upload_wait_time:.4f}s") + return write_results + + +class FileSystemWriter(StorageWriter): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + worker_count: int = 1, + per_process_copy_ahead: int = 10_000_000, + ) -> None: + """ + Initialize the writer pointing to `path` + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + worker_count: Number of IO workers (processes) to use to write. Default to 1. + per_process_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__() + self.path = Path(path) + self.single_file_per_rank = single_file_per_rank + self.sync_files = sync_files + self.worker_count = worker_count + self.per_process_copy_ahead = per_process_copy_ahead + + def set_up_storage_writer(self, is_coordinator: bool) -> None: + pass + + def prepare_local_plan(self, plan: SavePlan, p2p_tensors_info: P2PTensorsInfo = None) -> SavePlan: + self.path.mkdir(parents=True, exist_ok=True) + self.p2p_tensors_info = p2p_tensors_info + return plan + + def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]: + new_plans = [ + dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) for i, plan in enumerate(global_plan) + ] + return new_plans + + def prepare_write_data(self, tasks: List[Tuple[Path, str, List[WriteItem]]], planner: SavePlanner): + """ + First stage of saving, Perform Copy data to CPU (D2H). + + Args: + tasks: partitoned tasks for workers to conduct serialization and the actual saving. + planner: save planner used to resolve the bytes and tensor data. + async_io: whether do asynchrous D2H. + + NOTE: Currently we do D2H synchronously. + """ + + byte_data_item_writes: List[List[Tuple[io.BytesIO, WriteItem]]] = [] + tensor_data_item_writes: List[List[Tuple[torch.Tensor, WriteItem]]] = [] + file_path_names: List[Tuple[Path, str]] = [] + + # Perform D2H in copy stream. + d2h_dump_start = time.time() + for task in tasks: + file_path, file_name, write_items = task + byte_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + byte_data_item = [(planner.resolve_data(wi), wi) for wi in byte_w] + tensor_data_item = [] + # Async copy to pinned CPU memory pool. + for item in tensor_w: + tensor = planner.resolve_data(item).detach() + fqn = _item_fqn(item) + + if self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, + rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], + p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], + ) + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) + tensor_data_item.append((tensor, item)) + byte_data_item_writes.append(byte_data_item) + tensor_data_item_writes.append(tensor_data_item) + file_path_names.append((file_path, file_name)) + + d2h_dump_time = time.time() - d2h_dump_start + logger.debug(f"End waiting for D2H copy. Time cost: {d2h_dump_time}s") + + # Deallocate pinned memory. + # NOTE: when prepare_write_data() is called next time, make sure the previous save event is completed. + # Otherwise, tensors in pinned memory pool may be overwritten. + for tensor_data_item in tensor_data_item_writes: + for tensor, _ in tensor_data_item: + assert tensor.is_cpu + deallocate_cpu_tensor_in_pinned_mem_pool(tensor) + + return byte_data_item_writes, tensor_data_item_writes, file_path_names + + def write_data( + self, plan: SavePlan, planner: SavePlanner, async_io: bool = False, io_workers=False + ) -> Future[List[WriteResult]]: + storage_plan: _StoragePrefix = plan.storage_data + file_count = 0 + + def gen_file(): + nonlocal file_count + file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" + file_count += 1 + return file_name + + tasks: List[Tuple[Path, str, List[WriteItem]]] = [] + # Generate K tasks where K is the number of worker_count. + if self.single_file_per_rank: + for bucket in _split_by_size_and_type(self.worker_count, plan.items): + file_name = gen_file() + tasks.append((self.path / file_name, file_name, bucket)) + # Generate K tasks where K is the number of write items. + else: + for item in plan.items: + file_name = gen_file() + tasks.append((self.path / file_name, file_name, [item])) + logger.debug(f"Rank {dist.get_rank()} writes its checkpoint into {len(tasks)} files") + # Make sure the optimizer states across dp ranks + # has been sending to other ranks + # So the receiver can get it when writing tensors to local path + + if self.p2p_tensors_info: + logger.debug("Start waiting for sending p2p tensors futures") + p2p_tensor_send_wait_start = time.time() + for req in self.p2p_tensors_info.send_p2p_reqs: + req.wait() + p2p_tensor_send_wait_time = time.time() - p2p_tensor_send_wait_start + logger.debug(f"End waiting for sending p2p tensors futures Time: {p2p_tensor_send_wait_time}s") + + futures = [] + if not io_workers: + executor = ProcessPoolExecutor(max_workers=self.worker_count) + # executor = torch.multiprocessing.get_context("spawn").Pool(self.worker_count) + else: + executor = io_workers + + # ProcessPool VERSION. + if isinstance(executor, ProcessPoolExecutor): + byte_data_item_writes, tensor_data_item_writes, file_path_names = self.prepare_write_data(tasks, planner) + for byte_data_item, tensor_data_item, file_path_name in zip( + byte_data_item_writes, tensor_data_item_writes, file_path_names + ): + file_path, storage_key = file_path_name + worker_args = (file_path, storage_key, byte_data_item, tensor_data_item, self.sync_files) + futures.append(executor.submit(_write_files_per_proc_pipe, *worker_args)) + # futures.append(self._serialize_workers.apply_async(_write_files_per_proc, worker_args)) + if async_io: + return futures + else: + logger.debug("Start waiting for writing futures (serilization + save)") + future_wait_start = time.time() + for fut in futures: + fut.result() + # fut.wait() + future_wait_time = time.time() - future_wait_start + logger.debug(f"End waiting for writing futures. Time cost: {future_wait_time}s") + return futures + else: + # ThreadPool VERSION. + for task in tasks: + futures.append( + executor.submit( + _write_files_from_queue, + *task, + planner, + self.per_process_copy_ahead, + self.sync_files, + self.p2p_tensors_info, + ) + ) + if async_io: + return futures + else: + logger.debug("Start waiting for writing futures") + future_wait_start = time.time() + for fut in futures: + fut.result() + future_wait_time = time.time() - future_wait_start + logger.debug(f"End waiting for writing futures. Time cost: {future_wait_time}s") + return futures + + def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: + storage_md = dict() + for wr_list in results: + storage_md.update({wr.index: wr.storage_data for wr in wr_list}) + metadata.storage_data = storage_md + with (self.path / ".metadata.tmp").open("wb") as metadata_file: + pickle.dump(metadata, metadata_file) + os.fsync(metadata_file.fileno()) + + (self.path / ".metadata.tmp").rename(self.path / ".metadata") + + +class FileSystemReader(StorageReader): + def __init__( + self, + path: Union[str, os.PathLike], + broadcast_tensors=False, + data_parallel_process_group=None, + ) -> None: + super().__init__() + self.path = path + self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict() + self.broadcast_tensors = broadcast_tensors + self.data_parallel_process_group = data_parallel_process_group + + # If broadcast_tensors is enabled, the data_parallel_process_group is not none + if self.broadcast_tensors: + assert self.data_parallel_process_group + + def _slice_file(self, file, sinfo: _StorageInfo): + return _create_file_view(file, sinfo.offset, sinfo.length) + + def _get_file_path(self, relative_path): + file_path = os.path.join(self.path, relative_path) + return file_path + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + # group requests by file + per_file: Dict[str, List[ReadItem]] = dict() + for read_item in plan.items: + item_md = self.storage_data[read_item.storage_index] + path = item_md.relative_path + per_file.setdefault(path, []).append(read_item) + + # If broadcasting model tensors is enabled, + # let processes with dp_rank=0 load models and broadcast them to other processes + if self.broadcast_tensors: + self.read_data_with_broadcast(per_file=per_file, planner=planner) + else: + # Otherwise, let all ranks load tensors from files + self.read_from_files(per_file=per_file, planner=planner) + + fut: Future = Future() + fut.set_result(None) + + return fut + + def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner): + for relative_path, reqs in per_file.items(): + file_path = self._get_file_path(relative_path) + with open(file_path, "rb") as file: + reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset) + for req in reqs: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(file, item_md) + if req.type == LoadItemType.BYTE_IO: + bytes = io.BytesIO(file_slice.read(item_md.length)) + bytes.seek(0) + planner.load_bytes(req, bytes) + else: + tensor = cast(Tensor, torch.load(file_slice, map_location="cpu")) + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + + assert ( + target_tensor.size() == tensor.size() + ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner): + for relative_path, reqs in per_file.items(): + if dist.get_rank(self.data_parallel_process_group) == 0: + file_path = self._get_file_path(relative_path) + file = open(file_path, "rb") + dist.barrier(self.data_parallel_process_group) + reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset) + for req in reqs: + if dist.get_rank(self.data_parallel_process_group) == 0: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(file, item_md) + + if req.type == LoadItemType.BYTE_IO: + if dist.get_rank(self.data_parallel_process_group) == 0: + object_list = [io.BytesIO(file_slice.read(item_md.length))] + else: + object_list = [None] + + dist.broadcast_object_list( + object_list, + src=dist.get_global_rank(self.data_parallel_process_group, 0), + group=self.data_parallel_process_group, + device=f"cuda:{torch.cuda.current_device()}", + ) + bytes = object_list[0] + bytes.seek(0) + planner.load_bytes(req, bytes) + else: + if dist.get_rank(self.data_parallel_process_group) == 0: + object_list = [cast(Tensor, torch.load(file_slice, map_location="cuda"))] + else: + object_list = [None] + dist.broadcast_object_list( + object_list, + src=dist.get_global_rank(self.data_parallel_process_group, 0), + group=self.data_parallel_process_group, + device=f"cuda:{torch.cuda.current_device()}", + ) + tensor = object_list[0].cpu() + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + + assert ( + target_tensor.size() == tensor.size() + ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + # Implementing the abstract function in StorageReader + def read_metadata(self) -> Metadata: + metadata_path = self._get_file_path(".metadata") + with open(metadata_path, "rb") as metadata_file: + metadata = pickle.load(metadata_file) + return metadata + + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + self.storage_data = metadata.storage_data + assert self.storage_data is not None + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + return plan + + def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: + return global_plan diff --git a/vescale/checkpoint/utilities/bfile.py b/vescale/checkpoint/utilities/bfile.py index 980dbc8..02b83d6 100644 --- a/vescale/checkpoint/utilities/bfile.py +++ b/vescale/checkpoint/utilities/bfile.py @@ -22,11 +22,11 @@ import enum import contextlib import uuid -from .logger import get_omnistore_logger +from .logger import get_vescale_checkpoint_logger import shutil from .server import mem_server_lib -logger = get_omnistore_logger() +logger = get_vescale_checkpoint_logger() BFILE_DEFAULT_TIMEOUT = None diff --git a/vescale/checkpoint/utilities/logger.py b/vescale/checkpoint/utilities/logger.py index 98bb964..e97abe8 100644 --- a/vescale/checkpoint/utilities/logger.py +++ b/vescale/checkpoint/utilities/logger.py @@ -226,11 +226,11 @@ def log_to_root(message, *args, **kwargs): setattr(logging, method_name, log_to_root) -class OmniStoreLogger: +class VeScaleCheckpointLogger: def __new__(cls): if not hasattr(cls, "instance"): level = logging.WARNING - level_str = os.environ.get("OMNISTORE_LOGGING_LEVEL", "WARNING").upper() + level_str = os.environ.get("VESCALE_CHECKPOINT_LOGGING_LEVEL", "WARNING").upper() if level_str in logging._nameToLevel: level = logging._nameToLevel[level_str] formatter = logging.Formatter( @@ -239,13 +239,13 @@ def __new__(cls): ) handler = logging.StreamHandler(stream=sys.stdout) handler.setFormatter(formatter) - cls.instance = logging.getLogger("omnistore") + cls.instance = logging.getLogger("vescale_checkpoint") cls.instance.addHandler(handler) cls.instance.setLevel(level) cls.instance.propagate = False return cls.instance -def get_omnistore_logger(): - """Get omnistore logger with logging level OMNISTORE_LOGGING_LEVEL, and output to stdout.""" - return OmniStoreLogger() +def get_vescale_checkpoint_logger(): + """Get vescale.checkpoint logger with logging level VESCALE_CHECKPOINT_LOGGING_LEVEL, and output to stdout.""" + return VeScaleCheckpointLogger() diff --git a/vescale/checkpoint/utilities/mem_checkpoint.py b/vescale/checkpoint/utilities/mem_checkpoint.py index 9549f62..ab69cd7 100644 --- a/vescale/checkpoint/utilities/mem_checkpoint.py +++ b/vescale/checkpoint/utilities/mem_checkpoint.py @@ -24,9 +24,9 @@ import pickle from . import bfile -from .logger import get_omnistore_logger +from .logger import get_vescale_checkpoint_logger -logger = get_omnistore_logger() +logger = get_vescale_checkpoint_logger() if hasattr(torch.storage, "TypedStorage"): TypedStorage = torch.storage.TypedStorage @@ -84,12 +84,72 @@ def allocate(self, nbytes: int): return s.pop() def deallocate(self, s): + # WARNING: Call deallocate when the reference to CPU tensor goes to zero + # so the memory pool will reuse the memory if possbile + # Othterwise, the memory pool will allocate memory on the used memory range, + # leading to cuda error 712 cudaErrorHostMemoryAlreadyRegistered with self._l: self._m[s.nbytes()].add(s) GLOBAL_POOL = PinnedStoragePool() +TID = threading.get_ident() + + +def copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor: torch.Tensor, non_blocking=False) -> torch.Tensor: + """ + Copy a tensor on GPU to pinned memory pool (host CPU memory). + The input tensor will not be modified + Args: + tensor: a tensor on cuda device + Return: + a tensor on cpu, whose data is the same as input tensor + """ + m = {} + _old_warning = getattr(torch.storage, "_warn_typed_storage_removal", None) + torch.storage._warn_typed_storage_removal = lambda *args, **kwags: None + + def persistent_id(o): + if torch.is_storage(o) or isinstance(o, TypedStorage): + storage = o + if storage._cdata in m: + return storage._cdata + if storage.device.type != "cpu": + copied = GLOBAL_POOL.allocate(storage.nbytes()) + copied.copy_(storage, non_blocking=non_blocking) + if isinstance(storage, TypedStorage): + copied = storage._new_wrapped_storage(copied) + else: + copied = storage.clone() + m[storage._cdata] = copied + return storage._cdata + return + + b = io.BytesIO() + p = pickle.Pickler(b) + p.persistent_id = persistent_id + p.dump(tensor) + b.seek(0) + up = pickle.Unpickler(b) + up.persistent_load = lambda i: m[i] + cpu_tensor = up.load() + """ + assert type(tensor) == torch.Tensor + storage_obj = tensor.storage() + cpu_storage = GLOBAL_POOL.allocate(storage_obj.nbytes()) + + cpu_storage.copy_(storage_obj, non_blocking=non_blocking) + cpu_tensor = torch.tensor(cpu_storage) + """ + torch.storage._warn_typed_storage_removal = _old_warning + return cpu_tensor + + +def deallocate_cpu_tensor_in_pinned_mem_pool(tensor: torch.Tensor): + "Deallocate CPU tensor in the global pinned memory pool" + GLOBAL_POOL.deallocate(tensor.untyped_storage()) + class _CalledOnce: def __init__(self, func): diff --git a/vescale/checkpoint/utilities/server/mem_file_service.proto b/vescale/checkpoint/utilities/server/mem_file_service.proto index 6ca5723..614a694 100644 --- a/vescale/checkpoint/utilities/server/mem_file_service.proto +++ b/vescale/checkpoint/utilities/server/mem_file_service.proto @@ -1,72 +1,72 @@ // Run // // python -m grpc_tools.protoc -I. --python_out=. --pyi_out=. \ -// --grpc_python_out=. ./omnistore/utilities/server/mem_file_service.proto +// --grpc_python_out=. ./checkpoint/utilities/server/mem_file_service.proto // // to generate new protos. syntax = "proto3"; -message OmniStoreWriteRequest { +message VeScaleCheckpointWriteRequest { bytes content = 1; string name = 8; } -message OmniStoreWriteResponse { +message VeScaleCheckpointWriteResponse { } -message OmniStoreReadRequest { +message VeScaleCheckpointReadRequest { string name = 1; } -message OmniStoreReadResponse { +message VeScaleCheckpointReadResponse { bytes content = 1; } -message OmniStoreRenameRequest { +message VeScaleCheckpointRenameRequest { string src = 1; string dst = 2; bool overwrite = 3; } -message OmniStoreRenameResponse { +message VeScaleCheckpointRenameResponse { } -message OmniStoreRemoveRequest { +message VeScaleCheckpointRemoveRequest { string name = 1; } -message OmniStoreRemoveResponse { +message VeScaleCheckpointRemoveResponse { } -message OmniStoreListdirRequest { +message VeScaleCheckpointListdirRequest { string name = 1; } -message OmniStoreListdirResponse { +message VeScaleCheckpointListdirResponse { repeated string names = 1; } -message OmniStoreExistsRequest { +message VeScaleCheckpointExistsRequest { string name = 1; } -message OmniStoreExistsResponse { +message VeScaleCheckpointExistsResponse { bool exists = 1; } -service OmniStoreMemFileService { - rpc Write(stream OmniStoreWriteRequest) returns (OmniStoreWriteResponse) { +service VeScaleCheckpointMemFileService { + rpc Write(stream VeScaleCheckpointWriteRequest) returns (VeScaleCheckpointWriteResponse) { } - rpc Read(OmniStoreReadRequest) returns (stream OmniStoreReadResponse) { + rpc Read(VeScaleCheckpointReadRequest) returns (stream VeScaleCheckpointReadResponse) { } - rpc Rename(OmniStoreRenameRequest) returns (OmniStoreRenameResponse) { + rpc Rename(VeScaleCheckpointRenameRequest) returns (VeScaleCheckpointRenameResponse) { } - rpc Remove(OmniStoreRemoveRequest) returns (OmniStoreRemoveResponse) { + rpc Remove(VeScaleCheckpointRemoveRequest) returns (VeScaleCheckpointRemoveResponse) { } - rpc Listdir(OmniStoreListdirRequest) returns (OmniStoreListdirResponse) { + rpc Listdir(VeScaleCheckpointListdirRequest) returns (VeScaleCheckpointListdirResponse) { } - rpc Exists(OmniStoreExistsRequest) returns (OmniStoreExistsResponse) { + rpc Exists(VeScaleCheckpointExistsRequest) returns (VeScaleCheckpointExistsResponse) { } } \ No newline at end of file diff --git a/vescale/checkpoint/utilities/server/mem_file_service_pb2.py b/vescale/checkpoint/utilities/server/mem_file_service_pb2.py index feebf88..5966cd2 100644 --- a/vescale/checkpoint/utilities/server/mem_file_service_pb2.py +++ b/vescale/checkpoint/utilities/server/mem_file_service_pb2.py @@ -15,8 +15,8 @@ # ################################################################################ # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: omnistore/utilities/server/mem_file_service.proto -# Protobuf Python Version: 4.25.0 +# source: checkpoint/utilities/server/mem_file_service.proto +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor @@ -29,38 +29,38 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n1OmniStore/utilities/server/mem_file_service.proto"6\n\x15OmniStoreWriteRequest\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c\x12\x0c\n\x04name\x18\x08 \x01(\t"\x18\n\x16OmniStoreWriteResponse"$\n\x14OmniStoreReadRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"(\n\x15OmniStoreReadResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c"E\n\x16OmniStoreRenameRequest\x12\x0b\n\x03src\x18\x01 \x01(\t\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x11\n\toverwrite\x18\x03 \x01(\x08"\x19\n\x17OmniStoreRenameResponse"&\n\x16OmniStoreRemoveRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"\x19\n\x17OmniStoreRemoveResponse"\'\n\x17OmniStoreListdirRequest\x12\x0c\n\x04name\x18\x01 \x01(\t")\n\x18OmniStoreListdirResponse\x12\r\n\x05names\x18\x01 \x03(\t"&\n\x16OmniStoreExistsRequest\x12\x0c\n\x04name\x18\x01 \x01(\t")\n\x17OmniStoreExistsResponse\x12\x0e\n\x06\x65xists\x18\x01 \x01(\x08\x32\x91\x03\n\x17OmniStoreMemFileService\x12<\n\x05Write\x12\x16.OmniStoreWriteRequest\x1a\x17.OmniStoreWriteResponse"\x00(\x01\x12\x39\n\x04Read\x12\x15.OmniStoreReadRequest\x1a\x16.OmniStoreReadResponse"\x00\x30\x01\x12=\n\x06Rename\x12\x17.OmniStoreRenameRequest\x1a\x18.OmniStoreRenameResponse"\x00\x12=\n\x06Remove\x12\x17.OmniStoreRemoveRequest\x1a\x18.OmniStoreRemoveResponse"\x00\x12@\n\x07Listdir\x12\x18.OmniStoreListdirRequest\x1a\x19.OmniStoreListdirResponse"\x00\x12=\n\x06\x45xists\x12\x17.OmniStoreExistsRequest\x1a\x18.OmniStoreExistsResponse"\x00\x62\x06proto3' + b'\n2checkpoint/utilities/server/mem_file_service.proto">\n\x1dVeScaleCheckpointWriteRequest\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c\x12\x0c\n\x04name\x18\x08 \x01(\t" \n\x1eVeScaleCheckpointWriteResponse",\n\x1cVeScaleCheckpointReadRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"0\n\x1dVeScaleCheckpointReadResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c"M\n\x1eVeScaleCheckpointRenameRequest\x12\x0b\n\x03src\x18\x01 \x01(\t\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x11\n\toverwrite\x18\x03 \x01(\x08"!\n\x1fVeScaleCheckpointRenameResponse".\n\x1eVeScaleCheckpointRemoveRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"!\n\x1fVeScaleCheckpointRemoveResponse"/\n\x1fVeScaleCheckpointListdirRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"1\n VeScaleCheckpointListdirResponse\x12\r\n\x05names\x18\x01 \x03(\t".\n\x1eVeScaleCheckpointExistsRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"1\n\x1fVeScaleCheckpointExistsResponse\x12\x0e\n\x06\x65xists\x18\x01 \x01(\x08\x32\xf9\x03\n\x1fVeScaleCheckpointMemFileService\x12L\n\x05Write\x12\x1e.VeScaleCheckpointWriteRequest\x1a\x1f.VeScaleCheckpointWriteResponse"\x00(\x01\x12I\n\x04Read\x12\x1d.VeScaleCheckpointReadRequest\x1a\x1e.VeScaleCheckpointReadResponse"\x00\x30\x01\x12M\n\x06Rename\x12\x1f.VeScaleCheckpointRenameRequest\x1a .VeScaleCheckpointRenameResponse"\x00\x12M\n\x06Remove\x12\x1f.VeScaleCheckpointRemoveRequest\x1a .VeScaleCheckpointRemoveResponse"\x00\x12P\n\x07Listdir\x12 .VeScaleCheckpointListdirRequest\x1a!.VeScaleCheckpointListdirResponse"\x00\x12M\n\x06\x45xists\x12\x1f.VeScaleCheckpointExistsRequest\x1a .VeScaleCheckpointExistsResponse"\x00\x62\x06proto3' ) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "omnistore.utilities.server.mem_file_service_pb2", _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "checkpoint.utilities.server.mem_file_service_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS is False: DESCRIPTOR._options = None - _globals["_OMNISTOREWRITEREQUEST"]._serialized_start = 53 - _globals["_OMNISTOREWRITEREQUEST"]._serialized_end = 107 - _globals["_OMNISTOREWRITERESPONSE"]._serialized_start = 109 - _globals["_OMNISTOREWRITERESPONSE"]._serialized_end = 133 - _globals["_OMNISTOREREADREQUEST"]._serialized_start = 135 - _globals["_OMNISTOREREADREQUEST"]._serialized_end = 171 - _globals["_OMNISTOREREADRESPONSE"]._serialized_start = 173 - _globals["_OMNISTOREREADRESPONSE"]._serialized_end = 213 - _globals["_OMNISTORERENAMEREQUEST"]._serialized_start = 215 - _globals["_OMNISTORERENAMEREQUEST"]._serialized_end = 284 - _globals["_OMNISTORERENAMERESPONSE"]._serialized_start = 286 - _globals["_OMNISTORERENAMERESPONSE"]._serialized_end = 311 - _globals["_OMNISTOREREMOVEREQUEST"]._serialized_start = 313 - _globals["_OMNISTOREREMOVEREQUEST"]._serialized_end = 351 - _globals["_OMNISTOREREMOVERESPONSE"]._serialized_start = 353 - _globals["_OMNISTOREREMOVERESPONSE"]._serialized_end = 378 - _globals["_OMNISTORELISTDIRREQUEST"]._serialized_start = 380 - _globals["_OMNISTORELISTDIRREQUEST"]._serialized_end = 419 - _globals["_OMNISTORELISTDIRRESPONSE"]._serialized_start = 421 - _globals["_OMNISTORELISTDIRRESPONSE"]._serialized_end = 462 - _globals["_OMNISTOREEXISTSREQUEST"]._serialized_start = 464 - _globals["_OMNISTOREEXISTSREQUEST"]._serialized_end = 502 - _globals["_OMNISTOREEXISTSRESPONSE"]._serialized_start = 504 - _globals["_OMNISTOREEXISTSRESPONSE"]._serialized_end = 545 - _globals["_OMNISTOREMEMFILESERVICE"]._serialized_start = 548 - _globals["_OMNISTOREMEMFILESERVICE"]._serialized_end = 949 + _globals["_VESCALECHECKPOINTWRITEREQUEST"]._serialized_start = 54 + _globals["_VESCALECHECKPOINTWRITEREQUEST"]._serialized_end = 116 + _globals["_VESCALECHECKPOINTWRITERESPONSE"]._serialized_start = 118 + _globals["_VESCALECHECKPOINTWRITERESPONSE"]._serialized_end = 150 + _globals["_VESCALECHECKPOINTREADREQUEST"]._serialized_start = 152 + _globals["_VESCALECHECKPOINTREADREQUEST"]._serialized_end = 196 + _globals["_VESCALECHECKPOINTREADRESPONSE"]._serialized_start = 198 + _globals["_VESCALECHECKPOINTREADRESPONSE"]._serialized_end = 246 + _globals["_VESCALECHECKPOINTRENAMEREQUEST"]._serialized_start = 248 + _globals["_VESCALECHECKPOINTRENAMEREQUEST"]._serialized_end = 325 + _globals["_VESCALECHECKPOINTRENAMERESPONSE"]._serialized_start = 327 + _globals["_VESCALECHECKPOINTRENAMERESPONSE"]._serialized_end = 360 + _globals["_VESCALECHECKPOINTREMOVEREQUEST"]._serialized_start = 362 + _globals["_VESCALECHECKPOINTREMOVEREQUEST"]._serialized_end = 408 + _globals["_VESCALECHECKPOINTREMOVERESPONSE"]._serialized_start = 410 + _globals["_VESCALECHECKPOINTREMOVERESPONSE"]._serialized_end = 443 + _globals["_VESCALECHECKPOINTLISTDIRREQUEST"]._serialized_start = 445 + _globals["_VESCALECHECKPOINTLISTDIRREQUEST"]._serialized_end = 492 + _globals["_VESCALECHECKPOINTLISTDIRRESPONSE"]._serialized_start = 494 + _globals["_VESCALECHECKPOINTLISTDIRRESPONSE"]._serialized_end = 543 + _globals["_VESCALECHECKPOINTEXISTSREQUEST"]._serialized_start = 545 + _globals["_VESCALECHECKPOINTEXISTSREQUEST"]._serialized_end = 591 + _globals["_VESCALECHECKPOINTEXISTSRESPONSE"]._serialized_start = 593 + _globals["_VESCALECHECKPOINTEXISTSRESPONSE"]._serialized_end = 642 + _globals["_VESCALECHECKPOINTMEMFILESERVICE"]._serialized_start = 645 + _globals["_VESCALECHECKPOINTMEMFILESERVICE"]._serialized_end = 1150 # @@protoc_insertion_point(module_scope) diff --git a/vescale/checkpoint/utilities/server/mem_file_service_pb2.pyi b/vescale/checkpoint/utilities/server/mem_file_service_pb2.pyi index dc71884..bd7cb09 100644 --- a/vescale/checkpoint/utilities/server/mem_file_service_pb2.pyi +++ b/vescale/checkpoint/utilities/server/mem_file_service_pb2.pyi @@ -1,3 +1,19 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message @@ -5,7 +21,7 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Optional as _Op DESCRIPTOR: _descriptor.FileDescriptor -class OmniStoreWriteRequest(_message.Message): +class VeScaleCheckpointWriteRequest(_message.Message): __slots__ = ("content", "name") CONTENT_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int] @@ -13,23 +29,23 @@ class OmniStoreWriteRequest(_message.Message): name: str def __init__(self, content: _Optional[bytes] = ..., name: _Optional[str] = ...) -> None: ... -class OmniStoreWriteResponse(_message.Message): +class VeScaleCheckpointWriteResponse(_message.Message): __slots__ = () def __init__(self) -> None: ... -class OmniStoreReadRequest(_message.Message): +class VeScaleCheckpointReadRequest(_message.Message): __slots__ = ("name",) NAME_FIELD_NUMBER: _ClassVar[int] name: str def __init__(self, name: _Optional[str] = ...) -> None: ... -class OmniStoreReadResponse(_message.Message): +class VeScaleCheckpointReadResponse(_message.Message): __slots__ = ("content",) CONTENT_FIELD_NUMBER: _ClassVar[int] content: bytes def __init__(self, content: _Optional[bytes] = ...) -> None: ... -class OmniStoreRenameRequest(_message.Message): +class VeScaleCheckpointRenameRequest(_message.Message): __slots__ = ("src", "dst", "overwrite") SRC_FIELD_NUMBER: _ClassVar[int] DST_FIELD_NUMBER: _ClassVar[int] @@ -39,39 +55,39 @@ class OmniStoreRenameRequest(_message.Message): overwrite: bool def __init__(self, src: _Optional[str] = ..., dst: _Optional[str] = ..., overwrite: bool = ...) -> None: ... -class OmniStoreRenameResponse(_message.Message): +class VeScaleCheckpointRenameResponse(_message.Message): __slots__ = () def __init__(self) -> None: ... -class OmniStoreRemoveRequest(_message.Message): +class VeScaleCheckpointRemoveRequest(_message.Message): __slots__ = ("name",) NAME_FIELD_NUMBER: _ClassVar[int] name: str def __init__(self, name: _Optional[str] = ...) -> None: ... -class OmniStoreRemoveResponse(_message.Message): +class VeScaleCheckpointRemoveResponse(_message.Message): __slots__ = () def __init__(self) -> None: ... -class OmniStoreListdirRequest(_message.Message): +class VeScaleCheckpointListdirRequest(_message.Message): __slots__ = ("name",) NAME_FIELD_NUMBER: _ClassVar[int] name: str def __init__(self, name: _Optional[str] = ...) -> None: ... -class OmniStoreListdirResponse(_message.Message): +class VeScaleCheckpointListdirResponse(_message.Message): __slots__ = ("names",) NAMES_FIELD_NUMBER: _ClassVar[int] names: _containers.RepeatedScalarFieldContainer[str] def __init__(self, names: _Optional[_Iterable[str]] = ...) -> None: ... -class OmniStoreExistsRequest(_message.Message): +class VeScaleCheckpointExistsRequest(_message.Message): __slots__ = ("name",) NAME_FIELD_NUMBER: _ClassVar[int] name: str def __init__(self, name: _Optional[str] = ...) -> None: ... -class OmniStoreExistsResponse(_message.Message): +class VeScaleCheckpointExistsResponse(_message.Message): __slots__ = ("exists",) EXISTS_FIELD_NUMBER: _ClassVar[int] exists: bool diff --git a/vescale/checkpoint/utilities/server/mem_file_service_pb2_grpc.py b/vescale/checkpoint/utilities/server/mem_file_service_pb2_grpc.py index 978b388..4558bfa 100644 --- a/vescale/checkpoint/utilities/server/mem_file_service_pb2_grpc.py +++ b/vescale/checkpoint/utilities/server/mem_file_service_pb2_grpc.py @@ -20,11 +20,11 @@ import grpc from . import ( - mem_file_service_pb2 as OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2, + mem_file_service_pb2 as checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2, ) -class OmniStoreMemFileServiceStub: +class VeScaleCheckpointMemFileServiceStub: """Missing associated documentation comment in .proto file.""" def __init__(self, channel): @@ -34,38 +34,38 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Write = channel.stream_unary( - "/OmniStoreMemFileService/Write", - request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteRequest.SerializeToString, - response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteResponse.FromString, + "/VeScaleCheckpointMemFileService/Write", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteResponse.FromString, ) self.Read = channel.unary_stream( - "/OmniStoreMemFileService/Read", - request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadRequest.SerializeToString, - response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadResponse.FromString, + "/VeScaleCheckpointMemFileService/Read", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadResponse.FromString, ) self.Rename = channel.unary_unary( - "/OmniStoreMemFileService/Rename", - request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameRequest.SerializeToString, - response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameResponse.FromString, + "/VeScaleCheckpointMemFileService/Rename", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameResponse.FromString, ) self.Remove = channel.unary_unary( - "/OmniStoreMemFileService/Remove", - request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveRequest.SerializeToString, - response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveResponse.FromString, + "/VeScaleCheckpointMemFileService/Remove", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveResponse.FromString, ) self.Listdir = channel.unary_unary( - "/OmniStoreMemFileService/Listdir", - request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirRequest.SerializeToString, - response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirResponse.FromString, + "/VeScaleCheckpointMemFileService/Listdir", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirResponse.FromString, ) self.Exists = channel.unary_unary( - "/OmniStoreMemFileService/Exists", - request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsRequest.SerializeToString, - response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsResponse.FromString, + "/VeScaleCheckpointMemFileService/Exists", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsResponse.FromString, ) -class OmniStoreMemFileServiceServicer: +class VeScaleCheckpointMemFileServiceServicer: """Missing associated documentation comment in .proto file.""" def Write(self, request_iterator, context): @@ -105,45 +105,45 @@ def Exists(self, request, context): raise NotImplementedError("Method not implemented!") -def add_OmniStoreMemFileServiceServicer_to_server(servicer, server): +def add_VeScaleCheckpointMemFileServiceServicer_to_server(servicer, server): rpc_method_handlers = { "Write": grpc.stream_unary_rpc_method_handler( servicer.Write, - request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteRequest.FromString, - response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteResponse.SerializeToString, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteResponse.SerializeToString, ), "Read": grpc.unary_stream_rpc_method_handler( servicer.Read, - request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadRequest.FromString, - response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadResponse.SerializeToString, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadResponse.SerializeToString, ), "Rename": grpc.unary_unary_rpc_method_handler( servicer.Rename, - request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameRequest.FromString, - response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameResponse.SerializeToString, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameResponse.SerializeToString, ), "Remove": grpc.unary_unary_rpc_method_handler( servicer.Remove, - request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveRequest.FromString, - response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveResponse.SerializeToString, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveResponse.SerializeToString, ), "Listdir": grpc.unary_unary_rpc_method_handler( servicer.Listdir, - request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirRequest.FromString, - response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirResponse.SerializeToString, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirResponse.SerializeToString, ), "Exists": grpc.unary_unary_rpc_method_handler( servicer.Exists, - request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsRequest.FromString, - response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsResponse.SerializeToString, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsResponse.SerializeToString, ), } - generic_handler = grpc.method_handlers_generic_handler("OmniStoreMemFileService", rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler("VeScaleCheckpointMemFileService", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) # This class is part of an EXPERIMENTAL API. -class OmniStoreMemFileService: +class VeScaleCheckpointMemFileService: """Missing associated documentation comment in .proto file.""" @staticmethod @@ -162,9 +162,9 @@ def Write( return grpc.experimental.stream_unary( request_iterator, target, - "/OmniStoreMemFileService/Write", - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteRequest.SerializeToString, - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteResponse.FromString, + "/VeScaleCheckpointMemFileService/Write", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteResponse.FromString, options, channel_credentials, insecure, @@ -191,9 +191,9 @@ def Read( return grpc.experimental.unary_stream( request, target, - "/OmniStoreMemFileService/Read", - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadRequest.SerializeToString, - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadResponse.FromString, + "/VeScaleCheckpointMemFileService/Read", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadResponse.FromString, options, channel_credentials, insecure, @@ -220,9 +220,9 @@ def Rename( return grpc.experimental.unary_unary( request, target, - "/OmniStoreMemFileService/Rename", - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameRequest.SerializeToString, - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameResponse.FromString, + "/VeScaleCheckpointMemFileService/Rename", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameResponse.FromString, options, channel_credentials, insecure, @@ -249,9 +249,9 @@ def Remove( return grpc.experimental.unary_unary( request, target, - "/OmniStoreMemFileService/Remove", - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveRequest.SerializeToString, - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveResponse.FromString, + "/VeScaleCheckpointMemFileService/Remove", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveResponse.FromString, options, channel_credentials, insecure, @@ -278,9 +278,9 @@ def Listdir( return grpc.experimental.unary_unary( request, target, - "/OmniStoreMemFileService/Listdir", - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirRequest.SerializeToString, - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirResponse.FromString, + "/VeScaleCheckpointMemFileService/Listdir", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirResponse.FromString, options, channel_credentials, insecure, @@ -307,9 +307,9 @@ def Exists( return grpc.experimental.unary_unary( request, target, - "/OmniStoreMemFileService/Exists", - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsRequest.SerializeToString, - OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsResponse.FromString, + "/VeScaleCheckpointMemFileService/Exists", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsResponse.FromString, options, channel_credentials, insecure, diff --git a/vescale/checkpoint/utilities/server/mem_server_lib.py b/vescale/checkpoint/utilities/server/mem_server_lib.py index b12a8e1..2bf3af2 100644 --- a/vescale/checkpoint/utilities/server/mem_server_lib.py +++ b/vescale/checkpoint/utilities/server/mem_server_lib.py @@ -49,7 +49,7 @@ def get_mem_server_sock_file(name: str): return f"/var/tmp/mem_server_{name}.sock" -class MemFileServicer(mem_file_service_pb2_grpc.OmniStoreMemFileServiceServicer): +class MemFileServicer(mem_file_service_pb2_grpc.VeScaleCheckpointMemFileServiceServicer): def __init__(self): self._d = _Directory() @@ -66,7 +66,7 @@ def Write(self, request_iterator, ctx: grpc.ServicerContext): if name: with d.lock: d[bn] = _File(content=b.getvalue()) - return mem_file_service_pb2.OmniStoreWriteResponse() + return mem_file_service_pb2.VeScaleCheckpointWriteResponse() def Read(self, req, ctx: grpc.ServicerContext): d, bn = self._iterate_dir(req.name, ctx) @@ -76,7 +76,7 @@ def Read(self, req, ctx: grpc.ServicerContext): f: _File = d[bn] cur = 0 while cur < len(f.content): - yield mem_file_service_pb2.OmniStoreReadResponse(content=f.content[cur : cur + _CHUNK_SIZE]) + yield mem_file_service_pb2.VeScaleCheckpointReadResponse(content=f.content[cur : cur + _CHUNK_SIZE]) cur += _CHUNK_SIZE def Rename(self, req, ctx: grpc.ServicerContext): @@ -92,7 +92,7 @@ def Rename(self, req, ctx: grpc.ServicerContext): ctx.abort(grpc.StatusCode.ALREADY_EXISTS, f"{req.dst} already exists.") d[dst_bn] = d[src_bn] del d[src_bn] - return mem_file_service_pb2.OmniStoreRenameResponse() + return mem_file_service_pb2.VeScaleCheckpointRenameResponse() def Remove(self, req, ctx: grpc.ServicerContext): d, bn = self._iterate_dir(req.name, ctx) @@ -100,14 +100,14 @@ def Remove(self, req, ctx: grpc.ServicerContext): ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.name} not found.") with d.lock: del d[bn] - return mem_file_service_pb2.OmniStoreRemoveResponse() + return mem_file_service_pb2.VeScaleCheckpointRemoveResponse() def Listdir(self, req, ctx: grpc.ServicerContext): d, _ = self._iterate_dir(os.path.join(req.name, "*")) if d is None: - return mem_file_service_pb2.OmniStoreListdirResponse() + return mem_file_service_pb2.VeScaleCheckpointListdirResponse() - resp = mem_file_service_pb2.OmniStoreListdirResponse() + resp = mem_file_service_pb2.VeScaleCheckpointListdirResponse() with d.lock: for name in d: resp.names.append(name) @@ -116,9 +116,9 @@ def Listdir(self, req, ctx: grpc.ServicerContext): def Exists(self, req, ctx: grpc.ServicerContext): d, bn = self._iterate_dir(req.name) if d is None: - return mem_file_service_pb2.OmniStoreExistsResponse(exists=False) + return mem_file_service_pb2.VeScaleCheckpointExistsResponse(exists=False) with d.lock: - return mem_file_service_pb2.OmniStoreExistsResponse(exists=bn in d) + return mem_file_service_pb2.VeScaleCheckpointExistsResponse(exists=bn in d) def _iterate_dir(self, name: str, ctx: grpc.ServicerContext = None, create=False) -> Tuple[_Directory, str]: if ctx is None: @@ -152,7 +152,7 @@ def start_server(name: str, force=False): if os.path.exists(sock) and not force: raise OSError("Mem server is already running.") server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - mem_file_service_pb2_grpc.add_OmniStoreMemFileServiceServicer_to_server(MemFileServicer(), server) + mem_file_service_pb2_grpc.add_VeScaleCheckpointMemFileServiceServicer_to_server(MemFileServicer(), server) server.add_insecure_port(f"unix:{sock}") server.start() return server @@ -180,12 +180,12 @@ def _get_mem_name_and_name(path: str): def _get_stub_and_name( path: str, -) -> Tuple[mem_file_service_pb2_grpc.OmniStoreMemFileServiceStub, str]: +) -> Tuple[mem_file_service_pb2_grpc.VeScaleCheckpointMemFileServiceStub, str]: mem_name, name = _get_mem_name_and_name(path) if mem_name not in _STUB_CACHE: c = grpc.insecure_channel(f"unix:{get_mem_server_sock_file(mem_name)}") with _STUB_CACHE_LOCK: - _STUB_CACHE[mem_name] = mem_file_service_pb2_grpc.OmniStoreMemFileServiceStub(c) + _STUB_CACHE[mem_name] = mem_file_service_pb2_grpc.VeScaleCheckpointMemFileServiceStub(c) return _STUB_CACHE[mem_name], name @@ -204,7 +204,7 @@ def __init__(self, name: str, mode: str): def read_buf(self): if self._read_buf is None: self._read_buf = io.BytesIO() - for resp in self._stub.Read(mem_file_service_pb2.OmniStoreReadRequest(name=self._name)): + for resp in self._stub.Read(mem_file_service_pb2.VeScaleCheckpointReadRequest(name=self._name)): self._read_buf.write(resp.content) self._read_buf.seek(0) return self._read_buf @@ -223,7 +223,7 @@ def streaming(): break cur = 0 while cur < len(content): - req = mem_file_service_pb2.OmniStoreWriteRequest(content=content[cur : cur + _CHUNK_SIZE]) + req = mem_file_service_pb2.VeScaleCheckpointWriteRequest(content=content[cur : cur + _CHUNK_SIZE]) if cur == 0: req.name = self._name yield req @@ -254,18 +254,18 @@ def rename(src, dst, overwrite=False): dst_stub, dst_name = _get_stub_and_name(dst) if stub != dst_stub: raise ValueError(f"Rename across mem file system is not supported. {src} {dst}") - stub.Rename(mem_file_service_pb2.OmniStoreRenameRequest(src=src_name, dst=dst_name, overwrite=overwrite)) + stub.Rename(mem_file_service_pb2.VeScaleCheckpointRenameRequest(src=src_name, dst=dst_name, overwrite=overwrite)) def remove(name): stub, subname = _get_stub_and_name(name) - stub.Remove(mem_file_service_pb2.OmniStoreRemoveRequest(name=subname)) + stub.Remove(mem_file_service_pb2.VeScaleCheckpointRemoveRequest(name=subname)) def listdir(name): try: stub, subname = _get_stub_and_name(name) - resp = stub.Listdir(mem_file_service_pb2.OmniStoreListdirRequest(name=subname)) + resp = stub.Listdir(mem_file_service_pb2.VeScaleCheckpointListdirRequest(name=subname)) return list(resp.names) except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNAVAILABLE: @@ -276,7 +276,7 @@ def listdir(name): def exists(name): try: stub, subname = _get_stub_and_name(name) - resp = stub.Exists(mem_file_service_pb2.OmniStoreExistsRequest(name=subname)) + resp = stub.Exists(mem_file_service_pb2.VeScaleCheckpointExistsRequest(name=subname)) return resp.exists except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNAVAILABLE: @@ -297,7 +297,7 @@ def wait_until_fs_ready(name: str, timeout=120): t0 = time.time() while time.time() < t0 + timeout: try: - stub.Listdir(mem_file_service_pb2.OmniStoreListdirRequest(name="/")) + stub.Listdir(mem_file_service_pb2.VeScaleCheckpointListdirRequest(name="/")) return True except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNAVAILABLE: diff --git a/vescale/checkpoint/utilities/server/report_service.proto b/vescale/checkpoint/utilities/server/report_service.proto index d5ead0c..fe6d8a8 100644 --- a/vescale/checkpoint/utilities/server/report_service.proto +++ b/vescale/checkpoint/utilities/server/report_service.proto @@ -1,13 +1,13 @@ // Run // // python -m grpc_tools.protoc -I. --python_out=. --pyi_out=. \ -// --grpc_python_out=. ./omnistore/utilities/server/report_service.proto +// --grpc_python_out=. ./checkpoint/utilities/server/report_service.proto // // to generate new protos. syntax = "proto3"; -message OmniStoreGatherRequest { +message VeScaleCheckpointGatherRequest { // Used to distinguish different tasks. string tag = 1; int32 rank = 2; @@ -15,35 +15,35 @@ message OmniStoreGatherRequest { bool with_result = 4; } -message OmniStoreGatherResponse { +message VeScaleCheckpointGatherResponse { repeated bytes contents = 1; } -message OmniStoreBroadcastRequest { +message VeScaleCheckpointBroadcastRequest { string tag = 1; int32 rank = 2; bytes content = 3; int32 src_rank = 4; } -message OmniStoreBroadcastResponse { +message VeScaleCheckpointBroadcastResponse { bytes content = 1; } -message OmniStoreGetStatusRequest { +message VeScaleCheckpointGetStatusRequest { } -message OmniStoreGetStatusResponse { +message VeScaleCheckpointGetStatusResponse { bytes status = 1; } -service OmniStoreReportService { - rpc Gather(OmniStoreGatherRequest) returns (OmniStoreGatherResponse) { +service VeScaleCheckpointReportService { + rpc Gather(VeScaleCheckpointGatherRequest) returns (VeScaleCheckpointGatherResponse) { } - rpc Broadcast(OmniStoreBroadcastRequest) returns (OmniStoreBroadcastResponse) { + rpc Broadcast(VeScaleCheckpointBroadcastRequest) returns (VeScaleCheckpointBroadcastResponse) { } - rpc GetStatus(OmniStoreGetStatusRequest) returns (OmniStoreGetStatusResponse) { + rpc GetStatus(VeScaleCheckpointGetStatusRequest) returns (VeScaleCheckpointGetStatusResponse) { } } diff --git a/vescale/checkpoint/utilities/server/report_service_pb2.py b/vescale/checkpoint/utilities/server/report_service_pb2.py index 2a7f110..be16555 100644 --- a/vescale/checkpoint/utilities/server/report_service_pb2.py +++ b/vescale/checkpoint/utilities/server/report_service_pb2.py @@ -15,8 +15,8 @@ # ################################################################################ # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: omnistore/utilities/server/report_service.proto -# Protobuf Python Version: 4.25.0 +# source: checkpoint/utilities/server/report_service.proto +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor @@ -29,26 +29,26 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n/omnistore/utilities/server/report_service.proto"Y\n\x16OmniStoreGatherRequest\x12\x0b\n\x03tag\x18\x01 \x01(\t\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\x0c\x12\x13\n\x0bwith_result\x18\x04 \x01(\x08"+\n\x17OmniStoreGatherResponse\x12\x10\n\x08\x63ontents\x18\x01 \x03(\x0c"Y\n\x19OmniStoreBroadcastRequest\x12\x0b\n\x03tag\x18\x01 \x01(\t\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\x0c\x12\x10\n\x08src_rank\x18\x04 \x01(\x05"-\n\x1aOmniStoreBroadcastResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c"\x1b\n\x19OmniStoreGetStatusRequest",\n\x1aOmniStoreGetStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\x0c\x32\xe7\x01\n\x16OmniStoreReportService\x12=\n\x06Gather\x12\x17.OmniStoreGatherRequest\x1a\x18.OmniStoreGatherResponse"\x00\x12\x46\n\tBroadcast\x12\x1a.OmniStoreBroadcastRequest\x1a\x1b.OmniStoreBroadcastResponse"\x00\x12\x46\n\tGetStatus\x12\x1a.OmniStoreGetStatusRequest\x1a\x1b.OmniStoreGetStatusResponse"\x00\x62\x06proto3' + b'\n0checkpoint/utilities/server/report_service.proto"a\n\x1eVeScaleCheckpointGatherRequest\x12\x0b\n\x03tag\x18\x01 \x01(\t\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\x0c\x12\x13\n\x0bwith_result\x18\x04 \x01(\x08"3\n\x1fVeScaleCheckpointGatherResponse\x12\x10\n\x08\x63ontents\x18\x01 \x03(\x0c"a\n!VeScaleCheckpointBroadcastRequest\x12\x0b\n\x03tag\x18\x01 \x01(\t\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\x0c\x12\x10\n\x08src_rank\x18\x04 \x01(\x05"5\n"VeScaleCheckpointBroadcastResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c"#\n!VeScaleCheckpointGetStatusRequest"4\n"VeScaleCheckpointGetStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\x0c\x32\x9f\x02\n\x1eVeScaleCheckpointReportService\x12M\n\x06Gather\x12\x1f.VeScaleCheckpointGatherRequest\x1a .VeScaleCheckpointGatherResponse"\x00\x12V\n\tBroadcast\x12".VeScaleCheckpointBroadcastRequest\x1a#.VeScaleCheckpointBroadcastResponse"\x00\x12V\n\tGetStatus\x12".VeScaleCheckpointGetStatusRequest\x1a#.VeScaleCheckpointGetStatusResponse"\x00\x62\x06proto3' ) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "omnistore.utilities.server.report_service_pb2", _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "checkpoint.utilities.server.report_service_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS is False: DESCRIPTOR._options = None - _globals["_OMNISTOREGATHERREQUEST"]._serialized_start = 51 - _globals["_OMNISTOREGATHERREQUEST"]._serialized_end = 140 - _globals["_OMNISTOREGATHERRESPONSE"]._serialized_start = 142 - _globals["_OMNISTOREGATHERRESPONSE"]._serialized_end = 185 - _globals["_OMNISTOREBROADCASTREQUEST"]._serialized_start = 187 - _globals["_OMNISTOREBROADCASTREQUEST"]._serialized_end = 276 - _globals["_OMNISTOREBROADCASTRESPONSE"]._serialized_start = 278 - _globals["_OMNISTOREBROADCASTRESPONSE"]._serialized_end = 323 - _globals["_OMNISTOREGETSTATUSREQUEST"]._serialized_start = 325 - _globals["_OMNISTOREGETSTATUSREQUEST"]._serialized_end = 352 - _globals["_OMNISTOREGETSTATUSRESPONSE"]._serialized_start = 354 - _globals["_OMNISTOREGETSTATUSRESPONSE"]._serialized_end = 398 - _globals["_OMNISTOREREPORTSERVICE"]._serialized_start = 401 - _globals["_OMNISTOREREPORTSERVICE"]._serialized_end = 632 + _globals["_VESCALECHECKPOINTGATHERREQUEST"]._serialized_start = 52 + _globals["_VESCALECHECKPOINTGATHERREQUEST"]._serialized_end = 149 + _globals["_VESCALECHECKPOINTGATHERRESPONSE"]._serialized_start = 151 + _globals["_VESCALECHECKPOINTGATHERRESPONSE"]._serialized_end = 202 + _globals["_VESCALECHECKPOINTBROADCASTREQUEST"]._serialized_start = 204 + _globals["_VESCALECHECKPOINTBROADCASTREQUEST"]._serialized_end = 301 + _globals["_VESCALECHECKPOINTBROADCASTRESPONSE"]._serialized_start = 303 + _globals["_VESCALECHECKPOINTBROADCASTRESPONSE"]._serialized_end = 356 + _globals["_VESCALECHECKPOINTGETSTATUSREQUEST"]._serialized_start = 358 + _globals["_VESCALECHECKPOINTGETSTATUSREQUEST"]._serialized_end = 393 + _globals["_VESCALECHECKPOINTGETSTATUSRESPONSE"]._serialized_start = 395 + _globals["_VESCALECHECKPOINTGETSTATUSRESPONSE"]._serialized_end = 447 + _globals["_VESCALECHECKPOINTREPORTSERVICE"]._serialized_start = 450 + _globals["_VESCALECHECKPOINTREPORTSERVICE"]._serialized_end = 737 # @@protoc_insertion_point(module_scope) diff --git a/vescale/checkpoint/utilities/server/report_service_pb2.pyi b/vescale/checkpoint/utilities/server/report_service_pb2.pyi index a031c11..705daaf 100644 --- a/vescale/checkpoint/utilities/server/report_service_pb2.pyi +++ b/vescale/checkpoint/utilities/server/report_service_pb2.pyi @@ -1,3 +1,20 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message @@ -5,7 +22,7 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Optional as _Op DESCRIPTOR: _descriptor.FileDescriptor -class OmniStoreGatherRequest(_message.Message): +class VeScaleCheckpointGatherRequest(_message.Message): __slots__ = ("tag", "rank", "content", "with_result") TAG_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] @@ -23,13 +40,13 @@ class OmniStoreGatherRequest(_message.Message): with_result: bool = ..., ) -> None: ... -class OmniStoreGatherResponse(_message.Message): +class VeScaleCheckpointGatherResponse(_message.Message): __slots__ = ("contents",) CONTENTS_FIELD_NUMBER: _ClassVar[int] contents: _containers.RepeatedScalarFieldContainer[bytes] def __init__(self, contents: _Optional[_Iterable[bytes]] = ...) -> None: ... -class OmniStoreBroadcastRequest(_message.Message): +class VeScaleCheckpointBroadcastRequest(_message.Message): __slots__ = ("tag", "rank", "content", "src_rank") TAG_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] @@ -47,17 +64,17 @@ class OmniStoreBroadcastRequest(_message.Message): src_rank: _Optional[int] = ..., ) -> None: ... -class OmniStoreBroadcastResponse(_message.Message): +class VeScaleCheckpointBroadcastResponse(_message.Message): __slots__ = ("content",) CONTENT_FIELD_NUMBER: _ClassVar[int] content: bytes def __init__(self, content: _Optional[bytes] = ...) -> None: ... -class OmniStoreGetStatusRequest(_message.Message): +class VeScaleCheckpointGetStatusRequest(_message.Message): __slots__ = () def __init__(self) -> None: ... -class OmniStoreGetStatusResponse(_message.Message): +class VeScaleCheckpointGetStatusResponse(_message.Message): __slots__ = ("status",) STATUS_FIELD_NUMBER: _ClassVar[int] status: bytes diff --git a/vescale/checkpoint/utilities/server/report_service_pb2_grpc.py b/vescale/checkpoint/utilities/server/report_service_pb2_grpc.py index 85f55c4..c5257c8 100644 --- a/vescale/checkpoint/utilities/server/report_service_pb2_grpc.py +++ b/vescale/checkpoint/utilities/server/report_service_pb2_grpc.py @@ -19,10 +19,12 @@ import grpc -from . import report_service_pb2 as OmniStore_dot_utilities_dot_server_dot_report__service__pb2 +from . import ( + report_service_pb2 as checkpoint_dot_utilities_dot_server_dot_report__service__pb2, +) -class OmniStoreReportServiceStub: +class VeScaleCheckpointReportServiceStub: """Missing associated documentation comment in .proto file.""" def __init__(self, channel): @@ -32,23 +34,23 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Gather = channel.unary_unary( - "/OmniStoreReportService/Gather", - request_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherRequest.SerializeToString, - response_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherResponse.FromString, + "/VeScaleCheckpointReportService/Gather", + request_serializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGatherRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGatherResponse.FromString, ) self.Broadcast = channel.unary_unary( - "/OmniStoreReportService/Broadcast", - request_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastRequest.SerializeToString, - response_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastResponse.FromString, + "/VeScaleCheckpointReportService/Broadcast", + request_serializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointBroadcastRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointBroadcastResponse.FromString, ) self.GetStatus = channel.unary_unary( - "/OmniStoreReportService/GetStatus", - request_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusRequest.SerializeToString, - response_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusResponse.FromString, + "/VeScaleCheckpointReportService/GetStatus", + request_serializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGetStatusRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGetStatusResponse.FromString, ) -class OmniStoreReportServiceServicer: +class VeScaleCheckpointReportServiceServicer: """Missing associated documentation comment in .proto file.""" def Gather(self, request, context): @@ -70,30 +72,30 @@ def GetStatus(self, request, context): raise NotImplementedError("Method not implemented!") -def add_OmniStoreReportServiceServicer_to_server(servicer, server): +def add_VeScaleCheckpointReportServiceServicer_to_server(servicer, server): rpc_method_handlers = { "Gather": grpc.unary_unary_rpc_method_handler( servicer.Gather, - request_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherRequest.FromString, - response_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherResponse.SerializeToString, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGatherRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGatherResponse.SerializeToString, ), "Broadcast": grpc.unary_unary_rpc_method_handler( servicer.Broadcast, - request_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastRequest.FromString, - response_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastResponse.SerializeToString, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointBroadcastRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointBroadcastResponse.SerializeToString, ), "GetStatus": grpc.unary_unary_rpc_method_handler( servicer.GetStatus, - request_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusRequest.FromString, - response_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusResponse.SerializeToString, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGetStatusRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGetStatusResponse.SerializeToString, ), } - generic_handler = grpc.method_handlers_generic_handler("OmniStoreReportService", rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler("VeScaleCheckpointReportService", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) # This class is part of an EXPERIMENTAL API. -class OmniStoreReportService: +class VeScaleCheckpointReportService: """Missing associated documentation comment in .proto file.""" @staticmethod @@ -112,9 +114,9 @@ def Gather( return grpc.experimental.unary_unary( request, target, - "/OmniStoreReportService/Gather", - OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherRequest.SerializeToString, - OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherResponse.FromString, + "/VeScaleCheckpointReportService/Gather", + checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGatherRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGatherResponse.FromString, options, channel_credentials, insecure, @@ -141,9 +143,9 @@ def Broadcast( return grpc.experimental.unary_unary( request, target, - "/OmniStoreReportService/Broadcast", - OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastRequest.SerializeToString, - OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastResponse.FromString, + "/VeScaleCheckpointReportService/Broadcast", + checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointBroadcastRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointBroadcastResponse.FromString, options, channel_credentials, insecure, @@ -170,9 +172,9 @@ def GetStatus( return grpc.experimental.unary_unary( request, target, - "/OmniStoreReportService/GetStatus", - OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusRequest.SerializeToString, - OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusResponse.FromString, + "/VeScaleCheckpointReportService/GetStatus", + checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGetStatusRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_report__service__pb2.VeScaleCheckpointGetStatusResponse.FromString, options, channel_credentials, insecure, diff --git a/vescale/checkpoint/utilities/server/server_lib.py b/vescale/checkpoint/utilities/server/server_lib.py index ee6830b..01cc1f4 100644 --- a/vescale/checkpoint/utilities/server/server_lib.py +++ b/vescale/checkpoint/utilities/server/server_lib.py @@ -44,7 +44,7 @@ class Item: ] -class ReportServicer(report_service_pb2_grpc.OmniStoreReportServiceServicer): +class ReportServicer(report_service_pb2_grpc.VeScaleCheckpointReportServiceServicer): """A servicer that simulate `gather` in sync training. Using asyncio since we will block all incoming requests until we gather all. @@ -60,17 +60,17 @@ def __init__(self, world_size: int): self._gather_dict = DefaultDict(Item) self._bc_dict = DefaultDict(Item) - async def Gather(self, req: report_service_pb2.OmniStoreGatherRequest, ctx: grpc.aio.ServicerContext): + async def Gather(self, req: report_service_pb2.VeScaleCheckpointGatherRequest, ctx: grpc.aio.ServicerContext): i = await self._record(self._gather_dict, req, ctx) - resp = report_service_pb2.OmniStoreGatherResponse() + resp = report_service_pb2.VeScaleCheckpointGatherResponse() if req.with_result: resp.contents.extend([v for k, v in sorted(i.contents.items(), key=lambda x: x[0])]) return resp - async def Broadcast(self, req: report_service_pb2.OmniStoreBroadcastRequest, ctx: grpc.aio.ServicerContext): + async def Broadcast(self, req: report_service_pb2.VeScaleCheckpointBroadcastRequest, ctx: grpc.aio.ServicerContext): i = await self._record(self._bc_dict, req, ctx) - return report_service_pb2.OmniStoreBroadcastResponse(content=i.contents[req.src_rank]) + return report_service_pb2.VeScaleCheckpointBroadcastResponse(content=i.contents[req.src_rank]) async def _record(self, d: Dict[str, Item], req, ctx: grpc.aio.ServicerContext): async with self._l: @@ -91,7 +91,7 @@ async def _record(self, d: Dict[str, Item], req, ctx: grpc.aio.ServicerContext): await i.cv.wait_for(lambda: len(i.ranks) == self._world_size) return i - async def GetStatus(self, req: report_service_pb2.OmniStoreGetStatusRequest, ctx: grpc.aio.ServicerContext): + async def GetStatus(self, req: report_service_pb2.VeScaleCheckpointGetStatusRequest, ctx: grpc.aio.ServicerContext): async with self._l: b = pickle.dumps( { @@ -100,7 +100,7 @@ async def GetStatus(self, req: report_service_pb2.OmniStoreGetStatusRequest, ctx "bc_dict": self._bc_dict, } ) - return report_service_pb2.OmniStoreGetStatusResponse(status=b) + return report_service_pb2.VeScaleCheckpointGetStatusResponse(status=b) def _is_ipv6_address(ip: str): @@ -133,7 +133,7 @@ class _AsyncObj: async def async_serve(servicer, async_addr: _AsyncObj): server: grpc.Server = grpc.aio.server(options=_GRPC_OPTIONS) - report_service_pb2_grpc.add_OmniStoreReportServiceServicer_to_server(servicer, server) + report_service_pb2_grpc.add_VeScaleCheckpointReportServiceServicer_to_server(servicer, server) port = server.add_insecure_port("[::]:0") await server.start() async_addr.obj = _concat_ip_and_port(_get_local_ip(), port) @@ -170,7 +170,7 @@ def start_server_in_new_process(world_size: int): def get_stub(addr: str): channel = grpc.insecure_channel(addr, options=_GRPC_OPTIONS) - return report_service_pb2_grpc.OmniStoreReportServiceStub(channel) + return report_service_pb2_grpc.VeScaleCheckpointReportServiceStub(channel) def _get_tag(): @@ -178,7 +178,7 @@ def _get_tag(): def gather( - stub: report_service_pb2_grpc.OmniStoreReportServiceStub, + stub: report_service_pb2_grpc.VeScaleCheckpointReportServiceStub, gather_rank: int, rank: int, obj, @@ -186,7 +186,7 @@ def gather( timeout=None, ): tag = tag or _get_tag() - req = report_service_pb2.OmniStoreGatherRequest( + req = report_service_pb2.VeScaleCheckpointGatherRequest( tag=tag, rank=rank, content=pickle.dumps(obj), with_result=(gather_rank == rank) ) resp = stub.Gather(req, timeout=timeout) @@ -196,7 +196,7 @@ def gather( def broadcast( - stub: report_service_pb2_grpc.OmniStoreReportServiceStub, + stub: report_service_pb2_grpc.VeScaleCheckpointReportServiceStub, src_rank: int, rank: int, obj=None, @@ -208,7 +208,7 @@ def broadcast( # Since we will transfer this to all machines, compression here is important. c_content = zlib.compress(content) resp = stub.Broadcast( - report_service_pb2.OmniStoreBroadcastRequest(tag=tag, rank=rank, content=c_content, src_rank=src_rank), + report_service_pb2.VeScaleCheckpointBroadcastRequest(tag=tag, rank=rank, content=c_content, src_rank=src_rank), timeout=timeout, ) content = zlib.decompress(resp.content) @@ -216,7 +216,7 @@ def broadcast( def barrier( - stub: report_service_pb2_grpc.OmniStoreReportServiceStub, + stub: report_service_pb2_grpc.VeScaleCheckpointReportServiceStub, rank: int, tag: str = None, timeout=None, @@ -224,6 +224,6 @@ def barrier( gather(stub, 0, rank, tag=tag, obj=None, timeout=timeout) -def get_server_status(stub: report_service_pb2_grpc.OmniStoreReportServiceStub): - resp = stub.GetStatus(report_service_pb2.OmniStoreGetStatusRequest()) +def get_server_status(stub: report_service_pb2_grpc.VeScaleCheckpointReportServiceStub): + resp = stub.GetStatus(report_service_pb2.VeScaleCheckpointGetStatusRequest()) return pickle.loads(resp.status) diff --git a/vescale/checkpoint/version.py b/vescale/checkpoint/version.py index f7d2536..3459a5a 100644 --- a/vescale/checkpoint/version.py +++ b/vescale/checkpoint/version.py @@ -14,4 +14,4 @@ # limitations under the License. # ################################################################################ -__version__ = "0.1.5" +__version__ = "0.1.18" diff --git a/vescale/dmodule/_dmodule.py b/vescale/dmodule/_dmodule.py index 5e03a30..9247433 100644 --- a/vescale/dmodule/_dmodule.py +++ b/vescale/dmodule/_dmodule.py @@ -197,6 +197,11 @@ def register_sharding_plan( v, Sequence ), "the placements for variable position arguments have to be list" pis[k] = [_norm_one_placements(p) for p in v] + elif isinstance(v, Dict): # nested dict sharding plan + pis[k] = {k_: _norm_one_placements(v_) for k_, v_ in v.items()} + elif (isinstance(v, (List, Tuple))) and not isinstance(v[0], Placement): + # nested list/tuple sharding plan + pis[k] = [_norm_one_placements(p) for p in v] else: pis[k] = _norm_one_placements(v) # register plan diff --git a/vescale/dmodule/_hook.py b/vescale/dmodule/_hook.py index 96277da..63523b5 100644 --- a/vescale/dmodule/_hook.py +++ b/vescale/dmodule/_hook.py @@ -78,10 +78,26 @@ class PreHookInput: def _convert(x: Any, pi: Optional[PI], device_mesh: DeviceMesh): return _convert_by_pi(x, pi, device_mesh, raise_err=False) + @staticmethod + def _convert_dictlike(input_dict: Dict[str, Any], pi_dict: DictFwdPIs, device_mesh: DeviceMesh): + assert isinstance(pi_dict, Dict), f"{type(input_dict)}" + new_output = {} + for key in input_dict: + if key in pi_dict: + new_output[key] = PreHookInput._convert(input_dict[key], pi_dict[key], device_mesh) + else: + new_output[key] = input_dict[key] + return type(input_dict)(**new_output) + + @staticmethod + def _convert_listlike(input_list: Dict[str, Any], pi_list: DictFwdPIs, device_mesh: DeviceMesh): + return [_convert_by_pi(x, pi, device_mesh, raise_err=False) for x, pi in zip(input_list, pi_list)] + @staticmethod def _hook(module: nn.Module, args: Any, kwargs: Any, device_mesh: DeviceMesh, input_pis: FwdPIs): convert = lambda x, pi: PreHookInput._convert(x, pi, device_mesh) convert_dictlike = lambda x_dict, pi_dict: PreHookInput._convert_dictlike(x_dict, pi_dict, device_mesh) + convert_listlike = lambda x_list, pi_list: PreHookInput._convert_listlike(x_list, pi_list, device_mesh) func_sig = get_sig(module) bound_args = func_sig.bind(*args, **kwargs) bound_args.apply_defaults() @@ -138,7 +154,12 @@ def _hook(module: nn.Module, args: Any, kwargs: Any, device_mesh: DeviceMesh, in bound_args.arguments["args"] = new_var_pos continue pi = input_pis[k] - bound_args.arguments[k] = convert(v, pi) + if isinstance(pi, Dict): + bound_args.arguments[k] = convert_dictlike(v, pi) + elif (isinstance(pi, (list, tuple))) and not isinstance(pi[0], Placement): + bound_args.arguments[k] = convert_listlike(v, pi) + else: + bound_args.arguments[k] = convert(v, pi) if var_keyward_name is None: return bound_args.args, bound_args.kwargs for k, v in bound_args.arguments[var_keyward_name].items(): diff --git a/vescale/dtensor/_dispatch_bypass.py b/vescale/dtensor/_dispatch_bypass.py index fb7bb37..fa81cf5 100644 --- a/vescale/dtensor/_dispatch_bypass.py +++ b/vescale/dtensor/_dispatch_bypass.py @@ -8,18 +8,17 @@ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ +import functools +import operator + from typing import Dict, Tuple, cast import torch +import torch.distributed as dist -from vescale.dtensor.op_schema import ( - DTensorSpec, - OpInfo, - OutputSharding, -) -from vescale.dtensor.placement_types import TensorMeta +import vescale -__all__ = ["_bypass_for_dispatch", "_bypass_for_sharding_prop"] +__all__ = ["_bypass_for_dispatch"] aten = torch.ops.aten @@ -31,9 +30,15 @@ class BypassOpDispatch: def __init__(self): self.op_handlers = { + # origin bypass op dispatch func aten.linear.default: BypassOpDispatch.decompose_handler, aten.is_same_size.default: BypassOpDispatch.is_same_size_handler, + # from bypass op sharding prop aten.nonzero.default: BypassOpDispatch.nonzero_handler, + aten._to_copy.default: BypassOpDispatch.copy_handler, + aten._local_scalar_dense.default: BypassOpDispatch.scalar_handler, + aten.equal.default: BypassOpDispatch.equal_handler, + # other ? } def apply( @@ -42,9 +47,9 @@ def apply( args: Tuple[object, ...], kwargs: Dict[str, object], ) -> Tuple[bool, object]: - is_bypass = op_call in self.op_handlers - if is_bypass: - return True, self.op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] + bypass_call = self.op_handlers.get(op_call, None) + if bypass_call is not None: + return True, bypass_call(op_call, args, kwargs) # type: ignore[operator] else: return False, None @@ -80,16 +85,14 @@ def nonzero_handler( args: Tuple[object, ...], kwargs: Dict[str, object], ) -> object: - from vescale.dtensor import DTensor - input_ = kwargs.get("input", args[0]) - assert isinstance(input_, DTensor) + assert isinstance(input_, vescale.dtensor.DTensor) input_spec = input_._spec all_replicate = all(p.is_replicate() for p in input_spec.placements) assert all_replicate, "input placement has to be replicate" input_local = input_._local_tensor output_local = op_call(input_local) - return DTensor( + return vescale.dtensor.DTensor( local_tensor=output_local, device_mesh=input_spec.mesh, placements=input_spec.placements, @@ -99,6 +102,58 @@ def nonzero_handler( stride=output_local.stride(), ) + @staticmethod + def copy_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> object: + input_dtensor = args[0] + input_local = input_dtensor._local_tensor + output_local = op_call(*(input_local, *(args[1:])), **kwargs) + return vescale.dtensor.DTensor( + local_tensor=output_local, + device_mesh=input_dtensor.device_mesh, + placements=input_dtensor.placements, + shape=input_dtensor.shape, + dtype=output_local.dtype, + requires_grad=output_local.requires_grad, + stride=input_dtensor.stride(), + ) + + @staticmethod + def scalar_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> object: + input_dtensor = args[0] + input_local = input_dtensor._local_tensor + return op_call(*(input_local, *(args[1:])), **kwargs) + + @staticmethod + def equal_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> object: + dtensor0 = args[0] + dtensor1 = args[1] + local0 = dtensor0._local_tensor + local1 = dtensor1._local_tensor + local_results = op_call(local0, local1, *(args[2:]), **kwargs) + if dtensor0._spec.is_replicated() and dtensor1._spec.is_replicated(): + return local_results + + obj_list = [None] * dist.get_world_size() + dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] + obj_list = [e for e in obj_list if e is not None] + # perform reduce on the collection with AND op + # :NOTE: here is an implicit communication + local_results = functools.reduce(operator.and_, obj_list, True) + + return local_results + _bypass_op_dispatch = BypassOpDispatch() @@ -112,57 +167,3 @@ def _bypass_for_dispatch( Put bypass logic here before entering dtensor dispatching logic """ return _bypass_op_dispatch.apply(op_call, args, kwargs) - - -class BypassOpShardingProp: - """ - Register custom op handler to bypass sharding propagation here - """ - - def __init__(self): - self.op_handlers = { - aten._to_copy.default: BypassOpShardingProp.copy_handler, - aten._local_scalar_dense.default: BypassOpShardingProp.scalar_handler, - aten.equal.default: BypassOpShardingProp.scalar_handler, - } - - def apply(self, op_info: OpInfo) -> bool: - is_bypass = op_info.schema.op in self.op_handlers - if is_bypass: - op_info.output_sharding = self.op_handlers[op_info.schema.op](op_info) - return True - else: - return False - - @staticmethod - def copy_handler(op_info: OpInfo) -> OutputSharding: - op_schema = op_info.schema - kwargs = op_schema.gen_fake_kwargs() - dtype = kwargs["dtype"] - args_spec0 = op_schema.args_spec[0] - out_tensor_meta = TensorMeta( - shape=args_spec0.tensor_meta.shape, - stride=args_spec0.tensor_meta.stride, - dtype=dtype, - ) - return OutputSharding( - output_spec=DTensorSpec( - mesh=args_spec0.mesh, - placements=args_spec0.placements, - tensor_meta=out_tensor_meta, - ) - ) - - @staticmethod - def scalar_handler(op_info: OpInfo) -> OutputSharding: - return OutputSharding(None, [op_info.schema]) - - -_bypass_op_sharding_prop = BypassOpShardingProp() - - -def _bypass_for_sharding_prop(op_info: OpInfo) -> bool: - """ - Put bypass logic here before entering dtensor sharding propagation logic - """ - return _bypass_op_sharding_prop.apply(op_info) diff --git a/vescale/dtensor/_dispatch_patch.py b/vescale/dtensor/_dispatch_patch.py index cd0dd42..a792232 100644 --- a/vescale/dtensor/_dispatch_patch.py +++ b/vescale/dtensor/_dispatch_patch.py @@ -34,58 +34,102 @@ } -def hack_for_special_op( - op_call: torch._ops.OpOverload, - args: Tuple[object, ...], - kwargs: Dict[str, object], -): - new_args = list(args) - op_name = str(op_call) - if ( - op_name == "aten.index_put.default" - and not isinstance(args[2], dtensor.DTensor) - and isinstance(args[2], torch.Tensor) - and isinstance(args[0], dtensor.DTensor) +class DispatchPrePatch: + def __init__(self) -> None: + self.op_hackers = { + aten.index_put.default: DispatchPrePatch.index_put_handler, + aten.scatter_.value: DispatchPrePatch.scatter_handler, + aten.scatter.value: DispatchPrePatch.scatter_handler, + aten.scatter_.src: DispatchPrePatch.scatter_handler, + aten.scatter.src: DispatchPrePatch.scatter_handler, + aten.eq.Tensor: DispatchPrePatch.eq_tensor_handler, + aten.index.Tensor: DispatchPrePatch.index_tensor_handler, + } + + def apply( + self, + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> Tuple[Tuple[object, ...], Dict[str, object]]: + hack_call = self.op_hackers.get(op_call, None) + if hack_call is not None: + return hack_call(op_call, args, kwargs) + else: + return args, kwargs + + @staticmethod + def index_put_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], ): - device_mesh = args[0]._spec.mesh - sharding = args[0]._spec.placements - new_args[2] = dtensor.DTensor.from_local(new_args[2], device_mesh, sharding) - return tuple(new_args), kwargs - elif ( - op_name in ["aten.scatter_.value", "aten.scatter.value", "aten.scatter_.src", "aten.scatter.src"] - and not isinstance(args[0], dtensor.DTensor) - and isinstance(args[0], torch.Tensor) - and isinstance(args[2], dtensor.DTensor) + if ( + (not isinstance(args[2], dtensor.DTensor)) + and isinstance(args[2], torch.Tensor) + and isinstance(args[0], dtensor.DTensor) + ): + device_mesh = args[0]._spec.mesh + sharding = args[0]._spec.placements + new_args_2 = dtensor.DTensor.from_local(args[2], device_mesh, sharding, run_check=False) + return (*args[:2], new_args_2, *args[3:]), kwargs + return args, kwargs + + @staticmethod + def scatter_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], ): - device_mesh = args[2]._spec.mesh - new_args[0] = dtensor.DTensor.from_local(new_args[0], device_mesh, [Replicate()]) - return tuple(new_args), kwargs - elif ( - str(op_call) == "aten.eq.Tensor" - and not isinstance(args[1], dtensor.DTensor) - and isinstance(args[0], dtensor.DTensor) - and isinstance(args[1], torch.Tensor) + if ( + (not isinstance(args[0], dtensor.DTensor)) + and isinstance(args[0], torch.Tensor) + and isinstance(args[2], dtensor.DTensor) + ): + device_mesh = args[2]._spec.mesh + new_args_0 = dtensor.DTensor.from_local(args[0], device_mesh, [Replicate()], run_check=False) + return (new_args_0, *args[1:]), kwargs + return args, kwargs + + @staticmethod + def eq_tensor_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], ): - device_mesh = args[0]._spec.mesh - new_args[1] = dtensor.DTensor.from_local(new_args[1], device_mesh, [Replicate()]) - return tuple(new_args), kwargs - # hack to DTensorialize the index of aten.index.Tensor op. - elif op_call in [aten.index.Tensor] and isinstance(args[0], dtensor.DTensor): - device_mesh = args[0]._spec.mesh - new_args = [] - new_args.append(args[0]) - new_args.append( - [ - dtensor.DTensor.from_local(x, device_mesh, [Replicate()], run_check=False) - if isinstance(x, torch.Tensor) and not isinstance(x, dtensor.DTensor) - else x - for x in args[1] - ] - ) - return tuple(new_args), kwargs - else: + if ( + (not isinstance(args[1], dtensor.DTensor)) + and isinstance(args[0], dtensor.DTensor) + and isinstance(args[1], torch.Tensor) + ): + device_mesh = args[0]._spec.mesh + new_args_1 = dtensor.DTensor.from_local(args[1], device_mesh, [Replicate()], run_check=False) + return (args[0], new_args_1, *args[2:]), kwargs return args, kwargs + @staticmethod + def index_tensor_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ): + if isinstance(args[0], dtensor.DTensor): + device_mesh = args[0]._spec.mesh + new_args = ( + args[0], + [ + dtensor.DTensor.from_local(x, device_mesh, [Replicate()], run_check=False) + if isinstance(x, torch.Tensor) and not isinstance(x, dtensor.DTensor) + else x + for x in args[1] + ], + ) + return new_args, kwargs + return args, kwargs + + +_dispatch_pre_patch = DispatchPrePatch() + def defer_resharding(op_call: torch._ops.OpOverload, dt_wrap: Any): if DeferReshardMode._enable_autoresharding() and op_call in _linear_pointwise_ops: @@ -151,7 +195,7 @@ def _pre_patch_for_dispatch(*args, **kwargs): Put patch logic here before entering dtensor dispatching logic """ failed_on_mqa(*args, **kwargs) - return hack_for_special_op(*args, **kwargs) + return _dispatch_pre_patch.apply(*args, **kwargs) def _post_patch_for_dispatch(*args, **kwargs): diff --git a/vescale/dtensor/_utils.py b/vescale/dtensor/_utils.py index 4127a61..fb50d95 100644 --- a/vescale/dtensor/_utils.py +++ b/vescale/dtensor/_utils.py @@ -9,6 +9,7 @@ ################################################################################ import warnings +import copy from typing import List, Sequence, Tuple, Optional, Dict, Set, Union import torch @@ -440,3 +441,23 @@ def compute_local_offset(global_shape: ShapeType, mesh: DeviceMesh, placements: local_shape[shard_dim] = shard_size local_offsets[shard_dim] = shard_offset return tuple(local_offsets) + + +def compute_global_stride( + local_tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[int, ...]: + """ """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + return () + if not local_tensor.is_contiguous(): + raise RuntimeError("local tensor should be contiguous") + global_stride = copy.deepcopy(list(local_tensor.stride())) + for i, p in enumerate(placements): + if not p.is_shard(): + continue + shard_dim = p.dim + shard_size = mesh.size(i) + for j in range(shard_dim): + global_stride[j] *= shard_size + return tuple(global_stride) diff --git a/vescale/dtensor/device_mesh.py b/vescale/dtensor/device_mesh.py index c9647b9..9951271 100644 --- a/vescale/dtensor/device_mesh.py +++ b/vescale/dtensor/device_mesh.py @@ -498,6 +498,38 @@ def shape(self) -> Tuple[int, ...]: def get_rank(self) -> int: return get_rank() + def get_local_rank(self, mesh_dim: Optional[int] = None) -> int: + """ + Returns the local rank of the given mesh_dim of the DeviceMesh. + + Args: + mesh_dim (int, optional): it is the index of the mesh dimension. Default is None. + + Returns: + An integer denotes the local rank. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. + """ + if self.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + ) + elif mesh_dim is None: + mesh_dim = 0 + + mesh_dim_group = self.get_dim_groups(mesh_dim) + assert isinstance(mesh_dim_group, ProcessGroup), "We expect ProcessGroup before calling `get_rank`!" + + return get_rank(mesh_dim_group) + def get_coordinate(self) -> Optional[List[int]]: """ Return the relative indices of this rank relative to all diff --git a/vescale/dtensor/dispatch.py b/vescale/dtensor/dispatch.py index 64d5e63..6e63ef7 100644 --- a/vescale/dtensor/dispatch.py +++ b/vescale/dtensor/dispatch.py @@ -8,12 +8,9 @@ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ -import functools -import operator from typing import Dict, List, Optional, Sequence, Tuple, cast import torch -import torch.distributed as dist from optree import tree_flatten, tree_unflatten import vescale.dtensor.dtensor as dtensor @@ -322,16 +319,6 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: else: local_results = op_call(*local_tensor_args, **op_info.local_kwargs) - # communicate the result to all ranks for some operators that return scalar value - if output_sharding.output_spec is None: - if op_call == aten.equal.default: - obj_list = [None for _ in range(dist.get_world_size())] - dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] - obj_list = list(filter(lambda x: x is not None, obj_list)) - # perform reduce on the collection with AND op - # :NOTE: here is an implicit communication - local_results = functools.reduce(operator.and_, obj_list, True) - # fill tensor_meta for replicate output as it bypassed tensor_meta prop # TODO: move to `_post_patch_for_dispatch`? if ( isinstance(output_sharding.output_spec, DTensorSpec) diff --git a/vescale/dtensor/loss.py b/vescale/dtensor/loss.py new file mode 100644 index 0000000..c53881f --- /dev/null +++ b/vescale/dtensor/loss.py @@ -0,0 +1,474 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +import contextlib +from typing import cast, Dict, Optional, Tuple + +import torch +import torch._prims_common as utils +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch import Tensor +from vescale.dtensor import DTensor +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dtensor.ops.embedding_ops import _MaskPartial +from vescale.dtensor.ops.math_ops import ( + _skip_dim, + Reduction, + replicate_reduction_dims, +) +from vescale.dtensor.placement_types import Placement, TensorMeta +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor._dispatch_bypass import _bypass_op_dispatch +import vescale.dtensor.dispatch as op_dispatch + +aten = torch.ops.aten + + +__all__ = ["loss_parallel"] + + +@contextlib.contextmanager +def loss_parallel(): + """ + A context manager that enables loss parallelism, where efficient parallelized loss computation + can be performed when the input is sharded on the class dimension. Currently only the cross-entropy + loss is supported. + + Within this context manager, one can use :func:`~torch.nn.functional.cross_entropy` or + :class:`~torch.nn.CrossEntropyLoss` as usual, with the following assumptions on the input parameters. + The corresponding ``backward()`` call, if any, also needs to happen under this context manager. + + Args: + input (:class:`DTensor`): + Input logits. Assumed to be sharded on the class dimension. + target (Union[:class:`torch.Tensor`, :class:`DTensor`]): + Must be ground truth class indices (class probabilities currently not supported). + Assumed to be replicated across the ``DeviceMesh``. + weight (Union[:class:`torch.Tensor`, :class:`DTensor`], optional): + If given, assumed to be replicated across the ``DeviceMesh``. + label_smoothing: + Currently not supported. + + Returns: + A replicated :class:`DTensor`. + + Example: + A sharded DTensor is manually created here to showcase the usage. + In practice, it is usually the output of a TP module. + + >>> # xdoctest: +SKIP("distributed") + >>> from torch.distributed.tensor.parallel import loss_parallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> device_mesh = init_device_mesh("cuda", (8,)) + >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) + >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) + >>> target = torch.randint(16, (4,), device="cuda") + >>> with loss_parallel(): + >>> loss = F.cross_entropy(dist_input, target, reduction="mean") + >>> loss.backward() + >>> ... + """ + _enable_custom_loss_ops() + + yield + + _disable_custom_loss_ops() + + +# Currently only needs to support one dimensional DeviceMesh; in general return +# the mesh_dim with placements[mesh_dim].is_shard(dim) +def _find_all_reduce_mesh_dim(placements: Tuple[Placement, ...], dim: int) -> int: + if not len(placements) == 1: + raise ValueError("Currently loss_parallel() only supports input on one-dimensional DeviceMesh.") + if not placements[0].is_shard(dim): + raise ValueError(f"loss_parallel() should be enabled only when the input tensor is sharded on dimension {dim}.") + return 0 + + +def _cast_to_dtensor(tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh) -> DTensor: + if isinstance(tensor, DTensor): + if tensor.placements == placements: + return tensor + else: + raise RuntimeError(f"Expected {placements} but got {tensor.placements}.") + elif isinstance(tensor, torch.Tensor): + return DTensor.from_local(tensor, device_mesh=mesh, placements=placements, run_check=False) + else: + raise TypeError(f"Unsupported type {type(tensor)}") + + +def _propagate_tensor_meta( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> TensorMeta: + op_info = op_dispatch.unwrap_to_op_info(op_call, args, kwargs, DTensor._propagator) + tensor_meta = DTensor._propagator._propagate_tensor_meta(op_info.schema) + if isinstance(tensor_meta, TensorMeta): + return tensor_meta + elif isinstance(tensor_meta, tuple): + return tensor_meta[0] + else: + raise RuntimeError(f"Unexpected tensor meta type: {type(tensor_meta)}.") + + +# NOTE: The implementation follows torch._decomp.decomposition._log_softmax, +# with all_reduce manually inserted to perform distributed computation. +def _log_softmax(x, dim, half_to_float, mesh, mesh_dim): + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if x.numel() == 0: + shifted = x + else: + x_max = torch.amax(x, dim, keepdim=True) + x_max = funcol.all_reduce(x_max, reduceOp=c10d.ReduceOp.MAX.name, group=mesh._dim_group_infos[mesh_dim][1]) + shifted = x - x_max + shifted_sumexp = torch.sum(torch.exp(shifted), dim, keepdim=True) + shifted_sumexp = funcol.all_reduce( + shifted_sumexp, reduceOp=c10d.ReduceOp.SUM.name, group=mesh._dim_group_infos[mesh_dim][1] + ) + shifted_logsumexp = torch.log(shifted_sumexp) + result = shifted - shifted_logsumexp + if not half_to_float: + result = result.to(result_dtype) + return result + + +def _log_softmax_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + x = cast(DTensor, args[0]) + dim = cast(int, args[1]) + half_to_float = cast(bool, args[2]) + + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim) + + output_tensor_meta = _propagate_tensor_meta(op_call, args, kwargs) + + res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim) + + return DTensor( + res, + spec.mesh, + spec.placements, + shape=output_tensor_meta.shape, + dtype=output_tensor_meta.dtype, + requires_grad=res.requires_grad, + stride=output_tensor_meta.stride, + ) + + +# NOTE: As explained below at _nll_loss_and_log_softmax_backward, the +# _log_softmax_backward_handler does not actually do any computation. +def _log_softmax_backward_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + grad_output = cast(DTensor, args[0]) + input_dtype = cast(torch.dtype, args[3]) + return grad_output.to(input_dtype) + + +# NOTE: The implementation follows torch._decomp.decomposition._nll_loss_forward, +# with customized communication inserted to perform distributed computation. +def _nll_loss_forward( + x: Tensor, + target: Tensor, + weight: Optional[Tensor], + local_weight: Optional[Tensor], + reduction: int, + ignore_index: int, + channel_dim_size: int, + mesh: DeviceMesh, + mesh_dim: int, +) -> Tuple[Tensor, Tensor]: + n_dims = x.dim() + channel_dim = 1 + if n_dims < 2: + channel_dim = 0 + + def _weight_view(weight: Tensor) -> Tensor: + if n_dims > 1: + shape = [ + 1, + ] * n_dims + shape[channel_dim] = weight.shape[0] + w = weight.view(shape) + else: + w = weight + return w + + if weight is not None: + w = _weight_view(weight) + assert local_weight is not None + local_w = _weight_view(local_weight) + x = x * local_w + safe_target = torch.where(target != ignore_index, target, 0) + safe_target_ = safe_target.unsqueeze(channel_dim) + + # The following code block is a distributed version of + # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) + partial_placement = _MaskPartial(logical_dim_size=channel_dim_size) + safe_target_partial_ = partial_placement._partition_value(safe_target_, mesh, mesh_dim) + result_partial = torch.gather(x, channel_dim, safe_target_partial_) + # an all_reduce happens here + result_reduced = partial_placement._reduce_value(result_partial, mesh, mesh_dim) + result = -result_reduced.squeeze(channel_dim) + + result = torch.where(target != ignore_index, result, 0) + + if reduction == Reduction.NONE.value and n_dims > 1: + total_weight = x.new_full((), 0.0) + return result, total_weight + + if weight is not None: + new_shape = list(x.shape) + new_shape[channel_dim] = -1 + w = w.expand(new_shape) + wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) + wsum = torch.where(target != ignore_index, wsum, 0) + total_weight = wsum.sum() + else: + total_weight = (target != ignore_index).sum().to(x) + + # NOTE: this is correct only on 1D DeviceMesh; o/w additional + # all-reduce on result and total_weight is needed + if reduction == Reduction.SUM.value: + result = result.sum() + elif reduction == Reduction.MEAN.value: + result = result.sum() / total_weight + + return result, total_weight + + +def _nll_loss_forward_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + x = cast(DTensor, args[0]) + target = args[1] + weight = args[2] + reduction = cast(int, args[3]) + ignore_index = cast(int, args[4]) + + channel_dim = 1 if x.dim() >= 2 else 0 + channel_dim_size = x.shape[channel_dim] + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) + + # Check user input: if target and weight are not DTensors, convert them to DTensors; + # if they are DTensors, check that they have the desired placements. + target_placements = _skip_dim(replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim) + all_replicate_placements = (Replicate(),) * spec.mesh.ndim + target = _cast_to_dtensor(target, target_placements, spec.mesh) + local_weight = None + if weight is not None: + weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) + # For local computation, both (replicated) weight and (sharded) local_weight + # are needed in _nll_loss_forward(). local_weight is generated here using + # DTensor API, without incurring any communication. + sharded_placements = [Shard(0) if i == mesh_dim else Replicate() for i in range(spec.mesh.ndim)] + local_weight = weight.redistribute(spec.mesh, sharded_placements)._local_tensor + assert local_weight.shape[0] == x._local_tensor.shape[channel_dim] + + if reduction == Reduction.NONE.value: + output_placements = target_placements + else: + output_placements = all_replicate_placements + + # tensor inputs to _propagate_tensor_meta need to be DTensors + args = list(args) + args[1], args[2] = target, weight + output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) + + result, total_weight = _nll_loss_forward( + x._local_tensor, + target._local_tensor, + weight._local_tensor if weight is not None else None, + local_weight, + reduction, + ignore_index, + channel_dim_size, + spec.mesh, + mesh_dim, + ) + + return ( + DTensor( + result, + spec.mesh, + output_placements, + shape=output_tensor_meta.shape, + dtype=output_tensor_meta.dtype, + requires_grad=result.requires_grad, + stride=output_tensor_meta.stride, + ), + total_weight, + ) + + +# NOTE: The backward computation of cross_entropy goes through two steps: +# backward for nll_loss and then backward for log_softmax. In loss parallel, +# the two steps are fused into the following function (called by _nll_loss_backward_handler) +# to avoid communication when target contains class indices not class probabilities. +# Also note that the _log_softmax_backward_handler does not perform computation. +# The implementation resembles _nll_loss_backward and _log_softmax_backward_data +# from torch._decomp.decomposition. +def _nll_loss_and_log_softmax_backward( + grad_output: Tensor, + x: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, + channel_dim_size: int, + mesh: DeviceMesh, + mesh_dim: int, +) -> Tensor: + channel_dim = 0 if x.dim() < 2 else 1 + if reduction == Reduction.MEAN.value: + grad_output = grad_output / total_weight + + target = target.unsqueeze(channel_dim) + safe_target = torch.where(target != ignore_index, target, 0) + grad_input = torch.zeros_like(x) + + # The following code block is a distributed version of + # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) + partial_placement = _MaskPartial(logical_dim_size=channel_dim_size) + safe_target = safe_target.squeeze(channel_dim).flatten() + masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim) + # only update grad_input to -1 if not masked + assert partial_placement.mask_buffer.data is not None + grad_update = partial_placement.mask_buffer.data.float() - 1.0 + arange_1d = torch.arange(masked_safe_target.shape[0], device=masked_safe_target.device) + # The first two cases with x.dim() <= 2 are for aten.nll_loss_backward.default; + # the last case is for aten.nll_loss2d_backward.default. + if x.dim() == 1: + grad_input[masked_safe_target] = grad_update + elif x.dim() == 2: + grad_input[arange_1d, masked_safe_target] = grad_update + else: + grad_input_t = grad_input.transpose(channel_dim, -1) + intermidate_shape = grad_input_t.shape + grad_input_2d = grad_input_t.reshape(-1, x.shape[channel_dim]) + grad_input_2d[arange_1d, masked_safe_target] = grad_update + grad_input = grad_input_2d.view(intermidate_shape).transpose(channel_dim, -1) + + if grad_input.dim() > grad_output.dim() > 0: + grad_output = grad_output.unsqueeze(channel_dim) + + if weight is not None: + new_shape = [1 for _ in range(x.dim())] + new_shape[channel_dim] = weight.shape[0] + weight = weight.reshape(new_shape) + # In order for fused computation to work, the following line is rewritten. + # grad_output = grad_output * weight + new_shape = list(x.shape) + new_shape[channel_dim] = -1 + w = weight.expand(new_shape) + w_target = torch.gather(w, channel_dim, target) + grad_output = grad_output * w_target + + grad_output = torch.where(target != ignore_index, grad_output, 0) + + # NOTE: Instead of directly returning the grad_input as grad_output for log_softmax, + # here we perform backward computation for log_softmax altogether to avoid the + # otherwise extra all_gather communication. + # return grad_input * grad_output + return (grad_input + torch.exp(x)) * grad_output + + +def _nll_loss_backward_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + grad_output = cast(DTensor, args[0]) + x = cast(DTensor, args[1]) + target = args[2] + weight = args[3] + reduction = cast(int, args[4]) + ignore_index = cast(int, args[5]) + total_weight = cast(Tensor, args[6]) + + channel_dim = 1 if x.dim() >= 2 else 0 + channel_dim_size = x.shape[channel_dim] + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) + + # if target and weight are not DTensors, convert them to DTensors + target_placements = _skip_dim(replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim) + all_replicate_placements = (Replicate(),) * spec.mesh.ndim + target = _cast_to_dtensor(target, target_placements, spec.mesh) + if weight is not None: + weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) + + # tensor inputs to _propagate_tensor_meta need to be DTensors + args = list(args) + args[2], args[3] = target, weight + args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh) + output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) + + result = _nll_loss_and_log_softmax_backward( + grad_output._local_tensor, + x._local_tensor, + target._local_tensor, + weight._local_tensor if weight is not None else None, + reduction, + ignore_index, + total_weight, + channel_dim_size, + spec.mesh, + mesh_dim, + ) + + return DTensor( + result, + spec.mesh, + # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim + spec.placements, + shape=output_tensor_meta.shape, + dtype=output_tensor_meta.dtype, + requires_grad=result.requires_grad, + stride=output_tensor_meta.stride, + ) + + +customized_loss_ops = { + aten._log_softmax.default: _log_softmax_handler, + aten._log_softmax_backward_data.default: _log_softmax_backward_handler, + aten.nll_loss_forward.default: _nll_loss_forward_handler, + aten.nll_loss2d_forward.default: _nll_loss_forward_handler, + aten.nll_loss_backward.default: _nll_loss_backward_handler, + aten.nll_loss2d_backward.default: _nll_loss_backward_handler, +} + + +def _enable_custom_loss_ops(): + _bypass_op_dispatch.op_handlers.update(customized_loss_ops) + + +def _disable_custom_loss_ops(): + for custom_op in customized_loss_ops: + _bypass_op_dispatch.op_handlers.pop(custom_op) diff --git a/vescale/dtensor/ops/embedding_ops.py b/vescale/dtensor/ops/embedding_ops.py index b3c6506..cdd2dcd 100644 --- a/vescale/dtensor/ops/embedding_ops.py +++ b/vescale/dtensor/ops/embedding_ops.py @@ -10,8 +10,12 @@ # implement matrix related ops for distributed tensor import copy +from dataclasses import dataclass, field +from typing import cast, Optional, Tuple import torch +import torch.distributed.distributed_c10d as c10d +import torch.distributed._functional_collectives as funcol from vescale.dtensor.op_schema import OpSchema, OutputSharding from vescale.dtensor.ops.utils import ( @@ -20,11 +24,175 @@ is_tensor_all_replicate_except_sharded_at_dim, is_tensor_partial, ) -from vescale.dtensor.placement_types import DTensorSpec, Partial, Replicate, Shard +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.placement_types import DTensorSpec, Placement, Partial, Replicate, Shard +from vescale.dtensor.redistribute import _reduce_scatter_to_shard_with_pad aten = torch.ops.aten +@dataclass +class MaskBuffer: + data: Optional[torch.Tensor] = None + + def materialize_mask(self, mask): + if self.data is not None: + raise RuntimeError("MaskBuffer has already been materialized") + self.data = mask + + def release_mask(self): + # TODO: evaluate if we need to release the mask buffer or the buffer + # can just have the same lifetime as the _Partial placement + if self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + self.data = None + + def apply_mask(self, tensor): + if self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + + # NOTE: _MaskPartial is being used by the embedding op and the gather op. + # For gather, the mask has the same dimension as the output tensor, whereas + # the output of the embedding op has an additional dimension compare to the input, + # hence the output masking logic below having two different cases. + if tensor.ndim == self.data.ndim: + tensor[self.data] = 0.0 + else: + tensor[self.data, :] = 0.0 + + +@dataclass(frozen=True) +class _MaskPartial(Partial): + """ + A partial mask placement devised for rowwise sharded embedding op, where we need + to mask and adjust the indices to the local embedding shard, embedding masking + is a special type of the Partial placement + + NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor + lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. + """ + + logical_dim_size: int = -1 + mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) + reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM + + def _local_shard_size_on_dim( + self, + size_on_dim: int, + num_chunks: int, + rank: int, + return_offset: bool = False, + ) -> Tuple[int, int]: + """ + returns the local shard size and offset on a given tensor dim + """ + assert ( + size_on_dim >= num_chunks + ), f"Size to be sharded on with dim_size {size_on_dim} must be at least as large \ + as the number of devices in that dimension {num_chunks}" + + # Compute the chunk size inline with ``torch.chunk`` + full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks + + # Compute chunk size for each chunk on the dimension. + chunk_sizes = [ + max( + min(size_on_dim, full_chunk_size * (idx + 1)) - full_chunk_size * idx, + 0, + ) + for idx in range(num_chunks) + ] + local_shard_size = chunk_sizes[rank] + + local_offset_on_dim = -1 + if return_offset: + # Return global tensor dim size of current dimension if for empty shard + # to represent the end of the corresponding tensor dim. + local_offset_on_dim = sum(chunk_sizes[:rank]) + + return (local_shard_size, local_offset_on_dim) + + def _partition_value(self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int) -> torch.Tensor: + # override parent logic to perform partial mask for embedding + num_chunks = mesh.size(mesh_dim) + # get local shard size and offset on the embedding_dim + local_shard_size, local_offset_on_dim = self._local_shard_size_on_dim( + self.logical_dim_size, + num_chunks, + mesh.get_local_rank(mesh_dim), + return_offset=True, + ) + # Build the input mask and save it for the current partial placement + # this is so that the output of embedding op can reuse the same partial + # placement saved mask to perform mask + reduction + mask = (tensor < local_offset_on_dim) | (tensor >= local_offset_on_dim + local_shard_size) + # mask the input tensor + masked_tensor = tensor.clone() - local_offset_on_dim + masked_tensor[mask] = 0 + # materialize the mask buffer to be used for reduction + self.mask_buffer.materialize_mask(mask) + return masked_tensor + + def _reduce_value(self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # perform sum reduction + return funcol.all_reduce(tensor, reduceOp=self.reduce_op.name, group=mesh._dim_group_infos[mesh_dim][1]) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # call reduce_shard_tensor of the shard_spec. + shard_spec = cast(Shard, shard_spec) + return _reduce_scatter_to_shard_with_pad(tensor, mesh, self.reduce_op, mesh_dim, shard_spec) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _MaskPartial): + return False + + # if either data is not None, we invalidate the sharding cache, as this indicates + # the current MaskPartial placement is still in use and should not be used for cache hit. + if self.mask_buffer.data is not None or other.mask_buffer.data is not None: + return False + + return self.reduce_op == other.reduce_op and self.logical_dim_size == other.logical_dim_size + + def __hash__(self) -> int: + return 1 + hash((self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op)) + + def __repr__(self) -> str: + """ + machine readable representation of the MaskPartial placement + """ + return f"_MaskPartial(logical_dim_size={self.logical_dim_size})" + + def __str__(self) -> str: + """ + human readable representation of the MaskPartial placement + """ + return "MaskP" + + # TODO: Enable BWD for embedding op. @register_prop_rule(aten.embedding.default) def embedding_rules(op_schema: OpSchema) -> OutputSharding: diff --git a/vescale/dtensor/ops/math_ops.py b/vescale/dtensor/ops/math_ops.py index 7aafeaf..683eb8f 100644 --- a/vescale/dtensor/ops/math_ops.py +++ b/vescale/dtensor/ops/math_ops.py @@ -9,6 +9,7 @@ ################################################################################ from typing import cast, List, Optional, Sequence, Tuple +from enum import Enum import torch import torch.distributed.distributed_c10d as c10d @@ -36,6 +37,12 @@ aten = torch.ops.aten +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + + def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]: if dims_arg is None: return None @@ -62,6 +69,17 @@ def _infer_reduce_dims_map(reduction_dims: List[int], input_ndim: int, keep_dim= return reduction_dims_map +# return new_placements which align with placements but skip the skipped_dim +def _skip_dim(placements: Tuple[Placement, ...], skipped_dim: int) -> Tuple[Placement, ...]: + new_placements: List[Placement] = [] + for p in placements: + if isinstance(p, Shard) and p.dim >= skipped_dim: + new_placements.append(Shard(p.dim - 1)) + else: + new_placements.append(p) + return tuple(new_placements) + + def replicate_reduction_dims(placements: Tuple[Placement, ...], reduction_dims: List[int]) -> Tuple[Placement, ...]: # replicate the reduction dims if not reduction_linear new_placements: List[Placement] = [] diff --git a/vescale/dtensor/ops/pointwise_ops.py b/vescale/dtensor/ops/pointwise_ops.py index 0415514..781b86b 100644 --- a/vescale/dtensor/ops/pointwise_ops.py +++ b/vescale/dtensor/ops/pointwise_ops.py @@ -58,6 +58,7 @@ aten.add_.Tensor, aten.neg.default, aten.neg_.default, + aten.floor_divide_.Tensor, ] pointwise_ops = [ @@ -614,3 +615,12 @@ def foreach_list_linear_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> Strat for op in for_each_linearity_ops: register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(foreach_list_linear_strategy) + + +from vescale.dtensor.ops.utils import register_prop_rule +from vescale.dtensor.op_schema import OutputSharding + + +@register_prop_rule([aten.clamp_max.default, aten.clamp_min.default]) +def clamp_max_rule(op_schema: OpSchema) -> OutputSharding: + return OutputSharding(op_schema.args_schema[0]) diff --git a/vescale/dtensor/ops/tensor_ops.py b/vescale/dtensor/ops/tensor_ops.py index 1c380fb..3e119d6 100644 --- a/vescale/dtensor/ops/tensor_ops.py +++ b/vescale/dtensor/ops/tensor_ops.py @@ -168,10 +168,17 @@ def new_factory_rule(op_schema: OpSchema) -> OutputSharding: ), "tensor meta must not be None if you are constructing a sharded tensor using `new_zeros` or something like that" original_numel = prod(input_spec.tensor_meta.shape) target_numel = prod(output_shape) - assert original_numel == target_numel, "for now, we only support the same numel in new_factory methods" from vescale.dtensor.ops.vescale_view_ops import vescale_view_rule_prop, ops + if original_numel != target_numel: + return OutputSharding( + output_spec=DTensorSpec( + mesh=mesh, + placements=tuple([Replicate()] * mesh.ndim), + ) + ) + spec = ops[torch.Tensor.view] output_sharding = vescale_view_rule_prop(op_schema=op_schema, spec=spec) return output_sharding diff --git a/vescale/dtensor/ops/vescale_view_ops.py b/vescale/dtensor/ops/vescale_view_ops.py index a348bc1..15420eb 100644 --- a/vescale/dtensor/ops/vescale_view_ops.py +++ b/vescale/dtensor/ops/vescale_view_ops.py @@ -419,8 +419,13 @@ def prop_as_strided_rule(op_schema: OpSchema) -> OutputSharding: memory_offset = args_schema[3] assert isinstance(input_spec, DTensorSpec) - assert memory_offset == 0, "for now, we only support 0 offset" + mesh = input_spec.mesh + + # NOTE: this is a hack for a simple case. + if not output_shape or input_spec.is_replicated(): + return OutputSharding(DTensorSpec(mesh=mesh, placements=tuple([Replicate()] * mesh.ndim))) + assert memory_offset == 0, "for now, we only support 0 offset" assert _check_tensor_contiguous(output_shape, output_stride), "for now, we only support contiguous output" assert input_spec.tensor_meta is not None diff --git a/vescale/dtensor/ops/view_ops.py b/vescale/dtensor/ops/view_ops.py index ace7730..e1363f0 100644 --- a/vescale/dtensor/ops/view_ops.py +++ b/vescale/dtensor/ops/view_ops.py @@ -13,6 +13,7 @@ import torch from torch import Tensor +import warnings from vescale.dtensor._utils import compute_local_shape from vescale.dtensor.op_schema import OpSchema, OutputSharding, RuntimeSchemaInfo @@ -295,6 +296,22 @@ def view_groups(from_size: Shape, to_size: Shape) -> DimMap: assert from_nelem == prod(to_size), "Total view shape does not add up" + if from_nelem == 0: + warnings.warn("An empty tensor is encountered during a view or reshape operation.", UserWarning) + + flattened = Flatten.new(tuple(InputDim(fi) for fi in range(len(from_size)) if from_size[fi] != 1)) + to_group_shape = [t for t in to_size if t != 1] + result_pp = [] + i = 0 + for t in to_size: + if t == 1: + result_pp.append(Singleton()) + else: + result_pp.append(Split.new(flattened, tuple(to_group_shape), i)) + i += 1 + + return tuple(result_pp) + from_idx = 0 to_idx = 0 from_len = len(from_size) diff --git a/vescale/dtensor/redistribute.py b/vescale/dtensor/redistribute.py index ee80dde..ede8c0e 100644 --- a/vescale/dtensor/redistribute.py +++ b/vescale/dtensor/redistribute.py @@ -26,6 +26,7 @@ from vescale.dtensor.device_mesh import DeviceMesh from vescale.dtensor.op_schema import DTensorSpec from vescale.dtensor.placement_types import InterleavedShard, Partial, Placement, Replicate, Shard +from vescale.dtensor._utils import compute_global_stride _PlacementItem = Tuple[int, Tuple[Placement, Placement]] @@ -74,7 +75,8 @@ def _decompose_reshard(val: List[_PlacementItem]) -> List[_PlacementItem]: if ( isinstance(current, Shard) and isinstance(target, Shard) - and (current.dim != target.dim or repeat_dim_current[current.dim] != repeat_dim_target[target.dim]) + and (isinstance(target, InterleavedShard) or isinstance(current, InterleavedShard)) + # and (current.dim != target.dim or repeat_dim_current[current.dim] != repeat_dim_target[target.dim]) ): # decompose Shard(i) -> Shard(j) into Shard(i) -> Replicate() -> Shard(j) output.append((i, (current, Replicate()))) @@ -244,6 +246,7 @@ def redistribute_local_tensor( current_placements = current_spec.placements target_placements = target_spec.placements sorted_placements = list(enumerate(zip(current_placements, target_placements))) + sorted_placements = _decompose_reshard(sorted_placements) sorted_placements.sort(key=_replicate_then_shard) for i, (current, target) in sorted_placements: @@ -328,7 +331,7 @@ def redistribute_local_tensor( # FIXME: for now, we don't support conversion # between InterleavedShard and Shard. Maybe we should provide # a method to transfer InterleavedShard to a contiguous Shard? - raise NotImplementedError("Redistributiom from Shard to InterleavedShard is not supported") + raise NotImplementedError("Redistribution from Shard to InterleavedShard is not supported") elif target.is_shard(): # Case 2: target is Shard target_placement = cast(Shard, target) @@ -527,17 +530,29 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # Short cut the local tensor if the placements are the same. if current_spec == target_spec: output = local_tensor + output_dtensor = dtensor.DTensor( + output, + target_spec.mesh, + target_spec.placements, + shape=grad_output.shape, + dtype=grad_output.dtype, + requires_grad=grad_output.requires_grad, + stride=grad_output.stride(), + ) else: output = redistribute_local_tensor(local_tensor, current_spec, target_spec, async_op) - - output_dtensor = dtensor.DTensor( - output, - target_spec.mesh, - target_spec.placements, - shape=grad_output.shape, - dtype=grad_output.dtype, - requires_grad=grad_output.requires_grad, - stride=grad_output.stride(), - ) + output_dtensor = dtensor.DTensor( + output, + target_spec.mesh, + target_spec.placements, + shape=grad_output.shape, + dtype=grad_output.dtype, + requires_grad=grad_output.requires_grad, + stride=compute_global_stride(output, mesh=target_spec.mesh, placements=target_spec.placements) + # we found in some cases, after redistribute tensor, global stride will differ from the local + # stride, we forcely the global stride to match the local stride when local tensor is contiguous. + if output.is_contiguous() + else grad_output.stride(), + ) return (output_dtensor, None, None, None) diff --git a/vescale/dtensor/sharding_prop.py b/vescale/dtensor/sharding_prop.py index 6367d6d..a2fd397 100644 --- a/vescale/dtensor/sharding_prop.py +++ b/vescale/dtensor/sharding_prop.py @@ -31,7 +31,6 @@ TupleStrategy, ) from vescale.dtensor.placement_types import TensorMeta -from vescale.dtensor._dispatch_bypass import _bypass_for_sharding_prop aten = torch.ops.aten @@ -153,9 +152,6 @@ def _wrap_output_spec_tensor_meta( spec.tensor_meta = output_tensor_meta_i def propagate(self, op_info: OpInfo) -> None: - # bypass sharding prop of some ops for speed-up - if _bypass_for_sharding_prop(op_info): - return # We cannot use an lru cache if we know that inputs will have dynamic shapes, # because SymInts are not hashable. # This is generally ok because this only happens during tracing in torch.compile, diff --git a/vescale/optim/distributed_optimizer.py b/vescale/optim/distributed_optimizer.py index 0a28060..814499d 100644 --- a/vescale/optim/distributed_optimizer.py +++ b/vescale/optim/distributed_optimizer.py @@ -51,7 +51,7 @@ def __repr__(self) -> str: class OptimizerStateSpec: """This class represents mapping between local flattened 1D tensor and global original DTensor in DOptimzier, it is used for - loading or saving optimizer states using OmniStore (PyTorch DCP) + loading or saving optimizer states using vescale.checkpoint (PyTorch DCP) and load-time checkpoint resharding when changing tp size or dp size. For example, a linear layer in Vescale is DTensor(size=[1024, 1024]) @@ -85,7 +85,7 @@ class OptimizerStateSpec: global_offset: Tuple[int] # The unflattened local tensor after create view using local_shape on the flattened 1D Tensor in DOptimizer # NOTE: In order to support TP resharding and state cross dp ranks, we defer the reshaping from 1D to local_shape - # to generate saving plan using OmniStore (PyTorch DCP) + # to generate saving plan using vescale.checkpoint (PyTorch DCP) local_tensor: torch.Tensor # If the current optimizer state is sharded by multiple dp ranks, # we should record all ranks and their ranges @@ -244,7 +244,7 @@ def __init__( self.overlap_param_gather = overlap_param_gather self.grad_to_fp32 = grad_to_fp32 - # Model parameter sharding info for omnistore checkpointing + # Model parameter sharding info for vescale.checkpoint self.param_to_name = {} self.param_shard_info = {} self.param_global_shape_info = {} @@ -321,8 +321,8 @@ def __init__( except Exception: storage = bucket.data.storage().untyped() - # Typed param buffer. - param_buffer = torch.tensor(storage, dtype=self.main_param_dtype, device=bucket.data.device) + # use tensor._set(storage) to avoid unnecessary memory allocation. + param_buffer = torch.tensor([], dtype=self.main_param_dtype, device=storage.device).set_(storage) # .storage() ignores views / slices, so param_buffer now points to the start # of the grad_buffer instead of to the start of each bucket. As a result,