Skip to content

Commit

Permalink
Implementation of Frame class (#291)
Browse files Browse the repository at this point in the history
* Update data.py

Allows sending a manual coloring of atoms with the atoms.colors attribute

* Colors now in atoms.arrays

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Change to internal frame dataclass instead of ase.Atoms

* changes

* Implementation of frame class

* Manual fix of merge issues

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix for key error for vecField property

Fixed KeyError for not yet implemented property vecField

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test-case for frame class

* fix pbc and cell issues

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor bug fixes and improvements

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* removed data.py, no longer used

* fix return type, fix analysis

* fix inconsistency between ase.Atoms and Frames

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Paul Hohenberger <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: phohenberger <paulh@HO7>
Co-authored-by: phohenberger <[email protected]>
  • Loading branch information
5 people authored Dec 13, 2023
1 parent bd605b9 commit 2c7dda6
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 99 deletions.
9 changes: 6 additions & 3 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import pytest
from ase.build import bulk, molecule

from zndraw.data import atoms_from_json, atoms_to_json
from zndraw.frame import Frame


@pytest.mark.parametrize(
"atoms", [(molecule("C6H6")), (molecule("H2O")), (bulk("Cu", "fcc", a=3.6))]
)
def test_atoms_from_and_to_json_return_unchanged_atoms_object(atoms):
atom_copy = atoms.copy()
assert atom_copy == atoms_from_json(atoms_to_json(atoms))
frame = Frame.from_atoms(atoms)
assert frame == Frame(
positions=atoms.positions, numbers=atoms.numbers, cell=atoms.cell, pbc=atoms.pbc
)
assert frame.to_atoms() == atoms
85 changes: 0 additions & 85 deletions zndraw/data.py

This file was deleted.

253 changes: 253 additions & 0 deletions zndraw/frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import dataclasses

import ase
import networkx as nx
import numpy as np
from ase.calculators.singlepoint import SinglePointCalculator
from ase.data.colors import jmol_colors

from zndraw.bonds import ASEComputeBonds
from zndraw.utils import get_radius, rgb2hex


@dataclasses.dataclass
class Frame:
"""
Primary Attributes:
These attributes directly contain data of your frame
that will be displayed in the visualizer.
-------------
positions : np.ndarray
contains the positions of each particle in 3 dimensions
cell : np.ndarray
contains the cell size of the frame
numbers : t.Union[np.ndarray, list, int] #TODO update in case int gets removed
contains the number of each individual atom
colors : np.ndarray
contains the hexadecimal color representation of each atom.
radii : np.ndarray
contains the radius of each atom that is displayed in the visualizer.
pbc : bool
determines periodic boundary conditions
connectivity : nx.Graph()
contains the bonds between singular atoms
calc : dict
contains properties of the frame, such as energy,
that can be viewed using the analyze function.
vector_field : dict
WIP: contains a flowfield that will be displayed in the simulation box.
will contain box-length, numbers of vectors per dimension and the directional
vectors themself.#
-------------
Secondary Attributes:
These attributes influence the usage of the primary attributes, such as if
bonds are displayed or in what way bonds are calculated.
-------------
bonds : bool
determines if bonds are drawn
auto_bonds : bool
if true uses module ase to calculate chemically accurate bonds of atoms
using the positions and numbers (e.g. number = 1 = Hydrogen)
"""

positions: np.ndarray = None
cell: np.ndarray = np.array([0.0, 0.0, 0.0])
numbers: np.ndarray = None
colors: np.ndarray = None
radii: np.ndarray = None
pbc: bool = False
connectivity: nx.Graph() = nx.empty_graph()
calc: dict = None
vector_field: dict = None

bonds: bool = True
auto_bonds: bool = True

def __post_init__(self):
"""
Converts all lists to np.ndarray
"""
for item in ["positions", "numbers", "colors", "radii"]:
if isinstance(getattr(self, item), list):
setattr(self, item, np.array(getattr(self, item)))

if not isinstance(self.cell, np.ndarray):
self.cell = np.array(self.cell)

@classmethod
def from_atoms(cls, atoms: ase.Atoms):
"""
Creates an instance of the frame class from an ase.Atoms object
"""
frame = cls(**atoms.arrays)

frame.cell = np.array(atoms.cell)
frame.pbc = atoms.pbc

if hasattr(atoms, "connectivity"):
frame.connectivity = atoms.connectivity

try:
calc_data = {}
for key, value in atoms.calc.results.items():
if isinstance(value, np.ndarray):
value = value.tolist()
calc_data[key] = value
frame.calc = calc_data
except (
RuntimeError,
AttributeError,
): # This exception happens, when there is no calc-attribute given.
pass

return frame

def to_atoms(self) -> ase.Atoms:
"""
Creates an ase.Atoms object from a Frame instance
"""
atoms = ase.Atoms(
positions=self.positions, numbers=self.numbers, cell=self.cell, pbc=self.pbc
)

# atoms.arrays["colors"] = self.colors # TODO: see https://github.com/zincware/ZnDraw/issues/279

atoms.connectivity = self.connectivity

if self.calc is not None:
atoms.calc = SinglePointCalculator(atoms)
atoms.calc.results = {
key: np.array(val) if isinstance(val, list) else val
for key, val in self.calc.items()
}

return atoms

# TODO: use instead of ASEComputeBonds() in to_dict()
# def calc_bonds(self):
# """
# Experimental Function, currently not in use
# """
# single_bond_multiplier: float = Field(1.2, le=2, ge=0)
# double_bond_multiplier: float = Field(0.9, le=1, ge=0)
# triple_bond_multiplier: float = Field(0.0, le=1, ge=0)

# cutoffs = [
# single_bond_multiplier,
# double_bond_multiplier,
# triple_bond_multiplier,
# ]
# frame_copy = copy.deepcopy(self)
# connectivity_matrix = np.zeros((len(self), len(self)), dtype=int)
# distance_matrix = self.dist_matrix()
# np.fill_diagonal(distance_matrix, np.inf)
# for cutoff in cutoffs:
# cutoffs = np.array(natural_cutoffs(frame_copy, mult=cutoff))
# cutoffs = cutoffs[:, None] + cutoffs[None, :]
# connectivity_matrix[distance_matrix <= cutoffs] += 1
# self.connectivity = nx.from_numpy_array(connectivity_matrix)

# def dist_matrix(self):
# """
# Experimental function, currently not in use
# """
# matrix = np.zeros((len(self), len(self)))
# for i in range(1, len(self)):
# for j in range(i, len(self)):
# matrix[i, j] = np.linalg.norm(self.positions[i] - self.positions[j])
# return matrix

def get_bonds(self) -> list:
"""
Returns a list than contains all bonds
"""
bonds = []
for edge in self.connectivity.edges:
bonds.append((edge[0], edge[1], self.connectivity.edges[edge]["weight"]))
return bonds

def __len__(self):
if isinstance(self.numbers, np.ndarray):
return self.numbers.size
elif isinstance(self.numbers, int):
return 1

def __eq__(self, other):
"""
Check for identity of two frame objects.
"""
try:
return (
len(self) == len(other)
and (self.positions == other.positions).all()
and (self.numbers == other.numbers).all()
and (self.cell == other.cell).all()
and (self.pbc == other.pbc).all()
and (self.connectivity == other.connectivity)
)
except AttributeError:
return NotImplemented

def to_dict(self) -> dict:
"""
Creates a dictionary than contains all the relevant information of the Frame object
"""
frame_dict = {}

for field in dataclasses.fields(self):
frame_dict[field.name] = getattr(self, field.name)

for key, value in frame_dict.items():
if isinstance(value, np.ndarray):
frame_dict[key] = value.tolist()

if frame_dict["colors"] is None:
frame_dict["colors"] = [
rgb2hex(jmol_colors[number]) for number in frame_dict["numbers"]
]

if frame_dict["radii"] is None:
frame_dict["radii"] = [
get_radius(number) for number in frame_dict["numbers"]
]

if self.bonds:
try:
if self.auto_bonds:
ase_bond_calculator = ASEComputeBonds()
self.connectivity = ase_bond_calculator.build_graph(self.to_atoms())
frame_dict["connectivity"] = self.get_bonds()
except AttributeError:
frame_dict["connectivity"] = []
else:
frame_dict["connectivity"] = []

return frame_dict

@classmethod
def from_dict(cls, data):
"""
Creates an instance of the Frame class from a dictionary
"""
frame = cls(
positions=np.array(data["positions"]),
cell=np.array(data["cell"]),
numbers=np.array(data["numbers"]),
colors=np.array(data["colors"]),
radii=np.array(data["radii"]),
pbc=data["pbc"],
calc=data["calc"],
)

if (
"vector_field" in data
): # currently there is the vector_field part in js missing. So this is useless at the moment
frame.vector_field = data["vector_field"]

if "connectivity" in data:
frame.connectivity = nx.Graph()
for edge in data["connectivity"]:
frame.connectivity.add_edge(edge[0], edge[1], weight=edge[2])

return frame
10 changes: 10 additions & 0 deletions zndraw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@

import ase
import datamodel_code_generator
import numpy as np

SHARED = {"atoms": None}


def rgb2hex(value):
r, g, b = np.array(value * 255, dtype=int)
return "#%02x%02x%02x" % (r, g, b)


def get_radius(value):
return (0.25 * (2 - np.exp(-0.2 * value)),)


def get_port(default: int = 1234) -> int:
"""Get an open port."""
try:
Expand Down
Loading

0 comments on commit 2c7dda6

Please sign in to comment.