Skip to content

Commit

Permalink
compute radii and colors on the fly (#768)
Browse files Browse the repository at this point in the history
* compute radii and colors on the fly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix non-existing `colors` and `radii` keys

* fix connect

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove comments

* fix colors

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Dec 13, 2024
1 parent e62e9b3 commit 317a485
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 37 deletions.
15 changes: 15 additions & 0 deletions app/src/components/api.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import * as THREE from "three";
import { JMOL_COLORS, covalentRadii } from "./data";
import * as znsocket from "znsocket";
import { client } from "../socket";

Expand Down Expand Up @@ -327,6 +328,12 @@ export const setupFrames = (
const currentFrameUpdatedFromSocketRef = useRef(true);
const customRoomAvailRef = useRef(false); // Track whether listening to the default room
const [updateStepInPlace, setUpdateStepInPlace] = useState(0);
const scaledRadii = useMemo(() => {
const minRadius = Math.min(...covalentRadii);
const maxRadius = Math.max(...covalentRadii);
const range = maxRadius - minRadius;
return covalentRadii.map((x: number) => (x - minRadius) / range + 0.3);
}, [covalentRadii]);

const setCurrentFrameFromObject = (frame: any) => {
frame = frame.value;
Expand All @@ -335,6 +342,14 @@ export const setupFrames = (
new THREE.Vector3(position[0], position[1], position[2]),
) as THREE.Vector3[];
console.log("frame updated");
if (!frame?.arrays?.colors) {
frame.arrays.colors = frame.numbers.map(
(x: number) => "#" + JMOL_COLORS[x].getHexString(),
);
}
if (!frame.arrays.radii) {
frame.arrays.radii = frame.numbers.map((x: number) => scaledRadii[x]);
}
setCurrentFrame(frame);
currentFrameUpdatedFromSocketRef.current = true;
};
Expand Down
125 changes: 125 additions & 0 deletions app/src/components/data.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import { Color } from "three";

export const JMOL_COLORS = [
new Color("#ff0000"),
new Color("#ffffff"),
new Color("#d9ffff"),
new Color("#cc80ff"),
new Color("#c2ff00"),
new Color("#ffb5b5"),
new Color("#909090"),
new Color("#2f50f8"),
new Color("#ff0d0d"),
new Color("#90df50"),
new Color("#b3e2f5"),
new Color("#ab5cf1"),
new Color("#89ff00"),
new Color("#bea6a6"),
new Color("#efc79f"),
new Color("#ff8000"),
new Color("#ffff2f"),
new Color("#1fef1f"),
new Color("#80d1e2"),
new Color("#8f40d3"),
new Color("#3cff00"),
new Color("#e6e6e6"),
new Color("#bec2c6"),
new Color("#a6a6ab"),
new Color("#8999c6"),
new Color("#9c79c6"),
new Color("#df6633"),
new Color("#ef909f"),
new Color("#50d050"),
new Color("#c78033"),
new Color("#7c80af"),
new Color("#c28f8f"),
new Color("#668f8f"),
new Color("#bc80e2"),
new Color("#ffa000"),
new Color("#a62929"),
new Color("#5cb8d1"),
new Color("#6f2daf"),
new Color("#00ff00"),
new Color("#93ffff"),
new Color("#93dfdf"),
new Color("#73c2c8"),
new Color("#53b5b5"),
new Color("#3a9e9e"),
new Color("#238f8f"),
new Color("#097c8b"),
new Color("#006985"),
new Color("#c0c0c0"),
new Color("#ffd98f"),
new Color("#a67573"),
new Color("#668080"),
new Color("#9e62b5"),
new Color("#d37900"),
new Color("#930093"),
new Color("#429eaf"),
new Color("#56168f"),
new Color("#00c800"),
new Color("#6fd3ff"),
new Color("#ffffc6"),
new Color("#d9ffc6"),
new Color("#c6ffc6"),
new Color("#a2ffc6"),
new Color("#8fffc6"),
new Color("#60ffc6"),
new Color("#45ffc6"),
new Color("#2fffc6"),
new Color("#1fffc6"),
new Color("#00ff9c"),
new Color("#00e675"),
new Color("#00d352"),
new Color("#00be38"),
new Color("#00ab23"),
new Color("#4dc2ff"),
new Color("#4da6ff"),
new Color("#2093d5"),
new Color("#257cab"),
new Color("#256695"),
new Color("#165386"),
new Color("#d0d0df"),
new Color("#ffd122"),
new Color("#b8b8d0"),
new Color("#a6534d"),
new Color("#565860"),
new Color("#9e4fb5"),
new Color("#ab5c00"),
new Color("#754f45"),
new Color("#428295"),
new Color("#420066"),
new Color("#007c00"),
new Color("#6fabf9"),
new Color("#00b9ff"),
new Color("#00a0ff"),
new Color("#008fff"),
new Color("#0080ff"),
new Color("#006bff"),
new Color("#535cf1"),
new Color("#785ce2"),
new Color("#894fe2"),
new Color("#a036d3"),
new Color("#b31fd3"),
new Color("#b31fb9"),
new Color("#b30da6"),
new Color("#bc0d86"),
new Color("#c60066"),
new Color("#cc0058"),
new Color("#d1004f"),
new Color("#d90045"),
new Color("#df0038"),
new Color("#e6002d"),
new Color("#eb0025"),
];

export const covalentRadii = [
1, 0.31, 0.28, 1.28, 0.96, 0.84, 0.76, 0.71, 0.66, 0.57, 0.58, 1.66, 1.41,
1.21, 1.11, 1.07, 1.05, 1.02, 1.06, 2.03, 1.76, 1.7, 1.6, 1.53, 1.39, 1.39,
1.32, 1.26, 1.24, 1.32, 1.22, 1.22, 1.2, 1.19, 1.2, 1.2, 1.16, 2.2, 1.95, 1.9,
1.75, 1.64, 1.54, 1.47, 1.46, 1.42, 1.39, 1.45, 1.44, 1.42, 1.39, 1.39, 1.38,
1.39, 1.4, 2.44, 2.15, 2.07, 2.04, 2.03, 2.01, 1.99, 1.98, 1.98, 1.96, 1.94,
1.92, 1.92, 1.89, 1.9, 1.87, 1.87, 1.75, 1.7, 1.62, 1.51, 1.44, 1.41, 1.36,
1.36, 1.32, 1.45, 1.46, 1.48, 1.4, 1.5, 1.5, 2.6, 2.21, 2.15, 2.06, 2.0, 1.96,
1.9, 1.87, 1.8, 1.69,
];
15 changes: 7 additions & 8 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ def test_ase_converter(s22):

assert structures[3].calc.results == {"energy": 0.0, "predicted_energy": 1.0}

assert "colors" in structures[0].arrays
assert "radii" in structures[0].arrays
assert "colors" not in structures[0].arrays
assert "radii" not in structures[0].arrays

assert structures[4].info == {"key": "value"}


def test_exotic_atoms():
atoms = ase.Atoms("X", positions=[[0, 0, 0]])
atoms.arrays["colors"] = ["#ff0000"]
atoms.arrays["radii"] = [0.3]

new_atoms = znjson.loads(
znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
cls=znjson.ZnDecoder.from_converters([ASEConverter]),
Expand All @@ -60,8 +63,7 @@ def test_modified_atoms():
znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
cls=znjson.ZnDecoder.from_converters([ASEConverter]),
)
npt.assert_array_equal(new_atoms.arrays["colors"], ["#ffffff", "#ffffff"])
npt.assert_almost_equal(new_atoms.arrays["radii"], [0.3458333, 0.3458333])
npt.assert_array_equal(new_atoms.get_atomic_numbers(), [1, 1])

# subtract
atoms = new_atoms[:1]
Expand All @@ -70,8 +72,7 @@ def test_modified_atoms():
cls=znjson.ZnDecoder.from_converters([ASEConverter]),
)

npt.assert_array_equal(new_atoms.arrays["colors"], ["#ffffff"])
npt.assert_almost_equal(new_atoms.arrays["radii"], [0.3458333])
npt.assert_array_equal(new_atoms.get_atomic_numbers(), [1])

# add
atoms = new_atoms + ase.Atoms("H", positions=[[0, 0, 1]])
Expand All @@ -82,8 +83,6 @@ def test_modified_atoms():
)

npt.assert_array_equal(new_atoms.get_atomic_numbers(), [1, 1])
npt.assert_array_equal(new_atoms.arrays["colors"], ["#ffffff", "#ffffff"])
npt.assert_almost_equal(new_atoms.arrays["radii"], [0.3458333, 0.3458333])


def test_constraints_fixed_atoms():
Expand Down
22 changes: 0 additions & 22 deletions zndraw/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import znjson
from ase.calculators.singlepoint import SinglePointCalculator
from ase.constraints import FixAtoms
from ase.data.colors import jmol_colors

from zndraw.draw import Object3D
from zndraw.type_defs import ASEDict
from zndraw.utils import get_scaled_radii, rgb2hex


class ASEConverter(znjson.ConverterBase):
Expand Down Expand Up @@ -86,28 +84,8 @@ def encode(self, obj: ase.Atoms) -> ASEDict:
# All additional information should be stored in calc.results
# and not in calc.arrays, thus we will not convert it here!
arrays = {}
if ("colors" not in obj.arrays) or ("" in obj.arrays["colors"]):
arrays["colors"] = [rgb2hex(jmol_colors[number]) for number in numbers]
else:
arrays["colors"] = (
obj.arrays["colors"].tolist()
if isinstance(obj.arrays["colors"], np.ndarray)
else obj.arrays["colors"]
)

if ("radii" not in obj.arrays) or (0 in obj.arrays["radii"]):
# arrays["radii"] = [covalent_radii[number] for number in numbers]
arrays["radii"] = [get_scaled_radii()[number] for number in numbers]
else:
arrays["radii"] = (
obj.arrays["radii"].tolist()
if isinstance(obj.arrays["radii"], np.ndarray)
else obj.arrays["radii"]
)

for key in obj.arrays:
if key in ["colors", "radii"]:
continue
if isinstance(obj.arrays[key], np.ndarray):
arrays[key] = obj.arrays[key].tolist()
else:
Expand Down
17 changes: 10 additions & 7 deletions zndraw/modify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import Field

from zndraw.base import Extension
from zndraw.utils import get_scaled_radii

try:
from zndraw.modify import extras # noqa: F401
Expand Down Expand Up @@ -49,7 +50,9 @@ def run(self, vis: "ZnDraw", **kwargs) -> None:
camera_position = np.array(vis.camera["position"])[None, :] # 1,3

new_points = atom_positions[atom_ids] # N, 3
radii: np.ndarray = atoms.arrays["radii"][atom_ids][:, None]
radii = np.array(
[get_scaled_radii()[number] for number in atoms.numbers[atom_ids]]
)[:, None]
direction = camera_position - new_points
direction /= np.linalg.norm(direction, axis=1, keepdims=True)
new_points += direction * radii
Expand Down Expand Up @@ -167,8 +170,8 @@ def run(self, vis: "ZnDraw", **kwargs) -> None:
atom.position += np.array([self.x, self.y, self.z])
atom.symbol = self.symbol.name if self.symbol.name != "X" else atom.symbol
atoms += atom
del atoms.arrays["colors"]
del atoms.arrays["radii"]
atoms.arrays.pop("colors", None)
atoms.arrays.pop("radii", None)
if hasattr(atoms, "connectivity"):
del atoms.connectivity

Expand All @@ -187,8 +190,8 @@ def run(self, vis: "ZnDraw", **kwargs) -> None:
for atom_id in vis.selection:
atoms[atom_id].symbol = self.symbol.name

del atoms.arrays["colors"]
del atoms.arrays["radii"]
atoms.arrays.pop("colors", None)
atoms.arrays.pop("radii", None)
if hasattr(atoms, "connectivity"):
# vdW radii might change
del atoms.connectivity
Expand All @@ -208,8 +211,8 @@ def run(self, vis: "ZnDraw", **kwargs) -> None:
for point in vis.points:
atoms += ase.Atom(self.symbol.name, position=point)

del atoms.arrays["colors"]
del atoms.arrays["radii"]
atoms.arrays.pop("colors", None)
atoms.arrays.pop("radii", None)
if hasattr(atoms, "connectivity"):
del atoms.connectivity

Expand Down

0 comments on commit 317a485

Please sign in to comment.