Skip to content

Commit

Permalink
deepspeed comm ?
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Apr 28, 2024
1 parent 900367a commit d3ba3c0
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 36 deletions.
22 changes: 14 additions & 8 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

if is_deepspeed_available():
import deepspeed
import deepspeed.comm

if is_torch_distributed_available():
import torch.distributed
Expand Down Expand Up @@ -132,10 +133,6 @@ def __init__(self, config: PyTorchConfig):
model=self.pretrained_model, config=self.config.deepspeed_inference_config
)

if is_torch_distributed_available() and torch.distributed.is_initialized():
LOGGER.info("\t+ Waiting for torch.distributed process group to synchronize")
torch.distributed.barrier()

def validate_library(self) -> None:
if self.config.library == "timm":
LOGGER.info(f"\t+ Using Timm's {self.automodel_class.__name__}")
Expand Down Expand Up @@ -358,14 +355,23 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

@torch.inference_mode()
def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
if self.config.deepspeed_inference:
torch.distributed.barrier(group=deepspeed.comm.get_world_group())

return self.pretrained_model.forward(**inputs, **kwargs)

@torch.inference_mode()
def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
if self.config.deepspeed_inference:
torch.distributed.barrier(group=deepspeed.comm.get_world_group())

return self.pretrained_model.generate(**inputs, **kwargs)

@torch.inference_mode()
def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
if self.config.deepspeed_inference:
torch.distributed.barrier(group=deepspeed.comm.get_world_group())

return self.pretrained_model.generate(**inputs, **kwargs)

@torch.inference_mode()
Expand Down Expand Up @@ -399,10 +405,6 @@ def seed(self):
torch.cuda.manual_seed_all(self.config.seed)

def clean(self) -> None:
if is_torch_distributed_available() and torch.distributed.is_initialized():
LOGGER.info("\t+ Waiting for torch.distributed process group to synchronize")
torch.distributed.barrier()

if hasattr(self, "tmpdir"):
LOGGER.info("\t+ Cleaning backend temporary directory")
self.tmpdir.cleanup()
Expand All @@ -411,3 +413,7 @@ def clean(self) -> None:
LOGGER.info("\t+ Deleting pretrained model")
del self.pretrained_model
gc.collect()

if self.config.device == "cuda":
LOGGER.info("\t+ Emptying CUDA cache")
torch.cuda.empty_cache()
11 changes: 6 additions & 5 deletions optimum_benchmark/launchers/torchrun/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,25 @@ def target(worker, queue, lock, log_level, *worker_args, launch_config: LaunchCo
@record
def entrypoint(worker, queue, lock, log_level, *worker_args):
torch.distributed.init_process_group()
process_group = torch.distributed.group.WORLD

rank = torch.distributed.get_rank()
rank = torch.distributed.get_rank(group=process_group)

if rank == 0:
setup_logging(level=log_level, prefix=f"TORCHRUN-RANK-{rank}")
else:
setup_logging(level=log_level, prefix=f"TORCHRUN-RANK-{rank}")
setup_logging(level="ERROR", prefix=f"TORCHRUN-RANK-{rank}")

if torch.cuda.is_available():
LOGGER.info("\t+ Setting torch.distributed cuda device")
torch.cuda.set_device(rank % torch.cuda.device_count())

torch.distributed.barrier()
torch.distributed.barrier(group=process_group)
output = worker(*worker_args)
torch.distributed.barrier()
torch.distributed.barrier(group=process_group)

lock.acquire()
queue.put(output)
lock.release()

torch.distributed.destroy_process_group()
torch.distributed.destroy_process_group(group=process_group)
26 changes: 13 additions & 13 deletions optimum_benchmark/trackers/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,32 +164,32 @@ def __init__(self, backend: str, device: str, device_ids: Optional[str] = None):

@contextmanager
def track(self, file_prefix: str = "task"):
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()

if self.is_distributed:
torch.distributed.barrier()

self.emission_tracker.start_task()

yield

emission_data: EmissionsData = self.emission_tracker.stop_task()
if self.is_asynchronous:
torch.cuda.synchronize()

self.cpu_energy = emission_data.cpu_energy
self.gpu_energy = emission_data.gpu_energy
self.ram_energy = emission_data.ram_energy
self.total_energy = emission_data.energy_consumed
if self.is_distributed:
torch.distributed.barrier()

emission_data: EmissionsData = self.emission_tracker.stop_task()

with open(f"{file_prefix}_codecarbon.json", "w") as f:
LOGGER.info(f"\t+ Saving codecarbon emission data to {file_prefix}_codecarbon.json")
dump(asdict(emission_data), f, indent=4)

if self.is_asynchronous:
torch.cuda.synchronize()

if self.is_distributed:
torch.distributed.barrier()
self.cpu_energy = emission_data.cpu_energy
self.gpu_energy = emission_data.gpu_energy
self.ram_energy = emission_data.ram_energy
self.total_energy = emission_data.energy_consumed

def get_energy(self) -> Energy:
return Energy(
Expand Down
20 changes: 10 additions & 10 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def _cpu_latency(self):
self.end_events.append(time.perf_counter())

def get_latency(self) -> Latency:
# if self.is_distributed:
# torch.distributed.barrier()
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()
Expand Down Expand Up @@ -227,8 +227,8 @@ def on_step_end(self, *args, **kwargs):
self.end_events.append(time.perf_counter())

def get_latency(self) -> Latency:
# if self.is_distributed:
# torch.distributed.barrier()
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()
Expand Down Expand Up @@ -311,8 +311,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
return scores

def get_prefill_latency(self) -> Latency:
# if self.is_distributed:
# torch.distributed.barrier()
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()
Expand All @@ -332,8 +332,8 @@ def get_prefill_latency(self) -> Latency:
return Latency.from_values(latencies_list, unit=LATENCY_UNIT)

def get_decode_latency(self) -> Latency:
# if self.is_distributed:
# torch.distributed.barrier()
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()
Expand All @@ -352,8 +352,8 @@ def get_decode_latency(self) -> Latency:
return Latency.from_values(latencies_list, unit=LATENCY_UNIT)

def get_per_token_latency(self) -> Latency:
# if self.is_distributed:
# torch.distributed.barrier()
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()
Expand Down
2 changes: 2 additions & 0 deletions optimum_benchmark/trackers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def _cuda_pytorch_memory(self):
LOGGER.warning(f"\t\t+ Could not reset max memory stats for device {device}: {e}")

torch.cuda.synchronize()

yield from self._cuda_memory()

torch.cuda.synchronize()

self.max_allocated_memory = sum(
Expand Down

0 comments on commit d3ba3c0

Please sign in to comment.