Skip to content

Commit

Permalink
[TESTING] Added precision option to benchmark CSV data saving (#2933)
Browse files Browse the repository at this point in the history
This PR adds the `save_precision` argument to `benchmark.run()`, and
makes its default value 6 (which is somewhat arbitrary, but seems
reasonable). Right now it is not user configurable, and has an unusually
low default of `.1f`.

**Context:**

I was using the benchmarking capability, and found that the CSV which
was being generated had unacceptably low precision (`.1f`). This was
after I had run a long-ish benchmarking session, and I was surprised
that many of my benchmarks had an inference time of 0.

For the benchmarks I am performing, I need a higher degree of precision.
I do not think the downsides of higher precision, namely larger file
sizes for the CSVs is relevant compared to the downsides of losing data.
By making the value configurable, this gives us the best of both worlds.
  • Loading branch information
Wheest authored Jan 15, 2024
1 parent 150bfd8 commit ded6242
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def __init__(self, fn, benchmarks):
self.fn = fn
self.benchmarks = benchmarks

def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, **kwrags):
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False,
save_precision=6, **kwrags):
import os

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -325,7 +326,8 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
print(bench.plot_name + ':')
print(df)
if save_path:
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f",
index=False)
return df

def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
Expand Down

0 comments on commit ded6242

Please sign in to comment.