diff --git a/train/comms/pt/comms.py b/train/comms/pt/comms.py index bcc48b52..e0a800ec 100755 --- a/train/comms/pt/comms.py +++ b/train/comms/pt/comms.py @@ -169,6 +169,13 @@ def readArgs(self, parser): "For add or sub, '--mm-dim m n' uses the dimension of input annd output tensors are (m x n)" "If only one value is provided, it uses the dimension of input and output tensors are (n x n), i.e., square tensors", ) # Matrix multiplication dim n, A[m,n] * B [n,p] + parser.add_argument( + "--comp-data-type", + type=str, + default="float32", + help="datatype for GEMM or other compute kernels except emb_lookup" + + str(self.supportedDtype), + ) # For emb lookup parser.add_argument( "--emb-dim", @@ -362,6 +369,13 @@ def checkBasicArgs(self, args): args.collective = self._checkPt2Pt(args) args.device = self._check_device_type(args) + if args.comp_data_type not in self.supportedDtype: + logger.error( + f"Specified comp datatype: {args.comp_data_type} is not one of the supported commstyle: {str(self.supportedDtype)}" + ) + comms_utils.gracefulExit() + args.comp_data_type = self.dtypeMap[args.comp_data_type] + if args.size_start_profiler: args.size_start_profiler = comms_utils.parsesize(args.size_start_profiler) @@ -974,13 +988,13 @@ def initCollectiveArgs(self, commsParams): commsParams.mm_dim[1], commsParams.mm_dim[1], commsParams.mm_dim[2], - commsParams.dtype, + commsParams.comp_data_type, curDevice, ) if self.report: print( - f"[Rank {global_rank:>3}] mode: {commsParams.mode}, num_coll: {commsParams.num_coll}, kernel: {commsParams.kernel}, num_compute {commsParams.num_compute}, mm_dim {commsParams.mm_dim}" + f"[Rank {global_rank:>3}] mode: {commsParams.mode}, num_coll: {commsParams.num_coll}, collectives datatype: {commsParams.data_type}, kernel: {commsParams.kernel}, num_compute {commsParams.num_compute}, mm_dim {commsParams.mm_dim}, comp_datatype {self.collectiveArgs.MMout.dtype} " ) elif commsParams.kernel == "emb_lookup": comms_utils.init_emb_lookup( @@ -1003,7 +1017,7 @@ def initCollectiveArgs(self, commsParams): ) = self.prepComp( commsParams.mm_dim[0], commsParams.mm_dim[1], - commsParams.dtype, + commsParams.comp_data_type, curDevice, commsParams.kernel, ) diff --git a/train/comms/pt/comms_utils.py b/train/comms/pt/comms_utils.py index fc01cdc0..f5b32ab7 100644 --- a/train/comms/pt/comms_utils.py +++ b/train/comms/pt/comms_utils.py @@ -854,6 +854,7 @@ def __init__( self.inSplit = args.i self.outSplit = args.o self.data_type = args.data_type + self.comp_data_type = args.comp_data_type self.stepFactor = args.f self.stepBytes = args.sb self.srcOrDst = args.root @@ -1361,18 +1362,19 @@ def prepGemmNotSquare( mm0_dim1: int, mm1_dim0: int, mm1_dim1: int, - dtype: str, + dtype: torch.dtype, curDevice: str, gemmTensor: torch.tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if gemmTensor is None: - in1 = np.random.rand(mm0_dim0, mm0_dim1) - in2 = np.random.rand(mm1_dim0, mm1_dim1) - - MMin1 = torch.FloatTensor(in1).to(curDevice) - MMin2 = torch.FloatTensor(in2).to(curDevice) + MMin1 = self.backendFuncs.alloc_random( + (mm0_dim0, mm0_dim1), curDevice, dtype + ) + MMin2 = self.backendFuncs.alloc_random( + (mm1_dim0, mm1_dim1), curDevice, dtype + ) MMout = self.backendFuncs.alloc_empty( - (mm0_dim0, mm1_dim1), dtype, curDevice + (mm0_dim0, mm1_dim1), curRankDevice=curDevice, dtype=dtype ) else: mm_size0 = mm0_dim0 * mm0_dim1 @@ -1393,7 +1395,7 @@ def prepComp( self, mm_dim0: int, mm_dim1: int, - dtype: str, + dtype: torch.dtype, curDevice: str, kernel: str, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -1404,16 +1406,20 @@ def prepComp( [mm_dim0, mm_dim1], curDevice, dtype ) compOut = self.backendFuncs.alloc_empty( - [mm_dim0, mm_dim1], dtype, curDevice + [mm_dim0, mm_dim1], curDevice, dtype ) else: compOut = self.backendFuncs.alloc_empty( - [mm_dim0, mm_dim1], dtype, curDevice + [mm_dim0, mm_dim1], curDevice, dtype ) return (compOut, compIn1, compIn2) def prepGemm( - self, mm_dim: int, dtype: str, curDevice: str, gemmTensor: torch.tensor = None + self, + mm_dim: int, + dtype: torch.dtype, + curDevice: str, + gemmTensor: torch.tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.prepGemmNotSquare( mm_dim, mm_dim, mm_dim, mm_dim, dtype, curDevice, gemmTensor diff --git a/train/comms/pt/pytorch_dist_backend.py b/train/comms/pt/pytorch_dist_backend.py index 9f874315..d7e76b1d 100644 --- a/train/comms/pt/pytorch_dist_backend.py +++ b/train/comms/pt/pytorch_dist_backend.py @@ -820,7 +820,7 @@ def alloc_embedding_tables(self, n, m, curRankDevice, dtype): ) return EE - def alloc_empty(self, sizeArr, dtype, curRankDevice): + def alloc_empty(self, sizeArr, curRankDevice, dtype): return torch.empty(sizeArr, device=curRankDevice, dtype=dtype) def clear_memory(self, collectiveArgs):