Skip to content

Commit

Permalink
made definition of maxradial consistent with num_n and updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
arthur-lin1027 authored and rosecers committed Nov 16, 2023
1 parent 3362f78 commit 5128fc8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
7 changes: 4 additions & 3 deletions anisoap/representations/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def __init__(
raise ValueError(f"{self.radial_basis} is not an implemented basis.")

# As part of the initialization, compute the number of radial basis
# functions, nmax, for each angular frequency l.
# functions, num_n, for each angular frequency l.
# If nmax is given, num_n = nmax + 1 (n ranges from 0 to nmax)
self.num_radial_functions = []
for l in range(max_angular + 1):
if max_radial is None:
Expand All @@ -125,9 +126,9 @@ def __init__(
)
if not isinstance(max_radial[l], int):
raise ValueError("`max_radial` must be None, int, or list of int")
self.num_radial_functions.append(max_radial[l])
self.num_radial_functions.append(max_radial[l] + 1)
elif isinstance(max_radial, int):
self.num_radial_functions.append(max_radial)
self.num_radial_functions.append(max_radial + 1)
else:
raise ValueError("`max_radial` must be None, int, or list of int")

Expand Down
12 changes: 6 additions & 6 deletions tests/test_radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def test_radial_functions_n7(self):
num_ns = basis_gto.get_num_radial_functions()

# We specify max_radial so it's decoupled from max_angular.
num_ns_exact = [5, 5, 5, 5, 5, 5, 5]
assert len(num_ns) == len(num_ns_exact)
max_ns_exact = [5, 5, 5, 5, 5, 5, 5]
assert len(num_ns) == len(max_ns_exact)
for l, num in enumerate(num_ns):
assert num == num_ns_exact[l]
assert num == max_ns_exact[l] + 1

def test_radial_functions_n8(self):
basis_gto = RadialBasis(
Expand All @@ -58,10 +58,10 @@ def test_radial_functions_n8(self):
num_ns = basis_gto.get_num_radial_functions()

# We specify max_radial so it's decoupled from max_angular.
num_ns_exact = [1, 2, 3, 4, 5, 6, 7]
assert len(num_ns) == len(num_ns_exact)
max_ns_exact = [1, 2, 3, 4, 5, 6, 7]
assert len(num_ns) == len(max_ns_exact)
for l, num in enumerate(num_ns):
assert num == num_ns_exact[l]
assert num == max_ns_exact[l] + 1


class TestBadInputs:
Expand Down

0 comments on commit 5128fc8

Please sign in to comment.