Skip to content

Commit

Permalink
multifile support
Browse files Browse the repository at this point in the history
  • Loading branch information
haykh committed Nov 6, 2024
1 parent c2c92cf commit 8f142cd
Show file tree
Hide file tree
Showing 6 changed files with 531 additions and 341 deletions.
35 changes: 21 additions & 14 deletions nt2/containers/container.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import h5py
import numpy as np
from typing import Any
Expand Down Expand Up @@ -40,22 +41,29 @@ def __init__(
except OSError:
raise OSError(f"Could not open file {self.path}")
else:
self.master_file: h5py.File | None = None
raise NotImplementedError("Multiple files not yet supported")
field_path = os.path.join(self.path, "fields")
file = os.path.join(field_path, os.listdir(field_path)[0])
try:
self.master_file: h5py.File | None = h5py.File(file, "r")
except OSError:
raise OSError(f"Could not open file {file}")

self.attrs = _read_attribs_SingleFile(self.master_file)

if self.configs["single_file"]:
self.configs["ngh"] = int(self.master_file.attrs.get("NGhosts", 0))
self.configs["layout"] = (
"right" if self.master_file.attrs.get("LayoutRight", 1) == 1 else "left"
)
self.configs["dimension"] = int(self.master_file.attrs.get("Dimension", 1))
self.configs["coordinates"] = self.master_file.attrs.get(
"Coordinates", b"cart"
).decode("UTF-8")
if self.configs["coordinates"] == "qsph":
self.configs["coordinates"] = "sph"
self.configs["ngh"] = int(self.master_file.attrs.get("NGhosts", 0))
self.configs["layout"] = (
"right" if self.master_file.attrs.get("LayoutRight", 1) == 1 else "left"
)
self.configs["dimension"] = int(self.master_file.attrs.get("Dimension", 1))
self.configs["coordinates"] = self.master_file.attrs.get(
"Coordinates", b"cart"
).decode("UTF-8")
if self.configs["coordinates"] == "qsph":
self.configs["coordinates"] = "sph"

if not self.configs["single_file"]:
self.master_file.close()
self.master_file = None
# if coordinates == "sph":
# self.metric = SphericalMetric()
# else:
Expand Down Expand Up @@ -107,7 +115,6 @@ def plotGrid(self, ax, **kwargs):
def print_container(self) -> str:
return f"Client {self.client}\n"

#
# def makeMovie(self, plot, makeframes=True, **kwargs):
# """
# Makes a movie from a plot function
Expand Down
188 changes: 50 additions & 138 deletions nt2/containers/fields.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,13 @@
import os
import h5py
import xarray as xr
import numpy as np
from dask.array.core import from_array
from dask.array.core import stack

from nt2.containers.container import Container
from nt2.containers.utils import _read_category_metadata_SingleFile


def _read_coordinates_SingleFile(coords: list[str], file: h5py.File):
for st in file:
group = file[st]
if isinstance(group, h5py.Group):
if any([k.startswith("X") for k in group if k is not None]):
# cell-centered coords
xc = {
c: (
np.asarray(xi[:])
if isinstance(xi := group[f"X{i+1}"], h5py.Dataset) and xi
else None
)
for i, c in enumerate(coords[::-1])
}
# cell edges
xe_min = {
f"{c}_1": (
c,
(
np.asarray(xi[:-1])
if isinstance((xi := group[f"X{i+1}e"]), h5py.Dataset)
else None
),
)
for i, c in enumerate(coords[::-1])
}
xe_max = {
f"{c}_2": (
c,
(
np.asarray(xi[1:])
if isinstance((xi := group[f"X{i+1}e"]), h5py.Dataset)
else None
),
)
for i, c in enumerate(coords[::-1])
}
return {"x_c": xc, "x_emin": xe_min, "x_emax": xe_max}
else:
raise ValueError(f"Unexpected type {type(file[st])}")
raise ValueError("Could not find coordinates in file")


def _preload_field_SingleFile(
k: str,
dim: int,
ngh: int,
outsteps: list[int],
times: list[float],
steps: list[int],
coords: list[str],
xc_coords: dict[str, str],
xe_min_coords: dict[str, str],
xe_max_coords: dict[str, str],
coord_replacements: list[tuple[str, str]],
field_replacements: list[tuple[str, str]],
layout: str,
file: h5py.File,
):
if dim == 1:
noghosts = slice(ngh, -ngh) if ngh > 0 else slice(None)
elif dim == 2:
noghosts = (slice(ngh, -ngh), slice(ngh, -ngh)) if ngh > 0 else slice(None)
elif dim == 3:
noghosts = (
(slice(ngh, -ngh), slice(ngh, -ngh), slice(ngh, -ngh))
if ngh > 0
else slice(None)
)
else:
raise ValueError("Invalid dimension")

dask_arrays = []
for s in outsteps:
dset = file[f"{s}/{k}"]
if isinstance(dset, h5py.Dataset):
array = from_array(np.transpose(dset) if layout == "right" else dset)
dask_arrays.append(array[noghosts])
else:
raise ValueError(f"Unexpected type {type(dset)}")

k_ = k[1:]
for c in coord_replacements:
if "_" not in k_:
k_ = k_.replace(c[0], c[1])
else:
k_ = "_".join([k_.split("_")[0].replace(c[0], c[1])] + k_.split("_")[1:])
for f in field_replacements:
k_ = k_.replace(*f)

return k_, xr.DataArray(
stack(dask_arrays, axis=0),
dims=["t", *coords],
name=k_,
coords={
"t": times,
"s": ("t", steps),
**xc_coords,
**xe_min_coords,
**xe_max_coords,
},
)
from nt2.containers.utils import (
_read_category_metadata,
_read_coordinates,
_preload_field,
)


class FieldsContainer(Container):
Expand All @@ -134,50 +32,64 @@ def __init__(self, **kwargs):
}
if self.configs["single_file"]:
assert self.master_file is not None, "Master file not found"
self.metadata["fields"] = _read_category_metadata_SingleFile(
"f", self.master_file
self.metadata["fields"] = _read_category_metadata(
True, "f", self.master_file
)
else:
field_path = os.path.join(self.path, "fields")
files = sorted(os.listdir(field_path))
try:
raise NotImplementedError("Multiple files not yet supported")
self.fields_files = [
h5py.File(os.path.join(field_path, f), "r") for f in files
]
except OSError:
raise OSError(f"Could not open file {self.path}")
raise OSError(f"Could not open file in {field_path}")
self.metadata["fields"] = _read_category_metadata(
False, "f", self.fields_files
)

coords = list(CoordinateDict[self.configs["coordinates"]].values())[::-1][
-self.configs["dimension"] :
]

if self.configs["single_file"]:
self.mesh = _read_coordinates_SingleFile(coords, self.master_file)
assert self.master_file is not None, "Master file not found"
self.mesh = _read_coordinates(coords, self.master_file)
else:
raise NotImplementedError("Multiple files not yet supported")
self.mesh = _read_coordinates(coords, self.fields_files[0])

self.fields = xr.Dataset()
self._fields = xr.Dataset()

if len(self.metadata["fields"]["outsteps"]) > 0:
if self.configs["single_file"]:
for k in self.metadata["fields"]["quantities"]:
name, dset = _preload_field_SingleFile(
k,
dim=self.configs["dimension"],
ngh=self.configs["ngh"],
outsteps=self.metadata["fields"]["outsteps"],
times=self.metadata["fields"]["times"],
steps=self.metadata["fields"]["steps"],
coords=coords,
xc_coords=self.mesh["x_c"],
xe_min_coords=self.mesh["x_emin"],
xe_max_coords=self.mesh["x_emax"],
coord_replacements=list(
CoordinateDict[self.configs["coordinates"]].items()
),
field_replacements=list(QuantityDict.items()),
layout=self.configs["layout"],
file=self.master_file,
)
self.fields[name] = dset
else:
raise NotImplementedError("Multiple files not yet supported")
for k in self.metadata["fields"]["quantities"]:
name, dset = _preload_field(
single_file=self.configs["single_file"],
k=k,
dim=self.configs["dimension"],
ngh=self.configs["ngh"],
outsteps=self.metadata["fields"]["outsteps"],
times=self.metadata["fields"]["times"],
steps=self.metadata["fields"]["steps"],
coords=coords,
xc_coords=self.mesh["x_c"],
xe_min_coords=self.mesh["x_emin"],
xe_max_coords=self.mesh["x_emax"],
coord_replacements=list(
CoordinateDict[self.configs["coordinates"]].items()
),
field_replacements=list(QuantityDict.items()),
layout=self.configs["layout"],
file=(
self.master_file
if self.configs["single_file"] and self.master_file is not None
else self.fields_files
),
)
self.fields[name] = dset

@property
def fields(self):
return self._fields

def __del__(self):
if self.configs["single_file"] and self.master_file is not None:
Expand Down
Loading

0 comments on commit 8f142cd

Please sign in to comment.