Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add comm part to et_replay #169

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion et_replay/et_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def build_torchscript_func(n):
if (
n.op_schema == ""
or n.name == "aten::record_stream"
or n.name.startswith("aten::_foreach")
#or n.name.startswith("aten::_foreach")
):
return None, None

Expand Down
312 changes: 133 additions & 179 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import logging
import os
import time
from typing import Dict, List, Set
from typing import Dict, List, Set, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -121,6 +121,8 @@ def __init__(self):
self.out_path = ""
self.outputRanks = None
self.colls_per_batch = -1
self.coll_in_batch_num = 0
self.replay_start_time = -1
self.use_timestamp = False
self.num_replays = 1
self.profiler_num_replays_start = 0
Expand Down Expand Up @@ -610,12 +612,35 @@ def hashEtCommsOp(self, commsOp: commsArgs) -> int:

return hash(op)

def generate_io_tensors(
self,
curComm: commsArgs,
commsParams: commsParamsHolderBase,
regenerateTensors: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
# Use exactly specified inMsgSize/outMsgSize if call from trace replay
# This avoid regenerating sizes such as in _prep_all_gather_base
commsParams.size_from_trace = True
commsParams.dtype = self.dtypeMap[curComm.dtype]
if not curComm.id or regenerateTensors:
return super().prepComm(curComm, commsParams)
else:
commsOpHash = self.hashEtCommsOp(curComm)
if commsOpHash in self.et_to_tensors:
# Allocate input/output tensors if first time replay, otherwise the previous ones.
super().prepComm(curComm, commsParams, False)
(ipTensor, opTensor) = self.et_to_tensors[commsOpHash]
else:
(ipTensor, opTensor) = super().prepComm(curComm, commsParams, True)
self.et_to_tensors[commsOpHash] = (ipTensor, opTensor)
return (ipTensor, opTensor)

def prepComms(
self,
curComm: commsArgs,
commsParams: commsParamsHolderBase,
regenerateTensors: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepares the appropriate tensors for the current collective communication.

Expand Down Expand Up @@ -686,22 +711,7 @@ def prepComms(
f"shrink message sizes to curInNumElem {curComm.inMsgSize}, curOutNumElem {curComm.outMsgSize}"
)

# Use exactly specified inMsgSize/outMsgSize if call from trace replay
# This avoid regenerating sizes such as in _prep_all_gather_base
commsParams.size_from_trace = True
commsParams.dtype = self.dtypeMap[curComm.dtype]
if not curComm.id or regenerateTensors:
return super().prepComm(curComm, commsParams)
else:
commsOpHash = self.hashEtCommsOp(curComm)
if commsOpHash in self.et_to_tensors:
# Allocate input/output tensors if first time replay, otherwise the previous ones.
super().prepComm(curComm, commsParams, False)
(ipTensor, opTensor) = self.et_to_tensors[commsOpHash]
else:
(ipTensor, opTensor) = super().prepComm(curComm, commsParams, True)
self.et_to_tensors[commsOpHash] = (ipTensor, opTensor)
return (ipTensor, opTensor)
return self.generate_io_tensors(curComm, commsParams, regenerateTensors)

def commRebalance(self, curComm: commsArgs) -> None:
"""
Expand Down Expand Up @@ -989,189 +999,133 @@ def replayTrace(
Returns:
None
"""
self.coll_in_batch_num = 0
self.replay_start_time = time.monotonic_ns()
for cnt, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]):
self.replaySingle(commsParams, curComm, cnt, warmup)

def replaySingle(self, commsParams: commsParamsHolderBase, curComm: commsArgs, cnt: int, warmup: bool = False):
if warmup:
logLable = "[Warm-up]"
else:
logLable = f"[Replay {self.replayIter}]"

coll_in_batch_num = 0
startTime = time.monotonic_ns()
for cnt, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]):
curBlocks = curComm.markerStack if curComm.markerStack is not None else []
curBlockStack = (
" ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown"
)

# Replay compute
if curComm.compute is not None:
# Prepare to run the compute function
computeFunc = self.prepComputeReplay(commsParams, curComm)

# Running the kernel
logger.info(
f"{logLable}[Rank {self.collectiveArgs.global_rank:3}] [{cnt+1} / {self.max_msg_cnt}] Replaying {curComm.compute}"
)

# Run the kernel and report the total time
(latency, global_latency) = self.runCompute(
func=computeFunc, curBlockStack=curBlockStack
)
recordName = curComm.compute

# Replay comm
else:
if warmup:
self.commRebalance(curComm)

# Get the name of the collective from the comm object
collName = paramToCommName(curComm.comms)
(groupRank, groupDesc) = self.getCommGroupInfo(curComm, commsParams)
# Skip comm if the local process doesn't belong to the PG or encounter an unexpected collective
if (
collName not in self.allowList
or groupRank == -1
or (
collName in ("send", "isend")
and curComm.src_rank != self.backendFuncs.get_global_rank()
)
or (
collName in ("recv", "irecv")
and curComm.dst_rank != self.backendFuncs.get_global_rank()
)
):
continue
curBlocks = curComm.markerStack if curComm.markerStack is not None else []
curBlockStack = (
" ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown"
)

if groupRank >= 0:
commDesc = f"{str(curComm.comms)}: NumElemsIn={curComm.inMsgSize}, NumElemsOut={curComm.outMsgSize}, Dtype={curComm.dtype}"
if curComm.comms == "all_to_allv":
commDesc += (
f", InSplit={curComm.inSplit}, OutSplit={curComm.outSplit}"
)
if curComm.comms in supportedP2pOps:
commDesc += f", Src_Rank={curComm.src_rank}, Dst_Rank={curComm.dst_rank}"
logger.info(
f"{logLable}[Rank {self.collectiveArgs.global_rank:3}] [{cnt+1} / {self.max_msg_cnt}] Replaying {commDesc} with {groupDesc}"
)
# Replay compute
if curComm.compute is not None:
# Prepare to run the compute function
computeFunc = self.prepComputeReplay(commsParams, curComm)

# read fields and prepare the tensors
(
self.collectiveArgs.ipTensor,
self.collectiveArgs.opTensor,
) = self.prepComms(curComm, commsParams, not self.reuse_tensors)
# Running the kernel
logger.info(
f"{logLable}[Rank {self.collectiveArgs.global_rank:3}] [{cnt+1} / {self.max_msg_cnt}] Replaying {curComm.compute}"
)

if not warmup and self.colls_per_batch > 0 and coll_in_batch_num == 0:
batch_begin = time.monotonic()
# Run the kernel and report the total time
(latency, global_latency) = self.runCompute(
func=computeFunc, curBlockStack=curBlockStack
)
recordName = curComm.compute

# wait for collective timestamp if enabled.
if not warmup and self.use_timestamp:
self.waitForTimestamp(curComm, startTime)
# Replay comm
else:
if warmup:
self.commRebalance(curComm)

# send comm request to pytorch backend
(latency, global_latency) = self.runComms(
collName, curComm, curBlockStack
# Get the name of the collective from the comm object
collName = paramToCommName(curComm.comms)
(groupRank, groupDesc) = self.getCommGroupInfo(curComm, commsParams)
# Skip comm if the local process doesn't belong to the PG or encounter an unexpected collective
if (
collName not in self.allowList
or groupRank == -1
or (
collName in ("send", "isend")
and curComm.src_rank != self.backendFuncs.get_global_rank()
)

# perform data validation check on the final opTensor
if (
self.is_blocking
and commsParams.dcheck == 1
and collName not in ("wait", "barrier")
):
commsParams.collective = collName
commsParams.srcOrDst = (
curComm.root if curComm.root is not None else 0
)
self.dcheck(
commsParams, curComm.outMsgSize, self.collectiveArgs.opTensor
)

# calculating batch latency (batch defined by --colls-per-batch)
if not warmup and collName == "wait" and self.colls_per_batch > 0:
coll_in_batch_num += 1
if coll_in_batch_num == self.colls_per_batch:
batch_latency = (
time.monotonic() - batch_begin
) * 1e3 # make it millisecond
coll_in_batch_num = 0
self.batchLat.append(batch_latency)

recordName = collName

if not warmup:
# record performance metrics
self.recordCommReplay(
commsParams,
curComm,
recordName,
latency,
curBlockStack,
global_latency,
curBlocks,
or (
collName in ("recv", "irecv")
and curComm.dst_rank != self.backendFuncs.get_global_rank()
)
):
return

if self.backendFuncs.get_global_rank() == 0:
if groupRank >= 0:
commDesc = f"{str(curComm.comms)}: NumElemsIn={curComm.inMsgSize}, NumElemsOut={curComm.outMsgSize}, Dtype={curComm.dtype}"
if curComm.comms == "all_to_allv":
commDesc += (
f", InSplit={curComm.inSplit}, OutSplit={curComm.outSplit}"
)
if curComm.comms in supportedP2pOps:
commDesc += f", Src_Rank={curComm.src_rank}, Dst_Rank={curComm.dst_rank}"
logger.info(
f"{logLable}[{cnt+1} / {self.max_msg_cnt}] Replayed {recordName} in block [{curBlockStack}]... {global_latency:.2f} us"
f"{logLable}[Rank {self.collectiveArgs.global_rank:3}] [{cnt+1} / {self.max_msg_cnt}] Replaying {commDesc} with {groupDesc}"
)

def replaySingle(
self,
commsParams: commsParamsHolderBase,
id: int,
regenerateTensors: bool = True,
) -> torch.Tensor:
"""
Replay comms trace.
Args:
commsParams: Run-time parameters for replay.
id: comms op id.
Returns:
Output tensor.
"""
for _, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]):
if curComm.id == id:
collName = paramToCommName(curComm.comms)
if collName not in self.allowList:
return torch.Tensor()

curBlocks = (
curComm.markerStack if curComm.markerStack is not None else []
)
curBlockStack = (
" ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown"
)
# read fields and prepare the tensors
(
self.collectiveArgs.ipTensor,
self.collectiveArgs.opTensor,
) = self.prepComms(curComm, commsParams, not self.reuse_tensors)

if self.backendFuncs.get_global_rank() == 0:
logger.debug(
f"[Rank {self.collectiveArgs.global_rank:3}] Replaying \n{str(curComm.comms)}\n"
)
if not warmup and self.colls_per_batch > 0 and self.coll_in_batch_num == 0:
batch_begin = time.monotonic()

# read fields and prepare the tensors
(
self.collectiveArgs.ipTensor,
self.collectiveArgs.opTensor,
) = self.prepComms(curComm, commsParams, regenerateTensors)
# wait for collective timestamp if enabled.
if not warmup and self.use_timestamp:
self.waitForTimestamp(curComm, self.replay_start_time)

# send comm request to pytorch backend
(latency, global_latency) = self.runComms(
collName, curComm, curBlockStack
)

# send comm request to pytorch backend
(latency, global_latency) = self.runComms(
collName, curComm, curBlockStack
# perform data validation check on the final opTensor
if (
self.is_blocking
and commsParams.dcheck == 1
and collName not in ("wait", "barrier")
):
commsParams.collective = collName
commsParams.srcOrDst = (
curComm.root if curComm.root is not None else 0
)
self.dcheck(
commsParams, curComm.outMsgSize, self.collectiveArgs.opTensor
)

# perform data validation check on the final opTensor
if (
self.is_blocking
and commsParams.dcheck == 1
and collName not in ("wait", "barrier")
):
commsParams.collective = collName
commsParams.srcOrDst = (
curComm.root if curComm.root is not None else 0
)
self.dcheck(
commsParams, curComm.outMsgSize, self.collectiveArgs.opTensor
)
# calculating batch latency (batch defined by --colls-per-batch)
if not warmup and collName == "wait" and self.colls_per_batch > 0:
self.coll_in_batch_num += 1
if self.coll_in_batch_num == self.colls_per_batch:
batch_latency = (
time.monotonic() - batch_begin
) * 1e3 # make it millisecond
self.coll_in_batch_num = 0
self.batchLat.append(batch_latency)

recordName = collName

if not warmup:
# record performance metrics
self.recordCommReplay(
commsParams,
curComm,
recordName,
latency,
curBlockStack,
global_latency,
curBlocks,
)

return self.collectiveArgs.opTensor
if self.backendFuncs.get_global_rank() == 0:
logger.info(
f"{logLable}[{cnt+1} / {self.max_msg_cnt}] Replayed {recordName} in block [{curBlockStack}]... {global_latency:.2f} us"
)

def benchTime(self, commsParams: commsParamsHolderBase) -> None:
"""
Expand Down
Loading
Loading