-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
247 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
"""Module for waterfall plotting.""" | ||
from __future__ import annotations | ||
|
||
from collections.abc import Sequence | ||
from typing import Literal | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from dascore.constants import PatchType | ||
from dascore.utils.patch import patch_function | ||
from dascore.utils.plotting import ( | ||
_get_ax, | ||
_get_cmap, | ||
_get_dim_label, | ||
) | ||
|
||
|
||
def _set_scale(im, scale, scale_type, color_coords): | ||
"""Set the scale of the color bar based on scale and scale_type.""" | ||
# check scale paramters | ||
assert scale_type in {"absolute", "relative"} | ||
assert isinstance(scale, float | int) or len(scale) == 2 | ||
# make sure we have a len two array | ||
modifier = 1 | ||
if scale_type == "relative": | ||
modifier = 0.5 * (np.nanmax(color_coords) - np.nanmin(color_coords)) | ||
# only one scale parameter provided, center around mean | ||
if isinstance(scale, float): | ||
mean = np.nanmean(color_coords) | ||
scale = np.array([mean - scale * modifier, mean + scale * modifier]) | ||
im.set_clim(scale) | ||
|
||
|
||
@patch_function() | ||
def map_fiber( | ||
patch: PatchType, | ||
x: np.ndarray | str = "distance", | ||
y: np.ndarray | str = "distance", | ||
color: np.ndarray | str = "distance", | ||
ax: plt.Axes | None = None, | ||
cmap="cividis_r", | ||
scale: float | Sequence[float] | None = None, | ||
scale_type: Literal["relative", "absolute"] = "relative", | ||
show=False, | ||
) -> plt.Axes: | ||
""" | ||
Create a plot of the outline of the cable colorized by a given parameter. | ||
Parameters | ||
---------- | ||
patch | ||
The Patch object. | ||
x | ||
x coordinate: can be an array or a str representing a patch coordinate. | ||
y | ||
y coordinate: can be an array or a str representing a patch coordinate. | ||
color | ||
The color parameter to plot: can be an array or a str representing a patch | ||
attribute. | ||
ax | ||
A matplotlib object, if None create one. | ||
cmap | ||
A matplotlib colormap string or instance. Set to None to not plot the | ||
colorbar. | ||
scale | ||
If not None, controls the saturation level of the colorbar. | ||
Values can either be a float, to set upper and lower limit to the same | ||
value centered around the mean of the data, or a length 2 tuple | ||
specifying upper and lower limits. See `scale_type` for controlling how | ||
values are scaled. | ||
scale_type | ||
Controls the type of scaling specified by `scale` parameter. Options | ||
are: | ||
relative - scale based on half the dynamic range in patch | ||
absolute - scale based on absolute values provided to `scale` | ||
show | ||
If True, show the plot, else just return axis. | ||
Examples | ||
-------- | ||
>>> # Plot patch | ||
>>> import dascore as dc | ||
>>> patch = dc.get_example_patch("random_patch_with_lat_lon") | ||
>>> patch = patch.set_units(latitude="m", longitude="m") | ||
>>> _ = patch.viz.map_fiber("latitude", "longitude", "distance") | ||
""" | ||
dims = [] | ||
if isinstance(x, str): | ||
assert x in patch.coords, f"{x} not found in patch coordinates" | ||
dims.append(x) | ||
x = patch.coords.get_array(x) | ||
if isinstance(y, str): | ||
assert y in patch.coords, f"{y} not found in patch coordinates" | ||
dims.append(y) | ||
y = patch.coords.get_array(y) | ||
if isinstance(color, str): | ||
assert color in patch.coords, f"{color} not found in patch coordinates" | ||
data_type = color | ||
data_units = patch.attrs.coords[color].units | ||
color = patch.coords.get_array(color) | ||
else: | ||
data_type = "" | ||
data_units = "" | ||
|
||
ax = _get_ax(ax) | ||
cmap = _get_cmap(cmap) | ||
|
||
im = ax.scatter(x, y, c=color, cmap=cmap) | ||
|
||
# scale colorbar | ||
if scale is not None: | ||
_set_scale(im, scale, scale_type, color) | ||
|
||
# set axis labels | ||
for dim, x in zip(dims, ["x", "y"]): | ||
getattr(ax, f"set_{x}label")(_get_dim_label(patch, dim)) | ||
|
||
# add color bar with title | ||
if cmap is not None: | ||
cb = ax.get_figure().colorbar(im, ax=ax, fraction=0.05, pad=0.025) | ||
dunits = f" ({data_units})" if (data_type and data_units) else f"{data_units}" | ||
label = f"{data_type}{dunits}" | ||
cb.set_label(label) | ||
|
||
if show: | ||
plt.show() | ||
return ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
"""Tests for waterfall plots.""" | ||
from __future__ import annotations | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
|
||
import dascore as dc | ||
from dascore.utils.time import is_datetime64 | ||
|
||
|
||
def check_label_units(patch, ax, dims, color="distance"): | ||
"""Ensure patch label units match axis.""" | ||
axis_dict = {0: "xaxis", 1: "yaxis"} | ||
# dims = [] | ||
# Check coord-inate names | ||
for coord_name in dims: | ||
coord = patch.coords.coord_map[coord_name] | ||
if is_datetime64(coord[0]): | ||
continue # just skip datetimes for now. | ||
index = dims.index(coord_name) | ||
axis = getattr(ax, axis_dict[index]) | ||
label_text = axis.get_label().get_text().lower() | ||
assert str(coord.units.units) in label_text | ||
assert coord_name in label_text | ||
# check colorbar labels | ||
# cax = ax.images[-1].colorbar | ||
coord = patch.coords.coord_map[color] | ||
yaxis_label = ax.figure.get_children()[-1].yaxis.label.get_text() | ||
assert str(coord.units.units) in yaxis_label | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def patch_random_start(event_patch_1): | ||
"""Get a patch with a random, odd, starttime.""" | ||
random_starttime = dc.to_datetime64("2020-01-02T02:12:11.02232") | ||
attrs = dict(event_patch_1.attrs) | ||
coords = {i: v for i, v in event_patch_1.coords.items()} | ||
time = coords["time"] - coords["time"].min() | ||
coords["time"] = time + random_starttime | ||
attrs["time_min"] = coords["time"].min() | ||
attrs["time_max"] = coords["time"].max() | ||
patch = event_patch_1.update(attrs=attrs, coords=coords) | ||
return patch | ||
|
||
|
||
class TestPlotMap: | ||
"""Tests for map plot.""" | ||
|
||
def test_str_input(self, random_patch_with_lat_lon): | ||
"""Call map_fiber plot, return.""" | ||
patch = random_patch_with_lat_lon.set_units(latitude="ft") | ||
patch = patch.set_units(longitude="m") | ||
ax = patch.viz.map_fiber("latitude", "longitude") | ||
|
||
caxis_label = ax.figure.get_children()[-1].yaxis.label.get_text() | ||
|
||
# check labels | ||
assert "latitude" in ax.get_xlabel().lower() | ||
assert "longitude" in ax.get_ylabel().lower() | ||
assert "distance" in caxis_label | ||
assert isinstance(ax, plt.Axes) | ||
|
||
def test_array_inputs(self, random_patch_with_lat_lon): | ||
"""Call map_fiber plot, return.""" | ||
lats = random_patch_with_lat_lon.coords.get_array("latitude") | ||
lons = random_patch_with_lat_lon.coords.get_array("longitude") | ||
data = 0.5 * (lats + lons) | ||
ax = random_patch_with_lat_lon.viz.map_fiber(lats, lons, data) | ||
|
||
assert isinstance(ax, plt.Axes) | ||
|
||
def test_default_parameters(self, random_patch): | ||
"""Call map_fiber plot, return.""" | ||
ax = random_patch.viz.map_fiber() | ||
|
||
# check labels | ||
assert "distance" in ax.get_ylabel().lower() | ||
assert "distance" in ax.get_xlabel().lower() | ||
assert isinstance(ax, plt.Axes) | ||
|
||
def test_colorbar_scale(self, random_patch): | ||
"""Tests for the scaling parameter.""" | ||
ax_scalar = random_patch.viz.map_fiber(scale=0.2) | ||
assert ax_scalar is not None | ||
seq_scalar = random_patch.viz.map_fiber(scale=[0.1, 0.3]) | ||
assert seq_scalar is not None | ||
|
||
def test_colorbar_absolute_scale(self, random_patch): | ||
"""Tests for absolute scaling of colorbar.""" | ||
patch = random_patch.new(data=random_patch.data * 100 - 50) | ||
ax1 = patch.viz.map_fiber(scale_type="absolute", scale=(-50, 50)) | ||
assert ax1 is not None | ||
ax2 = patch.viz.map_fiber(scale_type="absolute", scale=10) | ||
assert ax2 is not None | ||
|
||
def test_no_colorbar(self, random_patch): | ||
"""Ensure the colorbar can be disabled.""" | ||
ax = random_patch.viz.map_fiber(cmap=None) | ||
# ensure no colorbar was created. | ||
assert len(ax.figure.get_children()) == 2 | ||
|
||
def test_units(self, random_patch_with_lat_lon): | ||
"""Test that units show up in labels.""" | ||
# standard units | ||
|
||
pa = random_patch_with_lat_lon.set_units(distance="m/s") | ||
ax = pa.viz.map_fiber() | ||
check_label_units(pa, ax, ["distance", "distance"]) | ||
|
||
new = pa.set_units(latitude="ft", longitude="m") | ||
ax = new.viz.map_fiber("latitude", "longitude") | ||
check_label_units(new, ax, ["latitude", "longitude"]) | ||
|
||
def test_show(self, random_patch, monkeypatch): | ||
"""Ensure show path is callable.""" | ||
monkeypatch.setattr(plt, "show", lambda: None) | ||
random_patch.viz.map_fiber(show=True) |