From 6640ebd22b45adb7c11317ad03caf69d197fa563 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 10 Oct 2023 16:21:49 -0400 Subject: [PATCH] merge conflict --- megablocks/layers/mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 258fac30..1454f496 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -4,7 +4,7 @@ from megablocks.layers import weight_parallel as wp from megablocks.layers.arguments import Arguments, InitFn from megablocks import turbo_util as turbo -from megablocks import grouped_gemm_util as grouped_gemm +from megablocks import grouped_gemm_util as gg import stk import torch import torch.nn.functional as F @@ -522,6 +522,6 @@ def forward(self, x, tokens_per_expert): self.args.quantize_rematerialize_num_bits) # Compute the MLP. - x = grouped_gemm.gmm(x, w1, batch_sizes, trans_b=True) + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x = F.gelu(x, approximate="tanh") - return grouped_gemm.gmm(x, w2, batch_sizes) + return gg.ops.gmm(x, w2, batch_sizes)