Skip to content

Commit

Permalink
ENH: Low memory adjoint
Browse files Browse the repository at this point in the history
This commit is an upgrade of the adjoint model allowing time step
checkpoints to lower the maximum memory usage. By default, in the
forward sweep of the adjoint, at each time, n arrays of shape (nrow, ncol)
where pushed. This arrays were poped at the end of the forward sweep which
result in some case (large domain and time steps) to a too high memory peak.
This commit adds checkpoints in the forward sweep to reduce the memory peak
and pass the variables that must be checkpointed at each time (i.e. fluxes and
states) to a vector of active cell instead of a grid of number of rows and columns.

This upgrade should be pretty efficient for large domain and time step calibration
  • Loading branch information
inoelloc committed Mar 5, 2024
1 parent 255a212 commit 8f1ebb0
Show file tree
Hide file tree
Showing 24 changed files with 4,348 additions and 3,528 deletions.
40 changes: 28 additions & 12 deletions doc/source/contributor_guide/development_process_details.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1107,36 +1107,52 @@ located here, ``tapenade/makefile``. The generation of the ``forward_db.f90`` fi
.. code-block:: fortran
subroutine store_timestep(mesh, output, returns, t, iret, q)
subroutine store_time_step(setup, mesh, output, returns, checkpoint_variable, time_step)
implicit none
type(SetupDT), intent(in) :: setup
type(MeshDT), intent(in) :: mesh
type(OutputDT), intent(inout) :: output
type(ReturnsDT), intent(inout) :: returns
integer, intent(in) :: t
integer, intent(inout) :: iret
real(sp), dimension(mesh%nrow, mesh%ncol), intent(in) :: q
type(Checkpoint_VariableDT), intent(in) :: checkpoint_variable
integer, intent(in) :: time_step
integer :: i
integer :: i, k, time_step_returns
do i = 1, mesh%ng
output%response%q(i, t) = q(mesh%gauge_pos(i, 1), mesh%gauge_pos(i, 2))
k = mesh%rowcol_to_ind_ac(mesh%gauge_pos(i, 1), mesh%gauge_pos(i, 2))
output%response%q(i, time_step) = checkpoint_variable%ac_qz(k, setup%nqz)
end do
!$AD start-exclude
if (allocated(returns%mask_time_step)) then
if (returns%mask_time_step(t)) then
iret = iret + 1
if (returns%rr_states_flag) returns%rr_states(iret) = output%rr_final_states
if (returns%q_domain_flag) returns%q_domain(:, :, iret) = q
if (returns%mask_time_step(time_step)) then
time_step_returns = returns%time_step_to_returns_time_step(time_step)
!% Return states
if (returns%rr_states_flag) then
do i = 1, setup%nrrs
call ac_vector_to_matrix(mesh, checkpoint_variable%ac_rr_states(:, i), &
& returns%rr_states(time_step_returns)%values(:, :, i))
end do
end if
!% Return discharge grid
if (returns%q_domain_flag) then
call ac_vector_to_matrix(mesh, checkpoint_variable%ac_qz(:, setup%nqz), &
& returns%q_domain(:, :, time_step_returns))
end if
end if
end if
!$AD end-exclude
end subroutine store_timestep
end subroutine store_time_step
Why has this section of code been removed from the differentiation? Firstly, Tapenade was returning a warning
(for some reason) and secondly, quite simply, this section allows you to store intermediate results which
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ are given in the table below.
- :math:`\max_{t\in\mathbf{E}} Q(t)`
- mm

where :math:`dt` is the timestep.
where :math:`dt` is the time step.

Now, denote :math:`S_s^*` and :math:`S_s` are observed and simulated signature type respectively. For each signature type :math:`s`,
the corresponding signature based efficiency metric is computed depending on if the signature is:
Expand Down
36 changes: 20 additions & 16 deletions f90wrap/finalize_f90wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def sed_internal_import(pyf: pathlib.PosixPath):
os.system(f'sed -i "s/from libfcore/from smash.fcore/g" {pyf}')


def get_flagged_attr(f90f: pathlib.PosixPath) -> dict[str, list[str]]:
def get_flagged_attr(f90f: pathlib.PosixPath) -> dict[str, set[str]]:
"""
Get the flagged derived type attributes in Fortran
Expand Down Expand Up @@ -120,33 +120,37 @@ def get_flagged_attr(f90f: pathlib.PosixPath) -> dict[str, list[str]]:
It adds one "_" at the beginning of each pseudo-private attribute
"""

index = []
index_array = []
char = []
char_array = []
private = []
index = set()
index_array = set()
char = set()
char_array = set()
private = set()

with open(f90f) as f:
for line in f:
# pass commented line
if line.strip().startswith("!"):
continue

if "!$F90W" in line:
ind_double_2dots = line.find("::") + 2

subline = line[ind_double_2dots:].strip().lower()

if "index-array" in subline:
index_array.append(subline.split(" ")[0])
index_array.add(subline.split(" ")[0])

elif "index" in subline:
index.append(subline.split(" ")[0])
index.add(subline.split(" ")[0])

if "char-array" in subline:
char_array.append(subline.split(" ")[0])
char_array.add(subline.split(" ")[0])

elif "char" in subline:
char.append(subline.split(" ")[0])
char.add(subline.split(" ")[0])

if "private" in subline:
private.append(subline.split(" ")[0])
private.add(subline.split(" ")[0])

res = {
"index": index,
Expand All @@ -159,7 +163,7 @@ def get_flagged_attr(f90f: pathlib.PosixPath) -> dict[str, list[str]]:
return res


def sed_index_decorator(pyf: pathlib.PosixPath, attribute: list[str]):
def sed_index_decorator(pyf: pathlib.PosixPath, attribute: set[str]):
"""
Modify Python script to handle index decorator for specific attributes.
Done by using the unix command sed in place
Expand Down Expand Up @@ -187,7 +191,7 @@ def sed_index_decorator(pyf: pathlib.PosixPath, attribute: list[str]):
os.system(f'sed -i "/\\b{attr}.setter/a \\\t\\@f90wrap_setter_index" {pyf}')


def sed_index_array_decorator(pyf: pathlib.PosixPath, attribute: list[str]):
def sed_index_array_decorator(pyf: pathlib.PosixPath, attribute: set[str]):
"""
Modify Python script to handle index array decorator for specific attributes.
Done by using the unix command sed in place
Expand Down Expand Up @@ -215,7 +219,7 @@ def sed_index_array_decorator(pyf: pathlib.PosixPath, attribute: list[str]):
os.system(f'sed -i "/\\b{attr}.setter/a \\\t\\@f90wrap_setter_index_array" {pyf}')


def sed_char_decorator(pyf: pathlib.PosixPath, attribute: list[str]):
def sed_char_decorator(pyf: pathlib.PosixPath, attribute: set[str]):
"""
Modify Python script to handle character decorator for specific attributes.
Done by using the unix command sed in place
Expand All @@ -241,7 +245,7 @@ def sed_char_decorator(pyf: pathlib.PosixPath, attribute: list[str]):
os.system(f'sed -i "/def {attr}(self)/i \\\t\\@f90wrap_getter_char" {pyf}')


def sed_char_array_decorator(pyf: pathlib.PosixPath, attribute: list[str]):
def sed_char_array_decorator(pyf: pathlib.PosixPath, attribute: set[str]):
"""
Modify Python script to handle character array decorator for specific attributes.
Done by using the unix command sed in place
Expand Down Expand Up @@ -269,7 +273,7 @@ def sed_char_array_decorator(pyf: pathlib.PosixPath, attribute: list[str]):
os.system(f'sed -i "/\\b{attr}.setter/a \\\t\\@f90wrap_setter_char_array" {pyf}')


def sed_private_property(pyf: pathlib.PosixPath, attribute: list[str]):
def sed_private_property(pyf: pathlib.PosixPath, attribute: set[str]):
"""
Modify Python script make pseudo-private property for specific attributes.
Done by using the unix command sed in place
Expand Down
Loading

0 comments on commit 8f1ebb0

Please sign in to comment.