Skip to content

Commit

Permalink
Improve deletion of timer (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Dec 18, 2024
1 parent 1d04fdf commit 7f85729
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. _dev_errors:
.. _help_errors:

Common Errors
=============
Expand Down Expand Up @@ -65,3 +65,26 @@ To avoid this error, manually import ``torch._dynamo`` in the code. For example:
if __tversion__ in ((2, 3, 0), (2, 3, 1)):
import torch._dynamo
TimerError: Timer '<interaction>' is running. Use .stop() to stop it.
---------------------------------------------------------------------
This error occurs when a calculation is launched again without resetting.
If you don't need the timer, you can disable it.
The timer lives in the global space and always starts when importing `dxtb`.
.. code-block:: python
from dxtb import kill_timer
kill_timer()
If you only want to disable the timer temporarily, you can use the following code:
.. code-block:: python
from dxtb import timer
timer.disable()
File renamed without changes.
File renamed without changes.
File renamed without changes.
14 changes: 10 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,22 @@
Installation <03_for_developers/installation>
Testing <03_for_developers/testing>
Style <03_for_developers/style>
Common Errors <03_for_developers/errors>

.. toctree::
:hidden:
:caption: Help
:maxdepth: 2

Common Errors <04_help/errors>

.. toctree::
:hidden:
:caption: About
:maxdepth: 2

Literature <04_about/literature>
Related Works <04_about/related>
License <04_about/license>
Literature <05_about/literature>
Related Works <05_about/related>
License <05_about/license>

.. toctree::
:hidden:
Expand Down
98 changes: 98 additions & 0 deletions examples/issues/183/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch

import dxtb
from dxtb.typing import DD

dd: DD = {"device": torch.device("cpu"), "dtype": torch.double}

numbers = torch.tensor([8, 1, 1], device=dd["device"])
positions = torch.tensor(
[
[-2.95915993, 1.40005084, 0.24966306],
[-2.1362031, 1.4795743, -1.38758999],
[-2.40235213, 2.84218589, 1.24419946],
],
requires_grad=True,
**dd,
)

opts = {
"scf_mode": dxtb.labels.SCF_MODE_FULL,
"cache_enabled": True,
}
calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)
assert calc.integrals.hcore is not None


def get_energy_force(calc: dxtb.Calculator):
forces = calc.get_forces(positions, create_graph=True)
energy = calc.get_energy(positions)
return energy, forces


es2 = calc.interactions.get_interaction("ES2")
es2.gexp = es2.gexp.clone().detach().requires_grad_(True)

hcore = calc.integrals.hcore
hcore.selfenergy = hcore.selfenergy.clone().detach().requires_grad_(True)

# energy and AD force
# energy, force = get_energy_force(calc)

# AD gradient w.r.t. params
energy, force = get_energy_force(calc)
de_dparam = torch.autograd.grad(
energy, (es2.gexp, hcore.selfenergy), retain_graph=True
)

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)

es2 = calc.interactions.get_interaction("ES2")
es2.gexp = es2.gexp.clone().detach().requires_grad_(True)
hcore = calc.integrals.hcore
hcore.selfenergy = hcore.selfenergy.clone().detach().requires_grad_(True)

pos = positions.clone().detach().requires_grad_(True)
energy = calc.get_energy(pos)
force = -torch.autograd.grad(energy, pos, create_graph=True)[0]
dfnorm_dparam = torch.autograd.grad(
torch.norm(force), (es2.gexp, hcore.selfenergy)
)

# Numerical gradient w.r.t. params
dparam = 2e-6
calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd)
es2 = calc.interactions.get_interaction("ES2")

es2.gexp += dparam / 2
energy1, force1 = get_energy_force(calc)

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd)
es2 = calc.interactions.get_interaction("ES2")

es2.gexp -= dparam / 2
energy2, force2 = get_energy_force(calc)

de_dgexp = (energy1 - energy2) / dparam

print(f"dE / dgexp (AD) = {de_dparam[0]: .8f}")
print(f"dE / dgexp (Num) = {de_dgexp: .8f}")

dF_dgexp = (torch.norm(force1) - torch.norm(force2)) / dparam
print(f"d|F| / dgexp (AD) = {dfnorm_dparam[0]: .8f}")
print(f"d|F| / dgexp (Num) = {dF_dgexp: .8f}")

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)
calc.integrals.hcore.selfenergy[0] += dparam / 2
energy1, force1 = get_energy_force(calc)
calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)
calc.integrals.hcore.selfenergy[0] -= dparam / 2
energy2, force2 = get_energy_force(calc)

de_dp = (energy1 - energy2) / dparam
print(f"dE / dselfenergy[0] (AD) = {de_dparam[1][0]: .8f}")
print(f"dE / dselfenergy[0] (Num) = {de_dp: .8f}")

df_dp = (torch.norm(force1) - torch.norm(force2)) / dparam
print(f"d|F| / dselfenergy[0] (AD) = {dfnorm_dparam[1][0]: .8f}")
print(f"d|F| / dselfenergy[0] (Num) = {df_dp: .8f}")
53 changes: 53 additions & 0 deletions examples/issues/187/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch

import dxtb
from dxtb.typing import DD

############################################
# Setup
############################################

dd: DD = {"device": torch.device("cpu"), "dtype": torch.double}

numbers = torch.tensor([8, 1, 1], device=dd["device"])
positions = torch.tensor(
[
[-2.95915993, 1.40005084, 0.24966306],
[-2.1362031, 1.4795743, -1.38758999],
[-2.40235213, 2.84218589, 1.24419946],
],
requires_grad=True,
**dd,
)

opts = {"verbosity": 0}
calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)
assert calc.integrals.hcore is not None


############################################
# Minimization
############################################

from dxtb._src.exlibs.xitorch.optimize import minimize


def get_energy(positions) -> torch.Tensor:
return calc.get_energy(positions)


minpos = minimize(
get_energy,
positions,
method="gd",
maxiter=200,
step=1e-2,
verbose=True,
)


print("\nInitial geometry:")
print(positions.detach().numpy())

print("Optimized geometry:")
print(minpos.detach().numpy())
3 changes: 3 additions & 0 deletions examples/issues/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# GitHub Issues

This directory contains the working solutions of GitHub issues raised by users.
11 changes: 8 additions & 3 deletions src/dxtb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""

# import timer first to get correct total time
from dxtb._src.timing import timer
from dxtb._src.timing import timer, kill_timer

timer.start("Import")
timer.start("PyTorch", parent_uid="Import")
Expand Down Expand Up @@ -54,20 +54,25 @@

###############################################################################

# stop timers and remove from global namespace
# stop timers and remove PyTorch from global namespace for cleaner API
del torch
timer.stop("dxtb")
timer.stop("Import")

###############################################################################

__all__ = [
"calculators",
"components",
#
"calculators",
"Calculator",
"GFN1_XTB",
"GFN2_XTB",
#
"IndexHelper",
#
"kill_timer",
"timer",
#
"__version__",
]
71 changes: 67 additions & 4 deletions src/dxtb/_src/timing/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import time

__all__ = ["timer"]
__all__ = ["timer", "create_timer", "kill_timer"]


class TimerError(Exception):
Expand Down Expand Up @@ -188,8 +188,8 @@ def __init__(
) -> None:
self.label = label
self.timers = {}
self._enabled = True
self._subtimer_parent_map = {}
self._enabled: bool = True
self._subtimer_parent_map: dict[str, str] = {}
self._autostart = autostart
self._cuda_sync = cuda_sync
self._only_parents = only_parents
Expand Down Expand Up @@ -368,6 +368,9 @@ def kill(self) -> None:
self.reset()
self.stop_all()

self.timers.clear()
self._subtimer_parent_map.clear()

def get_time(self, uid: str) -> float:
"""
Get the elapsed time of a timer.
Expand Down Expand Up @@ -463,6 +466,66 @@ def print(self, v: int = 5, precision: int = 3) -> None: # pragma: no cover
precision=precision,
)

def __str__(self) -> str: # pragma: no cover
"""Return a string representation of the :class:`._Timers` instance."""
timers_repr = ", ".join(
f"'{label}': {timer.elapsed_time:.3f}s"
for label, timer in self.timers.items()
)

return (
f"{self.__class__.__name__}("
f"label={self.label!r}, "
f"enabled={self._enabled}, "
f"cuda_sync={self._cuda_sync}, "
f"only_parents={self._only_parents}, "
f"timers={{{timers_repr}}}"
f")"
)

def __repr__(self) -> str: # pragma: no cover
"""Return a string representation of the :class:`._Timers` instance."""
return str(self)


def create_timer(autostart: bool = True, cuda_sync: bool = False) -> _Timers:
"""
Create a new timer instance.
Parameters
----------
autostart : bool, optional
Whether to start the total timer automatically. Defaults to ``True``.
cuda_sync : bool, optional
Whether to call :func:`torch.cuda.synchronize` after CUDA operations.
Defaults to ``False``.
Returns
-------
_Timers
Instance of the timer class.
Note
----
Delete the timer instance with :func:`.kill_timer` when it is no longer
needed or throws errors when reusing it.
"""
global timer
timer = _Timers(autostart=autostart, cuda_sync=cuda_sync)
return timer


def kill_timer() -> None:
"""Delete the global timer instance."""
global timer
if "timer" not in globals():
raise TimerError(
"Cannot delete timer instance; timer was never initialized."
)

timer.kill()
del timer


timer = _Timers(autostart=True, cuda_sync=False)
timer = create_timer(autostart=True, cuda_sync=False)
"""Global instance of the timer class."""
26 changes: 25 additions & 1 deletion test/test_utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@

import pytest

from dxtb._src.timing.timer import TimerError, _sync, _Timers
from dxtb._src.timing.timer import (
TimerError,
_sync,
_Timers,
create_timer,
kill_timer,
)


def test_fail() -> None:
Expand Down Expand Up @@ -83,3 +89,21 @@ def test_sync_true(mocker_avail, mocker_sync) -> None:

mocker_avail.assert_called_once()
mocker_sync.assert_called_once()


def test_kill() -> None:
create_timer()
kill_timer()

assert "timer" not in globals()


def test_kill_fail() -> None:
# In case no timer exists, create one first and then kill it
create_timer()
kill_timer()
assert "timer" not in globals()

# now trigger the error
with pytest.raises(TimerError):
kill_timer()

0 comments on commit 7f85729

Please sign in to comment.