Skip to content

Commit

Permalink
Fix date issue for get_series()
Browse files Browse the repository at this point in the history
  • Loading branch information
philsv committed Aug 22, 2024
1 parent f485cb3 commit 4a18a7e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 28 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: end-of-file-fixer
- id: check-docstring-first
- id: check-yaml

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# myeia

[![PyPI version](https://d25lcipzij17d.cloudfront.net/badge.svg?id=py&r=r&ts=1683906897&type=6e&v=0.4.1&x2=0)](https://badge.fury.io/py/myeia)
[![PyPI version](https://d25lcipzij17d.cloudfront.net/badge.svg?id=py&r=r&ts=1683906897&type=6e&v=0.4.2&x2=0)](https://badge.fury.io/py/myeia)
[![License: MIT](https://img.shields.io/badge/License-MIT-red.svg)](https://github.com/philsv/myeia/blob/main/LICENSE)
[![Weekly Downloads](https://static.pepy.tech/personalized-badge/myeia?period=week&units=international_system&left_color=grey&right_color=blue&left_text=downloads/week)](https://pepy.tech/project/myeia)
[![Monthly Downloads](https://static.pepy.tech/personalized-badge/myeia?period=month&units=international_system&left_color=grey&right_color=blue&left_text=downloads/month)](https://pepy.tech/project/myeia)
Expand Down
33 changes: 15 additions & 18 deletions myeia/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def get_response(
url: str,
headers: dict,
) -> pd.DataFrame:
"""
Helper function to get the response from the EIA API and return it as a dataframe.
"""
"""Helper function to get the response from the EIA API and return it as a dataframe."""
time.sleep(0.25)
response = requests.get(url, headers=headers)
response.raise_for_status()
Expand All @@ -58,9 +56,7 @@ def format_date(
self,
df: pd.DataFrame,
) -> pd.DataFrame:
"""
Helper function to format date.
"""
"""Helper function to format date."""
if "period" in df.columns:
df = df.rename(columns={"period": "Date"})
df = df.set_index("Date")
Expand All @@ -84,7 +80,7 @@ def get_series(
Args:
series_id (str): The series ID.
data_identifier (str, optional): The data identifier. Defaults to "value".
start_date (str, optional): The start date of the series.
end_date (str, optional): The end date of the series.
Expand All @@ -94,21 +90,20 @@ def get_series(
>>> eia.get_series("NG.RNGC1.W")
"""
api_endpoint = f"seriesid/{series_id}?api_key={self.token}"

if start_date and end_date:
api_endpoint += f"&start={start_date}&end={end_date}"

url = f"{self.base_url}{api_endpoint}"
df = self.get_response(url, self.header)
df = self.format_date(df)

# Filter the DataFrame by the specified date range
df = df[(df.index >= start_date) & (df.index <= end_date)]

df = df.sort_index(ascending=False)
df[data_identifier] = df[data_identifier].astype(float)

for col in df.columns:
if "name" in col.lower() or "description" in col.lower():
df = df.rename(columns={data_identifier: df[col][0]})
df = df[df[col][0]].to_frame()

return df

def get_series_via_route(
Expand All @@ -130,7 +125,7 @@ def get_series_via_route(
route (str): The route to the series.
series (str, list): The series.
frequency (str): The frequency of the series.
facet (str, list, optional): The facet of the series. Defaults to "series".
rename_to (str, optional): The rename of the series. Defaults to "value".
start_date (str, optional): The start date of the series. Defaults to str(date.today() - relativedelta(months=2)).
Expand All @@ -148,7 +143,7 @@ def get_series_via_route(

if start_date and end_date:
base_api_endpoint += f"&start={start_date}&end={end_date}"

# Filter by multiple facets
if isinstance(facet, list) and isinstance(series, list):
for f, s in zip(facet, series):
Expand All @@ -161,7 +156,10 @@ def get_series_via_route(
f"Ensure that facet and series are of the same type (either str or list). Received facet: {facet} and series: {series}."
)

api_endpoint = base_api_endpoint + f"&sort[0][column]=period&sort[0][direction]=desc&offset={offset}&length={limit}"
api_endpoint = (
base_api_endpoint
+ f"&sort[0][column]=period&sort[0][direction]=desc&offset={offset}&length={limit}"
)
url = f"{self.base_url}{api_endpoint}"

df = self.get_response(url, self.header)
Expand All @@ -174,7 +172,7 @@ def get_series_via_route(
df = self.format_date(df)
df = df.sort_index(ascending=False)
df[data_identifier] = df[data_identifier].astype(float)

if isinstance(facet, str) and isinstance(series, str):
for col in df.columns:
if "name" in col.lower() or "description" in col.lower():
Expand All @@ -188,5 +186,4 @@ def get_series_via_route(
facet.append(df[col][0])
df = df[facet]
break

return df
2 changes: 1 addition & 1 deletion myeia/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.1"
__version__ = "0.4.2"
17 changes: 10 additions & 7 deletions tests/test_myeia.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,22 @@


@pytest.mark.parametrize(
"series_id",
"series_id, start_date, end_date",
[
("NG.RNGC1.D"),
("PET.WCESTUS1.W"),
("INTL.29-12-HKG-BKWH.A"),
("NG.RNGC1.D", "2024-01-01", "2024-02-01"),
("PET.WCESTUS1.W", "2024-01-01", "2024-02-01"),
("INTL.29-12-HKG-BKWH.A", "2024-01-01", "2024-02-01"),
("STEO.PATC_WORLD.M", "2024-01-01", "2024-02-01"),
],
)
def test_get_series(series_id):
df = eia.get_series(series_id)
def test_get_series(series_id, start_date, end_date):
"""Test get_series method."""
df = eia.get_series(series_id, start_date=start_date, end_date=end_date)
assert isinstance(df, pd.DataFrame)


@pytest.mark.parametrize(
"route,series,frequency,facet",
"route, series, frequency, facet",
[
("steo", "PADI_OPEC", "monthly", "seriesId"),
("natural-gas/pri/fut", "RNGC1", "daily", "series"),
Expand All @@ -31,5 +33,6 @@ def test_get_series(series_id):
],
)
def test_get_series_via_route(route, series, frequency, facet):
"""Test get_series_via_route method."""
df = eia.get_series_via_route(route, series, frequency, facet)
assert isinstance(df, pd.DataFrame)

0 comments on commit 4a18a7e

Please sign in to comment.