-
Notifications
You must be signed in to change notification settings - Fork 134
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d5b30a2
commit 69afc6a
Showing
22 changed files
with
691 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .autox import AutoX | ||
from .autots import AutoTS | ||
from .autoxserver import AutoXServer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from autox.autox_ts.feature_engineer import fe_rolling_stat | ||
from autox.autox_ts.feature_engineer import fe_lag | ||
from autox.autox_ts.feature_engineer import fe_diff | ||
from autox.autox_ts.feature_engineer import fe_time | ||
from autox.autox_ts.feature_engineer import fe_time_add | ||
from autox.autox_ts.feature_selection import feature_filter | ||
from autox.autox_ts.models import ts_lgb_model | ||
from autox.autox_ts.util import feature_combination | ||
from autox.autox_ts.util import construct_data | ||
# from autox.autox_ts.baseline import prophet_predict | ||
from autox.autox_competition.util import log | ||
from autox.autox_competition.process_data import clip_label | ||
|
||
class AutoTS(): | ||
def __init__(self, | ||
df, | ||
id_col, | ||
time_col, | ||
target_col, | ||
time_varying_cols, | ||
time_interval_num, | ||
time_interval_unit, | ||
forecast_period, | ||
mode='auto', | ||
metric='rmse'): | ||
|
||
assert (mode in ['auto', 'prophet']) | ||
|
||
self.df = df | ||
self.id_col = id_col | ||
self.time_col = time_col | ||
self.target_col = target_col | ||
self.time_varying_cols = time_varying_cols | ||
self.time_interval_num = time_interval_num | ||
self.time_interval_unit = time_interval_unit | ||
self.forecast_period = forecast_period | ||
self.mode = mode | ||
self.metric = metric | ||
|
||
def get_result_(self): | ||
if self.mode == 'auto': | ||
sub = self.kdata_lgb() | ||
# elif self.mode == 'prophet': | ||
# sub = self.baseline_prophet() | ||
|
||
log('[+] post process') | ||
sub[self.target_col] = clip_label(sub[self.target_col], self.df[self.target_col].min(), self.df[self.target_col].max()) | ||
|
||
return sub | ||
|
||
# def baseline_prophet(self): | ||
# sub = prophet_predict(df=self.df, | ||
# id_col=self.id_col, | ||
# time_col=self.time_col, | ||
# target_col=self.target_col, | ||
# time_interval_num=self.time_interval_num, | ||
# time_interval_unit=self.time_interval_unit, | ||
# forecast_period=self.forecast_period) | ||
# return sub | ||
|
||
def kdata_lgb(self): | ||
|
||
log('[+] feature engineer') | ||
|
||
# rolling 窗口特征 | ||
df_rolling_stat = fe_rolling_stat(self.df, | ||
id_col=self.id_col, | ||
time_col=self.time_col, | ||
time_varying_cols=self.time_varying_cols, | ||
window_size=[4, 16, 64, 256]) | ||
# lag 特征 | ||
df_lag = fe_lag(self.df, | ||
id_col=self.id_col, | ||
time_col=self.time_col, | ||
time_varying_cols=self.time_varying_cols, | ||
lag=[1, 2, 3]) | ||
|
||
# diff特征 | ||
df_diff = fe_diff(self.df, | ||
id_col=self.id_col, | ||
time_col=self.time_col, | ||
time_varying_cols=self.time_varying_cols, | ||
lag=[1, 2, 3]) | ||
|
||
# 时间特征 | ||
df_time = fe_time(self.df, time_col=self.time_col) | ||
|
||
# 合并所有特征 | ||
df_all = feature_combination([self.df, df_rolling_stat, df_lag, df_diff, df_time]) | ||
|
||
# 构造数据 | ||
new_target_col = 'y' | ||
add_time_col = 't2' | ||
train, test = construct_data(df_all, | ||
id_col=self.id_col, | ||
time_col=self.time_col, | ||
target_col=self.target_col, | ||
time_interval_num=self.time_interval_num, | ||
time_interval_unit=self.time_interval_unit, | ||
forecast_period=self.forecast_period, | ||
new_target_col=new_target_col, | ||
add_time_col=add_time_col) | ||
|
||
# 补充特征 | ||
fe_time_add(train, add_time_col) | ||
fe_time_add(test, add_time_col) | ||
|
||
# 特征选择 | ||
used_features = feature_filter(train, test, self.time_col, target_col=new_target_col) | ||
category_cols = [self.id_col, 'k_step'] | ||
|
||
log('[+] train model') | ||
# 模型 | ||
sub, feature_importances = \ | ||
ts_lgb_model(train, test, | ||
id_col=self.id_col, | ||
time_col=add_time_col, | ||
target_col=new_target_col, | ||
used_features=used_features, | ||
category_cols=category_cols, | ||
time_interval_num=self.time_interval_num, | ||
time_interval_unit=self.time_interval_unit, | ||
forecast_period=self.forecast_period, | ||
label_log=False, | ||
metric=self.metric) | ||
self.feature_importances = feature_importances | ||
|
||
return sub |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .facebook_prophet import prophet_predict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# import pandas as pd | ||
# from fbprophet import Prophet | ||
# from tqdm import tqdm | ||
# | ||
# def prophet_predict(df, id_col, time_col, target_col, time_interval_num, time_interval_unit, forecast_period): | ||
# result = pd.DataFrame() | ||
# for cur_TurbID in tqdm(df[id_col].unique(), total=df[id_col].nunique()): | ||
# Prophet_df = df.loc[df[id_col] == cur_TurbID, [time_col, target_col]] | ||
# Prophet_df.columns = ['ds', 'y'] | ||
# Prophet_df.index = range(len(Prophet_df)) | ||
# Prophet_df['ds'] = pd.to_datetime(Prophet_df['ds']) | ||
# | ||
# m = Prophet() | ||
# m.fit(Prophet_df) | ||
# | ||
# freq = str(time_interval_num) + time_interval_unit | ||
# future = m.make_future_dataframe(periods=forecast_period, freq=freq) | ||
# forecast = m.predict(future) | ||
# | ||
# cur_forecast = forecast.loc[forecast['ds'] > df[time_col].max(), ['ds', 'yhat']] | ||
# cur_forecast[id_col] = cur_TurbID | ||
# result = result.append(cur_forecast) | ||
# result.index = range(len(result)) | ||
# | ||
# result.rename({'ds': time_col, 'yhat': target_col}, axis=1, inplace=True) | ||
# return result[[id_col, time_col, target_col]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .fe_diff import fe_diff | ||
from .fe_lag import fe_lag | ||
from .fe_rolling_stat import fe_rolling_stat | ||
from .fe_time import fe_time | ||
from .fe_time_add import fe_time_add |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from autox.autox_competition.util import log | ||
|
||
def fe_diff(df, id_col, time_col, time_varying_cols, lag): | ||
log('[+] fe_diff') | ||
result = df[[id_col, time_col]] | ||
df = df.sort_values(by = time_col) | ||
add_feas = [] | ||
key = id_col | ||
for value in time_varying_cols: | ||
for cur_lag in lag: | ||
name = f'{key}__{value}__diff__{cur_lag}' | ||
add_feas.append(name) | ||
df[name] = df[value] - df.groupby(key)[value].shift(cur_lag) | ||
return result.merge(df[[id_col, time_col] + add_feas], on = [id_col, time_col], how = 'left')[add_feas] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from autox.autox_competition.util import log | ||
|
||
def fe_lag(df, id_col, time_col, time_varying_cols, lag): | ||
log('[+] fe_lag') | ||
result = df[[id_col, time_col]] | ||
df = df.sort_values(by = time_col) | ||
add_feas = [] | ||
key = id_col | ||
for value in time_varying_cols: | ||
for cur_lag in lag: | ||
name = f'{key}__{value}__lag__{cur_lag}' | ||
add_feas.append(name) | ||
df[name] = df.groupby(key)[value].shift(cur_lag) | ||
return result.merge(df[[id_col, time_col] + add_feas], on = [id_col, time_col], how = 'left')[add_feas] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from tqdm import tqdm | ||
import warnings | ||
warnings.filterwarnings('ignore') | ||
from autox.autox_competition.util import log | ||
|
||
def fe_rolling_stat(df, id_col, time_col, time_varying_cols, window_size): | ||
log('[+] fe_rolling_stat') | ||
result = df[[id_col, time_col]] | ||
df = df.sort_values(by = time_col) | ||
add_feas = [] | ||
key = id_col | ||
for cur_ws in tqdm(window_size): | ||
for val in time_varying_cols: | ||
for op in ['mean', 'std', 'median', 'max', 'min', 'kurt', 'skew']: | ||
name = f'{key}__{val}__{cur_ws}__{op}' | ||
add_feas.append(name) | ||
if op == 'mean': | ||
df[name] = df.groupby(key)[val].transform( | ||
lambda x: x.rolling(window=cur_ws).mean()) | ||
if op == 'std': | ||
df[name] = df.groupby(key)[val].transform( | ||
lambda x: x.rolling(window=cur_ws).std()) | ||
if op == 'median': | ||
df[name] = df.groupby(key)[val].transform( | ||
lambda x: x.rolling(window=cur_ws).median()) | ||
if op == 'max': | ||
df[name] = df.groupby(key)[val].transform( | ||
lambda x: x.rolling(window=cur_ws).max()) | ||
if op == 'min': | ||
df[name] = df.groupby(key)[val].transform( | ||
lambda x: x.rolling(window=cur_ws).min()) | ||
if op == 'kurt': | ||
df[name] = df.groupby(key)[val].transform( | ||
lambda x: x.rolling(window=cur_ws).kurt()) | ||
if op == 'skew': | ||
df[name] = df.groupby(key)[val].transform( | ||
lambda x: x.rolling(window=cur_ws).skew()) | ||
return result.merge(df[[id_col, time_col] + add_feas], on = [id_col, time_col], how = 'left')[add_feas] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import pandas as pd | ||
from autox.autox_competition.util import log | ||
|
||
def fe_time(df, time_col): | ||
log('[+] fe_time') | ||
result = pd.DataFrame() | ||
prefix = time_col + "_" | ||
|
||
df[time_col] = pd.to_datetime(df[time_col]) | ||
|
||
result[prefix + 'year'] = df[time_col].dt.year | ||
result[prefix + 'month'] = df[time_col].dt.month | ||
result[prefix + 'day'] = df[time_col].dt.day | ||
result[prefix + 'hour'] = df[time_col].dt.hour | ||
result[prefix + 'weekofyear'] = df[time_col].dt.weekofyear | ||
result[prefix + 'dayofweek'] = df[time_col].dt.dayofweek | ||
result[prefix + 'is_wknd'] = df[time_col].dt.dayofweek // 5 | ||
result[prefix + 'quarter'] = df[time_col].dt.quarter | ||
result[prefix + 'is_month_start'] = df[time_col].dt.is_month_start.astype(int) | ||
result[prefix + 'is_month_end'] = df[time_col].dt.is_month_end.astype(int) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import pandas as pd | ||
from autox.autox_competition.util import log | ||
|
||
def fe_time_add(df, time_col): | ||
log('[+] fe_time_add') | ||
prefix = time_col + "_" | ||
|
||
df[time_col] = pd.to_datetime(df[time_col]) | ||
|
||
df[prefix + 'year'] = df[time_col].dt.year | ||
df[prefix + 'month'] = df[time_col].dt.month | ||
df[prefix + 'day'] = df[time_col].dt.day | ||
df[prefix + 'hour'] = df[time_col].dt.hour | ||
df[prefix + 'weekofyear'] = df[time_col].dt.weekofyear | ||
df[prefix + 'dayofweek'] = df[time_col].dt.dayofweek | ||
df[prefix + 'is_wknd'] = df[time_col].dt.dayofweek // 5 | ||
df[prefix + 'quarter'] = df[time_col].dt.quarter | ||
df[prefix + 'is_month_start'] = df[time_col].dt.is_month_start.astype(int) | ||
df[prefix + 'is_month_end'] = df[time_col].dt.is_month_end.astype(int) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .feature_filter import feature_filter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from tqdm import tqdm | ||
from autox.autox_competition.util import log | ||
|
||
def feature_filter(train, test, time_col, target_col): | ||
log('[+] feature_filter') | ||
not_used = [] | ||
|
||
# nunique为1 | ||
# train的最小值比test的最大值大 or train的最大值比test的最小值小 | ||
for col in tqdm(test.columns): | ||
|
||
col_dtype = str(test[col].dtype) | ||
if not col_dtype.startswith('int') and not col_dtype.startswith('float'): | ||
not_used.append(col) | ||
elif train[col].nunique() <= 1: | ||
not_used.append(col) | ||
elif train[col].min() > test[col].max() or train[col].max() < test[col].min(): | ||
not_used.append(col) | ||
|
||
not_used = list(set(not_used + [time_col, target_col])) | ||
print(f'not_used: {not_used}') | ||
|
||
used_features = [x for x in test.columns if x not in not_used] | ||
return used_features |
Oops, something went wrong.