Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Feb 1, 2025
1 parent 0d5da7c commit 7e4eedd
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions benchmark/classes/hash_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
if key1.is_cpu:
HashMap = torch.classes.pyg.CPUHashMap
elif key1.is_cuda:
HashMap = torch.classes.pyg.CUDHashMap
HashMap = torch.classes.pyg.CUDAHashMap
else:
raise NotImplementedError(f"Unsupported device '{device}'")
raise NotImplementedError(f"Unsupported device '{args.device}'")

t_init = t_get = 0
for i in range(num_warmups + num_steps):
Expand All @@ -62,7 +62,7 @@
t_start = time.perf_counter()
hash_map = torch.full((args.num_keys, ), fill_value=-1,
dtype=torch.long, device=args.device)
hash_map[key2] = torch.arange(args.num_keys)
hash_map[key2] = torch.arange(args.num_keys, device=args.device)
torch.cuda.synchronize()
if i >= num_warmups:
t_init += time.perf_counter() - t_start
Expand Down Expand Up @@ -92,7 +92,7 @@
if i >= num_warmups:
t_get += time.perf_counter() - t_start

print(f' Pandas Init: {t_init / num_steps:.4f}s')
print(f' Pandas Get: {t_get / num_steps:.4f}s')
print(f' Pandas Init: {t_init / num_steps:.4f}s')
print(f' Pandas Get: {t_get / num_steps:.4f}s')

assert out1.equal(torch.tensor(out3))
assert out1.equal(torch.tensor(out3))

0 comments on commit 7e4eedd

Please sign in to comment.