Skip to content

Commit

Permalink
add force l2 norm scaled error and max error to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
jmusiel committed Oct 31, 2023
1 parent a269025 commit d68c29c
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions finetuna/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from finetuna.calcs import DeltaCalc
import numpy as np
from numpy.linalg import norm
from ase.db import connect


def convert_to_singlepoint(images):
Expand Down Expand Up @@ -345,6 +346,14 @@ def add_hookean_constraint(
image.set_constraint(cons)


def get_online_asedb_as_parent_images(db_path):
images = []
with connect(db_path) as queried_db:
for row in queried_db.select(check=True):
images.append(asedb_row_to_atoms(row))
return images


def force_l2_norm_err(forces0, forces1):
diff_array = forces0 - forces1
l2_norm_vector = norm(diff_array, axis=1)
Expand All @@ -367,3 +376,52 @@ def force_cos_sim(forces0, forces1):
cos_sim = dot_array / (magnitudes0 * magnitudes1)
mean_cos_sim = np.mean(cos_sim)
return mean_cos_sim


def force_l2_norm_scaled_error(forces0, forces1):
# forces0 is ground truth
diff_array = forces0 - forces1
l2_norm_vector = norm(diff_array, axis=1)
forces0_norm_vector = norm(forces0, axis=1)
scaled_l2_norm_vector = l2_norm_vector / forces0_norm_vector
scaled_mae = np.mean(scaled_l2_norm_vector)
return (
np.max(scaled_l2_norm_vector),
scaled_mae,
np.mean(l2_norm_vector),
np.mean(forces0_norm_vector),
np.mean(norm(forces1, axis=1)),
)


def force_l2_norm_max_err(forces0, forces1):
diff_array = forces0 - forces1
l2_norm_vector = norm(diff_array, axis=1)
mae = np.max(l2_norm_vector)
return mae


if __name__ == "__main__":
from ase.io import Trajectory

i = 0
images = []
ml_images = []
re_images = []
db_path = "/home/jovyan/shared-scratch/joe/jobs/finetuna_2_tests/all_sudheesh_data/5_mae_threshold_005_reduce/Sudheesh_mof_zeolite/Cu_zeolites/MOR_T1_T4_96_141_1/finetuna/oal_queried_images.db"
with connect(db_path) as queried_db:
for row in queried_db.select(check=True):
i += 1
images.append(asedb_row_to_atoms(row, calc="parent"))
ml_images.append(asedb_row_to_atoms(row, calc="ml"))
re_images.append(asedb_row_to_atoms(row, calc="retrained"))
print(
(
i,
force_l2_norm_scaled_error(
images[-1].get_forces(), ml_images[-1].get_forces()
),
)
)
if i > 100000:
break

0 comments on commit d68c29c

Please sign in to comment.