From d41c2043f9259ad327969f0930abf1f072246c46 Mon Sep 17 00:00:00 2001 From: ravjot07 Date: Sun, 30 Nov 2025 19:13:37 +0530 Subject: [PATCH 1/3] Add preprocessing steps (scaling/transformations) to D2 layer - Implement target normalization with TorchNormalizer - Add continuous feature scaling support (StandardScaler, RobustScaler, TorchNormalizer) - Store target_scale parameters for inverse transformation - Update return type from list[dict] to dict - Maintain backward compatibility (scalers are optional) --- pytorch_forecasting/data/data_module.py | 87 ++++++++++++++++++++----- 1 file changed, 72 insertions(+), 15 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 34aa145e7..bf7d88254 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -293,19 +293,11 @@ def metadata(self): self._metadata = self._prepare_metadata() return self._metadata - def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: - """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. + def _preprocess_data(self, series_idx: torch.Tensor) -> dict[str, Any]: + """Preprocess data before feeding into the dataset. - Preprocessing steps - -------------------- - - * Converts target (`y`) and features (`x`) to `torch.float32`. - * Masks time points that are at or before the cutoff time. - * Splits features into categorical and continuous subsets based on - predefined indices. - - - TODO: add scalers, target normalizers etc. + Handles type conversion, masking, feature splitting, and optional + normalization/scaling of targets and continuous features. """ sample = self.time_series_dataset[series_idx] @@ -316,6 +308,7 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) + # Convert to float32 tensors if isinstance(target, torch.Tensor): target = target.float() else: @@ -326,8 +319,21 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: else: features = torch.tensor(features, dtype=torch.float32) - # TODO: add scalers, target normalizers etc. - + # Normalize target if a normalizer is provided + target_scale = None + # Handle both target_normalizer and _target_normalizer (for "auto" case) + target_normalizer = getattr(self, "_target_normalizer", None) or self.target_normalizer + if target_normalizer is not None and target_normalizer != "auto": + # Only process TorchNormalizer instances (they have get_parameters) + if hasattr(target_normalizer, "get_parameters"): + # Fit if not already fitted + if not hasattr(target_normalizer, "center_") or target_normalizer.center_ is None: + target_normalizer.fit(target) + + target = target_normalizer.transform(target) + target_scale = target_normalizer.get_parameters() + + # Split into categorical and continuous features categorical = ( features[:, self.categorical_indices] if self.categorical_indices @@ -339,7 +345,52 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: else torch.zeros((features.shape[0], 0)) ) - return { + # Scale continuous features if scalers are provided + if self.scalers is not None and len(self.scalers) > 0 and continuous.shape[1] > 0: + from sklearn.utils.validation import check_is_fitted + + scaled_continuous = continuous.clone() + feature_names = self.time_series_metadata.get("cols", {}).get("x", []) + + # Apply scaler to each continuous feature + for idx, feat_idx in enumerate(self.continuous_indices): + feature_name = None + scaler = None + + # Get feature name from metadata if available + if feat_idx < len(feature_names): + feature_name = feature_names[feat_idx] + + # Look up scaler by name or index + if feature_name and feature_name in self.scalers: + scaler = self.scalers[feature_name] + elif str(feat_idx) in self.scalers: + scaler = self.scalers[str(feat_idx)] + + if scaler is not None: + feature_col = continuous[:, idx : idx + 1] + + # sklearn scalers return numpy arrays + if isinstance(scaler, (StandardScaler, RobustScaler)): + try: + check_is_fitted(scaler) + except Exception: + scaler.fit(feature_col.numpy()) + + scaled_col = scaler.transform(feature_col.numpy()) + scaled_continuous[:, idx : idx + 1] = torch.from_numpy(scaled_col).float() + + # torch normalizers work directly with tensors + elif isinstance(scaler, (TorchNormalizer, EncoderNormalizer)): + if not hasattr(scaler, "center_") or scaler.center_ is None: + scaler.fit(feature_col) + + scaled_col = scaler.transform(feature_col) + scaled_continuous[:, idx : idx + 1] = scaled_col.float() + + continuous = scaled_continuous + + result = { "features": {"categorical": categorical, "continuous": continuous}, "target": target, "static": sample.get("st", None), @@ -350,6 +401,12 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]: "cutoff_time": cutoff_time, } + # Store normalization params for inverse transform during prediction + if target_scale is not None: + result["target_scale"] = target_scale + + return result + class _ProcessedEncoderDecoderDataset(Dataset): """PyTorch Dataset for processed encoder-decoder time series data. From 95f7cf7a3a18004d3b46d5d9fb524dca9cca608c Mon Sep 17 00:00:00 2001 From: ravjot07 Date: Sun, 30 Nov 2025 19:33:48 +0530 Subject: [PATCH 2/3] minor chnages Signed-off-by: ravjot07 --- pytorch_forecasting/data/data_module.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index bf7d88254..7a2a00ef8 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -322,12 +322,17 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> dict[str, Any]: # Normalize target if a normalizer is provided target_scale = None # Handle both target_normalizer and _target_normalizer (for "auto" case) - target_normalizer = getattr(self, "_target_normalizer", None) or self.target_normalizer + target_normalizer = ( + getattr(self, "_target_normalizer", None) or self.target_normalizer + ) if target_normalizer is not None and target_normalizer != "auto": # Only process TorchNormalizer instances (they have get_parameters) if hasattr(target_normalizer, "get_parameters"): # Fit if not already fitted - if not hasattr(target_normalizer, "center_") or target_normalizer.center_ is None: + if ( + not hasattr(target_normalizer, "center_") + or target_normalizer.center_ is None + ): target_normalizer.fit(target) target = target_normalizer.transform(target) @@ -346,7 +351,11 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> dict[str, Any]: ) # Scale continuous features if scalers are provided - if self.scalers is not None and len(self.scalers) > 0 and continuous.shape[1] > 0: + if ( + self.scalers is not None + and len(self.scalers) > 0 + and continuous.shape[1] > 0 + ): from sklearn.utils.validation import check_is_fitted scaled_continuous = continuous.clone() @@ -378,7 +387,9 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> dict[str, Any]: scaler.fit(feature_col.numpy()) scaled_col = scaler.transform(feature_col.numpy()) - scaled_continuous[:, idx : idx + 1] = torch.from_numpy(scaled_col).float() + scaled_continuous[:, idx : idx + 1] = torch.from_numpy( + scaled_col + ).float() # torch normalizers work directly with tensors elif isinstance(scaler, (TorchNormalizer, EncoderNormalizer)): From b320aeed5f51b932758d0caa06ee308216dcd49a Mon Sep 17 00:00:00 2001 From: ravjot07 Date: Sun, 30 Nov 2025 20:16:27 +0530 Subject: [PATCH 3/3] minor chnages Signed-off-by: ravjot07 --- pytorch_forecasting/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index fa5906bec..4624ccad6 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -4,6 +4,8 @@ __version__ = "1.5.0" +import torch + from pytorch_forecasting.data import ( EncoderNormalizer, GroupNormalizer, @@ -11,6 +13,19 @@ NaNLabelEncoder, TimeSeriesDataSet, ) +from pytorch_forecasting.data.encoders import TorchNormalizer + +# Register custom classes as safe globals for PyTorch 2.6+ safe unpickling +# This allows checkpoints containing these classes to be loaded with weights_only=True +torch.serialization.add_safe_globals( + [ + EncoderNormalizer, + GroupNormalizer, + MultiNormalizer, + NaNLabelEncoder, + TorchNormalizer, + ] +) from pytorch_forecasting.metrics import ( MAE, MAPE,