diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/raster_base.py index 74616fea..6af50dfa 100644 --- a/terracotta/drivers/raster_base.py +++ b/terracotta/drivers/raster_base.py @@ -393,8 +393,11 @@ def _get_resampling_enum(method: str) -> Any: @staticmethod def _has_alpha_band(src: 'DatasetReader') -> bool: - from rasterio.enums import MaskFlags - return any([MaskFlags.per_dataset in flags for flags in src.mask_flag_enums]) + from rasterio.enums import MaskFlags, ColorInterp + return ( + any([MaskFlags.alpha in flags for flags in src.mask_flag_enums]) + or ColorInterp.alpha in src.colorinterp + ) @classmethod @trace('get_raster_tile') @@ -478,8 +481,9 @@ def _get_raster_tile(cls, path: str, *, # construct VRT vrt = es.enter_context( WarpedVRT( - src, crs=cls._TARGET_CRS, resampling=reproject_enum, add_alpha=True, - transform=vrt_transform, width=vrt_width, height=vrt_height + src, crs=cls._TARGET_CRS, resampling=reproject_enum, + transform=vrt_transform, width=vrt_width, height=vrt_height, + add_alpha=not cls._has_alpha_band(src) ) ) @@ -489,9 +493,11 @@ def _get_raster_tile(cls, path: str, *, tile_data = vrt.read( 1, resampling=resampling_enum, window=out_window, out_shape=tile_size ) - # read alpha mask - mask_idx = src.count + 1 + + # assemble alpha mask + mask_idx = vrt.count mask = vrt.read(mask_idx, window=out_window, out_shape=tile_size) == 0 + if src.nodata is not None: mask |= tile_data == src.nodata