diff --git a/dlrm_data_pytorch.py b/dlrm_data_pytorch.py index b815205c..4f14cf9d 100644 --- a/dlrm_data_pytorch.py +++ b/dlrm_data_pytorch.py @@ -730,6 +730,28 @@ def make_random_data_and_loader(args, ln_emb, m_den, rand_seed=args.numpy_rand_seed ) # WARNING: generates a batch of lookups at once + test_data = RandomDataset( + m_den, + ln_emb, + args.data_size, + args.num_batches, + args.mini_batch_size, + args.num_indices_per_lookup, + args.num_indices_per_lookup_fixed, + 1, # num_targets + args.round_targets, + args.data_generation, + args.data_trace_file, + args.data_trace_enable_padding, + reset_seed_on_access=True, + rand_data_dist=args.rand_data_dist, + rand_data_min=args.rand_data_min, + rand_data_max=args.rand_data_max, + rand_data_mu=args.rand_data_mu, + rand_data_sigma=args.rand_data_sigma, + rand_seed=args.numpy_rand_seed + ) + collate_wrapper_random = collate_wrapper_random_offset if offset_to_length_converter: collate_wrapper_random = collate_wrapper_random_length @@ -743,7 +765,17 @@ def make_random_data_and_loader(args, ln_emb, m_den, pin_memory=False, drop_last=False, # True ) - return train_data, train_loader + + test_loader = torch.utils.data.DataLoader( + test_data, + batch_size=1, + shuffle=False, + num_workers=args.num_workers, + collate_fn=collate_wrapper_random, + pin_memory=False, + drop_last=False, # True + ) + return train_data, train_loader, test_data, test_loader def generate_random_data( diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py index 9b2f4522..eb352664 100644 --- a/dlrm_s_pytorch.py +++ b/dlrm_s_pytorch.py @@ -1103,8 +1103,9 @@ def run(): # input and target at random ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-") m_den = ln_bot[0] - train_data, train_ld = dp.make_random_data_and_loader(args, ln_emb, m_den) + train_data, train_ld, test_data, test_ld = dp.make_random_data_and_loader(args, ln_emb, m_den) nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) + nbatches_test = len(test_ld) args.ln_emb = ln_emb.tolist() if args.mlperf_logging: @@ -1447,9 +1448,6 @@ def run(): if args.quantize_emb_with_bit != 32: dlrm.quantize_embedding(args.quantize_emb_with_bit) # print(dlrm) - assert ( - args.data_generation == "dataset" - ), "currently only dataset loader provides testset" print("time/loss/accuracy (if enabled):") @@ -1599,7 +1597,7 @@ def run(): ) should_test = ( (args.test_freq > 0) - and (args.data_generation == "dataset") + and (args.data_generation in ["dataset", "random"]) and (((j + 1) % args.test_freq == 0) or (j + 1 == nbatches)) )