Skip to content

Commit

Permalink
Merge branch 'main' into dbt_setup
Browse files Browse the repository at this point in the history
  • Loading branch information
zaneselvans committed Jan 14, 2025
2 parents 784cf96 + 454ca96 commit 48a16e1
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 326 deletions.
74 changes: 30 additions & 44 deletions test/unit/extract/csv_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""Unit tests for pudl.extract.csv module."""

from unittest.mock import MagicMock, patch

import pandas as pd
import pytest

Expand All @@ -16,16 +12,16 @@


class FakeExtractor(CsvExtractor):
def __init__(self):
def __init__(self, mocker):
# TODO: Make these tests independent of the eia176 implementation
self.METADATA = GenericMetadata("eia176")
super().__init__(ds=MagicMock())
super().__init__(ds=mocker.MagicMock())


@pytest.fixture
def extractor():
# Create an instance of the CsvExtractor class
return FakeExtractor()
def extractor(mocker):
# Create an instance of the CsvExtractor class with mocker
return FakeExtractor(mocker)


def test_source_filename_valid_partition(extractor):
Expand All @@ -45,8 +41,8 @@ def test_source_filename_multiple_selections(extractor):
extractor.source_filename(PAGE, **multiple_selections)


@patch("pudl.extract.csv.pd")
def test_load_source(mock_pd, extractor):
def test_load_source(mocker, extractor):
mock_pd = mocker.patch("pudl.extract.csv.pd")
assert extractor.load_source(PAGE, **PARTITION) == mock_pd.read_csv.return_value
extractor.ds.get_zipfile_resource.assert_called_once_with(DATASET, **PARTITION)
zipfile = extractor.ds.get_zipfile_resource.return_value.__enter__.return_value
Expand All @@ -55,7 +51,7 @@ def test_load_source(mock_pd, extractor):
mock_pd.read_csv.assert_called_once_with(file)


def test_extract(extractor):
def test_extract(mocker, extractor):
# Create a sample of data we could expect from an EIA CSV
company_field = "company"
company_data = "Total of All Companies"
Expand All @@ -64,34 +60,30 @@ def test_extract(extractor):

# TODO: Once FakeExtractor is independent of eia176, mock out populating _column_map for PARTITION_SELECTION;
# Also include negative tests, i.e., for partition selections not in the _column_map
with (
patch.object(CsvExtractor, "load_source", return_value=df),
patch.object(
# Testing the rename
GenericMetadata,
"get_column_map",
return_value={"company_rename": company_field},
),
patch.object(
# Transposing the df here to get the orientation we expect get_page_cols to return
CsvExtractor,
"get_page_cols",
return_value=df.T.index,
),
):
res = extractor.extract(**PARTITION)
mocker.patch.object(CsvExtractor, "load_source", return_value=df)
# Testing the rename
mocker.patch.object(
GenericMetadata,
"get_column_map",
return_value={"company_rename": company_field},
)
# Transposing the df here to get the orientation we expect get_page_cols to return
mocker.patch.object(
CsvExtractor,
"get_page_cols",
return_value=df.T.index,
)
res = extractor.extract(**PARTITION)
assert len(res) == 1 # Assert only one page extracted
assert list(res.keys()) == [PAGE] # Assert it is named correctly
assert (
res[PAGE][company_field][0] == company_data
) # Assert that column correctly renamed and data is there.


@patch.object(FakeExtractor, "METADATA")
def test_validate_exact_columns(mock_metadata, extractor):
def test_validate_exact_columns(mocker, extractor):
# Mock the partition selection and page columns
# mock_metadata._get_partition_selection.return_value = "partition1"
extractor.get_page_cols = MagicMock(return_value={"col1", "col2"})
extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"})

# Create a DataFrame with the exact expected columns
df = pd.DataFrame(columns=["col1", "col2"])
Expand All @@ -100,11 +92,9 @@ def test_validate_exact_columns(mock_metadata, extractor):
extractor.validate(df, "page1", partition="partition1")


@patch.object(FakeExtractor, "METADATA")
def test_validate_extra_columns(mock_metadata, extractor):
def test_validate_extra_columns(mocker, extractor):
# Mock the partition selection and page columns
mock_metadata._get_partition_selection.return_value = "partition1"
extractor.get_page_cols = MagicMock(return_value={"col1", "col2"})
extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"})

# Create a DataFrame with extra columns
df = pd.DataFrame(columns=["col1", "col2", "col3"])
Expand All @@ -114,11 +104,9 @@ def test_validate_extra_columns(mock_metadata, extractor):
extractor.validate(df, "page1", partition="partition1")


@patch.object(FakeExtractor, "METADATA")
def test_validate_missing_columns(mock_metadata, extractor):
def test_validate_missing_columns(mocker, extractor):
# Mock the partition selection and page columns
mock_metadata._get_partition_selection.return_value = "partition1"
extractor.get_page_cols = MagicMock(return_value={"col1", "col2"})
extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"})

# Create a DataFrame with missing columns
df = pd.DataFrame(columns=["col1"])
Expand All @@ -130,11 +118,9 @@ def test_validate_missing_columns(mock_metadata, extractor):
extractor.validate(df, "page1", partition="partition1")


@patch.object(FakeExtractor, "METADATA")
def test_validate_extra_and_missing_columns(mock_metadata, extractor):
def test_validate_extra_and_missing_columns(mocker, extractor):
# Mock the partition selection and page columns
mock_metadata._get_partition_selection.return_value = "partition1"
extractor.get_page_cols = MagicMock(return_value={"col1", "col2"})
extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"})

# Create a DataFrame with both extra and missing columns
df = pd.DataFrame(columns=["col1", "col3"])
Expand Down
54 changes: 30 additions & 24 deletions test/unit/extract/excel_test.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,42 @@
"""Unit tests for pudl.extract.excel module."""

import unittest
from unittest import mock as mock

import pandas as pd
import pytest

from pudl.extract import excel


class TestMetadata(unittest.TestCase):
class TestMetadata:
"""Tests basic operation of the excel.Metadata object."""

@pytest.fixture(autouse=True)
def setUp(self):
"""Cosntructs test metadata instance for testing."""
"""Constructs test metadata instance for testing."""
self._metadata = excel.ExcelMetadata("test")

def test_basics(self):
"""Test that basic API method return expected results."""
self.assertEqual("test", self._metadata.get_dataset_name())
self.assertListEqual(
["books", "boxes", "shoes"], self._metadata.get_all_pages()
)
self.assertListEqual(
["author", "pages", "title"], self._metadata.get_all_columns("books")
)
self.assertDictEqual(
{"book_title": "title", "name": "author", "pages": "pages"},
self._metadata.get_column_map("books", year=2010),
)
self.assertEqual(10, self._metadata.get_skiprows("boxes", year=2011))
self.assertEqual(1, self._metadata.get_sheet_name("boxes", year=2011))
assert self._metadata.get_dataset_name() == "test"
assert self._metadata.get_all_pages() == ["books", "boxes", "shoes"]
assert self._metadata.get_all_columns("books") == ["author", "pages", "title"]
assert self._metadata.get_column_map("books", year=2010) == {
"book_title": "title",
"name": "author",
"pages": "pages",
}
assert self._metadata.get_skiprows("boxes", year=2011) == 10
assert self._metadata.get_sheet_name("boxes", year=2011) == 1

def test_metadata_methods(self):
"""Test various metadata methods."""
assert self._metadata.get_all_columns("books") == ["author", "pages", "title"]
assert self._metadata.get_column_map("books", year=2010) == {
"book_title": "title",
"name": "author",
"pages": "pages",
}
assert self._metadata.get_skiprows("boxes", year=2011) == 10
assert self._metadata.get_sheet_name("boxes", year=2011) == 1


class FakeExtractor(excel.ExcelExtractor):
Expand Down Expand Up @@ -77,11 +84,10 @@ def _fake_data_frames(page_name, **kwargs):
return fake_data[page_name]


class TestExtractor(unittest.TestCase):
class TestExtractor:
"""Test operation of the excel.Extractor class."""

@staticmethod
def test_extract():
def test_extract(self):
extractor = FakeExtractor()
res = extractor.extract(year=[2010, 2011])
expected_books = {
Expand All @@ -103,7 +109,7 @@ def test_extract():
# def test_resulting_dataframes(self):
# """Checks that pages across years are merged and columns are translated."""
# dfs = FakeExtractor().extract([2010, 2011], testing=True)
# self.assertEqual(set(['books', 'boxes']), set(dfs.keys()))
# assert set(['books', 'boxes']) == set(dfs.keys())
# pd.testing.assert_frame_equal(
# pd.DataFrame(data={
# 'author': ['Laozi', 'Benjamin Hoff'],
Expand All @@ -118,5 +124,5 @@ def test_extract():
# }),
# dfs['boxes'])

# TODO([email protected]): need to figure out how to test process_$x methods.
# TODO([email protected]): we should test that empty columns are properly added.
# TODO: need to figure out how to test process_$x methods.
# TODO: we should test that empty columns are properly added.
28 changes: 13 additions & 15 deletions test/unit/extract/phmsagas_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest.mock import MagicMock, patch

import pandas as pd
import pytest

Expand All @@ -8,20 +6,20 @@


class FakeExtractor(Extractor):
def __init__(self):
def __init__(self, mocker):
self.METADATA = ExcelMetadata("phmsagas")
super().__init__(ds=MagicMock())
self._metadata = MagicMock()
super().__init__(ds=mocker.Mock())
self._metadata = mocker.Mock()


@pytest.fixture
def extractor():
def extractor(mocker):
# Create an instance of the CsvExtractor class
return FakeExtractor()
return FakeExtractor(mocker)


@patch("pudl.extract.phmsagas.logger")
def test_process_renamed_drop_columns(mock_logger, extractor):
def test_process_renamed_drop_columns(mocker, extractor):
mock_logger = mocker.patch("pudl.extract.phmsagas.logger")
# Mock metadata methods
extractor._metadata.get_form.return_value = "gas_transmission_gathering"
extractor._metadata.get_all_columns.return_value = ["col1", "col2"]
Expand All @@ -38,8 +36,8 @@ def test_process_renamed_drop_columns(mock_logger, extractor):
mock_logger.info.assert_called_once()


@patch("pudl.extract.phmsagas.logger")
def test_process_renamed_keep_columns(mock_logger, extractor):
def test_process_renamed_keep_columns(mocker, extractor):
mock_logger = mocker.patch("pudl.extract.phmsagas.logger")
# Mock metadata methods
extractor._metadata.get_form.return_value = "gas_transmission_gathering"
extractor._metadata.get_all_columns.return_value = ["col1", "col2"]
Expand All @@ -56,8 +54,8 @@ def test_process_renamed_keep_columns(mock_logger, extractor):
mock_logger.info.assert_not_called()


@patch("pudl.extract.phmsagas.logger")
def test_process_renamed_drop_unnamed_columns(mock_logger, extractor):
def test_process_renamed_drop_unnamed_columns(mocker, extractor):
mock_logger = mocker.patch("pudl.extract.phmsagas.logger")
# Mock metadata methods
extractor._metadata.get_form.return_value = "some_form"
extractor._metadata.get_all_columns.return_value = ["col1", "col2"]
Expand All @@ -74,8 +72,8 @@ def test_process_renamed_drop_unnamed_columns(mock_logger, extractor):
mock_logger.warning.assert_not_called()


@patch("pudl.extract.phmsagas.logger")
def test_process_renamed_warn_unnamed_columns(mock_logger, extractor):
def test_process_renamed_warn_unnamed_columns(mocker, extractor):
mock_logger = mocker.patch("pudl.extract.phmsagas.logger")
# Mock metadata methods
extractor._metadata.get_form.return_value = "some_form"
extractor._metadata.get_all_columns.return_value = ["col1", "col2"]
Expand Down
13 changes: 4 additions & 9 deletions test/unit/output/ferc1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""

import logging
import unittest
from io import StringIO

import networkx as nx
Expand All @@ -37,11 +36,7 @@
logger = logging.getLogger(__name__)


class TestForestSetup(unittest.TestCase):
def setUp(self):
# this is where you add nodes you want to use
pass

class TestForestSetup:
def _exploded_calcs_from_edges(self, edges: list[tuple[NodeId, NodeId]]):
records = []
for parent, child in edges:
Expand Down Expand Up @@ -89,8 +84,8 @@ def build_forest_and_annotated_tags(
return annotated_tags


class TestPrunnedNode(TestForestSetup):
def setUp(self):
class TestPrunedNode(TestForestSetup):
def setup_method(self):
self.root = NodeId(
table_name="table_1",
xbrl_factoid="reported_1",
Expand Down Expand Up @@ -133,7 +128,7 @@ def test_pruned_nodes(self):


class TestTagPropagation(TestForestSetup):
def setUp(self):
def setup_method(self):
self.parent = NodeId(
table_name="table_1",
xbrl_factoid="reported_1",
Expand Down
Loading

0 comments on commit 48a16e1

Please sign in to comment.