Skip to content

Commit

Permalink
Refactor code / import changes
Browse files Browse the repository at this point in the history
  • Loading branch information
philsv committed Oct 8, 2024
1 parent 0767138 commit 2f10b1a
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 26 deletions.
30 changes: 15 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ To find all EIA Datasets visit [API Dashboard](https://www.eia.gov/opendata/brow
## How to use

```python
from myeia.api import API
from myeia import API

eia = API()
```
Expand Down Expand Up @@ -133,7 +133,7 @@ Output Example:

```ini
countryRegionId productId Crude oil, NGPL, and other liquids
Date
Date
2024-03-01 ARE 55 4132.394334
2024-02-01 ARE 55 4132.394334
2024-01-01 ARE 55 4142.394334
Expand Down Expand Up @@ -165,12 +165,12 @@ Output Example:

```ini
Natural Gas Futures Contract 1 (Dollars per Million Btu) Natural Gas Futures Contract 2 (Dollars per Million Btu)
Date
2023-08-29 2.556 2.662
2023-08-28 2.579 2.665
2023-08-25 2.540 2.657
2023-08-24 2.519 2.636
2023-08-23 2.497 2.592
Date
2023-08-29 2.556 2.662
2023-08-28 2.579 2.665
2023-08-25 2.540 2.657
2023-08-24 2.519 2.636
2023-08-23 2.497 2.592
... ... ...
```

Expand All @@ -192,13 +192,13 @@ Output Example:

```ini
Natural Gas Futures Contract 1 (Dollars per Million Btu)
Date
2021-01-29 2.564
2021-01-28 2.664
2021-01-27 2.760
2021-01-26 2.656
2021-01-25 2.602
... ...
Date
2021-01-29 2.564
2021-01-28 2.664
2021-01-27 2.760
2021-01-26 2.656
2021-01-25 2.602
... ...
```

This also works for the `get_series_via_route` method.
Expand Down
3 changes: 3 additions & 0 deletions myeia/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .api import API

__all__ = ["API"]
25 changes: 18 additions & 7 deletions myeia/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
import os
import time
import warnings
from datetime import date
from typing import Optional, Union

import backoff
import numpy as np
import pandas as pd
import requests
from dateutil.relativedelta import relativedelta
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
max_tries=10,
raise_on_giveup=True,
jitter=backoff.full_jitter,
giveup=lambda e: e.response.status_code == 403,
giveup=lambda e: hasattr(e, "response") and e.response.status_code == 403,
)
def get_response(
self,
Expand All @@ -57,8 +57,8 @@ def get_response(
json_response = response.json()
return pd.DataFrame(json_response["response"]["data"])

@staticmethod
def format_date(
self,
df: pd.DataFrame,
) -> pd.DataFrame:
"""Helper function to format date."""
Expand Down Expand Up @@ -103,18 +103,23 @@ def get_series(
df = df[(df.index >= start_date) & (df.index <= end_date)]

df = df.sort_index(ascending=False)

if "NA" in df[data_identifier].values:
df[data_identifier] = df[data_identifier].replace("NA", np.nan)

df[data_identifier] = df[data_identifier].astype(float)

# Filtering the DataFrame by the specified date range can result in an empty DataFrame
if df.empty:
return df

descriptions = ["series-description", "seriesDescription", "productName"]

for col in df.columns:
if "name" in col.lower() or "description" in col.lower():
if col in descriptions:
df = df.rename(columns={data_identifier: df[col][0]})
df = df[df[col][0]].to_frame()
break

return df

def get_series_via_route(
Expand Down Expand Up @@ -182,17 +187,23 @@ def get_series_via_route(

df = self.format_date(df)
df = df.sort_index(ascending=False)

if "NA" in df[data_identifier].values:
df[data_identifier] = df[data_identifier].replace("NA", np.nan)

df[data_identifier] = df[data_identifier].astype(float)

descriptions = ["series-description", "seriesDescription", "productName"]

if isinstance(facet, str) and isinstance(series, str):
for col in df.columns:
if "name" in col.lower() or "description" in col.lower():
if col in descriptions:
df = df.rename(columns={data_identifier: df[col][0]})
df = df[df[col][0]].to_frame()
break
elif isinstance(facet, list) and isinstance(series, list):
for col in df.columns:
if "name" in col.lower() or "description" in col.lower():
if col in descriptions:
df = df.rename(columns={data_identifier: df[col][0]})
facet.append(df[col][0])
df = df[facet]
Expand Down
10 changes: 6 additions & 4 deletions tests/test_myeia.py → tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import pandas as pd
import pytest

from myeia.api import API
from myeia import API

eia = API()


@pytest.mark.parametrize(
"series_id, start_date, end_date",
[
("NG.RNGC1.D", "2024-01-01", "2024-02-01"),
("PET.WCESTUS1.W", "2024-01-01", "2024-02-01"),
("INTL.29-12-HKG-BKWH.A", "2024-01-01", "2024-02-01"),
("NG.RNGC1.D", "2020-01-01", "2024-02-01"),
("PET.WCESTUS1.W", "2020-01-01", "2024-02-01"),
("INTL.29-12-HKG-BKWH.A", "2020-01-01", "2024-02-01"),
("STEO.PATC_WORLD.M", "2024-01-01", "2024-02-01"),
],
)
def test_get_series(series_id, start_date, end_date):
"""Test get_series method."""
df = eia.get_series(series_id, start_date=start_date, end_date=end_date)
assert not df.empty
assert isinstance(df, pd.DataFrame)


Expand All @@ -35,4 +36,5 @@ def test_get_series(series_id, start_date, end_date):
def test_get_series_via_route(route, series, frequency, facet):
"""Test get_series_via_route method."""
df = eia.get_series_via_route(route, series, frequency, facet)
assert not df.empty
assert isinstance(df, pd.DataFrame)

0 comments on commit 2f10b1a

Please sign in to comment.