From c2c92cfbbe29a0a03d37542a0cf269abc998f3db Mon Sep 17 00:00:00 2001 From: hayk Date: Wed, 6 Nov 2024 14:50:48 -0500 Subject: [PATCH] new data reading layout --- .gitignore | 9 +- nt2/__init__.py | 5 +- nt2/containers/__init__.py | 0 nt2/containers/container.py | 148 ++++++ nt2/containers/fields.py | 229 +++++++++ nt2/containers/particles.py | 178 +++++++ nt2/containers/spectra.py | 138 +++++ nt2/containers/utils.py | 38 ++ nt2/plotters/__init__.py | 0 nt2/{ => plotters}/plot.py | 12 +- nt2/plotters/polarplot.py | 488 ++++++++++++++++++ nt2/read.py | 987 +----------------------------------- pyrightconfig.json | 3 + 13 files changed, 1255 insertions(+), 980 deletions(-) create mode 100644 nt2/containers/__init__.py create mode 100644 nt2/containers/container.py create mode 100644 nt2/containers/fields.py create mode 100644 nt2/containers/particles.py create mode 100644 nt2/containers/spectra.py create mode 100644 nt2/containers/utils.py create mode 100644 nt2/plotters/__init__.py rename nt2/{ => plotters}/plot.py (94%) create mode 100644 nt2/plotters/polarplot.py create mode 100644 pyrightconfig.json diff --git a/.gitignore b/.gitignore index 4a1d483..ae1febe 100644 --- a/.gitignore +++ b/.gitignore @@ -151,11 +151,6 @@ dmypy.json # Cython debug symbols cython_debug/ -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ test/ -temp/ \ No newline at end of file +temp/ +*.bak diff --git a/nt2/__init__.py b/nt2/__init__.py index 3d26edf..90d4cd0 100644 --- a/nt2/__init__.py +++ b/nt2/__init__.py @@ -1 +1,4 @@ -__version__ = "0.4.1" +__version__ = "0.5.0" + +from nt2.read import Data as Data +from nt2.plotters import polarplot as polarplot diff --git a/nt2/containers/__init__.py b/nt2/containers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nt2/containers/container.py b/nt2/containers/container.py new file mode 100644 index 0000000..07904d5 --- /dev/null +++ b/nt2/containers/container.py @@ -0,0 +1,148 @@ +import h5py +import numpy as np +from typing import Any +from dask.distributed import Client + + +def _read_attribs_SingleFile(file: h5py.File): + attribs = {} + for k in file.attrs.keys(): + attr = file.attrs[k] + if type(attr) is bytes or type(attr) is np.bytes_: + attribs[k] = attr.decode("UTF-8") + else: + attribs[k] = attr + return attribs + + +class Container: + def __init__( + self, path, single_file=False, pickle=True, greek=False, dask_props={} + ): + super(Container, self).__init__() + + self.client = Client(**dask_props) + if self.client.status == "running": + print("Dask client launched:") + print(self.client) + + self.configs: dict[str, Any] = { + "single_file": single_file, + "use_pickle": pickle, + "use_greek": greek, + } + self.path = path + self.metadata = {} + self.mesh = None + if self.configs["single_file"]: + try: + self.master_file: h5py.File | None = h5py.File(self.path, "r") + except OSError: + raise OSError(f"Could not open file {self.path}") + else: + self.master_file: h5py.File | None = None + raise NotImplementedError("Multiple files not yet supported") + + self.attrs = _read_attribs_SingleFile(self.master_file) + + if self.configs["single_file"]: + self.configs["ngh"] = int(self.master_file.attrs.get("NGhosts", 0)) + self.configs["layout"] = ( + "right" if self.master_file.attrs.get("LayoutRight", 1) == 1 else "left" + ) + self.configs["dimension"] = int(self.master_file.attrs.get("Dimension", 1)) + self.configs["coordinates"] = self.master_file.attrs.get( + "Coordinates", b"cart" + ).decode("UTF-8") + if self.configs["coordinates"] == "qsph": + self.configs["coordinates"] = "sph" + # if coordinates == "sph": + # self.metric = SphericalMetric() + # else: + # self.metric = MinkowskiMetric() + + def plotGrid(self, ax, **kwargs): + from matplotlib import patches + + xlim, ylim = ax.get_xlim(), ax.get_ylim() + options = { + "lw": 1, + "color": "k", + "ls": "-", + } + options.update(kwargs) + + if self.configs["coordinates"] == "cart": + for x in self.attrs["X1"]: + ax.plot([x, x], [self.attrs["X2Min"], self.attrs["X2Max"]], **options) + for y in self.attrs["X2"]: + ax.plot([self.attrs["X1Min"], self.attrs["X1Max"]], [y, y], **options) + else: + for r in self.attrs["X1"]: + ax.add_patch( + patches.Arc( + (0, 0), + 2 * r, + 2 * r, + theta1=-90, + theta2=90, + fill=False, + **options, + ) + ) + for th in self.attrs["X2"]: + ax.plot( + [ + self.attrs["X1Min"] * np.sin(th), + self.attrs["X1Max"] * np.sin(th), + ], + [ + self.attrs["X1Min"] * np.cos(th), + self.attrs["X1Max"] * np.cos(th), + ], + **options, + ) + ax.set(xlim=xlim, ylim=ylim) + + def print_container(self) -> str: + return f"Client {self.client}\n" + + # + # def makeMovie(self, plot, makeframes=True, **kwargs): + # """ + # Makes a movie from a plot function + # + # Parameters + # ---------- + # plot : function + # The plot function to use; accepts output timestep and dataset as arguments. + # makeframes : bool, optional + # Whether to make the frames, or just proceed to making the movie. Default is True. + # num_cpus : int, optional + # The number of CPUs to use for making the frames. Default is None. + # **kwargs : + # Additional keyword arguments passed to `ffmpeg`. + # """ + # import numpy as np + # + # if makeframes: + # makemovie = all( + # exp.makeFrames( + # plot, + # np.arange(len(self.t)), + # f"{self.attrs['simulation.name']}/frames", + # data=self, + # num_cpus=kwargs.pop("num_cpus", None), + # ) + # ) + # else: + # makemovie = True + # if makemovie: + # exp.makeMovie( + # input=f"{self.attrs['simulation.name']}/frames/", + # overwrite=True, + # output=f"{self.attrs['simulation.name']}.mp4", + # number=5, + # **kwargs, + # ) + # return True diff --git a/nt2/containers/fields.py b/nt2/containers/fields.py new file mode 100644 index 0000000..33bc355 --- /dev/null +++ b/nt2/containers/fields.py @@ -0,0 +1,229 @@ +import h5py +import xarray as xr +import numpy as np +from dask.array.core import from_array +from dask.array.core import stack + +from nt2.containers.container import Container +from nt2.containers.utils import _read_category_metadata_SingleFile + + +def _read_coordinates_SingleFile(coords: list[str], file: h5py.File): + for st in file: + group = file[st] + if isinstance(group, h5py.Group): + if any([k.startswith("X") for k in group if k is not None]): + # cell-centered coords + xc = { + c: ( + np.asarray(xi[:]) + if isinstance(xi := group[f"X{i+1}"], h5py.Dataset) and xi + else None + ) + for i, c in enumerate(coords[::-1]) + } + # cell edges + xe_min = { + f"{c}_1": ( + c, + ( + np.asarray(xi[:-1]) + if isinstance((xi := group[f"X{i+1}e"]), h5py.Dataset) + else None + ), + ) + for i, c in enumerate(coords[::-1]) + } + xe_max = { + f"{c}_2": ( + c, + ( + np.asarray(xi[1:]) + if isinstance((xi := group[f"X{i+1}e"]), h5py.Dataset) + else None + ), + ) + for i, c in enumerate(coords[::-1]) + } + return {"x_c": xc, "x_emin": xe_min, "x_emax": xe_max} + else: + raise ValueError(f"Unexpected type {type(file[st])}") + raise ValueError("Could not find coordinates in file") + + +def _preload_field_SingleFile( + k: str, + dim: int, + ngh: int, + outsteps: list[int], + times: list[float], + steps: list[int], + coords: list[str], + xc_coords: dict[str, str], + xe_min_coords: dict[str, str], + xe_max_coords: dict[str, str], + coord_replacements: list[tuple[str, str]], + field_replacements: list[tuple[str, str]], + layout: str, + file: h5py.File, +): + if dim == 1: + noghosts = slice(ngh, -ngh) if ngh > 0 else slice(None) + elif dim == 2: + noghosts = (slice(ngh, -ngh), slice(ngh, -ngh)) if ngh > 0 else slice(None) + elif dim == 3: + noghosts = ( + (slice(ngh, -ngh), slice(ngh, -ngh), slice(ngh, -ngh)) + if ngh > 0 + else slice(None) + ) + else: + raise ValueError("Invalid dimension") + + dask_arrays = [] + for s in outsteps: + dset = file[f"{s}/{k}"] + if isinstance(dset, h5py.Dataset): + array = from_array(np.transpose(dset) if layout == "right" else dset) + dask_arrays.append(array[noghosts]) + else: + raise ValueError(f"Unexpected type {type(dset)}") + + k_ = k[1:] + for c in coord_replacements: + if "_" not in k_: + k_ = k_.replace(c[0], c[1]) + else: + k_ = "_".join([k_.split("_")[0].replace(c[0], c[1])] + k_.split("_")[1:]) + for f in field_replacements: + k_ = k_.replace(*f) + + return k_, xr.DataArray( + stack(dask_arrays, axis=0), + dims=["t", *coords], + name=k_, + coords={ + "t": times, + "s": ("t", steps), + **xc_coords, + **xe_min_coords, + **xe_max_coords, + }, + ) + + +class FieldsContainer(Container): + def __init__(self, **kwargs): + super(FieldsContainer, self).__init__(**kwargs) + QuantityDict = { + "Ttt": "E", + "Ttx": "Px", + "Tty": "Py", + "Ttz": "Pz", + } + CoordinateDict = { + "cart": {"x": "x", "y": "y", "z": "z", "1": "x", "2": "y", "3": "z"}, + "sph": { + "r": "r", + "theta": "θ" if self.configs["use_greek"] else "th", + "phi": "φ" if self.configs["use_greek"] else "ph", + "1": "r", + "2": "θ" if self.configs["use_greek"] else "th", + "3": "φ" if self.configs["use_greek"] else "ph", + }, + } + if self.configs["single_file"]: + assert self.master_file is not None, "Master file not found" + self.metadata["fields"] = _read_category_metadata_SingleFile( + "f", self.master_file + ) + else: + try: + raise NotImplementedError("Multiple files not yet supported") + except OSError: + raise OSError(f"Could not open file {self.path}") + + coords = list(CoordinateDict[self.configs["coordinates"]].values())[::-1][ + -self.configs["dimension"] : + ] + + if self.configs["single_file"]: + self.mesh = _read_coordinates_SingleFile(coords, self.master_file) + else: + raise NotImplementedError("Multiple files not yet supported") + + self.fields = xr.Dataset() + + if len(self.metadata["fields"]["outsteps"]) > 0: + if self.configs["single_file"]: + for k in self.metadata["fields"]["quantities"]: + name, dset = _preload_field_SingleFile( + k, + dim=self.configs["dimension"], + ngh=self.configs["ngh"], + outsteps=self.metadata["fields"]["outsteps"], + times=self.metadata["fields"]["times"], + steps=self.metadata["fields"]["steps"], + coords=coords, + xc_coords=self.mesh["x_c"], + xe_min_coords=self.mesh["x_emin"], + xe_max_coords=self.mesh["x_emax"], + coord_replacements=list( + CoordinateDict[self.configs["coordinates"]].items() + ), + field_replacements=list(QuantityDict.items()), + layout=self.configs["layout"], + file=self.master_file, + ) + self.fields[name] = dset + else: + raise NotImplementedError("Multiple files not yet supported") + + def __del__(self): + if self.configs["single_file"] and self.master_file is not None: + self.master_file.close() + else: + raise NotImplementedError("Multiple files not yet supported") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.configs["single_file"] and self.master_file is not None: + self.master_file.close() + else: + raise NotImplementedError("Multiple files not yet supported") + + def print_fields(self) -> str: + def sizeof_fmt(num, suffix="B"): + for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): + if abs(num) < 1e3: + return f"{num:3.1f} {unit}{suffix}" + num /= 1e3 + return f"{num:.1f} Y{suffix}" + + def compactify(lst): + c = "" + cntr = 0 + for l_ in lst: + if cntr > 5: + c += "\n " + cntr = 0 + c += l_ + ", " + cntr += 1 + return c[:-2] + + string = "" + field_keys = list(self.fields.data_vars.keys()) + + if len(field_keys) > 0: + string += "Fields:\n" + string += f" - data axes: {compactify(self.fields.indexes.keys())}\n" + string += f" - timesteps: {self.fields[field_keys[0]].shape[0]}\n" + string += f" - shape: {self.fields[field_keys[0]].shape[1:]}\n" + string += f" - quantities: {compactify(self.fields.data_vars.keys())}\n" + string += f" - total size: {sizeof_fmt(self.fields.nbytes)}\n" + else: + string += "Fields: empty\n" + + return string diff --git a/nt2/containers/particles.py b/nt2/containers/particles.py new file mode 100644 index 0000000..4440b9d --- /dev/null +++ b/nt2/containers/particles.py @@ -0,0 +1,178 @@ +import h5py +import numpy as np +import xarray as xr +from dask.array.core import from_array + + +from nt2.containers.container import Container +from nt2.containers.utils import _read_category_metadata_SingleFile + + +def _list_to_ragged(arr): + max_len = np.max([len(a) for a in arr]) + return map( + lambda a: np.concatenate([a, np.full(max_len - len(a), np.nan)]), + arr, + ) + + +def _read_species_SingleFile(first_step: int, file: h5py.File): + group = file[first_step] + if not isinstance(group, h5py.Group): + raise ValueError(f"Unexpected type {type(group)}") + species = np.unique( + [int(pq.split("_")[1]) for pq in group.keys() if pq.startswith("p")] + ) + return species + + +def _preload_particle_species_SingleFile( + s: int, + quantities: list[str], + coord_type: str, + outsteps: list[int], + times: list[float], + steps: list[int], + coord_replacements: dict[str, str], + file: h5py.File, +): + prtl_data = {} + for q in [ + f"X1_{s}", + f"X2_{s}", + f"X3_{s}", + f"U1_{s}", + f"U2_{s}", + f"U3_{s}", + f"W_{s}", + ]: + if q[0] in ["X", "U"]: + q_ = coord_replacements[q.split("_")[0]] + else: + q_ = q.split("_")[0] + if "p" + q not in quantities: + continue + if q not in prtl_data.keys(): + prtl_data[q_] = [] + for step_k in outsteps: + group = file[step_k] + if isinstance(group, h5py.Group): + if "p" + q in group.keys(): + prtl_data[q_].append(group["p" + q]) + else: + prtl_data[q_].append(np.full_like(prtl_data[q_][-1], np.nan)) + else: + raise ValueError(f"Unexpected type {type(file[step_k])}") + prtl_data[q_] = _list_to_ragged(prtl_data[q_]) + prtl_data[q_] = from_array(list(prtl_data[q_])) + prtl_data[q_] = xr.DataArray( + prtl_data[q_], + dims=["t", "id"], + name=q_, + coords={"t": times, "s": ("t", steps)}, + ) + if coord_type == "sph": + prtl_data["x"] = ( + prtl_data[coord_replacements["X1"]] + * np.sin(prtl_data[coord_replacements["X2"]]) + * np.cos(prtl_data[coord_replacements["X3"]]) + ) + prtl_data["y"] = ( + prtl_data[coord_replacements["X1"]] + * np.sin(prtl_data[coord_replacements["X2"]]) + * np.sin(prtl_data[coord_replacements["X3"]]) + ) + prtl_data["z"] = prtl_data[coord_replacements["X1"]] * np.cos( + prtl_data[coord_replacements["X2"]] + ) + return xr.Dataset(prtl_data) + + +class ParticleContainer(Container): + def __init__(self, **kwargs): + super(ParticleContainer, self).__init__(**kwargs) + PrtlDict = { + "cart": { + "X1": "x", + "X2": "y", + "X3": "z", + "U1": "ux", + "U2": "uy", + "U3": "uz", + }, + "sph": { + "X1": "r", + "X2": "θ" if self.configs["use_greek"] else "th", + "X3": "φ" if self.configs["use_greek"] else "ph", + "U1": "ur", + "U2": "uΘ" if self.configs["use_greek"] else "uth", + "U3": "uφ" if self.configs["use_greek"] else "uph", + }, + } + + if self.configs["single_file"]: + assert self.master_file is not None, "Master file not found" + self.metadata["particles"] = _read_category_metadata_SingleFile( + "p", self.master_file + ) + self._particles = {} + + if len(self.metadata["particles"]["outsteps"]) > 0: + if self.configs["single_file"]: + assert self.master_file is not None, "Master file not found" + species = _read_species_SingleFile( + self.metadata["particles"]["outsteps"][0], self.master_file + ) + for s in species: + self._particles[s] = _preload_particle_species_SingleFile( + s=s, + quantities=self.metadata["particles"]["quantities"], + coord_type=self.configs["coordinates"], + outsteps=self.metadata["particles"]["outsteps"], + times=self.metadata["particles"]["times"], + steps=self.metadata["particles"]["steps"], + coord_replacements=PrtlDict[self.configs["coordinates"]], + file=self.master_file, + ) + + @property + def particles(self): + return self._particles + + def print_particles(self) -> str: + def sizeof_fmt(num, suffix="B"): + for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): + if abs(num) < 1e3: + return f"{num:3.1f} {unit}{suffix}" + num /= 1e3 + return f"{num:.1f} Y{suffix}" + + def compactify(lst): + c = "" + cntr = 0 + for l_ in lst: + if cntr > 5: + c += "\n " + cntr = 0 + c += l_ + ", " + cntr += 1 + return c[:-2] + + string = "" + if self.particles != {}: + species = [int(i) for i in self.particles.keys()] + string += "Particles:\n" + string += f" - species: {species}\n" + string += f" - data axes: {compactify(self.particles[species[0]].indexes.keys())}\n" + string += f" - timesteps: {self.particles[species[0]][list(self.particles[species[0]].data_vars.keys())[0]].shape[0]}\n" + string += f" - quantities: {compactify(self.particles[species[0]].data_vars.keys())}\n" + size = 0 + for s in species: + keys = list(self.particles[s].data_vars.keys()) + string += f" - species [{s}]:\n" + string += f" - number: {self.particles[s][keys[0]].shape[1]}\n" + size += self.particles[s].nbytes + string += f" - total size: {sizeof_fmt(size)}\n" + else: + string += "Particles: empty\n" + return string diff --git a/nt2/containers/spectra.py b/nt2/containers/spectra.py new file mode 100644 index 0000000..8825e9c --- /dev/null +++ b/nt2/containers/spectra.py @@ -0,0 +1,138 @@ +import h5py +import numpy as np +import xarray as xr +from dask.array.core import from_array +from dask.array.core import stack + +from nt2.containers.container import Container +from nt2.containers.utils import _read_category_metadata_SingleFile + + +def _read_species_SingleFile(first_step: int, file: h5py.File): + group = file[first_step] + if not isinstance(group, h5py.Group): + raise ValueError(f"Unexpected type {type(group)}") + species = np.unique( + [int(pq.split("_")[1]) for pq in group.keys() if pq.startswith("sN")] + ) + return species + + +def _read_spectra_bins_SingleFile(first_step: int, log_bins: bool, file: h5py.File): + group = file[first_step] + if not isinstance(group, h5py.Group): + raise ValueError(f"Unexpected type {type(group)}") + e_bins = group["sEbn"] + if not isinstance(e_bins, h5py.Dataset): + raise ValueError(f"Unexpected type {type(e_bins)}") + if log_bins: + e_bins = np.sqrt(e_bins[1:] * e_bins[:-1]) + else: + e_bins = (e_bins[1:] + e_bins[:-1]) / 2 + return e_bins + + +def _preload_spectra_SingleFile( + sp: int, + e_bins: np.ndarray, + outsteps: list[int], + times: list[float], + steps: list[int], + file: h5py.File, +): + dask_arrays = [] + for st in outsteps: + array = from_array(file[f"{st}/sN_{sp}"]) + dask_arrays.append(array) + + return xr.DataArray( + stack(dask_arrays, axis=0), + dims=["t", "e"], + name=f"n_{sp}", + coords={ + "t": times, + "s": ("t", steps), + "e": e_bins, + }, + ) + + +class SpectraContainer(Container): + def __init__(self, **kwargs): + super(SpectraContainer, self).__init__(**kwargs) + assert "single_file" in self.configs + assert "use_pickle" in self.configs + assert "use_greek" in self.configs + assert "path" in self.__dict__ + assert "metadata" in self.__dict__ + assert "mesh" in self.__dict__ + assert "attrs" in self.__dict__ + + if self.configs["single_file"]: + assert self.master_file is not None, "Master file not found" + self.metadata["spectra"] = _read_category_metadata_SingleFile( + "s", self.master_file + ) + self._spectra = xr.Dataset() + log_bins = self.attrs["output.spectra.log_bins"] + + if len(self.metadata["spectra"]["outsteps"]) > 0: + if self.configs["single_file"]: + assert self.master_file is not None, "Master file not found" + species = _read_species_SingleFile( + self.metadata["spectra"]["outsteps"][0], self.master_file + ) + else: + raise NotImplementedError("Multiple files not yet supported") + + e_bins = _read_spectra_bins_SingleFile( + self.metadata["spectra"]["outsteps"][0], log_bins, self.master_file + ) + + for sp in species: + self._spectra[f"n_{sp}"] = _preload_spectra_SingleFile( + sp, + e_bins, + self.metadata["spectra"]["outsteps"], + self.metadata["spectra"]["times"], + self.metadata["spectra"]["steps"], + self.master_file, + ) + + @property + def spectra(self): + return self._spectra + + def print_spectra(self) -> str: + def sizeof_fmt(num, suffix="B"): + for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): + if abs(num) < 1e3: + return f"{num:3.1f} {unit}{suffix}" + num /= 1e3 + return f"{num:.1f} Y{suffix}" + + def compactify(lst): + c = "" + cntr = 0 + for l_ in lst: + if cntr > 5: + c += "\n " + cntr = 0 + c += l_ + ", " + cntr += 1 + return c[:-2] + + string = "" + spec_keys = list(self.spectra.data_vars.keys()) + + if len(spec_keys) > 0: + string += "Spectra:\n" + string += f" - data axes: {compactify(self.spectra.indexes.keys())}\n" + string += f" - timesteps: {self.spectra[spec_keys[0]].shape[0]}\n" + string += f" - # of bins: {self.spectra[spec_keys[0]].shape[1]}\n" + string += f" - quantities: {compactify(self.spectra.data_vars.keys())}\n" + string += f" - total size: {sizeof_fmt(self.spectra.nbytes)}\n" + else: + string += "Spectra: empty\n" + + return string diff --git a/nt2/containers/utils.py b/nt2/containers/utils.py new file mode 100644 index 0000000..d5a9718 --- /dev/null +++ b/nt2/containers/utils.py @@ -0,0 +1,38 @@ +import h5py +import numpy as np + + +def _read_category_metadata_SingleFile(prefix: str, file: h5py.File): + f_outsteps = [] + f_steps = [] + f_times = [] + f_quantities = None + for st in file: + group = file[st] + if isinstance(group, h5py.Group): + if any([k.startswith(prefix) for k in group if k is not None]): + if f_quantities is None: + f_quantities = [k for k in group.keys() if k.startswith(prefix)] + f_outsteps.append(st) + time_ds = group["Time"] + if isinstance(time_ds, h5py.Dataset): + f_times.append(time_ds[()]) + else: + raise ValueError(f"Unexpected type {type(time_ds)}") + step_ds = group["Step"] + if isinstance(step_ds, h5py.Dataset): + f_steps.append(int(step_ds[()])) + else: + raise ValueError(f"Unexpected type {type(step_ds)}") + + else: + raise ValueError(f"Unexpected type {type(file[st])}") + f_outsteps = sorted(f_outsteps, key=lambda x: int(x.replace("Step", ""))) + f_steps = sorted(f_steps) + f_times = np.array(sorted(f_times), dtype=np.float64) + return { + "quantities": f_quantities, + "outsteps": f_outsteps, + "steps": f_steps, + "times": f_times, + } diff --git a/nt2/plotters/__init__.py b/nt2/plotters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nt2/plot.py b/nt2/plotters/plot.py similarity index 94% rename from nt2/plot.py rename to nt2/plotters/plot.py index 69b12c4..7843fd1 100644 --- a/nt2/plot.py +++ b/nt2/plotters/plot.py @@ -2,7 +2,8 @@ def annotatePulsar( ax, data, rmax, rstar=1.1, ti=None, time=None, attrs={}, ax_props={}, star_props={} ): import numpy as np - import matplotlib as mpl + from matplotlib import lines + from matplotlib import patches if ti is None and time is None: raise ValueError("Must provide either ti or time") @@ -21,6 +22,7 @@ def annotatePulsar( "WARNING: No spinup time or spin period found, please specify explicitly as `attrs = {'psr_omega': ..., 'psr_spinup_time': ...}`" ) demo_rotation = False + phase = 0 else: phase = ( omega @@ -45,7 +47,7 @@ def annotatePulsar( for i in range(-int(rmax * 0.8) // 2 - 1, int(rmax * 0.8) // 2): if i != -1: ax.add_artist( - mpl.lines.Line2D( + lines.Line2D( [0, -0.1], [2 * (i + 1), 2 * (i + 1)], color=ax_props.get("color", "k"), @@ -85,7 +87,7 @@ def annotatePulsar( xs = np.concatenate([xs1, xs2[::-1]]) ys = np.concatenate([ys1, ys2[::-1]]) ax.add_artist( - mpl.patches.Polygon( + patches.Polygon( (rstar + 0.02) * np.array([xs, ys]).T, color=star_props.get("c1", "r"), lw=0, @@ -94,7 +96,7 @@ def annotatePulsar( ) ) ax.add_artist( - mpl.patches.Polygon( + patches.Polygon( (rstar + 0.02) * np.array([-xs, ys]).T, color=star_props.get("c2", "b"), lw=0, @@ -103,7 +105,7 @@ def annotatePulsar( ) ) ax.add_artist( - mpl.patches.Circle( + patches.Circle( (0, 0), rstar, color=star_props.get("color", "royalblue"), diff --git a/nt2/plotters/polarplot.py b/nt2/plotters/polarplot.py new file mode 100644 index 0000000..465872a --- /dev/null +++ b/nt2/plotters/polarplot.py @@ -0,0 +1,488 @@ +from dask.delayed import delayed +import xarray as xr +import numpy as np +from typing import Any + + +def DataIs2DPolar(ds): + return ("r" in ds.dims and ("θ" in ds.dims or "th" in ds.dims)) and len( + ds.dims + ) == 2 + + +def DipoleSampling(**kwargs): + """ + Returns an array of angles sampled from a dipole distribution. + + Parameters + ---------- + nth : int, optional + The number of angles to sample. Default is 30. + pole : float, optional + The fraction of the angles to sample from the poles. Default is 1/16. + + Returns + ------- + ndarray + An array of angles sampled from a dipole distribution. + """ + nth = kwargs.get("nth", 30) + pole = kwargs.get("pole", 1 / 16) + + nth_poles = int(nth * pole) + nth_equator = (nth - 2 * nth_poles) // 2 + return np.concatenate( + [ + np.linspace(0, np.pi * pole, nth_poles + 1)[1:], + np.linspace(np.pi * pole, np.pi / 2, nth_equator + 2)[1:-1], + np.linspace(np.pi * (1 - pole), np.pi, nth_poles + 1)[:-1], + ] + ) + + +def MonopoleSampling(**kwargs): + """ + Returns an array of angles sampled from a monopole distribution. + + Parameters + ---------- + nth : int, optional + The number of angles to sample. Default is 30. + + Returns + ------- + ndarray + An array of angles sampled from a monopole distribution. + """ + nth = kwargs.get("nth", 30) + + return np.linspace(0, np.pi, nth + 2)[1:-1] + + +@xr.register_dataset_accessor("polar") +class DatasetPolarPlotAccessor: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def pcolor(self, value, **kwargs): + assert "t" not in self._obj[value].dims, "Time must be specified" + assert DataIs2DPolar(self._obj), "Data must be 2D polar" + self._obj[value].polar.pcolor(**kwargs) + + def fieldplot( + self, + fr, + fth, + start_points=None, + sample=None, + invert_x=False, + invert_y=False, + **kwargs, + ): + """ + Plot field lines of a vector field defined by functions fr and fth. + + Parameters + ---------- + fr : string + Radial component of the vector field. + fth : string + Azimuthal component of the vector field. + start_points : array_like, optional + Starting points for the field lines. Either this or `sample` must be specified. + sample : dict, optional + Sampling template for generating starting points. Either this or `start_points` must be specified. + The template can be "dipole" or "monopole". The dict also contains the starting `radius`, + and the number of points in theta `nth` key. + invert_x : bool, optional + Whether to invert the x-axis. Default is False. + invert_y : bool, optional + Whether to invert the y-axis. Default is False. + **kwargs : + Additional keyword arguments passed to `fieldlines` and `ax.plot`. + + Raises + ------ + ValueError + If neither `start_points` nor `sample` are specified or if an unknown sampling template is given. + + Returns + ------- + None + + Examples + -------- + >>> ds.polar.fieldplot("Br", "Bth", sample={"template": "dipole", "nth": 30, "radius": 2.0}) + """ + import matplotlib.pyplot as plt + + if start_points is None and sample is None: + raise ValueError("Either start_points or sample must be specified") + elif start_points is None and sample is not None: + radius = sample.pop("radius", 1.5) + template = sample.pop("template", "dipole") + if template == "dipole": + start_points = [[radius, th] for th in DipoleSampling(**sample)] + elif template == "monopole": + start_points = [[radius, th] for th in MonopoleSampling(**sample)] + else: + raise ValueError("Unknown sampling template: " + template) + + fieldlines = self.fieldlines(fr, fth, start_points, **kwargs).compute() + ax = kwargs.pop("ax", plt.gca()) + for fieldline in fieldlines: + if invert_x: + fieldline[:, 0] = -fieldline[:, 0] + if invert_y: + fieldline[:, 1] = -fieldline[:, 1] + ax.plot(*fieldline.T, **kwargs) + + @delayed + def fieldlines(self, fr, fth, start_points, **kwargs): + """ + Compute field lines of a vector field defined by functions fr and fth. + + Parameters + ---------- + fr : string + Radial component of the vector field. + fth : string + Azimuthal component of the vector field. + start_points : array_like + Starting points for the field lines. + direction : str, optional + Direction to integrate in. Can be "both", "forward" or "backward". Default is "both". + stopWhen : callable, optional + Function that takes the current position and returns True if the integration should stop. Default is to never stop. + ds : float, optional + Integration step size. Default is 0.1. + maxsteps : int, optional + Maximum number of integration steps. Default is 1000. + + Returns + ------- + list + List of field lines. + + Examples + -------- + >>> ds.polar.fieldlines("Br", "Bth", [[2.0, np.pi / 4], [2.0, 3 * np.pi / 4]], stopWhen = lambda xy, rth: rth[0] > 5.0) + """ + + import numpy as np + from scipy.interpolate import RegularGridInterpolator + + assert "t" not in self._obj[fr].dims, "Time must be specified" + assert "t" not in self._obj[fth].dims, "Time must be specified" + assert DataIs2DPolar(self._obj), "Data must be 2D polar" + + useGreek = "θ" in self._obj.coords.keys() + + r, th = ( + self._obj.coords["r"].values, + self._obj.coords["θ" if useGreek else "th"].values, + ) + _, ths = np.meshgrid(r, th) + fxs = self._obj[fr] * np.sin(ths) + self._obj[fth] * np.cos(ths) + fys = self._obj[fr] * np.cos(ths) - self._obj[fth] * np.sin(ths) + + props: dict[str, Any] = { + "method": "nearest", + "bounds_error": False, + "fill_value": 0, + } + interpFx = RegularGridInterpolator((th, r), fxs.values, **props) + interpFy = RegularGridInterpolator((th, r), fys.values, **props) + return [ + self._fieldline(interpFx, interpFy, rth, **kwargs) for rth in start_points + ] + + def _fieldline(self, interp_fx, interp_fy, r_th_start, **kwargs): + import numpy as np + from copy import copy + + direction = kwargs.pop("direction", "both") + stopWhen = kwargs.pop("stopWhen", lambda _, __: False) + ds = kwargs.pop("ds", 0.1) + maxsteps = kwargs.pop("maxsteps", 1000) + + rmax = self._obj.r.max() + rmin = self._obj.r.min() + + def stop(xy, rth): + return ( + stopWhen(xy, rth) + or (rth[0] < rmin) + or (rth[0] > rmax) + or (rth[1] < 0) + or (rth[1] > np.pi) + ) + + def integrate(delta, counter): + r0, th0 = copy(r_th_start) + XY = np.array([r0 * np.sin(th0), r0 * np.cos(th0)]) + RTH = [r0, th0] + fieldline = np.array([XY]) + with np.errstate(divide="ignore", invalid="ignore"): + while range(counter, maxsteps): + x, y = XY + r = np.sqrt(x**2 + y**2) + th = np.arctan2(-y, x) + np.pi / 2 + RTH = [r, th] + vx = interp_fx((th, r))[()] + vy = interp_fy((th, r))[()] + vmag = np.sqrt(vx**2 + vy**2) + XY = XY + delta * np.array([vx, vy]) / vmag + if stop(XY, RTH) or np.isnan(XY).any() or np.isinf(XY).any(): + break + else: + fieldline = np.append(fieldline, [XY], axis=0) + return fieldline + + if direction == "forward": + return integrate(ds, 0) + elif direction == "backward": + return integrate(-ds, 0) + else: + cntr = 0 + f1 = integrate(ds, cntr) + f2 = integrate(-ds, cntr) + return np.append(f2[::-1], f1, axis=0) + + +@xr.register_dataarray_accessor("polar") +class PolarPlotAccessor: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def pcolor(self, **kwargs): + """ + Plots a pseudocolor plot of 2D polar data on a rectilinear projection. + + Parameters + ---------- + ax : Axes object, optional + The axes on which to plot. Default is the current axes. + cell_centered : bool, optional + Whether the data is cell-centered. Default is True. + cell_size : float, optional + If not cell_centered, defines the fraction of the cell to use for coloring. Default is 0.75. + cbar_size : str, optional + The size of the colorbar. Default is "5%". + cbar_pad : float, optional + The padding between the colorbar and the plot. Default is 0.05. + cbar_position : str, optional + The position of the colorbar. Default is "right". + cbar_ticksize : int or float, optional + The size of the ticks on the colorbar. Default is None. + title : str, optional + The title of the plot. Default is None. + invert_x : bool, optional + Whether to invert the x-axis. Default is False. + invert_y : bool, optional + Whether to invert the y-axis. Default is False. + ylabel : str, optional + The label for the y-axis. Default is "y". + xlabel : str, optional + The label for the x-axis. Default is "x". + label : str, optional + The label for the plot. Default is None. + + Returns + ------- + matplotlib.collections.Collection + The pseudocolor plot. + + Raises + ------ + AssertionError + If `ax` is a polar projection or if time is not specified or if data is not 2D polar. + + Notes + ----- + Additional keyword arguments are passed to `pcolormesh`. + """ + + import matplotlib.pyplot as plt + from matplotlib import colors + from matplotlib import tri + import matplotlib as mpl + from mpl_toolkits.axes_grid1 import make_axes_locatable + + useGreek = "θ" in self._obj.coords.keys() + + ax = kwargs.pop("ax", plt.gca()) + cbar_size = kwargs.pop("cbar_size", "5%") + cbar_pad = kwargs.pop("cbar_pad", 0.05) + cbar_pos = kwargs.pop("cbar_position", "right") + cbar_orientation = ( + "vertical" if cbar_pos == "right" or cbar_pos == "left" else "horizontal" + ) + cbar_ticksize = kwargs.pop("cbar_ticksize", None) + title = kwargs.pop("title", None) + invert_x = kwargs.pop("invert_x", False) + invert_y = kwargs.pop("invert_y", False) + ylabel = kwargs.pop("ylabel", "y") + xlabel = kwargs.pop("xlabel", "x") + label = kwargs.pop("label", None) + cell_centered = kwargs.pop("cell_centered", True) + cell_size = kwargs.pop("cell_size", 0.75) + + assert ax.name != "polar", "`ax` must be a rectilinear projection" + assert "t" not in self._obj.dims, "Time must be specified" + assert DataIs2DPolar(self._obj), "Data must be 2D polar" + ax.grid(False) + if type(kwargs.get("norm", None)) is colors.LogNorm: + cm = kwargs.get("cmap", "viridis") + cm = mpl.colormaps[cm] + cm.set_bad(cm(0)) + kwargs["cmap"] = cm + + vals = self._obj.values.flatten() + vals = np.concatenate((vals, vals)) + if not cell_centered: + drs = self._obj.coords["r_2"] - self._obj.coords["r_1"] + dths = ( + self._obj.coords["θ_2" if useGreek else "th_2"] + - self._obj.coords["θ_1" if useGreek else "th_1"] + ) + r1s = self._obj.coords["r_1"] - drs * cell_size / 2 + r2s = self._obj.coords["r_1"] + drs * cell_size / 2 + th1s = ( + self._obj.coords["θ_1" if useGreek else "th_1"] - dths * cell_size / 2 + ) + th2s = ( + self._obj.coords["θ_1" if useGreek else "th_1"] + dths * cell_size / 2 + ) + rs = np.ravel(np.column_stack((r1s, r2s))) + ths = np.ravel(np.column_stack((th1s, th2s))) + nr = len(rs) + nth = len(ths) + rs, ths = np.meshgrid(rs, ths) + rs = rs.flatten() + ths = ths.flatten() + points_1 = np.arange(nth * nr).reshape(nth, -1)[:-1:2, :-1:2].flatten() + points_2 = np.arange(nth * nr).reshape(nth, -1)[:-1:2, 1::2].flatten() + points_3 = np.arange(nth * nr).reshape(nth, -1)[1::2, 1::2].flatten() + points_4 = np.arange(nth * nr).reshape(nth, -1)[1::2, :-1:2].flatten() + + else: + rs = np.append(self._obj.coords["r_1"], self._obj.coords["r_2"][-1]) + ths = np.append( + self._obj.coords["θ_1" if useGreek else "th_1"], + self._obj.coords["θ_2" if useGreek else "th_2"][-1], + ) + nr = len(rs) + nth = len(ths) + rs, ths = np.meshgrid(rs, ths) + rs = rs.flatten() + ths = ths.flatten() + points_1 = np.arange(nth * nr).reshape(nth, -1)[:-1, :-1].flatten() + points_2 = np.arange(nth * nr).reshape(nth, -1)[:-1, 1:].flatten() + points_3 = np.arange(nth * nr).reshape(nth, -1)[1:, 1:].flatten() + points_4 = np.arange(nth * nr).reshape(nth, -1)[1:, :-1].flatten() + x, y = rs * np.sin(ths), rs * np.cos(ths) + if invert_x: + x = -x + if invert_y: + y = -y + triang = tri.Triangulation( + x, + y, + triangles=np.concatenate( + [ + np.array([points_1, points_2, points_3]).T, + np.array([points_1, points_3, points_4]).T, + ], + axis=0, + ), + ) + ax.set( + aspect="equal", + xlabel=xlabel, + ylabel=ylabel, + ) + im = ax.tripcolor(triang, vals, rasterized=True, shading="flat", **kwargs) + if cbar_pos is not None: + divider = make_axes_locatable(ax) + cax = divider.append_axes(cbar_pos, size=cbar_size, pad=cbar_pad) + _ = plt.colorbar( + im, + cax=cax, + label=self._obj.name if label is None else label, + orientation=cbar_orientation, + ) + if cbar_orientation == "vertical": + axis = cax.yaxis + else: + axis = cax.xaxis + axis.set_label_position(cbar_pos) + axis.set_ticks_position(cbar_pos) + if cbar_ticksize is not None: + cax.tick_params("both", labelsize=cbar_ticksize) + ax.set_title( + f"t={self._obj.coords['t'].values[()]:.2f}" if title is None else title + ) + return im + + def contour(self, **kwargs): + """ + Plots a pseudocolor plot of 2D polar data on a rectilinear projection. + + Parameters + ---------- + ax : Axes object, optional + The axes on which to plot. Default is the current axes. + invert_x : bool, optional + Whether to invert the x-axis. Default is False. + invert_y : bool, optional + Whether to invert the y-axis. Default is False. + + Returns + ------- + matplotlib.contour.QuadContourSet + The contour plot. + + Raises + ------ + AssertionError + If `ax` is a polar projection or if time is not specified or if data is not 2D polar. + + Notes + ----- + Additional keyword arguments are passed to `contour`. + """ + + import warnings + import matplotlib.pyplot as plt + + useGreek = "θ" in self._obj.coords.keys() + + ax = kwargs.pop("ax", plt.gca()) + title = kwargs.pop("title", None) + invert_x = kwargs.pop("invert_x", False) + invert_y = kwargs.pop("invert_y", False) + + assert ax.name != "polar", "`ax` must be a rectilinear projection" + assert "t" not in self._obj.dims, "Time must be specified" + assert DataIs2DPolar(self._obj), "Data must be 2D polar" + ax.grid(False) + r, th = np.meshgrid( + self._obj.coords["r"], self._obj.coords["θ" if useGreek else "th"] + ) + x, y = r * np.sin(th), r * np.cos(th) + if invert_x: + x = -x + if invert_y: + y = -y + ax.set( + aspect="equal", + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + im = ax.contour(x, y, self._obj.values, **kwargs) + + ax.set_title( + f"t={self._obj.coords['t'].values[()]:.2f}" if title is None else title + ) + return im diff --git a/nt2/read.py b/nt2/read.py index dd9a427..5c6af32 100644 --- a/nt2/read.py +++ b/nt2/read.py @@ -1,977 +1,30 @@ -import xarray as xr +from nt2.containers.fields import FieldsContainer +from nt2.containers.particles import ParticleContainer +from nt2.containers.spectra import SpectraContainer -import nt2.export as exp -useGreek = False -usePickle = False - - -def configure(use_greek=False, use_pickle=False): - global useGreek - global usePickle - useGreek = use_greek - usePickle = use_pickle - - -def DataIs2DPolar(ds): - return ("r" in ds.dims and ("θ" in ds.dims or "th" in ds.dims)) and len( - ds.dims - ) == 2 - - -def DipoleSampling(**kwargs): - """ - Returns an array of angles sampled from a dipole distribution. - - Parameters - ---------- - nth : int, optional - The number of angles to sample. Default is 30. - pole : float, optional - The fraction of the angles to sample from the poles. Default is 1/16. - - Returns - ------- - ndarray - An array of angles sampled from a dipole distribution. - """ - import numpy as np - - nth = kwargs.get("nth", 30) - pole = kwargs.get("pole", 1 / 16) - - nth_poles = int(nth * pole) - nth_equator = (nth - 2 * nth_poles) // 2 - return np.concatenate( - [ - np.linspace(0, np.pi * pole, nth_poles + 1)[1:], - np.linspace(np.pi * pole, np.pi / 2, nth_equator + 2)[1:-1], - np.linspace(np.pi * (1 - pole), np.pi, nth_poles + 1)[:-1], - ] - ) - - -def MonopoleSampling(**kwargs): +class Data(FieldsContainer, ParticleContainer, SpectraContainer): """ - Returns an array of angles sampled from a monopole distribution. - - Parameters - ---------- - nth : int, optional - The number of angles to sample. Default is 30. - - Returns - ------- - ndarray - An array of angles sampled from a monopole distribution. + A class to load Entity data and store it as a lazily loaded xarray Dataset. """ - import numpy as np - - nth = kwargs.get("nth", 30) - - return np.linspace(0, np.pi, nth + 2)[1:-1] - - -@xr.register_dataset_accessor("polar") -class DatasetPolarPlotAccessor: - import dask - - def __init__(self, xarray_obj): - self._obj = xarray_obj - - def pcolor(self, value, **kwargs): - assert "t" not in self._obj[value].dims, "Time must be specified" - assert DataIs2DPolar(self._obj), "Data must be 2D polar" - self._obj[value].polar.pcolor(**kwargs) - - def fieldplot( - self, - fr, - fth, - start_points=None, - sample=None, - invert_x=False, - invert_y=False, - **kwargs, - ): - """ - Plot field lines of a vector field defined by functions fr and fth. - - Parameters - ---------- - fr : string - Radial component of the vector field. - fth : string - Azimuthal component of the vector field. - start_points : array_like, optional - Starting points for the field lines. Either this or `sample` must be specified. - sample : dict, optional - Sampling template for generating starting points. Either this or `start_points` must be specified. - The template can be "dipole" or "monopole". The dict also contains the starting `radius`, - and the number of points in theta `nth` key. - invert_x : bool, optional - Whether to invert the x-axis. Default is False. - invert_y : bool, optional - Whether to invert the y-axis. Default is False. - **kwargs : - Additional keyword arguments passed to `fieldlines` and `ax.plot`. - - Raises - ------ - ValueError - If neither `start_points` nor `sample` are specified or if an unknown sampling template is given. - - Returns - ------- - None - - Examples - -------- - >>> ds.polar.fieldplot("Br", "Bth", sample={"template": "dipole", "nth": 30, "radius": 2.0}) - """ - import matplotlib.pyplot as plt - - if start_points is None and sample is None: - raise ValueError("Either start_points or sample must be specified") - elif start_points is None: - radius = sample.pop("radius", 1.5) - template = sample.pop("template", "dipole") - if template == "dipole": - start_points = [[radius, th] for th in DipoleSampling(**sample)] - elif template == "monopole": - start_points = [[radius, th] for th in MonopoleSampling(**sample)] - else: - raise ValueError("Unknown sampling template: " + template) - - fieldlines = self.fieldlines(fr, fth, start_points, **kwargs).compute() - ax = kwargs.pop("ax", plt.gca()) - for fieldline in fieldlines: - if invert_x: - fieldline[:, 0] = -fieldline[:, 0] - if invert_y: - fieldline[:, 1] = -fieldline[:, 1] - ax.plot(*fieldline.T, **kwargs) - - @dask.delayed - def fieldlines(self, fr, fth, start_points, **kwargs): - """ - Compute field lines of a vector field defined by functions fr and fth. - - Parameters - ---------- - fr : string - Radial component of the vector field. - fth : string - Azimuthal component of the vector field. - start_points : array_like - Starting points for the field lines. - direction : str, optional - Direction to integrate in. Can be "both", "forward" or "backward". Default is "both". - stopWhen : callable, optional - Function that takes the current position and returns True if the integration should stop. Default is to never stop. - ds : float, optional - Integration step size. Default is 0.1. - maxsteps : int, optional - Maximum number of integration steps. Default is 1000. - - Returns - ------- - list - List of field lines. - Examples - -------- - >>> ds.polar.fieldlines("Br", "Bth", [[2.0, np.pi / 4], [2.0, 3 * np.pi / 4]], stopWhen = lambda xy, rth: rth[0] > 5.0) - """ + def __init__(self, **kwargs): + super(Data, self).__init__(**kwargs) - import numpy as np - from scipy.interpolate import RegularGridInterpolator - - assert "t" not in self._obj[fr].dims, "Time must be specified" - assert "t" not in self._obj[fth].dims, "Time must be specified" - assert DataIs2DPolar(self._obj), "Data must be 2D polar" - - r, th = ( - self._obj.coords["r"].values, - self._obj.coords["θ" if useGreek else "th"].values, - ) - _, ths = np.meshgrid(r, th) - fxs = self._obj[fr] * np.sin(ths) + self._obj[fth] * np.cos(ths) - fys = self._obj[fr] * np.cos(ths) - self._obj[fth] * np.sin(ths) - - props = dict(method="nearest", bounds_error=False, fill_value=0) - interpFx = RegularGridInterpolator((th, r), fxs.values, **props) - interpFy = RegularGridInterpolator((th, r), fys.values, **props) - return [ - self._fieldline(interpFx, interpFy, rth, **kwargs) for rth in start_points - ] - - def _fieldline(self, interp_fx, interp_fy, r_th_start, **kwargs): - import numpy as np - from copy import copy - - direction = kwargs.pop("direction", "both") - stopWhen = kwargs.pop("stopWhen", lambda xy, rth: False) - ds = kwargs.pop("ds", 0.1) - maxsteps = kwargs.pop("maxsteps", 1000) - - rmax = self._obj.r.max() - rmin = self._obj.r.min() - - stop = ( - lambda xy, rth: stopWhen(xy, rth) - or (rth[0] < rmin) - or (rth[0] > rmax) - or (rth[1] < 0) - or (rth[1] > np.pi) - ) - - def integrate(delta, counter): - r0, th0 = copy(r_th_start) - XY = np.array([r0 * np.sin(th0), r0 * np.cos(th0)]) - RTH = [r0, th0] - fieldline = np.array([XY]) - with np.errstate(divide="ignore", invalid="ignore"): - while range(counter, maxsteps): - x, y = XY - r = np.sqrt(x**2 + y**2) - th = np.arctan2(-y, x) + np.pi / 2 - RTH = [r, th] - vx = interp_fx((th, r))[()] - vy = interp_fy((th, r))[()] - vmag = np.sqrt(vx**2 + vy**2) - XY = XY + delta * np.array([vx, vy]) / vmag - if stop(XY, RTH) or np.isnan(XY).any() or np.isinf(XY).any(): - break - else: - fieldline = np.append(fieldline, [XY], axis=0) - return fieldline - - if direction == "forward": - return integrate(ds, 0) - elif direction == "backward": - return integrate(-ds, 0) - else: - cntr = 0 - f1 = integrate(ds, cntr) - f2 = integrate(-ds, cntr) - return np.append(f2[::-1], f1, axis=0) - - -@xr.register_dataarray_accessor("polar") -class PolarPlotAccessor: - def __init__(self, xarray_obj): - self._obj = xarray_obj - - def pcolor(self, **kwargs): - """ - Plots a pseudocolor plot of 2D polar data on a rectilinear projection. - - Parameters - ---------- - ax : Axes object, optional - The axes on which to plot. Default is the current axes. - cell_centered : bool, optional - Whether the data is cell-centered. Default is True. - cell_size : float, optional - If not cell_centered, defines the fraction of the cell to use for coloring. Default is 0.75. - cbar_size : str, optional - The size of the colorbar. Default is "5%". - cbar_pad : float, optional - The padding between the colorbar and the plot. Default is 0.05. - cbar_position : str, optional - The position of the colorbar. Default is "right". - cbar_ticksize : int or float, optional - The size of the ticks on the colorbar. Default is None. - title : str, optional - The title of the plot. Default is None. - invert_x : bool, optional - Whether to invert the x-axis. Default is False. - invert_y : bool, optional - Whether to invert the y-axis. Default is False. - ylabel : str, optional - The label for the y-axis. Default is "y". - xlabel : str, optional - The label for the x-axis. Default is "x". - label : str, optional - The label for the plot. Default is None. - - Returns - ------- - matplotlib.collections.Collection - The pseudocolor plot. - - Raises - ------ - AssertionError - If `ax` is a polar projection or if time is not specified or if data is not 2D polar. - - Notes - ----- - Additional keyword arguments are passed to `pcolormesh`. - """ - - import numpy as np - import matplotlib.pyplot as plt - import matplotlib as mpl - from mpl_toolkits.axes_grid1 import make_axes_locatable - - ax = kwargs.pop("ax", plt.gca()) - cbar_size = kwargs.pop("cbar_size", "5%") - cbar_pad = kwargs.pop("cbar_pad", 0.05) - cbar_pos = kwargs.pop("cbar_position", "right") - cbar_orientation = ( - "vertical" if cbar_pos == "right" or cbar_pos == "left" else "horizontal" - ) - cbar_ticksize = kwargs.pop("cbar_ticksize", None) - title = kwargs.pop("title", None) - invert_x = kwargs.pop("invert_x", False) - invert_y = kwargs.pop("invert_y", False) - ylabel = kwargs.pop("ylabel", "y") - xlabel = kwargs.pop("xlabel", "x") - label = kwargs.pop("label", None) - cell_centered = kwargs.pop("cell_centered", True) - cell_size = kwargs.pop("cell_size", 0.75) - - assert ax.name != "polar", "`ax` must be a rectilinear projection" - assert "t" not in self._obj.dims, "Time must be specified" - assert DataIs2DPolar(self._obj), "Data must be 2D polar" - ax.grid(False) - if type(kwargs.get("norm", None)) == mpl.colors.LogNorm: - cm = kwargs.get("cmap", "viridis") - cm = mpl.colormaps[cm] - cm.set_bad(cm(0)) - kwargs["cmap"] = cm - - vals = self._obj.values.flatten() - vals = np.concatenate((vals, vals)) - if not cell_centered: - drs = self._obj.coords["r_2"] - self._obj.coords["r_1"] - dths = ( - self._obj.coords["θ_2" if useGreek else "th_2"] - - self._obj.coords["θ_1" if useGreek else "th_1"] - ) - r1s = self._obj.coords["r_1"] - drs * cell_size / 2 - r2s = self._obj.coords["r_1"] + drs * cell_size / 2 - th1s = ( - self._obj.coords["θ_1" if useGreek else "th_1"] - dths * cell_size / 2 - ) - th2s = ( - self._obj.coords["θ_1" if useGreek else "th_1"] + dths * cell_size / 2 - ) - rs = np.ravel(np.column_stack((r1s, r2s))) - ths = np.ravel(np.column_stack((th1s, th2s))) - nr = len(rs) - nth = len(ths) - rs, ths = np.meshgrid(rs, ths) - rs = rs.flatten() - ths = ths.flatten() - points_1 = np.arange(nth * nr).reshape(nth, -1)[:-1:2, :-1:2].flatten() - points_2 = np.arange(nth * nr).reshape(nth, -1)[:-1:2, 1::2].flatten() - points_3 = np.arange(nth * nr).reshape(nth, -1)[1::2, 1::2].flatten() - points_4 = np.arange(nth * nr).reshape(nth, -1)[1::2, :-1:2].flatten() - - else: - rs = np.append(self._obj.coords["r_1"], self._obj.coords["r_2"][-1]) - ths = np.append( - self._obj.coords["θ_1" if useGreek else "th_1"], - self._obj.coords["θ_2" if useGreek else "th_2"][-1], - ) - nr = len(rs) - nth = len(ths) - rs, ths = np.meshgrid(rs, ths) - rs = rs.flatten() - ths = ths.flatten() - points_1 = np.arange(nth * nr).reshape(nth, -1)[:-1, :-1].flatten() - points_2 = np.arange(nth * nr).reshape(nth, -1)[:-1, 1:].flatten() - points_3 = np.arange(nth * nr).reshape(nth, -1)[1:, 1:].flatten() - points_4 = np.arange(nth * nr).reshape(nth, -1)[1:, :-1].flatten() - x, y = rs * np.sin(ths), rs * np.cos(ths) - if invert_x: - x = -x - if invert_y: - y = -y - triang = mpl.tri.Triangulation( - x, - y, - triangles=np.concatenate( - [ - np.array([points_1, points_2, points_3]).T, - np.array([points_1, points_3, points_4]).T, - ], - axis=0, - ), + def __repr__(self) -> str: + return ( + self.print_container() + + "\n" + + self.print_fields() + + "\n" + + self.print_particles() + + "\n" + + self.print_spectra() ) - ax.set( - aspect="equal", - xlabel=xlabel, - ylabel=ylabel, - ) - im = ax.tripcolor(triang, vals, rasterized=True, shading="flat", **kwargs) - if cbar_pos is not None: - divider = make_axes_locatable(ax) - cax = divider.append_axes(cbar_pos, size=cbar_size, pad=cbar_pad) - _ = plt.colorbar( - im, - cax=cax, - label=self._obj.name if label is None else label, - orientation=cbar_orientation, - ) - if cbar_orientation == "vertical": - axis = cax.yaxis - else: - axis = cax.xaxis - axis.set_label_position(cbar_pos) - axis.set_ticks_position(cbar_pos) - if cbar_ticksize is not None: - cax.tick_params("both", labelsize=cbar_ticksize) - ax.set_title( - f"t={self._obj.coords['t'].values[()]:.2f}" if title is None else title - ) - return im - - def contour(self, **kwargs): - """ - Plots a pseudocolor plot of 2D polar data on a rectilinear projection. - - Parameters - ---------- - ax : Axes object, optional - The axes on which to plot. Default is the current axes. - invert_x : bool, optional - Whether to invert the x-axis. Default is False. - invert_y : bool, optional - Whether to invert the y-axis. Default is False. - - Returns - ------- - matplotlib.contour.QuadContourSet - The contour plot. - - Raises - ------ - AssertionError - If `ax` is a polar projection or if time is not specified or if data is not 2D polar. - - Notes - ----- - Additional keyword arguments are passed to `contour`. - """ - - import warnings - import numpy as np - import matplotlib.pyplot as plt - import matplotlib as mpl - from mpl_toolkits.axes_grid1 import make_axes_locatable - ax = kwargs.pop("ax", plt.gca()) - title = kwargs.pop("title", None) - invert_x = kwargs.pop("invert_x", False) - invert_y = kwargs.pop("invert_y", False) - - assert ax.name != "polar", "`ax` must be a rectilinear projection" - assert "t" not in self._obj.dims, "Time must be specified" - assert DataIs2DPolar(self._obj), "Data must be 2D polar" - ax.grid(False) - r, th = np.meshgrid( - self._obj.coords["r"], self._obj.coords["θ" if useGreek else "th"] - ) - x, y = r * np.sin(th), r * np.cos(th) - if invert_x: - x = -x - if invert_y: - y = -y - ax.set( - aspect="equal", - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - im = ax.contour(x, y, self._obj.values, **kwargs) - - return im - - -class Metric: - def __init__(self, base): - self.base = base - - -class MinkowskiMetric(Metric): - def __init__(self): - super().__init__("minkowski") - - def sqrt_h(self, **coords): - return 1 - - def h_11(self, **coords): - return 1 - - def h_22(self, **coords): - return 1 - - def h_33(self, **coords): - return 1 - - -class SphericalMetric(Metric): - def __init__(self): - super().__init__("spherical") - - def sqrt_h(self, r, th): - import numpy as np - - return r**2 * np.sin(th) - - def h_11(self, r, th): - return 1 - - def h_22(self, r, th): - return r**2 - - def h_33(self, r, th): - import numpy as np - - return r**2 * np.sin(th) ** 2 - - -class Data: - """ - A class to load data from the Entity single-HDF5 file and store it as a lazily loaded xarray Dataset. - - Parameters - ---------- - fname : str - The name of the HDF5 file to read. - - Attributes - ---------- - fname : str - The name of the HDF5 file. - file : h5py.File - The HDF5 file object. - dataset : xr.Dataset - The xarray Dataset containing the loaded data. - particles: list - The list of particle species in the simulation. Each element is an Xarray Dataset. - - Methods - ------- - __del__() - Closes the HDF5 file. - __getattr__(name) - Gets an attribute from the xarray Dataset. - __getitem__(name) - Gets an item from the xarray Dataset. - - Examples - -------- - >>> import nt2.read as nt2r - >>> data = nt2r.Data("Sim.h5") - >>> data.Bx.sel(t=10.0, method="nearest").plot() - """ - - def __init__(self, fname): - if usePickle: - import h5pickle as h5py - else: - import h5py - import dask.array as da - from functools import reduce - import numpy as np - - QuantityDict = { - "Ttt": "E", - "Ttx": "Px", - "Tty": "Py", - "Ttz": "Pz", - } - CoordinateDict = { - "cart": {"x": "x", "y": "y", "z": "z", "1": "x", "2": "y", "3": "z"}, - "sph": { - "r": "r", - "theta": "θ" if useGreek else "th", - "phi": "φ" if useGreek else "ph", - "1": "r", - "2": "θ" if useGreek else "th", - "3": "φ" if useGreek else "ph", - }, - } - PrtlDict = { - "cart": { - "X1": "x", - "X2": "y", - "X3": "z", - "U1": "ux", - "U2": "uy", - "U3": "uz", - }, - "sph": { - "X1": "r", - "X2": "θ" if useGreek else "th", - "X3": "φ" if useGreek else "ph", - "U1": "ur", - "U2": "uΘ" if useGreek else "uth", - "U3": "uφ" if useGreek else "uph", - }, - } - self.fname = fname - try: - self.file = h5py.File(self.fname, "r") - except OSError: - raise OSError(f"Could not open file {self.fname}") - ngh = int(self.file.attrs["NGhosts"]) - layout = "right" if self.file.attrs["LayoutRight"] == 1 else "left" - dimension = int(self.file.attrs["Dimension"]) - coordinates = self.file.attrs["Coordinates"].decode("UTF-8") - if coordinates == "qsph": - coordinates = "sph" - if coordinates == "sph": - self.metric = SphericalMetric() - else: - self.metric = MinkowskiMetric() - coords = list(CoordinateDict[coordinates].values())[::-1][-dimension:] - - for s in self.file.keys(): - if any([k.startswith("X") for k in self.file[s].keys()]): - # cell-centered coords - cc_coords = { - c: self.file[s][f"X{i+1}"] for i, c in enumerate(coords[::-1]) - } - # cell edges - cell_1 = { - f"{c}_1": ( - c, - self.file[s][f"X{i+1}e"][:-1], - ) - for i, c in enumerate(coords[::-1]) - } - cell_2 = { - f"{c}_2": ( - c, - self.file[s][f"X{i+1}e"][1:], - ) - for i, c in enumerate(coords[::-1]) - } - break - - if dimension == 1: - noghosts = slice(ngh, -ngh) if ngh > 0 else slice(None) - elif dimension == 2: - noghosts = (slice(ngh, -ngh), slice(ngh, -ngh)) if ngh > 0 else slice(None) - elif dimension == 3: - noghosts = ( - (slice(ngh, -ngh), slice(ngh, -ngh), slice(ngh, -ngh)) - if ngh > 0 - else slice(None) - ) - - self.dataset = xr.Dataset() - - # -------------------------------- load fields ------------------------------- # - fields = None - f_outsteps = [] - f_steps = [] - f_times = [] - for s in self.file.keys(): - if any([k.startswith("f") for k in self.file[s].keys()]): - if fields is None: - fields = [k for k in self.file[s].keys() if k.startswith("f")] - f_outsteps.append(s) - f_times.append(self.file[s]["Time"][()]) - f_steps.append(self.file[s]["Step"][()]) - - f_outsteps = sorted(f_outsteps, key=lambda x: int(x.replace("Step", ""))) - f_steps = sorted(f_steps) - f_times = np.array(sorted(f_times), dtype=np.float64) - - for k in self.file.attrs.keys(): - if ( - type(self.file.attrs[k]) == bytes - or type(self.file.attrs[k]) == np.bytes_ - ): - self.dataset.attrs[k] = self.file.attrs[k].decode("UTF-8") - else: - self.dataset.attrs[k] = self.file.attrs[k] - - for k in fields: - dask_arrays = [] - for s in f_outsteps: - array = da.from_array( - np.transpose(self.file[f"{s}/{k}"]) - if layout == "right" - else self.file[f"{s}/{k}"] - ) - dask_arrays.append(array[noghosts]) - - k_ = reduce( - lambda x, y: ( - x.replace(*y) - if "_" not in x - else "_".join([x.split("_")[0].replace(*y)] + x.split("_")[1:]) - ), - [k, *list(CoordinateDict[coordinates].items())], - ) - k_ = reduce( - lambda x, y: x.replace(*y), - [k_, *list(QuantityDict.items())], - )[1:] - x = xr.DataArray( - da.stack(dask_arrays, axis=0), - dims=["t", *coords], - name=k_, - coords={ - "t": f_times, - "s": ("t", f_steps), - **cc_coords, - **cell_1, - **cell_2, - }, - ) - self.dataset[k_] = x - - # ------------------------------ load particles ------------------------------ # - particles = None - p_outsteps = [] - p_steps = [] - p_times = [] - for s in self.file.keys(): - if any([k.startswith("p") for k in self.file[s].keys()]): - if particles is None: - particles = [k for k in self.file[s].keys() if k.startswith("p")] - p_outsteps.append(s) - p_times.append(self.file[s]["Time"][()]) - p_steps.append(self.file[s]["Step"][()]) - - p_outsteps = sorted(p_outsteps, key=lambda x: int(x.replace("Step", ""))) - p_steps = sorted(p_steps) - p_times = np.array(sorted(p_times), dtype=np.float64) - - self._particles = {} - - if len(p_outsteps) > 0: - species = np.unique( - [ - int(pq.split("_")[1]) - for pq in self.file[p_outsteps[0]].keys() - if pq.startswith("p") - ] - ) - - def list_to_ragged(arr): - max_len = np.max([len(a) for a in arr]) - return map( - lambda a: np.concatenate([a, np.full(max_len - len(a), np.nan)]), - arr, - ) - - for s in species: - prtl_data = {} - for q in [ - f"X1_{s}", - f"X2_{s}", - f"X3_{s}", - f"U1_{s}", - f"U2_{s}", - f"U3_{s}", - f"W_{s}", - ]: - if q[0] in ["X", "U"]: - q_ = PrtlDict[coordinates][q.split("_")[0]] - else: - q_ = q.split("_")[0] - if "p" + q not in particles: - continue - if q not in prtl_data.keys(): - prtl_data[q_] = [] - for step_k in p_outsteps: - if "p" + q in self.file[step_k].keys(): - prtl_data[q_].append(self.file[step_k]["p" + q]) - else: - prtl_data[q_].append( - np.full_like(prtl_data[q_][-1], np.nan) - ) - prtl_data[q_] = list_to_ragged(prtl_data[q_]) - prtl_data[q_] = da.from_array(list(prtl_data[q_])) - prtl_data[q_] = xr.DataArray( - prtl_data[q_], - dims=["t", "id"], - name=q_, - coords={"t": p_times, "s": ("t", p_steps)}, - ) - if coordinates == "sph": - prtl_data["x"] = ( - prtl_data[PrtlDict[coordinates]["X1"]] - * np.sin(prtl_data[PrtlDict[coordinates]["X2"]]) - * np.cos(prtl_data[PrtlDict[coordinates]["X3"]]) - ) - prtl_data["y"] = ( - prtl_data[PrtlDict[coordinates]["X1"]] - * np.sin(prtl_data[PrtlDict[coordinates]["X2"]]) - * np.sin(prtl_data[PrtlDict[coordinates]["X3"]]) - ) - prtl_data["z"] = prtl_data[PrtlDict[coordinates]["X1"]] * np.cos( - prtl_data[PrtlDict[coordinates]["X2"]] - ) - self._particles[s] = xr.Dataset(prtl_data) - - # ------------------------------- load spectra ------------------------------- # - spectra = None - s_outsteps = [] - s_steps = [] - s_times = [] - for s in self.file.keys(): - if any([k.startswith("s") for k in self.file[s].keys()]): - if spectra is None: - spectra = [k for k in self.file[s].keys() if k.startswith("s")] - s_outsteps.append(s) - s_times.append(self.file[s]["Time"][()]) - s_steps.append(self.file[s]["Step"][()]) - - s_outsteps = sorted(s_outsteps, key=lambda x: int(x.replace("Step", ""))) - s_steps = sorted(s_steps) - s_times = np.array(sorted(s_times), dtype=np.float64) - - self._spectra = xr.Dataset() - log_bins = self.file.attrs["output.spectra.log_bins"] - - if len(s_outsteps) > 0: - species = np.unique( - [ - int(pq.split("_")[1]) - for pq in self.file[s_outsteps[0]].keys() - if pq.startswith("sN") - ] - ) - e_bins = self.file[s_outsteps[0]]["sEbn"] - if log_bins: - e_bins = np.sqrt(e_bins[1:] * e_bins[:-1]) - else: - e_bins = (e_bins[1:] + e_bins[:-1]) / 2 - - for sp in species: - dask_arrays = [] - for st in s_outsteps: - array = da.from_array(self.file[f"{st}/sN_{sp}"]) - dask_arrays.append(array) - - x = xr.DataArray( - da.stack(dask_arrays, axis=0), - dims=["t", "e"], - name=f"n_{sp}", - coords={ - "t": s_times, - "s": ("t", s_steps), - "e": e_bins, - }, - ) - self._spectra[f"n_{sp}"] = x + def __str__(self) -> str: + return self.__repr__() def __del__(self): - self.file.close() - - def __getattr__(self, name): - return getattr(self.dataset, name) - - def __getitem__(self, name): - return self.dataset[name] - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.file.close() - - @property - def particles(self): - return self._particles - - @property - def spectra(self): - return self._spectra - - def plotGrid(self, ax, **kwargs): - import matplotlib as mpl - import numpy as np - - coordinates = self.file.attrs["Coordinates"].decode("UTF-8") - if coordinates == "qsph": - coordinates = "sph" - - xlim, ylim = ax.get_xlim(), ax.get_ylim() - options = { - "lw": 1, - "color": "k", - "ls": "-", - } - options.update(kwargs) - - if coordinates == "cart": - for x in self.attrs["X1"]: - ax.plot([x, x], [self.attrs["X2Min"], self.attrs["X2Max"]], **options) - for y in self.attrs["X2"]: - ax.plot([self.attrs["X1Min"], self.attrs["X1Max"]], [y, y], **options) - else: - for r in self.attrs["X1"]: - ax.add_patch( - mpl.patches.Arc( - (0, 0), - 2 * r, - 2 * r, - theta1=-90, - theta2=90, - fill=False, - **options, - ) - ) - for th in self.attrs["X2"]: - ax.plot( - [ - self.attrs["X1Min"] * np.sin(th), - self.attrs["X1Max"] * np.sin(th), - ], - [ - self.attrs["X1Min"] * np.cos(th), - self.attrs["X1Max"] * np.cos(th), - ], - **options, - ) - ax.set(xlim=xlim, ylim=ylim) - - def makeMovie(self, plot, makeframes=True, **kwargs): - """ - Makes a movie from a plot function - - Parameters - ---------- - plot : function - The plot function to use; accepts output timestep and dataset as arguments. - makeframes : bool, optional - Whether to make the frames, or just proceed to making the movie. Default is True. - num_cpus : int, optional - The number of CPUs to use for making the frames. Default is None. - **kwargs : - Additional keyword arguments passed to `ffmpeg`. - """ - import numpy as np - - if makeframes: - makemovie = all( - exp.makeFrames( - plot, - np.arange(len(self.t)), - f"{self.attrs['simulation.name']}/frames", - data=self, - num_cpus=kwargs.pop("num_cpus", None), - ) - ) - else: - makemovie = True - if makemovie: - exp.makeMovie( - input=f"{self.attrs['simulation.name']}/frames/", - overwrite=True, - output=f"{self.attrs['simulation.name']}.mp4", - number=5, - **kwargs, - ) - return True + self.client.close() + super().__del__() diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..bdfd610 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,3 @@ +{ + "extraPath": ["./"], +}