Skip to content

Commit

Permalink
Adding test set to the random dataset generator
Browse files Browse the repository at this point in the history
  • Loading branch information
aghaderi committed May 23, 2021
1 parent e9b659e commit 9acb4e1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
34 changes: 33 additions & 1 deletion dlrm_data_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,28 @@ def make_random_data_and_loader(args, ln_emb, m_den,
rand_seed=args.numpy_rand_seed
) # WARNING: generates a batch of lookups at once

test_data = RandomDataset(
m_den,
ln_emb,
args.data_size,
args.num_batches,
args.mini_batch_size,
args.num_indices_per_lookup,
args.num_indices_per_lookup_fixed,
1, # num_targets
args.round_targets,
args.data_generation,
args.data_trace_file,
args.data_trace_enable_padding,
reset_seed_on_access=True,
rand_data_dist=args.rand_data_dist,
rand_data_min=args.rand_data_min,
rand_data_max=args.rand_data_max,
rand_data_mu=args.rand_data_mu,
rand_data_sigma=args.rand_data_sigma,
rand_seed=args.numpy_rand_seed
)

collate_wrapper_random = collate_wrapper_random_offset
if offset_to_length_converter:
collate_wrapper_random = collate_wrapper_random_length
Expand All @@ -743,7 +765,17 @@ def make_random_data_and_loader(args, ln_emb, m_den,
pin_memory=False,
drop_last=False, # True
)
return train_data, train_loader

test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_wrapper_random,
pin_memory=False,
drop_last=False, # True
)
return train_data, train_loader, test_data, test_loader


def generate_random_data(
Expand Down
8 changes: 3 additions & 5 deletions dlrm_s_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,8 +1103,9 @@ def run():
# input and target at random
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
m_den = ln_bot[0]
train_data, train_ld = dp.make_random_data_and_loader(args, ln_emb, m_den)
train_data, train_ld, test_data, test_ld = dp.make_random_data_and_loader(args, ln_emb, m_den)
nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
nbatches_test = len(test_ld)

args.ln_emb = ln_emb.tolist()
if args.mlperf_logging:
Expand Down Expand Up @@ -1447,9 +1448,6 @@ def run():
if args.quantize_emb_with_bit != 32:
dlrm.quantize_embedding(args.quantize_emb_with_bit)
# print(dlrm)
assert (
args.data_generation == "dataset"
), "currently only dataset loader provides testset"

print("time/loss/accuracy (if enabled):")

Expand Down Expand Up @@ -1599,7 +1597,7 @@ def run():
)
should_test = (
(args.test_freq > 0)
and (args.data_generation == "dataset")
and (args.data_generation in ["dataset", "random"])
and (((j + 1) % args.test_freq == 0) or (j + 1 == nbatches))
)

Expand Down

0 comments on commit 9acb4e1

Please sign in to comment.