Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable continuous GeM computation. #309

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions pysages/colvars/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from jaxopt import GradientDescent as minimize

from pysages.colvars.core import CollectiveVariable
from pysages.utils import gaussian, quaternion_from_euler, quaternion_matrix
from pysages.utils import (
gaussian, row_sum, quaternion_from_euler,
quaternion_matrix)


def rotate_pattern_with_quaternions(rot_q, pattern):
Expand Down Expand Up @@ -42,6 +44,9 @@ def __init__(
centre_j_id,
standard_deviation,
mesh_size,
number_of_added_sites=0,
Copy link
Collaborator Author

@maggiezimon maggiezimon Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the continuous version of the LoM is to be used, the additional atoms should already be added to the reference (coordinates of the reference structure). In other words, reference should have $M$ rows corresponding to the original reference and additional $M_b$ rows representing the coordinates of extra atoms. With number_of_added_atoms we specify how many sites (rows) were included. The $M_b$ atoms are assumed to be outside the 1st shell, so their distances from the central site are larger than for the other $M$ sites.

width_of_switch_func=None,
scale_for_radial_distance=None
):

self.characteristic_distance = characteristic_distance
Expand All @@ -58,6 +63,18 @@ def __init__(
self.centre_j_coords = self.positions[self.centre_j_id]
self.standard_deviation = standard_deviation
self.mesh_size = mesh_size
# These settings are needed if continuous LoM is to be used
self.number_of_added_sites = number_of_added_sites
if self.number_of_added_sites > 0:
if width_of_switch_func is None:
self.width_of_switch_func = self.standard_deviation / 2
else:
self.width_of_switch_func = width_of_switch_func

if scale_for_radial_distance is None:
self.scale_for_radial_distance = 0.9
else:
self.scale_for_radial_distance = scale_for_radial_distance

def comp_pair_distance_squared(self, pos1):
displacement_fn, shift_fn = space.periodic(np.diag(self.simulation_box))
Expand All @@ -79,6 +96,13 @@ def _generate_neighborhood(self):

ids_of_neighbors = np.argsort(distances)[: len(self.reference)]

if self.number_of_added_sites > 0:
ids_of_neighbors_2nd_shell = ids_of_neighbors[
-self.number_of_added_sites:]
self.shell_distance = self.scale_for_radial_distance * np.mean(
distances[ids_of_neighbors_2nd_shell])
self._neighborhood_distances = distances[ids_of_neighbors]

coordinates = mic_vectors[ids_of_neighbors] + self.centre_j_coords
# Step 1: Translate to origin;
coordinates = coordinates.at[:].set(coordinates - np.mean(coordinates, axis=0))
Expand All @@ -95,9 +119,34 @@ def _generate_neighborhood(self):
self._neighbor_coords = np.array([n["coordinates"] for n in self._neighborhood])
self._orig_neighbor_coords = positions_of_all_nbrs[ids_of_neighbors]

def _switching_function(self, distance, width):
result = 0.5 * lax.erfc(
(distance - self.shell_distance) / width)
return result

def compute_score(self, optim_reference):
r = self._neighbor_coords - optim_reference
return np.prod(gaussian(1, self.standard_deviation, r))

if self.number_of_added_sites != 0:
width = self.width_of_switch_func
squared_dist = row_sum(r**2)
return np.exp(
- np.sum(
self._switching_function(
self._neighborhood_distances,
width) * squared_dist
) / (
2 * (self.standard_deviation ** 2) * np.sum(
self._switching_function(
self._neighborhood_distances, width)
)
)
)
else:
return np.prod(
gaussian(1,
self.standard_deviation * np.sqrt(
len(self.reference)), r))

def rotate_reference(self, random_euler_point):
# Perform rotation of the reference pattern;
Expand Down Expand Up @@ -153,7 +202,7 @@ def return_close(_, n):
close_sites,
)
# Return the locations of settled nighbours in the neighborhood;
# Settlled site should have a unique neighbor
# Settled site should have a unique neighbor
settled_neighbor_indices = np.where(np.sum(indices, axis=0) >= 1, 1, 0)
return settled_neighbor_indices

Expand Down Expand Up @@ -281,6 +330,9 @@ def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params)
i,
params.standard_deviation,
params.mesh_size,
params.number_of_added_sites,
params.width_of_switch_func,
params.scale_for_radial_distance
).driver_match(
params.number_of_rotations,
params.number_of_opt_it,
Expand Down Expand Up @@ -339,6 +391,14 @@ class GeM(CollectiveVariable):
fractional_coords: bool
Set to True if NPT simulation is considered and the box size
changes; use periodic_general for constructing the neighborlist.
number_of_added_sites: int
Specify additional sites to the main reference for the continuous
calculation (skip if the continuous LoM is not needed).
width_of_switch_func: float
Width of the switching function for the continuous score function.
scale_for_radial_distance: float
Scaling factor for the mean radial distance of added sites
used in the continuous score function calculation.
Returns
-------
calculate_lom: float
Expand All @@ -357,6 +417,9 @@ def __init__(
mesh_size,
nbrs,
fractional_coords,
number_of_added_sites=0,
width_of_switch_func=None,
scale_for_radial_distance=None
):
super().__init__(indices, group_length=None)

Expand All @@ -369,6 +432,10 @@ def __init__(
self.mesh_size = mesh_size
self.nbrs = nbrs
self.fractional_coords = fractional_coords
# The parameters below are only used in the continuous version
self.number_of_added_sites = number_of_added_sites
self.width_of_switch_func = width_of_switch_func
self.scale_for_radial_distance = scale_for_radial_distance

@property
def function(self):
Expand Down
2 changes: 1 addition & 1 deletion pysages/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
solve_pos_def,
try_import,
)
from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity
from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity, row_sum
from .transformations import quaternion_from_euler, quaternion_matrix
Loading