Skip to content

Commit

Permalink
revert onnx, pr, md bag changes (facebookresearch#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
YazhiGao authored Dec 7, 2020
1 parent b9c61a6 commit 520515e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 25 deletions.
91 changes: 67 additions & 24 deletions dlrm_s_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@

exc = getattr(builtins, "IOError", "FileNotFoundError")


def time_wrap(use_gpu):
if use_gpu:
torch.cuda.synchronize()
return time.time()


def dlrm_wrap(X, lS_o, lS_i, use_gpu, device, ndevices=1):
with record_function("DLRM forward"):
if use_gpu: # .cuda()
Expand All @@ -132,6 +134,7 @@ def dlrm_wrap(X, lS_o, lS_i, use_gpu, device, ndevices=1):
)
return dlrm(X.to(device), lS_o, lS_i)


def loss_fn_wrap(Z, T, use_gpu, device):
with record_function("DLRM loss compute"):
if args.loss_function == "mse" or args.loss_function == "bce":
Expand All @@ -142,6 +145,7 @@ def loss_fn_wrap(Z, T, use_gpu, device):
loss_sc_ = loss_ws_ * loss_fn_
return loss_sc_.mean()


# The following function is a wrapper to avoid checking this multiple times in th
# loop below.
def unpack_batch(b):
Expand Down Expand Up @@ -248,8 +252,8 @@ def create_emb(self, m, ln, weighted_pooling=None):
sparse=True,
)
elif self.md_flag and n > self.md_threshold:
_m = m[i]
base = max(m)
_m = m[i] if n > self.md_threshold else base
EE = PrEmbeddingBag(n, _m, base)
# use np initialization as below for consistency...
W = np.random.uniform(
Expand Down Expand Up @@ -496,19 +500,13 @@ def interact_features(self, x, ly):
def forward(self, dense_x, lS_o, lS_i):
if ext_dist.my_size > 1:
# multi-node multi-device run
return self.distributed_forward(
dense_x, lS_o, lS_i
)
return self.distributed_forward(dense_x, lS_o, lS_i)
elif self.ndevices <= 1:
# single device run
return self.sequential_forward(
dense_x, lS_o, lS_i
)
return self.sequential_forward(dense_x, lS_o, lS_i)
else:
# single-node multi-device run
return self.parallel_forward(
dense_x, lS_o, lS_i
)
return self.parallel_forward(dense_x, lS_o, lS_i)

def distributed_forward(self, dense_x, lS_o, lS_i):
batch_size = dense_x.size()[0]
Expand All @@ -535,9 +533,7 @@ def distributed_forward(self, dense_x, lS_o, lS_i):

# embeddings
with record_function("DLRM embedding forward"):
ly = self.apply_emb(
lS_o, lS_i, self.emb_l, self.v_W_l
)
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)

# WARNING: Note that at this point we have the result of the embedding lookup
# for the entire batch on each rank. We would like to obtain partial results
Expand Down Expand Up @@ -579,9 +575,7 @@ def sequential_forward(self, dense_x, lS_o, lS_i):
# print(x.detach().cpu().numpy())

# process sparse features(using embeddings), resulting in a list of row vectors
ly = self.apply_emb(
lS_o, lS_i, self.emb_l, self.v_W_l
)
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
# for y in ly:
# print(y.detach().cpu().numpy())

Expand Down Expand Up @@ -666,9 +660,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i):
# print(x)

# embeddings
ly = self.apply_emb(
lS_o, lS_i, self.emb_l, self.v_W_l
)
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
# debug prints
# print(ly)

Expand Down Expand Up @@ -778,7 +770,6 @@ def inference(
print("Warning: Skiping the batch %d with size %d" % (i, X_test.size(0)))
continue


# forward pass
Z_test = dlrm_wrap(
X_test,
Expand Down Expand Up @@ -1085,9 +1076,7 @@ def run():
mlperf_logger.barrier()

if args.data_generation == "dataset":
train_data, train_ld, test_data, test_ld = dp.make_criteo_data_and_loaders(
args
)
train_data, train_ld, test_data, test_ld = dp.make_criteo_data_and_loaders(args)
table_feature_map = {idx: idx for idx in range(len(train_data.counts))}
nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
nbatches_test = len(test_ld)
Expand Down Expand Up @@ -1800,13 +1789,67 @@ def run():
# print("inputs", X_onnx, lS_o_onnx, lS_i_onnx)
# print("output", dlrm_wrap(X_onnx, lS_o_onnx, lS_i_onnx, use_gpu, device))
dlrm_pytorch_onnx_file = "dlrm_s_pytorch.onnx"
batch_size = X_onnx.shape[0]
print("X_onnx.shape", X_onnx.shape)
if torch.is_tensor(lS_o_onnx):
print("lS_o_onnx.shape", lS_o_onnx.shape)
else:
for oo in lS_o_onnx:
print("oo.shape", oo.shape)
if torch.is_tensor(lS_i_onnx):
print("lS_i_onnx.shape", lS_i_onnx.shape)
else:
for ii in lS_i_onnx:
print("ii.shape", ii.shape)

# name inputs and outputs
o_inputs = (
["offsets"]
if torch.is_tensor(lS_o_onnx)
else ["offsets_" + str(i) for i in range(len(lS_o_onnx))]
)
i_inputs = (
["indices"]
if torch.is_tensor(lS_i_onnx)
else ["indices_" + str(i) for i in range(len(lS_i_onnx))]
)
all_inputs = ["dense_x"] + o_inputs + i_inputs
# debug prints
print("inputs", all_inputs)

# create dynamic_axis dictionaries
do_inputs = (
[{"offsets": {1: "batch_size"}}]
if torch.is_tensor(lS_o_onnx)
else [
{"offsets_" + str(i): {0: "batch_size"}} for i in range(len(lS_o_onnx))
]
)
di_inputs = (
[{"indices": {1: "batch_size"}}]
if torch.is_tensor(lS_i_onnx)
else [
{"indices_" + str(i): {0: "batch_size"}} for i in range(len(lS_i_onnx))
]
)
dynamic_axes = {"dense_x": {0: "batch_size"}, "pred": {0: "batch_size"}}
for do in do_inputs:
dynamic_axes.update(do)
for di in di_inputs:
dynamic_axes.update(di)
# debug prints
print(dynamic_axes)
# export model
torch.onnx.export(
dlrm,
(X_onnx, lS_o_onnx, lS_i_onnx),
dlrm_pytorch_onnx_file,
verbose=True,
use_external_data_format=True,
opset_version=10,
opset_version=11,
input_names=all_inputs,
output_names=["pred"],
dynamic_axes=dynamic_axes,
)
# recover the model back
dlrm_pytorch_onnx = onnx.load("dlrm_s_pytorch.onnx")
Expand Down
5 changes: 4 additions & 1 deletion tricks/md_embedding_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def md_solver(n, alpha, d0=None, B=None, round_dim=True, k=None):
d = alpha_power_rule(n.type(torch.float) / k, alpha, d0=d0, B=B)
if round_dim:
d = pow_2_round(d)
return d
undo_sort = [0] * len(indices)
for i, v in enumerate(indices):
undo_sort[v] = i
return d[undo_sort]


def alpha_power_rule(n, alpha, d0=None, B=None):
Expand Down

0 comments on commit 520515e

Please sign in to comment.