From 542b9fb6bcb496a55ff993ffb7ecd54435567fb5 Mon Sep 17 00:00:00 2001 From: teks Date: Thu, 6 Feb 2025 07:51:40 -0500 Subject: [PATCH] 993 progress: convert 3 more files to pytest (#1520) * 993 convert test_item_assets.py to pytest * 993 progress on test_item_collection.py: * replace unittest asserts with pytest asserts * remove maxDiff * 993 remove class (nonfunctional atm) * 993 remove self from methods (still not functional) * 993 install fixtures & finish first test conversion * 993 complete conversion to pytest for test_item_collection.py * 993 pre-commit fixups * 993 start on test_collection.py * 993 revise assertion * 993 replace assertEqual with assert; cmd line: sed -i.bak -E 's/self.assertEqual\((.*), (.*)\)/assert \1 == \2/' tests/test_collection.py note needed a bit of editing as multiple commas in one line could confuse sed * 993 convert rest of asserts * 993 convert CollectionTest to top-level functions * 993 remove self to prep for declassification * 993 convert ExtentTest method tests to functions * 993 convert Collection subclass tests to pytest * 993 complete test_collection.py conversion to pytest * 993 fastidious revisions by ruff --- tests/test_collection.py | 1003 +++++++++++++++++---------------- tests/test_item_assets.py | 110 ++-- tests/test_item_collection.py | 300 +++++----- 3 files changed, 716 insertions(+), 697 deletions(-) diff --git a/tests/test_collection.py b/tests/test_collection.py index 0ccc5a6b1..5aafa1ea0 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -3,7 +3,6 @@ import json import os import tempfile -import unittest from collections.abc import Iterator from copy import deepcopy from datetime import datetime @@ -32,499 +31,512 @@ TEST_DATETIME = datetime(2020, 3, 14, 16, 32) -class ProviderTest(unittest.TestCase): - def test_to_from_dict(self) -> None: - provider_dict = { - "name": "Remote Data, Inc", - "description": "Producers of awesome spatiotemporal assets", - "roles": ["producer", "processor"], - "url": "http://remotedata.io", - "extension:field": "some value", - } - expected_extra_fields = {"extension:field": provider_dict["extension:field"]} - - provider = Provider.from_dict(provider_dict) - - self.assertEqual(provider_dict["name"], provider.name) - self.assertEqual(provider_dict["description"], provider.description) - self.assertEqual(provider_dict["roles"], provider.roles) - self.assertEqual(provider_dict["url"], provider.url) - self.assertDictEqual(expected_extra_fields, provider.extra_fields) - - self.assertDictEqual(provider_dict, provider.to_dict()) - - -class CollectionTest(unittest.TestCase): - def test_spatial_extent_from_coordinates(self) -> None: - extent = SpatialExtent.from_coordinates(ARBITRARY_GEOM["coordinates"]) - - self.assertEqual(len(extent.bboxes), 1) - bbox = extent.bboxes[0] - self.assertEqual(len(bbox), 4) - for x in bbox: - self.assertTrue(isinstance(x, float)) - - def test_read_eo_items_are_heritable(self) -> None: - cat = TestCases.case_5() - item = next(cat.get_items(recursive=True)) - - self.assertTrue(EOExtension.has_extension(item)) - - def test_save_uses_previous_catalog_type(self) -> None: - collection = TestCases.case_8() - assert collection.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION - self.assertEqual(collection.catalog_type, CatalogType.SELF_CONTAINED) - with tempfile.TemporaryDirectory() as tmp_dir: - collection.normalize_hrefs(tmp_dir) - href = collection.self_href - collection.save() - - collection2 = pystac.Collection.from_file(href) - self.assertEqual(collection2.catalog_type, CatalogType.SELF_CONTAINED) - - def test_clone_uses_previous_catalog_type(self) -> None: - catalog = TestCases.case_8() - assert catalog.catalog_type == CatalogType.SELF_CONTAINED - clone = catalog.clone() - self.assertEqual(clone.catalog_type, CatalogType.SELF_CONTAINED) - - def test_clone_cant_mutate_original(self) -> None: - collection = TestCases.case_8() - assert collection.keywords is not None - self.assertListEqual(collection.keywords, ["disaster", "open"]) - clone = collection.clone() - clone.extra_fields["test"] = "extra" - self.assertNotIn("test", collection.extra_fields) - assert clone.keywords is not None - clone.keywords.append("clone") - self.assertListEqual(clone.keywords, ["disaster", "open", "clone"]) - self.assertListEqual(collection.keywords, ["disaster", "open"]) - self.assertNotEqual(id(collection.summaries), id(clone.summaries)) - - def test_multiple_extents(self) -> None: - cat1 = TestCases.case_1() - country = cat1.get_child("country-1") - assert country is not None - col1 = country.get_child("area-1-1") - assert col1 is not None - col1.validate() - self.assertIsInstance(col1, Collection) - validate_dict(col1.to_dict(), pystac.STACObjectType.COLLECTION) - - multi_ext_uri = TestCases.get_path("data-files/collections/multi-extent.json") - with open(multi_ext_uri) as f: - multi_ext_dict = json.load(f) - validate_dict(multi_ext_dict, pystac.STACObjectType.COLLECTION) - self.assertIsInstance(Collection.from_dict(multi_ext_dict), Collection) - - multi_ext_col = Collection.from_file(multi_ext_uri) - multi_ext_col.validate() - ext = multi_ext_col.extent - extent_dict = multi_ext_dict["extent"] - self.assertIsInstance(ext, Extent) - self.assertIsInstance(ext.spatial.bboxes[0], list) - self.assertEqual(len(ext.spatial.bboxes), 3) - self.assertDictEqual(ext.to_dict(), extent_dict) - - cloned_ext = ext.clone() - self.assertDictEqual(cloned_ext.to_dict(), multi_ext_dict["extent"]) - - def test_extra_fields(self) -> None: - catalog = TestCases.case_2() - collection = catalog.get_child("1a8c1632-fa91-4a62-b33e-3a87c2ebdf16") - assert collection is not None - - collection.extra_fields["test"] = "extra" - - with tempfile.TemporaryDirectory() as tmp_dir: - p = os.path.join(tmp_dir, "collection.json") - collection.save_object(include_self_link=False, dest_href=p) - with open(p) as f: - col_json = json.load(f) - self.assertTrue("test" in col_json) - self.assertEqual(col_json["test"], "extra") - - read_col = pystac.Collection.from_file(p) - self.assertTrue("test" in read_col.extra_fields) - self.assertEqual(read_col.extra_fields["test"], "extra") - - def test_update_extents(self) -> None: - catalog = TestCases.case_2() - base_collection = catalog.get_child("1a8c1632-fa91-4a62-b33e-3a87c2ebdf16") - assert isinstance(base_collection, Collection) - base_extent = base_collection.extent - collection = base_collection.clone() - - item1 = Item( - id="test-item-1", - geometry=ARBITRARY_GEOM, - bbox=[-180, -90, 180, 90], - datetime=TEST_DATETIME, - properties={"key": "one"}, - stac_extensions=["eo", "commons"], - ) - - item2 = Item( - id="test-item-1", - geometry=ARBITRARY_GEOM, - bbox=[-180, -90, 180, 90], - datetime=None, - properties={ - "start_datetime": datetime_to_str(datetime(2000, 1, 1, 12, 0, 0, 0)), - "end_datetime": datetime_to_str(datetime(2000, 2, 1, 12, 0, 0, 0)), - }, - stac_extensions=["eo", "commons"], - ) - - collection.add_item(item1) - - collection.update_extent_from_items() - self.assertEqual([[-180, -90, 180, 90]], collection.extent.spatial.bboxes) - self.assertEqual( - len(base_extent.spatial.bboxes[0]), len(collection.extent.spatial.bboxes[0]) - ) - - self.assertNotEqual( - base_extent.temporal.intervals, collection.extent.temporal.intervals - ) - collection.remove_item("test-item-1") - collection.update_extent_from_items() - self.assertNotEqual([[-180, -90, 180, 90]], collection.extent.spatial.bboxes) - collection.add_item(item2) - - collection.update_extent_from_items() - - self.assertEqual( - [ - [ - item2.common_metadata.start_datetime, - base_extent.temporal.intervals[0][1], - ] - ], - collection.extent.temporal.intervals, - ) - - def test_supplying_href_in_init_does_not_fail(self) -> None: - test_href = "http://example.com/collection.json" - spatial_extent = SpatialExtent(bboxes=[ARBITRARY_BBOX]) - temporal_extent = TemporalExtent(intervals=[[TEST_DATETIME, None]]) - - collection_extent = Extent(spatial=spatial_extent, temporal=temporal_extent) - collection = Collection( - id="test", description="test desc", extent=collection_extent, href=test_href - ) - - self.assertEqual(collection.get_self_href(), test_href) - - def test_collection_with_href_caches_by_href(self) -> None: - collection = pystac.Collection.from_file( - TestCases.get_path("data-files/examples/hand-0.8.1/collection.json") - ) - cache = collection._resolved_objects - - # Since all of our STAC objects have HREFs, everything should be - # cached only by HREF - self.assertEqual(len(cache.id_keys_to_objects), 0) - - @pytest.mark.block_network - def test_assets(self) -> None: - path = TestCases.get_path("data-files/collections/with-assets.json") - with open(path) as f: - data = json.load(f) - collection = pystac.Collection.from_dict(data) - collection.validate() - - def test_get_assets(self) -> None: - collection = pystac.Collection.from_file( - TestCases.get_path("data-files/collections/with-assets.json") - ) - - media_type_filter = collection.get_assets(media_type=pystac.MediaType.PNG) - self.assertCountEqual(media_type_filter.keys(), ["thumbnail"]) - role_filter = collection.get_assets(role="thumbnail") - self.assertCountEqual(role_filter.keys(), ["thumbnail"]) - multi_filter = collection.get_assets( - media_type=pystac.MediaType.PNG, role="thumbnail" - ) - self.assertCountEqual(multi_filter.keys(), ["thumbnail"]) - - no_filter = collection.get_assets() - self.assertIsNot(no_filter, collection.assets) - self.assertCountEqual(no_filter.keys(), ["thumbnail"]) - no_filter["thumbnail"].description = "foo" - assert collection.assets["thumbnail"].description != "foo" - - no_assets = collection.get_assets(media_type=pystac.MediaType.HDF) - self.assertEqual(no_assets, {}) - - def test_removing_optional_attributes(self) -> None: - path = TestCases.get_path("data-files/collections/with-assets.json") - with open(path) as file: - data = json.load(file) - data["title"] = "dummy title" - data["stac_extensions"] = ["dummy extension"] - data["keywords"] = ["key", "word"] - data["providers"] = [{"name": "pystac"}] - collection = pystac.Collection.from_dict(data) - - # Assert we have everything set - assert collection.title - assert collection.stac_extensions - assert collection.keywords - assert collection.providers - assert collection.summaries - assert collection.assets - - # Remove all of the optional stuff - collection.title = None - collection.stac_extensions = [] - collection.keywords = [] - collection.providers = [] - collection.summaries = pystac.Summaries({}) - collection.assets = {} - - collection_as_dict = collection.to_dict() - for key in ( - "title", - "stac_extensions", - "keywords", - "providers", - "summaries", - "assets", - ): - assert key not in collection_as_dict - - def test_from_dict_preserves_dict(self) -> None: - path = TestCases.get_path("data-files/collections/with-assets.json") - with open(path) as f: - collection_dict = json.load(f) - param_dict = deepcopy(collection_dict) - - # test that the parameter is preserved - _ = Collection.from_dict(param_dict) - self.assertEqual(param_dict, collection_dict) - - # assert that the parameter is not preserved with - # non-default parameter - _ = Collection.from_dict(param_dict, preserve_dict=False, migrate=False) - self.assertNotEqual(param_dict, collection_dict) - - def test_from_dict_set_root(self) -> None: - path = TestCases.get_path("data-files/examples/hand-0.8.1/collection.json") - with open(path) as f: - collection_dict = json.load(f) - catalog = pystac.Catalog(id="test", description="test desc") - collection = Collection.from_dict(collection_dict, root=catalog) - self.assertIs(collection.get_root(), catalog) - - def test_schema_summary(self) -> None: - collection = pystac.Collection.from_file( - TestCases.get_path( - "data-files/examples/1.0.0/collection-only/collection-with-schemas.json" - ) - ) - instruments_schema = get_required( - collection.summaries.get_schema("instruments"), - collection.summaries, - "instruments", - ) - - self.assertIsInstance(instruments_schema, dict) - - def test_from_invalid_dict_raises_exception(self) -> None: - stac_io = pystac.StacIO.default() - catalog_dict = stac_io.read_json( - TestCases.get_path("data-files/catalogs/test-case-1/catalog.json") - ) - with self.assertRaises(pystac.STACTypeError): - _ = pystac.Collection.from_dict(catalog_dict) - - def test_clone_preserves_assets(self) -> None: - path = TestCases.get_path("data-files/collections/with-assets.json") - original_collection = Collection.from_file(path) - assert len(original_collection.assets) > 0 - assert all( - asset.owner is original_collection - for asset in original_collection.assets.values() - ) - - cloned_collection = original_collection.clone() - - for key in original_collection.assets: - with self.subTest(f"Preserves {key} asset"): - self.assertIn(key, cloned_collection.assets) - cloned_asset = cloned_collection.assets.get(key) - if cloned_asset is not None: - with self.subTest(f"Sets owner for {key}"): - self.assertIs(cloned_asset.owner, cloned_collection) - - def test_to_dict_no_self_href(self) -> None: - temporal_extent = TemporalExtent(intervals=[[TEST_DATETIME, None]]) - spatial_extent = SpatialExtent(bboxes=ARBITRARY_BBOX) - extent = Extent(spatial=spatial_extent, temporal=temporal_extent) - collection = Collection( - id="an-id", description="A test Collection", extent=extent +def test_provider_to_from_dict() -> None: + provider_dict = { + "name": "Remote Data, Inc", + "description": "Producers of awesome spatiotemporal assets", + "roles": ["producer", "processor"], + "url": "http://remotedata.io", + "extension:field": "some value", + } + expected_extra_fields = {"extension:field": provider_dict["extension:field"]} + + provider = Provider.from_dict(provider_dict) + + assert ( + provider_dict["name"], + provider_dict["description"], + provider_dict["roles"], + provider_dict["url"], + expected_extra_fields, + provider_dict, + ) == ( + provider.name, + provider.description, + provider.roles, + provider.url, + provider.extra_fields, + provider.to_dict(), + ) + + +def test_spatial_extent_from_coordinates() -> None: + extent = SpatialExtent.from_coordinates(ARBITRARY_GEOM["coordinates"]) + + assert len(extent.bboxes) == 1 + bbox = extent.bboxes[0] + assert len(bbox) == 4 + for x in bbox: + assert isinstance(x, float) + + +def test_read_eo_items_are_heritable() -> None: + cat = TestCases.case_5() + item = next(cat.get_items(recursive=True)) + + assert EOExtension.has_extension(item) + + +def test_save_uses_previous_catalog_type() -> None: + collection = TestCases.case_8() + assert collection.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION + assert collection.catalog_type == CatalogType.SELF_CONTAINED + with tempfile.TemporaryDirectory() as tmp_dir: + collection.normalize_hrefs(tmp_dir) + href = collection.self_href + collection.save() + + collection2 = pystac.Collection.from_file(href) + assert collection2.catalog_type == CatalogType.SELF_CONTAINED + + +def test_clone_uses_previous_catalog_type() -> None: + catalog = TestCases.case_8() + assert catalog.catalog_type == CatalogType.SELF_CONTAINED + clone = catalog.clone() + assert clone.catalog_type == CatalogType.SELF_CONTAINED + + +def test_clone_cant_mutate_original() -> None: + collection = TestCases.case_8() + assert collection.keywords == ["disaster", "open"] + clone = collection.clone() + clone.extra_fields["test"] = "extra" + assert "test" not in collection.extra_fields + assert clone.keywords is not None + clone.keywords.append("clone") + assert clone.keywords == ["disaster", "open", "clone"] + assert collection.keywords == ["disaster", "open"] + assert id(collection.summaries) != id(clone.summaries) + + +def test_multiple_extents() -> None: + cat1 = TestCases.case_1() + country = cat1.get_child("country-1") + assert country is not None + col1 = country.get_child("area-1-1") + assert col1 is not None + col1.validate() + assert isinstance(col1, Collection) + validate_dict(col1.to_dict(), pystac.STACObjectType.COLLECTION) + + multi_ext_uri = TestCases.get_path("data-files/collections/multi-extent.json") + with open(multi_ext_uri) as f: + multi_ext_dict = json.load(f) + validate_dict(multi_ext_dict, pystac.STACObjectType.COLLECTION) + assert isinstance(Collection.from_dict(multi_ext_dict), Collection) + + multi_ext_col = Collection.from_file(multi_ext_uri) + multi_ext_col.validate() + ext = multi_ext_col.extent + extent_dict = multi_ext_dict["extent"] + assert isinstance(ext, Extent) + assert isinstance(ext.spatial.bboxes[0], list) + assert len(ext.spatial.bboxes) == 3 + assert ext.to_dict() == extent_dict + + cloned_ext = ext.clone() + assert cloned_ext.to_dict() == multi_ext_dict["extent"] + + +def test_extra_fields() -> None: + catalog = TestCases.case_2() + collection = catalog.get_child("1a8c1632-fa91-4a62-b33e-3a87c2ebdf16") + assert collection is not None + + collection.extra_fields["test"] = "extra" + + with tempfile.TemporaryDirectory() as tmp_dir: + p = os.path.join(tmp_dir, "collection.json") + collection.save_object(include_self_link=False, dest_href=p) + with open(p) as f: + col_json = json.load(f) + assert "test" in col_json + assert col_json["test"] == "extra" + + read_col = pystac.Collection.from_file(p) + assert "test" in read_col.extra_fields + assert read_col.extra_fields["test"] == "extra" + + +def test_update_extents() -> None: + catalog = TestCases.case_2() + base_collection = catalog.get_child("1a8c1632-fa91-4a62-b33e-3a87c2ebdf16") + assert isinstance(base_collection, Collection) + base_extent = base_collection.extent + collection = base_collection.clone() + + item1 = Item( + id="test-item-1", + geometry=ARBITRARY_GEOM, + bbox=[-180, -90, 180, 90], + datetime=TEST_DATETIME, + properties={"key": "one"}, + stac_extensions=["eo", "commons"], + ) + + item2 = Item( + id="test-item-1", + geometry=ARBITRARY_GEOM, + bbox=[-180, -90, 180, 90], + datetime=None, + properties={ + "start_datetime": datetime_to_str(datetime(2000, 1, 1, 12, 0, 0, 0)), + "end_datetime": datetime_to_str(datetime(2000, 2, 1, 12, 0, 0, 0)), + }, + stac_extensions=["eo", "commons"], + ) + + collection.add_item(item1) + + collection.update_extent_from_items() + assert [[-180, -90, 180, 90]] == collection.extent.spatial.bboxes + assert len(base_extent.spatial.bboxes[0]) == len( + collection.extent.spatial.bboxes[0] + ) + assert base_extent.temporal.intervals != collection.extent.temporal.intervals + + collection.remove_item("test-item-1") + collection.update_extent_from_items() + assert [[-180, -90, 180, 90]] != collection.extent.spatial.bboxes + collection.add_item(item2) + + collection.update_extent_from_items() + + assert [ + [ + item2.common_metadata.start_datetime, + base_extent.temporal.intervals[0][1], + ] + ] == collection.extent.temporal.intervals + + +def test_supplying_href_in_init_does_not_fail() -> None: + test_href = "http://example.com/collection.json" + spatial_extent = SpatialExtent(bboxes=[ARBITRARY_BBOX]) + temporal_extent = TemporalExtent(intervals=[[TEST_DATETIME, None]]) + + collection_extent = Extent(spatial=spatial_extent, temporal=temporal_extent) + collection = Collection( + id="test", description="test desc", extent=collection_extent, href=test_href + ) + + assert collection.get_self_href() == test_href + + +def test_collection_with_href_caches_by_href() -> None: + collection = pystac.Collection.from_file( + TestCases.get_path("data-files/examples/hand-0.8.1/collection.json") + ) + cache = collection._resolved_objects + + # Since all of our STAC objects have HREFs, everything should be + # cached only by HREF + assert len(cache.id_keys_to_objects) == 0 + + +@pytest.mark.block_network +def test_assets() -> None: + path = TestCases.get_path("data-files/collections/with-assets.json") + with open(path) as f: + data = json.load(f) + collection = pystac.Collection.from_dict(data) + collection.validate() + + +def test_get_assets() -> None: + collection = pystac.Collection.from_file( + TestCases.get_path("data-files/collections/with-assets.json") + ) + + media_type_filter = collection.get_assets(media_type=pystac.MediaType.PNG) + assert list(media_type_filter.keys()) == ["thumbnail"] + role_filter = collection.get_assets(role="thumbnail") + assert list(role_filter.keys()) == ["thumbnail"] + multi_filter = collection.get_assets( + media_type=pystac.MediaType.PNG, role="thumbnail" + ) + assert list(multi_filter.keys()) == ["thumbnail"] + + no_filter = collection.get_assets() + assert no_filter is not collection.assets + assert list(no_filter.keys()) == ["thumbnail"] + no_filter["thumbnail"].description = "foo" + assert collection.assets["thumbnail"].description != "foo" + + no_assets = collection.get_assets(media_type=pystac.MediaType.HDF) + assert no_assets == {} + + +def test_removing_optional_attributes() -> None: + path = TestCases.get_path("data-files/collections/with-assets.json") + with open(path) as file: + data = json.load(file) + data["title"] = "dummy title" + data["stac_extensions"] = ["dummy extension"] + data["keywords"] = ["key", "word"] + data["providers"] = [{"name": "pystac"}] + collection = pystac.Collection.from_dict(data) + + # Assert we have everything set + assert collection.title + assert collection.stac_extensions + assert collection.keywords + assert collection.providers + assert collection.summaries + assert collection.assets + + # Remove all of the optional stuff + collection.title = None + collection.stac_extensions = [] + collection.keywords = [] + collection.providers = [] + collection.summaries = pystac.Summaries({}) + collection.assets = {} + + collection_as_dict = collection.to_dict() + for key in ( + "title", + "stac_extensions", + "keywords", + "providers", + "summaries", + "assets", + ): + assert key not in collection_as_dict + + +def test_from_dict_preserves_dict() -> None: + path = TestCases.get_path("data-files/collections/with-assets.json") + with open(path) as f: + collection_dict = json.load(f) + param_dict = deepcopy(collection_dict) + + # test that the parameter is preserved + _ = Collection.from_dict(param_dict) + assert param_dict == collection_dict + + # assert that the parameter is not preserved with + # non-default parameter + _ = Collection.from_dict(param_dict, preserve_dict=False, migrate=False) + assert param_dict != collection_dict + + +def test_from_dict_set_root() -> None: + path = TestCases.get_path("data-files/examples/hand-0.8.1/collection.json") + with open(path) as f: + collection_dict = json.load(f) + catalog = pystac.Catalog(id="test", description="test desc") + collection = Collection.from_dict(collection_dict, root=catalog) + assert collection.get_root() is catalog + + +def test_schema_summary() -> None: + collection = pystac.Collection.from_file( + TestCases.get_path( + "data-files/examples/1.0.0/collection-only/collection-with-schemas.json" ) - d = collection.to_dict(include_self_link=False) - Collection.from_dict(d) - - -class ExtentTest(unittest.TestCase): - def setUp(self) -> None: - self.maxDiff = None - - def test_temporal_extent_init_typing(self) -> None: - # This test exists purely to test the typing of the intervals argument to - # TemporalExtent - start_datetime = str_to_datetime("2022-01-01T00:00:00Z") - end_datetime = str_to_datetime("2022-01-31T23:59:59Z") - - _ = TemporalExtent([[start_datetime, end_datetime]]) - - @pytest.mark.block_network() - def test_temporal_extent_allows_single_interval(self) -> None: - start_datetime = str_to_datetime("2022-01-01T00:00:00Z") - end_datetime = str_to_datetime("2022-01-31T23:59:59Z") - - interval = [start_datetime, end_datetime] - temporal_extent = TemporalExtent(intervals=interval) # type: ignore - - self.assertEqual(temporal_extent.intervals, [interval]) - - @pytest.mark.block_network() - def test_temporal_extent_allows_single_interval_open_start(self) -> None: - end_datetime = str_to_datetime("2022-01-31T23:59:59Z") - - interval = [None, end_datetime] - temporal_extent = TemporalExtent(intervals=interval) - - self.assertEqual(temporal_extent.intervals, [interval]) - - @pytest.mark.block_network() - def test_temporal_extent_non_list_intervals_fails(self) -> None: - with pytest.raises(TypeError): - # Pass in non-list intervals - _ = TemporalExtent(intervals=1) # type: ignore - - @pytest.mark.block_network() - def test_spatial_allows_single_bbox(self) -> None: - temporal_extent = TemporalExtent(intervals=[[TEST_DATETIME, None]]) - - # Pass in a single BBOX - spatial_extent = SpatialExtent(bboxes=ARBITRARY_BBOX) + ) + instruments_schema = get_required( + collection.summaries.get_schema("instruments"), + collection.summaries, + "instruments", + ) - collection_extent = Extent(spatial=spatial_extent, temporal=temporal_extent) + assert isinstance(instruments_schema, dict) - collection = Collection( - id="test", description="test desc", extent=collection_extent - ) - # HREF required by validation - collection.set_self_href("https://example.com/collection.json") +def test_from_invalid_dict_raises_exception() -> None: + stac_io = pystac.StacIO.default() + catalog_dict = stac_io.read_json( + TestCases.get_path("data-files/catalogs/test-case-1/catalog.json") + ) + with pytest.raises(pystac.STACTypeError): + _ = pystac.Collection.from_dict(catalog_dict) - collection.validate() - @pytest.mark.block_network() - def test_spatial_extent_non_list_bboxes_fails(self) -> None: - with pytest.raises(TypeError): - # Pass in non-list bboxes - _ = SpatialExtent(bboxes=1) # type: ignore +def test_clone_preserves_assets() -> None: + path = TestCases.get_path("data-files/collections/with-assets.json") + original_collection = Collection.from_file(path) + assert len(original_collection.assets) > 0 + assert all( + asset.owner is original_collection + for asset in original_collection.assets.values() + ) - def test_from_items(self) -> None: - item1 = Item( - id="test-item-1", - geometry=ARBITRARY_GEOM, - bbox=[-10, -20, 0, -10], - datetime=datetime(2000, 2, 1, 12, 0, 0, 0, tzinfo=tz.UTC), - properties={}, - ) + cloned_collection = original_collection.clone() - item2 = Item( - id="test-item-2", - geometry=ARBITRARY_GEOM, - bbox=[0, -9, 10, 1], - datetime=None, - properties={ - "start_datetime": datetime_to_str( - datetime(2000, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC) - ), - "end_datetime": datetime_to_str( - datetime(2000, 7, 1, 12, 0, 0, 0, tzinfo=tz.UTC) - ), - }, - ) + for key in original_collection.assets: + assert key in cloned_collection.assets, f"Failed to Preserve {key} asset" + cloned_asset = cloned_collection.assets.get(key) + if cloned_asset is not None: + assert ( + cloned_asset.owner is cloned_collection + ), f"Failed to set owner for {key}" - item3 = Item( - id="test-item-2", - geometry=ARBITRARY_GEOM, - bbox=[-5, -20, 5, 0], - datetime=None, - properties={ - "start_datetime": datetime_to_str( - datetime(2000, 12, 1, 12, 0, 0, 0, tzinfo=tz.UTC) - ), - "end_datetime": datetime_to_str( - datetime(2001, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC), - ), - }, - ) - extent = Extent.from_items([item1, item2, item3]) - - self.assertEqual(len(extent.spatial.bboxes), 1) - self.assertEqual(extent.spatial.bboxes[0], [-10, -20, 10, 1]) - - self.assertEqual(len(extent.temporal.intervals), 1) - interval = extent.temporal.intervals[0] - - self.assertEqual(interval[0], datetime(2000, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC)) - self.assertEqual(interval[1], datetime(2001, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC)) - - def test_to_from_dict(self) -> None: - spatial_dict = { - "bbox": [ - [ - 172.91173669923782, - 1.3438851951615003, - 172.95469614953714, - 1.3690476620161975, - ] - ], - "extension:field": "spatial value", - } - temporal_dict = { - "interval": [ - ["2020-12-11T22:38:32.125000Z", "2020-12-14T18:02:31.437000Z"] - ], - "extension:field": "temporal value", - } - extent_dict = { - "spatial": spatial_dict, - "temporal": temporal_dict, - "extension:field": "extent value", - } - expected_extent_extra_fields = { - "extension:field": extent_dict["extension:field"], - } - expected_spatial_extra_fields = { - "extension:field": spatial_dict["extension:field"], - } - expected_temporal_extra_fields = { - "extension:field": temporal_dict["extension:field"], - } - - extent = Extent.from_dict(extent_dict) - - self.assertDictEqual(expected_extent_extra_fields, extent.extra_fields) - self.assertDictEqual(expected_spatial_extra_fields, extent.spatial.extra_fields) - self.assertDictEqual( - expected_temporal_extra_fields, extent.temporal.extra_fields - ) +def test_to_dict_no_self_href() -> None: + temporal_extent = TemporalExtent(intervals=[[TEST_DATETIME, None]]) + spatial_extent = SpatialExtent(bboxes=ARBITRARY_BBOX) + extent = Extent(spatial=spatial_extent, temporal=temporal_extent) + collection = Collection(id="an-id", description="A test Collection", extent=extent) + d = collection.to_dict(include_self_link=False) + Collection.from_dict(d) + + +def test_temporal_extent_init_typing() -> None: + # This test exists purely to test the typing of the intervals argument to + # TemporalExtent + start_datetime = str_to_datetime("2022-01-01T00:00:00Z") + end_datetime = str_to_datetime("2022-01-31T23:59:59Z") + + _ = TemporalExtent([[start_datetime, end_datetime]]) + + +@pytest.mark.block_network() +def test_temporal_extent_allows_single_interval() -> None: + start_datetime = str_to_datetime("2022-01-01T00:00:00Z") + end_datetime = str_to_datetime("2022-01-31T23:59:59Z") + + interval = [start_datetime, end_datetime] + temporal_extent = TemporalExtent(intervals=interval) # type: ignore + + assert temporal_extent.intervals == [interval] + + +@pytest.mark.block_network() +def test_temporal_extent_allows_single_interval_open_start() -> None: + end_datetime = str_to_datetime("2022-01-31T23:59:59Z") + + interval = [None, end_datetime] + temporal_extent = TemporalExtent(intervals=interval) + + assert temporal_extent.intervals == [interval] + + +@pytest.mark.block_network() +def test_temporal_extent_non_list_intervals_fails() -> None: + with pytest.raises(TypeError): + # Pass in non-list intervals + _ = TemporalExtent(intervals=1) # type: ignore + + +@pytest.mark.block_network() +def test_spatial_allows_single_bbox() -> None: + temporal_extent = TemporalExtent(intervals=[[TEST_DATETIME, None]]) + + # Pass in a single BBOX + spatial_extent = SpatialExtent(bboxes=ARBITRARY_BBOX) + + collection_extent = Extent(spatial=spatial_extent, temporal=temporal_extent) + + collection = Collection( + id="test", description="test desc", extent=collection_extent + ) + + # HREF required by validation + collection.set_self_href("https://example.com/collection.json") + + collection.validate() + + +@pytest.mark.block_network() +def test_spatial_extent_non_list_bboxes_fails() -> None: + with pytest.raises(TypeError): + # Pass in non-list bboxes + _ = SpatialExtent(bboxes=1) # type: ignore + + +def test_extent_from_items() -> None: + item1 = Item( + id="test-item-1", + geometry=ARBITRARY_GEOM, + bbox=[-10, -20, 0, -10], + datetime=datetime(2000, 2, 1, 12, 0, 0, 0, tzinfo=tz.UTC), + properties={}, + ) + + item2 = Item( + id="test-item-2", + geometry=ARBITRARY_GEOM, + bbox=[0, -9, 10, 1], + datetime=None, + properties={ + "start_datetime": datetime_to_str( + datetime(2000, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC) + ), + "end_datetime": datetime_to_str( + datetime(2000, 7, 1, 12, 0, 0, 0, tzinfo=tz.UTC) + ), + }, + ) - self.assertDictEqual(extent_dict, extent.to_dict()) + item3 = Item( + id="test-item-2", + geometry=ARBITRARY_GEOM, + bbox=[-5, -20, 5, 0], + datetime=None, + properties={ + "start_datetime": datetime_to_str( + datetime(2000, 12, 1, 12, 0, 0, 0, tzinfo=tz.UTC) + ), + "end_datetime": datetime_to_str( + datetime(2001, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC), + ), + }, + ) + extent = Extent.from_items([item1, item2, item3]) + assert len(extent.spatial.bboxes) == 1 + assert extent.spatial.bboxes[0] == [-10, -20, 10, 1] + assert len(extent.temporal.intervals) == 1 -class CollectionSubClassTest(unittest.TestCase): + interval = extent.temporal.intervals[0] + assert interval[0] == datetime(2000, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC) + assert interval[1] == datetime(2001, 1, 1, 12, 0, 0, 0, tzinfo=tz.UTC) + + +def test_extent_to_from_dict() -> None: + spatial_dict = { + "bbox": [ + [ + 172.91173669923782, + 1.3438851951615003, + 172.95469614953714, + 1.3690476620161975, + ] + ], + "extension:field": "spatial value", + } + temporal_dict = { + "interval": [["2020-12-11T22:38:32.125000Z", "2020-12-14T18:02:31.437000Z"]], + "extension:field": "temporal value", + } + extent_dict = { + "spatial": spatial_dict, + "temporal": temporal_dict, + "extension:field": "extent value", + } + expected_extent_extra_fields = { + "extension:field": extent_dict["extension:field"], + } + expected_spatial_extra_fields = { + "extension:field": spatial_dict["extension:field"], + } + expected_temporal_extra_fields = { + "extension:field": temporal_dict["extension:field"], + } + + extent = Extent.from_dict(extent_dict) + + assert expected_extent_extra_fields == extent.extra_fields + assert expected_spatial_extra_fields == extent.spatial.extra_fields + assert expected_temporal_extra_fields == extent.temporal.extra_fields + + assert extent_dict == extent.to_dict() + + +class TestCollectionSubClass: """This tests cases related to creating classes inheriting from pystac.Catalog to ensure that inheritance, class methods, etc. function as expected.""" @@ -537,25 +549,23 @@ def get_items(self) -> Iterator[Item]: # type: ignore # backwards compatibility of inherited classes return super().get_items() - def setUp(self) -> None: - self.stac_io = pystac.StacIO.default() - def test_from_dict_returns_subclass(self) -> None: - collection_dict = self.stac_io.read_json(self.MULTI_EXTENT) + stac_io = pystac.StacIO.default() + collection_dict = stac_io.read_json(self.MULTI_EXTENT) custom_collection = self.BasicCustomCollection.from_dict(collection_dict) - self.assertIsInstance(custom_collection, self.BasicCustomCollection) + assert isinstance(custom_collection, self.BasicCustomCollection) def test_from_file_returns_subclass(self) -> None: custom_collection = self.BasicCustomCollection.from_file(self.MULTI_EXTENT) - self.assertIsInstance(custom_collection, self.BasicCustomCollection) + assert isinstance(custom_collection, self.BasicCustomCollection) def test_clone(self) -> None: custom_collection = self.BasicCustomCollection.from_file(self.MULTI_EXTENT) cloned_collection = custom_collection.clone() - self.assertIsInstance(cloned_collection, self.BasicCustomCollection) + assert isinstance(cloned_collection, self.BasicCustomCollection) def test_collection_get_item_works(self) -> None: path = TestCases.get_path( @@ -567,7 +577,7 @@ def test_collection_get_item_works(self) -> None: collection.get_item("area-1-1-imagery") -class CollectionPartialSubClassTest(unittest.TestCase): +def test_collection_get_item_raises_type_error() -> None: class BasicCustomCollection(pystac.Collection): def get_items( # type: ignore self, *, recursive: bool = False @@ -575,14 +585,13 @@ def get_items( # type: ignore # This get_items does not allow ids as args. return super().get_items(recursive=recursive) - def test_collection_get_item_raises_type_error(self) -> None: - path = TestCases.get_path( - "data-files/catalogs/test-case-1/country-1/area-1-1/collection.json" - ) - custom_collection = self.BasicCustomCollection.from_file(path) - collection = custom_collection.clone() - with pytest.raises(TypeError, match="takes 1 positional argument"): - collection.get_item("area-1-1-imagery") + path = TestCases.get_path( + "data-files/catalogs/test-case-1/country-1/area-1-1/collection.json" + ) + custom_collection = BasicCustomCollection.from_file(path) + collection = custom_collection.clone() + with pytest.raises(TypeError, match="takes 1 positional argument"): + collection.get_item("area-1-1-imagery") def test_custom_collection_from_dict(collection: Collection) -> None: diff --git a/tests/test_item_assets.py b/tests/test_item_assets.py index 0e5ac895f..167f1bd26 100644 --- a/tests/test_item_assets.py +++ b/tests/test_item_assets.py @@ -1,5 +1,3 @@ -import unittest - import pytest from pystac import Collection @@ -13,43 +11,18 @@ ) -class TestItemAssets(unittest.TestCase): - def setUp(self) -> None: - self.maxDiff = None - self.collection = Collection.from_file( - TestCases.get_path("data-files/item-assets/example-landsat8.json") - ) - - def test_example(self) -> None: - collection = self.collection.clone() - - self.assertEqual(len(collection.item_assets), 13) - - self.assertEqual( - collection.item_assets["B1"], - ItemAssetDefinition( - { - "type": "image/tiff; application=geotiff", - "eo:bands": [ - { - "name": "B1", - "common_name": "coastal", - "center_wavelength": 0.44, - "full_width_half_max": 0.02, - } - ], - "title": "Coastal Band (B1)", - "description": "Coastal Band Top Of the Atmosphere", - } - ), - ) +@pytest.fixture +def landsat8_collection() -> Collection: + return Collection.from_file( + TestCases.get_path("data-files/item-assets/example-landsat8.json") + ) - def test_set_using_dict(self) -> None: - collection = self.collection.clone() - self.assertEqual(len(collection.item_assets), 13) +def test_example(landsat8_collection: Collection) -> None: + assert len(landsat8_collection.item_assets) == 13 - collection.item_assets["Bx"] = { + assert landsat8_collection.item_assets["B1"] == ItemAssetDefinition( + { "type": "image/tiff; application=geotiff", "eo:bands": [ { @@ -61,20 +34,35 @@ def test_set_using_dict(self) -> None: ], "title": "Coastal Band (B1)", "description": "Coastal Band Top Of the Atmosphere", - } # type:ignore + } + ) - self.assertEqual(collection.item_assets["B1"], collection.item_assets["Bx"]) +def test_set_using_dict(landsat8_collection: Collection) -> None: + assert len(landsat8_collection.item_assets) == 13 -class TestAssetDefinition(unittest.TestCase): - def setUp(self) -> None: - self.maxDiff = None - self.collection = Collection.from_file( - TestCases.get_path("data-files/item-assets/example-landsat8.json") - ) + landsat8_collection.item_assets["Bx"] = { + "type": "image/tiff; application=geotiff", + "eo:bands": [ + { + "name": "B1", + "common_name": "coastal", + "center_wavelength": 0.44, + "full_width_half_max": 0.02, + } + ], + "title": "Coastal Band (B1)", + "description": "Coastal Band Top Of the Atmosphere", + } # type:ignore + + assert ( + landsat8_collection.item_assets["B1"] == landsat8_collection.item_assets["Bx"] + ) - def test_eq(self) -> None: - assert self.collection.item_assets["B1"] != {"title": "Coastal Band (B1)"} + +class TestAssetDefinition: + def test_eq(self, landsat8_collection: Collection) -> None: + assert landsat8_collection.item_assets["B1"] != {"title": "Coastal Band (B1)"} def test_create(self) -> None: title = "Coastal Band (B1)" @@ -84,10 +72,12 @@ def test_create(self) -> None: asset_defn = ItemAssetDefinition.create( title=title, description=description, media_type=media_type, roles=roles ) - self.assertEqual(asset_defn.title, title) - self.assertEqual(asset_defn.description, description) - self.assertEqual(asset_defn.media_type, media_type) - self.assertEqual(asset_defn.roles, roles) + assert ( + asset_defn.title, + asset_defn.description, + asset_defn.media_type, + asset_defn.roles, + ) == (title, description, media_type, roles) def test_title(self) -> None: asset_defn = ItemAssetDefinition({}) @@ -95,8 +85,7 @@ def test_title(self) -> None: asset_defn.title = title - self.assertEqual(asset_defn.title, title) - self.assertEqual(asset_defn.to_dict()["title"], title) + assert asset_defn.title == asset_defn.to_dict()["title"] == title def test_description(self) -> None: asset_defn = ItemAssetDefinition({}) @@ -104,8 +93,9 @@ def test_description(self) -> None: asset_defn.description = description - self.assertEqual(asset_defn.description, description) - self.assertEqual(asset_defn.to_dict()["description"], description) + assert ( + asset_defn.description == asset_defn.to_dict()["description"] == description + ) def test_media_type(self) -> None: asset_defn = ItemAssetDefinition({}) @@ -113,8 +103,7 @@ def test_media_type(self) -> None: asset_defn.media_type = media_type - self.assertEqual(asset_defn.media_type, media_type) - self.assertEqual(asset_defn.to_dict()["type"], media_type) + assert asset_defn.media_type == asset_defn.to_dict()["type"] == media_type def test_roles(self) -> None: asset_defn = ItemAssetDefinition({}) @@ -122,10 +111,9 @@ def test_roles(self) -> None: asset_defn.roles = roles - self.assertEqual(asset_defn.roles, roles) - self.assertEqual(asset_defn.to_dict()["roles"], roles) + assert asset_defn.roles == asset_defn.to_dict()["roles"] == roles - def test_set_owner(self) -> None: + def test_set_owner(self, landsat8_collection: Collection) -> None: asset_definition = ItemAssetDefinition( { "type": "image/tiff; application=geotiff", @@ -141,8 +129,8 @@ def test_set_owner(self) -> None: "description": "Coastal Band Top Of the Atmosphere", } ) - asset_definition.set_owner(self.collection) - assert asset_definition.owner == self.collection + asset_definition.set_owner(landsat8_collection) + assert asset_definition.owner == landsat8_collection def test_extra_fields(collection: Collection) -> None: diff --git a/tests/test_item_collection.py b/tests/test_item_collection.py index 71dab23b1..56fc96fbb 100644 --- a/tests/test_item_collection.py +++ b/tests/test_item_collection.py @@ -1,186 +1,208 @@ import json -import unittest from copy import deepcopy from os.path import relpath +from typing import Any, cast + +import pytest import pystac +from pystac import Item, StacIO from pystac.item_collection import ItemCollection from tests.utils import TestCases from tests.utils.stac_io_mock import MockDefaultStacIO +SIMPLE_ITEM = TestCases.get_path("data-files/examples/1.0.0-RC1/simple-item.json") +CORE_ITEM = TestCases.get_path("data-files/examples/1.0.0-RC1/core-item.json") +EXTENDED_ITEM = TestCases.get_path("data-files/examples/1.0.0-RC1/extended-item.json") + +ITEM_COLLECTION = TestCases.get_path( + "data-files/item-collection/sample-item-collection.json" +) + + +@pytest.fixture +def item_collection_dict() -> dict[str, Any]: + with open(ITEM_COLLECTION) as src: + return cast(dict[str, Any], json.load(src)) + + +@pytest.fixture +def items(item_collection_dict: dict[str, Any]) -> list[Item]: + return [Item.from_dict(f) for f in item_collection_dict["features"]] + + +@pytest.fixture +def stac_io() -> StacIO: + return StacIO.default() + + +def test_item_collection_length( + item_collection_dict: dict[str, Any], items: list[Item] +) -> None: + item_collection = pystac.ItemCollection(items=items) + + assert len(item_collection) == len(items) + + +def test_item_collection_iter(items: list[Item]) -> None: + expected_ids = [item.id for item in items] + item_collection = pystac.ItemCollection(items=items) + + actual_ids = [item.id for item in item_collection] + + assert expected_ids == actual_ids + + +def test_item_collection_get_item_by_index(items: list[Item]) -> None: + expected_id = items[0].id + item_collection = pystac.ItemCollection(items=items) + + assert item_collection[0].id == expected_id + + +def test_item_collection_contains() -> None: + item = pystac.Item.from_file(SIMPLE_ITEM) + item_collection = pystac.ItemCollection(items=[item], clone_items=False) -class TestItemCollection(unittest.TestCase): - SIMPLE_ITEM = TestCases.get_path("data-files/examples/1.0.0-RC1/simple-item.json") - CORE_ITEM = TestCases.get_path("data-files/examples/1.0.0-RC1/core-item.json") - EXTENDED_ITEM = TestCases.get_path( - "data-files/examples/1.0.0-RC1/extended-item.json" + assert item in item_collection + + +def test_item_collection_extra_fields(items: list[Item]) -> None: + item_collection = pystac.ItemCollection( + items=items, extra_fields={"custom_field": "My value"} ) - ITEM_COLLECTION = TestCases.get_path( - "data-files/item-collection/sample-item-collection.json" + assert item_collection.extra_fields.get("custom_field") == "My value" + + +def test_item_collection_to_dict(items: list[Item]) -> None: + item_collection = pystac.ItemCollection( + items=items, extra_fields={"custom_field": "My value"} ) - def setUp(self) -> None: - self.maxDiff = None - with open(self.ITEM_COLLECTION) as src: - self.item_collection_dict = json.load(src) - self.items = [ - pystac.Item.from_dict(f) for f in self.item_collection_dict["features"] - ] - self.stac_io = pystac.StacIO.default() + d = item_collection.to_dict() - def test_item_collection_length(self) -> None: - item_collection = pystac.ItemCollection(items=self.items) + assert len(d["features"]) == len(items) + assert d.get("custom_field") == "My value" - self.assertEqual(len(item_collection), len(self.items)) - def test_item_collection_iter(self) -> None: - expected_ids = [item.id for item in self.items] - item_collection = pystac.ItemCollection(items=self.items) +def test_item_collection_from_dict(items: list[Item]) -> None: + features = [item.to_dict(transform_hrefs=False) for item in items] + d = { + "type": "FeatureCollection", + "features": features, + "custom_field": "My value", + } + item_collection = pystac.ItemCollection.from_dict(d) + expected = len(features) + assert expected == len(item_collection.items) + assert item_collection.extra_fields.get("custom_field") == "My value" - actual_ids = [item.id for item in item_collection] - self.assertListEqual(expected_ids, actual_ids) +def test_clone_item_collection() -> None: + item_collection_1 = pystac.ItemCollection.from_file(ITEM_COLLECTION) + item_collection_2 = item_collection_1.clone() - def test_item_collection_get_item_by_index(self) -> None: - expected_id = self.items[0].id - item_collection = pystac.ItemCollection(items=self.items) + item_ids_1 = [item.id for item in item_collection_1] + item_ids_2 = [item.id for item in item_collection_2] - self.assertEqual(item_collection[0].id, expected_id) + # All items from the original collection should be in the clone... + assert item_ids_1 == item_ids_2 + # ... but they should not be the same objects + assert item_collection_1[0] is not item_collection_2[0] - def test_item_collection_contains(self) -> None: - item = pystac.Item.from_file(self.SIMPLE_ITEM) - item_collection = pystac.ItemCollection(items=[item], clone_items=False) - self.assertIn(item, item_collection) +def test_raise_error_for_invalid_object(stac_io: StacIO) -> None: + item_dict = stac_io.read_json(SIMPLE_ITEM) - def test_item_collection_extra_fields(self) -> None: - item_collection = pystac.ItemCollection( - items=self.items, extra_fields={"custom_field": "My value"} - ) + with pytest.raises(pystac.STACTypeError): + _ = pystac.ItemCollection.from_dict(item_dict) - self.assertEqual(item_collection.extra_fields.get("custom_field"), "My value") - def test_item_collection_to_dict(self) -> None: - item_collection = pystac.ItemCollection( - items=self.items, extra_fields={"custom_field": "My value"} +def test_from_relative_path() -> None: + _ = pystac.ItemCollection.from_file( + relpath( + TestCases.get_path("data-files/item-collection/sample-item-collection.json") ) + ) - d = item_collection.to_dict() - - self.assertEqual(len(d["features"]), len(self.items)) - self.assertEqual(d.get("custom_field"), "My value") - - def test_item_collection_from_dict(self) -> None: - features = [item.to_dict(transform_hrefs=False) for item in self.items] - d = { - "type": "FeatureCollection", - "features": features, - "custom_field": "My value", - } - item_collection = pystac.ItemCollection.from_dict(d) - expected = len(features) - self.assertEqual(expected, len(item_collection.items)) - self.assertEqual(item_collection.extra_fields.get("custom_field"), "My value") - - def test_clone_item_collection(self) -> None: - item_collection_1 = pystac.ItemCollection.from_file(self.ITEM_COLLECTION) - item_collection_2 = item_collection_1.clone() - - item_ids_1 = [item.id for item in item_collection_1] - item_ids_2 = [item.id for item in item_collection_2] - - # All items from the original collection should be in the clone... - self.assertListEqual(item_ids_1, item_ids_2) - # ... but they should not be the same objects - self.assertIsNot(item_collection_1[0], item_collection_2[0]) - - def test_raise_error_for_invalid_object(self) -> None: - item_dict = self.stac_io.read_json(self.SIMPLE_ITEM) - - with self.assertRaises(pystac.STACTypeError): - _ = pystac.ItemCollection.from_dict(item_dict) - - def test_from_relative_path(self) -> None: - _ = pystac.ItemCollection.from_file( - relpath( - TestCases.get_path( - "data-files/item-collection/sample-item-collection.json" - ) - ) - ) - def test_from_list_of_dicts(self) -> None: - item_dict = self.stac_io.read_json(self.SIMPLE_ITEM) - item_collection = pystac.ItemCollection(items=[item_dict], clone_items=True) +def test_from_list_of_dicts(stac_io: StacIO) -> None: + item_dict = stac_io.read_json(SIMPLE_ITEM) + item_collection = pystac.ItemCollection(items=[item_dict], clone_items=True) - self.assertEqual(item_collection[0].id, item_dict.get("id")) + assert item_collection[0].id == item_dict.get("id") - def test_add_item_collections(self) -> None: - item_1 = pystac.Item.from_file(self.SIMPLE_ITEM) - item_2 = pystac.Item.from_file(self.EXTENDED_ITEM) - item_3 = pystac.Item.from_file(self.CORE_ITEM) - item_collection_1 = pystac.ItemCollection(items=[item_1, item_2]) - item_collection_2 = pystac.ItemCollection(items=[item_2, item_3]) +def test_add_item_collections() -> None: + item_1 = pystac.Item.from_file(SIMPLE_ITEM) + item_2 = pystac.Item.from_file(EXTENDED_ITEM) + item_3 = pystac.Item.from_file(CORE_ITEM) - combined = item_collection_1 + item_collection_2 + item_collection_1 = pystac.ItemCollection(items=[item_1, item_2]) + item_collection_2 = pystac.ItemCollection(items=[item_2, item_3]) - self.assertEqual(len(combined), 4) + combined = item_collection_1 + item_collection_2 - def test_add_other_raises_error(self) -> None: - item_collection = pystac.ItemCollection.from_file(self.ITEM_COLLECTION) + assert len(combined) == 4 - with self.assertRaises(TypeError): - _ = item_collection + 2 - def test_identify_0_8_itemcollection_type(self) -> None: - itemcollection_path = TestCases.get_path( - "data-files/examples/0.8.1/item-spec/" - "examples/itemcollection-sample-full.json" - ) - itemcollection_dict = pystac.StacIO.default().read_json(itemcollection_path) +def test_add_other_raises_error() -> None: + item_collection = pystac.ItemCollection.from_file(ITEM_COLLECTION) - self.assertTrue( - pystac.ItemCollection.is_item_collection(itemcollection_dict), - msg="Did not correctly identify valid STAC 0.8 ItemCollection.", - ) + with pytest.raises(TypeError): + _ = item_collection + 2 - def test_identify_0_9_itemcollection(self) -> None: - itemcollection_path = TestCases.get_path( - "data-files/examples/0.9.0/item-spec/" - "examples/itemcollection-sample-full.json" - ) - itemcollection_dict = pystac.StacIO.default().read_json(itemcollection_path) - self.assertTrue( - pystac.ItemCollection.is_item_collection(itemcollection_dict), - msg="Did not correctly identify valid STAC 0.9 ItemCollection.", - ) +def test_identify_0_8_itemcollection_type(stac_io: StacIO) -> None: + itemcollection_path = TestCases.get_path( + "data-files/examples/0.8.1/item-spec/" + "examples/itemcollection-sample-full.json" + ) + itemcollection_dict = stac_io.read_json(itemcollection_path) + + assert pystac.ItemCollection.is_item_collection( + itemcollection_dict + ), "Did not correctly identify valid STAC 0.8 ItemCollection." + + +def test_identify_0_9_itemcollection(stac_io: StacIO) -> None: + itemcollection_path = TestCases.get_path( + "data-files/examples/0.9.0/item-spec/" + "examples/itemcollection-sample-full.json" + ) + itemcollection_dict = stac_io.read_json(itemcollection_path) + + assert pystac.ItemCollection.is_item_collection( + itemcollection_dict + ), "Did not correctly identify valid STAC 0.9 ItemCollection." + + +def test_from_dict_preserves_dict(item_collection_dict: dict[str, Any]) -> None: + param_dict = deepcopy(item_collection_dict) + + # test that the parameter is preserved + _ = ItemCollection.from_dict(param_dict) + assert param_dict == item_collection_dict - def test_from_dict_preserves_dict(self) -> None: - param_dict = deepcopy(self.item_collection_dict) + # assert that the parameter is preserved regardless of + # preserve_dict + _ = ItemCollection.from_dict(param_dict, preserve_dict=False) + assert param_dict == item_collection_dict - # test that the parameter is preserved - _ = ItemCollection.from_dict(param_dict) - self.assertEqual(param_dict, self.item_collection_dict) - # assert that the parameter is preserved regardless of - # preserve_dict - _ = ItemCollection.from_dict(param_dict, preserve_dict=False) - self.assertEqual(param_dict, self.item_collection_dict) +def test_from_dict_sets_root(item_collection_dict: dict[str, Any]) -> None: + param_dict = deepcopy(item_collection_dict) + catalog = pystac.Catalog(id="test", description="test desc") + item_collection = ItemCollection.from_dict(param_dict, root=catalog) + for item in item_collection.items: + assert item.get_root() == catalog - def test_from_dict_sets_root(self) -> None: - param_dict = deepcopy(self.item_collection_dict) - catalog = pystac.Catalog(id="test", description="test desc") - item_collection = ItemCollection.from_dict(param_dict, root=catalog) - for item in item_collection.items: - self.assertEqual(item.get_root(), catalog) - def test_to_dict_does_not_read_root_link_of_items(self) -> None: - with MockDefaultStacIO() as mock_stac_io: - item_collection = pystac.ItemCollection.from_file(self.ITEM_COLLECTION) +def test_to_dict_does_not_read_root_link_of_items() -> None: + with MockDefaultStacIO() as mock_stac_io: + item_collection = pystac.ItemCollection.from_file(ITEM_COLLECTION) - item_collection.to_dict() + item_collection.to_dict() - self.assertEqual(mock_stac_io.mock.read_text.call_count, 1) + assert mock_stac_io.mock.read_text.call_count == 1