Skip to content

Commit

Permalink
update basicts
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Aug 31, 2024
1 parent ea194e9 commit fd15e52
Show file tree
Hide file tree
Showing 15 changed files with 255 additions and 151 deletions.
65 changes: 47 additions & 18 deletions basicts/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@
@dataclass
class BaseDataset(Dataset):
"""
An abstract base class for custom time series datasets.
An abstract base class for creating datasets for time series forecasting in PyTorch.
This class provides a structured template for defining custom datasets by specifying methods
to load data and descriptions, and to access individual samples. It is designed to be subclassed
with specific implementations for different types of time series data.
Attributes:
dataset_name (str): The name of the dataset.
train_val_test_ratio (List[float]): Ratios for splitting the dataset into train, validation, and test sets.
mode (str): The mode of the dataset. Must be one of ["train", "valid", "test"].
input_len (int): The length of the input sequence (history).
output_len (int): The length of the output sequence (future).
overlap (bool): Whether to allow overlapping between the training, validation, and test sets.
It is usually set to True, but can be set to False in certain special cases, such as
when the training, validation, and test sets are sampled from different periods of time or
are from different source domains.
dataset_name (str): The name of the dataset which is used for identifying the dataset uniquely.
train_val_test_ratio (List[float]): Ratios for splitting the dataset into training, validation,
and testing sets respectively. Each value in the list should sum to 1.0.
mode (str): Operational mode of the dataset. Valid values are "train", "valid", or "test".
input_len (int): The length of the input sequence, i.e., the number of historical data points used.
output_len (int): The length of the output sequence, i.e., the number of future data points predicted.
overlap (bool): Flag to indicate whether the splits between training, validation, and testing can overlap.
Defaults to True but can be set to False to enforce non-overlapping data in different sets.
"""

dataset_name: str
Expand All @@ -30,42 +33,68 @@ class BaseDataset(Dataset):

def _load_description(self) -> dict:
"""
Load the dataset's description from a file.
Abstract method to load a dataset's description from a file or source.
This method should be implemented by subclasses to load and return the dataset's metadata,
such as its shape, range, or other relevant properties, typically from a JSON or similar file.
Returns:
dict: The dataset description.
dict: A dictionary containing the dataset's metadata.
Raises:
NotImplementedError: If the method has not been implemented by a subclass.
"""

raise NotImplementedError("Subclasses must implement this method.")

def _load_data(self) -> np.ndarray:
"""
Load the dataset and split it according to the mode.
Abstract method to load the dataset and organize it based on the specified mode.
This method should be implemented by subclasses to load actual time series data into an array,
handling any necessary preprocessing and partitioning according to the specified `mode`.
Returns:
np.ndarray: The loaded data.
np.ndarray: The loaded and appropriately split dataset array.
Raises:
NotImplementedError: If the method has not been implemented by a subclass.
"""

raise NotImplementedError("Subclasses must implement this method.")

def __len__(self) -> int:
"""
Get the length of the dataset.
Abstract method to get the total number of samples available in the dataset.
This method should be implemented by subclasses to calculate and return the total number of valid
samples available for training, validation, or testing based on the configuration and dataset size.
Returns:
int: The total number of samples.
Raises:
NotImplementedError: If the method has not been implemented by a subclass.
"""

raise NotImplementedError("Subclasses must implement this method.")

def __getitem__(self, idx: int) -> dict:
"""
Retrieve a single data sample.
Abstract method to retrieve a single sample from the dataset.
This method should be implemented by subclasses to access and return a specific sample from the dataset,
given an index. It should handle the slicing of input and output sequences according to the defined
`input_len` and `output_len`.
Args:
idx (int): Index of the data sample to retrieve.
idx (int): The index of the sample to retrieve.
Returns:
dict: A dictionary containing the input and output data.
dict: A dictionary containing the input sequence ('inputs') and output sequence ('target').
Raises:
NotImplementedError: If the method has not been implemented by a subclass.
"""

raise NotImplementedError("Subclasses must implement this method.")
66 changes: 33 additions & 33 deletions basicts/data/simple_tsf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,56 @@

class TimeSeriesForecastingDataset(BaseDataset):
"""
A PyTorch Dataset for time series forecasting.
This dataset handles splitting the data into train, validation, and test sets
based on the provided mode. It provides sequences of historical data (input)
and corresponding future data (target) for training.
A dataset class for time series forecasting problems, handling the loading, parsing, and partitioning
of time series data into training, validation, and testing sets based on provided ratios.
This class supports configurations where sequences may or may not overlap, accommodating scenarios
where time series data is drawn from continuous periods or distinct episodes, affecting how
the data is split into batches for model training or evaluation.
Attributes:
data_file_path (str): Path to the file containing the time series data.
description_file_path (str): Path to the JSON file containing the description of the dataset.
data (np.ndarray): The loaded time series data array, split according to the specified mode.
description (dict): Metadata about the dataset, such as shape and other properties.
"""

def __init__(self, dataset_name: str, train_val_test_ratio: List[float], mode: str, input_len: int, output_len: int, overlap: bool = True) -> None:
"""
Initialize the dataset by loading the data and description.
Initializes the TimeSeriesForecastingDataset by setting up paths, loading data, and
preparing it according to the specified configurations.
Args:
dataset_name (str): The name of the dataset.
train_val_test_ratio (List[float]): Ratios for splitting the dataset into train, validation, and test sets.
mode (str): The mode of the dataset. Must be one of ['train', 'valid', 'test'].
input_len (int): The length of the input sequence (history).
output_len (int): The length of the output sequence (future).
overlap (bool): Whether to allow overlapping between the training, validation, and test sets.
It is usually set to True, but can be set to False in certain special cases, such as
when the training, validation, and test sets are sampled from different periods of time or
are from different source domains.
Each value should be a float between 0 and 1, and their sum should ideally be 1.
mode (str): The operation mode of the dataset. Valid values are 'train', 'valid', or 'test'.
input_len (int): The length of the input sequence (number of historical points).
output_len (int): The length of the output sequence (number of future points to predict).
overlap (bool): Flag to determine if training/validation/test splits should overlap.
Defaults to True. Set to False for strictly non-overlapping periods.
Raises:
AssertionError: If `mode` is not one of ['train', 'valid', 'test'].
"""

assert mode in ['train', 'valid', 'test'], f"Invalid mode: {mode}. Must be one of ['train', 'valid', 'test']."
super().__init__(dataset_name, train_val_test_ratio, mode, input_len, output_len, overlap)

self.data_file_path = f'datasets/{dataset_name}/data.dat'
self.description_file_path = f'datasets/{dataset_name}/desc.json'

# Load description and data
self.description = self._load_description()
self.data = self._load_data()

def _load_description(self) -> dict:
"""
Load the dataset description from a JSON file.
Loads the description of the dataset from a JSON file.
Returns:
dict: The dataset description loaded from the JSON file.
dict: A dictionary containing metadata about the dataset, such as its shape and other properties.
Raises:
FileNotFoundError: If the description file is not found.
json.JSONDecodeError: If the JSON file cannot be decoded.
json.JSONDecodeError: If there is an error decoding the JSON data.
"""

try:
Expand All @@ -66,29 +70,27 @@ def _load_description(self) -> dict:

def _load_data(self) -> np.ndarray:
"""
Load the data file and split it based on the mode.
Loads the time series data from a file and splits it according to the selected mode.
Returns:
np.ndarray: The data for the selected mode (train, validation, or test).
np.ndarray: The data array for the specified mode (train, validation, or test).
Raises:
ValueError: If the data file cannot be loaded or the shape is incorrect.
ValueError: If there is an issue with loading the data file or if the data shape is not as expected.
"""

try:
# Load data using memory-mapped file for efficiency
data = np.memmap(self.data_file_path, dtype='float32', mode='r', shape=tuple(self.description['shape']))
except (FileNotFoundError, ValueError) as e:
raise ValueError(f'Error loading data file: {self.data_file_path}') from e

# Split data based on train/val/test ratios
total_len = len(data)
train_len = int(total_len * self.train_val_test_ratio[0])
valid_len = int(total_len * self.train_val_test_ratio[1])

if self.mode == 'train':
offset = self.output_len if self.overlap else 0
return data[:train_len + offset].copy() # Consider overlapping
return data[:train_len + offset].copy()
elif self.mode == 'valid':
offset_left = self.input_len - 1 if self.overlap else 0
offset_right = self.output_len if self.overlap else 0
Expand All @@ -99,26 +101,24 @@ def _load_data(self) -> np.ndarray:

def __getitem__(self, index: int) -> dict:
"""
Get a sample from the dataset.
Retrieves a sample from the dataset at the specified index, considering both the input and output lengths.
Args:
index (int): The index of the sample.
index (int): The index of the desired sample in the dataset.
Returns:
dict: A dictionary containing 'inputs' (history_data) and 'target' (future_data),
where the shape of each is (L, N, C).
dict: A dictionary containing 'inputs' and 'target', where both are slices of the dataset corresponding to
the historical input data and future prediction data, respectively.
"""

history_data = self.data[index:index + self.input_len]
future_data = self.data[index + self.input_len:index + self.input_len + self.output_len]
return {'inputs': history_data, 'target': future_data}

def __len__(self) -> int:
"""
Get the number of samples in the dataset.
Calculates the total number of samples available in the dataset, adjusted for the lengths of input and output sequences.
Returns:
int: The number of samples available based on the input and output lengths.
int: The number of valid samples that can be drawn from the dataset, based on the configurations of input and output lengths.
"""

return len(self.data) - self.input_len - self.output_len + 1
2 changes: 1 addition & 1 deletion basicts/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def evaluation_func(cfg: Dict,
runner.init_logger(logger_name='easytorch-evaluation', log_file_name='evaluation_log')

try:
# Set batch size if provided
# set batch size if provided
if batch_size is not None:
cfg.TEST.DATA.BATCH_SIZE = batch_size
else:
Expand Down
28 changes: 18 additions & 10 deletions basicts/metrics/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,36 @@

def masked_mae(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
"""
Calculate the Masked Mean Absolute Error (MAE).
Calculate the Masked Mean Absolute Error (MAE) between the predicted and target values,
while ignoring the entries in the target tensor that match the specified null value.
This function is particularly useful for scenarios where the dataset contains missing or irrelevant
values (denoted by `null_val`) that should not contribute to the loss calculation. It effectively
masks these values to ensure they do not skew the error metrics.
Args:
prediction (torch.Tensor): The predicted values.
target (torch.Tensor): The ground truth values (labels).
null_val (float, optional): The value considered as null. Defaults to np.nan.
prediction (torch.Tensor): The predicted values as a tensor.
target (torch.Tensor): The ground truth values as a tensor with the same shape as `prediction`.
null_val (float, optional): The value considered as null or missing in the `target` tensor.
Default is `np.nan`. The function will mask all `NaN` values in the target.
Returns:
torch.Tensor: The masked mean absolute error.
torch.Tensor: A scalar tensor representing the masked mean absolute error.
"""

if np.isnan(null_val):
mask = ~torch.isnan(target)
else:
eps = 5e-5
mask = ~torch.isclose(target, torch.tensor(null_val).to(target.device), atol=eps)
mask = ~torch.isclose(target, torch.tensor(null_val).expand_as(target).to(target.device), atol=eps, rtol=0.0)

mask = mask.float()
mask /= torch.mean(mask)
mask = torch.nan_to_num(mask)
mask /= torch.mean(mask) # Normalize mask to avoid bias in the loss due to the number of valid entries
mask = torch.nan_to_num(mask) # Replace any NaNs in the mask with zero

loss = torch.abs(prediction - target)
loss *= mask
loss = torch.nan_to_num(loss)
loss = loss * mask # Apply the mask to the loss
loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero

return torch.mean(loss)
52 changes: 32 additions & 20 deletions basicts/metrics/mape.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,48 @@
import torch
import numpy as np


def masked_mape(prediction: torch.Tensor, target: torch.Tensor, null_val: float = 0.0) -> torch.Tensor:
def masked_mape(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
"""
Calculate the Masked Mean Absolute Percentage Error (MAPE).
Calculate the Masked Mean Absolute Percentage Error (MAPE) between predicted and target values,
ignoring entries that are either zero or match the specified null value in the target tensor.
Note:
The null_val is set to 0.0 or np.nan for MAPE by default, and should not be changed.
This function is particularly useful for time series or regression tasks where the target values may
contain zeros or missing values, which could otherwise distort the error calculation. The function
applies a mask to ensure these entries do not affect the resulting MAPE.
Args:
prediction (torch.Tensor): The predicted values.
target (torch.Tensor): The ground truth values (labels).
null_val (float, optional): The value considered as null. Defaults to 0.0.
prediction (torch.Tensor): The predicted values as a tensor.
target (torch.Tensor): The ground truth values as a tensor with the same shape as `prediction`.
null_val (float, optional): The value considered as null or missing in the `target` tensor.
Defaults to `np.nan`. The function will mask all `NaN` values in the target.
Returns:
torch.Tensor: The masked mean absolute percentage error.
Raises:
AssertionError: If null_val is not 0.0 or np.nan.
torch.Tensor: A scalar tensor representing the masked mean absolute percentage error.
Details:
- The function creates two masks:
1. `zero_mask`: This mask excludes entries in the `target` tensor that are close to zero,
since division by zero or near-zero values would result in extremely large or undefined errors.
2. `null_mask`: This mask excludes entries in the `target` tensor that match the specified `null_val`.
If `null_val` is `np.nan`, the mask will exclude `NaN` values using `torch.isnan`.
- The final mask is the intersection of `zero_mask` and `null_mask`, ensuring that only valid, non-zero,
and non-null values contribute to the MAPE calculation.
"""
assert null_val == 0.0 or np.isnan(null_val), (
"In MAPE, null_val must be 0.0 or np.nan by default. "
"This parameter is kept for consistency, but it cannot be changed."
)

# Create mask for non-NaN and non-null values
nan_mask = ~torch.isnan(target)
# mask to exclude zero values in the target
zero_mask = ~torch.isclose(target, torch.tensor(0.0).to(target.device), atol=5e-5)

mask = (nan_mask & zero_mask).float()
# mask to exclude null values in the target
if np.isnan(null_val):
null_mask = ~torch.isnan(target)
else:
eps = 5e-5
null_mask = ~torch.isclose(target, torch.tensor(null_val).to(target.device), atol=eps)

# combine zero and null masks
mask = (zero_mask & null_mask).float()

mask /= torch.mean(mask)
mask = torch.nan_to_num(mask)

Expand All @@ -38,4 +51,3 @@ def masked_mape(prediction: torch.Tensor, target: torch.Tensor, null_val: float
loss = torch.nan_to_num(loss)

return torch.mean(loss)

Loading

0 comments on commit fd15e52

Please sign in to comment.