Skip to content

Commit

Permalink
Merge pull request #163 from firedrakeproject/timing_prepostproc
Browse files Browse the repository at this point in the history
Timing object for pre/postproc callbacks
  • Loading branch information
JHopeCollins authored Jan 29, 2024
2 parents 5b8441a + f5c20af commit 68d8b34
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
11 changes: 11 additions & 0 deletions case_studies/shallow_water/linear_gravity_bumps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from firedrake.petsc import PETSc

from utils.timing import SolverTimer
from utils import units
from utils.planets import earth
import utils.shallow_water as swe
Expand Down Expand Up @@ -144,14 +145,18 @@

paradiag = miniapp.paradiag

timer = SolverTimer()


def window_preproc(swe_app, pdg, wndw):
PETSc.Sys.Print('')
PETSc.Sys.Print(f'### === --- Calculating time-window {wndw} --- === ###')
PETSc.Sys.Print('')
timer.start_timing()


def window_postproc(swe_app, pdg, wndw):
timer.stop_timing()
if pdg.layout.is_local(miniapp.save_step):
nt = (pdg.total_windows - 1)*pdg.ntimesteps + (miniapp.save_step + 1)
time = nt*pdg.aaoform.dt
Expand Down Expand Up @@ -191,3 +196,9 @@ def window_postproc(swe_app, pdg, wndw):
PETSc.Sys.Print(f'Maximum CFL = {max(miniapp.cfl_series)}')
PETSc.Sys.Print(f'Minimum CFL = {min(miniapp.cfl_series)}')
PETSc.Sys.Print('')

if timer.ntimes() > 1:
timer.times[0] = timer.times[1]

PETSc.Sys.Print(timer.string(timesteps_per_solve=window_length, ndigits=5))
PETSc.Sys.Print('')
1 change: 1 addition & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
import utils.vertical_slice # noqa: F401
import utils.mg # noqa: F401
import utils.diagnostics # noqa: F401
import utils.timing # noqa: F401
import utils.serial # noqa: F401
import utils.misc # noqa: F401
59 changes: 59 additions & 0 deletions utils/timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from mpi4py import MPI

__all__ = ['Timer', 'SolverTimer']


class Timer:
'''
Time multiple similar actions.
'''
def __init__(self):
self.times = []

def start_timing(self):
'''
Start timing an action. This should be the last statement before the action starts.
'''
self.times.append(MPI.Wtime())

def stop_timing(self):
'''
Stop timing an action. This should be the first statement after the action stops.
'''
etime = MPI.Wtime()
stime = self.times[-1]
self.times[-1] = etime - stime

def total_time(self):
'''
The total duration of all actions timed.
'''
return sum(self.times)

def ntimes(self):
'''
The total number of actions timed.
'''
return len(self.times)

def average_time(self):
'''
The average duration of an action.
'''
return self.total_time()/self.ntimes()


class SolverTimer(Timer):
'''
Time multiple solves and print out total/average etc times.
'''
def string(self, timesteps_per_solve=1, ndigits=None):
rnd = lambda x: x if ndigits is None else round(x, ndigits)
total_time = self.total_time()
average_time = self.average_time()
timestep_time = average_time/timesteps_per_solve
string = ''\
+ f'Total solution time: {rnd(total_time)}\n' \
+ f'Average solve solution time: {rnd(average_time)}\n' \
+ f'Average timestep solution time: {rnd(timestep_time)}'
return string

0 comments on commit 68d8b34

Please sign in to comment.