diff --git a/environment-dev.yml b/environment-dev.yml index eed054a3..fb6c038e 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -22,4 +22,6 @@ dependencies: # test/dev - ipykernel - ipython + - packaging + - pooch - pytest diff --git a/monet/monet_accessor.py b/monet/monet_accessor.py index c659e18a..5ccc1514 100644 --- a/monet/monet_accessor.py +++ b/monet/monet_accessor.py @@ -4,6 +4,8 @@ import pandas as pd import xarray as xr +from .plots import _set_outline_patch_alpha + try: import xesmf # noqa: F401 @@ -1021,7 +1023,7 @@ def _check_kwargs_and_set_defaults(**kwargs): kwargs["filename"] = "monet_xesmf_regrid_file.nc" return kwargs - def quick_imshow(self, map_kws={}, roll_dateline=False, **kwargs): + def quick_imshow(self, map_kws=None, roll_dateline=False, **kwargs): """This function takes an xarray DataArray and quickly cerates a figure using cartopy and the matplotlib imshow. Note that this should only be used for regular grids. @@ -1050,6 +1052,9 @@ def quick_imshow(self, map_kws={}, roll_dateline=False, **kwargs): from .plots import _dynamic_fig_size from .plots.mapgen import draw_map + if map_kws is None: + map_kws = {} + sns.set_context("notebook", font_scale=1.2) da = _dataset_to_monet(self._obj) da = _monet_to_latlon(da) @@ -1069,10 +1074,7 @@ def quick_imshow(self, map_kws={}, roll_dateline=False, **kwargs): kwargs.pop("transform", None) if "ax" not in kwargs: ax = draw_map(**map_kws) - try: - ax.axes.outline_patch.set_alpha(0) - except AttributeError: - ax.outline_patch.set_alpha(0) + _set_outline_patch_alpha(ax) if roll_dateline: _ = ( da.squeeze() @@ -1084,7 +1086,7 @@ def quick_imshow(self, map_kws={}, roll_dateline=False, **kwargs): plt.tight_layout() return ax - def quick_map(self, map_kws={}, roll_dateline=False, **kwargs): + def quick_map(self, map_kws=None, roll_dateline=False, **kwargs): """This function takes an xarray DataArray and quickly cerates a figure using cartopy and the matplotlib pcolormesh @@ -1112,6 +1114,9 @@ def quick_map(self, map_kws={}, roll_dateline=False, **kwargs): from .plots import _dynamic_fig_size from .plots.mapgen import draw_map + if map_kws is None: + map_kws = {} + sns.set_context("notebook") da = _dataset_to_monet(self._obj) crs_p = ccrs.PlateCarree() @@ -1126,10 +1131,7 @@ def quick_map(self, map_kws={}, roll_dateline=False, **kwargs): transform = kwargs.pop("transform", crs_p) if "ax" not in kwargs: ax = draw_map(**map_kws) - try: - ax.axes.outline_patch.set_alpha(0) - except AttributeError: - ax.outline_patch.set_alpha(0) + _set_outline_patch_alpha(ax) if roll_dateline: _ = da.roll(x=int(len(da.x) / 2), roll_coords=True).plot( x="longitude", y="latitude", ax=ax, transform=transform, **kwargs @@ -1139,7 +1141,7 @@ def quick_map(self, map_kws={}, roll_dateline=False, **kwargs): plt.tight_layout() return ax - def quick_contourf(self, map_kws={}, roll_dateline=False, **kwargs): + def quick_contourf(self, map_kws=None, roll_dateline=False, **kwargs): """This function takes an xarray DataArray and quickly cerates a figure using cartopy and the matplotlib contourf @@ -1167,6 +1169,9 @@ def quick_contourf(self, map_kws={}, roll_dateline=False, **kwargs): from monet.plots import _dynamic_fig_size from monet.plots.mapgen import draw_map + if map_kws is None: + map_kws = {} + sns.set_context("notebook") da = _dataset_to_monet(self._obj) crs_p = ccrs.PlateCarree() @@ -1185,10 +1190,7 @@ def quick_contourf(self, map_kws={}, roll_dateline=False, **kwargs): kwargs.pop("transform", None) if "ax" not in kwargs: ax = draw_map(**map_kws) - try: - ax.axes.outline_patch.set_alpha(0) - except AttributeError: - ax.outline_patch.set_alpha(0) + _set_outline_patch_alpha(ax) if roll_dateline: _ = da.roll(x=int(len(da.x) / 2), roll_coords=True).plot.contourf( x="longitude", y="latitude", ax=ax, transform=transform, **kwargs diff --git a/monet/plots/__init__.py b/monet/plots/__init__.py index d04d0d05..f7908e28 100644 --- a/monet/plots/__init__.py +++ b/monet/plots/__init__.py @@ -1,3 +1,5 @@ +import warnings + from .colorbars import cmap_discretize, colorbar_index from .mapgen import draw_map from .plots import ( @@ -187,8 +189,8 @@ def sp_scatter_bias( colorbar=True, **kwargs, ) - if ~outline: - ax.outline_patch.set_alpha(0) + if not outline: + _set_outline_patch_alpha(ax) if global_map: plt.xlim([-180, 180]) plt.ylim([-90, 90]) @@ -197,3 +199,19 @@ def sp_scatter_bias( return ax except ValueError: exit + + +def _set_outline_patch_alpha(ax, alpha=0): + for f in [ + lambda alpha: ax.axes.outline_patch.set_alpha(alpha), + lambda alpha: ax.outline_patch.set_alpha(alpha), + lambda alpha: ax.spines["geo"].set_alpha(alpha), + ]: + try: + f(alpha) + except AttributeError: + continue + else: + break + else: + warnings.warn("unable to set outline_patch alpha", stacklevel=2) diff --git a/tests/test_plot.py b/tests/test_plot.py new file mode 100644 index 00000000..fe8c26b2 --- /dev/null +++ b/tests/test_plot.py @@ -0,0 +1,23 @@ +import cartopy +import pytest +import xarray as xr +from packaging.version import Version + +import monet # noqa: F401 + +cartopy_version = Version(cartopy.__version__) + +da = xr.tutorial.load_dataset("air_temperature").air.isel(time=1) + + +@pytest.mark.parametrize("which", ["imshow", "map", "contourf"]) +def test_quick_(which): + getattr(da.monet, f"quick_{which}")() + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + test_quick_("map") + + plt.show()