From fd15e522abb7213475dbf1ca404225d19604fb56 Mon Sep 17 00:00:00 2001 From: Zezhi Shao <864453277@qq.com> Date: Sat, 31 Aug 2024 06:31:00 +0000 Subject: [PATCH] update basicts --- basicts/data/base_dataset.py | 65 ++++++++++++++++++-------- basicts/data/simple_tsf_dataset.py | 66 +++++++++++++-------------- basicts/launcher.py | 2 +- basicts/metrics/mae.py | 28 ++++++++---- basicts/metrics/mape.py | 52 +++++++++++++-------- basicts/metrics/mse.py | 31 ++++++++----- basicts/metrics/rmse.py | 18 +++++--- basicts/metrics/wape.py | 19 ++++---- basicts/runners/base_tsf_runner.py | 2 +- basicts/scaler/min_max_scaler.py | 56 +++++++++++++++-------- basicts/scaler/z_score_scaler.py | 62 +++++++++++++++---------- basicts/utils/adjacent_matrix_norm.py | 1 + basicts/utils/early_stopping.py | 0 basicts/utils/misc.py | 1 + basicts/utils/serialization.py | 3 ++ 15 files changed, 255 insertions(+), 151 deletions(-) delete mode 100644 basicts/utils/early_stopping.py diff --git a/basicts/data/base_dataset.py b/basicts/data/base_dataset.py index da056848..ffb27d9c 100644 --- a/basicts/data/base_dataset.py +++ b/basicts/data/base_dataset.py @@ -7,18 +7,21 @@ @dataclass class BaseDataset(Dataset): """ - An abstract base class for custom time series datasets. + An abstract base class for creating datasets for time series forecasting in PyTorch. + + This class provides a structured template for defining custom datasets by specifying methods + to load data and descriptions, and to access individual samples. It is designed to be subclassed + with specific implementations for different types of time series data. Attributes: - dataset_name (str): The name of the dataset. - train_val_test_ratio (List[float]): Ratios for splitting the dataset into train, validation, and test sets. - mode (str): The mode of the dataset. Must be one of ["train", "valid", "test"]. - input_len (int): The length of the input sequence (history). - output_len (int): The length of the output sequence (future). - overlap (bool): Whether to allow overlapping between the training, validation, and test sets. - It is usually set to True, but can be set to False in certain special cases, such as - when the training, validation, and test sets are sampled from different periods of time or - are from different source domains. + dataset_name (str): The name of the dataset which is used for identifying the dataset uniquely. + train_val_test_ratio (List[float]): Ratios for splitting the dataset into training, validation, + and testing sets respectively. Each value in the list should sum to 1.0. + mode (str): Operational mode of the dataset. Valid values are "train", "valid", or "test". + input_len (int): The length of the input sequence, i.e., the number of historical data points used. + output_len (int): The length of the output sequence, i.e., the number of future data points predicted. + overlap (bool): Flag to indicate whether the splits between training, validation, and testing can overlap. + Defaults to True but can be set to False to enforce non-overlapping data in different sets. """ dataset_name: str @@ -30,42 +33,68 @@ class BaseDataset(Dataset): def _load_description(self) -> dict: """ - Load the dataset's description from a file. + Abstract method to load a dataset's description from a file or source. + + This method should be implemented by subclasses to load and return the dataset's metadata, + such as its shape, range, or other relevant properties, typically from a JSON or similar file. Returns: - dict: The dataset description. + dict: A dictionary containing the dataset's metadata. + + Raises: + NotImplementedError: If the method has not been implemented by a subclass. """ raise NotImplementedError("Subclasses must implement this method.") def _load_data(self) -> np.ndarray: """ - Load the dataset and split it according to the mode. + Abstract method to load the dataset and organize it based on the specified mode. + + This method should be implemented by subclasses to load actual time series data into an array, + handling any necessary preprocessing and partitioning according to the specified `mode`. Returns: - np.ndarray: The loaded data. + np.ndarray: The loaded and appropriately split dataset array. + + Raises: + NotImplementedError: If the method has not been implemented by a subclass. """ raise NotImplementedError("Subclasses must implement this method.") def __len__(self) -> int: """ - Get the length of the dataset. + Abstract method to get the total number of samples available in the dataset. + + This method should be implemented by subclasses to calculate and return the total number of valid + samples available for training, validation, or testing based on the configuration and dataset size. Returns: int: The total number of samples. + + Raises: + NotImplementedError: If the method has not been implemented by a subclass. """ + raise NotImplementedError("Subclasses must implement this method.") def __getitem__(self, idx: int) -> dict: """ - Retrieve a single data sample. + Abstract method to retrieve a single sample from the dataset. + + This method should be implemented by subclasses to access and return a specific sample from the dataset, + given an index. It should handle the slicing of input and output sequences according to the defined + `input_len` and `output_len`. Args: - idx (int): Index of the data sample to retrieve. + idx (int): The index of the sample to retrieve. Returns: - dict: A dictionary containing the input and output data. + dict: A dictionary containing the input sequence ('inputs') and output sequence ('target'). + + Raises: + NotImplementedError: If the method has not been implemented by a subclass. """ raise NotImplementedError("Subclasses must implement this method.") diff --git a/basicts/data/simple_tsf_dataset.py b/basicts/data/simple_tsf_dataset.py index 6224e907..231696eb 100644 --- a/basicts/data/simple_tsf_dataset.py +++ b/basicts/data/simple_tsf_dataset.py @@ -8,52 +8,56 @@ class TimeSeriesForecastingDataset(BaseDataset): """ - A PyTorch Dataset for time series forecasting. - - This dataset handles splitting the data into train, validation, and test sets - based on the provided mode. It provides sequences of historical data (input) - and corresponding future data (target) for training. + A dataset class for time series forecasting problems, handling the loading, parsing, and partitioning + of time series data into training, validation, and testing sets based on provided ratios. + + This class supports configurations where sequences may or may not overlap, accommodating scenarios + where time series data is drawn from continuous periods or distinct episodes, affecting how + the data is split into batches for model training or evaluation. + + Attributes: + data_file_path (str): Path to the file containing the time series data. + description_file_path (str): Path to the JSON file containing the description of the dataset. + data (np.ndarray): The loaded time series data array, split according to the specified mode. + description (dict): Metadata about the dataset, such as shape and other properties. """ def __init__(self, dataset_name: str, train_val_test_ratio: List[float], mode: str, input_len: int, output_len: int, overlap: bool = True) -> None: """ - Initialize the dataset by loading the data and description. + Initializes the TimeSeriesForecastingDataset by setting up paths, loading data, and + preparing it according to the specified configurations. Args: dataset_name (str): The name of the dataset. train_val_test_ratio (List[float]): Ratios for splitting the dataset into train, validation, and test sets. - mode (str): The mode of the dataset. Must be one of ['train', 'valid', 'test']. - input_len (int): The length of the input sequence (history). - output_len (int): The length of the output sequence (future). - overlap (bool): Whether to allow overlapping between the training, validation, and test sets. - It is usually set to True, but can be set to False in certain special cases, such as - when the training, validation, and test sets are sampled from different periods of time or - are from different source domains. + Each value should be a float between 0 and 1, and their sum should ideally be 1. + mode (str): The operation mode of the dataset. Valid values are 'train', 'valid', or 'test'. + input_len (int): The length of the input sequence (number of historical points). + output_len (int): The length of the output sequence (number of future points to predict). + overlap (bool): Flag to determine if training/validation/test splits should overlap. + Defaults to True. Set to False for strictly non-overlapping periods. Raises: AssertionError: If `mode` is not one of ['train', 'valid', 'test']. """ - assert mode in ['train', 'valid', 'test'], f"Invalid mode: {mode}. Must be one of ['train', 'valid', 'test']." super().__init__(dataset_name, train_val_test_ratio, mode, input_len, output_len, overlap) self.data_file_path = f'datasets/{dataset_name}/data.dat' self.description_file_path = f'datasets/{dataset_name}/desc.json' - - # Load description and data self.description = self._load_description() self.data = self._load_data() def _load_description(self) -> dict: """ - Load the dataset description from a JSON file. + Loads the description of the dataset from a JSON file. Returns: - dict: The dataset description loaded from the JSON file. + dict: A dictionary containing metadata about the dataset, such as its shape and other properties. Raises: FileNotFoundError: If the description file is not found. - json.JSONDecodeError: If the JSON file cannot be decoded. + json.JSONDecodeError: If there is an error decoding the JSON data. """ try: @@ -66,29 +70,27 @@ def _load_description(self) -> dict: def _load_data(self) -> np.ndarray: """ - Load the data file and split it based on the mode. + Loads the time series data from a file and splits it according to the selected mode. Returns: - np.ndarray: The data for the selected mode (train, validation, or test). + np.ndarray: The data array for the specified mode (train, validation, or test). Raises: - ValueError: If the data file cannot be loaded or the shape is incorrect. + ValueError: If there is an issue with loading the data file or if the data shape is not as expected. """ try: - # Load data using memory-mapped file for efficiency data = np.memmap(self.data_file_path, dtype='float32', mode='r', shape=tuple(self.description['shape'])) except (FileNotFoundError, ValueError) as e: raise ValueError(f'Error loading data file: {self.data_file_path}') from e - # Split data based on train/val/test ratios total_len = len(data) train_len = int(total_len * self.train_val_test_ratio[0]) valid_len = int(total_len * self.train_val_test_ratio[1]) if self.mode == 'train': offset = self.output_len if self.overlap else 0 - return data[:train_len + offset].copy() # Consider overlapping + return data[:train_len + offset].copy() elif self.mode == 'valid': offset_left = self.input_len - 1 if self.overlap else 0 offset_right = self.output_len if self.overlap else 0 @@ -99,26 +101,24 @@ def _load_data(self) -> np.ndarray: def __getitem__(self, index: int) -> dict: """ - Get a sample from the dataset. + Retrieves a sample from the dataset at the specified index, considering both the input and output lengths. Args: - index (int): The index of the sample. + index (int): The index of the desired sample in the dataset. Returns: - dict: A dictionary containing 'inputs' (history_data) and 'target' (future_data), - where the shape of each is (L, N, C). + dict: A dictionary containing 'inputs' and 'target', where both are slices of the dataset corresponding to + the historical input data and future prediction data, respectively. """ - history_data = self.data[index:index + self.input_len] future_data = self.data[index + self.input_len:index + self.input_len + self.output_len] return {'inputs': history_data, 'target': future_data} def __len__(self) -> int: """ - Get the number of samples in the dataset. + Calculates the total number of samples available in the dataset, adjusted for the lengths of input and output sequences. Returns: - int: The number of samples available based on the input and output lengths. + int: The number of valid samples that can be drawn from the dataset, based on the configurations of input and output lengths. """ - return len(self.data) - self.input_len - self.output_len + 1 diff --git a/basicts/launcher.py b/basicts/launcher.py index b8fbdec9..7b61237c 100644 --- a/basicts/launcher.py +++ b/basicts/launcher.py @@ -39,7 +39,7 @@ def evaluation_func(cfg: Dict, runner.init_logger(logger_name='easytorch-evaluation', log_file_name='evaluation_log') try: - # Set batch size if provided + # set batch size if provided if batch_size is not None: cfg.TEST.DATA.BATCH_SIZE = batch_size else: diff --git a/basicts/metrics/mae.py b/basicts/metrics/mae.py index cf15e1c3..d9163b86 100644 --- a/basicts/metrics/mae.py +++ b/basicts/metrics/mae.py @@ -3,28 +3,36 @@ def masked_mae(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: """ - Calculate the Masked Mean Absolute Error (MAE). + Calculate the Masked Mean Absolute Error (MAE) between the predicted and target values, + while ignoring the entries in the target tensor that match the specified null value. + + This function is particularly useful for scenarios where the dataset contains missing or irrelevant + values (denoted by `null_val`) that should not contribute to the loss calculation. It effectively + masks these values to ensure they do not skew the error metrics. Args: - prediction (torch.Tensor): The predicted values. - target (torch.Tensor): The ground truth values (labels). - null_val (float, optional): The value considered as null. Defaults to np.nan. + prediction (torch.Tensor): The predicted values as a tensor. + target (torch.Tensor): The ground truth values as a tensor with the same shape as `prediction`. + null_val (float, optional): The value considered as null or missing in the `target` tensor. + Default is `np.nan`. The function will mask all `NaN` values in the target. Returns: - torch.Tensor: The masked mean absolute error. + torch.Tensor: A scalar tensor representing the masked mean absolute error. + """ + if np.isnan(null_val): mask = ~torch.isnan(target) else: eps = 5e-5 - mask = ~torch.isclose(target, torch.tensor(null_val).to(target.device), atol=eps) + mask = ~torch.isclose(target, torch.tensor(null_val).expand_as(target).to(target.device), atol=eps, rtol=0.0) mask = mask.float() - mask /= torch.mean(mask) - mask = torch.nan_to_num(mask) + mask /= torch.mean(mask) # Normalize mask to avoid bias in the loss due to the number of valid entries + mask = torch.nan_to_num(mask) # Replace any NaNs in the mask with zero loss = torch.abs(prediction - target) - loss *= mask - loss = torch.nan_to_num(loss) + loss = loss * mask # Apply the mask to the loss + loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero return torch.mean(loss) diff --git a/basicts/metrics/mape.py b/basicts/metrics/mape.py index 800a7eb0..5a9ce6ce 100644 --- a/basicts/metrics/mape.py +++ b/basicts/metrics/mape.py @@ -1,35 +1,48 @@ import torch import numpy as np - -def masked_mape(prediction: torch.Tensor, target: torch.Tensor, null_val: float = 0.0) -> torch.Tensor: +def masked_mape(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: """ - Calculate the Masked Mean Absolute Percentage Error (MAPE). + Calculate the Masked Mean Absolute Percentage Error (MAPE) between predicted and target values, + ignoring entries that are either zero or match the specified null value in the target tensor. - Note: - The null_val is set to 0.0 or np.nan for MAPE by default, and should not be changed. + This function is particularly useful for time series or regression tasks where the target values may + contain zeros or missing values, which could otherwise distort the error calculation. The function + applies a mask to ensure these entries do not affect the resulting MAPE. Args: - prediction (torch.Tensor): The predicted values. - target (torch.Tensor): The ground truth values (labels). - null_val (float, optional): The value considered as null. Defaults to 0.0. + prediction (torch.Tensor): The predicted values as a tensor. + target (torch.Tensor): The ground truth values as a tensor with the same shape as `prediction`. + null_val (float, optional): The value considered as null or missing in the `target` tensor. + Defaults to `np.nan`. The function will mask all `NaN` values in the target. Returns: - torch.Tensor: The masked mean absolute percentage error. - - Raises: - AssertionError: If null_val is not 0.0 or np.nan. + torch.Tensor: A scalar tensor representing the masked mean absolute percentage error. + + Details: + - The function creates two masks: + 1. `zero_mask`: This mask excludes entries in the `target` tensor that are close to zero, + since division by zero or near-zero values would result in extremely large or undefined errors. + 2. `null_mask`: This mask excludes entries in the `target` tensor that match the specified `null_val`. + If `null_val` is `np.nan`, the mask will exclude `NaN` values using `torch.isnan`. + + - The final mask is the intersection of `zero_mask` and `null_mask`, ensuring that only valid, non-zero, + and non-null values contribute to the MAPE calculation. """ - assert null_val == 0.0 or np.isnan(null_val), ( - "In MAPE, null_val must be 0.0 or np.nan by default. " - "This parameter is kept for consistency, but it cannot be changed." - ) - # Create mask for non-NaN and non-null values - nan_mask = ~torch.isnan(target) + # mask to exclude zero values in the target zero_mask = ~torch.isclose(target, torch.tensor(0.0).to(target.device), atol=5e-5) - mask = (nan_mask & zero_mask).float() + # mask to exclude null values in the target + if np.isnan(null_val): + null_mask = ~torch.isnan(target) + else: + eps = 5e-5 + null_mask = ~torch.isclose(target, torch.tensor(null_val).to(target.device), atol=eps) + + # combine zero and null masks + mask = (zero_mask & null_mask).float() + mask /= torch.mean(mask) mask = torch.nan_to_num(mask) @@ -38,4 +51,3 @@ def masked_mape(prediction: torch.Tensor, target: torch.Tensor, null_val: float loss = torch.nan_to_num(loss) return torch.mean(loss) - diff --git a/basicts/metrics/mse.py b/basicts/metrics/mse.py index 83ab958a..dde0fdad 100644 --- a/basicts/metrics/mse.py +++ b/basicts/metrics/mse.py @@ -1,19 +1,26 @@ import torch import numpy as np - def masked_mse(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: """ - Calculate the Masked Mean Squared Error (MSE). + Calculate the Masked Mean Squared Error (MSE) between predicted and target values, + while ignoring the entries in the target tensor that match the specified null value. + + This function is useful for scenarios where the dataset contains missing or irrelevant values + (denoted by `null_val`) that should not contribute to the loss calculation. The function applies + a mask to these values, ensuring they do not affect the error metric. Args: - prediction (torch.Tensor): The predicted values. - target (torch.Tensor): The ground truth values (labels). - null_val (float, optional): The value considered as null. Defaults to np.nan. + prediction (torch.Tensor): The predicted values as a tensor. + target (torch.Tensor): The ground truth values as a tensor with the same shape as `prediction`. + null_val (float, optional): The value considered as null or missing in the `target` tensor. + Defaults to `np.nan`. The function will mask all `NaN` values in the target. Returns: - torch.Tensor: The masked mean squared error. + torch.Tensor: A scalar tensor representing the masked mean squared error. + """ + if np.isnan(null_val): mask = ~torch.isnan(target) else: @@ -21,11 +28,11 @@ def masked_mse(prediction: torch.Tensor, target: torch.Tensor, null_val: float = mask = ~torch.isclose(target, torch.tensor(null_val).to(target.device), atol=eps) mask = mask.float() - mask /= torch.mean(mask) - mask = torch.nan_to_num(mask) + mask /= torch.mean(mask) # Normalize mask to maintain unbiased MSE calculation + mask = torch.nan_to_num(mask) # Replace any NaNs in the mask with zero - loss = (prediction - target) ** 2 - loss *= mask - loss = torch.nan_to_num(loss) + loss = (prediction - target) ** 2 # Compute squared error + loss *= mask # Apply mask to the loss + loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero - return torch.mean(loss) + return torch.mean(loss) # Return the mean of the masked loss diff --git a/basicts/metrics/rmse.py b/basicts/metrics/rmse.py index 84397b16..06c62fb3 100644 --- a/basicts/metrics/rmse.py +++ b/basicts/metrics/rmse.py @@ -3,17 +3,23 @@ from .mse import masked_mse - def masked_rmse(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: """ - Calculate the Masked Root Mean Squared Error (RMSE). + Calculate the Masked Root Mean Squared Error (RMSE) between predicted and target values, + ignoring entries in the target tensor that match the specified null value. + + This function is useful for evaluating model performance on datasets where some target values + may be missing or irrelevant (denoted by `null_val`). The RMSE provides a measure of the average + magnitude of errors, accounting only for the valid, non-null entries. Args: - prediction (torch.Tensor): The predicted values. - target (torch.Tensor): The ground truth values (labels). - null_val (float, optional): The value considered as null. Defaults to np.nan. + prediction (torch.Tensor): The predicted values as a tensor. + target (torch.Tensor): The ground truth values as a tensor with the same shape as `prediction`. + null_val (float, optional): The value considered as null or missing in the `target` tensor. + Defaults to `np.nan`. The function will ignore all `NaN` values in the target. Returns: - torch.Tensor: The masked root mean squared error. + torch.Tensor: A scalar tensor representing the masked root mean squared error. """ + return torch.sqrt(masked_mse(prediction=prediction, target=target, null_val=null_val)) diff --git a/basicts/metrics/wape.py b/basicts/metrics/wape.py index dd3df0ff..ddaf4b8e 100644 --- a/basicts/metrics/wape.py +++ b/basicts/metrics/wape.py @@ -1,19 +1,24 @@ import torch import numpy as np - def masked_wape(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: """ - Calculate the Masked Weighted Absolute Percentage Error (WAPE). + Calculate the Masked Weighted Absolute Percentage Error (WAPE) between predicted and target values, + ignoring entries in the target tensor that match the specified null value. + + WAPE is a useful metric for measuring the average error relative to the magnitude of the target values, + making it particularly suitable for comparing errors across datasets or time series with different scales. Args: - prediction (torch.Tensor): The predicted values. - target (torch.Tensor): The ground truth values (labels). - null_val (float, optional): The value considered as null. Defaults to np.nan. + prediction (torch.Tensor): The predicted values as a tensor. + target (torch.Tensor): The ground truth values as a tensor with the same shape as `prediction`. + null_val (float, optional): The value considered as null or missing in the `target` tensor. + Defaults to `np.nan`. The function will mask all `NaN` values in the target. Returns: - torch.Tensor: The masked weighted absolute percentage error. + torch.Tensor: A scalar tensor representing the masked weighted absolute percentage error. """ + if np.isnan(null_val): mask = ~torch.isnan(target) else: @@ -21,8 +26,6 @@ def masked_wape(prediction: torch.Tensor, target: torch.Tensor, null_val: float mask = ~torch.isclose(target, torch.tensor(null_val).to(target.device), atol=eps) mask = mask.float() - - # Apply mask to predictions and targets prediction, target = prediction * mask, target * mask prediction = torch.nan_to_num(prediction) diff --git a/basicts/runners/base_tsf_runner.py b/basicts/runners/base_tsf_runner.py index 239f7c93..a42f8bca 100644 --- a/basicts/runners/base_tsf_runner.py +++ b/basicts/runners/base_tsf_runner.py @@ -88,7 +88,7 @@ def __init__(self, cfg: Dict): self.null_val = cfg.get('METRICS', {}).get('NULL_VAL', np.nan) # support early stopping - # NOTE: If the project is stopped early and its configuration is rerun, + # NOTE: If the project has been stopped early and its configuration is rerun, # training will resume from the last saved checkpoint. # This feature is designed primarily for the convenience of users, # allowing them to continue training seamlessly after an interruption. diff --git a/basicts/scaler/min_max_scaler.py b/basicts/scaler/min_max_scaler.py index 437d3d90..b060ce08 100644 --- a/basicts/scaler/min_max_scaler.py +++ b/basicts/scaler/min_max_scaler.py @@ -1,45 +1,56 @@ import json -import numpy as np + import torch +import numpy as np + from .base_scaler import BaseScaler class MinMaxScaler(BaseScaler): """ - MinMaxScaler performs min-max normalization on the dataset. + MinMaxScaler performs min-max normalization on the dataset, scaling the data to a specified range + (typically [0, 1] or [-1, 1]). Attributes: - min (np.ndarray): Minimum values of the training data. - max (np.ndarray): Maximum values of the training data. - target_channel (int): The channel to apply normalization to. + min (np.ndarray): The minimum values of the training data used for normalization. + If `norm_each_channel` is True, this is an array of minimum values, one for each channel. Otherwise, it's a single scalar. + max (np.ndarray): The maximum values of the training data used for normalization. + If `norm_each_channel` is True, this is an array of maximum values, one for each channel. Otherwise, it's a single scalar. + target_channel (int): The specific channel (feature) to which normalization is applied. + By default, it is set to 0, indicating the first channel. """ - def __init__(self, dataset_name: str, train_ratio: float, norm_each_channel: bool, rescale: bool): + def __init__(self, dataset_name: str, train_ratio: float, norm_each_channel: bool = True, rescale: bool = True): """ - Initialize MinMaxScaler by loading and fitting the scaler to the training data. + Initialize the MinMaxScaler by loading the dataset and fitting the scaler to the training data. + + The scaler computes the minimum and maximum values from the training data, which are then used + to normalize the data during the `transform` operation. Args: - dataset_name (str): The name of the dataset. - train_ratio (float): Ratio of the data to be used for training. - norm_each_channel (bool): Whether to normalize each channel separately. Defaults to True. - rescale (bool): Whether to apply rescaling. Defaults to True. + dataset_name (str): The name of the dataset used to load the data. + train_ratio (float): The ratio of the dataset to be used for training. The scaler is fitted on this portion of the data. + norm_each_channel (bool): Flag indicating whether to normalize each channel separately. + If True, the min and max values are computed for each channel independently. Defaults to True. + rescale (bool): Flag indicating whether to apply rescaling after normalization. + This flag is included for consistency with the base class but is typically True in min-max scaling. """ super().__init__(dataset_name, train_ratio, norm_each_channel, rescale) - self.target_channel = 0 # Assuming normalization on the first channel + self.target_channel = 0 # assuming normalization on the first channel - # Load dataset description and data + # load dataset description and data description_file_path = f'datasets/{dataset_name}/desc.json' with open(description_file_path, 'r') as f: description = json.load(f) data_file_path = f'datasets/{dataset_name}/data.dat' data = np.memmap(data_file_path, dtype='float32', mode='r', shape=tuple(description['shape'])) - # Split data into training set + # split data into training set based on the train_ratio train_size = int(len(data) * train_ratio) train_data = data[:train_size, :, self.target_channel].copy() - # Compute min and max values + # compute minimum and maximum values for normalization if norm_each_channel: self.min = np.min(train_data, axis=0, keepdims=True) self.max = np.max(train_data, axis=0, keepdims=True) @@ -51,11 +62,14 @@ def transform(self, input_data: torch.Tensor) -> torch.Tensor: """ Apply min-max normalization to the input data. + This method normalizes the input data using the minimum and maximum values computed from the training data. + The normalization is applied only to the specified `target_channel`. + Args: - input_data (torch.Tensor): Input data to be normalized. + input_data (torch.Tensor): The input data to be normalized. Returns: - torch.Tensor: Normalized data. + torch.Tensor: The normalized data with the same shape as the input. """ input_data[..., self.target_channel] = (input_data[..., self.target_channel] - self.min) / (self.max - self.min) @@ -65,11 +79,15 @@ def inverse_transform(self, input_data: torch.Tensor) -> torch.Tensor: """ Reverse the min-max normalization to recover the original data scale. + This method transforms the normalized data back to its original scale using the minimum and maximum + values computed from the training data. This is useful for interpreting model outputs or for further analysis + in the original data scale. + Args: - input_data (torch.Tensor): Normalized data to be transformed back. + input_data (torch.Tensor): The normalized data to be transformed back. Returns: - torch.Tensor: Data in its original scale. + torch.Tensor: The data transformed back to its original scale. """ input_data[..., self.target_channel] = input_data[..., self.target_channel] * (self.max - self.min) + self.min diff --git a/basicts/scaler/z_score_scaler.py b/basicts/scaler/z_score_scaler.py index 2ad341e2..bdb58782 100644 --- a/basicts/scaler/z_score_scaler.py +++ b/basicts/scaler/z_score_scaler.py @@ -1,66 +1,79 @@ import json -import numpy as np import torch +import numpy as np from .base_scaler import BaseScaler class ZScoreScaler(BaseScaler): """ - ZScoreScaler performs Z-score normalization on the dataset. + ZScoreScaler performs Z-score normalization on the dataset, transforming the data to have a mean of zero + and a standard deviation of one. This is commonly used in preprocessing to normalize data, ensuring that + each feature contributes equally to the model. Attributes: - mean (np.ndarray): Mean of the training data. - std (np.ndarray): Standard deviation of the training data. - target_channel (int): The channel to apply normalization to. + mean (np.ndarray): The mean of the training data used for normalization. + If `norm_each_channel` is True, this is an array of means, one for each channel. Otherwise, it's a single scalar. + std (np.ndarray): The standard deviation of the training data used for normalization. + If `norm_each_channel` is True, this is an array of standard deviations, one for each channel. Otherwise, it's a single scalar. + target_channel (int): The specific channel (feature) to which normalization is applied. + By default, it is set to 0, indicating the first channel. """ def __init__(self, dataset_name: str, train_ratio: float, norm_each_channel: bool, rescale: bool): """ - Initialize ZScoreScaler by loading and fitting the scaler to the training data. + Initialize the ZScoreScaler by loading the dataset and fitting the scaler to the training data. + + The scaler computes the mean and standard deviation from the training data, which is then used to + normalize the data during the `transform` operation. Args: - dataset_name (str): The name of the dataset. - train_ratio (float): Ratio of the data to be used for training. - norm_each_channel (bool): Whether to normalize each channel separately. - rescale (bool): Whether to apply rescaling. + dataset_name (str): The name of the dataset used to load the data. + train_ratio (float): The ratio of the dataset to be used for training. The scaler is fitted on this portion of the data. + norm_each_channel (bool): Flag indicating whether to normalize each channel separately. + If True, the mean and standard deviation are computed for each channel independently. + rescale (bool): Flag indicating whether to apply rescaling after normalization. This flag is included for consistency with + the base class but is not directly used in Z-score normalization. """ super().__init__(dataset_name, train_ratio, norm_each_channel, rescale) - self.target_channel = 0 # Assuming normalization on the first channel + self.target_channel = 0 # assuming normalization on the first channel - # Load dataset description and data + # load dataset description and data description_file_path = f'datasets/{dataset_name}/desc.json' with open(description_file_path, 'r') as f: description = json.load(f) data_file_path = f'datasets/{dataset_name}/data.dat' data = np.memmap(data_file_path, dtype='float32', mode='r', shape=tuple(description['shape'])) - # Split data into training set + # split data into training set based on the train_ratio train_size = int(len(data) * train_ratio) train_data = data[:train_size, :, self.target_channel].copy() - # Compute mean and standard deviation + # compute mean and standard deviation if norm_each_channel: self.mean = np.mean(train_data, axis=0, keepdims=True) self.std = np.std(train_data, axis=0, keepdims=True) - self.std[self.std == 0] = 1.0 # Prevent division by zero + self.std[self.std == 0] = 1.0 # prevent division by zero by setting std to 1 where it's 0 else: self.mean = np.mean(train_data) self.std = np.std(train_data) if self.std == 0: - self.std = 1.0 # Prevent division by zero + self.std = 1.0 # prevent division by zero by setting std to 1 where it's 0 def transform(self, input_data: torch.Tensor) -> torch.Tensor: """ Apply Z-score normalization to the input data. + This method normalizes the input data using the mean and standard deviation computed from the training data. + The normalization is applied only to the specified `target_channel`. + Args: - input_data (torch.Tensor): Input data to be normalized. + input_data (torch.Tensor): The input data to be normalized. Returns: - torch.Tensor: Normalized data. + torch.Tensor: The normalized data with the same shape as the input. """ input_data[..., self.target_channel] = (input_data[..., self.target_channel] - self.mean) / self.std @@ -70,17 +83,20 @@ def inverse_transform(self, input_data: torch.Tensor) -> torch.Tensor: """ Reverse the Z-score normalization to recover the original data scale. + This method transforms the normalized data back to its original scale using the mean and standard deviation + computed from the training data. This is useful for interpreting model outputs or for further analysis in the original data scale. + Args: - input_data (torch.Tensor): Normalized data to be transformed back. + input_data (torch.Tensor): The normalized data to be transformed back. Returns: - torch.Tensor: Data in its original scale. + torch.Tensor: The data transformed back to its original scale. """ if isinstance(self.mean, np.ndarray): - self.mean = torch.tensor(self.mean) - self.std = torch.tensor(self.std) - # prevent in-place modification via clone (forbidden in PyTorch) + self.mean = torch.tensor(self.mean, device=input_data.device) + self.std = torch.tensor(self.std, device=input_data.device) + # Clone the input data to prevent in-place modification (which is not allowed in PyTorch) input_data = input_data.clone() input_data[..., self.target_channel] = input_data[..., self.target_channel] * self.std + self.mean return input_data diff --git a/basicts/utils/adjacent_matrix_norm.py b/basicts/utils/adjacent_matrix_norm.py index 072f5ff4..36d13218 100644 --- a/basicts/utils/adjacent_matrix_norm.py +++ b/basicts/utils/adjacent_matrix_norm.py @@ -2,6 +2,7 @@ import scipy.sparse as sp from scipy.sparse import linalg + def calculate_symmetric_normalized_laplacian(adj: np.ndarray) -> np.matrix: """ Calculate the symmetric normalized Laplacian. diff --git a/basicts/utils/early_stopping.py b/basicts/utils/early_stopping.py deleted file mode 100644 index e69de29b..00000000 diff --git a/basicts/utils/misc.py b/basicts/utils/misc.py index a8e705a2..dc74959a 100644 --- a/basicts/utils/misc.py +++ b/basicts/utils/misc.py @@ -1,5 +1,6 @@ import time from functools import partial + import torch diff --git a/basicts/utils/serialization.py b/basicts/utils/serialization.py index c503fabb..1c57a916 100644 --- a/basicts/utils/serialization.py +++ b/basicts/utils/serialization.py @@ -19,6 +19,7 @@ def get_regular_settings(dataset_name: str) -> dict: Returns: dict: Regular settings for the dataset. """ + # read json file: datasets/dataset_name/desc.json desc = load_dataset_desc(dataset_name) regular_settings = desc['regular_settings'] @@ -34,6 +35,7 @@ def load_dataset_desc(dataset_name: str) -> str: Returns: str: Description of the dataset. """ + # read json file: datasets/dataset_name/desc.json with open(f'datasets/{dataset_name}/desc.json', 'r') as f: desc = json.load(f) @@ -49,6 +51,7 @@ def load_dataset_data(dataset_name: str) -> np.ndarray: Returns: np.ndarray: Loaded data. """ + shape = load_dataset_desc(dataset_name)['shape'] dat_file_path = f'datasets/{dataset_name}/data.dat' data = np.memmap(dat_file_path, mode='r', dtype=np.float32, shape=tuple(shape)).copy()