Skip to content

Commit

Permalink
Fixes plotting point clouds from 1-D data
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesVarndell committed Jun 25, 2024
1 parent 2265c65 commit 5e5351a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 14 deletions.
5 changes: 5 additions & 0 deletions src/earthkit/plots/quickmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def scatter(*args, **kwargs):
"""Quick plot"""


@_quickmap
def point_cloud(*args, **kwargs):
"""Quick plot"""


@_quickmap
def block(*args, **kwargs):
"""Quick plot"""
Expand Down
2 changes: 1 addition & 1 deletion src/earthkit/plots/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_source(*args, data=None, x=None, y=None, z=None, u=None, v=None, **kwarg
if len(args) == 1 and core_data is None:
core_data = args[0]
if core_data is not None:
if data.__class__.__name__ in ("Dataset", "DataArray"):
if core_data.__class__.__name__ in ("Dataset", "DataArray"):
cls = XarraySource
elif isinstance(core_data, ek_data.core.Base):
cls = EarthkitSource
Expand Down
5 changes: 4 additions & 1 deletion src/earthkit/plots/sources/earthkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def metadata(self, key, default=None):
"""
value = super().metadata(key, default)
if value == default:
value = self.data.metadata(key, default=default)
try:
value = self.data.metadata(key, default=default)
except NotImplementedError:
pass
return value

def datetime(self, *args, **kwargs):
Expand Down
21 changes: 10 additions & 11 deletions src/earthkit/plots/sources/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,17 @@ def y_values(self):
def z_values(self):
"""The z values of the data."""
values = None
if len(self.dims) > 1:
if self._z is None:
if not hasattr(self.data, "data_vars"):
data = self.data
else:
data = self.data[list(self.data.data_vars)[0]]
if self._z is None:
if not hasattr(self.data, "data_vars"):
data = self.data
else:
data = self.data[self._z]
values = data.values
x, y = self.extract_xy()
if [y, x] != [c for c in data.dims if c in [y, x]]:
values = values.T
data = self.data[list(self.data.data_vars)[0]]
else:
data = self.data[self._z]
values = data.values
# x, y = self.extract_xy()
# if [y, x] != [c for c in data.dims if c in [y, x]]:
# values = values.T

return values

Expand Down
7 changes: 6 additions & 1 deletion src/earthkit/plots/styles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,12 @@ def scatter(self, ax, x, y, values, s=3, *args, **kwargs):
if values is not None:
kwargs = {**self.to_scatter_kwargs(values), **kwargs}
kwargs.pop("extend", None)
if values is not None and missing_values is not None and np.isnan(values).any():
if (
values is not None
and missing_values is not None
and np.isnan(values).any()
and len(values.shape) > 1
):
missing_idx = np.where(np.isnan(values))
missing_x = x[missing_idx]
missing_y = y[missing_idx]
Expand Down

0 comments on commit 5e5351a

Please sign in to comment.