From b60edbd696f8ce22bdfc2262dbf67c0f9a3b3117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Wed, 12 Sep 2018 11:36:54 +0200 Subject: [PATCH 01/15] remove /colormaps api endpoint --- terracotta/api/colormaps.py | 37 -------------- terracotta/api/flask_api.py | 6 +-- terracotta/api/legend.py | 83 -------------------------------- terracotta/handlers/colormaps.py | 15 ------ terracotta/handlers/legend.py | 32 ------------ tests/api/test_flask_api.py | 18 +++---- tests/handlers/test_colormaps.py | 7 --- tests/handlers/test_legend.py | 17 ------- 8 files changed, 8 insertions(+), 207 deletions(-) delete mode 100644 terracotta/api/colormaps.py delete mode 100644 terracotta/api/legend.py delete mode 100644 terracotta/handlers/colormaps.py delete mode 100644 terracotta/handlers/legend.py delete mode 100644 tests/handlers/test_colormaps.py delete mode 100644 tests/handlers/test_legend.py diff --git a/terracotta/api/colormaps.py b/terracotta/api/colormaps.py deleted file mode 100644 index a3d075c9..00000000 --- a/terracotta/api/colormaps.py +++ /dev/null @@ -1,37 +0,0 @@ -"""api/colormaps.py - -Flask route to handle /colormaps calls. -""" - -from flask import jsonify -from marshmallow import Schema, fields - -from terracotta.api.flask_api import convert_exceptions, metadata_api, spec - - -class ColormapSchema(Schema): - colormaps = fields.List(fields.String(example='viridis'), required=True) - - -@metadata_api.route('/colormaps', methods=['GET']) -@convert_exceptions -def get_colormaps() -> str: - """Get all registered colormaps. - --- - get: - summary: /colormaps - description: - Get all registered colormaps. For a preview see - https://matplotlib.org/examples/color/colormaps_reference.html - responses: - 200: - description: List of names of all colormaps - schema: ColormapSchema - """ - from terracotta.handlers.colormaps import colormaps - payload = {'colormaps': colormaps()} - schema = ColormapSchema() - return jsonify(schema.load(payload)) - - -spec.definition('Colormaps', schema=ColormapSchema) diff --git a/terracotta/api/flask_api.py b/terracotta/api/flask_api.py index 999f63c4..516c8534 100644 --- a/terracotta/api/flask_api.py +++ b/terracotta/api/flask_api.py @@ -82,10 +82,9 @@ def create_app(debug: bool = False, new_app.debug = debug # import submodules to populate blueprints - import terracotta.api.colormaps import terracotta.api.datasets import terracotta.api.keys - import terracotta.api.legend + import terracotta.api.colormap import terracotta.api.metadata import terracotta.api.rgb import terracotta.api.singleband @@ -95,10 +94,9 @@ def create_app(debug: bool = False, # register routes on API spec with new_app.test_request_context(): - spec.add_path(view=terracotta.api.colormaps.get_colormaps) spec.add_path(view=terracotta.api.datasets.get_datasets) spec.add_path(view=terracotta.api.keys.get_keys) - spec.add_path(view=terracotta.api.legend.get_legend) + spec.add_path(view=terracotta.api.colormap.get_colormap) spec.add_path(view=terracotta.api.metadata.get_metadata) spec.add_path(view=terracotta.api.rgb.get_rgb) spec.add_path(view=terracotta.api.singleband.get_singleband) diff --git a/terracotta/api/legend.py b/terracotta/api/legend.py deleted file mode 100644 index b78e8b3b..00000000 --- a/terracotta/api/legend.py +++ /dev/null @@ -1,83 +0,0 @@ -"""api/keys.py - -Flask route to handle /legend calls. -""" - -from typing import Any, Mapping, Dict -import json - -from flask import jsonify, request -from marshmallow import Schema, fields, validate, pre_load, ValidationError, EXCLUDE - -from terracotta.api.flask_api import convert_exceptions, metadata_api, spec -from terracotta.cmaps import AVAILABLE_CMAPS - - -class LegendEntrySchema(Schema): - value = fields.Number(required=True) - rgb = fields.List(fields.Number(), required=True, validate=validate.Length(equal=3)) - - -class LegendSchema(Schema): - legend = fields.Nested(LegendEntrySchema, many=True, required=True) - - -class LegendOptionSchema(Schema): - class Meta: - unknown = EXCLUDE - - stretch_range = fields.List( - fields.Number(), validate=validate.Length(equal=2), required=True, - description='Minimum and maximum value of colormap as JSON array ' - '(same as for /singleband and /rgb)' - ) - colormap = fields.String(description='Name of color map to use (see /colormap)', - missing=None, validate=validate.OneOf(AVAILABLE_CMAPS)) - num_values = fields.Int(description='Number of values to return', missing=255) - - @pre_load - def process_ranges(self, data: Mapping[str, Any]) -> Dict[str, Any]: - data = dict(data.items()) - var = 'stretch_range' - val = data.get(var) - if val: - try: - data[var] = json.loads(val) - except json.decoder.JSONDecodeError as exc: - raise ValidationError(f'Could not decode value for {var} as JSON') from exc - return data - - -@metadata_api.route('/legend', methods=['GET']) -@convert_exceptions -def get_legend() -> str: - """Get a legend mapping pixel values to colors - --- - get: - summary: /legend - description: - Get a legend mapping pixel values to colors. Use this to construct a color bar for a - dataset. - parameters: - - in: query - schema: LegendOptionSchema - responses: - 200: - description: Array containing data values and RGBA tuples - schema: LegendSchema - 400: - description: Query parameters are invalid - """ - from terracotta.handlers.legend import legend - - input_schema = LegendOptionSchema() - options = input_schema.load(request.args) - - payload = {'legend': legend(**options)} - - schema = LegendSchema() - return jsonify(schema.load(payload)) - - -spec.definition('LegendEntry', schema=LegendEntrySchema) -spec.definition('Legend', schema=LegendSchema) diff --git a/terracotta/handlers/colormaps.py b/terracotta/handlers/colormaps.py deleted file mode 100644 index 623abd97..00000000 --- a/terracotta/handlers/colormaps.py +++ /dev/null @@ -1,15 +0,0 @@ -"""handlers/colormaps.py - -Handle /colormaps API endpoint. -""" - -from typing import List - -from terracotta.profile import trace - - -@trace('colormaps_handler') -def colormaps() -> List[str]: - """Return all supported colormaps""" - from terracotta.cmaps import AVAILABLE_CMAPS - return list(sorted(AVAILABLE_CMAPS)) diff --git a/terracotta/handlers/legend.py b/terracotta/handlers/legend.py deleted file mode 100644 index 3176060a..00000000 --- a/terracotta/handlers/legend.py +++ /dev/null @@ -1,32 +0,0 @@ -"""handlers/legend.py - -Handle /legend API endpoint. -""" - -from typing import List, Tuple, TypeVar, Dict, Any - -import numpy as np - -from terracotta.profile import trace - -Number = TypeVar('Number', 'int', 'float') - - -@trace('legend_handler') -def legend(*, stretch_range: Tuple[Number, Number], - colormap: str = None, - num_values: int = 255) -> List[Dict[str, Any]]: - """Returns a list [{value=pixel value, rgb=rgb tuple}] for given stretch parameters""" - target_coords = np.linspace(stretch_range[0], stretch_range[1], num_values) - - if colormap is not None: - from terracotta.cmaps import get_cmap - cmap = get_cmap(colormap) - else: - # assemble greyscale cmap of shape (255, 3) - cmap = np.tile(np.arange(1, 256, dtype='uint8')[:, np.newaxis], (1, 3)) - - cmap_coords = np.around(np.linspace(0, len(cmap) - 1, num_values)).astype('uint8') - colors = cmap[cmap_coords] - - return [dict(value=p, rgb=c) for p, c in zip(target_coords.tolist(), colors.tolist())] diff --git a/tests/api/test_flask_api.py b/tests/api/test_flask_api.py index da156a95..f0fd03f2 100644 --- a/tests/api/test_flask_api.py +++ b/tests/api/test_flask_api.py @@ -19,12 +19,6 @@ def client(flask_app): yield client -def test_get_colormaps(client): - rv = client.get('/colormaps') - assert rv.status_code == 200 - assert 'jet' in json.loads(rv.data)['colormaps'] - - def test_get_keys(client, use_read_only_database): rv = client.get('/keys') assert rv.status_code == 200 @@ -171,16 +165,16 @@ def test_get_rgb_stretch(client, use_read_only_database, raster_file_xyz): assert np.asarray(img).shape == (*settings.TILE_SIZE, 3) -def test_get_legend(client): - rv = client.get('/legend?stretch_range=[0,1]&num_values=100') +def test_get_colormap(client): + rv = client.get('/colormap?stretch_range=[0,1]&num_values=100') assert rv.status_code == 200 - assert len(json.loads(rv.data)['legend']) == 100 + assert len(json.loads(rv.data)['colormap']) == 100 -def test_get_legend_extra_args(client): - rv = client.get('/legend?stretch_range=[0,1]&num_values=100&foo=bar&baz=quz') +def test_get_colormap_extra_args(client): + rv = client.get('/colormap?stretch_range=[0,1]&num_values=100&foo=bar&baz=quz') assert rv.status_code == 200 - assert len(json.loads(rv.data)['legend']) == 100 + assert len(json.loads(rv.data)['colormap']) == 100 def test_get_preview(client): diff --git a/tests/handlers/test_colormaps.py b/tests/handlers/test_colormaps.py deleted file mode 100644 index 2b32121d..00000000 --- a/tests/handlers/test_colormaps.py +++ /dev/null @@ -1,7 +0,0 @@ - -def test_colormaps(): - from terracotta.handlers import colormaps - assert colormaps.colormaps() - - from matplotlib.cm import cmap_d - assert all(cmap.lower() in colormaps.colormaps() for cmap in cmap_d) diff --git a/tests/handlers/test_legend.py b/tests/handlers/test_legend.py deleted file mode 100644 index d890f645..00000000 --- a/tests/handlers/test_legend.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np - - -def test_legend_handler(): - from terracotta.handlers import legend - leg = legend.legend(colormap='jet', stretch_range=[0., 1.], num_values=50) - assert leg - assert len(leg) == 50 - assert len(leg[0]['rgb']) == 3 - assert leg[0]['value'] == 0. and leg[-1]['value'] == 1. - - -def test_nocmap(): - from terracotta.handlers import legend - leg = legend.legend(stretch_range=[0., 1.], num_values=255) - leg_array = np.array([row['rgb'] for row in leg]) - np.testing.assert_array_equal(leg_array, np.tile(np.arange(1, 256)[:, np.newaxis], (1, 3))) From a43e40595df716af97bb4f4b7d09aa87ec7c3032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Fri, 14 Sep 2018 11:54:36 +0200 Subject: [PATCH 02/15] properly test colormap call (fixes #47) --- terracotta/api/colormap.py | 86 +++++++++++++++++++++++++++++++++ terracotta/handlers/colormap.py | 34 +++++++++++++ tests/conftest.py | 9 ++-- tests/handlers/test_colormap.py | 75 ++++++++++++++++++++++++++++ 4 files changed, 198 insertions(+), 6 deletions(-) create mode 100644 terracotta/api/colormap.py create mode 100644 terracotta/handlers/colormap.py create mode 100644 tests/handlers/test_colormap.py diff --git a/terracotta/api/colormap.py b/terracotta/api/colormap.py new file mode 100644 index 00000000..02185045 --- /dev/null +++ b/terracotta/api/colormap.py @@ -0,0 +1,86 @@ +"""api/keys.py + +Flask route to handle /colormap calls. +""" + +from typing import Any, Mapping, Dict +import json + +from flask import jsonify, request +from marshmallow import Schema, fields, validate, pre_load, ValidationError, EXCLUDE + +from terracotta.api.flask_api import convert_exceptions, metadata_api, spec +from terracotta.cmaps import AVAILABLE_CMAPS + + +class colormapEntrySchema(Schema): + value = fields.Number(required=True) + rgb = fields.List(fields.Number(), required=True, validate=validate.Length(equal=3)) + + +class colormapSchema(Schema): + colormap = fields.Nested(colormapEntrySchema, many=True, required=True) + + +class colormapOptionSchema(Schema): + class Meta: + unknown = EXCLUDE + + stretch_range = fields.List( + fields.Number(), validate=validate.Length(equal=2), required=True, + description='Minimum and maximum value of colormap as JSON array ' + '(same as for /singleband and /rgb)' + ) + colormap = fields.String( + description='Name of color map to use (for a preview see ' + 'https://matplotlib.org/examples/color/colormaps_reference.html)', + missing=None, validate=validate.OneOf(AVAILABLE_CMAPS) + ) + num_values = fields.Int(description='Number of values to return', missing=255) + + @pre_load + def process_ranges(self, data: Mapping[str, Any]) -> Dict[str, Any]: + data = dict(data.items()) + var = 'stretch_range' + val = data.get(var) + if val: + try: + data[var] = json.loads(val) + except json.decoder.JSONDecodeError as exc: + raise ValidationError(f'Could not decode value for {var} as JSON') from exc + return data + + +@metadata_api.route('/colormap', methods=['GET']) +@convert_exceptions +def get_colormap() -> str: + """Get a colormap mapping pixel values to colors + --- + get: + summary: /colormap + description: + Get a colormap mapping pixel values to colors. Use this to construct a color bar for a + dataset. + parameters: + - in: query + schema: colormapOptionSchema + responses: + 200: + description: Array containing data values and RGBA tuples + schema: colormapSchema + 400: + description: Query parameters are invalid + """ + from terracotta.handlers.colormap import colormap + + input_schema = colormapOptionSchema() + options = input_schema.load(request.args) + + payload = {'colormap': colormap(**options)} + + schema = colormapSchema() + return jsonify(schema.load(payload)) + + +spec.definition('colormapEntry', schema=colormapEntrySchema) +spec.definition('colormap', schema=colormapSchema) diff --git a/terracotta/handlers/colormap.py b/terracotta/handlers/colormap.py new file mode 100644 index 00000000..0df533a5 --- /dev/null +++ b/terracotta/handlers/colormap.py @@ -0,0 +1,34 @@ +"""handlers/colormap.py + +Handle /colormap API endpoint. +""" + +from typing import List, Tuple, TypeVar, Dict, Any + +import numpy as np + +from terracotta.profile import trace + +Number = TypeVar('Number', 'int', 'float') + + +@trace('colormap_handler') +def colormap(*, stretch_range: Tuple[Number, Number], + colormap: str = None, + num_values: int = 255) -> List[Dict[str, Any]]: + """Returns a list [{value=pixel value, rgb=rgb tuple}] for given stretch parameters""" + from terracotta import image + + target_coords = np.linspace(stretch_range[0], stretch_range[1], num_values) + + if colormap is not None: + from terracotta.cmaps import get_cmap + cmap = get_cmap(colormap) + else: + # assemble greyscale cmap of shape (255, 3) + cmap = np.tile(np.arange(1, 256, dtype='uint8')[:, np.newaxis], (1, 3)) + + cmap_coords = image.to_uint8(target_coords, *stretch_range) - 1 + colors = cmap[cmap_coords] + + return [dict(value=p, rgb=c) for p, c in zip(target_coords.tolist(), colors.tolist())] diff --git a/tests/conftest.py b/tests/conftest.py index acf14d0d..3b375c0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,14 +15,11 @@ def pytest_unconfigure(config): @pytest.fixture(scope='session') -def raster_data(): - return np.arange(-128 * 256, 128 * 256, dtype='int16').reshape(256, 256) - - -@pytest.fixture(scope='session') -def raster_file(tmpdir_factory, raster_data): +def raster_file(tmpdir_factory): import affine + raster_data = np.arange(-128 * 256, 128 * 256, dtype='int16').reshape(256, 256) + profile = { 'driver': 'GTiff', 'dtype': 'int16', diff --git a/tests/handlers/test_colormap.py b/tests/handlers/test_colormap.py new file mode 100644 index 00000000..36afafd4 --- /dev/null +++ b/tests/handlers/test_colormap.py @@ -0,0 +1,75 @@ +import numpy as np +from PIL import Image + +import pytest + + +def test_colormap_handler(): + from terracotta.handlers import colormap + cmap = colormap.colormap(colormap='jet', stretch_range=[0., 1.], num_values=50) + assert cmap + assert len(cmap) == 50 + assert len(cmap[0]['rgb']) == 3 + assert cmap[0]['value'] == 0. and cmap[-1]['value'] == 1. + + +@pytest.mark.parametrize('stretch_range', [[0, 20000], [20000, 30000], [-50000, 50000]]) +@pytest.mark.parametrize('cmap_name', [None, 'jet']) +def test_colormap_consistency(use_read_only_database, read_only_database, raster_file_xyz, + stretch_range, cmap_name): + import terracotta + from terracotta.xyz import get_tile_data + from terracotta.handlers import singleband, colormap + + nodata = 10000 + ds_keys = ['val21', 'val22'] + + # get image with applied stretch and colormap + raw_img = singleband.singleband(ds_keys, raster_file_xyz, stretch_range=stretch_range, + colormap=cmap_name) + img_data = np.asarray(Image.open(raw_img).convert('RGBA')) + + # get raw data to compare to + driver = terracotta.get_driver(read_only_database) + tile_x, tile_y, tile_z = raster_file_xyz + + with driver.connect(): + tile_data = get_tile_data(driver, ds_keys, tile_x=tile_x, tile_y=tile_y, tile_z=tile_z, + tilesize=img_data.shape[:2]) + + # make sure all pixel values are included in colormap + num_values = stretch_range[1] - stretch_range[0] + 1 + + # get colormap for given stretch + cmap = colormap.colormap(colormap=cmap_name, stretch_range=stretch_range, + num_values=num_values) + cmap = dict(row.values() for row in cmap) + + # test nodata + nodata_mask = tile_data == nodata + assert np.all(img_data[nodata_mask, -1] == 0) + + # test clipping + below_mask = tile_data < stretch_range[0] + assert np.all(img_data[below_mask & ~nodata_mask, :-1] == cmap[stretch_range[0]]) + + above_mask = tile_data > stretch_range[1] + assert np.all(img_data[above_mask & ~nodata_mask, :-1] == cmap[stretch_range[1]]) + + # test values inside stretch_range + values_to_test = np.unique(tile_data) + values_to_test = values_to_test[(values_to_test >= stretch_range[0]) & + (values_to_test <= stretch_range[1]) & + (values_to_test != nodata)] + + for val in values_to_test: + rgb = cmap[val] + assert np.all(img_data[tile_data == val, :-1] == rgb) + + + +def test_nocmap(): + from terracotta.handlers import colormap + cmap = colormap.colormap(stretch_range=[0., 1.], num_values=255) + cmap_array = np.array([row['rgb'] for row in cmap]) + np.testing.assert_array_equal(cmap_array, np.tile(np.arange(1, 256)[:, np.newaxis], (1, 3))) From 85f131d62fe5bdb2035a00f7c2c5745983154b91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Fri, 14 Sep 2018 15:00:41 +0200 Subject: [PATCH 03/15] implement explicit color mapping support --- terracotta/drivers/base.py | 3 +- terracotta/drivers/raster_base.py | 21 ++++++---- terracotta/handlers/singleband.py | 46 +++++++++++++++------ terracotta/image.py | 68 ++++++++++++++++++++++++------- terracotta/xyz.py | 11 +++-- tests/handlers/test_colormap.py | 8 ---- tests/handlers/test_singleband.py | 28 +++++++++++++ 7 files changed, 139 insertions(+), 46 deletions(-) diff --git a/terracotta/drivers/base.py b/terracotta/drivers/base.py index ef0b3065..449a1b3b 100644 --- a/terracotta/drivers/base.py +++ b/terracotta/drivers/base.py @@ -71,7 +71,8 @@ def get_metadata(self, keys: Union[Sequence[str], Mapping[str, str]]) -> Dict[st def get_raster_tile(self, keys: Union[Sequence[str], Mapping[str, str]], *, bounds: Sequence[float] = None, tilesize: Sequence[int] = (256, 256), - nodata: Number = 0) -> np.ndarray: + nodata: Number = 0, + preserve_values: bool = False) -> np.ndarray: """Get raster tile as a NumPy array for given keys.""" pass diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/raster_base.py index 9640b623..c3719d14 100644 --- a/terracotta/drivers/raster_base.py +++ b/terracotta/drivers/raster_base.py @@ -266,8 +266,10 @@ def _calculate_default_transform(src_crs: Union[Dict[str, str], str], def _get_raster_tile(self, keys: Tuple[str], *, bounds: Tuple[float, float, float, float] = None, tilesize: Tuple[int, int] = (256, 256), - nodata: Number = 0) -> np.ndarray: + nodata: Number = 0, + preserve_values: bool = False) -> np.ndarray: """Load a raster dataset from a file through rasterio. + Heavily inspired by mapbox/rio-tiler """ import rasterio @@ -283,8 +285,12 @@ def _get_raster_tile(self, keys: Tuple[str], *, path = path[keys] target_crs = 'epsg:3857' - upsampling_method = settings.UPSAMPLING_METHOD - upsampling_enum = self._get_resampling_enum(upsampling_method) + + if preserve_values: + upsampling_enum = downsampling_enum = self._get_resampling_enum('nearest') + else: + upsampling_enum = self._get_resampling_enum(settings.UPSAMPLING_METHOD) + downsampling_enum = self._get_resampling_enum(settings.DOWNSAMPLING_METHOD) with contextlib.ExitStack() as es: try: @@ -338,8 +344,7 @@ def _get_raster_tile(self, keys: Tuple[str], *, if window_ratio > 1: resampling_enum = upsampling_enum else: - downsampling_method = settings.DOWNSAMPLING_METHOD - resampling_enum = self._get_resampling_enum(downsampling_method) + resampling_enum = downsampling_enum # read data with warnings.catch_warnings(), trace('read_from_vrt'): @@ -354,7 +359,8 @@ def _get_raster_tile(self, keys: Tuple[str], *, def get_raster_tile(self, keys: Union[Sequence[str], Mapping[str, str]], *, bounds: Sequence[float] = None, tilesize: Sequence[int] = (256, 256), - nodata: Number = 0) -> np.ndarray: + nodata: Number = 0, + preserve_values: bool = False) -> np.ndarray: """Load tile with given keys or metadata""" # make sure all arguments are hashable _keys = self._key_dict_to_sequence(keys) @@ -362,5 +368,6 @@ def get_raster_tile(self, keys: Union[Sequence[str], Mapping[str, str]], *, tuple(_keys), bounds=tuple(bounds) if bounds else None, tilesize=tuple(tilesize), - nodata=nodata + nodata=nodata, + preserve_values=preserve_values ) diff --git a/terracotta/handlers/singleband.py b/terracotta/handlers/singleband.py index be523a88..ce263481 100644 --- a/terracotta/handlers/singleband.py +++ b/terracotta/handlers/singleband.py @@ -3,20 +3,27 @@ Handle /singleband API endpoint. """ -from typing import Sequence, Mapping, Union, Tuple, TypeVar +from typing import Sequence, Mapping, Union, Tuple, Optional, TypeVar, cast from typing.io import BinaryIO +import collections + from terracotta import get_settings, get_driver, image, xyz from terracotta.profile import trace Number = TypeVar('Number', int, float) +RGB = Tuple[Number, Number, Number] @trace('singleband_handler') -def singleband(keys: Union[Sequence[str], Mapping[str, str]], tile_xyz: Sequence[int], *, - colormap: str = None, stretch_range: Tuple[Number, Number] = None) -> BinaryIO: +def singleband(keys: Union[Sequence[str], Mapping[str, str]], + tile_xyz: Sequence[int], *, + colormap: Union[str, Mapping[Number, RGB], None] = None, + stretch_range: Tuple[Number, Number] = None) -> BinaryIO: """Return singleband image as PNG""" + colormap_: Union[str, Sequence[RGB], None] + try: tile_x, tile_y, tile_z = tile_xyz except ValueError: @@ -27,25 +34,40 @@ def singleband(keys: Union[Sequence[str], Mapping[str, str]], tile_xyz: Sequence else: stretch_min, stretch_max = stretch_range + preserve_values = isinstance(colormap, collections.Mapping) + settings = get_settings() driver = get_driver(settings.DRIVER_PATH, provider=settings.DRIVER_PROVIDER) with driver.connect(): metadata = driver.get_metadata(keys) tile_size = settings.TILE_SIZE - tile_data = xyz.get_tile_data(driver, keys, tile_x=tile_x, tile_y=tile_y, tile_z=tile_z, - tilesize=tile_size) + tile_data = xyz.get_tile_data( + driver, keys, tile_x=tile_x, tile_y=tile_y, tile_z=tile_z, + tilesize=tile_size, preserve_values=preserve_values + ) valid_mask = image.get_valid_mask(tile_data, nodata=metadata['nodata']) - stretch_range_ = list(metadata['range']) + if preserve_values: + # bin output image into supplied labels, starting at 1 + colormap = cast(Mapping, colormap) + + labels, label_colors = list(colormap.keys()), list(colormap.values()) + + colormap_ = label_colors + out = image.label(tile_data, labels) + else: + # determine stretch range from metadata and arguments + stretch_range_ = list(metadata['range']) - if stretch_min is not None: - stretch_range_[0] = stretch_min + if stretch_min is not None: + stretch_range_[0] = stretch_min - if stretch_max is not None: - stretch_range_[1] = stretch_max + if stretch_max is not None: + stretch_range_[1] = stretch_max - out = image.to_uint8(tile_data, *stretch_range_) + colormap_ = cast(Optional[str], colormap) + out = image.to_uint8(tile_data, *stretch_range_) - return image.array_to_png(out, transparency_mask=~valid_mask, colormap=colormap) + return image.array_to_png(out, transparency_mask=~valid_mask, colormap=colormap_) diff --git a/terracotta/image.py b/terracotta/image.py index 31182b57..136ea3a9 100644 --- a/terracotta/image.py +++ b/terracotta/image.py @@ -5,6 +5,7 @@ from typing import Sequence, Tuple, TypeVar, Union from typing.io import BinaryIO + from io import BytesIO import numpy as np @@ -14,17 +15,18 @@ from terracotta import exceptions, get_settings Number = TypeVar('Number', int, float) +Palette = Sequence[Tuple[Number, Number, Number]] @trace('array_to_png') def array_to_png(arr: np.ndarray, transparency_mask: np.ndarray = None, - colormap: str = None) -> BinaryIO: + colormap: Union[str, Palette, None] = None) -> BinaryIO: from terracotta.cmaps import get_cmap - settings = get_settings() - transparency: Union[Tuple[int, int, int], int] + + settings = get_settings() compress_level = settings.PNG_COMPRESS_LEVEL if arr.ndim == 3: # encode RGB image @@ -43,17 +45,31 @@ def array_to_png(arr: np.ndarray, transparency = 0 if colormap is not None: - try: - cmap_vals = get_cmap(colormap) - except ValueError as exc: - raise exceptions.InvalidArgumentsError( - f'Encountered invalid color map {colormap}') from exc - - palette = np.concatenate(( - np.zeros(3, dtype='uint8'), - cmap_vals.flatten() - )) - assert palette.shape == (3 * 256,) + if isinstance(colormap, str): + # get and apply colormap by name + try: + cmap_vals = get_cmap(colormap) + except ValueError as exc: + raise exceptions.InvalidArgumentsError( + f'Encountered invalid color map {colormap}') from exc + + palette = np.concatenate(( + np.zeros(3, dtype='uint8'), + cmap_vals.flatten() + )) + else: + # explicit mapping + if len(colormap) > 255: + msg = 'Explicit color map must contain less than 256 values' + raise exceptions.InvalidArgumentsError(msg) + + palette = np.concatenate(( + np.zeros(3, dtype='uint8'), + np.array(colormap, dtype='uint8').flatten(), + np.zeros(3 * (256 - len(colormap) - 1), dtype='uint8') + )) + + assert palette.shape == (3 * 256,), palette.shape else: palette = None @@ -91,7 +107,7 @@ def get_valid_mask(data: np.ndarray, nodata: Number) -> np.ndarray: """Return mask for data, masking out nodata and invalid values""" out = data != nodata - # Also mask out other invalid values if float + # also mask out other invalid values if float if np.issubdtype(data.dtype, np.floating): out &= np.isfinite(data) @@ -126,3 +142,25 @@ def to_uint8(data: np.ndarray, lower_bound: Number, upper_bound: Number) -> np.n """Re-scale an array to [1, 255] and cast to uint8 (0 is used for transparency)""" rescaled = contrast_stretch(data, (lower_bound, upper_bound), (1, 255), clip=True) return rescaled.astype(np.uint8) + + +def label(data: np.ndarray, labels: Sequence[Number]) -> np.ndarray: + """Create a labelled uint8 version of data, with output values starting at 1. + + Values not found in labels are replaced by 0. + + Example: + + >>> data = np.array([15, 16, 17]) + >>> label(data, [17, 15]) + np.array([2, 0, 1]) + + """ + if len(labels) > 255: + raise ValueError('Cannot fit more than 255 labels') + + out_data = np.zeros(data.shape, dtype='uint8') + for i, label in enumerate(labels, 1): + out_data[data == label] = i + + return out_data diff --git a/terracotta/xyz.py b/terracotta/xyz.py index f4936235..f7636cef 100644 --- a/terracotta/xyz.py +++ b/terracotta/xyz.py @@ -14,28 +14,33 @@ def get_tile_data(driver: Driver, keys: Union[Sequence[str], Mapping[str, str]], tile_x: int, tile_y: int, tile_z: int, *, - tilesize: Sequence[int] = (256, 256)) -> np.ndarray: + tilesize: Sequence[int] = (256, 256), + preserve_values: bool = False) -> np.ndarray: """Retrieve xyz tile data from given driver""" metadata = driver.get_metadata(keys) nodata = metadata['nodata'] wgs_bounds = metadata['bounds'] + if not tile_exists(wgs_bounds, tile_x, tile_y, tile_z): raise exceptions.TileOutOfBoundsError( f'Tile {tile_z}/{tile_x}/{tile_y} is outside image bounds' ) + target_bounds = get_xy_bounds(tile_x, tile_y, tile_z) - return driver.get_raster_tile(keys, bounds=target_bounds, tilesize=tilesize, nodata=nodata) + + return driver.get_raster_tile(keys, bounds=target_bounds, tilesize=tilesize, + nodata=nodata, preserve_values=preserve_values) def get_xy_bounds(tile_x: int, tile_y: int, tile_z: int) -> Tuple[float]: """Retrieve physical bounds covered by given xyz tile.""" mercator_tile = mercantile.Tile(x=tile_x, y=tile_y, z=tile_z) + return mercantile.xy_bounds(mercator_tile) def tile_exists(bounds: Sequence[float], tile_x: int, tile_y: int, tile_z: int) -> bool: """Check if a mercatile tile is inside a given bounds.""" - mintile = mercantile.tile(bounds[0], bounds[3], tile_z) maxtile = mercantile.tile(bounds[2], bounds[1], tile_z) diff --git a/tests/handlers/test_colormap.py b/tests/handlers/test_colormap.py index 36afafd4..08a4205a 100644 --- a/tests/handlers/test_colormap.py +++ b/tests/handlers/test_colormap.py @@ -65,11 +65,3 @@ def test_colormap_consistency(use_read_only_database, read_only_database, raster for val in values_to_test: rgb = cmap[val] assert np.all(img_data[tile_data == val, :-1] == rgb) - - - -def test_nocmap(): - from terracotta.handlers import colormap - cmap = colormap.colormap(stretch_range=[0., 1.], num_values=255) - cmap_array = np.array([row['rgb'] for row in cmap]) - np.testing.assert_array_equal(cmap_array, np.tile(np.arange(1, 256)[:, np.newaxis], (1, 3))) diff --git a/tests/handlers/test_singleband.py b/tests/handlers/test_singleband.py index a961fc12..7d7b59ea 100644 --- a/tests/handlers/test_singleband.py +++ b/tests/handlers/test_singleband.py @@ -61,3 +61,31 @@ def test_singleband_stretch(stretch_range, use_read_only_database, read_only_dat stretch_range_mask = (valid_data > stretch_range[0]) & (valid_data < stretch_range[1]) assert not np.any(np.isin(valid_img[stretch_range_mask], [1, 255])) assert np.all(valid_img[valid_data > stretch_range[1]] == 255) + + +def test_explicit_colormap(use_read_only_database, read_only_database, raster_file_xyz): + import terracotta + from terracotta.xyz import get_tile_data + from terracotta.handlers import singleband + + ds_keys = ['val21', 'val22'] + colormap = {i: (i, i, i) for i in range(150)} + + raw_img = singleband.singleband(ds_keys, raster_file_xyz, colormap=colormap) + img_data = np.asarray(Image.open(raw_img).convert('RGBA')) + + # get unstretched data to compare to + driver = terracotta.get_driver(read_only_database) + + tile_x, tile_y, tile_z = raster_file_xyz + + with driver.connect(): + tile_data = get_tile_data(driver, ds_keys, tile_x=tile_x, tile_y=tile_y, tile_z=tile_z, + tilesize=img_data.shape[:2]) + + # check that labels are mapped to colors correctly + for cmap_label, cmap_color in colormap.items(): + assert np.all(img_data[tile_data == cmap_label] == np.array([*cmap_color, 255])) + + # check that all data outside of labels is transparent + assert np.all(img_data[~np.isin(tile_data, colormap.keys()), -1] == 0) From 80caae16b450619bf3996092428c812b8fbdaf1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Sat, 15 Sep 2018 20:42:50 +0200 Subject: [PATCH 04/15] implement api handler and harden testing --- terracotta/api/singleband.py | 66 ++++++++++++++++++++++++------- terracotta/handlers/singleband.py | 8 ++-- tests/api/test_flask_api.py | 50 +++++++++++++++++++++++ tests/conftest.py | 7 ++++ tests/handlers/test_colormap.py | 1 + tests/handlers/test_singleband.py | 36 +---------------- 6 files changed, 116 insertions(+), 52 deletions(-) diff --git a/terracotta/api/singleband.py b/terracotta/api/singleband.py index 51d59891..3c8ac5ab 100644 --- a/terracotta/api/singleband.py +++ b/terracotta/api/singleband.py @@ -6,7 +6,8 @@ from typing import Any, Mapping, Dict import json -from marshmallow import Schema, fields, validate, pre_load, ValidationError, EXCLUDE +from marshmallow import (Schema, fields, validate, validates_schema, + pre_load, ValidationError, EXCLUDE) from flask import request, send_file from terracotta.api.flask_api import convert_exceptions, tile_api @@ -29,19 +30,55 @@ class Meta: description='Stretch range to use as JSON array, uses full range by default. ' 'Null values indicate global minimum / maximum.', missing=None ) - colormap = fields.String(description='Colormap to apply to image (see /colormap)', - missing=None, validate=validate.OneOf(AVAILABLE_CMAPS)) + + colormap = fields.String( + description='Colormap to apply to image (see /colormap)', + validate=validate.OneOf(('explicit', *AVAILABLE_CMAPS)), missing=None + ) + + explicit_color_map = fields.Dict( + keys=fields.Number(), + values=fields.List(fields.Number, validate=validate.Length(equal=3)), + example='{{0: (255, 255, 255)}}', + description='Explicit value-color mapping to use as JSON object. ' + 'Must be given together with colormap=explicit. Color values can be ' + 'specified either as RGB tuple (in the range of [0, 255]), or as ' + 'hex strings.' + ) + + @validates_schema + def validate_cmap(self, data: Mapping[str, Any]) -> None: + if data.get('colormap', '') == 'explicit' and not data.get('explicit_color_map'): + raise ValidationError('explicit_color_map argument must be given for colormap=explicit', + 'colormap') + + if data.get('explicit_color_map') and data.get('colormap', '') != 'explicit': + raise ValidationError('explicit_color_map can only be given for colormap=explicit', + 'explicit_color_map') @pre_load - def process_ranges(self, data: Mapping[str, Any]) -> Dict[str, Any]: + def decode_json(self, data: Mapping[str, Any]) -> Dict[str, Any]: data = dict(data.items()) - var = 'stretch_range' - val = data.get(var) - if val: - try: - data[var] = json.loads(val) - except json.decoder.JSONDecodeError as exc: - raise ValidationError(f'Could not decode value for {var} as JSON') from exc + for var in ('stretch_range', 'explicit_color_map'): + val = data.get(var) + if val: + try: + data[var] = json.loads(val) + except json.decoder.JSONDecodeError as exc: + raise ValidationError(f'Could not decode value for {var} as JSON') from exc + + val = data.get('explicit_color_map') + if val and isinstance(val, dict): + for key, color in val.items(): + if isinstance(color, str): + hex_string = color.lstrip('#') + try: + rgb = [int(hex_string[i:i + 2], 16) for i in (0, 2, 4)] + data['explicit_color_map'][key] = rgb + except ValueError: + msg = f'Could not decode value {color} in explicit_color_map as hex string' + raise ValidationError(msg) + return data @@ -79,8 +116,9 @@ def get_singleband(tile_z: int, tile_y: int, tile_x: int, keys: str) -> Any: option_schema = SinglebandOptionSchema() options = option_schema.load(request.args) - image = singleband( - parsed_keys, tile_xyz, **options - ) + if options.get('colormap', '') == 'explicit': + options['colormap'] = options.pop('explicit_color_map') + + image = singleband(parsed_keys, tile_xyz, **options) return send_file(image, mimetype='image/png') diff --git a/terracotta/handlers/singleband.py b/terracotta/handlers/singleband.py index ce263481..5cd06b69 100644 --- a/terracotta/handlers/singleband.py +++ b/terracotta/handlers/singleband.py @@ -22,7 +22,7 @@ def singleband(keys: Union[Sequence[str], Mapping[str, str]], stretch_range: Tuple[Number, Number] = None) -> BinaryIO: """Return singleband image as PNG""" - colormap_: Union[str, Sequence[RGB], None] + cmap_or_palette: Union[str, Sequence[RGB], None] try: tile_x, tile_y, tile_z = tile_xyz @@ -55,7 +55,7 @@ def singleband(keys: Union[Sequence[str], Mapping[str, str]], labels, label_colors = list(colormap.keys()), list(colormap.values()) - colormap_ = label_colors + cmap_or_palette = label_colors out = image.label(tile_data, labels) else: # determine stretch range from metadata and arguments @@ -67,7 +67,7 @@ def singleband(keys: Union[Sequence[str], Mapping[str, str]], if stretch_max is not None: stretch_range_[1] = stretch_max - colormap_ = cast(Optional[str], colormap) + cmap_or_palette = cast(Optional[str], colormap) out = image.to_uint8(tile_data, *stretch_range_) - return image.array_to_png(out, transparency_mask=~valid_mask, colormap=colormap_) + return image.array_to_png(out, transparency_mask=~valid_mask, colormap=cmap_or_palette) diff --git a/tests/api/test_flask_api.py b/tests/api/test_flask_api.py index f0fd03f2..f887810b 100644 --- a/tests/api/test_flask_api.py +++ b/tests/api/test_flask_api.py @@ -93,6 +93,56 @@ def test_get_singleband_cmap(client, use_read_only_database, raster_file_xyz): assert np.asarray(img).shape == settings.TILE_SIZE +def test_get_singleband_explicit_cmap(client, use_read_only_database, raster_file_xyz): + import terracotta + settings = terracotta.get_settings() + + x, y, z = raster_file_xyz + explicit_cmap = {1: (0, 0, 0), 2.0: (255, 255, 255), 3: '#ffffff'} + + rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' + f'&explicit_color_map={json.dumps(explicit_cmap)}') + print(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' + f'&explicit_color_map={json.dumps(explicit_cmap)}') + assert rv.status_code == 200 + + img = Image.open(BytesIO(rv.data)) + assert np.asarray(img).shape == settings.TILE_SIZE + + +def test_get_singleband_explicit_cmap_invalid(client, use_read_only_database, raster_file_xyz): + import terracotta + settings = terracotta.get_settings() + + x, y, z = raster_file_xyz + explicit_cmap = {1: (0, 0, 0), 2: (255, 255, 255), 3: '#ffffff'} + + rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?' + f'explicit_color_map={json.dumps(explicit_cmap)}') + assert rv.status_code == 400 + + rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=jet' + f'&explicit_color_map={json.dumps(explicit_cmap)}') + assert rv.status_code == 400 + + rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit') + assert rv.status_code == 400 + + explicit_cmap[3] = 'omgomg' + rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' + f'&explicit_color_map={json.dumps(explicit_cmap)}') + assert rv.status_code == 400 + + explicit_cmap = [(255, 255, 255)] + rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' + f'&explicit_color_map={json.dumps(explicit_cmap)}') + assert rv.status_code == 400 + + rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' + f'&explicit_color_map=foo') + assert rv.status_code == 400 + + def test_get_singleband_stretch(client, use_read_only_database, raster_file_xyz): import terracotta settings = terracotta.get_settings() diff --git a/tests/conftest.py b/tests/conftest.py index 3b375c0e..755d2fab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,13 @@ def pytest_configure(config): os.environ['TC_TESTING'] = '1' + # prevent caching to keep tests independent + import terracotta + terracotta.update_settings( + METADATA_CACHE_SIZE=0, + RASTER_CACHE_SIZE=0 + ) + def pytest_unconfigure(config): os.environ['TC_TESTING'] = '0' diff --git a/tests/handlers/test_colormap.py b/tests/handlers/test_colormap.py index 08a4205a..741f886b 100644 --- a/tests/handlers/test_colormap.py +++ b/tests/handlers/test_colormap.py @@ -17,6 +17,7 @@ def test_colormap_handler(): @pytest.mark.parametrize('cmap_name', [None, 'jet']) def test_colormap_consistency(use_read_only_database, read_only_database, raster_file_xyz, stretch_range, cmap_name): + """Test consistency between /colormap and images returned by /singleband""" import terracotta from terracotta.xyz import get_tile_data from terracotta.handlers import singleband, colormap diff --git a/tests/handlers/test_singleband.py b/tests/handlers/test_singleband.py index 7d7b59ea..9c949919 100644 --- a/tests/handlers/test_singleband.py +++ b/tests/handlers/test_singleband.py @@ -30,40 +30,8 @@ def test_singleband_out_of_bounds(use_read_only_database, raster_file): singleband.singleband(keys, (10, 0, 0)) -@pytest.mark.parametrize('stretch_range', [[0, 20000], [10000, 20000], [-50000, 50000]]) -def test_singleband_stretch(stretch_range, use_read_only_database, read_only_database, raster_file_xyz): - import terracotta - from terracotta.xyz import get_tile_data - from terracotta.handlers import singleband - - ds_keys = ['val21', 'val22'] - - raw_img = singleband.singleband(ds_keys, raster_file_xyz, stretch_range=stretch_range) - img_data = np.asarray(Image.open(raw_img)) - - # get unstretched data to compare to - driver = terracotta.get_driver(read_only_database) - - tile_x, tile_y, tile_z = raster_file_xyz - - with driver.connect(): - tile_data = get_tile_data(driver, ds_keys, tile_x=tile_x, tile_y=tile_y, tile_z=tile_z, - tilesize=img_data.shape) - - # filter transparent values - valid_mask = tile_data != 0 - assert np.all(img_data[~valid_mask] == 0) - - valid_img = img_data[valid_mask] - valid_data = tile_data[valid_mask] - - assert np.all(valid_img[valid_data < stretch_range[0]] == 1) - stretch_range_mask = (valid_data > stretch_range[0]) & (valid_data < stretch_range[1]) - assert not np.any(np.isin(valid_img[stretch_range_mask], [1, 255])) - assert np.all(valid_img[valid_data > stretch_range[1]] == 255) - - -def test_explicit_colormap(use_read_only_database, read_only_database, raster_file_xyz): +def test_singleband_explicit_colormap(use_read_only_database, read_only_database, + raster_file_xyz): import terracotta from terracotta.xyz import get_tile_data from terracotta.handlers import singleband From 3a10f2d2870d511055640acb28c6e3014be04561 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Sat, 15 Sep 2018 21:09:10 +0200 Subject: [PATCH 05/15] debug failing test --- tests/api/test_flask_api.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/api/test_flask_api.py b/tests/api/test_flask_api.py index f887810b..cb32b688 100644 --- a/tests/api/test_flask_api.py +++ b/tests/api/test_flask_api.py @@ -102,9 +102,7 @@ def test_get_singleband_explicit_cmap(client, use_read_only_database, raster_fil rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' f'&explicit_color_map={json.dumps(explicit_cmap)}') - print(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' - f'&explicit_color_map={json.dumps(explicit_cmap)}') - assert rv.status_code == 200 + assert rv.status_code == 200, rv.data.decode('utf-8') img = Image.open(BytesIO(rv.data)) assert np.asarray(img).shape == settings.TILE_SIZE From 51f493e4255b9d1f2f578a8ba0a60dceace9d7b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Sat, 15 Sep 2018 21:20:58 +0200 Subject: [PATCH 06/15] more debug --- terracotta/api/singleband.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/terracotta/api/singleband.py b/terracotta/api/singleband.py index 3c8ac5ab..e3f76e92 100644 --- a/terracotta/api/singleband.py +++ b/terracotta/api/singleband.py @@ -65,7 +65,8 @@ def decode_json(self, data: Mapping[str, Any]) -> Dict[str, Any]: try: data[var] = json.loads(val) except json.decoder.JSONDecodeError as exc: - raise ValidationError(f'Could not decode value for {var} as JSON') from exc + msg = f'Could not decode value {val} for {var} as JSON' + raise ValidationError(msg) from exc val = data.get('explicit_color_map') if val and isinstance(val, dict): From a9941ab873321ac5512be061e6536a283688aa35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Sat, 15 Sep 2018 21:42:22 +0200 Subject: [PATCH 07/15] escape json before passing to url --- tests/api/test_flask_api.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/api/test_flask_api.py b/tests/api/test_flask_api.py index cb32b688..c7538942 100644 --- a/tests/api/test_flask_api.py +++ b/tests/api/test_flask_api.py @@ -1,5 +1,6 @@ -import json from io import BytesIO +import json +import urllib.parse from PIL import Image import numpy as np @@ -93,15 +94,20 @@ def test_get_singleband_cmap(client, use_read_only_database, raster_file_xyz): assert np.asarray(img).shape == settings.TILE_SIZE +def urlsafe_json(payload): + payload_json = json.dumps(payload) + return urllib.parse.unquote_plus(payload_json) + + def test_get_singleband_explicit_cmap(client, use_read_only_database, raster_file_xyz): import terracotta settings = terracotta.get_settings() x, y, z = raster_file_xyz - explicit_cmap = {1: (0, 0, 0), 2.0: (255, 255, 255), 3: '#ffffff'} + explicit_cmap = {1: (0, 0, 0), 2.0: (255, 255, 255), 3: '#ffffff', 4: 'abcabc'} rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' - f'&explicit_color_map={json.dumps(explicit_cmap)}') + f'&explicit_color_map={urlsafe_json(explicit_cmap)}') assert rv.status_code == 200, rv.data.decode('utf-8') img = Image.open(BytesIO(rv.data)) @@ -113,14 +119,14 @@ def test_get_singleband_explicit_cmap_invalid(client, use_read_only_database, ra settings = terracotta.get_settings() x, y, z = raster_file_xyz - explicit_cmap = {1: (0, 0, 0), 2: (255, 255, 255), 3: '#ffffff'} + explicit_cmap = {1: (0, 0, 0), 2: (255, 255, 255), 3: '#ffffff', 4: 'abcabc'} rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?' - f'explicit_color_map={json.dumps(explicit_cmap)}') + f'explicit_color_map={urlsafe_json(explicit_cmap)}') assert rv.status_code == 400 rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=jet' - f'&explicit_color_map={json.dumps(explicit_cmap)}') + f'&explicit_color_map={urlsafe_json(explicit_cmap)}') assert rv.status_code == 400 rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit') @@ -128,12 +134,12 @@ def test_get_singleband_explicit_cmap_invalid(client, use_read_only_database, ra explicit_cmap[3] = 'omgomg' rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' - f'&explicit_color_map={json.dumps(explicit_cmap)}') + f'&explicit_color_map={urlsafe_json(explicit_cmap)}') assert rv.status_code == 400 explicit_cmap = [(255, 255, 255)] rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' - f'&explicit_color_map={json.dumps(explicit_cmap)}') + f'&explicit_color_map={urlsafe_json(explicit_cmap)}') assert rv.status_code == 400 rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' From 8d2e2c192b4137ee0c66c893da32a8517d8fe437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Sun, 16 Sep 2018 12:16:12 +0200 Subject: [PATCH 08/15] fix url quoting --- tests/api/test_flask_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/api/test_flask_api.py b/tests/api/test_flask_api.py index c7538942..2f4ff34f 100644 --- a/tests/api/test_flask_api.py +++ b/tests/api/test_flask_api.py @@ -96,7 +96,7 @@ def test_get_singleband_cmap(client, use_read_only_database, raster_file_xyz): def urlsafe_json(payload): payload_json = json.dumps(payload) - return urllib.parse.unquote_plus(payload_json) + return urllib.parse.quote_plus(payload_json, safe=r',.[]{}:"') def test_get_singleband_explicit_cmap(client, use_read_only_database, raster_file_xyz): @@ -108,7 +108,9 @@ def test_get_singleband_explicit_cmap(client, use_read_only_database, raster_fil rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' f'&explicit_color_map={urlsafe_json(explicit_cmap)}') + print(urlsafe_json(explicit_cmap)) assert rv.status_code == 200, rv.data.decode('utf-8') + assert False img = Image.open(BytesIO(rv.data)) assert np.asarray(img).shape == settings.TILE_SIZE From 9c81a582f6b02f43bdd6170497ccf411bd605ade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Sun, 16 Sep 2018 12:24:10 +0200 Subject: [PATCH 09/15] yea i'm stupid --- tests/api/test_flask_api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/api/test_flask_api.py b/tests/api/test_flask_api.py index 2f4ff34f..88fa6a25 100644 --- a/tests/api/test_flask_api.py +++ b/tests/api/test_flask_api.py @@ -108,9 +108,7 @@ def test_get_singleband_explicit_cmap(client, use_read_only_database, raster_fil rv = client.get(f'/singleband/val11/val12/{z}/{x}/{y}.png?colormap=explicit' f'&explicit_color_map={urlsafe_json(explicit_cmap)}') - print(urlsafe_json(explicit_cmap)) assert rv.status_code == 200, rv.data.decode('utf-8') - assert False img = Image.open(BytesIO(rv.data)) assert np.asarray(img).shape == settings.TILE_SIZE From 4f29f13ae98a30643d1e6b8cd01a61a3fb1ff9d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Sun, 16 Sep 2018 14:57:30 +0200 Subject: [PATCH 10/15] fix colormap class naming --- terracotta/api/colormap.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/terracotta/api/colormap.py b/terracotta/api/colormap.py index 02185045..828d2665 100644 --- a/terracotta/api/colormap.py +++ b/terracotta/api/colormap.py @@ -13,16 +13,16 @@ from terracotta.cmaps import AVAILABLE_CMAPS -class colormapEntrySchema(Schema): +class ColormapEntrySchema(Schema): value = fields.Number(required=True) rgb = fields.List(fields.Number(), required=True, validate=validate.Length(equal=3)) -class colormapSchema(Schema): - colormap = fields.Nested(colormapEntrySchema, many=True, required=True) +class ColormapSchema(Schema): + colormap = fields.Nested(ColormapEntrySchema, many=True, required=True) -class colormapOptionSchema(Schema): +class ColormapOptionSchema(Schema): class Meta: unknown = EXCLUDE @@ -63,24 +63,24 @@ def get_colormap() -> str: dataset. parameters: - in: query - schema: colormapOptionSchema + schema: ColormapOptionSchema responses: 200: description: Array containing data values and RGBA tuples - schema: colormapSchema + schema: ColormapSchema 400: description: Query parameters are invalid """ from terracotta.handlers.colormap import colormap - input_schema = colormapOptionSchema() + input_schema = ColormapOptionSchema() options = input_schema.load(request.args) payload = {'colormap': colormap(**options)} - schema = colormapSchema() + schema = ColormapSchema() return jsonify(schema.load(payload)) -spec.definition('colormapEntry', schema=colormapEntrySchema) -spec.definition('colormap', schema=colormapSchema) +spec.definition('ColormapEntry', schema=ColormapEntrySchema) +spec.definition('Colormap', schema=ColormapSchema) From 2ab093f9e5d45f6172a0e3809e2e7ab8515ead32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 17 Sep 2018 15:50:58 +0200 Subject: [PATCH 11/15] improve convex hull computation #69 #70 --- terracotta/drivers/raster_base.py | 42 ++++++++----------- tests/conftest.py | 6 ++- tests/drivers/test_staticmethods.py | 62 +++++++++++++++-------------- 3 files changed, 53 insertions(+), 57 deletions(-) diff --git a/terracotta/drivers/raster_base.py b/terracotta/drivers/raster_base.py index c3719d14..b6aff7d7 100644 --- a/terracotta/drivers/raster_base.py +++ b/terracotta/drivers/raster_base.py @@ -25,7 +25,7 @@ except ImportError: has_crick = False -from terracotta import get_settings, exceptions +from terracotta import get_settings, exceptions, image from terracotta.drivers.base import requires_connection, Driver from terracotta.profile import trace @@ -60,12 +60,12 @@ def _compute_image_stats_chunked(dataset: 'DatasetReader', nodata: Number) -> Optional[Dict[str, Any]]: """Loop over chunks and accumulate statistics""" from rasterio import features, warp, windows - from shapely import geometry, ops + from shapely import geometry total_count = valid_data_count = 0 tdigest = TDigest() sstats = SummaryStats() - convex_hulls = [] + convex_hull = geometry.Polygon() block_windows = [w for _, w in dataset.block_windows(1)] @@ -76,10 +76,7 @@ def _compute_image_stats_chunked(dataset: 'DatasetReader', total_count += int(block_data.size) - valid_data_mask = np.isfinite(block_data) - if not np.isnan(nodata): - valid_data_mask &= (block_data != nodata) - + valid_data_mask = image.get_valid_mask(block_data, nodata) valid_data = block_data[valid_data_mask] if valid_data.size == 0: @@ -87,14 +84,12 @@ def _compute_image_stats_chunked(dataset: 'DatasetReader', valid_data_count += int(valid_data.size) - # this formulation allows us to store only one convex hull per block, - # which should be relatively lightweight block_shapes = [geometry.shape(s) for s, _ in features.shapes( valid_data_mask.astype('uint8'), mask=valid_data_mask, transform=windows.transform(w, dataset.transform) )] - convex_hulls.append(ops.unary_union(block_shapes).convex_hull) + convex_hull = geometry.MultiPolygon([convex_hull, *block_shapes]).convex_hull tdigest.update(valid_data) sstats.update(valid_data) @@ -102,12 +97,9 @@ def _compute_image_stats_chunked(dataset: 'DatasetReader', if sstats.count() == 0: return None - # remove merge artefacts, transform, re-compute convex hull - convex_hull = ops.unary_union(convex_hulls).simplify(0) - convex_hull = warp.transform_geom( + convex_hull_wgs = warp.transform_geom( dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) ) - convex_hull = geometry.shape(convex_hull).convex_hull return { 'valid_percentage': valid_data_count / total_count * 100, @@ -115,34 +107,32 @@ def _compute_image_stats_chunked(dataset: 'DatasetReader', 'mean': sstats.mean(), 'stdev': sstats.std(), 'percentiles': tdigest.quantile(np.arange(0.01, 1, 0.01)), - 'convex_hull': geometry.mapping(convex_hull) + 'convex_hull': convex_hull_wgs } @staticmethod def _compute_image_stats(dataset: 'DatasetReader', nodata: Number) -> Optional[Dict[str, Any]]: from rasterio import features, warp - from shapely import geometry, ops + from shapely import geometry raster_data = dataset.read(1) - valid_data_mask = np.isfinite(raster_data) - if not np.isnan(nodata): - valid_data_mask &= (raster_data != nodata) - + valid_data_mask = image.get_valid_mask(raster_data, nodata) valid_data = raster_data[valid_data_mask] if valid_data.size == 0: return None - raster_features = features.shapes( + raster_shapes = [geometry.shape(s) for s, _ in features.shapes( valid_data_mask.astype('uint8'), - mask=valid_data_mask, + mask=valid_data_mask.astype('bool'), transform=dataset.transform + )] + convex_hull = geometry.MultiPolygon(raster_shapes).convex_hull + convex_hull_wgs = warp.transform_geom( + dataset.crs, 'epsg:4326', geometry.mapping(convex_hull) ) - raster_shapes_wgs = [geometry.shape(warp.transform_geom(dataset.crs, 'epsg:4326', s)) - for s, _ in raster_features] - convex_hull = ops.unary_union(raster_shapes_wgs).convex_hull return { 'valid_percentage': valid_data.size / raster_data.size * 100, @@ -150,7 +140,7 @@ def _compute_image_stats(dataset: 'DatasetReader', 'mean': float(valid_data.mean()), 'stdev': float(valid_data.std()), 'percentiles': np.percentile(valid_data, np.arange(1, 100)), - 'convex_hull': geometry.mapping(convex_hull) + 'convex_hull': convex_hull_wgs } @staticmethod diff --git a/tests/conftest.py b/tests/conftest.py index 755d2fab..e7a39ea0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,10 +56,12 @@ def big_raster_file(tmpdir_factory): raster_data = np.random.randint(0, np.iinfo(np.uint16).max, size=(1024, 1024), dtype='uint16') # include some big nodata regions + ix, iy = np.indices(raster_data.shape) + circular_mask = np.sqrt((ix - raster_data.shape[0] / 2) ** 2 + + (iy - raster_data.shape[1] / 2) ** 2) > 400 + raster_data[circular_mask] = 0 raster_data[200:600, 400:800] = 0 raster_data[500, :] = 0 - raster_data[900:, :] = 0 - raster_data[800:, 800:] = 0 profile = { 'driver': 'GTiff', diff --git a/tests/drivers/test_staticmethods.py b/tests/drivers/test_staticmethods.py index c72556d5..5c4fbd51 100644 --- a/tests/drivers/test_staticmethods.py +++ b/tests/drivers/test_staticmethods.py @@ -25,12 +25,15 @@ def test_default_transform(): assert our_height == args[3] +def geometry_mismatch(shape1, shape2): + return shape1.symmetric_difference(shape2).area / shape1.union(shape2).area + + @pytest.mark.parametrize('use_chunks', [True, False]) def test_compute_metadata(big_raster_file, use_chunks): import rasterio import rasterio.features - from shapely.geometry import shape - from shapely.ops import unary_union + from shapely.geometry import shape, MultiPolygon, mapping import numpy as np from terracotta.drivers.raster_base import RasterDriver @@ -42,7 +45,7 @@ def test_compute_metadata(big_raster_file, use_chunks): src, bidx=1, as_mask=True, geographic=True )) - convex_hull = unary_union([shape(s['geometry']) for s in dataset_shape]).convex_hull + convex_hull = MultiPolygon([shape(s['geometry']) for s in dataset_shape]).convex_hull # compare mtd = RasterDriver.compute_metadata(str(big_raster_file), use_chunks=use_chunks) @@ -59,7 +62,7 @@ def test_compute_metadata(big_raster_file, use_chunks): rtol=0.01 ) - assert shape(mtd['convex_hull']).equals(convex_hull) + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-8 @pytest.mark.parametrize('use_chunks', [True, False]) @@ -71,12 +74,9 @@ def test_compute_metadata_invalid(invalid_raster_file, use_chunks): def test_compute_metadata_nocrick(big_raster_file): - import importlib - import rasterio import rasterio.features - from shapely.geometry import shape - from shapely.ops import unary_union + from shapely.geometry import shape, MultiPolygon, mapping import numpy as np with rasterio.open(str(big_raster_file)) as src: @@ -86,26 +86,30 @@ def test_compute_metadata_nocrick(big_raster_file): src, bidx=1, as_mask=True, geographic=True )) - convex_hull = unary_union([shape(s['geometry']) for s in dataset_shape]).convex_hull + convex_hull = MultiPolygon([shape(s['geometry']) for s in dataset_shape]).convex_hull import terracotta.drivers.raster_base - terracotta.drivers.raster_base.has_crick = False - - with pytest.warns(UserWarning): - mtd = terracotta.drivers.raster_base.RasterDriver.compute_metadata( - str(big_raster_file), use_chunks=True) - - # compare - np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) - np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) - np.testing.assert_allclose(mtd['mean'], valid_data.mean()) - np.testing.assert_allclose(mtd['stdev'], valid_data.std()) - - # allow error of 1%, since we only compute approximate quantiles - np.testing.assert_allclose( - mtd['percentiles'], - np.percentile(valid_data, np.arange(1, 100)), - rtol=0.01 - ) - - assert shape(mtd['convex_hull']).equals(convex_hull) + try: + terracotta.drivers.raster_base.has_crick = False + + with pytest.warns(UserWarning): + mtd = terracotta.drivers.raster_base.RasterDriver.compute_metadata( + str(big_raster_file), use_chunks=True) + + # compare + np.testing.assert_allclose(mtd['valid_percentage'], 100 * valid_data.size / data.size) + np.testing.assert_allclose(mtd['range'], (valid_data.min(), valid_data.max())) + np.testing.assert_allclose(mtd['mean'], valid_data.mean()) + np.testing.assert_allclose(mtd['stdev'], valid_data.std()) + + # allow error of 1%, since we only compute approximate quantiles + np.testing.assert_allclose( + mtd['percentiles'], + np.percentile(valid_data, np.arange(1, 100)), + rtol=0.01 + ) + + assert geometry_mismatch(shape(mtd['convex_hull']), convex_hull) < 1e-8 + + finally: + terracotta.drivers.raster_base.has_crick = True From 43c8e62cd3cd2dc1ffbee87d990633d523b9760e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 17 Sep 2018 15:51:54 +0200 Subject: [PATCH 12/15] use union instead of intersection --- tests/drivers/test_staticmethods.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/drivers/test_staticmethods.py b/tests/drivers/test_staticmethods.py index 5c4fbd51..187f9b49 100644 --- a/tests/drivers/test_staticmethods.py +++ b/tests/drivers/test_staticmethods.py @@ -26,6 +26,7 @@ def test_default_transform(): def geometry_mismatch(shape1, shape2): + """Compute relative mismatch of two shapes""" return shape1.symmetric_difference(shape2).area / shape1.union(shape2).area From 1d2b856432bf5fc4b0a741e9ede28fcd66c73d9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 17 Sep 2018 16:29:11 +0200 Subject: [PATCH 13/15] add recipe to README --- README.md | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 924299d0..59e4d6ec 100644 --- a/README.md +++ b/README.md @@ -195,11 +195,90 @@ For all available settings, their types and default values, have a look at the f [config.py](https://github.com/DHI-GRAS/terracotta/blob/master/terracotta/config.py) in the Terracotta code. +## Advances recipes + +### Serving categorical data + +Categorical datasets are special in that the numerical pixel values carry no direct meaning, +but rather encode which category or label the pixel belongs to. Because labels must be preserved, +serving categorical data comes with its own set of complications: + +- Dynamical stretching does not make sense +- Nearest neighbor resampling must be used +- Labels must be mapped to colors consistently + +So far, Terracotta is agnostic of categories and labels, but the API is flexible enough to give +you the tools to build your own system. Categorical data can be served by following these steps: + +#### During ingestion + +1. Create an additional key to encode whether a dataset is categorical or not. E.g., if you are + currently using the keys `sensor`, `date`, and `band`, ingest your data with the keys + `[type, sensor, date, band]`, where `type` can take one of the values `categorical`, `index`, + `reflectance`, or whatever makes sense for your given application. +2. Attach a mapping `category name -> pixel value` to the metadata of your categorical dataset. + Using the Python API, this could e.g. be done like this: + ```python + import terracotta as tc + + driver = tc.get_driver('terracotta.sqlite') + + # assuming keys are [type, sensor, date, band] + keys = ['categorical', 'S2', '20181010', 'cloudmask'] + raster_path = 'cloud_mask.tif' + + category_map = { + 'clear land': 0, + 'clear water': 1, + 'cloud': 2, + 'cloud shadow': 3 + } + + with driver.connect(): + metadata = driver.compute_metadata(raster_path, extra_metadata={'categories': category_map}) + driver.insert(keys, raster_path, metadata=metadata) + ``` + +#### In the frontend + +Ingesting categorical data this way allows us to access it from the frontend. Given that your +Terracotta server runs at `example.com`, you can use the following functionality: + +- To get a list of all categorical data, simply send a GET request to + `example.com/datasets?type=categorical`. +- To get the available categories of a dataset, query + `example.com/metadata/categorical/S2/20181010/cloudmask`. The returned JSON object will contain + a section like this: + + ```json + { + ... + "extra_metadata": { + "categories": { + "clear land": 0, + "clear water": 1, + "cloud": 2, + "cloud shadow": 3 + } + } + } + ``` +- To get correctly labelled imagery, the frontend will have to pass an explicit color mapping of pixel + values to colors by using `/singleband`'s `explicit_color_map` argument. In our case, this could look + like this: + `example.com/singleband/categorical/S2/20181010/cloudmask/{z}/{x}/{y}.png?colormap=explicit&explicit_color_map={"0": "99d594", "1": "2b83ba", "2": "ffffff", "3": "404040"}`. + + Supplying an explicit color map in this fashion suppresses stretching, and forces Terracotta to only use + nearest neighbor resampling when reading the data. + + Colors can be passed as hex strings (as in this example) or RGB color tuples. In case you are looking + for a nice color scheme for your categorical datasets, [color brewer](http://colorbrewer2.org) features + some excellent suggestions. + ## Deployment on AWS λ The easiest way to deploy Terracotta on AWS λ is by using [Zappa](https://github.com/Miserlou/Zappa). - Example `zappa_settings.json` file: ```json From 80ae74f1e8da5610a74ac710fd2bb186fd4d3fff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 17 Sep 2018 16:32:53 +0200 Subject: [PATCH 14/15] prettify --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 59e4d6ec..17103586 100644 --- a/README.md +++ b/README.md @@ -218,6 +218,7 @@ you the tools to build your own system. Categorical data can be served by follow `reflectance`, or whatever makes sense for your given application. 2. Attach a mapping `category name -> pixel value` to the metadata of your categorical dataset. Using the Python API, this could e.g. be done like this: + ```python import terracotta as tc @@ -252,7 +253,6 @@ Terracotta server runs at `example.com`, you can use the following functionality ```json { - ... "extra_metadata": { "categories": { "clear land": 0, From 681e6d874a98a3af7435597b2cdc37db707868dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 17 Sep 2018 16:56:45 +0200 Subject: [PATCH 15/15] lighten multithreaded tests --- tests/drivers/test_drivers.py | 314 --------------------------- tests/drivers/test_raster_drivers.py | 8 +- 2 files changed, 6 insertions(+), 316 deletions(-) diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index ada3dba9..2fe064b6 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -1,11 +1,7 @@ -import numpy as np import pytest - DRIVERS = ['sqlite'] -METADATA_KEYS = ('bounds', 'nodata', 'range', 'mean', 'stdev', 'percentiles', 'metadata') - @pytest.mark.parametrize('provider', DRIVERS) def test_creation(tmpdir, provider): @@ -64,313 +60,3 @@ def test_recreation(tmpdir, provider): assert db.available_keys == keys assert db.get_datasets() == {} - -@pytest.mark.parametrize('provider', DRIVERS) -def test_insertion_and_retrieval(tmpdir, raster_file, provider): - from terracotta import drivers - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('some', 'keys') - - db.create(keys) - db.insert(['some', 'value'], str(raster_file)) - - data = db.get_datasets() - assert list(data.keys()) == [('some', 'value')] - assert data[('some', 'value')] == str(raster_file) - - metadata = db.get_metadata(('some', 'value')) - assert all(key in metadata for key in METADATA_KEYS) - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_where(tmpdir, raster_file, provider): - from terracotta import drivers - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('some', 'keys') - - db.create(keys) - db.insert(['some', 'value'], str(raster_file)) - db.insert(['some', 'other_value'], str(raster_file)) - db.insert({'some': 'a', 'keys': 'third_value'}, str(raster_file)) - - data = db.get_datasets() - assert len(data) == 3 - - data = db.get_datasets(where=dict(some='some')) - assert len(data) == 2 - - data = db.get_datasets(where=dict(some='some', keys='value')) - assert list(data.keys()) == [('some', 'value')] - assert data[('some', 'value')] == str(raster_file) - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_lazy_loading(tmpdir, raster_file, provider): - from terracotta import drivers - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('some', 'keys') - - db.create(keys) - db.insert(['some', 'value'], str(raster_file), skip_metadata=False) - db.insert(['some', 'other_value'], str(raster_file), skip_metadata=True) - - datasets = db.get_datasets() - assert len(datasets) == 2 - - data1 = db.get_metadata(['some', 'value']) - data2 = db.get_metadata({'some': 'some', 'keys': 'other_value'}) - assert list(data1.keys()) == list(data2.keys()) - assert all(np.all(data1[k] == data2[k]) for k in data1.keys()) - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_precomputed_metadata(tmpdir, raster_file, provider): - from terracotta import drivers - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('some', 'keys') - - metadata = db.compute_metadata(str(raster_file)) - - db.create(keys) - db.insert(['some', 'value'], str(raster_file), metadata=metadata) - db.insert(['some', 'other_value'], str(raster_file)) - - datasets = db.get_datasets() - assert len(datasets) == 2 - - data1 = db.get_metadata(['some', 'value']) - data2 = db.get_metadata({'some': 'some', 'keys': 'other_value'}) - assert list(data1.keys()) == list(data2.keys()) - assert all(np.all(data1[k] == data2[k]) for k in data1.keys()) - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_invalid_insertion(tmpdir, raster_file, provider): - from terracotta import drivers - - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('key',) - - db.create(keys) - - def throw(*args, **kwargs): - raise NotImplementedError() - - db.compute_metadata = throw - - db.insert(['bar'], str(raster_file), skip_metadata=True) - - with pytest.raises(NotImplementedError): - db.insert(['foo'], str(raster_file), skip_metadata=False) - - datasets = db.get_datasets() - - assert ('bar',) in datasets - assert ('foo',) not in datasets - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_invalid_group_insertion(tmpdir, raster_file, provider): - from terracotta import drivers - - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('key',) - - db.create(keys) - - def throw(*args, **kwargs): - raise NotImplementedError() - - db.compute_metadata = throw - - with db.connect(): - db.insert(['bar'], str(raster_file), skip_metadata=True) - - with pytest.raises(NotImplementedError): - db.insert(['foo'], str(raster_file), skip_metadata=False) - - datasets = db.get_datasets() - - assert ('bar',) not in datasets - assert ('foo',) not in datasets - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_insertion_cache(tmpdir, raster_file, provider): - from terracotta import drivers - - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('key',) - - db.create(keys) - datasets_before = db.get_datasets() - db.insert(['foo'], str(raster_file), skip_metadata=True) - datasets_after = db.get_datasets() - - assert ('foo',) in datasets_after and ('foo',) not in datasets_before - - -def insertion_worker(key, dbfile, raster_file, provider): - from terracotta import drivers - db = drivers.get_driver(str(dbfile), provider=provider) - db.insert([key], str(raster_file), skip_metadata=False) - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_multithreaded_insertion(tmpdir, raster_file, provider): - import functools - import concurrent.futures - from terracotta import drivers - - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('key',) - - db.create(keys) - - key_vals = [str(i) for i in range(100)] - - worker = functools.partial(insertion_worker, dbfile=dbfile, raster_file=raster_file, - provider=provider) - - with concurrent.futures.ThreadPoolExecutor(10) as executor: - for result in executor.map(worker, key_vals): - pass - - datasets = db.get_datasets() - assert all((key,) in datasets for key in key_vals), datasets.keys() - - data1 = db.get_metadata(['77']) - data2 = db.get_metadata({'key': '99'}) - assert list(data1.keys()) == list(data2.keys()) - assert all(np.all(data1[k] == data2[k]) for k in data1.keys()) - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_multiprocess_insertion(tmpdir, raster_file, provider): - import functools - import concurrent.futures - from terracotta import drivers - - dbfile = str(tmpdir.join('test.sqlite')) - raster_file = str(raster_file) - db = drivers.get_driver(dbfile, provider=provider) - keys = ('key',) - - db.create(keys) - - key_vals = [str(i) for i in range(100)] - - worker = functools.partial(insertion_worker, dbfile=dbfile, raster_file=raster_file, - provider=provider) - - with concurrent.futures.ProcessPoolExecutor(4) as executor: - for result in executor.map(worker, key_vals): - pass - - datasets = db.get_datasets() - assert all((key,) in datasets for key in key_vals) - - data1 = db.get_metadata(['77']) - data2 = db.get_metadata({'key': '99'}) - assert list(data1.keys()) == list(data2.keys()) - assert all(np.all(data1[k] == data2[k]) for k in data1.keys()) - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_insertion_invalid_raster(tmpdir, invalid_raster_file, provider): - from terracotta import drivers - - dbfile = str(tmpdir.join('test.sqlite')) - db = drivers.get_driver(dbfile, provider=provider) - keys = ('key',) - - db.create(keys) - - with pytest.raises(ValueError): - db.insert(['val'], str(invalid_raster_file)) - - datasets = db.get_datasets() - assert ('val',) not in datasets - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_raster_retrieval(tmpdir, raster_file, provider): - from terracotta import drivers - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('some', 'keys') - - db.create(keys) - db.insert(['some', 'value'], str(raster_file)) - db.insert(['some', 'other_value'], str(raster_file)) - - data1 = db.get_raster_tile(['some', 'value'], tilesize=(256, 256)) - assert data1.shape == (256, 256) - - data2 = db.get_raster_tile(['some', 'other_value'], tilesize=(256, 256)) - assert data2.shape == (256, 256) - - np.testing.assert_array_equal(data1, data2) - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_raster_duplicate(tmpdir, raster_file, provider): - from terracotta import drivers - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('some', 'keys') - - db.create(keys) - db.insert(['some', 'value'], str(raster_file)) - db.insert(['some', 'value'], str(raster_file)) - - assert list(db.get_datasets().keys()) == [('some', 'value')] - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_deletion(tmpdir, raster_file, provider): - from terracotta import drivers, exceptions - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('some', 'keys') - - db.create(keys) - - dataset = {'some': 'some', 'keys': 'value'} - db.insert(dataset, str(raster_file)) - - data = db.get_datasets() - assert list(data.keys()) == [('some', 'value')] - assert data[('some', 'value')] == str(raster_file) - - metadata = db.get_metadata(('some', 'value')) - assert all(key in metadata for key in METADATA_KEYS) - - db.delete(dataset) - assert not db.get_datasets() - - with pytest.raises(exceptions.DatasetNotFoundError): - db.get_metadata(dataset) - - -@pytest.mark.parametrize('provider', DRIVERS) -def test_delete_nonexisting(tmpdir, raster_file, provider): - from terracotta import drivers, exceptions - dbfile = tmpdir.join('test.sqlite') - db = drivers.get_driver(str(dbfile), provider=provider) - keys = ('some', 'keys') - - db.create(keys) - - dataset = {'some': 'some', 'keys': 'value'} - - with pytest.raises(exceptions.DatasetNotFoundError): - db.delete(dataset) diff --git a/tests/drivers/test_raster_drivers.py b/tests/drivers/test_raster_drivers.py index 9f5f12fa..acae99ac 100644 --- a/tests/drivers/test_raster_drivers.py +++ b/tests/drivers/test_raster_drivers.py @@ -1,7 +1,6 @@ import numpy as np import pytest - DRIVERS = ['sqlite'] METADATA_KEYS = ('bounds', 'nodata', 'range', 'mean', 'stdev', 'percentiles', 'metadata') @@ -160,9 +159,14 @@ def test_insertion_cache(tmpdir, raster_file, provider): def insertion_worker(key, dbfile, raster_file, provider): + import time from terracotta import drivers db = drivers.get_driver(str(dbfile), provider=provider) - db.insert([key], str(raster_file), skip_metadata=False) + with db.connect(): + db.insert([key], str(raster_file), skip_metadata=True) + # keep connection open for a while to increase the chance of + # triggering a race condition + time.sleep(0.05) @pytest.mark.parametrize('provider', DRIVERS)