Skip to content

Commit

Permalink
Map plot (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
aissah authored Jun 6, 2024
1 parent d319112 commit bca67c4
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dascore/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .spectrogram import spectrogram
from .waterfall import waterfall
from .wiggle import wiggle
from .map_fiber import map_fiber


class VizPatchNameSpace(MethodNameSpace):
Expand All @@ -15,3 +16,4 @@ class VizPatchNameSpace(MethodNameSpace):
waterfall = waterfall
spectrogram = spectrogram
wiggle = wiggle
map_fiber = map_fiber
128 changes: 128 additions & 0 deletions dascore/viz/map_fiber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Module for waterfall plotting."""
from __future__ import annotations

from collections.abc import Sequence
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np

from dascore.constants import PatchType
from dascore.utils.patch import patch_function
from dascore.utils.plotting import (
_get_ax,
_get_cmap,
_get_dim_label,
)


def _set_scale(im, scale, scale_type, color_coords):
"""Set the scale of the color bar based on scale and scale_type."""
# check scale paramters
assert scale_type in {"absolute", "relative"}
assert isinstance(scale, float | int) or len(scale) == 2
# make sure we have a len two array
modifier = 1
if scale_type == "relative":
modifier = 0.5 * (np.nanmax(color_coords) - np.nanmin(color_coords))
# only one scale parameter provided, center around mean
if isinstance(scale, float):
mean = np.nanmean(color_coords)
scale = np.array([mean - scale * modifier, mean + scale * modifier])
im.set_clim(scale)


@patch_function()
def map_fiber(
patch: PatchType,
x: np.ndarray | str = "distance",
y: np.ndarray | str = "distance",
color: np.ndarray | str = "distance",
ax: plt.Axes | None = None,
cmap="cividis_r",
scale: float | Sequence[float] | None = None,
scale_type: Literal["relative", "absolute"] = "relative",
show=False,
) -> plt.Axes:
"""
Create a plot of the outline of the cable colorized by a given parameter.
Parameters
----------
patch
The Patch object.
x
x coordinate: can be an array or a str representing a patch coordinate.
y
y coordinate: can be an array or a str representing a patch coordinate.
color
The color parameter to plot: can be an array or a str representing a patch
attribute.
ax
A matplotlib object, if None create one.
cmap
A matplotlib colormap string or instance. Set to None to not plot the
colorbar.
scale
If not None, controls the saturation level of the colorbar.
Values can either be a float, to set upper and lower limit to the same
value centered around the mean of the data, or a length 2 tuple
specifying upper and lower limits. See `scale_type` for controlling how
values are scaled.
scale_type
Controls the type of scaling specified by `scale` parameter. Options
are:
relative - scale based on half the dynamic range in patch
absolute - scale based on absolute values provided to `scale`
show
If True, show the plot, else just return axis.
Examples
--------
>>> # Plot patch
>>> import dascore as dc
>>> patch = dc.get_example_patch("random_patch_with_lat_lon")
>>> patch = patch.set_units(latitude="m", longitude="m")
>>> _ = patch.viz.map_fiber("latitude", "longitude", "distance")
"""
dims = []
if isinstance(x, str):
assert x in patch.coords, f"{x} not found in patch coordinates"
dims.append(x)
x = patch.coords.get_array(x)
if isinstance(y, str):
assert y in patch.coords, f"{y} not found in patch coordinates"
dims.append(y)
y = patch.coords.get_array(y)
if isinstance(color, str):
assert color in patch.coords, f"{color} not found in patch coordinates"
data_type = color
data_units = patch.attrs.coords[color].units
color = patch.coords.get_array(color)
else:
data_type = ""
data_units = ""

ax = _get_ax(ax)
cmap = _get_cmap(cmap)

im = ax.scatter(x, y, c=color, cmap=cmap)

# scale colorbar
if scale is not None:
_set_scale(im, scale, scale_type, color)

# set axis labels
for dim, x in zip(dims, ["x", "y"]):
getattr(ax, f"set_{x}label")(_get_dim_label(patch, dim))

# add color bar with title
if cmap is not None:
cb = ax.get_figure().colorbar(im, ax=ax, fraction=0.05, pad=0.025)
dunits = f" ({data_units})" if (data_type and data_units) else f"{data_units}"
label = f"{data_type}{dunits}"
cb.set_label(label)

if show:
plt.show()
return ax
117 changes: 117 additions & 0 deletions tests/test_viz/test_map_fiber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests for waterfall plots."""
from __future__ import annotations

import matplotlib.pyplot as plt
import pytest

import dascore as dc
from dascore.utils.time import is_datetime64


def check_label_units(patch, ax, dims, color="distance"):
"""Ensure patch label units match axis."""
axis_dict = {0: "xaxis", 1: "yaxis"}
# dims = []
# Check coord-inate names
for coord_name in dims:
coord = patch.coords.coord_map[coord_name]
if is_datetime64(coord[0]):
continue # just skip datetimes for now.
index = dims.index(coord_name)
axis = getattr(ax, axis_dict[index])
label_text = axis.get_label().get_text().lower()
assert str(coord.units.units) in label_text
assert coord_name in label_text
# check colorbar labels
# cax = ax.images[-1].colorbar
coord = patch.coords.coord_map[color]
yaxis_label = ax.figure.get_children()[-1].yaxis.label.get_text()
assert str(coord.units.units) in yaxis_label


@pytest.fixture(scope="session")
def patch_random_start(event_patch_1):
"""Get a patch with a random, odd, starttime."""
random_starttime = dc.to_datetime64("2020-01-02T02:12:11.02232")
attrs = dict(event_patch_1.attrs)
coords = {i: v for i, v in event_patch_1.coords.items()}
time = coords["time"] - coords["time"].min()
coords["time"] = time + random_starttime
attrs["time_min"] = coords["time"].min()
attrs["time_max"] = coords["time"].max()
patch = event_patch_1.update(attrs=attrs, coords=coords)
return patch


class TestPlotMap:
"""Tests for map plot."""

def test_str_input(self, random_patch_with_lat_lon):
"""Call map_fiber plot, return."""
patch = random_patch_with_lat_lon.set_units(latitude="ft")
patch = patch.set_units(longitude="m")
ax = patch.viz.map_fiber("latitude", "longitude")

caxis_label = ax.figure.get_children()[-1].yaxis.label.get_text()

# check labels
assert "latitude" in ax.get_xlabel().lower()
assert "longitude" in ax.get_ylabel().lower()
assert "distance" in caxis_label
assert isinstance(ax, plt.Axes)

def test_array_inputs(self, random_patch_with_lat_lon):
"""Call map_fiber plot, return."""
lats = random_patch_with_lat_lon.coords.get_array("latitude")
lons = random_patch_with_lat_lon.coords.get_array("longitude")
data = 0.5 * (lats + lons)
ax = random_patch_with_lat_lon.viz.map_fiber(lats, lons, data)

assert isinstance(ax, plt.Axes)

def test_default_parameters(self, random_patch):
"""Call map_fiber plot, return."""
ax = random_patch.viz.map_fiber()

# check labels
assert "distance" in ax.get_ylabel().lower()
assert "distance" in ax.get_xlabel().lower()
assert isinstance(ax, plt.Axes)

def test_colorbar_scale(self, random_patch):
"""Tests for the scaling parameter."""
ax_scalar = random_patch.viz.map_fiber(scale=0.2)
assert ax_scalar is not None
seq_scalar = random_patch.viz.map_fiber(scale=[0.1, 0.3])
assert seq_scalar is not None

def test_colorbar_absolute_scale(self, random_patch):
"""Tests for absolute scaling of colorbar."""
patch = random_patch.new(data=random_patch.data * 100 - 50)
ax1 = patch.viz.map_fiber(scale_type="absolute", scale=(-50, 50))
assert ax1 is not None
ax2 = patch.viz.map_fiber(scale_type="absolute", scale=10)
assert ax2 is not None

def test_no_colorbar(self, random_patch):
"""Ensure the colorbar can be disabled."""
ax = random_patch.viz.map_fiber(cmap=None)
# ensure no colorbar was created.
assert len(ax.figure.get_children()) == 2

def test_units(self, random_patch_with_lat_lon):
"""Test that units show up in labels."""
# standard units

pa = random_patch_with_lat_lon.set_units(distance="m/s")
ax = pa.viz.map_fiber()
check_label_units(pa, ax, ["distance", "distance"])

new = pa.set_units(latitude="ft", longitude="m")
ax = new.viz.map_fiber("latitude", "longitude")
check_label_units(new, ax, ["latitude", "longitude"])

def test_show(self, random_patch, monkeypatch):
"""Ensure show path is callable."""
monkeypatch.setattr(plt, "show", lambda: None)
random_patch.viz.map_fiber(show=True)

0 comments on commit bca67c4

Please sign in to comment.