diff --git a/test/unit/extract/csv_test.py b/test/unit/extract/csv_test.py index 8f47bf653c..cb328851bb 100644 --- a/test/unit/extract/csv_test.py +++ b/test/unit/extract/csv_test.py @@ -1,7 +1,3 @@ -"""Unit tests for pudl.extract.csv module.""" - -from unittest.mock import MagicMock, patch - import pandas as pd import pytest @@ -16,16 +12,16 @@ class FakeExtractor(CsvExtractor): - def __init__(self): + def __init__(self, mocker): # TODO: Make these tests independent of the eia176 implementation self.METADATA = GenericMetadata("eia176") - super().__init__(ds=MagicMock()) + super().__init__(ds=mocker.MagicMock()) @pytest.fixture -def extractor(): - # Create an instance of the CsvExtractor class - return FakeExtractor() +def extractor(mocker): + # Create an instance of the CsvExtractor class with mocker + return FakeExtractor(mocker) def test_source_filename_valid_partition(extractor): @@ -45,8 +41,8 @@ def test_source_filename_multiple_selections(extractor): extractor.source_filename(PAGE, **multiple_selections) -@patch("pudl.extract.csv.pd") -def test_load_source(mock_pd, extractor): +def test_load_source(mocker, extractor): + mock_pd = mocker.patch("pudl.extract.csv.pd") assert extractor.load_source(PAGE, **PARTITION) == mock_pd.read_csv.return_value extractor.ds.get_zipfile_resource.assert_called_once_with(DATASET, **PARTITION) zipfile = extractor.ds.get_zipfile_resource.return_value.__enter__.return_value @@ -55,7 +51,7 @@ def test_load_source(mock_pd, extractor): mock_pd.read_csv.assert_called_once_with(file) -def test_extract(extractor): +def test_extract(mocker, extractor): # Create a sample of data we could expect from an EIA CSV company_field = "company" company_data = "Total of All Companies" @@ -64,22 +60,20 @@ def test_extract(extractor): # TODO: Once FakeExtractor is independent of eia176, mock out populating _column_map for PARTITION_SELECTION; # Also include negative tests, i.e., for partition selections not in the _column_map - with ( - patch.object(CsvExtractor, "load_source", return_value=df), - patch.object( - # Testing the rename - GenericMetadata, - "get_column_map", - return_value={"company_rename": company_field}, - ), - patch.object( - # Transposing the df here to get the orientation we expect get_page_cols to return - CsvExtractor, - "get_page_cols", - return_value=df.T.index, - ), - ): - res = extractor.extract(**PARTITION) + mocker.patch.object(CsvExtractor, "load_source", return_value=df) + # Testing the rename + mocker.patch.object( + GenericMetadata, + "get_column_map", + return_value={"company_rename": company_field}, + ) + # Transposing the df here to get the orientation we expect get_page_cols to return + mocker.patch.object( + CsvExtractor, + "get_page_cols", + return_value=df.T.index, + ) + res = extractor.extract(**PARTITION) assert len(res) == 1 # Assert only one page extracted assert list(res.keys()) == [PAGE] # Assert it is named correctly assert ( @@ -87,11 +81,9 @@ def test_extract(extractor): ) # Assert that column correctly renamed and data is there. -@patch.object(FakeExtractor, "METADATA") -def test_validate_exact_columns(mock_metadata, extractor): +def test_validate_exact_columns(mocker, extractor): # Mock the partition selection and page columns - # mock_metadata._get_partition_selection.return_value = "partition1" - extractor.get_page_cols = MagicMock(return_value={"col1", "col2"}) + extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"}) # Create a DataFrame with the exact expected columns df = pd.DataFrame(columns=["col1", "col2"]) @@ -100,11 +92,9 @@ def test_validate_exact_columns(mock_metadata, extractor): extractor.validate(df, "page1", partition="partition1") -@patch.object(FakeExtractor, "METADATA") -def test_validate_extra_columns(mock_metadata, extractor): +def test_validate_extra_columns(mocker, extractor): # Mock the partition selection and page columns - mock_metadata._get_partition_selection.return_value = "partition1" - extractor.get_page_cols = MagicMock(return_value={"col1", "col2"}) + extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"}) # Create a DataFrame with extra columns df = pd.DataFrame(columns=["col1", "col2", "col3"]) @@ -114,11 +104,9 @@ def test_validate_extra_columns(mock_metadata, extractor): extractor.validate(df, "page1", partition="partition1") -@patch.object(FakeExtractor, "METADATA") -def test_validate_missing_columns(mock_metadata, extractor): +def test_validate_missing_columns(mocker, extractor): # Mock the partition selection and page columns - mock_metadata._get_partition_selection.return_value = "partition1" - extractor.get_page_cols = MagicMock(return_value={"col1", "col2"}) + extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"}) # Create a DataFrame with missing columns df = pd.DataFrame(columns=["col1"]) @@ -130,11 +118,9 @@ def test_validate_missing_columns(mock_metadata, extractor): extractor.validate(df, "page1", partition="partition1") -@patch.object(FakeExtractor, "METADATA") -def test_validate_extra_and_missing_columns(mock_metadata, extractor): +def test_validate_extra_and_missing_columns(mocker, extractor): # Mock the partition selection and page columns - mock_metadata._get_partition_selection.return_value = "partition1" - extractor.get_page_cols = MagicMock(return_value={"col1", "col2"}) + extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"}) # Create a DataFrame with both extra and missing columns df = pd.DataFrame(columns=["col1", "col3"]) diff --git a/test/unit/extract/excel_test.py b/test/unit/extract/excel_test.py index f9a85f0ceb..2d2aa7de24 100644 --- a/test/unit/extract/excel_test.py +++ b/test/unit/extract/excel_test.py @@ -1,35 +1,42 @@ """Unit tests for pudl.extract.excel module.""" -import unittest -from unittest import mock as mock - import pandas as pd +import pytest from pudl.extract import excel -class TestMetadata(unittest.TestCase): +class TestMetadata: """Tests basic operation of the excel.Metadata object.""" + @pytest.fixture(autouse=True) def setUp(self): - """Cosntructs test metadata instance for testing.""" + """Constructs test metadata instance for testing.""" self._metadata = excel.ExcelMetadata("test") def test_basics(self): """Test that basic API method return expected results.""" - self.assertEqual("test", self._metadata.get_dataset_name()) - self.assertListEqual( - ["books", "boxes", "shoes"], self._metadata.get_all_pages() - ) - self.assertListEqual( - ["author", "pages", "title"], self._metadata.get_all_columns("books") - ) - self.assertDictEqual( - {"book_title": "title", "name": "author", "pages": "pages"}, - self._metadata.get_column_map("books", year=2010), - ) - self.assertEqual(10, self._metadata.get_skiprows("boxes", year=2011)) - self.assertEqual(1, self._metadata.get_sheet_name("boxes", year=2011)) + assert self._metadata.get_dataset_name() == "test" + assert self._metadata.get_all_pages() == ["books", "boxes", "shoes"] + assert self._metadata.get_all_columns("books") == ["author", "pages", "title"] + assert self._metadata.get_column_map("books", year=2010) == { + "book_title": "title", + "name": "author", + "pages": "pages", + } + assert self._metadata.get_skiprows("boxes", year=2011) == 10 + assert self._metadata.get_sheet_name("boxes", year=2011) == 1 + + def test_metadata_methods(self): + """Test various metadata methods.""" + assert self._metadata.get_all_columns("books") == ["author", "pages", "title"] + assert self._metadata.get_column_map("books", year=2010) == { + "book_title": "title", + "name": "author", + "pages": "pages", + } + assert self._metadata.get_skiprows("boxes", year=2011) == 10 + assert self._metadata.get_sheet_name("boxes", year=2011) == 1 class FakeExtractor(excel.ExcelExtractor): @@ -77,11 +84,10 @@ def _fake_data_frames(page_name, **kwargs): return fake_data[page_name] -class TestExtractor(unittest.TestCase): +class TestExtractor: """Test operation of the excel.Extractor class.""" - @staticmethod - def test_extract(): + def test_extract(self): extractor = FakeExtractor() res = extractor.extract(year=[2010, 2011]) expected_books = { @@ -103,7 +109,7 @@ def test_extract(): # def test_resulting_dataframes(self): # """Checks that pages across years are merged and columns are translated.""" # dfs = FakeExtractor().extract([2010, 2011], testing=True) - # self.assertEqual(set(['books', 'boxes']), set(dfs.keys())) + # assert set(['books', 'boxes']) == set(dfs.keys()) # pd.testing.assert_frame_equal( # pd.DataFrame(data={ # 'author': ['Laozi', 'Benjamin Hoff'], @@ -118,5 +124,5 @@ def test_extract(): # }), # dfs['boxes']) - # TODO(rousik@gmail.com): need to figure out how to test process_$x methods. - # TODO(rousik@gmail.com): we should test that empty columns are properly added. + # TODO: need to figure out how to test process_$x methods. + # TODO: we should test that empty columns are properly added. diff --git a/test/unit/extract/phmsagas_test.py b/test/unit/extract/phmsagas_test.py index ace2b9d966..2a6b21f1fc 100644 --- a/test/unit/extract/phmsagas_test.py +++ b/test/unit/extract/phmsagas_test.py @@ -1,5 +1,3 @@ -from unittest.mock import MagicMock, patch - import pandas as pd import pytest @@ -8,20 +6,20 @@ class FakeExtractor(Extractor): - def __init__(self): + def __init__(self, mocker): self.METADATA = ExcelMetadata("phmsagas") - super().__init__(ds=MagicMock()) - self._metadata = MagicMock() + super().__init__(ds=mocker.Mock()) + self._metadata = mocker.Mock() @pytest.fixture -def extractor(): +def extractor(mocker): # Create an instance of the CsvExtractor class - return FakeExtractor() + return FakeExtractor(mocker) -@patch("pudl.extract.phmsagas.logger") -def test_process_renamed_drop_columns(mock_logger, extractor): +def test_process_renamed_drop_columns(mocker, extractor): + mock_logger = mocker.patch("pudl.extract.phmsagas.logger") # Mock metadata methods extractor._metadata.get_form.return_value = "gas_transmission_gathering" extractor._metadata.get_all_columns.return_value = ["col1", "col2"] @@ -38,8 +36,8 @@ def test_process_renamed_drop_columns(mock_logger, extractor): mock_logger.info.assert_called_once() -@patch("pudl.extract.phmsagas.logger") -def test_process_renamed_keep_columns(mock_logger, extractor): +def test_process_renamed_keep_columns(mocker, extractor): + mock_logger = mocker.patch("pudl.extract.phmsagas.logger") # Mock metadata methods extractor._metadata.get_form.return_value = "gas_transmission_gathering" extractor._metadata.get_all_columns.return_value = ["col1", "col2"] @@ -56,8 +54,8 @@ def test_process_renamed_keep_columns(mock_logger, extractor): mock_logger.info.assert_not_called() -@patch("pudl.extract.phmsagas.logger") -def test_process_renamed_drop_unnamed_columns(mock_logger, extractor): +def test_process_renamed_drop_unnamed_columns(mocker, extractor): + mock_logger = mocker.patch("pudl.extract.phmsagas.logger") # Mock metadata methods extractor._metadata.get_form.return_value = "some_form" extractor._metadata.get_all_columns.return_value = ["col1", "col2"] @@ -74,8 +72,8 @@ def test_process_renamed_drop_unnamed_columns(mock_logger, extractor): mock_logger.warning.assert_not_called() -@patch("pudl.extract.phmsagas.logger") -def test_process_renamed_warn_unnamed_columns(mock_logger, extractor): +def test_process_renamed_warn_unnamed_columns(mocker, extractor): + mock_logger = mocker.patch("pudl.extract.phmsagas.logger") # Mock metadata methods extractor._metadata.get_form.return_value = "some_form" extractor._metadata.get_all_columns.return_value = ["col1", "col2"] diff --git a/test/unit/output/ferc1_test.py b/test/unit/output/ferc1_test.py index 7b87401f1d..c32478746f 100644 --- a/test/unit/output/ferc1_test.py +++ b/test/unit/output/ferc1_test.py @@ -19,7 +19,6 @@ """ import logging -import unittest from io import StringIO import networkx as nx @@ -37,11 +36,7 @@ logger = logging.getLogger(__name__) -class TestForestSetup(unittest.TestCase): - def setUp(self): - # this is where you add nodes you want to use - pass - +class TestForestSetup: def _exploded_calcs_from_edges(self, edges: list[tuple[NodeId, NodeId]]): records = [] for parent, child in edges: @@ -89,8 +84,8 @@ def build_forest_and_annotated_tags( return annotated_tags -class TestPrunnedNode(TestForestSetup): - def setUp(self): +class TestPrunedNode(TestForestSetup): + def setup_method(self): self.root = NodeId( table_name="table_1", xbrl_factoid="reported_1", @@ -133,7 +128,7 @@ def test_pruned_nodes(self): class TestTagPropagation(TestForestSetup): - def setUp(self): + def setup_method(self): self.parent = NodeId( table_name="table_1", xbrl_factoid="reported_1", diff --git a/test/unit/workspace/datastore_test.py b/test/unit/workspace/datastore_test.py index d6840ad1a6..06c6972246 100644 --- a/test/unit/workspace/datastore_test.py +++ b/test/unit/workspace/datastore_test.py @@ -3,7 +3,6 @@ import io import json import re -import unittest import zipfile from typing import Any @@ -46,154 +45,121 @@ def _make_descriptor( ) -class TestDatapackageDescriptor(unittest.TestCase): - """Unit tests for the DatapackageDescriptor class.""" +def test_get_partition_filters(): + desc = _make_descriptor( + "blabla", + "doi-123", + _make_resource("foo", group="first", color="red"), + _make_resource("bar", group="first", color="blue"), + _make_resource("baz", group="second", color="black", order=1), + ) + assert list(desc.get_partition_filters()) == [ + {"group": "first", "color": "red"}, + {"group": "first", "color": "blue"}, + {"group": "second", "color": "black", "order": 1}, + ] + assert list(desc.get_partition_filters(group="first")) == [ + {"group": "first", "color": "red"}, + {"group": "first", "color": "blue"}, + ] + assert list(desc.get_partition_filters(color="blue")) == [ + {"group": "first", "color": "blue"}, + ] + assert list(desc.get_partition_filters(color="blue", group="second")) == [] + + +def test_get_resource_path(): + """Check that get_resource_path returns correct paths.""" + desc = _make_descriptor( + "blabla", + "doi-123", + _make_resource("foo", group="first", color="red"), + _make_resource("bar", group="first", color="blue"), + ) + assert desc.get_resource_path("foo") == "http://localhost/foo" + assert desc.get_resource_path("bar") == "http://localhost/bar" + with pytest.raises(KeyError): + desc.get_resource_path("other") - def test_get_partition_filters(self): - desc = _make_descriptor( - "blabla", - "doi-123", - _make_resource("foo", group="first", color="red"), - _make_resource("bar", group="first", color="blue"), - _make_resource("baz", group="second", color="black", order=1), - ) - self.assertEqual( - [ - {"group": "first", "color": "red"}, - {"group": "first", "color": "blue"}, - {"group": "second", "color": "black", "order": 1}, - ], - list(desc.get_partition_filters()), - ) - self.assertEqual( - [ - {"group": "first", "color": "red"}, - {"group": "first", "color": "blue"}, - ], - list(desc.get_partition_filters(group="first")), - ) - self.assertEqual( - [ - {"group": "first", "color": "blue"}, - ], - list(desc.get_partition_filters(color="blue")), - ) - self.assertEqual( - [], list(desc.get_partition_filters(color="blue", group="second")) - ) - def test_get_resource_path(self): - """Check that get_resource_path returns correct paths.""" - desc = _make_descriptor( - "blabla", - "doi-123", - _make_resource("foo", group="first", color="red"), - _make_resource("bar", group="first", color="blue"), - ) - self.assertEqual("http://localhost/foo", desc.get_resource_path("foo")) - self.assertEqual("http://localhost/bar", desc.get_resource_path("bar")) - # The following resource does not exist and should throw KeyError - self.assertRaises(KeyError, desc.get_resource_path, "other") - - def test_modernize_zenodo_legacy_api_url(self): - legacy_url = "https://zenodo.org/api/files/082e4932-c772-4e9c-a670-376a1acc3748/datapackage.json" - - descriptor = datastore.DatapackageDescriptor( - {"resources": [{"name": "datapackage.json", "path": legacy_url}]}, - dataset="test", - doi="10.5281/zenodo.123123", - ) +def test_modernize_zenodo_legacy_api_url(): + legacy_url = "https://zenodo.org/api/files/082e4932-c772-4e9c-a670-376a1acc3748/datapackage.json" - assert ( - descriptor.get_resource_path("datapackage.json") - == "https://zenodo.org/records/123123/files/datapackage.json" - ) + descriptor = datastore.DatapackageDescriptor( + {"resources": [{"name": "datapackage.json", "path": legacy_url}]}, + dataset="test", + doi="10.5281/zenodo.123123", + ) - def test_get_resources_filtering(self): - """Verifies correct operation of get_resources().""" - desc = _make_descriptor( - "data", - "doi-123", - _make_resource("foo", group="first", color="red"), - _make_resource("bar", group="first", color="blue", rank=5), - _make_resource( - "baz", group="second", color="blue", rank=5, mood="VeryHappy" - ), - ) - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "foo"), - PudlResourceKey("data", "doi-123", "bar"), - PudlResourceKey("data", "doi-123", "baz"), - ], - list(desc.get_resources()), - ) - # Simple filtering by one attribute. - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "foo"), - PudlResourceKey("data", "doi-123", "bar"), - ], - list(desc.get_resources(group="first")), - ) - # Filter by two attributes - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "bar"), - ], - list(desc.get_resources(group="first", rank=5)), - ) - # Attributes that do not match anything - self.assertEqual( - [], - list(desc.get_resources(group="second", shape="square")), - ) - # Search attribute values are cast to lowercase strings - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "baz"), - ], - list(desc.get_resources(rank="5", mood="VERYhappy")), - ) - # Test lookup by name - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "foo"), - ], - list(desc.get_resources("foo")), - ) + assert ( + descriptor.get_resource_path("datapackage.json") + == "https://zenodo.org/records/123123/files/datapackage.json" + ) - def test_json_string_representation(self): - """Checks that json representation parses to the same dict.""" - desc = _make_descriptor( - "data", - "doi-123", - _make_resource("foo", group="first"), - _make_resource("bar", group="second"), - _make_resource("baz"), - ) - self.assertEqual( + +def test_get_resources_filtering(): + """Verifies correct operation of get_resources().""" + desc = _make_descriptor( + "data", + "doi-123", + _make_resource("foo", group="first", color="red"), + _make_resource("bar", group="first", color="blue", rank=5), + _make_resource("baz", group="second", color="blue", rank=5, mood="VeryHappy"), + ) + assert list(desc.get_resources()) == [ + PudlResourceKey("data", "doi-123", "foo"), + PudlResourceKey("data", "doi-123", "bar"), + PudlResourceKey("data", "doi-123", "baz"), + ] + # Simple filtering by one attribute. + assert list(desc.get_resources(group="first")) == [ + PudlResourceKey("data", "doi-123", "foo"), + PudlResourceKey("data", "doi-123", "bar"), + ] + # Filter by two attributes + assert list(desc.get_resources(group="first", rank=5)) == [ + PudlResourceKey("data", "doi-123", "bar"), + ] + # Attributes that do not match anything + assert list(desc.get_resources(group="second", shape="square")) == [] + # Search attribute values are cast to lowercase strings + assert list(desc.get_resources(rank="5", mood="VERYhappy")) == [ + PudlResourceKey("data", "doi-123", "baz"), + ] + # Test lookup by name + assert list(desc.get_resources("foo")) == [ + PudlResourceKey("data", "doi-123", "foo"), + ] + + +def test_json_string_representation(): + """Checks that json representation parses to the same dict.""" + desc = _make_descriptor( + "data", + "doi-123", + _make_resource("foo", group="first"), + _make_resource("bar", group="second"), + _make_resource("baz"), + ) + assert json.loads(desc.get_json_string()) == { + "resources": [ { - "resources": [ - { - "name": "foo", - "path": "http://localhost/foo", - "parts": {"group": "first"}, - }, - { - "name": "bar", - "path": "http://localhost/bar", - "parts": {"group": "second"}, - }, - { - "name": "baz", - "path": "http://localhost/baz", - "parts": {}, - }, - ], + "name": "foo", + "path": "http://localhost/foo", + "parts": {"group": "first"}, }, - json.loads(desc.get_json_string()), - ) + { + "name": "bar", + "path": "http://localhost/bar", + "parts": {"group": "second"}, + }, + { + "name": "baz", + "path": "http://localhost/baz", + "parts": {}, + }, + ], + } class MockableZenodoFetcher(datastore.ZenodoFetcher): @@ -210,7 +176,7 @@ def __init__( self._descriptor_cache = descriptors -class TestZenodoFetcher(unittest.TestCase): +class TestZenodoFetcher: """Unit tests for ZenodoFetcher class.""" MOCK_EPACEMS_DEPOSITION = { @@ -243,7 +209,8 @@ class TestZenodoFetcher(unittest.TestCase): r"^10\.(5072|5281)/zenodo\.(\d+)$", PROD_EPACEMS_DOI ).group(2) - def setUp(self): + @pytest.fixture(autouse=True) + def setup(self): """Constructs mockable Zenodo fetcher based on MOCK_EPACEMS_DATAPACKAGE.""" self.fetcher = MockableZenodoFetcher( descriptors={ @@ -263,38 +230,36 @@ def test_doi_format_is_correct(self): identified. """ zf = datastore.ZenodoFetcher() - self.assertTrue(zf.get_known_datasets()) + assert zf.get_known_datasets() for dataset, doi in zf.zenodo_dois: - self.assertTrue( - zf.get_doi(dataset) == doi, - msg=f"Zenodo DOI for {dataset} matches result of get_doi()", + assert zf.get_doi(dataset) == doi, ( + f"Zenodo DOI for {dataset} matches result of get_doi()" ) - self.assertFalse( - re.fullmatch(r"10\.5072/zenodo\.[0-9]{5,10}", doi), - msg=f"Zenodo sandbox DOI found for {dataset}: {doi}", + assert not re.fullmatch(r"10\.5072/zenodo\.[0-9]{5,10}", doi), ( + f"Zenodo sandbox DOI found for {dataset}: {doi}" ) - self.assertTrue( - re.fullmatch(r"10\.5281/zenodo\.[0-9]{5,10}", doi), - msg=f"Zenodo production DOI for {dataset} is {doi}", + assert re.fullmatch(r"10\.5281/zenodo\.[0-9]{5,10}", doi), ( + f"Zenodo production DOI for {dataset} is {doi}" ) def test_get_known_datasets(self): """Call to get_known_datasets() produces the expected results.""" - self.assertEqual( - sorted(name for name, doi in datastore.ZenodoFetcher().zenodo_dois), - self.fetcher.get_known_datasets(), + assert ( + sorted(name for name, doi in datastore.ZenodoFetcher().zenodo_dois) + == self.fetcher.get_known_datasets() ) def test_get_unknown_dataset(self): """Ensure that we get a failure when attempting to access an invalid dataset.""" - self.assertRaises(AttributeError, self.fetcher.get_doi, "unknown") + with pytest.raises(AttributeError): + self.fetcher.get_doi("unknown") def test_doi_of_prod_epacems_matches(self): """Most of the tests assume specific DOI for production epacems dataset. This test verifies that the expected value is in use. """ - self.assertEqual(self.PROD_EPACEMS_DOI, self.fetcher.get_doi("epacems")) + assert self.fetcher.get_doi("epacems") == self.PROD_EPACEMS_DOI @responses.activate def test_get_descriptor_http_calls(self): @@ -311,8 +276,7 @@ def test_get_descriptor_http_calls(self): json=self.MOCK_EPACEMS_DATAPACKAGE, ) desc = fetcher.get_descriptor("epacems") - self.assertEqual(self.MOCK_EPACEMS_DATAPACKAGE, desc.datapackage_json) - # self.assertTrue(responses.assert_call_count("http://localhost/my/datapackage.json", 1)) + assert desc.datapackage_json == self.MOCK_EPACEMS_DATAPACKAGE @responses.activate def test_get_resource(self): @@ -321,21 +285,21 @@ def test_get_resource(self): res = self.fetcher.get_resource( PudlResourceKey("epacems", self.PROD_EPACEMS_DOI, "first") ) - self.assertEqual(b"blah", res) + assert res == b"blah" @responses.activate def test_get_resource_with_invalid_checksum(self): """Test that resource with bad checksum raises ChecksumMismatchError.""" responses.add(responses.GET, "http://localhost/first", body="wrongContent") res = PudlResourceKey("epacems", self.PROD_EPACEMS_DOI, "first") - self.assertRaises( - datastore.ChecksumMismatchError, self.fetcher.get_resource, res - ) + with pytest.raises(datastore.ChecksumMismatchError): + self.fetcher.get_resource(res) def test_get_resource_with_nonexistent_resource_fails(self): """If resource does not exist, get_resource() throws KeyError.""" res = PudlResourceKey("epacems", self.PROD_EPACEMS_DOI, "nonexistent") - self.assertRaises(KeyError, self.fetcher.get_resource, res) + with pytest.raises(KeyError): + self.fetcher.get_resource(res) def test_get_zipfile_resource_failure(mocker): @@ -415,4 +379,4 @@ def test_get_zipfile_resources_eventual_success(mocker): assert test_file.read().decode(encoding="utf-8") == file_contents -# TODO(rousik): add unit tests for Datasource class as well +# TODO: add unit tests for Datasource class as well diff --git a/test/unit/workspace/resource_cache_test.py b/test/unit/workspace/resource_cache_test.py index 95d582d680..feba668eed 100644 --- a/test/unit/workspace/resource_cache_test.py +++ b/test/unit/workspace/resource_cache_test.py @@ -2,7 +2,6 @@ import shutil import tempfile -import unittest from pathlib import Path import requests.exceptions as requests_exceptions @@ -13,7 +12,7 @@ from pudl.workspace.resource_cache import PudlResourceKey, extend_gcp_retry_predicate -class TestGoogleCloudStorageCache(unittest.TestCase): +class TestGoogleCloudStorageCache: """Unit tests for the GoogleCloudStorageCache class.""" def test_bad_request_predicate(self): @@ -21,77 +20,77 @@ def test_bad_request_predicate(self): bad_request_predicate = extend_gcp_retry_predicate(_should_retry, BadRequest) # Check default exceptions. - self.assertFalse(_should_retry(BadRequest(message="Bad request!"))) - self.assertTrue(_should_retry(requests_exceptions.Timeout())) + assert not _should_retry(BadRequest(message="Bad request!")) + assert _should_retry(requests_exceptions.Timeout()) - # Check extended predicate handles default exceptionss and BadRequest. - self.assertTrue(bad_request_predicate(requests_exceptions.Timeout())) - self.assertTrue(bad_request_predicate(BadRequest(message="Bad request!"))) + # Check extended predicate handles default exceptions and BadRequest. + assert bad_request_predicate(requests_exceptions.Timeout()) + assert bad_request_predicate(BadRequest(message="Bad request!")) -class TestLocalFileCache(unittest.TestCase): +class TestLocalFileCache: """Unit tests for the LocalFileCache class.""" - def setUp(self): + def setup_method(self): """Prepares temporary directory for storing cache contents.""" self.test_dir = tempfile.mkdtemp() self.cache = resource_cache.LocalFileCache(Path(self.test_dir)) - def tearDown(self): + def teardown_method(self): """Deletes content of the temporary directories.""" shutil.rmtree(self.test_dir) def test_add_single_resource(self): """Adding resource has expected effect on later get() and contains() calls.""" res = PudlResourceKey("ds", "doi", "file.txt") - self.assertFalse(self.cache.contains(res)) + assert not self.cache.contains(res) self.cache.add(res, b"blah") - self.assertTrue(self.cache.contains(res)) - self.assertEqual(b"blah", self.cache.get(res)) + assert self.cache.contains(res) + assert self.cache.get(res) == b"blah" def test_that_two_cache_objects_share_storage(self): """Two LocalFileCache instances with the same path share the object storage.""" second_cache = resource_cache.LocalFileCache(Path(self.test_dir)) res = PudlResourceKey("dataset", "doi", "file.txt") - self.assertFalse(self.cache.contains(res)) - self.assertFalse(second_cache.contains(res)) + assert not self.cache.contains(res) + assert not second_cache.contains(res) self.cache.add(res, b"testContents") - self.assertTrue(self.cache.contains(res)) - self.assertTrue(second_cache.contains(res)) - self.assertEqual(b"testContents", second_cache.get(res)) + assert self.cache.contains(res) + assert second_cache.contains(res) + assert second_cache.get(res) == b"testContents" def test_deletion(self): """Deleting resources has expected effect on later get() / contains() calls.""" res = PudlResourceKey("a", "b", "c") - self.assertFalse(self.cache.contains(res)) + assert not self.cache.contains(res) self.cache.add(res, b"sampleContents") - self.assertTrue(self.cache.contains(res)) + assert self.cache.contains(res) self.cache.delete(res) - self.assertFalse(self.cache.contains(res)) + assert not self.cache.contains(res) def test_read_only_add_and_delete_do_nothing(self): """Test that in read_only mode, add() and delete() calls are ignored.""" res = PudlResourceKey("a", "b", "c") ro_cache = resource_cache.LocalFileCache(Path(self.test_dir), read_only=True) - self.assertTrue(ro_cache.is_read_only()) + assert ro_cache.is_read_only() ro_cache.add(res, b"sample") - self.assertFalse(ro_cache.contains(res)) + assert not ro_cache.contains(res) # Use read-write cache to insert resource self.cache.add(res, b"sample") - self.assertFalse(self.cache.is_read_only()) - self.assertTrue(ro_cache.contains(res)) + assert not self.cache.is_read_only() + assert ro_cache.contains(res) # Deleting via ro cache should not happen ro_cache.delete(res) - self.assertTrue(ro_cache.contains(res)) + assert ro_cache.contains(res) -class TestLayeredCache(unittest.TestCase): +class TestLayeredCache: """Unit tests for LayeredCache class.""" - def setUp(self): + def setup_method(self): """Constructs two LocalFileCache layers pointed at temporary directories.""" self.layered_cache = resource_cache.LayeredCache() self.test_dir_1 = tempfile.mkdtemp() @@ -106,53 +105,56 @@ def tearDown(self): def test_add_caching_layers(self): """Adding layers has expected effect on the subsequent num_layers() calls.""" - self.assertEqual(0, self.layered_cache.num_layers()) + # self.assertEqual(0, self.layered_cache.num_layers()) + assert self.layered_cache.num_layers() == 0 self.layered_cache.add_cache_layer(self.cache_1) - self.assertEqual(1, self.layered_cache.num_layers()) + assert self.layered_cache.num_layers() == 1 self.layered_cache.add_cache_layer(self.cache_2) - self.assertEqual(2, self.layered_cache.num_layers()) + assert self.layered_cache.num_layers() == 2 def test_add_to_first_layer(self): """Adding to layered cache by default stores entires in the first layer.""" self.layered_cache.add_cache_layer(self.cache_1) self.layered_cache.add_cache_layer(self.cache_2) res = PudlResourceKey("a", "b", "x.txt") - self.assertFalse(self.layered_cache.contains(res)) + # self.assertFalse(self.layered_cache.contains(res)) + assert not self.layered_cache.contains(res) self.layered_cache.add(res, b"sampleContent") - self.assertTrue(self.layered_cache.contains(res)) - self.assertTrue(self.cache_1.contains(res)) - self.assertFalse(self.cache_2.contains(res)) + assert self.layered_cache.contains(res) + assert self.cache_1.contains(res) + assert not self.cache_2.contains(res) def test_get_uses_innermost_layer(self): """Resource is retrieved from the leftmost layer that contains it.""" res = PudlResourceKey("a", "b", "x.txt") self.layered_cache.add_cache_layer(self.cache_1) self.layered_cache.add_cache_layer(self.cache_2) - # self.cache_1.add(res, "firstLayer") + self.cache_1.add(res, b"firstLayer") self.cache_2.add(res, b"secondLayer") - self.assertEqual(b"secondLayer", self.layered_cache.get(res)) + # assert self.layered_cache.get(res) == b"secondLayer" + assert self.layered_cache.get(res) == b"firstLayer" self.cache_1.add(res, b"firstLayer") - self.assertEqual(b"firstLayer", self.layered_cache.get(res)) + assert self.layered_cache.get(res) == b"firstLayer" # Set on layered cache updates innermost layer self.layered_cache.add(res, b"newContents") - self.assertEqual(b"newContents", self.layered_cache.get(res)) - self.assertEqual(b"newContents", self.cache_1.get(res)) - self.assertEqual(b"secondLayer", self.cache_2.get(res)) + assert self.layered_cache.get(res) == b"newContents" + assert self.cache_1.get(res) == b"newContents" + assert self.cache_2.get(res) == b"secondLayer" # Deletion also only affects innermost layer self.layered_cache.delete(res) - self.assertTrue(self.layered_cache.contains(res)) - self.assertFalse(self.cache_1.contains(res)) - self.assertTrue(self.cache_2.contains(res)) - self.assertEqual(b"secondLayer", self.layered_cache.get(res)) + assert self.layered_cache.contains(res) + assert not self.cache_1.contains(res) + assert self.cache_2.contains(res) + assert self.cache_2.get(res) == b"secondLayer" def test_add_with_no_layers_does_nothing(self): """When add() is called on cache with no layers nothing happens.""" res = PudlResourceKey("a", "b", "c") - self.assertFalse(self.layered_cache.contains(res)) + assert not self.layered_cache.contains(res) self.layered_cache.add(res, b"sample") - self.assertFalse(self.layered_cache.contains(res)) + assert not self.layered_cache.contains(res) self.layered_cache.delete(res) def test_read_only_layers_skipped_when_adding(self): @@ -163,19 +165,19 @@ def test_read_only_layers_skipped_when_adding(self): res = PudlResourceKey("a", "b", "c") - self.assertFalse(lc.contains(res)) - self.assertFalse(c1.contains(res)) - self.assertFalse(c2.contains(res)) + assert not lc.contains(res) + assert not c1.contains(res) + assert not c2.contains(res) lc.add(res, b"test") - self.assertTrue(lc.contains(res)) - self.assertFalse(c1.contains(res)) - self.assertTrue(c2.contains(res)) + assert lc.contains(res) + assert not c1.contains(res) + assert c2.contains(res) lc.delete(res) - self.assertFalse(lc.contains(res)) - self.assertFalse(c1.contains(res)) - self.assertFalse(c2.contains(res)) + assert not lc.contains(res) + assert not c1.contains(res) + assert not c2.contains(res) def test_read_only_cache_ignores_modifications(self): """When cache is marked as read_only, add() and delete() calls are ignored.""" @@ -183,22 +185,22 @@ def test_read_only_cache_ignores_modifications(self): r2 = PudlResourceKey("a", "b", "r2") self.cache_1.add(r1, b"xxx") self.cache_2.add(r2, b"yyy") - self.assertTrue(self.cache_1.contains(r1)) - self.assertTrue(self.cache_2.contains(r2)) + assert self.cache_1.contains(r1) + assert self.cache_2.contains(r2) lc = resource_cache.LayeredCache(self.cache_1, self.cache_2, read_only=True) - self.assertTrue(lc.contains(r1)) - self.assertTrue(lc.contains(r2)) + assert lc.contains(r1) + assert lc.contains(r2) lc.delete(r1) lc.delete(r2) - self.assertTrue(lc.contains(r1)) - self.assertTrue(lc.contains(r2)) - self.assertTrue(self.cache_1.contains(r1)) - self.assertTrue(self.cache_2.contains(r2)) + assert lc.contains(r1) + assert lc.contains(r2) + assert self.cache_1.contains(r1) + assert self.cache_2.contains(r2) r_new = PudlResourceKey("a", "b", "new") lc.add(r_new, b"xyz") - self.assertFalse(lc.contains(r_new)) - self.assertFalse(self.cache_1.contains(r_new)) - self.assertFalse(self.cache_2.contains(r_new)) + assert not lc.contains(r_new) + assert not self.cache_1.contains(r_new) + assert not self.cache_2.contains(r_new)