Skip to content

Commit

Permalink
drop black for ruff-format
Browse files Browse the repository at this point in the history
fix typos and mypy
  • Loading branch information
janosh committed Nov 16, 2023
1 parent 870ca54 commit 80c9cb9
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 35 deletions.
21 changes: 6 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,32 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.284
rev: v0.1.5
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.0
rev: v1.7.0
hooks:
- id: mypy

- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.2.6
hooks:
- id: codespell
stages: [commit, commit-msg]
exclude_types: [html]
additional_dependencies: [tomli] # needed to read pyproject.toml below py3.11
args: [--check-filenames]

- repo: https://github.com/MarcoGorelli/cython-lint
rev: v0.15.0
Expand All @@ -45,12 +42,6 @@ repos:
args: [--no-pycodestyle]
- id: double-quote-cython-strings

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
hooks:
- id: nbqa-ruff
args: [--fix]

- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
Expand Down
14 changes: 8 additions & 6 deletions matgl/ext/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class Relaxer:

def __init__(
self,
potential: Potential | None = None,
potential: Potential,
state_attr: torch.Tensor | None = None,
optimizer: Optimizer | str = "FIRE",
relax_cell: bool = True,
Expand All @@ -209,7 +209,9 @@ def __init__(
"""
self.optimizer: Optimizer = OPTIMIZERS[optimizer.lower()].value if isinstance(optimizer, str) else optimizer
self.calculator = M3GNetCalculator(
potential=potential, state_attr=state_attr, stress_weight=stress_weight # type: ignore
potential=potential,
state_attr=state_attr,
stress_weight=stress_weight, # type: ignore
)
self.relax_cell = relax_cell
self.potential = potential
Expand Down Expand Up @@ -295,7 +297,7 @@ def __len__(self):
return len(self.energies)

def as_pandas(self) -> pd.DataFrame:
"""Returns: DataFrame of energies, forces, streeses, cells and atom_positions."""
"""Returns: DataFrame of energies, forces, stresses, cells and atom_positions."""
return pd.DataFrame(
{
"energies": self.energies,
Expand Down Expand Up @@ -368,18 +370,18 @@ def __init__(
taut (float): time constant for Berendsen temperature coupling
taup (float): time constant for pressure coupling
friction (float): friction coefficient for nvt_langevin, typically set to 1e-4 to 1e-2
andersen_prob (float): random collision probility for nvt_andersen, typically set to 1e-4 to 1e-1
andersen_prob (float): random collision probability for nvt_andersen, typically set to 1e-4 to 1e-1
ttime (float): Characteristic timescale of the thermostat, in ASE internal units
pfactor (float): A constant in the barostat differential equation.
external_stress (float): The external stress in eV/A^3.
Either 3x3 tensor,6-vector or a scalar representing pressure
Either 3x3 tensor,6-vector or a scalar representing pressure
compressibility_au (float): compressibility of the material in A^3/eV
trajectory (str or Trajectory): Attach trajectory object
logfile (str): open this file for recording MD outputs
loginterval (int): write to log file every interval steps
append_trajectory (bool): Whether to append to prev trajectory.
mask (np.array): either a tuple of 3 numbers (0 or 1) or a symmetric 3x3 array indicating,
which strain values may change for NPT simulations.
which strain values may change for NPT simulations.
"""
if isinstance(atoms, (Structure, Molecule)):
atoms = AseAtomsAdaptor().get_atoms(atoms)
Expand Down
22 changes: 9 additions & 13 deletions matgl/layers/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import sympy
import torch
from torch import nn
from torch import Tensor, nn

import matgl
from matgl.layers._three_body import combine_sbf_shf
Expand Down Expand Up @@ -239,24 +239,24 @@ def __init__(self, max_l: int, use_phi: bool = True):
func = sympy.functions.special.spherical_harmonics.Znm(lval, m, theta, phi).expand(func=True)
funcs.append(func)
# replace all theta with cos(theta)
costheta = sympy.symbols("costheta")
funcs = [i.subs({theta: sympy.acos(costheta)}) for i in funcs]
cos_theta = sympy.symbols("costheta")
funcs = [i.subs({theta: sympy.acos(cos_theta)}) for i in funcs]
self.orig_funcs = [sympy.simplify(i).evalf() for i in funcs]
self.funcs = [sympy.lambdify([costheta, phi], i, [{"conjugate": _conjugate}, torch]) for i in self.orig_funcs]
self.funcs = [sympy.lambdify([cos_theta, phi], i, [{"conjugate": torch.conj}, torch]) for i in self.orig_funcs]
self.funcs[0] = _y00

def forward(self, costheta, phi=None):
def __call__(self, cos_theta, phi=None):
"""Args:
costheta: Cosine of the azimuthal angle
cos_theta: Cosine of the azimuthal angle
phi: torch.Tensor, the polar angle.
Returns: [n, m] spherical harmonic results, where n is the number
of angles. The column is arranged following
`[Y_0^0, Y_1^{-1}, Y_1^{0}, Y_1^1, Y_2^{-2}, ...]`
"""
# costheta = torch.tensor(costheta, dtype=torch.complex64)
# cos_theta = torch.tensor(cos_theta, dtype=torch.complex64)
# phi = torch.tensor(phi, dtype=torch.complex64)
return torch.stack([func(costheta, phi) for func in self.funcs], axis=1)
return torch.stack([func(cos_theta, phi) for func in self.funcs], axis=1)
# results = results.type(dtype=DataType.torch_float)
# return results

Expand All @@ -276,11 +276,7 @@ def _y00(theta, phi):
return 0.5 * torch.ones_like(theta) * sqrt(1.0 / pi)


def _conjugate(x):
return torch.conj(x)


def spherical_bessel_smooth(r, cutoff: float = 5.0, max_n: int = 10):
def spherical_bessel_smooth(r: Tensor, cutoff: float = 5.0, max_n: int = 10) -> Tensor:
"""This is an orthogonal basis with first
and second derivative at the cutoff
equals to zero. The function was derived from the order 0 spherical Bessel
Expand Down
4 changes: 3 additions & 1 deletion matgl/models/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def __init__(
if task_type == "classification":
raise ValueError("Classification task cannot be extensive.")
self.final_layer = WeightedReadOut(
in_feats=dim_node_embedding, dims=[units, units], num_targets=ntargets # type: ignore
in_feats=dim_node_embedding,
dims=[units, units],
num_targets=ntargets, # type: ignore
)

self.max_n = max_n
Expand Down

0 comments on commit 80c9cb9

Please sign in to comment.