diff --git a/anisoap/representations/radial_basis.py b/anisoap/representations/radial_basis.py index f9e3397..7c105c8 100644 --- a/anisoap/representations/radial_basis.py +++ b/anisoap/representations/radial_basis.py @@ -110,7 +110,8 @@ def __init__( raise ValueError(f"{self.radial_basis} is not an implemented basis.") # As part of the initialization, compute the number of radial basis - # functions, nmax, for each angular frequency l. + # functions, num_n, for each angular frequency l. + # If nmax is given, num_n = nmax + 1 (n ranges from 0 to nmax) self.num_radial_functions = [] for l in range(max_angular + 1): if max_radial is None: @@ -125,9 +126,9 @@ def __init__( ) if not isinstance(max_radial[l], int): raise ValueError("`max_radial` must be None, int, or list of int") - self.num_radial_functions.append(max_radial[l]) + self.num_radial_functions.append(max_radial[l] + 1) elif isinstance(max_radial, int): - self.num_radial_functions.append(max_radial) + self.num_radial_functions.append(max_radial + 1) else: raise ValueError("`max_radial` must be None, int, or list of int") diff --git a/tests/test_radial_basis.py b/tests/test_radial_basis.py index 723b7df..8035186 100644 --- a/tests/test_radial_basis.py +++ b/tests/test_radial_basis.py @@ -43,10 +43,10 @@ def test_radial_functions_n7(self): num_ns = basis_gto.get_num_radial_functions() # We specify max_radial so it's decoupled from max_angular. - num_ns_exact = [5, 5, 5, 5, 5, 5, 5] - assert len(num_ns) == len(num_ns_exact) + max_ns_exact = [5, 5, 5, 5, 5, 5, 5] + assert len(num_ns) == len(max_ns_exact) for l, num in enumerate(num_ns): - assert num == num_ns_exact[l] + assert num == max_ns_exact[l] + 1 def test_radial_functions_n8(self): basis_gto = RadialBasis( @@ -58,10 +58,10 @@ def test_radial_functions_n8(self): num_ns = basis_gto.get_num_radial_functions() # We specify max_radial so it's decoupled from max_angular. - num_ns_exact = [1, 2, 3, 4, 5, 6, 7] - assert len(num_ns) == len(num_ns_exact) + max_ns_exact = [1, 2, 3, 4, 5, 6, 7] + assert len(num_ns) == len(max_ns_exact) for l, num in enumerate(num_ns): - assert num == num_ns_exact[l] + assert num == max_ns_exact[l] + 1 class TestBadInputs: