Skip to content

Commit

Permalink
Changed the script to perform comparison between different versions
Browse files Browse the repository at this point in the history
  • Loading branch information
YCC-ProjBackups committed Aug 21, 2023
1 parent 2b64571 commit cc90003
Show file tree
Hide file tree
Showing 6 changed files with 1,291 additions and 155 deletions.
8 changes: 4 additions & 4 deletions anisoap/representations/ellipsoidal_density_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def pairwise_ellip_expansion(
radial_basis,
*,
timer: SimpleTimer = None,
moment_fn_lang: str = "rust"
version: int = None
):
"""
Function to compute the pairwise expansion <anlm|rho_ij> by combining the moments and the spherical to Cartesian
Expand Down Expand Up @@ -149,7 +149,7 @@ def pairwise_ellip_expansion(
else:
internal_timer3 = None

if moment_fn_lang == "rust":
if version is None or version >= 1:
# NOTE: This line was replaced with Rust implementation.
moments = compute_moments(precision, center, lmax + np.max(num_ns))
# Mark the timers for consistency
Expand Down Expand Up @@ -475,7 +475,7 @@ def __init__(

self.rotation_key = rotation_key

def transform(self, frames, show_progress=False, *, timer: SimpleTimer = None, moment_fn_lang: str = "rust"): # frames: List[Atoms]
def transform(self, frames, show_progress=False, *, version: int = None, timer: SimpleTimer = None): # frames: List[Atoms]
"""
Computes the features and (if compute_gradients == True) gradients
for all the provided frames. The features and gradients are stored in
Expand Down Expand Up @@ -585,7 +585,7 @@ def transform(self, frames, show_progress=False, *, timer: SimpleTimer = None, m
self.sph_to_cart,
self.radial_basis,
timer=internal_timer,
moment_fn_lang=moment_fn_lang
version=version
)
if timer is not None:
timer.mark("5-8. pairwise ellip expansion")
Expand Down
38 changes: 30 additions & 8 deletions anisoap/utils/code_timer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import time
from collections import defaultdict
from typing import Callable
from enum import Enum


class SimpleTimerCollectMode(Enum):
AVG = 1,
SUM = 2,
MIN = 3,
MAX = 4,
MED = 5,

class SimpleTimer:
default_coll_mode = "avg"
# NOTE: Change this collect_mode default argument to change all measuring across the files
# except the one specified.
default_coll_mode = SimpleTimerCollectMode.AVG

def __init__(self):
self._internal_time = time.perf_counter()
Expand Down Expand Up @@ -55,12 +66,10 @@ def collect_trials(self, collect_fn: Callable[[list[float]], float]) -> dict[str
coll_dict[key] = collect_fn(val)
return coll_dict

# NOTE: Change this collect_mode default argument to change all measuring across the files
# except the one specified.
def collect_and_append(
self,
other: 'SimpleTimer',
collect_mode: str | Callable[[list[float]], float] = None
collect_mode: SimpleTimerCollectMode | Callable[[list[float]], float] = None
):
"""
Takes another SimpleTimer class as argument and calls average_trials
Expand All @@ -71,14 +80,16 @@ def collect_and_append(
if collect_mode == None:
collect_mode = SimpleTimer.default_coll_mode

if collect_mode == "avg":
if collect_mode == SimpleTimerCollectMode.AVG:
coll_dict = other.collect_trials(lambda x: sum(x) / len(x))
elif collect_mode == "sum":
elif collect_mode == SimpleTimerCollectMode.SUM:
coll_dict = other.collect_trials(lambda x: sum(x))
elif collect_mode == "max":
elif collect_mode == SimpleTimerCollectMode.MAX:
coll_dict = other.collect_trials(lambda x: max(x))
elif collect_mode == "min":
elif collect_mode == SimpleTimerCollectMode.MIN:
coll_dict = other.collect_trials(lambda x: min(x))
elif collect_mode == SimpleTimerCollectMode.MED:
coll_dict = other.collect_trials(lambda x: SimpleTimer._median(x))
else:
coll_dict = other.collect_trials(lambda x: collect_mode(x))
for key, val in coll_dict.items():
Expand Down Expand Up @@ -106,3 +117,14 @@ def _largest_leading_num(text: str) -> float:
break

return float(int_str) if len(int_str) > 0 else float("inf")

@staticmethod
def _median(x: list[float]):
sorted_x = sorted(x)
len_x = len(x)
half_index = len_x // 2 # floor

if len_x % 2 == 0: # even number of elements:
return (sorted_x[half_index - 1] + sorted_x[half_index]) / 2.0
else:
return sorted_x[half_index]
15 changes: 10 additions & 5 deletions anisoap/utils/cyclic_list.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

class CGRCacheList:
"""
This is a simple class that only exists to be used as a "private" cache
Expand All @@ -11,14 +13,12 @@ def __init__(self, size: int):
A constructor that makes an empty cyclic list.
"""
self._size = size
self._next_ins_index = 0
self._keys = []
self._cyclic_list = [None] * size # will be a list of tuples (key, value)
self.clear_cache()

def keys(self) -> list:
return self._keys

def insert(self, key, value):
def insert(self, key, value) -> None:
if key not in self.keys():
# Store (key, value) pair in cyclic list
self._cyclic_list[self._next_ins_index] = (key, value)
Expand All @@ -32,8 +32,13 @@ def insert(self, key, value):
# Update the index at which the next element should be inserted.
self._next_ins_index = (self._next_ins_index + 1) % self._size

def get_val(self, key):
def get_val(self, key) -> Any:
for element in self._cyclic_list:
if element is not None and key == element[0]:
return element[1]
raise IndexError(f"The specified key {key} is not in the list. Current keys in the list are: {self._keys}")

def clear_cache(self) -> None:
self._next_ins_index = 0
self._keys = []
self._cyclic_list = [None] * self._size # will be a list of tuples (key, value)
4 changes: 2 additions & 2 deletions anisoap/utils/equistore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .cyclic_list import CGRCacheList

class ClebschGordanReal:
def __init__(self, l_max, *, timer: SimpleTimer = None, cache_list: CGRCacheList = None):
def __init__(self, l_max, *, version: int = None, cache_list: CGRCacheList = None, timer: SimpleTimer = None):
if timer is not None:
timer.mark_start()
self._l_max = l_max
Expand All @@ -29,7 +29,7 @@ def __init__(self, l_max, *, timer: SimpleTimer = None, cache_list: CGRCacheList
if timer is not None:
timer.mark("8-2. compute r2c and c2r")

if cache_list is not None:
if version >= 1 and cache_list is not None:
if l_max in cache_list.keys():
self._cg = cache_list.get_val(l_max)
if timer is not None:
Expand Down
Loading

0 comments on commit cc90003

Please sign in to comment.