Skip to content

Commit

Permalink
Tweak/add kind to gelu benchmark name (tracel-ai#1533)
Browse files Browse the repository at this point in the history
* Add kind to gelu benchmark name

* [backend-comparison] Compute column size in benchmarks report
  • Loading branch information
syl20bnr authored Mar 28, 2024
1 parent 279be04 commit 32a8d80
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
5 changes: 2 additions & 3 deletions backend-comparison/benches/custom_gelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ impl<B: Backend, const D: usize> Benchmark for CustomGeluBenchmark<B, D> {

fn name(&self) -> String {
match self.autodiff {
true => "gelu_autodiff",
false => "gelu",
true => format!("gelu_autodiff_{:?}", self.kind),
false => format!("gelu_{:?}", self.kind),
}
.into()
}

fn options(&self) -> Option<String> {
Expand Down
25 changes: 19 additions & 6 deletions backend-comparison/src/persistence/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,33 @@ pub(crate) struct BenchmarkCollection {

impl Display for BenchmarkCollection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Compute the max length for each column
let mut max_name_len = 0;
let mut max_backend_len = 0;
for record in self.records.iter() {
let backend_name = [record.backend.clone(), record.device.clone()].join("-");
max_name_len = max_name_len.max(record.results.name.len());
max_backend_len = max_backend_len.max(backend_name.len());
}
// Header
writeln!(
f,
"| {0:<15}| {1:<35}| {2:<15}|\n|{3:-<16}|{4:-<36}|{5:-<16}|",
"Benchmark", "Backend", "Median", "", "", ""
"| {:<width_name$} | {:<width_backend$} | Median |\n|{:->width_name$}--|{:->width_backend$}--|----------------|",
"Benchmark", "Backend", "", "", width_name = max_name_len, width_backend = max_backend_len
)?;
// Table entries
for record in self.records.iter() {
let backend = [record.backend.clone(), record.device.clone()].join("-");
let backend_name = [record.backend.clone(), record.device.clone()].join("-");
writeln!(
f,
"| {0:<15}| {1:<35}| {2:<15.3?}|",
record.results.name, backend, record.results.computed.median
"| {:<width_name$} | {:<width_backend$} | {:<15.3?}|",
record.results.name,
backend_name,
record.results.computed.median,
width_name = max_name_len,
width_backend = max_backend_len
)?;
}

Ok(())
}
}
Expand Down

0 comments on commit 32a8d80

Please sign in to comment.