Skip to content

Commit

Permalink
Update processor_alpaca.py (#1185)
Browse files Browse the repository at this point in the history
* Update processor_alpaca.py

Update processor_alpaca to use the new Alpaca-Py instead of the deprecated alpaca-trade-api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Omers5 and pre-commit-ci[bot] authored Jan 10, 2025
1 parent 3049279 commit 1c07162
Showing 1 changed file with 94 additions and 16 deletions.
110 changes: 94 additions & 16 deletions finrl/meta/data_processors/processor_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,40 @@

from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import timedelta as td

import alpaca_trade_api as tradeapi
import exchange_calendars as tc
import numpy as np
import pandas as pd
import pytz
from alpaca.data.historical import StockHistoricalDataClient
from alpaca.data.requests import StockBarsRequest
from alpaca.data.timeframe import TimeFrame
from stockstats import StockDataFrame as Sdf

# import alpaca_trade_api as tradeapi


class AlpacaProcessor:
def __init__(self, API_KEY=None, API_SECRET=None, API_BASE_URL=None, api=None):
if api is None:
def __init__(self, API_KEY=None, API_SECRET=None, API_BASE_URL=None, client=None):
if client is None:
try:
self.api = tradeapi.REST(API_KEY, API_SECRET, API_BASE_URL, "v2")
self.client = StockHistoricalDataClient(API_KEY, API_SECRET)
except BaseException:
raise ValueError("Wrong Account Info!")
else:
self.api = api
self.client = client

def _fetch_data_for_ticker(self, ticker, start_date, end_date, time_interval):
bars = self.api.get_bars(
ticker,
time_interval,
start=start_date.isoformat(),
end=end_date.isoformat(),
).df
bars["symbol"] = ticker
request_params = StockBarsRequest(
symbol_or_symbols=ticker,
timeframe=TimeFrame.Minute,
start=start_date,
end=end_date,
)
bars = self.client.get_stock_bars(request_params).df

return bars

def download_data(
Expand All @@ -53,7 +60,7 @@ def download_data(
NY = "America/New_York"
start_date = pd.Timestamp(start_date + " 09:30:00", tz=NY)
end_date = pd.Timestamp(end_date + " 15:59:00", tz=NY)

data_list = []
# Use ThreadPoolExecutor to fetch data for multiple tickers concurrently
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [
Expand All @@ -66,7 +73,42 @@ def download_data(
)
for ticker in ticker_list
]
data_list = [future.result() for future in futures]
for future in futures:

bars = future.result()
# fix start
# Reorganize the dataframes to be in original alpaca_trade_api structure
# Rename the existing 'symbol' column if it exists
if not bars.empty:

# Now reset the index
bars.reset_index(inplace=True)

# Set 'timestamp' as the new index
if "level_1" in bars.columns:
bars.rename(columns={"level_1": "timestamp"}, inplace=True)
if "level_0" in bars.columns:
bars.rename(columns={"level_0": "symbol"}, inplace=True)

bars.set_index("timestamp", inplace=True)

# Reorder and rename columns as needed
bars = bars[
[
"close",
"high",
"low",
"trade_count",
"open",
"volume",
"vwap",
"symbol",
]
]

data_list.append(bars)
else:
print("empty")

# Combine the data
data_df = pd.concat(data_list, axis=0)
Expand Down Expand Up @@ -371,7 +413,40 @@ def fetch_latest_data(
) -> pd.DataFrame:
data_df = pd.DataFrame()
for tic in ticker_list:
barset = self.api.get_bars([tic], time_interval, limit=limit).df # [tic]
request_params = StockBarsRequest(
symbol_or_symbols=[tic], timeframe=TimeFrame.Minute, limit=limit
)

barset = self.client.get_stock_bars(request_params).df
# Reorganize the dataframes to be in original alpaca_trade_api structure
# Rename the existing 'symbol' column if it exists
if "symbol" in barset.columns:
barset.rename(columns={"symbol": "symbol_old"}, inplace=True)

# Now reset the index
barset.reset_index(inplace=True)

# Set 'timestamp' as the new index
if "level_0" in barset.columns:
barset.rename(columns={"level_0": "symbol"}, inplace=True)
if "level_1" in bars.columns:
barset.rename(columns={"level_1": "timestamp"}, inplace=True)
barset.set_index("timestamp", inplace=True)

# Reorder and rename columns as needed
barset = bars[
[
"close",
"high",
"low",
"trade_count",
"open",
"volume",
"vwap",
"symbol",
]
]

barset["tic"] = tic
barset = barset.reset_index()
data_df = pd.concat([data_df, barset])
Expand Down Expand Up @@ -451,6 +526,9 @@ def fetch_latest_data(
)
latest_price = price_array[-1]
latest_tech = tech_array[-1]
turb_df = self.api.get_bars(["VIXY"], time_interval, limit=1).df
request_params = StockBarsRequest(
symbol_or_symbols="VIXY", timeframe=TimeFrame.Minute, limit=1
)
turb_df = self.client.get_stock_bars(request_params).df
latest_turb = turb_df["close"].values
return latest_price, latest_tech, latest_turb

0 comments on commit 1c07162

Please sign in to comment.