Skip to content

Commit

Permalink
Fixed the issue of v0's output differing from the original implementa…
Browse files Browse the repository at this point in the history
…tion.
  • Loading branch information
YCC-ProjBackups committed Aug 25, 2023
1 parent 7fecde9 commit 1c4cce3
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 29 deletions.
24 changes: 8 additions & 16 deletions anisoap/representations/ellipsoidal_density_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,9 @@ def __init__(
# Currently, gradients are not supported
if compute_gradients:
raise NotImplementedError(
"Sorry! Gradients have not yet been implemented")
#

"Sorry! Gradients have not yet been implemented"
)
# Initialize the radial basis class
if radial_basis_name not in ["monomial", "gto"]:
raise NotImplementedError(
Expand All @@ -459,7 +459,8 @@ def __init__(
)
if radial_gaussian_width != None and radial_basis_name != "gto":
raise ValueError(
"Gaussian width can only be provided with GTO basis")
"Gaussian width can only be provided with GTO basis"
)
elif radial_gaussian_width is None and radial_basis_name == "gto":
raise ValueError("Gaussian width must be provided with GTO basis")
elif type(radial_gaussian_width) == int:
Expand Down Expand Up @@ -529,8 +530,6 @@ def transform(self, frames, show_progress=False, normalize=True, *, version: int

# Define variables determining size of feature vector coming from frames
self.num_atoms_per_frame = np.array([len(frame) for frame in frames])

num_particle_types = len(species)

# Initialize arrays in which to store all features
self.feature_gradients = 0
Expand Down Expand Up @@ -567,11 +566,7 @@ def transform(self, frames, show_progress=False, normalize=True, *, version: int
frames[i].arrays["c_diameter[3]"][j] / 2,
]

# TypeError: NeighborList.__init__() takes 3 positional arguments but 4 were given
# Deleted the last "True" to resolve this error
# NeighborList(self.cutoff_radius, True, True) -> NeighborList(self.cutoff_radius, True)
self.nl = NeighborList(
self.cutoff_radius, True).compute(frame_generator)
self.nl = NeighborList(self.cutoff_radius, True, True).compute(frame_generator)

pairwise_ellip_feat = pairwise_ellip_expansion(
self.max_angular,
Expand All @@ -585,12 +580,9 @@ def transform(self, frames, show_progress=False, normalize=True, *, version: int
version=version
)

features = contract_pairwise_feat(
pairwise_ellip_feat, species, show_progress)
features = contract_pairwise_feat(pairwise_ellip_feat, species, show_progress)

if normalize:
normalized_features = self.radial_basis.orthonormalize_basis(
features)
return normalized_features
return self.radial_basis.orthonormalize_basis(features)
else:
return features
20 changes: 7 additions & 13 deletions tests/ex_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
_comp_version = [1, _MOST_RECENT_VER]
_test_files = [
"ellipsoid_frames",
"ell-trimers"
# "ell-trimers"
# "both_rotating_in_z", # Results in key error in frames.arrays['c_q']
# "face_to_face",
# "random_rotations",
Expand Down Expand Up @@ -348,8 +348,7 @@ def write_result_summary(file: TextIOWrapper, timer: SimpleTimer, err_dict: dict
file.write(f"{err_dict.get(key):.4e}\n")

file.write("\nNote: Number in parenthesis after runtime refers to corresponding runtime change compared to original implementation (version 0).\n")
file.write(
" Negative runtime change means computation was faster compared to the original implementation.\n")
file.write(" Negative runtime change means computation was faster compared to the original implementation.\n")


def write_raw_data(file: TextIOWrapper, raw_data: dict[str, list[list[float]]]):
Expand All @@ -364,10 +363,8 @@ def write_raw_data(file: TextIOWrapper, raw_data: dict[str, list[list[float]]]):


if __name__ == "__main__":
actual_file_name = "comp_result_v" + \
",".join([str(ver) for ver in _comp_version])
write_name = str(pathlib.Path(__file__).parent.absolute()) + \
"/time_results/" + actual_file_name + ".csv"
actual_file_name = "comp_result_v" + ",".join([str(ver) for ver in _comp_version])
write_name = str(pathlib.Path(__file__).parent.absolute()) + "/time_results/" + actual_file_name + ".csv"
raw_results = dict()
extra_infos = dict()
errors = dict()
Expand All @@ -385,16 +382,14 @@ def write_raw_data(file: TextIOWrapper, raw_data: dict[str, list[list[float]]]):

for ver in _comp_version:
for test_file in _test_files:
file_path = str(pathlib.Path(__file__).parent.parent.absolute(
)) + "/benchmarks/two_particle_gb/" + test_file + ".xyz"
file_path = str(pathlib.Path(__file__).parent.parent.absolute()) + "/benchmarks/two_particle_gb/" + test_file + ".xyz"

for (param_index, (param, repeat_no)) in enumerate(_params):
iter_str = get_key(ver, param_index + 1, test_file)

for rep_index in tqdm(range(repeat_no), desc=f"{iter_str}"):
single_pass_timer.mark_start()
comp_result, ex_info = single_pass(
file_path, param, version=ver)
comp_result, ex_info = single_pass(file_path, param, version=ver)
single_pass_timer.mark(iter_str)

# Only stores the result and extra info from the last iteration, as all iterations
Expand All @@ -403,8 +398,7 @@ def write_raw_data(file: TextIOWrapper, raw_data: dict[str, list[list[float]]]):
extra_infos.update({iter_str: ex_info})

# Get SSE based on the original implementation (v0) of equivalent parameter set and the test file
errors.update({iter_str: total_error(
raw_results.get(get_comp_key(iter_str)), comp_result)})
errors.update({iter_str: total_error(raw_results.get(get_comp_key(iter_str)), comp_result)})

# Make sure garbage collection does not interfere with the next iteration (version change).
ClebschGordanReal.cache_list.clear_cache()
Expand Down
35 changes: 35 additions & 0 deletions tests/time_results/comp_result_v0,1,2.csv
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,38 @@ Comparison of versions [0 1 2]
---------------- Initialization Info ----------------
initial_import time (sec), 0.0000

----------------- Parameter Summary -----------------
Parameter Set,l_max,sigma,r_cut,sigma_1,sigma_2,sigma_3,Rotation Quaternion
p1,10,2.0,5.0,3.0,2.0,1.0,-0.7071 - 0.7071i + 0.0000j + 0.0000k
p2,10,3.0,4.0,3.0,2.0,1.0,-0.7071 - 0.7071i + 0.0000j + 0.0000k
p3,7,4.0,6.0,3.0,2.0,1.0,-0.7071 - 0.7071i + 0.0000j + 0.0000k
p4,5,3.0,5.0,5.0,3.0,1.0,-0.7071 - 0.7071i + 0.0000j + 0.0000k
p5,10,2.0,5.0,3.0,3.0,0.8,-0.7071 - 0.7071i + 0.0000j + 0.0000k
p6,8,10.0,10.0,10.0,7.0,5.0,-0.7071 - 0.7071i + 0.0000j + 0.0000k

------------------ Overall Summary ------------------
Name,Maximum runtime (sec),Average runtime (sec),Median runtime (sec),SSE (from v0)
v0_p1_ellipsoid_frames,3.1423 (0.00%),2.7798 (0.00%),2.9694 (0.00%),0.0000e+00
v0_p2_ellipsoid_frames,6.3015 (0.00%),3.1810 (0.00%),2.4965 (0.00%),0.0000e+00
v0_p3_ellipsoid_frames,1.2019 (0.00%),0.9835 (0.00%),0.9155 (0.00%),0.0000e+00
v0_p4_ellipsoid_frames,0.5421 (0.00%),0.4584 (0.00%),0.4557 (0.00%),0.0000e+00
v0_p5_ellipsoid_frames,4.3704 (0.00%),3.2551 (0.00%),3.0284 (0.00%),0.0000e+00
v0_p6_ellipsoid_frames,1.8138 (0.00%),1.4145 (0.00%),1.3588 (0.00%),0.0000e+00

v1_p1_ellipsoid_frames,2.2992 (-26.83%),0.5607 (-79.83%),0.3151 (-89.39%),5.6160e-31
v1_p2_ellipsoid_frames,0.3154 (-95.00%),0.2812 (-91.16%),0.2741 (-89.02%),3.7748e-32
v1_p3_ellipsoid_frames,0.4770 (-60.32%),0.2263 (-76.99%),0.1833 (-79.98%),0.0000e+00
v1_p4_ellipsoid_frames,0.2562 (-52.73%),0.1531 (-66.60%),0.1365 (-70.05%),7.7037e-34
v1_p5_ellipsoid_frames,0.3155 (-92.78%),0.2995 (-90.80%),0.2974 (-90.18%),6.9333e-33
v1_p6_ellipsoid_frames,0.5709 (-68.53%),0.2688 (-81.00%),0.2205 (-83.77%),0.0000e+00

v2_p1_ellipsoid_frames,1.4603 (-53.53%),0.4608 (-83.42%),0.3267 (-89.00%),5.6160e-31
v2_p2_ellipsoid_frames,0.2883 (-95.42%),0.2729 (-91.42%),0.2699 (-89.19%),3.7748e-32
v2_p3_ellipsoid_frames,0.5480 (-54.40%),0.2346 (-76.15%),0.1811 (-80.21%),0.0000e+00
v2_p4_ellipsoid_frames,0.2527 (-53.38%),0.1467 (-68.00%),0.1294 (-71.61%),7.7037e-34
v2_p5_ellipsoid_frames,0.3103 (-92.90%),0.2950 (-90.94%),0.2933 (-90.32%),6.9333e-33
v2_p6_ellipsoid_frames,1.0280 (-43.32%),0.3344 (-76.36%),0.2294 (-83.12%),0.0000e+00

Note: Number in parenthesis after runtime refers to corresponding runtime change compared to original implementation (version 0).
Negative runtime change means computation was faster compared to the original implementation.

0 comments on commit 1c4cce3

Please sign in to comment.