Skip to content

Commit

Permalink
[Breaking] Split utils into sub-modules, move typing from utils
Browse files Browse the repository at this point in the history
… to root (`pmv.typing`) (#248)

* migrate some
* more move

* clean up docstring

* ptable_hists_plotly | ptable_heatmap_splits_plotly support multiple annotations per element tile

* fix test_init.py not wrapping asserts in test function

* merge utils/misc.py into utils/__init__.py and utils/image.py into utils/plotting.py

* move ExperimentalWarning into __init__

* mv utils/typing.py to typing.py
---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
DanielYang59 and janosh authored Nov 20, 2024
1 parent 0d24ace commit bc2669b
Show file tree
Hide file tree
Showing 32 changed files with 1,056 additions and 887 deletions.
2 changes: 1 addition & 1 deletion assets/scripts/coordination/coordination_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pymatviz as pmv
from pymatviz.coordination import CnSplitMode
from pymatviz.utils import TEST_FILES
from pymatviz.utils.testing import TEST_FILES


pmv.set_plotly_template("pymatviz_white")
Expand Down
2 changes: 1 addition & 1 deletion assets/scripts/coordination/coordination_vs_cutoff_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pymatviz as pmv
from pymatviz.enums import Key
from pymatviz.utils import TEST_FILES
from pymatviz.utils.testing import TEST_FILES


pmv.set_plotly_template("pymatviz_white")
Expand Down
2 changes: 1 addition & 1 deletion assets/scripts/histogram/spacegroup_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


# %% Spacegroup histograms
for backend in pmv.BACKENDS:
for backend in pmv.typing.BACKENDS:
fig = pmv.spacegroup_bar(df_phonons[Key.spg_num], backend=backend)
pmv.io.save_and_compress_svg(fig, f"spg-num-hist-{backend}")

Expand Down
3 changes: 2 additions & 1 deletion assets/scripts/phonons/phonon_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pymatviz as pmv
from pymatviz.enums import Key
from pymatviz.utils.testing import TEST_FILES


# TODO: ffonons not working properly (see #195)
Expand All @@ -23,7 +24,7 @@
("mp-23907", "H2"),
):
docs = {}
for path in glob(f"{pmv.utils.TEST_FILES}/phonons/{mp_id}-{formula}-*.json.lzma"):
for path in glob(f"{TEST_FILES}/phonons/{mp_id}-{formula}-*.json.lzma"):
model_label = (
"CHGNet"
if "chgnet" in path
Expand Down
3 changes: 2 additions & 1 deletion assets/scripts/phonons/phonon_bands_and_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pymatviz as pmv
from pymatviz.enums import Key
from pymatviz.utils.testing import TEST_FILES


# TODO: ffonons not working properly (see #195)
Expand All @@ -24,7 +25,7 @@
("mp-23907", "H2"),
):
docs = {}
for path in glob(f"{pmv.utils.TEST_FILES}/phonons/{mp_id}-{formula}-*.json.lzma"):
for path in glob(f"{TEST_FILES}/phonons/{mp_id}-{formula}-*.json.lzma"):
key = path.split("-")[-1].split(".")[0]
with zopen(path) as file:
docs[key] = json.loads(file.read(), cls=MontyDecoder)
Expand Down
3 changes: 2 additions & 1 deletion assets/scripts/phonons/phonon_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pymatviz as pmv
from pymatviz.enums import Key
from pymatviz.utils.testing import TEST_FILES


# TODO: ffonons not working properly (see #195)
Expand All @@ -23,7 +24,7 @@
("mp-23907", "H2"),
):
docs = {}
for path in glob(f"{pmv.utils.TEST_FILES}/phonons/{mp_id}-{formula}-*.json.lzma"):
for path in glob(f"{TEST_FILES}/phonons/{mp_id}-{formula}-*.json.lzma"):
key = path.split("-")[-1].split(".")[0]
with zopen(path) as file:
docs[key] = json.loads(file.read(), cls=MontyDecoder)
Expand Down
2 changes: 1 addition & 1 deletion assets/scripts/structure_viz/structure_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pymatviz as pmv
from pymatviz.enums import ElemColorScheme, Key
from pymatviz.utils import TEST_FILES
from pymatviz.utils.testing import TEST_FILES


df_phonons = load_dataset("matbench_phonons")
Expand Down
2 changes: 1 addition & 1 deletion assets/scripts/xrd/xrd_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pymatgen.core import Structure

import pymatviz as pmv
from pymatviz.utils import TEST_FILES
from pymatviz.utils.testing import TEST_FILES


pmv.set_plotly_template("pymatviz_white")
Expand Down
11 changes: 2 additions & 9 deletions pymatviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
structure_viz,
sunburst,
templates,
typing,
uncertainty,
utils,
xrd,
Expand Down Expand Up @@ -82,15 +83,7 @@
set_plotly_template,
)
from pymatviz.uncertainty import error_decay_with_uncert, qq_gaussian
from pymatviz.utils import (
BACKENDS,
PKG_DIR,
ROOT,
df_ptable,
html_tag,
si_fmt,
si_fmt_int,
)
from pymatviz.utils import PKG_DIR, ROOT, df_ptable, html_tag, si_fmt, si_fmt_int
from pymatviz.xrd import xrd_pattern


Expand Down
5 changes: 4 additions & 1 deletion pymatviz/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
from pymatgen.symmetry.groups import SpaceGroup

from pymatviz.enums import Key
from pymatviz.utils import PLOTLY, Backend, crystal_sys_from_spg_num, si_fmt_int
from pymatviz.typing import PLOTLY
from pymatviz.utils import crystal_sys_from_spg_num, si_fmt_int


if TYPE_CHECKING:
from typing import Any, Literal

from pymatviz.typing import Backend


def spacegroup_bar(
data: Sequence[int | str | Structure] | pd.Series,
Expand Down
4 changes: 2 additions & 2 deletions pymatviz/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from pymatviz.bar import spacegroup_bar
from pymatviz.enums import ElemCountMode
from pymatviz.process_data import count_elements
from pymatviz.utils import BACKENDS, MATPLOTLIB, PLOTLY, Backend
from pymatviz.typing import BACKENDS, MATPLOTLIB, PLOTLY, Backend


if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any, Literal

from pymatviz.utils import ElemValues
from pymatviz.typing import ElemValues


def spacegroup_hist(*args: Any, **kwargs: Any) -> plt.Axes | go.Figure:
Expand Down
4 changes: 3 additions & 1 deletion pymatviz/powerups/both.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
from sklearn.metrics import mean_absolute_percentage_error as mape
from sklearn.metrics import r2_score

from pymatviz.utils import (
from pymatviz.typing import (
BACKENDS,
MATPLOTLIB,
PLOTLY,
VALID_FIG_NAMES,
VALID_FIG_TYPES,
AxOrFig,
Backend,
)
from pymatviz.utils import (
annotate,
get_fig_xy_range,
get_font_color,
Expand Down
4 changes: 3 additions & 1 deletion pymatviz/process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from pymatgen.core import Composition

from pymatviz.enums import ElemCountMode, Key
from pymatviz.utils import ElemValues, df_ptable
from pymatviz.utils import df_ptable


if TYPE_CHECKING:
from collections.abc import Sequence

from pymatviz.typing import ElemValues


def count_elements(
values: ElemValues,
Expand Down
8 changes: 2 additions & 6 deletions pymatviz/ptable/ptable_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
OverwriteTileValueColor,
PTableProjector,
)
from pymatviz.utils import (
ColorElemTypeStrategy,
ElemValues,
get_cbar_label_formatter,
pick_bw_for_contrast,
)
from pymatviz.utils import get_cbar_label_formatter, pick_bw_for_contrast


if TYPE_CHECKING:
Expand All @@ -34,6 +29,7 @@
from pymatgen.core import Element

from pymatviz.ptable._process_data import PTableData
from pymatviz.typing import ColorElemTypeStrategy, ElemValues

# Custom types
ElemStr: TypeAlias = str # element as a str
Expand Down
61 changes: 34 additions & 27 deletions pymatviz/ptable/ptable_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from pymatviz.colors import ELEM_TYPE_COLORS
from pymatviz.enums import ElemCountMode
from pymatviz.process_data import count_elements
from pymatviz.utils import (
from pymatviz.typing import (
VALID_COLOR_ELEM_STRATEGIES,
ColorElemTypeStrategy,
ElemValues,
df_ptable,
)
from pymatviz.utils import df_ptable


if TYPE_CHECKING:
Expand Down Expand Up @@ -400,7 +400,7 @@ def ptable_hists_plotly(
symbol_kwargs: dict[str, Any] | None = None,
# Annotation
annotations: dict[str, str | dict[str, Any]]
| Callable[[np.ndarray], str | dict[str, Any]]
| Callable[[Sequence[float]], str | dict[str, Any] | list[dict[str, Any]]]
| None = None,
# Element type colors
color_elem_strategy: ColorElemTypeStrategy = "background",
Expand Down Expand Up @@ -575,21 +575,25 @@ def ptable_hists_plotly(
# Pass the element's values to the callable
annotation = annotations(values)
else:
# Use dictionary lookup as before
# Use dictionary lookup
annotation = annotations.get(symbol, "")

if annotation: # Only add annotation if we have text
annotation = (
{"text": annotation} if isinstance(annotation, str) else annotation
)
anno_defaults = {
"font_size": (font_size or 8) * scale,
"font_color": font_color,
"x": 1,
"y": 0.97,
"showarrow": False,
}
fig.add_annotation(anno_defaults | xy_ref | annotation)
# Convert single annotation to list for uniform handling
for anno in (
[annotation] if isinstance(annotation, str | dict) else annotation
):
# Convert string annotations to dict format
anno_dict = {"text": anno} if isinstance(anno, str) else anno
anno_defaults = {
"font_size": (font_size or 8) * scale,
"x": 0.95,
"y": 0.95,
"showarrow": False,
"xanchor": "right",
"yanchor": "top",
}
fig.add_annotation(**anno_defaults | xy_ref | anno_dict)

if colorbar is not False:
colorbar = dict(orientation="h", lenmode="fraction", thickness=15) | (
Expand Down Expand Up @@ -925,18 +929,21 @@ def create_section_coords(
annotation = annotations.get(symbol, "")

if annotation: # Only add annotation if we have text
annotation = (
{"text": annotation} if isinstance(annotation, str) else annotation
)
anno_defaults = {
"font_size": (font_size or 8) * scale,
"x": 0.95,
"y": 0.95,
"showarrow": False,
"xanchor": "right",
"yanchor": "top",
}
fig.add_annotation(**anno_defaults | xy_ref | annotation)
# Convert single annotation to list for uniform handling
for anno in (
[annotation] if isinstance(annotation, str | dict) else annotation
):
# Convert string annotations to dict format
anno_dict = {"text": anno} if isinstance(anno, str) else anno
anno_defaults = {
"font_size": (font_size or 8) * scale,
"x": 0.95,
"y": 0.95,
"showarrow": False,
"xanchor": "right",
"yanchor": "top",
}
fig.add_annotation(**anno_defaults | xy_ref | anno_dict)

# Update layout
fig.layout.showlegend = False
Expand Down
44 changes: 44 additions & 0 deletions pymatviz/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Typing related: TypeAlias, generic types and so on."""

from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, get_args

import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go


if TYPE_CHECKING:
from typing import TypeAlias

AxOrFig: TypeAlias = plt.Axes | plt.Figure | go.Figure

Backend: TypeAlias = Literal["matplotlib", "plotly"]
BACKENDS = MATPLOTLIB, PLOTLY = get_args(Backend)

ColorElemTypeStrategy: TypeAlias = Literal["symbol", "background", "both", "off"]
VALID_COLOR_ELEM_STRATEGIES = get_args(ColorElemTypeStrategy)

CrystalSystem: TypeAlias = Literal[
"triclinic",
"monoclinic",
"orthorhombic",
"tetragonal",
"trigonal",
"hexagonal",
"cubic",
]

ElemValues: TypeAlias = dict[str | int, float] | pd.Series | Sequence[str]

T = TypeVar("T") # generic type for input validation
P = ParamSpec("P") # generic type for function parameters
R = TypeVar("R") # generic type for return value


VALID_FIG_TYPES = get_args(AxOrFig)
VALID_FIG_NAMES: str = " | ".join(
f"{t.__module__}.{t.__qualname__}" for t in VALID_FIG_TYPES
)
Loading

0 comments on commit bc2669b

Please sign in to comment.