From 4a18a7e89968fd94276ca30f3d85d5027113e9ce Mon Sep 17 00:00:00 2001 From: philsv <35413649+philsv@users.noreply.github.com> Date: Fri, 23 Aug 2024 00:08:53 +0200 Subject: [PATCH] Fix date issue for get_series() --- .pre-commit-config.yaml | 1 - README.md | 2 +- myeia/api.py | 33 +++++++++++++++------------------ myeia/version.py | 2 +- tests/test_myeia.py | 17 ++++++++++------- 5 files changed, 27 insertions(+), 28 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 304a3a0..13e27ef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: end-of-file-fixer - id: check-docstring-first - id: check-yaml diff --git a/README.md b/README.md index 147fc19..ce4b0b8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # myeia -[![PyPI version](https://d25lcipzij17d.cloudfront.net/badge.svg?id=py&r=r&ts=1683906897&type=6e&v=0.4.1&x2=0)](https://badge.fury.io/py/myeia) +[![PyPI version](https://d25lcipzij17d.cloudfront.net/badge.svg?id=py&r=r&ts=1683906897&type=6e&v=0.4.2&x2=0)](https://badge.fury.io/py/myeia) [![License: MIT](https://img.shields.io/badge/License-MIT-red.svg)](https://github.com/philsv/myeia/blob/main/LICENSE) [![Weekly Downloads](https://static.pepy.tech/personalized-badge/myeia?period=week&units=international_system&left_color=grey&right_color=blue&left_text=downloads/week)](https://pepy.tech/project/myeia) [![Monthly Downloads](https://static.pepy.tech/personalized-badge/myeia?period=month&units=international_system&left_color=grey&right_color=blue&left_text=downloads/month)](https://pepy.tech/project/myeia) diff --git a/myeia/api.py b/myeia/api.py index 3fa134b..f05e415 100644 --- a/myeia/api.py +++ b/myeia/api.py @@ -45,9 +45,7 @@ def get_response( url: str, headers: dict, ) -> pd.DataFrame: - """ - Helper function to get the response from the EIA API and return it as a dataframe. - """ + """Helper function to get the response from the EIA API and return it as a dataframe.""" time.sleep(0.25) response = requests.get(url, headers=headers) response.raise_for_status() @@ -58,9 +56,7 @@ def format_date( self, df: pd.DataFrame, ) -> pd.DataFrame: - """ - Helper function to format date. - """ + """Helper function to format date.""" if "period" in df.columns: df = df.rename(columns={"period": "Date"}) df = df.set_index("Date") @@ -84,7 +80,7 @@ def get_series( Args: series_id (str): The series ID. - + data_identifier (str, optional): The data identifier. Defaults to "value". start_date (str, optional): The start date of the series. end_date (str, optional): The end date of the series. @@ -94,21 +90,20 @@ def get_series( >>> eia.get_series("NG.RNGC1.W") """ api_endpoint = f"seriesid/{series_id}?api_key={self.token}" - - if start_date and end_date: - api_endpoint += f"&start={start_date}&end={end_date}" - url = f"{self.base_url}{api_endpoint}" df = self.get_response(url, self.header) df = self.format_date(df) + + # Filter the DataFrame by the specified date range + df = df[(df.index >= start_date) & (df.index <= end_date)] + df = df.sort_index(ascending=False) df[data_identifier] = df[data_identifier].astype(float) - + for col in df.columns: if "name" in col.lower() or "description" in col.lower(): df = df.rename(columns={data_identifier: df[col][0]}) df = df[df[col][0]].to_frame() - return df def get_series_via_route( @@ -130,7 +125,7 @@ def get_series_via_route( route (str): The route to the series. series (str, list): The series. frequency (str): The frequency of the series. - + facet (str, list, optional): The facet of the series. Defaults to "series". rename_to (str, optional): The rename of the series. Defaults to "value". start_date (str, optional): The start date of the series. Defaults to str(date.today() - relativedelta(months=2)). @@ -148,7 +143,7 @@ def get_series_via_route( if start_date and end_date: base_api_endpoint += f"&start={start_date}&end={end_date}" - + # Filter by multiple facets if isinstance(facet, list) and isinstance(series, list): for f, s in zip(facet, series): @@ -161,7 +156,10 @@ def get_series_via_route( f"Ensure that facet and series are of the same type (either str or list). Received facet: {facet} and series: {series}." ) - api_endpoint = base_api_endpoint + f"&sort[0][column]=period&sort[0][direction]=desc&offset={offset}&length={limit}" + api_endpoint = ( + base_api_endpoint + + f"&sort[0][column]=period&sort[0][direction]=desc&offset={offset}&length={limit}" + ) url = f"{self.base_url}{api_endpoint}" df = self.get_response(url, self.header) @@ -174,7 +172,7 @@ def get_series_via_route( df = self.format_date(df) df = df.sort_index(ascending=False) df[data_identifier] = df[data_identifier].astype(float) - + if isinstance(facet, str) and isinstance(series, str): for col in df.columns: if "name" in col.lower() or "description" in col.lower(): @@ -188,5 +186,4 @@ def get_series_via_route( facet.append(df[col][0]) df = df[facet] break - return df diff --git a/myeia/version.py b/myeia/version.py index 3d26edf..df12433 100644 --- a/myeia/version.py +++ b/myeia/version.py @@ -1 +1 @@ -__version__ = "0.4.1" +__version__ = "0.4.2" diff --git a/tests/test_myeia.py b/tests/test_myeia.py index 8f3af9e..0b78be2 100644 --- a/tests/test_myeia.py +++ b/tests/test_myeia.py @@ -7,20 +7,22 @@ @pytest.mark.parametrize( - "series_id", + "series_id, start_date, end_date", [ - ("NG.RNGC1.D"), - ("PET.WCESTUS1.W"), - ("INTL.29-12-HKG-BKWH.A"), + ("NG.RNGC1.D", "2024-01-01", "2024-02-01"), + ("PET.WCESTUS1.W", "2024-01-01", "2024-02-01"), + ("INTL.29-12-HKG-BKWH.A", "2024-01-01", "2024-02-01"), + ("STEO.PATC_WORLD.M", "2024-01-01", "2024-02-01"), ], ) -def test_get_series(series_id): - df = eia.get_series(series_id) +def test_get_series(series_id, start_date, end_date): + """Test get_series method.""" + df = eia.get_series(series_id, start_date=start_date, end_date=end_date) assert isinstance(df, pd.DataFrame) @pytest.mark.parametrize( - "route,series,frequency,facet", + "route, series, frequency, facet", [ ("steo", "PADI_OPEC", "monthly", "seriesId"), ("natural-gas/pri/fut", "RNGC1", "daily", "series"), @@ -31,5 +33,6 @@ def test_get_series(series_id): ], ) def test_get_series_via_route(route, series, frequency, facet): + """Test get_series_via_route method.""" df = eia.get_series_via_route(route, series, frequency, facet) assert isinstance(df, pd.DataFrame)