Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,28 @@

__version__ = "1.5.0"

import torch

from pytorch_forecasting.data import (
EncoderNormalizer,
GroupNormalizer,
MultiNormalizer,
NaNLabelEncoder,
TimeSeriesDataSet,
)
from pytorch_forecasting.data.encoders import TorchNormalizer

# Register custom classes as safe globals for PyTorch 2.6+ safe unpickling
# This allows checkpoints containing these classes to be loaded with weights_only=True
torch.serialization.add_safe_globals(
[
EncoderNormalizer,
GroupNormalizer,
MultiNormalizer,
NaNLabelEncoder,
TorchNormalizer,
]
)
from pytorch_forecasting.metrics import (
MAE,
MAPE,
Expand Down
96 changes: 82 additions & 14 deletions pytorch_forecasting/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,19 +293,11 @@ def metadata(self):
self._metadata = self._prepare_metadata()
return self._metadata

def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:
"""Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset.
def _preprocess_data(self, series_idx: torch.Tensor) -> dict[str, Any]:
"""Preprocess data before feeding into the dataset.

Preprocessing steps
--------------------

* Converts target (`y`) and features (`x`) to `torch.float32`.
* Masks time points that are at or before the cutoff time.
* Splits features into categorical and continuous subsets based on
predefined indices.


TODO: add scalers, target normalizers etc.
Handles type conversion, masking, feature splitting, and optional
normalization/scaling of targets and continuous features.
"""
sample = self.time_series_dataset[series_idx]

Expand All @@ -316,6 +308,7 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:

time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool)

# Convert to float32 tensors
if isinstance(target, torch.Tensor):
target = target.float()
else:
Expand All @@ -326,8 +319,26 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:
else:
features = torch.tensor(features, dtype=torch.float32)

# TODO: add scalers, target normalizers etc.
# Normalize target if a normalizer is provided
target_scale = None
# Handle both target_normalizer and _target_normalizer (for "auto" case)
target_normalizer = (
getattr(self, "_target_normalizer", None) or self.target_normalizer
)
if target_normalizer is not None and target_normalizer != "auto":
# Only process TorchNormalizer instances (they have get_parameters)
if hasattr(target_normalizer, "get_parameters"):
# Fit if not already fitted
if (
not hasattr(target_normalizer, "center_")
or target_normalizer.center_ is None
):
target_normalizer.fit(target)

target = target_normalizer.transform(target)
target_scale = target_normalizer.get_parameters()

# Split into categorical and continuous features
categorical = (
features[:, self.categorical_indices]
if self.categorical_indices
Expand All @@ -339,7 +350,58 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:
else torch.zeros((features.shape[0], 0))
)

return {
# Scale continuous features if scalers are provided
if (
self.scalers is not None
and len(self.scalers) > 0
and continuous.shape[1] > 0
):
from sklearn.utils.validation import check_is_fitted

scaled_continuous = continuous.clone()
feature_names = self.time_series_metadata.get("cols", {}).get("x", [])

# Apply scaler to each continuous feature
for idx, feat_idx in enumerate(self.continuous_indices):
feature_name = None
scaler = None

# Get feature name from metadata if available
if feat_idx < len(feature_names):
feature_name = feature_names[feat_idx]

# Look up scaler by name or index
if feature_name and feature_name in self.scalers:
scaler = self.scalers[feature_name]
elif str(feat_idx) in self.scalers:
scaler = self.scalers[str(feat_idx)]

if scaler is not None:
feature_col = continuous[:, idx : idx + 1]

# sklearn scalers return numpy arrays
if isinstance(scaler, (StandardScaler, RobustScaler)):
try:
check_is_fitted(scaler)
except Exception:
scaler.fit(feature_col.numpy())

scaled_col = scaler.transform(feature_col.numpy())
scaled_continuous[:, idx : idx + 1] = torch.from_numpy(
scaled_col
).float()

# torch normalizers work directly with tensors
elif isinstance(scaler, (TorchNormalizer, EncoderNormalizer)):
if not hasattr(scaler, "center_") or scaler.center_ is None:
scaler.fit(feature_col)

scaled_col = scaler.transform(feature_col)
scaled_continuous[:, idx : idx + 1] = scaled_col.float()

continuous = scaled_continuous

result = {
"features": {"categorical": categorical, "continuous": continuous},
"target": target,
"static": sample.get("st", None),
Expand All @@ -350,6 +412,12 @@ def _preprocess_data(self, series_idx: torch.Tensor) -> list[dict[str, Any]]:
"cutoff_time": cutoff_time,
}

# Store normalization params for inverse transform during prediction
if target_scale is not None:
result["target_scale"] = target_scale

return result

class _ProcessedEncoderDecoderDataset(Dataset):
"""PyTorch Dataset for processed encoder-decoder time series data.

Expand Down
Loading