Skip to content

Commit

Permalink
better barriers in trackers
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Apr 28, 2024
1 parent b6ad2d2 commit 89380de
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 42 deletions.
10 changes: 0 additions & 10 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

if is_deepspeed_available():
import deepspeed
import deepspeed.comm


if is_zentorch_available():
Expand Down Expand Up @@ -349,23 +348,14 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

@torch.inference_mode()
def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
if self.config.deepspeed_inference:
deepspeed.comm.barrier()

return self.pretrained_model.forward(**inputs, **kwargs)

@torch.inference_mode()
def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
if self.config.deepspeed_inference:
deepspeed.comm.barrier()

return self.pretrained_model.generate(**inputs, **kwargs)

@torch.inference_mode()
def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
if self.config.deepspeed_inference:
deepspeed.comm.barrier()

return self.pretrained_model.generate(**inputs, **kwargs)

@torch.inference_mode()
Expand Down
12 changes: 6 additions & 6 deletions optimum_benchmark/trackers/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,22 @@ def __init__(self, backend: str, device: str, device_ids: Optional[str] = None):

@contextmanager
def track(self, file_prefix: str = "task"):
if self.is_asynchronous:
torch.cuda.synchronize()

if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()

self.emission_tracker.start_task()

yield

if self.is_asynchronous:
torch.cuda.synchronize()

if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()

emission_data: EmissionsData = self.emission_tracker.stop_task()

with open(f"{file_prefix}_codecarbon.json", "w") as f:
Expand Down
37 changes: 22 additions & 15 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,17 @@ def reset(self):

@contextmanager
def track(self):
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
yield from self._pytorch_cuda_latency()
else:
yield from self._cpu_latency()

if self.is_distributed:
torch.distributed.barrier()

def _pytorch_cuda_latency(self):
self.start_events.append(torch.cuda.Event(enable_timing=True))
self.start_events[-1].record()
Expand All @@ -164,9 +170,6 @@ def _cpu_latency(self):
self.end_events.append(time.perf_counter())

def get_latency(self) -> Latency:
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()

Expand Down Expand Up @@ -205,6 +208,11 @@ def __init__(self, device: str, backend: str) -> None:
self.is_asynchronous = self.backend == "pytorch" and self.device == "cuda"
self.is_distributed = is_torch_distributed_available() and torch.distributed.is_initialized()

if self.is_asynchronous:
LOGGER.info("\t+ Tracking latency using Pytorch CUDA events")
else:
LOGGER.info("\t+ Tracking latency using CPU performance counter")

self.start_events: List[Union[float, torch.cuda.Event]] = []
self.end_events: List[Union[float, torch.cuda.Event]] = []

Expand All @@ -227,9 +235,6 @@ def on_step_end(self, *args, **kwargs):
self.end_events.append(time.perf_counter())

def get_latency(self) -> Latency:
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()

Expand All @@ -251,6 +256,11 @@ def __init__(self, device: str, backend: str):
self.is_asynchronous = self.backend == "pytorch" and self.device == "cuda"
self.is_distributed = is_torch_distributed_available() and torch.distributed.is_initialized()

if self.is_asynchronous:
LOGGER.info("\t+ Tracking latency using Pytorch CUDA events")
else:
LOGGER.info("\t+ Tracking latency using CPU performance counter")

self.start_time: Optional[float] = None
self.next_is_prefill_end_decode_start: Optional[bool] = None

Expand All @@ -272,6 +282,9 @@ def reset(self):

@contextmanager
def track(self):
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
self.prefill_start_events.append(torch.cuda.Event(enable_timing=True))
self.prefill_start_events[-1].record()
Expand All @@ -290,6 +303,9 @@ def track(self):
else:
self.decode_end_events.append(time.perf_counter())

if self.is_distributed:
torch.distributed.barrier()

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
assert (
self.next_is_prefill_end_decode_start is not None
Expand All @@ -311,9 +327,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
return scores

def get_prefill_latency(self) -> Latency:
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()

Expand All @@ -332,9 +345,6 @@ def get_prefill_latency(self) -> Latency:
return Latency.from_values(latencies_list, unit=LATENCY_UNIT)

def get_decode_latency(self) -> Latency:
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()

Expand All @@ -352,9 +362,6 @@ def get_decode_latency(self) -> Latency:
return Latency.from_values(latencies_list, unit=LATENCY_UNIT)

def get_per_token_latency(self) -> Latency:
if self.is_distributed:
torch.distributed.barrier()

if self.is_asynchronous:
torch.cuda.synchronize()

Expand Down
18 changes: 7 additions & 11 deletions optimum_benchmark/trackers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,11 @@ def __init__(

if self.monitored_pid is None:
self.monitored_pid = int(os.environ.get("ISOLATED_PROCESS_PID", os.getpid()))

LOGGER.info("\t+ Tracking RAM memory")
LOGGER.info(f"\t+ Tracking RAM memory of process with PID [{self.monitored_pid}]")

if self.device == "cuda":
self.device_ids = list(map(int, self.device_ids.split(",")))
LOGGER.info(f"\t+ Tracking VRAM memory of CUDA devices {self.device_ids}")
LOGGER.info(f"\t+ Tracking VRAM memory of CUDA devices with IDs [{self.device_ids}]")

if self.uses_cuda_pytorch_allocator:
self.num_pytorch_devices = torch.cuda.device_count()
Expand All @@ -138,39 +137,36 @@ def reset(self):

@contextmanager
def track(self):
if self.is_distributed:
torch.distributed.barrier()

if self.uses_cuda_pytorch_allocator:
yield from self._cuda_pytorch_memory()
elif self.device == "cuda":
yield from self._cuda_memory()
else:
yield from self._cpu_memory()

def _cuda_pytorch_memory(self):
if self.is_distributed:
torch.distributed.barrier()

def _cuda_pytorch_memory(self):
for device in range(self.num_pytorch_devices):
try:
torch.cuda.synchronize(device=device)
torch.cuda.reset_peak_memory_stats(device=device)
except Exception as e:
LOGGER.warning(f"\t\t+ Could not reset max memory stats for device {device}: {e}")

torch.cuda.synchronize()

yield from self._cuda_memory()

torch.cuda.synchronize()

self.max_allocated_memory = sum(
torch.cuda.max_memory_allocated(device=device) / 1e6 for device in range(self.num_pytorch_devices)
)
self.max_reserved_memory = sum(
torch.cuda.max_memory_reserved(device=device) / 1e6 for device in range(self.num_pytorch_devices)
)

if self.is_distributed:
torch.distributed.barrier()

def _cuda_memory(self):
child_connection, parent_connection = Pipe()

Expand Down

0 comments on commit 89380de

Please sign in to comment.