diff --git a/.github/workflows/website.yml b/.github/workflows/website.yml
index d1c8c008b..7f9f6023b 100644
--- a/.github/workflows/website.yml
+++ b/.github/workflows/website.yml
@@ -21,7 +21,7 @@ jobs:
cache: 'npm'
- uses: actions/setup-python@v4
with:
- python-version: '3.10'
+ python-version: '3.11'
- name: configure access to git repositories in package-lock.json
run: git config --global url."https://github.com/".insteadOf ssh://git@github.com
- name: install npm dependencies
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 8d5e7895e..3a43f19d4 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -8,6 +8,12 @@ numpy
ipywidgets
matplotlib
ase==3.22.1
+# These three pins are to avoid rdkit version 2024.3.5. Once a later version
+# is available, the latest version (assuming Python >= 3.11) of all three
+# should be ok.
+rdkit==2023.9.5
+stk==2024.9.23.1
+stko==2024.8.29.1
# required for the examples of chemiscope.explore
mace-torch
diff --git a/docs/src/python/reference.rst b/docs/src/python/reference.rst
index a65851ea3..63b6af907 100644
--- a/docs/src/python/reference.rst
+++ b/docs/src/python/reference.rst
@@ -28,3 +28,5 @@
.. autofunction:: chemiscope.explore
.. autofunction:: chemiscope.metatensor_featurizer
+
+.. autofunction:: chemiscope.convert_stk_bonds_as_shapes
diff --git a/python/README.md b/python/README.md
index 7a7fd5d58..acc9b97fd 100644
--- a/python/README.md
+++ b/python/README.md
@@ -24,6 +24,8 @@ import chemiscope
import ase.io
# read frames using ase
+# frames can also be stk objets, e.g.
+# frames = [stk.BuildingBlock(smiles="NCCN")]
frames = ase.io.read("structures.xyz", ":")
# add additional properties to display
@@ -44,6 +46,8 @@ import chemiscope
import ase.io
# read frames using ase
+# frames can also be stk objets, e.g.
+# frames = [stk.BuildingBlock(smiles="NCCN")]
frames = ase.io.read("structures.xyz", ":")
# add additional properties to display
diff --git a/python/chemiscope/__init__.py b/python/chemiscope/__init__.py
index 9723959c1..c275c6bcd 100644
--- a/python/chemiscope/__init__.py
+++ b/python/chemiscope/__init__.py
@@ -10,6 +10,7 @@
ellipsoid_from_tensor,
extract_properties,
librascal_atomic_environments,
+ convert_stk_bonds_as_shapes,
)
from .explore import explore, metatensor_featurizer # noqa: F401
from .version import __version__ # noqa: F401
diff --git a/python/chemiscope/input.py b/python/chemiscope/input.py
index 3b5a4060a..6ef495ef4 100644
--- a/python/chemiscope/input.py
+++ b/python/chemiscope/input.py
@@ -2,6 +2,7 @@
"""
Generate JSON input files for the default chemiscope visualizer.
"""
+
import gzip
import json
import os
@@ -282,6 +283,7 @@ def create_input(
data["structures"] = frames_to_json(frames)
n_structures = len(data["structures"])
n_atoms = sum(s["size"] for s in data["structures"])
+
else:
n_atoms = 0
@@ -378,6 +380,7 @@ def create_input(
# should be removed in version 0.6 of chemiscope.
if frames is not None:
found_one_from_frame = False
+
atom_properties = _list_atom_properties(frames)
for name in atom_properties:
if name in data["properties"]:
diff --git a/python/chemiscope/jupyter.py b/python/chemiscope/jupyter.py
index d690190b7..2364a54e4 100644
--- a/python/chemiscope/jupyter.py
+++ b/python/chemiscope/jupyter.py
@@ -155,7 +155,7 @@ def show_input(path, mode="default"):
try:
meta = dict_input["meta"]
- if meta is {"name": " "}:
+ if meta == {"name": " "}:
has_metadata = False
else:
has_metadata = True
diff --git a/python/chemiscope/structures/__init__.py b/python/chemiscope/structures/__init__.py
index 3f6fa9991..059e05b45 100644
--- a/python/chemiscope/structures/__init__.py
+++ b/python/chemiscope/structures/__init__.py
@@ -20,6 +20,15 @@
ase_tensors_to_ellipsoids,
ase_vectors_to_arrows,
)
+from ._stk import ( # noqa: F401
+ _stk_valid_structures,
+ _stk_to_json,
+ convert_stk_bonds_as_shapes,
+ _stk_all_atomic_environments,
+ _stk_composition_properties,
+ _stk_list_atom_properties,
+ _stk_list_structure_properties,
+)
def _guess_adapter(frames):
@@ -32,6 +41,10 @@ def _guess_adapter(frames):
if use_ase:
return ase_frames, "ASE"
+ stk_frames, use_stk = _stk_valid_structures(frames)
+ if use_stk:
+ return stk_frames, "stk"
+
raise Exception(f"unknown frame type: '{frames[0].__class__.__name__}'")
@@ -48,6 +61,8 @@ def frames_to_json(frames):
if adapter == "ASE":
return [_ase_to_json(frame) for frame in frames]
+ elif adapter == "stk":
+ return [_stk_to_json(frame) for frame in frames]
else:
raise Exception("reached unreachable code")
@@ -62,6 +77,9 @@ def _list_atom_properties(frames):
if adapter == "ASE":
return _ase_list_atom_properties(frames)
+ elif adapter == "stk":
+ return _stk_list_atom_properties(frames)
+
else:
raise Exception("reached unreachable code")
@@ -76,6 +94,8 @@ def _list_structure_properties(frames):
if adapter == "ASE":
return _ase_list_structure_properties(frames)
+ elif adapter == "stk":
+ return _stk_list_structure_properties(frames)
else:
raise Exception("reached unreachable code")
@@ -96,6 +116,12 @@ def extract_properties(frames, only=None, environments=None):
if adapter == "ASE":
return _ase_extract_properties(frames, only, environments)
+
+ elif adapter == "stk":
+ raise RuntimeError(
+ "stk molecules do not contain properties, you must manually provide them"
+ )
+
else:
raise Exception("reached unreachable code")
@@ -119,6 +145,10 @@ def composition_properties(frames, environments=None):
if adapter == "ASE":
return _ase_composition_properties(frames, environments)
+
+ elif adapter == "stk":
+ return _stk_composition_properties(frames, environments)
+
else:
raise Exception("reached unreachable code")
@@ -137,6 +167,8 @@ def all_atomic_environments(frames, cutoff=3.5):
if adapter == "ASE":
return _ase_all_atomic_environments(frames, cutoff)
+ elif adapter == "stk":
+ return _stk_all_atomic_environments(frames, cutoff)
else:
raise Exception("reached unreachable code")
diff --git a/python/chemiscope/structures/_ase.py b/python/chemiscope/structures/_ase.py
index c597624b4..f10538e1f 100644
--- a/python/chemiscope/structures/_ase.py
+++ b/python/chemiscope/structures/_ase.py
@@ -15,7 +15,11 @@
def _ase_valid_structures(frames):
- frames_list = list(frames)
+ try:
+ frames_list = list(frames)
+ except TypeError:
+ return [], False
+
if HAVE_ASE and isinstance(frames_list[0], ase.Atoms):
for frame in frames_list:
assert isinstance(frame, ase.Atoms)
diff --git a/python/chemiscope/structures/_stk.py b/python/chemiscope/structures/_stk.py
new file mode 100644
index 000000000..b2459d1b0
--- /dev/null
+++ b/python/chemiscope/structures/_stk.py
@@ -0,0 +1,212 @@
+import typing
+from collections import Counter
+
+try:
+ import stk
+
+ HAVE_STK = True
+except ImportError:
+ HAVE_STK = False
+
+
+def _stk_valid_structures(
+ frames: typing.Union[stk.Molecule, list[stk.Molecule]],
+) -> tuple[list[stk.Molecule], bool]:
+ if HAVE_STK and isinstance(frames, stk.Molecule):
+ # deal with the user passing a single frame
+ return [frames], True
+ elif HAVE_STK and isinstance(frames[0], stk.Molecule):
+ for frame in frames:
+ assert isinstance(frame, stk.Molecule)
+ return frames, True
+ else:
+ return frames, False
+
+
+def _stk_to_json(molecule: stk.Molecule) -> dict[str : typing.Union[int, list]]:
+ """Implementation of frame_to_json for stk.Molcule.
+
+ The current implementation assumes no periodic information, which is safe
+ for the majority of stk molecules. If necessary, we can add cell information
+ in the future.
+
+ """
+ pos_mat = molecule.get_position_matrix()
+ data = {}
+ data["size"] = molecule.get_num_atoms()
+ data["names"] = [atom.__class__.__name__ for atom in molecule.get_atoms()]
+ data["x"] = [float(pos_mat[atom.get_id()][0]) for atom in molecule.get_atoms()]
+ data["y"] = [float(pos_mat[atom.get_id()][1]) for atom in molecule.get_atoms()]
+ data["z"] = [float(pos_mat[atom.get_id()][2]) for atom in molecule.get_atoms()]
+
+ return data
+
+
+def _stk_all_atomic_environments(
+ frames: list[stk.Molecule],
+ cutoff: float,
+) -> list[tuple[int, int, float]]:
+ "Extract all atomic environments out of a set of stk Molecule objects"
+ environments = []
+ for structure_i, frame in enumerate(frames):
+ for atom in frame.get_atoms():
+ environments.append((structure_i, atom.get_id(), cutoff))
+
+ return environments
+
+
+def _stk_composition_properties(frames, environments=None):
+ all_elements = set()
+ for frame in frames:
+ all_elements.update([atom.__class__.__name__ for atom in frame.get_atoms()])
+ all_elements = set(all_elements)
+
+ composition = []
+ elements_count = {element: [] for element in all_elements}
+ for frame in frames:
+ counter = Counter([atom.__class__.__name__ for atom in frame.get_atoms()])
+
+ composition.append("".join(f"{i}{counter[i]}" for i in sorted(counter)))
+
+ dict_composition = dict(counter)
+
+ for element in all_elements:
+ if element in dict_composition:
+ elements_count[element].append(dict_composition[element])
+ else:
+ elements_count[element].append(0)
+
+ properties = {
+ f"n_{element}": {"values": values, "target": "structure"}
+ for element, values in elements_count.items()
+ }
+
+ properties["composition"] = {"values": composition, "target": "structure"}
+
+ if environments is not None:
+ atoms_mask = [[False] * len(f) for f in frames]
+ for structure, center, _ in environments:
+ atoms_mask[structure][center] = True
+ else:
+ atoms_mask = None
+
+ symbols = []
+ numbers = []
+ for i, frame in enumerate(frames):
+ if atoms_mask is None:
+ frame_symbols = [atom.__class__.__name__ for atom in frame.get_atoms()]
+ frame_numbers = [atom.get_atomic_number() for atom in frame.get_atoms()]
+
+ else:
+ frame_symbols = [atom.__class__.__name__ for atom in frame.get_atoms()][
+ atoms_mask[i]
+ ]
+ frame_numbers = [atom.get_atomic_number() for atom in frame.get_atoms()][
+ atoms_mask[i]
+ ]
+
+ symbols.extend(frame_symbols)
+ numbers.extend(frame_numbers)
+
+ properties["symbol"] = {"values": symbols, "target": "atom"}
+ properties["number"] = {"values": numbers, "target": "atom"}
+
+ return properties
+
+
+def convert_stk_bonds_as_shapes(
+ frames: list[stk.Molecule],
+ bond_color: str,
+ bond_radius: float,
+) -> dict[str, dict]:
+ """Convert connections between atom ids in each structure to shapes.
+
+ Parameters:
+
+ frames:
+ List of stk.Molecule objects, which each are structures in
+ chemiscope.
+
+ bond_colour:
+ How to colour the bonds added.
+
+ bond_radius:
+ Radius of bonds to add.
+
+
+ """
+
+ shape_dict: dict[str, dict] = {}
+ max_length = 0
+ for molecule in frames:
+ bonds_to_add = tuple(
+ (bond.get_atom1().get_id(), bond.get_atom2().get_id())
+ for bond in molecule.get_bonds()
+ )
+
+ for bid, bond_info in enumerate(bonds_to_add):
+ bname = f"bond_{bid}"
+
+ # Compute the bond vector.
+ position_matrix = molecule.get_position_matrix()
+ bond_geometry = {
+ "vector": (
+ position_matrix[bond_info[1]] - position_matrix[bond_info[0]]
+ ).tolist(),
+ "position": (position_matrix[bond_info[0]]).tolist(),
+ }
+
+ # Add the bond name to the dictionary to be iterated through.
+ if bname not in shape_dict:
+ if bname == "bond_0":
+ shape_dict[bname] = {
+ "kind": "cylinder",
+ "parameters": {
+ "global": {"radius": bond_radius, "color": bond_color},
+ "structure": [],
+ },
+ }
+
+ else:
+ num_to_add = len(shape_dict["bond_0"]["parameters"]["structure"])
+ shape_dict[bname] = {
+ "kind": "cylinder",
+ "parameters": {
+ "global": {"radius": bond_radius, "color": bond_color},
+ # Add zero placements for previously non-existant
+ # bond shapes up to the length of bond_0 -1,
+ # because that should already be at the current
+ # length that the new one should be.
+ "structure": [
+ {"vector": [0, 0, 0], "position": [0, 0, 0]}
+ for i in range(num_to_add - 1)
+ ],
+ },
+ }
+
+ # Add vector to the shape dictionary.
+ shape_dict[bname]["parameters"]["structure"].append(bond_geometry)
+ max_length = max(
+ (max_length, len(shape_dict[bname]["parameters"]["structure"]))
+ )
+
+ # Fill in bond shapes that are not the same length as the max length.
+ for bname in shape_dict.keys():
+ missing = max_length - len(shape_dict[bname]["parameters"]["structure"])
+ if missing == 0:
+ continue
+ for _ in range(missing):
+ fake_bond = {"vector": [0, 0, 0], "position": [0, 0, 0]}
+ shape_dict[bname]["parameters"]["structure"].append(fake_bond)
+
+ return shape_dict
+
+
+def _stk_list_atom_properties(frames: list[stk.Molecule]) -> list:
+ # stk cannot have atom properties or structure properties, so skipping.
+ return []
+
+
+def _stk_list_structure_properties(frames: list[stk.Molecule]) -> list:
+ # stk cannot have atom properties or structure properties, so skipping.
+ return []
diff --git a/python/examples/9-showing_custom_bonds.py b/python/examples/9-showing_custom_bonds.py
new file mode 100644
index 000000000..9c5af3658
--- /dev/null
+++ b/python/examples/9-showing_custom_bonds.py
@@ -0,0 +1,207 @@
+"""
+Showing custom bonds using stk
+==============================
+
+This example demonstrates how to add shapes into the chemiscope output such
+that custom bonds that would not automatically be assigned can be rendered.
+
+This is done by using `stk `_ to
+generate and analyse molecules, which comes with topology/bonding information
+by default (using the cheminformatic software rdkit).
+
+We use `stko `_ to calculate
+some rudimentary properties of `stk` molecules. `stko` can be installed with
+``pip install stko``.
+
+"""
+
+import stk
+import stko
+from rdkit.Chem import AllChem as rdkit
+
+import chemiscope
+
+# %%
+#
+# Generate a list of stk BuildingBlocks (representation of a molecule) with
+# properties. This also includes working with rdkit, which comes installed
+# with stk.
+
+rdkitmol = rdkit.MolFromSmiles("Cc1ccccc1")
+rdkitmol = rdkit.AddHs(rdkitmol)
+rdkit.Kekulize(rdkitmol)
+params = rdkit.ETKDGv3()
+params.randomSeed = 0xF00D
+rdkit.EmbedMolecule(rdkitmol, params)
+
+structures = [
+ # A building block.
+ stk.BuildingBlock(smiles="NCCN"),
+ # A mostly optimised cage molecule.
+ stk.ConstructedMolecule(
+ topology_graph=stk.cage.FourPlusSix(
+ building_blocks=(
+ stk.BuildingBlock(
+ smiles="NCCN",
+ functional_groups=[stk.PrimaryAminoFactory()],
+ ),
+ stk.BuildingBlock(
+ smiles="O=CC(C=O)C=O",
+ functional_groups=[stk.AldehydeFactory()],
+ ),
+ ),
+ optimizer=stk.MCHammer(),
+ ),
+ ),
+ # A metal-organic cage.
+ stk.ConstructedMolecule(
+ stk.cage.M2L4Lantern(
+ building_blocks=(
+ stk.BuildingBlock(
+ smiles="[Pd+2]",
+ functional_groups=(
+ stk.SingleAtom(stk.Pd(0, charge=2)) for i in range(4)
+ ),
+ position_matrix=[[0.0, 0.0, 0.0]],
+ ),
+ stk.BuildingBlock(
+ smiles=("C1=NC=CC(C2=CC=CC(C3=C" "C=NC=C3)=C2)=C1"),
+ functional_groups=[
+ stk.SmartsFunctionalGroupFactory(
+ smarts="[#6]~[#7X2]~[#6]",
+ bonders=(1,),
+ deleters=(),
+ ),
+ ],
+ ),
+ ),
+ # Ensure that bonds between the
+ # GenericFunctionalGroups of the ligand and the
+ # SingleAtom functional groups of the metal are
+ # dative.
+ reaction_factory=stk.DativeReactionFactory(
+ stk.GenericReactionFactory(
+ bond_orders={
+ frozenset(
+ {
+ stk.GenericFunctionalGroup,
+ stk.SingleAtom,
+ }
+ ): 9,
+ },
+ ),
+ ),
+ ),
+ ),
+ # A host guest molecule.
+ stk.ConstructedMolecule(
+ topology_graph=stk.host_guest.Complex(
+ host=stk.BuildingBlock.init_from_molecule(
+ stk.ConstructedMolecule(
+ topology_graph=stk.cage.FourPlusSix(
+ building_blocks=(
+ stk.BuildingBlock(
+ smiles="NC1CCCCC1N",
+ functional_groups=[
+ stk.PrimaryAminoFactory(),
+ ],
+ ),
+ stk.BuildingBlock(
+ smiles="O=Cc1cc(C=O)cc(C=O)c1",
+ functional_groups=[stk.AldehydeFactory()],
+ ),
+ ),
+ optimizer=stk.MCHammer(),
+ ),
+ )
+ ),
+ guests=stk.host_guest.Guest(
+ building_block=stk.BuildingBlock("[Br][Br]"),
+ ),
+ ),
+ ),
+ # From rdkit.
+ stk.BuildingBlock.init_from_rdkit_mol(rdkitmol),
+]
+
+
+# %%
+#
+# Write their properties using any method, here we show using stko:
+# https://stko-docs.readthedocs.io/en/latest/
+
+energy = stko.UFFEnergy()
+shape_calc = stko.ShapeCalculator()
+properties = {
+ "uffenergy": [energy.get_energy(molecule) for molecule in structures],
+ "aspheriticty": [
+ shape_calc.get_results(molecule).get_asphericity() for molecule in structures
+ ],
+}
+
+
+# %%
+#
+# Get the stk bonding information and convert them into shapes.
+shape_dict = chemiscope.convert_stk_bonds_as_shapes(
+ frames=structures,
+ bond_color="#fc5500",
+ bond_radius=0.12,
+)
+
+# Write the shape string for settings to turn them on automatically.
+shape_string = ",".join(shape_dict.keys())
+
+
+# %%
+#
+# A chemiscope widget showing the result without the added bonding.
+
+chemiscope.show(frames=structures, properties=properties)
+
+
+# %%
+#
+# Writing to a json.gz file, again without added bonding.
+
+chemiscope.write_input(
+ path="noshape_example.json.gz",
+ frames=structures,
+ properties=properties,
+ meta=dict(name="Missing bonds by automation."),
+ settings=chemiscope.quick_settings(x="aspheriticty", y="uffenergy", color=""),
+)
+
+
+# %%
+#
+# Now with added bonding information.
+
+chemiscope.show(
+ frames=structures,
+ properties=properties,
+ shapes=shape_dict,
+)
+
+# %%
+#
+# Write to json file with added shapes.
+
+chemiscope.write_input(
+ path="shape_example.json.gz",
+ frames=structures,
+ properties=properties,
+ meta=dict(name="Added all stk bonds."),
+ settings=chemiscope.quick_settings(
+ x="aspheriticty",
+ y="uffenergy",
+ color="",
+ structure_settings={
+ "shape": shape_string,
+ "atoms": True,
+ "bonds": False,
+ "spaceFilling": False,
+ },
+ ),
+ shapes=shape_dict,
+)
diff --git a/python/tests/create_input.py b/python/tests/create_input.py
index 7a16586ae..81b3409c6 100644
--- a/python/tests/create_input.py
+++ b/python/tests/create_input.py
@@ -2,323 +2,354 @@
import ase
import numpy as np
+import stk
from chemiscope import all_atomic_environments, create_input
+# These should be the same molecule!
TEST_FRAMES = [ase.Atoms("CO2")]
+TEST_FRAMES_STK = [stk.BuildingBlock("O=C=O")]
class TestCreateInputMeta(unittest.TestCase):
def test_meta(self):
- meta = {}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "")
- self.assertEqual(len(data["meta"].keys()), 1)
-
- meta = {"name": ""}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "")
- self.assertEqual(len(data["meta"].keys()), 1)
-
- meta = {"name": "foo"}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "foo")
- self.assertEqual(len(data["meta"].keys()), 1)
-
- meta = {"name": "foo", "description": "bar"}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "foo")
- self.assertEqual(data["meta"]["description"], "bar")
- self.assertEqual(len(data["meta"].keys()), 2)
-
- meta = {"name": "foo", "references": ["bar"]}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "foo")
- self.assertEqual(len(data["meta"]["references"]), 1)
- self.assertEqual(data["meta"]["references"][0], "bar")
- self.assertEqual(len(data["meta"].keys()), 2)
-
- meta = {"name": "foo", "authors": ["bar"]}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "foo")
- self.assertEqual(len(data["meta"]["authors"]), 1)
- self.assertEqual(data["meta"]["authors"][0], "bar")
- self.assertEqual(len(data["meta"].keys()), 2)
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ meta = {}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "")
+ self.assertEqual(len(data["meta"].keys()), 1)
+
+ meta = {"name": ""}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "")
+ self.assertEqual(len(data["meta"].keys()), 1)
+
+ meta = {"name": "foo"}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "foo")
+ self.assertEqual(len(data["meta"].keys()), 1)
+
+ meta = {"name": "foo", "description": "bar"}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "foo")
+ self.assertEqual(data["meta"]["description"], "bar")
+ self.assertEqual(len(data["meta"].keys()), 2)
+
+ meta = {"name": "foo", "references": ["bar"]}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "foo")
+ self.assertEqual(len(data["meta"]["references"]), 1)
+ self.assertEqual(data["meta"]["references"][0], "bar")
+ self.assertEqual(len(data["meta"].keys()), 2)
+
+ meta = {"name": "foo", "authors": ["bar"]}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "foo")
+ self.assertEqual(len(data["meta"]["authors"]), 1)
+ self.assertEqual(data["meta"]["authors"][0], "bar")
+ self.assertEqual(len(data["meta"].keys()), 2)
def test_meta_unknown_keys_warning(self):
- meta = {"name": "foo", "what_is_this": "I don't know"}
- with self.assertWarns(UserWarning) as cm:
- data = create_input(frames=TEST_FRAMES, meta=meta)
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ meta = {"name": "foo", "what_is_this": "I don't know"}
+ with self.assertWarns(UserWarning) as cm:
+ data = create_input(frames=TF, meta=meta)
- self.assertEqual(data["meta"]["name"], "foo")
- self.assertEqual(len(data["meta"].keys()), 1)
+ self.assertEqual(data["meta"]["name"], "foo")
+ self.assertEqual(len(data["meta"].keys()), 1)
- self.assertEqual(str(cm.warning), "ignoring unexpected metadata: what_is_this")
+ self.assertEqual(
+ str(cm.warning), "ignoring unexpected metadata: what_is_this"
+ )
def test_meta_conversions(self):
- meta = {"name": 33}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "33")
- self.assertEqual(len(data["meta"].keys()), 1)
-
- meta = {"name": ["foo", "bar"], "description": False}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "['foo', 'bar']")
- self.assertEqual(data["meta"]["description"], "False")
- self.assertEqual(len(data["meta"].keys()), 2)
-
- meta = {"name": "foo", "references": (3, False)}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "foo")
- self.assertEqual(len(data["meta"]["references"]), 2)
- self.assertEqual(data["meta"]["references"][0], "3")
- self.assertEqual(data["meta"]["references"][1], "False")
- self.assertEqual(len(data["meta"].keys()), 2)
-
- meta = {"name": "foo", "authors": (3, False)}
- data = create_input(frames=TEST_FRAMES, meta=meta)
- self.assertEqual(data["meta"]["name"], "foo")
- self.assertEqual(len(data["meta"]["authors"]), 2)
- self.assertEqual(data["meta"]["authors"][0], "3")
- self.assertEqual(data["meta"]["authors"][1], "False")
- self.assertEqual(len(data["meta"].keys()), 2)
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ meta = {"name": 33}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "33")
+ self.assertEqual(len(data["meta"].keys()), 1)
+
+ meta = {"name": ["foo", "bar"], "description": False}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "['foo', 'bar']")
+ self.assertEqual(data["meta"]["description"], "False")
+ self.assertEqual(len(data["meta"].keys()), 2)
+
+ meta = {"name": "foo", "references": (3, False)}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "foo")
+ self.assertEqual(len(data["meta"]["references"]), 2)
+ self.assertEqual(data["meta"]["references"][0], "3")
+ self.assertEqual(data["meta"]["references"][1], "False")
+ self.assertEqual(len(data["meta"].keys()), 2)
+
+ meta = {"name": "foo", "authors": (3, False)}
+ data = create_input(frames=TF, meta=meta)
+ self.assertEqual(data["meta"]["name"], "foo")
+ self.assertEqual(len(data["meta"]["authors"]), 2)
+ self.assertEqual(data["meta"]["authors"][0], "3")
+ self.assertEqual(data["meta"]["authors"][1], "False")
+ self.assertEqual(len(data["meta"].keys()), 2)
class TestCreateInputProperties(unittest.TestCase):
def test_properties(self):
- # values are numbers
- properties = {"name": {"target": "atom", "values": [2, 3, 4]}}
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "atom")
- self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
- self.assertEqual(len(data["properties"]["name"].keys()), 2)
-
- # values are strings
- properties = {"name": {"target": "atom", "values": ["2", "3", "4"]}}
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "atom")
- self.assertEqual(data["properties"]["name"]["values"], ["2", "3", "4"])
- self.assertEqual(len(data["properties"]["name"].keys()), 2)
-
- properties = {
- "name": {
- "target": "atom",
- "values": [2, 3, 4],
- "description": "foo",
- },
- }
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "atom")
- self.assertEqual(data["properties"]["name"]["description"], "foo")
- self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
- self.assertEqual(len(data["properties"]["name"].keys()), 3)
-
- properties = {
- "name": {
- "target": "atom",
- "values": [2, 3, 4],
- "units": "foo",
- },
- }
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "atom")
- self.assertEqual(data["properties"]["name"]["units"], "foo")
- self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
- self.assertEqual(len(data["properties"]["name"].keys()), 3)
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ # values are numbers
+ properties = {"name": {"target": "atom", "values": [2, 3, 4]}}
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["target"], "atom")
+ self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
+ self.assertEqual(len(data["properties"]["name"].keys()), 2)
+
+ # values are strings
+ properties = {"name": {"target": "atom", "values": ["2", "3", "4"]}}
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["target"], "atom")
+ self.assertEqual(data["properties"]["name"]["values"], ["2", "3", "4"])
+ self.assertEqual(len(data["properties"]["name"].keys()), 2)
+
+ properties = {
+ "name": {
+ "target": "atom",
+ "values": [2, 3, 4],
+ "description": "foo",
+ },
+ }
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["target"], "atom")
+ self.assertEqual(data["properties"]["name"]["description"], "foo")
+ self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
+ self.assertEqual(len(data["properties"]["name"].keys()), 3)
+
+ properties = {
+ "name": {
+ "target": "atom",
+ "values": [2, 3, 4],
+ "units": "foo",
+ },
+ }
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["target"], "atom")
+ self.assertEqual(data["properties"]["name"]["units"], "foo")
+ self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
+ self.assertEqual(len(data["properties"]["name"].keys()), 3)
def test_ndarray_properties(self):
- # shape N
- properties = {"name": {"target": "atom", "values": np.array([2, 3, 4])}}
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "atom")
- self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
- self.assertEqual(len(data["properties"].keys()), 1)
-
- # shape N
- properties = {"name": {"target": "atom", "values": np.array(["2", "3", "4"])}}
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "atom")
- self.assertEqual(data["properties"]["name"]["values"], ["2", "3", "4"])
- self.assertEqual(len(data["properties"].keys()), 1)
-
- # shape N x 1
- properties = {"name": {"target": "atom", "values": np.array([[2], [3], [4]])}}
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "atom")
- self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
- self.assertEqual(len(data["properties"].keys()), 1)
-
- # shape N x 3
- properties = {
- "name": {
- "target": "atom",
- "values": np.array([[1, 2, 4], [1, 2, 4], [1, 2, 4]]),
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ # shape N
+ properties = {"name": {"target": "atom", "values": np.array([2, 3, 4])}}
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["target"], "atom")
+ self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
+ self.assertEqual(len(data["properties"].keys()), 1)
+
+ # shape N
+ properties = {
+ "name": {"target": "atom", "values": np.array(["2", "3", "4"])}
}
- }
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name[1]"]["target"], "atom")
- self.assertEqual(data["properties"]["name[1]"]["values"], [1, 1, 1])
- self.assertEqual(data["properties"]["name[2]"]["target"], "atom")
- self.assertEqual(data["properties"]["name[2]"]["values"], [2, 2, 2])
- self.assertEqual(data["properties"]["name[3]"]["target"], "atom")
- self.assertEqual(data["properties"]["name[3]"]["values"], [4, 4, 4])
- self.assertEqual(len(data["properties"].keys()), 3)
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["target"], "atom")
+ self.assertEqual(data["properties"]["name"]["values"], ["2", "3", "4"])
+ self.assertEqual(len(data["properties"].keys()), 1)
+
+ # shape N x 1
+ properties = {
+ "name": {"target": "atom", "values": np.array([[2], [3], [4]])}
+ }
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["target"], "atom")
+ self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
+ self.assertEqual(len(data["properties"].keys()), 1)
+
+ # shape N x 3
+ properties = {
+ "name": {
+ "target": "atom",
+ "values": np.array([[1, 2, 4], [1, 2, 4], [1, 2, 4]]),
+ }
+ }
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name[1]"]["target"], "atom")
+ self.assertEqual(data["properties"]["name[1]"]["values"], [1, 1, 1])
+ self.assertEqual(data["properties"]["name[2]"]["target"], "atom")
+ self.assertEqual(data["properties"]["name[2]"]["values"], [2, 2, 2])
+ self.assertEqual(data["properties"]["name[3]"]["target"], "atom")
+ self.assertEqual(data["properties"]["name[3]"]["values"], [4, 4, 4])
+ self.assertEqual(len(data["properties"].keys()), 3)
def test_shortened_properties(self):
- # atom property
- properties = {"name": [2, 3, 4]}
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "atom")
- self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
- self.assertEqual(len(data["properties"]["name"].keys()), 2)
-
- # frame property
- properties = {"name": [2]}
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "structure")
- self.assertEqual(data["properties"]["name"]["values"], [2])
- self.assertEqual(len(data["properties"]["name"].keys()), 2)
-
- # ndarray frame property
- properties = {"name": np.array([[2, 4]])}
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name[1]"]["target"], "structure")
- self.assertEqual(data["properties"]["name[1]"]["values"], [2])
- self.assertEqual(len(data["properties"]["name[1]"].keys()), 2)
-
- self.assertEqual(data["properties"]["name[2]"]["target"], "structure")
- self.assertEqual(data["properties"]["name[2]"]["values"], [4])
- self.assertEqual(len(data["properties"]["name[2]"].keys()), 2)
-
- # the initial properties object must not be changed
- self.assertEqual(type(properties["name"]), np.ndarray)
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ # atom property
+ properties = {"name": [2, 3, 4]}
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["target"], "atom")
+ self.assertEqual(data["properties"]["name"]["values"], [2, 3, 4])
+ self.assertEqual(len(data["properties"]["name"].keys()), 2)
+
+ # frame property
+ properties = {"name": [2]}
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["target"], "structure")
+ self.assertEqual(data["properties"]["name"]["values"], [2])
+ self.assertEqual(len(data["properties"]["name"].keys()), 2)
+
+ # ndarray frame property
+ properties = {"name": np.array([[2, 4]])}
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name[1]"]["target"], "structure")
+ self.assertEqual(data["properties"]["name[1]"]["values"], [2])
+ self.assertEqual(len(data["properties"]["name[1]"].keys()), 2)
+
+ self.assertEqual(data["properties"]["name[2]"]["target"], "structure")
+ self.assertEqual(data["properties"]["name[2]"]["values"], [4])
+ self.assertEqual(len(data["properties"]["name[2]"].keys()), 2)
+
+ # the initial properties object must not be changed
+ self.assertEqual(type(properties["name"]), np.ndarray)
def test_shortened_properties_errors(self):
- properties = {"name": ["2", "3"]}
- with self.assertRaises(ValueError) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(
- str(cm.exception),
- "The length of property values is different from the number of "
- "structures and the number of atoms, we can not guess the target. "
- "Got n_atoms = 3, n_structures = 1, the length of property values "
- "is 2, for the 'name' property",
- )
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ properties = {"name": ["2", "3"]}
+ with self.assertRaises(ValueError) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception),
+ "The length of property values is different from the number of "
+ "structures and the number of atoms, we can not guess the target. "
+ "Got n_atoms = 3, n_structures = 1, the length of property values "
+ "is 2, for the 'name' property",
+ )
- properties = {"name": ase.Atoms("CO2")}
- with self.assertRaises(ValueError) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(
- str(cm.exception),
- "Property values should be either list or numpy array, got "
- " instead",
- )
+ properties = {"name": ase.Atoms("CO2")}
+ with self.assertRaises(ValueError) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception),
+ "Property values should be either list or numpy array, got "
+ " instead",
+ )
- properties = {"name": ["2", "3"]}
- frames_single_atoms = [ase.Atoms("C"), ase.Atoms("H")]
- with self.assertWarns(UserWarning) as cm:
- data = create_input(frames=frames_single_atoms, properties=properties)
+ properties = {"name": ["2", "3"]}
+ frames_single_atoms = [ase.Atoms("C"), ase.Atoms("H")]
+ with self.assertWarns(UserWarning) as cm:
+ data = create_input(frames=frames_single_atoms, properties=properties)
- self.assertEqual(data["properties"]["name"]["target"], "structure")
+ self.assertEqual(data["properties"]["name"]["target"], "structure")
- self.assertEqual(
- cm.warning.args[0],
- "The target of the property 'name' is ambiguous because there is the same "
- "number of atoms and structures. We will assume target=structure",
- )
+ self.assertEqual(
+ cm.warning.args[0],
+ "The target of the property 'name' is ambiguous because there "
+ "is the same number of atoms and structures. We will assume "
+ "target=structure",
+ )
def test_invalid_name(self):
- properties = {"": {"target": "atom", "values": [2, 3, 4]}}
- with self.assertRaises(Exception) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(
- str(cm.exception), "the name of a property can not be the empty string"
- )
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ properties = {"": {"target": "atom", "values": [2, 3, 4]}}
+ with self.assertRaises(Exception) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception),
+ "the name of a property can not be the empty string",
+ )
- properties = {False: {"target": "atom", "values": [2, 3, 4]}}
- with self.assertRaises(Exception) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(
- str(cm.exception),
- "the name of a property name must be a string, "
- "got 'False' of type ",
- )
+ properties = {False: {"target": "atom", "values": [2, 3, 4]}}
+ with self.assertRaises(Exception) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception),
+ "the name of a property name must be a string, "
+ "got 'False' of type ",
+ )
def test_invalid_target(self):
- properties = {"name": {"values": [2, 3, 4]}}
- with self.assertRaises(Exception) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(str(cm.exception), "missing 'target' for the 'name' property")
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ properties = {"name": {"values": [2, 3, 4]}}
+ with self.assertRaises(Exception) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception), "missing 'target' for the 'name' property"
+ )
- properties = {"name": {"target": "atoms", "values": [2, 3, 4]}}
- with self.assertRaises(Exception) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(
- str(cm.exception),
- "the target must be 'atom' or 'structure' for the 'name' property",
- )
+ properties = {"name": {"target": "atoms", "values": [2, 3, 4]}}
+ with self.assertRaises(Exception) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception),
+ "the target must be 'atom' or 'structure' for the 'name' property",
+ )
def test_invalid_types_metadata(self):
- properties = {"name": {"target": "atom", "values": [2, 3, 4], "units": False}}
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["units"], "False")
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ properties = {
+ "name": {"target": "atom", "values": [2, 3, 4], "units": False}
+ }
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["units"], "False")
- properties = {
- "name": {"target": "atom", "values": [2, 3, 4], "description": False}
- }
- data = create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(data["properties"]["name"]["description"], "False")
+ properties = {
+ "name": {"target": "atom", "values": [2, 3, 4], "description": False}
+ }
+ data = create_input(frames=TF, properties=properties)
+ self.assertEqual(data["properties"]["name"]["description"], "False")
def test_property_unknown_keys_warning(self):
- properties = {"name": {"target": "atom", "values": [2, 3, 4], "what": False}}
- with self.assertWarns(UserWarning) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(str(cm.warning), "ignoring unexpected property key: what")
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ properties = {
+ "name": {"target": "atom", "values": [2, 3, 4], "what": False}
+ }
+ with self.assertWarns(UserWarning) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(str(cm.warning), "ignoring unexpected property key: what")
def test_invalid_values_types(self):
- properties = {"name": {"target": "atom", "values": 3}}
- with self.assertRaises(Exception) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(
- str(cm.exception), "unknown type () for property 'name'"
- )
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ properties = {"name": {"target": "atom", "values": 3}}
+ with self.assertRaises(Exception) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception), "unknown type () for property 'name'"
+ )
- properties = {"name": {"target": "atom", "values": {"test": "bad"}}}
- with self.assertRaises(Exception) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(
- str(cm.exception), "unknown type () for property 'name'"
- )
+ properties = {"name": {"target": "atom", "values": {"test": "bad"}}}
+ with self.assertRaises(Exception) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception), "unknown type () for property 'name'"
+ )
- properties = {"name": {"target": "atom", "values": [{}, {}, {}]}}
- with self.assertRaises(Exception) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(
- str(cm.exception),
- "unsupported type in property 'name' values: should be string or number",
- )
+ properties = {"name": {"target": "atom", "values": [{}, {}, {}]}}
+ with self.assertRaises(Exception) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception),
+ "unsupported type in property 'name' values: should be string"
+ " or number",
+ )
def test_wrong_number_of_values(self):
- properties = {"name": {"target": "atom", "values": [2, 3]}}
- environments = [(0, 0, 3), (0, 1, 3), (0, 2, 3)]
- with self.assertRaises(Exception) as cm:
- create_input(
- frames=TEST_FRAMES, properties=properties, environments=environments
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ properties = {"name": {"target": "atom", "values": [2, 3]}}
+ environments = [(0, 0, 3), (0, 1, 3), (0, 2, 3)]
+ with self.assertRaises(Exception) as cm:
+ create_input(
+ frames=TF, properties=properties, environments=environments
+ )
+ self.assertEqual(
+ str(cm.exception),
+ "wrong size for the property 'name' with target=='atom': "
+ "expected 3 values, got 2",
)
- self.assertEqual(
- str(cm.exception),
- "wrong size for the property 'name' with target=='atom': "
- "expected 3 values, got 2",
- )
- properties = {"name": {"target": "structure", "values": [2, 3, 5]}}
- with self.assertRaises(Exception) as cm:
- create_input(frames=TEST_FRAMES, properties=properties)
- self.assertEqual(
- str(cm.exception),
- "wrong size for the property 'name' with target=='structure': "
- "expected 1 values, got 3",
- )
+ properties = {"name": {"target": "structure", "values": [2, 3, 5]}}
+ with self.assertRaises(Exception) as cm:
+ create_input(frames=TF, properties=properties)
+ self.assertEqual(
+ str(cm.exception),
+ "wrong size for the property 'name' with target=='structure': "
+ "expected 1 values, got 3",
+ )
def test_property_only(self):
properties = {"name": [2, 3, 4]}
@@ -358,25 +389,27 @@ def test_property_only(self):
class TestCreateInputEnvironments(unittest.TestCase):
def test_manual_environments_list(self):
- environments = [
- (0, 0, 3.5),
- (1, 1, 2.5),
- (1, 2, 3),
- ]
- data = create_input(frames=TEST_FRAMES + TEST_FRAMES, environments=environments)
- self.assertEqual(len(data["environments"]), 3)
-
- for i, env in enumerate(data["environments"]):
- self.assertEqual(env["structure"], environments[i][0])
- self.assertEqual(env["center"], environments[i][1])
- self.assertEqual(env["cutoff"], environments[i][2])
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ environments = [
+ (0, 0, 3.5),
+ (1, 1, 2.5),
+ (1, 2, 3),
+ ]
+ data = create_input(frames=TF + TF, environments=environments)
+ self.assertEqual(len(data["environments"]), 3)
+
+ for i, env in enumerate(data["environments"]):
+ self.assertEqual(env["structure"], environments[i][0])
+ self.assertEqual(env["center"], environments[i][1])
+ self.assertEqual(env["cutoff"], environments[i][2])
def test_all_environments(self):
- environments = all_atomic_environments(TEST_FRAMES, cutoff=6)
- for i, (structure, center, cutoff) in enumerate(environments):
- self.assertEqual(structure, 0)
- self.assertEqual(center, i)
- self.assertEqual(cutoff, 6)
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ environments = all_atomic_environments(TF, cutoff=6)
+ for i, (structure, center, cutoff) in enumerate(environments):
+ self.assertEqual(structure, 0)
+ self.assertEqual(center, i)
+ self.assertEqual(cutoff, 6)
if __name__ == "__main__":
diff --git a/python/tests/stk_structures.py b/python/tests/stk_structures.py
new file mode 100644
index 000000000..844aeb711
--- /dev/null
+++ b/python/tests/stk_structures.py
@@ -0,0 +1,143 @@
+import unittest
+
+import stk
+
+import chemiscope
+
+BASE_FRAME = stk.BuildingBlock("N#CC")
+
+
+class TestStructures(unittest.TestCase):
+ """Conversion of structure data to chemiscope JSON"""
+
+ def test_structures(self):
+ data = chemiscope.create_input(BASE_FRAME)
+ self.assertEqual(len(data["structures"]), 1)
+ self.assertEqual(data["structures"][0]["size"], 6)
+ self.assertEqual(
+ data["structures"][0]["names"],
+ ["N", "C", "C", "H", "H", "H"],
+ )
+ self.assertEqual(
+ data["structures"][0]["x"],
+ [
+ 1.6991195138834223,
+ 0.7737143493209756,
+ -0.41192204250544034,
+ -0.7778845126633998,
+ -1.1777543806588109,
+ -0.10527292738297804,
+ ],
+ )
+ self.assertEqual(
+ data["structures"][0]["y"],
+ [
+ -1.2265369887154756,
+ -0.5721898035707434,
+ 0.28832060028277334,
+ 0.6076276888433211,
+ -0.27163665176706653,
+ 1.1744151549238042,
+ ],
+ )
+ self.assertEqual(
+ data["structures"][0]["z"],
+ [
+ -0.19321573000005213,
+ -0.10192268845612924,
+ 0.03435599430880268,
+ -0.9630155400427929,
+ 0.6165952621860082,
+ 0.6072027020039786,
+ ],
+ )
+ self.assertEqual(data["structures"][0].get("cell"), None)
+
+ # Not testing cell because stk implementation does not have that yet.
+
+ frame = BASE_FRAME.clone()
+ data = chemiscope.create_input([frame])
+ self.assertEqual(len(data["structures"]), 1)
+ self.assertEqual(data["structures"][0]["size"], 6)
+ self.assertEqual(
+ data["structures"][0]["names"],
+ ["N", "C", "C", "H", "H", "H"],
+ )
+ self.assertEqual(
+ data["structures"][0]["x"],
+ [
+ 1.6991195138834223,
+ 0.7737143493209756,
+ -0.41192204250544034,
+ -0.7778845126633998,
+ -1.1777543806588109,
+ -0.10527292738297804,
+ ],
+ )
+ self.assertEqual(
+ data["structures"][0]["y"],
+ [
+ -1.2265369887154756,
+ -0.5721898035707434,
+ 0.28832060028277334,
+ 0.6076276888433211,
+ -0.27163665176706653,
+ 1.1744151549238042,
+ ],
+ )
+ self.assertEqual(
+ data["structures"][0]["z"],
+ [
+ -0.19321573000005213,
+ -0.10192268845612924,
+ 0.03435599430880268,
+ -0.9630155400427929,
+ 0.6165952621860082,
+ 0.6072027020039786,
+ ],
+ )
+ self.assertEqual(data["structures"][0].get("cell"), None)
+
+
+class TestExtractProperties(unittest.TestCase):
+ """Properties extraction"""
+
+ def test_exception(self):
+ with self.assertRaises(RuntimeError):
+ chemiscope.extract_properties(BASE_FRAME)
+
+
+class TestCompositionProperties(unittest.TestCase):
+ """Composition properties"""
+
+ def test_composition(self):
+ properties = chemiscope.composition_properties([BASE_FRAME, BASE_FRAME])
+
+ self.assertEqual(len(properties.keys()), 6)
+
+ self.assertEqual(properties["composition"]["target"], "structure")
+ self.assertEqual(properties["composition"]["values"], ["C2H3N1", "C2H3N1"])
+
+ self.assertEqual(properties["n_C"]["target"], "structure")
+ self.assertEqual(properties["n_C"]["values"], [2, 2])
+
+ self.assertEqual(properties["n_N"]["target"], "structure")
+ self.assertEqual(properties["n_N"]["values"], [1, 1])
+
+ self.assertEqual(properties["n_H"]["target"], "structure")
+ self.assertEqual(properties["n_H"]["values"], [3, 3])
+
+ self.assertEqual(properties["symbol"]["target"], "atom")
+ self.assertEqual(
+ properties["symbol"]["values"],
+ ["N", "C", "C", "H", "H", "H", "N", "C", "C", "H", "H", "H"],
+ )
+
+ self.assertEqual(properties["number"]["target"], "atom")
+ self.assertEqual(
+ properties["number"]["values"], [7, 6, 6, 1, 1, 1, 7, 6, 6, 1, 1, 1]
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/tests/write_input.py b/python/tests/write_input.py
index 193d95e96..55cf17543 100644
--- a/python/tests/write_input.py
+++ b/python/tests/write_input.py
@@ -4,10 +4,12 @@
import unittest
import ase
+import stk
from chemiscope import write_input
TEST_FRAMES = [ase.Atoms("CO2")]
+TEST_FRAMES_STK = [stk.BuildingBlock("NCCN")]
def is_gz_file(filepath):
@@ -18,34 +20,37 @@ def is_gz_file(filepath):
class TestWriteInput(unittest.TestCase):
def test_file_path_as_dataset_name(self):
- with tempfile.TemporaryDirectory() as dirname:
- path = os.path.join(dirname, "test.json")
- write_input(path, TEST_FRAMES)
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ with tempfile.TemporaryDirectory() as dirname:
+ path = os.path.join(dirname, "test.json")
+ write_input(path, TF)
- with open(path) as fd:
- data = json.load(fd)
+ with open(path) as fd:
+ data = json.load(fd)
- self.assertEqual(data["meta"]["name"], "test")
+ self.assertEqual(data["meta"]["name"], "test")
def test_create_gz_file(self):
- with tempfile.TemporaryDirectory() as dirname:
- path = os.path.join(dirname, "test.json")
- write_input(path, TEST_FRAMES)
- self.assertFalse(is_gz_file(path))
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ with tempfile.TemporaryDirectory() as dirname:
+ path = os.path.join(dirname, "test.json")
+ write_input(path, TF)
+ self.assertFalse(is_gz_file(path))
- path = os.path.join(dirname, "test.json.gz")
- write_input(path, TEST_FRAMES)
- self.assertTrue(is_gz_file(path))
+ path = os.path.join(dirname, "test.json.gz")
+ write_input(path, TF)
+ self.assertTrue(is_gz_file(path))
def test_wrong_path(self):
- with tempfile.TemporaryDirectory() as dirname:
- path = os.path.join(dirname, "test.tmp")
- with self.assertRaises(Exception) as cm:
- write_input(path, TEST_FRAMES)
-
- self.assertEqual(
- str(cm.exception), "path should end with .json or .json.gz"
- )
+ for TF in (TEST_FRAMES, TEST_FRAMES_STK):
+ with tempfile.TemporaryDirectory() as dirname:
+ path = os.path.join(dirname, "test.tmp")
+ with self.assertRaises(Exception) as cm:
+ write_input(path, TF)
+
+ self.assertEqual(
+ str(cm.exception), "path should end with .json or .json.gz"
+ )
if __name__ == "__main__":
diff --git a/tox.ini b/tox.ini
index 86749e9e1..a3099cbfb 100644
--- a/tox.ini
+++ b/tox.ini
@@ -27,6 +27,8 @@ commands =
description = Run Python unit tests
deps =
ase==3.22.1
+ rdkit==2024.3.4
+ stk
commands =
pip install chemiscope[explore]