Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of __getitem__ of TimeSeriesDataSet #806

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions pytorch_forecasting/data/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,8 +1247,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra
len(df_index) > 0
), "filters should not remove entries all entries - check encoder/decoder lengths and lags"

return df_index

return df_index.to_records(index=True)
def filter(self, filter_func: Callable, copy: bool = True) -> "TimeSeriesDataSet":
"""
Filter subsequences in dataset.
Expand Down Expand Up @@ -1292,8 +1291,8 @@ def decoded_index(self) -> pd.DataFrame:
pd.DataFrame: index that can be understood in terms of original data
"""
# get dataframe to filter
index_start = self.index["index_start"].to_numpy()
index_last = self.index["index_end"].to_numpy()
index_start = self.index["index_start"]
index_last = self.index["index_end"]
index = (
# get group ids in order of index
pd.DataFrame(self.data["groups"][index_start].numpy(), columns=self.group_ids)
Expand Down Expand Up @@ -1404,25 +1403,30 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
Returns:
Tuple[Dict[str, torch.Tensor], torch.Tensor]: x and y for model
"""
index = self.index.iloc[idx]
index = self.index[idx]
# get index data
index_start = index.index_start
index_end = index.index_end
index_sequence_length = index.sequence_length

# get index data
data_cont = self.data["reals"][index.index_start : index.index_end + 1].clone()
data_cat = self.data["categoricals"][index.index_start : index.index_end + 1].clone()
time = self.data["time"][index.index_start : index.index_end + 1].clone()
target = [d[index.index_start : index.index_end + 1].clone() for d in self.data["target"]]
groups = self.data["groups"][index.index_start].clone()
data_cont = self.data["reals"][index_start : index_end + 1].clone()
data_cat = self.data["categoricals"][index_start : index_end + 1].clone()
time = self.data["time"][index_start : index_end + 1].clone()
target = [d[index_start : index_end + 1].clone() for d in self.data["target"]]
groups = self.data["groups"][index_start].clone()
if self.data["weight"] is None:
weight = None
else:
weight = self.data["weight"][index.index_start : index.index_end + 1].clone()
weight = self.data["weight"][index_start : index_end + 1].clone()
# get target scale in the form of a list
target_scale = self.target_normalizer.get_parameters(groups, self.group_ids)
if not isinstance(self.target_normalizer, MultiNormalizer):
target_scale = [target_scale]

# fill in missing values (if not all time indices are specified
sequence_length = len(time)
if sequence_length < index.sequence_length:
if sequence_length < index_sequence_length:
assert self.allow_missing_timesteps, "allow_missing_timesteps should be True if sequences have gaps"
repetitions = torch.cat([time[1:] - time[:-1], torch.ones(1, dtype=time.dtype)])
indices = torch.repeat_interleave(torch.arange(len(time)), repetitions)
Expand Down