Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix properties not loading #263

Merged
merged 3 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 23 additions & 28 deletions zndraw/analyse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import itertools
import logging
import typing as t
from typing import Any

import ase
import numpy as np
import pandas as pd
import plotly.express as px
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from zndraw.utils import set_global_atoms
from zndraw.utils import SHARED, set_global_atoms

log = logging.getLogger(__name__)


def _schema_from_atoms(schema, cls):
return cls.model_json_schema_from_atoms(schema)


class Distance(BaseModel):
method: t.Literal["Distance"] = "Distance"
discriminator: t.Literal["Distance"] = Field("Distance")

smooth: bool = False

Expand Down Expand Up @@ -51,20 +54,21 @@ def run(self, atoms_lst, ids):


class Properties2D(BaseModel):
method: t.Literal["Properties2D"] = "Properties2D"

discriminator: t.Literal["Properties2D"] = Field("Properties2D")
x_data: str = "step"
y_data: str = "energy"
color: str = "energy"
fix_aspect_ratio: bool = True

model_config = ConfigDict(json_schema_extra=_schema_from_atoms)

@classmethod
def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]:
schema = super().model_json_schema(*args, **kwargs)
log.debug(f"GATHERING PROPERTIES FROM {ATOMS=}") # noqa: F821
def model_json_schema_from_atoms(cls, schema: dict) -> dict:
ATOMS = SHARED["atoms"]
log.debug(f"GATHERING PROPERTIES FROM {ATOMS=}")
try:
available_properties = list(ATOMS.calc.results) # noqa: F821
available_properties += list(ATOMS.arrays) # noqa: F821
available_properties = list(ATOMS.calc.results)
available_properties += list(ATOMS.arrays)
available_properties += ["step"]
schema["properties"]["x_data"]["enum"] = available_properties
schema["properties"]["y_data"]["enum"] = available_properties
Expand Down Expand Up @@ -117,22 +121,24 @@ def run(self, atoms_lst, ids):


class Properties1D(BaseModel):
method: t.Literal["Properties1D"] = "Properties1D"
discriminator: t.Literal["Properties1D"] = Field("Properties1D")

value: str = "energy"
smooth: bool = False

model_config = ConfigDict(json_schema_extra=_schema_from_atoms)

@classmethod
def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]:
schema = super().model_json_schema(*args, **kwargs)
def model_json_schema_from_atoms(cls, schema: dict) -> dict:
ATOMS = SHARED["atoms"]
try:
available_properties = list(
ATOMS.calc.results.keys() # noqa: F821
ATOMS.calc.results.keys()
) # global ATOMS object
log.debug(f"AVAILABLE PROPERTIES: {available_properties=}")
schema["properties"]["value"]["enum"] = available_properties
except AttributeError:
pass
print(f"{ATOMS=}")
return schema

def run(self, atoms_lst, ids):
Expand All @@ -156,7 +162,7 @@ def run(self, atoms_lst, ids):
def get_analysis_class(methods):
class Analysis(BaseModel):
method: methods = Field(
..., description="Analysis method", discriminator="method"
..., description="Analysis method", discriminator="discriminator"
)

def run(self, *args, **kwargs) -> list[ase.Atoms]:
Expand All @@ -170,15 +176,4 @@ def model_json_schema_from_atoms(
result = cls.model_json_schema(*args, **kwargs)
return result

@classmethod
def model_json_schema(cls, *args, **kwargs) -> dict[str, t.Any]:
schema = super().model_json_schema(*args, **kwargs)
for prop in [x.__name__ for x in t.get_args(methods)]:
schema["$defs"][prop]["properties"]["method"]["options"] = {
"hidden": True
}
schema["$defs"][prop]["properties"]["method"]["type"] = "string"

return schema

return Analysis
10 changes: 6 additions & 4 deletions zndraw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import ase
import datamodel_code_generator

SHARED = {"atoms": None}


def get_port(default: int = 1234) -> int:
"""Get an open port."""
Expand All @@ -29,11 +31,11 @@ def get_port(default: int = 1234) -> int:

@contextlib.contextmanager
def set_global_atoms(atoms: ase.Atoms):
"""Temporarily create a global 'ATOMS' variable."""
global ATOMS
ATOMS = atoms
"""Temporarily create a 'SHARED["atoms"]' variable."""
# TODO: I know this is bad, but I don't know how to do it better - send help!
SHARED["atoms"] = atoms
yield
del ATOMS
SHARED["atoms"] = None


class ZnDrawLoggingHandler(logging.Handler):
Expand Down
5 changes: 4 additions & 1 deletion zndraw/zndraw.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,13 @@ def analysis_schema(self):
except IndexError:
atoms = ase.Atoms()

schema = cls.model_json_schema_from_atoms(atoms)
hide_discriminator_field(schema)

self.socket.emit(
"analysis:schema",
{
"schema": cls.model_json_schema_from_atoms(atoms),
"schema": schema,
"sid": self._target_sid,
},
)
Expand Down