Skip to content

Commit

Permalink
Revert "fix: Moved device parameter from DRLAgent initialization to g…
Browse files Browse the repository at this point in the history
…et_model…" (#1280)

This reverts commit e11f729.
  • Loading branch information
zhumingpassional authored Oct 7, 2024
1 parent e11f729 commit 690fc61
Show file tree
Hide file tree
Showing 6 changed files with 496 additions and 2,127 deletions.
199 changes: 199 additions & 0 deletions examples/FinRL_PaperTrading_Demo_refactored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Disclaimer: Nothing herein is financial advice, and NOT a recommendation to trade real money. Many platforms exist for simulated trading (paper trading) which can be used for building and developing the methods discussed. Please use common sense and always first consult a professional before trading or investing.
# install finrl library
# %pip install --upgrade git+https://github.com/AI4Finance-Foundation/FinRL.git
# Alpaca keys
from __future__ import annotations

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("data_key", help="data source api key")
parser.add_argument("data_secret", help="data source api secret")
parser.add_argument("data_url", help="data source api base url")
parser.add_argument("trading_key", help="trading api key")
parser.add_argument("trading_secret", help="trading api secret")
parser.add_argument("trading_url", help="trading api base url")
args = parser.parse_args()
DATA_API_KEY = args.data_key
DATA_API_SECRET = args.data_secret
DATA_API_BASE_URL = args.data_url
TRADING_API_KEY = args.trading_key
TRADING_API_SECRET = args.trading_secret
TRADING_API_BASE_URL = args.trading_url

print("DATA_API_KEY: ", DATA_API_KEY)
print("DATA_API_SECRET: ", DATA_API_SECRET)
print("DATA_API_BASE_URL: ", DATA_API_BASE_URL)
print("TRADING_API_KEY: ", TRADING_API_KEY)
print("TRADING_API_SECRET: ", TRADING_API_SECRET)
print("TRADING_API_BASE_URL: ", TRADING_API_BASE_URL)

from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
from finrl.meta.paper_trading.alpaca import PaperTradingAlpaca
from finrl.meta.paper_trading.common import train, test, alpaca_history, DIA_history
from finrl.config import INDICATORS

# Import Dow Jones 30 Symbols
from finrl.config_tickers import DOW_30_TICKER

ticker_list = DOW_30_TICKER
env = StockTradingEnv
# if you want to use larger datasets (change to longer period), and it raises error, please try to increase "target_step". It should be larger than the episode steps.
ERL_PARAMS = {
"learning_rate": 3e-6,
"batch_size": 2048,
"gamma": 0.985,
"seed": 312,
"net_dimension": [128, 64],
"target_step": 5000,
"eval_gap": 30,
"eval_times": 1,
}

# Set up sliding window of 6 days training and 2 days testing
import datetime
from pandas.tseries.offsets import BDay # BDay is business day, not birthday...

today = datetime.datetime.today()

TEST_END_DATE = (today - BDay(1)).to_pydatetime().date()
TEST_START_DATE = (TEST_END_DATE - BDay(1)).to_pydatetime().date()
TRAIN_END_DATE = (TEST_START_DATE - BDay(1)).to_pydatetime().date()
TRAIN_START_DATE = (TRAIN_END_DATE - BDay(5)).to_pydatetime().date()
TRAINFULL_START_DATE = TRAIN_START_DATE
TRAINFULL_END_DATE = TEST_END_DATE

TRAIN_START_DATE = str(TRAIN_START_DATE)
TRAIN_END_DATE = str(TRAIN_END_DATE)
TEST_START_DATE = str(TEST_START_DATE)
TEST_END_DATE = str(TEST_END_DATE)
TRAINFULL_START_DATE = str(TRAINFULL_START_DATE)
TRAINFULL_END_DATE = str(TRAINFULL_END_DATE)

print("TRAIN_START_DATE: ", TRAIN_START_DATE)
print("TRAIN_END_DATE: ", TRAIN_END_DATE)
print("TEST_START_DATE: ", TEST_START_DATE)
print("TEST_END_DATE: ", TEST_END_DATE)
print("TRAINFULL_START_DATE: ", TRAINFULL_START_DATE)
print("TRAINFULL_END_DATE: ", TRAINFULL_END_DATE)

train(
start_date=TRAIN_START_DATE,
end_date=TRAIN_END_DATE,
ticker_list=ticker_list,
data_source="alpaca",
time_interval="1Min",
technical_indicator_list=INDICATORS,
drl_lib="elegantrl",
env=env,
model_name="ppo",
if_vix=True,
API_KEY=DATA_API_KEY,
API_SECRET=DATA_API_SECRET,
API_BASE_URL=DATA_API_BASE_URL,
erl_params=ERL_PARAMS,
cwd="./papertrading_erl", # current_working_dir
break_step=1e5,
)

account_value_erl = test(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
ticker_list=ticker_list,
data_source="alpaca",
time_interval="1Min",
technical_indicator_list=INDICATORS,
drl_lib="elegantrl",
env=env,
model_name="ppo",
if_vix=True,
API_KEY=DATA_API_KEY,
API_SECRET=DATA_API_SECRET,
API_BASE_URL=DATA_API_BASE_URL,
cwd="./papertrading_erl",
net_dimension=ERL_PARAMS["net_dimension"],
)

train(
start_date=TRAINFULL_START_DATE, # After tuning well, retrain on the training and testing sets
end_date=TRAINFULL_END_DATE,
ticker_list=ticker_list,
data_source="alpaca",
time_interval="1Min",
technical_indicator_list=INDICATORS,
drl_lib="elegantrl",
env=env,
model_name="ppo",
if_vix=True,
API_KEY=DATA_API_KEY,
API_SECRET=DATA_API_SECRET,
API_BASE_URL=DATA_API_BASE_URL,
erl_params=ERL_PARAMS,
cwd="./papertrading_erl_retrain",
break_step=2e5,
)

action_dim = len(DOW_30_TICKER)
state_dim = (
1 + 2 + 3 * action_dim + len(INDICATORS) * action_dim
) # Calculate the DRL state dimension manually for paper trading. amount + (turbulence, turbulence_bool) + (price, shares, cd (holding time)) * stock_dim + tech_dim

paper_trading_erl = PaperTradingAlpaca(
ticker_list=DOW_30_TICKER,
time_interval="1Min",
drl_lib="elegantrl",
agent="ppo",
cwd="./papertrading_erl_retrain",
net_dim=ERL_PARAMS["net_dimension"],
state_dim=state_dim,
action_dim=action_dim,
API_KEY=TRADING_API_KEY,
API_SECRET=TRADING_API_SECRET,
API_BASE_URL=TRADING_API_BASE_URL,
tech_indicator_list=INDICATORS,
turbulence_thresh=30,
max_stock=1e2,
)

paper_trading_erl.run()

# Check Portfolio Performance
# ## Get cumulative return
df_erl, cumu_erl = alpaca_history(
key=DATA_API_KEY,
secret=DATA_API_SECRET,
url=DATA_API_BASE_URL,
start="2022-09-01", # must be within 1 month
end="2022-09-12",
) # change the date if error occurs

df_djia, cumu_djia = DIA_history(start="2022-09-01")
returns_erl = cumu_erl - 1
returns_dia = cumu_djia - 1
returns_dia = returns_dia[: returns_erl.shape[0]]

# plot and save
import matplotlib.pyplot as plt

plt.figure(dpi=1000)
plt.grid()
plt.grid(which="minor", axis="y")
plt.title("Stock Trading (Paper trading)", fontsize=20)
plt.plot(returns_erl, label="ElegantRL Agent", color="red")
# plt.plot(returns_sb3, label = 'Stable-Baselines3 Agent', color = 'blue' )
# plt.plot(returns_rllib, label = 'RLlib Agent', color = 'green')
plt.plot(returns_dia, label="DJIA", color="grey")
plt.ylabel("Return", fontsize=16)
plt.xlabel("Year 2021", fontsize=16)
plt.xticks(size=14)
plt.yticks(size=14)
ax = plt.gca()
ax.xaxis.set_major_locator(ticker.MultipleLocator(78))
ax.xaxis.set_minor_locator(ticker.MultipleLocator(6))
ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.005))
ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1, decimals=2))
ax.xaxis.set_major_formatter(
ticker.FixedFormatter(["", "10-19", "", "10-20", "", "10-21", "", "10-22"])
)
plt.legend(fontsize=10.5)
plt.savefig("papertrading_stock.png")
Loading

0 comments on commit 690fc61

Please sign in to comment.