Skip to content

Commit

Permalink
Datatype precision for compute kernel
Browse files Browse the repository at this point in the history
Summary:
Adding precision for compute kernels with an option `--comp-data-type`.  This enables comm and comput kernels to have different datatype precisions.

## Background:
1) In the previous code, the output of the MatMul is always with FP32.
However, in the workloads, the mm uses different precision type {bfloat16, fp32 etc}.
2) In the previous code, the input datatype supplied to compute functions were also converted to FP32.
This diff fixes this by using the comp-data-type option.

Test case: N5625782 that returned FP32

Reviewed By: kingchc

Differential Revision: D59994979

fbshipit-source-id: 7a834b37125f5c2b2b7ffd86520790e66b4754dd
  • Loading branch information
ashramac authored and facebook-github-bot committed Jul 21, 2024
1 parent 7c820e5 commit 7b19f58
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 15 deletions.
20 changes: 17 additions & 3 deletions train/comms/pt/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def readArgs(self, parser):
"For add or sub, '--mm-dim m n' uses the dimension of input annd output tensors are (m x n)"
"If only one value is provided, it uses the dimension of input and output tensors are (n x n), i.e., square tensors",
) # Matrix multiplication dim n, A[m,n] * B [n,p]
parser.add_argument(
"--comp-data-type",
type=str,
default="float32",
help="datatype for GEMM or other compute kernels except emb_lookup"
+ str(self.supportedDtype),
)
# For emb lookup
parser.add_argument(
"--emb-dim",
Expand Down Expand Up @@ -362,6 +369,13 @@ def checkBasicArgs(self, args):
args.collective = self._checkPt2Pt(args)
args.device = self._check_device_type(args)

if args.comp_data_type not in self.supportedDtype:
logger.error(
f"Specified comp datatype: {args.comp_data_type} is not one of the supported commstyle: {str(self.supportedDtype)}"
)
comms_utils.gracefulExit()
args.comp_data_type = self.dtypeMap[args.comp_data_type]

if args.size_start_profiler:
args.size_start_profiler = comms_utils.parsesize(args.size_start_profiler)

Expand Down Expand Up @@ -974,13 +988,13 @@ def initCollectiveArgs(self, commsParams):
commsParams.mm_dim[1],
commsParams.mm_dim[1],
commsParams.mm_dim[2],
commsParams.dtype,
commsParams.comp_data_type,
curDevice,
)

if self.report:
print(
f"[Rank {global_rank:>3}] mode: {commsParams.mode}, num_coll: {commsParams.num_coll}, kernel: {commsParams.kernel}, num_compute {commsParams.num_compute}, mm_dim {commsParams.mm_dim}"
f"[Rank {global_rank:>3}] mode: {commsParams.mode}, num_coll: {commsParams.num_coll}, collectives datatype: {commsParams.data_type}, kernel: {commsParams.kernel}, num_compute {commsParams.num_compute}, mm_dim {commsParams.mm_dim}, comp_datatype {self.collectiveArgs.MMout.dtype} "
)
elif commsParams.kernel == "emb_lookup":
comms_utils.init_emb_lookup(
Expand All @@ -1003,7 +1017,7 @@ def initCollectiveArgs(self, commsParams):
) = self.prepComp(
commsParams.mm_dim[0],
commsParams.mm_dim[1],
commsParams.dtype,
commsParams.comp_data_type,
curDevice,
commsParams.kernel,
)
Expand Down
28 changes: 17 additions & 11 deletions train/comms/pt/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ def __init__(
self.inSplit = args.i
self.outSplit = args.o
self.data_type = args.data_type
self.comp_data_type = args.comp_data_type
self.stepFactor = args.f
self.stepBytes = args.sb
self.srcOrDst = args.root
Expand Down Expand Up @@ -1361,18 +1362,19 @@ def prepGemmNotSquare(
mm0_dim1: int,
mm1_dim0: int,
mm1_dim1: int,
dtype: str,
dtype: torch.dtype,
curDevice: str,
gemmTensor: torch.tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if gemmTensor is None:
in1 = np.random.rand(mm0_dim0, mm0_dim1)
in2 = np.random.rand(mm1_dim0, mm1_dim1)

MMin1 = torch.FloatTensor(in1).to(curDevice)
MMin2 = torch.FloatTensor(in2).to(curDevice)
MMin1 = self.backendFuncs.alloc_random(
(mm0_dim0, mm0_dim1), curDevice, dtype
)
MMin2 = self.backendFuncs.alloc_random(
(mm1_dim0, mm1_dim1), curDevice, dtype
)
MMout = self.backendFuncs.alloc_empty(
(mm0_dim0, mm1_dim1), dtype, curDevice
(mm0_dim0, mm1_dim1), curRankDevice=curDevice, dtype=dtype
)
else:
mm_size0 = mm0_dim0 * mm0_dim1
Expand All @@ -1393,7 +1395,7 @@ def prepComp(
self,
mm_dim0: int,
mm_dim1: int,
dtype: str,
dtype: torch.dtype,
curDevice: str,
kernel: str,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -1404,16 +1406,20 @@ def prepComp(
[mm_dim0, mm_dim1], curDevice, dtype
)
compOut = self.backendFuncs.alloc_empty(
[mm_dim0, mm_dim1], dtype, curDevice
[mm_dim0, mm_dim1], curDevice, dtype
)
else:
compOut = self.backendFuncs.alloc_empty(
[mm_dim0, mm_dim1], dtype, curDevice
[mm_dim0, mm_dim1], curDevice, dtype
)
return (compOut, compIn1, compIn2)

def prepGemm(
self, mm_dim: int, dtype: str, curDevice: str, gemmTensor: torch.tensor = None
self,
mm_dim: int,
dtype: torch.dtype,
curDevice: str,
gemmTensor: torch.tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.prepGemmNotSquare(
mm_dim, mm_dim, mm_dim, mm_dim, dtype, curDevice, gemmTensor
Expand Down
2 changes: 1 addition & 1 deletion train/comms/pt/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def alloc_embedding_tables(self, n, m, curRankDevice, dtype):
)
return EE

def alloc_empty(self, sizeArr, dtype, curRankDevice):
def alloc_empty(self, sizeArr, curRankDevice, dtype):
return torch.empty(sizeArr, device=curRankDevice, dtype=dtype)

def clear_memory(self, collectiveArgs):
Expand Down

0 comments on commit 7b19f58

Please sign in to comment.