From 302ec0f739ef7614fe6c7f67302c311380c42c07 Mon Sep 17 00:00:00 2001 From: JATAYU000 Date: Fri, 28 Nov 2025 20:21:01 +0530 Subject: [PATCH 1/6] Initial Implementation of iTransformer --- .../models/itransformer/__init__.py | 13 ++ .../itransformer/_itransformer_pkg_v2.py | 118 +++++++++++ .../models/itransformer/_itransformer_v2.py | 187 ++++++++++++++++++ 3 files changed, 318 insertions(+) create mode 100644 pytorch_forecasting/models/itransformer/__init__.py create mode 100644 pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py create mode 100644 pytorch_forecasting/models/itransformer/_itransformer_v2.py diff --git a/pytorch_forecasting/models/itransformer/__init__.py b/pytorch_forecasting/models/itransformer/__init__.py new file mode 100644 index 000000000..eaa9f79f0 --- /dev/null +++ b/pytorch_forecasting/models/itransformer/__init__.py @@ -0,0 +1,13 @@ +""" +iTransformer model for forecasting time series. +""" + +from pytorch_forecasting.models.itransformer._itransformer_pkg_v2 import ( + iTransformer_pkg_v2, +) +from pytorch_forecasting.models.itransformer._itransformer_v2 import iTransformer + +__all__ = [ + "iTransformer", + "iTransformer_pkg_v2", +] diff --git a/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py new file mode 100644 index 000000000..b37e81de9 --- /dev/null +++ b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py @@ -0,0 +1,118 @@ +"""iTransformer package container v2.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 + + +class iTransformer_pkg_v2(_BasePtForecasterV2): + """iTransformer metadata container.""" + + _tags = { + "info:name": "iTransformer", + "authors": ["JATAYU000"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": False, + "capability:flexible_history_length": False, + "capability:cold_start": False, + } + + @classmethod + def get_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.itransformer._itransformer_v2 import ( + iTransformer, + ) + + return iTransformer + + @classmethod + def _get_test_datamodule_from(cls, trainer_kwargs): + """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + from pytorch_forecasting.data._tslib_data_module import TslibDataModule + from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates_v2, + make_datasets_v2, + ) + + data_with_covariates = data_with_covariates_v2() + + data_loader_default_kwargs = dict( + target="target", + group_ids=["agency_encoded", "sku_encoded"], + add_relative_time_idx=True, + ) + + data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) + data_loader_default_kwargs.update(data_loader_kwargs) + + datasets_info = make_datasets_v2( + data_with_covariates, **data_loader_default_kwargs + ) + + training_dataset = datasets_info["training_dataset"] + validation_dataset = datasets_info["validation_dataset"] + + context_length = data_loader_kwargs.get("context_length", 12) + prediction_length = data_loader_kwargs.get("prediction_length", 4) + batch_size = data_loader_kwargs.get("batch_size", 2) + + train_datamodule = TslibDataModule( + time_series_dataset=training_dataset, + context_length=context_length, + prediction_length=prediction_length, + add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), + batch_size=batch_size, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = TslibDataModule( + time_series_dataset=validation_dataset, + context_length=context_length, + prediction_length=prediction_length, + add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), + batch_size=batch_size, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = TslibDataModule( + time_series_dataset=validation_dataset, + context_length=context_length, + prediction_length=prediction_length, + add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True), + batch_size=1, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @classmethod + def get_test_train_params(cls): + """Get test train params.""" + # todo: expand test parameters + return [ + {}, + dict(d_model=16, n_heads=2, e_layers=2, d_ff=64), + dict( + d_model=32, + n_heads=4, + e_layers=3, + d_ff=128, + dropout=0.1, + data_loader_kwargs=dict( + batch_size=4, context_length=16, prediction_length=8 + ), + ), + ] diff --git a/pytorch_forecasting/models/itransformer/_itransformer_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_v2.py new file mode 100644 index 000000000..10d614ea8 --- /dev/null +++ b/pytorch_forecasting/models/itransformer/_itransformer_v2.py @@ -0,0 +1,187 @@ +from typing import Any, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import Optimizer + +from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel + + +class iTransformer(TslibBaseModel): + """ + An implementation of iTransformer model for v2 of pytorch-forecasting. + + Parameters + ---------- + + References + ---------- + [1] https://arxiv.org/pdf/2310.06625 + [2] https://github.com/thuml/iTransformer/blob/main/model/iTransformer.py + + Notes + ----- + + """ + + @classmethod + def _pkg(cls): + """Package containing the model.""" + from pytorch_forecasting.models.itransformer._itransformer_pkg_v2 import ( + iTransformer_pkg_v2, + ) + + return iTransformer_pkg_v2 + + def __init__( + self, + loss: nn.Module, + output_attention: bool = False, + use_norm: bool = False, + factor: int = 5, + d_model: int = 512, + d_ff: int = 2048, + activation: str = "relu", + dropout: float = 0.1, + n_heads: int = 8, + e_layers: int = 3, + embed: str = "fixed", + logging_metrics: Optional[list[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[dict] = None, + metadata: Optional[dict] = None, + **kwargs, + ): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + metadata=metadata, + ) + + self.output_attention = output_attention + self.use_norm = use_norm + self.factor = factor + self.d_model = d_model + self.d_ff = d_ff + self.activation = activation + self.dropout = dropout + self.n_heads = n_heads + self.e_layers = e_layers + self.embed = embed + self.freq = self.metadata.get("freq", "h") + + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + + self._init_network() + + def _init_network(self): + """ + Initialize the network for iTransformer's architecture. + """ + from pytorch_forecasting.layers import ( + AttentionLayer, + DataEmbedding_inverted, + Encoder, + EncoderLayer, + FullAttention, + ) + + self.enc_embedding = DataEmbedding_inverted( + self.context_length, self.d_model, self.embed, self.freq, self.dropout + ) + + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + FullAttention( + False, + self.factor, + attention_dropout=self.dropout, + output_attention=self.output_attention, + ), + self.d_model, + self.n_heads, + ), + self.d_model, + self.d_ff, + dropout=self.dropout, + activation=self.activation, + ) + for _ in range(self.e_layers) + ], + norm_layer=torch.nn.LayerNorm(self.d_model), + ) + self.projector = nn.Linear(self.d_model, self.prediction_length, bias=True) + + def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forward pass of the iTransformer model. + Args: + x (dict[str, torch.Tensor]): Input data. + Returns: + dict[str, torch.Tensor]: Model predictions. + """ + x_enc = x["history_target"] + x_mark_enc = x["history_cont"] + + if self.use_norm: + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - means + stdev = torch.sqrt( + torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5 + ) + x_enc /= stdev + + _, _, N = x_enc.shape # B L N + # Embedding + # B L N -> B N E + enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) + # B N E -> B N E + # the dimensions of embedded time series has been inverted + enc_out, attns = self.encoder(enc_out, attn_mask=None) + + # B N E -> B N S -> B S N + dec_out = self.projector(enc_out).permute(0, 2, 1)[ + :, :, :N + ] # filter covariates + + if self.use_norm: + # De-Normalization from Non-stationary Transformer + dec_out = dec_out * ( + stdev[:, 0, :].unsqueeze(1).repeat(1, self.prediction_length, 1) + ) + dec_out = dec_out + ( + means[:, 0, :].unsqueeze(1).repeat(1, self.prediction_length, 1) + ) + + return dec_out, attns + + def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Forward pass of the iTransformer model. + Args: + x (dict[str, torch.Tensor]): Input data. + Returns: + dict[str, torch.Tensor]: Model predictions. + """ + dec_out, attns = self._forecast(x) + prediction = dec_out[:, -self.prediction_length :, :] + if prediction.shape[-1] == 1: + prediction = prediction.squeeze(-1) + + if "target_scale" in x: + prediction = self.transform_output(prediction, x["target_scale"]) + if self.output_attention: + return {"prediction": prediction, "attention_weights": attns} + else: + return {"prediction": prediction} From abeb908e8a22b041def4e59acba0eb85489262bd Mon Sep 17 00:00:00 2001 From: JATAYU000 Date: Fri, 28 Nov 2025 21:36:46 +0530 Subject: [PATCH 2/6] Import modules from submodules --- .../models/itransformer/__init__.py | 12 ++ .../itransformer/_itransformer_pkg_v2.py | 2 +- .../models/itransformer/_itransformer_v2.py | 10 +- .../models/itransformer/submodules.py | 171 ++++++++++++++++++ 4 files changed, 188 insertions(+), 7 deletions(-) create mode 100644 pytorch_forecasting/models/itransformer/submodules.py diff --git a/pytorch_forecasting/models/itransformer/__init__.py b/pytorch_forecasting/models/itransformer/__init__.py index eaa9f79f0..f4b15b0a9 100644 --- a/pytorch_forecasting/models/itransformer/__init__.py +++ b/pytorch_forecasting/models/itransformer/__init__.py @@ -6,8 +6,20 @@ iTransformer_pkg_v2, ) from pytorch_forecasting.models.itransformer._itransformer_v2 import iTransformer +from pytorch_forecasting.models.itransformer.submodules import ( + AttentionLayer, + DataEmbedding_inverted, + Encoder, + EncoderLayer, + FullAttention, +) __all__ = [ "iTransformer", "iTransformer_pkg_v2", + "AttentionLayer", + "DataEmbedding_inverted", + "Encoder", + "EncoderLayer", + "FullAttention", ] diff --git a/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py index b37e81de9..6734e8900 100644 --- a/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py +++ b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py @@ -112,7 +112,7 @@ def get_test_train_params(cls): d_ff=128, dropout=0.1, data_loader_kwargs=dict( - batch_size=4, context_length=16, prediction_length=8 + batch_size=4, context_length=8, prediction_length=4 ), ), ] diff --git a/pytorch_forecasting/models/itransformer/_itransformer_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_v2.py index 10d614ea8..bb989e5af 100644 --- a/pytorch_forecasting/models/itransformer/_itransformer_v2.py +++ b/pytorch_forecasting/models/itransformer/_itransformer_v2.py @@ -47,7 +47,6 @@ def __init__( dropout: float = 0.1, n_heads: int = 8, e_layers: int = 3, - embed: str = "fixed", logging_metrics: Optional[list[nn.Module]] = None, optimizer: Optional[Union[Optimizer, str]] = "adam", optimizer_params: Optional[dict] = None, @@ -75,7 +74,6 @@ def __init__( self.dropout = dropout self.n_heads = n_heads self.e_layers = e_layers - self.embed = embed self.freq = self.metadata.get("freq", "h") self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) @@ -86,7 +84,7 @@ def _init_network(self): """ Initialize the network for iTransformer's architecture. """ - from pytorch_forecasting.layers import ( + from pytorch_forecasting.models.itransformer.submodules import ( AttentionLayer, DataEmbedding_inverted, Encoder, @@ -95,7 +93,7 @@ def _init_network(self): ) self.enc_embedding = DataEmbedding_inverted( - self.context_length, self.d_model, self.embed, self.freq, self.dropout + self.context_length, self.d_model, self.dropout ) self.encoder = Encoder( @@ -176,8 +174,8 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ dec_out, attns = self._forecast(x) prediction = dec_out[:, -self.prediction_length :, :] - if prediction.shape[-1] == 1: - prediction = prediction.squeeze(-1) + # if prediction.shape[-1] == 1: + # prediction = prediction.squeeze(-1) if "target_scale" in x: prediction = self.transform_output(prediction, x["target_scale"]) diff --git a/pytorch_forecasting/models/itransformer/submodules.py b/pytorch_forecasting/models/itransformer/submodules.py new file mode 100644 index 000000000..87e256ec9 --- /dev/null +++ b/pytorch_forecasting/models/itransformer/submodules.py @@ -0,0 +1,171 @@ +""" +Implementation of `nn.Modules` for iTransformer model. +""" + +from math import sqrt + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TriangularCausalMask: + """ + Triangular causal mask for attention mechanism. + """ + + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu( + torch.ones(mask_shape, dtype=torch.bool), diagonal=1 + ).to(device) + + @property + def mask(self): + return self._mask + + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super().__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = ( + nn.ModuleList(conv_layers) if conv_layers is not None else None + ) + self.norm = norm_layer + + def forward(self, x, attn_mask=None, tau=None, delta=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for i, (attn_layer, conv_layer) in enumerate( + zip(self.attn_layers, self.conv_layers) + ): + delta = delta if i == 0 else None + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, tau=tau, delta=None) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super().__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None, tau=None, delta=None): + new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class FullAttention(nn.Module): + def __init__( + self, + mask_flag=True, + factor=5, + scale=None, + attention_dropout=0.1, + output_attention=False, + ): + super().__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1.0 / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return (V.contiguous(), A) + else: + return (V.contiguous(), None) + + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): + super().__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, keys, values, attn_mask, tau=tau, delta=delta + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +class DataEmbedding_inverted(nn.Module): + def __init__(self, c_in, d_model, dropout=0.1): + super().__init__() + self.value_embedding = nn.Linear(c_in, d_model) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = x.permute(0, 2, 1) + # x: [Batch Variate Time] + if x_mark is None: + x = self.value_embedding(x) + else: + # the potential to take covariates (e.g. timestamps) as tokens + # If they differ, convert x_mark: + x_mark = x_mark.float() + x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) + # x: [Batch Variate d_model] + return self.dropout(x) From 4d9d8090b058bfde81ecde176d314af0b5018c46 Mon Sep 17 00:00:00 2001 From: JATAYU000 Date: Sun, 30 Nov 2025 14:29:10 +0530 Subject: [PATCH 3/6] Quantile preds --- .../itransformer/_itransformer_pkg_v2.py | 2 +- .../models/itransformer/_itransformer_v2.py | 22 ++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py index 6734e8900..a20974d4e 100644 --- a/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py +++ b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py @@ -11,7 +11,7 @@ class iTransformer_pkg_v2(_BasePtForecasterV2): "authors": ["JATAYU000"], "capability:exogenous": True, "capability:multivariate": True, - "capability:pred_int": False, + "capability:pred_int": True, "capability:flexible_history_length": False, "capability:cold_start": False, } diff --git a/pytorch_forecasting/models/itransformer/_itransformer_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_v2.py index bb989e5af..a80984e90 100644 --- a/pytorch_forecasting/models/itransformer/_itransformer_v2.py +++ b/pytorch_forecasting/models/itransformer/_itransformer_v2.py @@ -96,6 +96,11 @@ def _init_network(self): self.context_length, self.d_model, self.dropout ) + self.n_quantiles = None + + if hasattr(self.loss, "quantiles") and self.loss.quantiles is not None: + self.n_quantiles = len(self.loss.quantiles) + self.encoder = Encoder( [ EncoderLayer( @@ -118,7 +123,12 @@ def _init_network(self): ], norm_layer=torch.nn.LayerNorm(self.d_model), ) - self.projector = nn.Linear(self.d_model, self.prediction_length, bias=True) + if self.n_quantiles is not None: + self.projector = nn.Linear( + self.d_model, self.prediction_length * self.n_quantiles, bias=True + ) + else: + self.projector = nn.Linear(self.d_model, self.prediction_length, bias=True) def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ @@ -173,12 +183,18 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: dict[str, torch.Tensor]: Model predictions. """ dec_out, attns = self._forecast(x) + + if self.n_quantiles is not None: + batch_size = dec_out.shape[0] + dec_out = dec_out.reshape( + batch_size, self.prediction_length, self.n_quantiles + ) + prediction = dec_out[:, -self.prediction_length :, :] - # if prediction.shape[-1] == 1: - # prediction = prediction.squeeze(-1) if "target_scale" in x: prediction = self.transform_output(prediction, x["target_scale"]) + if self.output_attention: return {"prediction": prediction, "attention_weights": attns} else: From 48cb4edb8cd031cc10dcde2dd6b2d06d0aa6d856 Mon Sep 17 00:00:00 2001 From: JATAYU000 Date: Sun, 30 Nov 2025 23:25:47 +0530 Subject: [PATCH 4/6] Added Docstrings, removed use_norm --- .../models/itransformer/__init__.py | 6 - .../models/itransformer/_itransformer_v2.py | 77 +++++++---- .../models/itransformer/submodules.py | 121 +++--------------- 3 files changed, 68 insertions(+), 136 deletions(-) diff --git a/pytorch_forecasting/models/itransformer/__init__.py b/pytorch_forecasting/models/itransformer/__init__.py index f4b15b0a9..a9d9d3da4 100644 --- a/pytorch_forecasting/models/itransformer/__init__.py +++ b/pytorch_forecasting/models/itransformer/__init__.py @@ -7,19 +7,13 @@ ) from pytorch_forecasting.models.itransformer._itransformer_v2 import iTransformer from pytorch_forecasting.models.itransformer.submodules import ( - AttentionLayer, - DataEmbedding_inverted, Encoder, EncoderLayer, - FullAttention, ) __all__ = [ "iTransformer", "iTransformer_pkg_v2", - "AttentionLayer", - "DataEmbedding_inverted", "Encoder", "EncoderLayer", - "FullAttention", ] diff --git a/pytorch_forecasting/models/itransformer/_itransformer_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_v2.py index a80984e90..f4e7a7e34 100644 --- a/pytorch_forecasting/models/itransformer/_itransformer_v2.py +++ b/pytorch_forecasting/models/itransformer/_itransformer_v2.py @@ -1,9 +1,7 @@ -from typing import Any, Optional, Union +from typing import Optional, Union -import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from torch.optim import Optimizer from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel @@ -13,8 +11,50 @@ class iTransformer(TslibBaseModel): """ An implementation of iTransformer model for v2 of pytorch-forecasting. + iTransformer repurposes the Transformer architecture by applying attention + and feed-forward networks on inverted dimensions. Instead of treating + timestamps as tokens (like traditional Transformers), iTransformer embeds + individual time series as variate tokens. The attention mechanism captures + multivariate correlations, while the feed-forward network learns nonlinear + representations for each variate. This inversion enables better handling + of long lookback windows, improved generalization across different variates, + and state-of-the-art performance on real-world forecasting tasks without + modifying the basic Transformer components. + Parameters ---------- + loss: nn.Module + Loss function to use for training. + output_attention: bool, default=False + Whether to output attention weights. + factor: int, default=5 + Factor for the attention mechanism, controlling keys and values. + d_model: int, default=512 + Dimension of the model embeddings and hidden representations. + d_ff: int, default=2048 + Dimension of the feed-forward network. + activation: str, default='relu' + Activation function to use in the feed-forward network. + dropout: float, default=0.1 + Dropout rate for regularization. + n_heads: int, default=8 + Number of attention heads in the multi-head attention mechanism. + e_layers: int, default=3 + Number of encoder layers in the transformer architecture. + logging_metrics: Optional[list[nn.Module]], default=None + List of metrics to log during training, validation, and testing. + optimizer: Optional[Union[Optimizer, str]], default='adam' + Optimizer to use for training. Can be a string name or an instance of an optimizer. + optimizer_params: Optional[dict], default=None + Parameters for the optimizer. If None, default parameters for the optimizer will be used. + lr_scheduler: Optional[str], default=None + Learning rate scheduler to use. If None, no scheduler is used. + lr_scheduler_params: Optional[dict], default=None + Parameters for the learning rate scheduler. If None, default parameters for the scheduler will be used. + metadata: Optional[dict], default=None + Metadata for the model from TslibDataModule. This can include information about the dataset, + such as the number of time steps, number of features, etc. It is used to initialize the model + and ensure it is compatible with the data being used. References ---------- @@ -23,8 +63,9 @@ class iTransformer(TslibBaseModel): Notes ----- - - """ + [1] The `iTransformer` model obtains many of its attributes from the `TslibBaseModel` class, which is a base class + where a lot of the boiler plate code for metadata handling and model initialization is implemented. + """ # noqa: E501 @classmethod def _pkg(cls): @@ -39,7 +80,6 @@ def __init__( self, loss: nn.Module, output_attention: bool = False, - use_norm: bool = False, factor: int = 5, d_model: int = 512, d_ff: int = 2048, @@ -66,7 +106,6 @@ def __init__( ) self.output_attention = output_attention - self.use_norm = use_norm self.factor = factor self.d_model = d_model self.d_ff = d_ff @@ -84,12 +123,14 @@ def _init_network(self): """ Initialize the network for iTransformer's architecture. """ - from pytorch_forecasting.models.itransformer.submodules import ( + from pytorch_forecasting.layers import ( AttentionLayer, DataEmbedding_inverted, + FullAttention, + ) + from pytorch_forecasting.models.itransformer.submodules import ( Encoder, EncoderLayer, - FullAttention, ) self.enc_embedding = DataEmbedding_inverted( @@ -141,15 +182,6 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: x_enc = x["history_target"] x_mark_enc = x["history_cont"] - if self.use_norm: - # Normalization from Non-stationary Transformer - means = x_enc.mean(1, keepdim=True).detach() - x_enc = x_enc - means - stdev = torch.sqrt( - torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5 - ) - x_enc /= stdev - _, _, N = x_enc.shape # B L N # Embedding # B L N -> B N E @@ -163,15 +195,6 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: :, :, :N ] # filter covariates - if self.use_norm: - # De-Normalization from Non-stationary Transformer - dec_out = dec_out * ( - stdev[:, 0, :].unsqueeze(1).repeat(1, self.prediction_length, 1) - ) - dec_out = dec_out + ( - means[:, 0, :].unsqueeze(1).repeat(1, self.prediction_length, 1) - ) - return dec_out, attns def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: diff --git a/pytorch_forecasting/models/itransformer/submodules.py b/pytorch_forecasting/models/itransformer/submodules.py index 87e256ec9..622801ac4 100644 --- a/pytorch_forecasting/models/itransformer/submodules.py +++ b/pytorch_forecasting/models/itransformer/submodules.py @@ -10,23 +10,6 @@ import torch.nn.functional as F -class TriangularCausalMask: - """ - Triangular causal mask for attention mechanism. - """ - - def __init__(self, B, L, device="cpu"): - mask_shape = [B, 1, L, L] - with torch.no_grad(): - self._mask = torch.triu( - torch.ones(mask_shape, dtype=torch.bool), diagonal=1 - ).to(device) - - @property - def mask(self): - return self._mask - - class Encoder(nn.Module): def __init__(self, attn_layers, conv_layers=None, norm_layer=None): super().__init__() @@ -83,89 +66,21 @@ def forward(self, x, attn_mask=None, tau=None, delta=None): return self.norm2(x + y), attn -class FullAttention(nn.Module): - def __init__( - self, - mask_flag=True, - factor=5, - scale=None, - attention_dropout=0.1, - output_attention=False, - ): - super().__init__() - self.scale = scale - self.mask_flag = mask_flag - self.output_attention = output_attention - self.dropout = nn.Dropout(attention_dropout) - - def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): - B, L, H, E = queries.shape - _, S, _, D = values.shape - scale = self.scale or 1.0 / sqrt(E) - - scores = torch.einsum("blhe,bshe->bhls", queries, keys) - - if self.mask_flag: - if attn_mask is None: - attn_mask = TriangularCausalMask(B, L, device=queries.device) - - scores.masked_fill_(attn_mask.mask, -np.inf) - - A = self.dropout(torch.softmax(scale * scores, dim=-1)) - V = torch.einsum("bhls,bshd->blhd", A, values) - - if self.output_attention: - return (V.contiguous(), A) - else: - return (V.contiguous(), None) - - -class AttentionLayer(nn.Module): - def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): - super().__init__() - - d_keys = d_keys or (d_model // n_heads) - d_values = d_values or (d_model // n_heads) - - self.inner_attention = attention - self.query_projection = nn.Linear(d_model, d_keys * n_heads) - self.key_projection = nn.Linear(d_model, d_keys * n_heads) - self.value_projection = nn.Linear(d_model, d_values * n_heads) - self.out_projection = nn.Linear(d_values * n_heads, d_model) - self.n_heads = n_heads - - def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): - B, L, _ = queries.shape - _, S, _ = keys.shape - H = self.n_heads - - queries = self.query_projection(queries).view(B, L, H, -1) - keys = self.key_projection(keys).view(B, S, H, -1) - values = self.value_projection(values).view(B, S, H, -1) - - out, attn = self.inner_attention( - queries, keys, values, attn_mask, tau=tau, delta=delta - ) - out = out.view(B, L, -1) - - return self.out_projection(out), attn - - -class DataEmbedding_inverted(nn.Module): - def __init__(self, c_in, d_model, dropout=0.1): - super().__init__() - self.value_embedding = nn.Linear(c_in, d_model) - self.dropout = nn.Dropout(p=dropout) - - def forward(self, x, x_mark): - x = x.permute(0, 2, 1) - # x: [Batch Variate Time] - if x_mark is None: - x = self.value_embedding(x) - else: - # the potential to take covariates (e.g. timestamps) as tokens - # If they differ, convert x_mark: - x_mark = x_mark.float() - x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) - # x: [Batch Variate d_model] - return self.dropout(x) +# class DataEmbedding_inverted(nn.Module): +# def __init__(self, c_in, d_model, dropout=0.1): +# super().__init__() +# self.value_embedding = nn.Linear(c_in, d_model) +# self.dropout = nn.Dropout(p=dropout) + +# def forward(self, x, x_mark): +# x = x.permute(0, 2, 1) +# # x: [Batch Variate Time] +# if x_mark is None: +# x = self.value_embedding(x) +# else: +# # the potential to take covariates (e.g. timestamps) as tokens +# # If they differ, convert x_mark: +# x_mark = x_mark.float() +# x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) +# # x: [Batch Variate d_model] +# return self.dropout(x) From 171326f6ad289ef833458acc3c57d4382ac47a41 Mon Sep 17 00:00:00 2001 From: JATAYU000 Date: Mon, 1 Dec 2025 20:21:26 +0530 Subject: [PATCH 5/6] Unused imports --- .../models/itransformer/submodules.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/pytorch_forecasting/models/itransformer/submodules.py b/pytorch_forecasting/models/itransformer/submodules.py index 622801ac4..4f5af5320 100644 --- a/pytorch_forecasting/models/itransformer/submodules.py +++ b/pytorch_forecasting/models/itransformer/submodules.py @@ -2,10 +2,6 @@ Implementation of `nn.Modules` for iTransformer model. """ -from math import sqrt - -import numpy as np -import torch import torch.nn as nn import torch.nn.functional as F @@ -64,23 +60,3 @@ def forward(self, x, attn_mask=None, tau=None, delta=None): y = self.dropout(self.conv2(y).transpose(-1, 1)) return self.norm2(x + y), attn - - -# class DataEmbedding_inverted(nn.Module): -# def __init__(self, c_in, d_model, dropout=0.1): -# super().__init__() -# self.value_embedding = nn.Linear(c_in, d_model) -# self.dropout = nn.Dropout(p=dropout) - -# def forward(self, x, x_mark): -# x = x.permute(0, 2, 1) -# # x: [Batch Variate Time] -# if x_mark is None: -# x = self.value_embedding(x) -# else: -# # the potential to take covariates (e.g. timestamps) as tokens -# # If they differ, convert x_mark: -# x_mark = x_mark.float() -# x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) -# # x: [Batch Variate d_model] -# return self.dropout(x) From 8eccc2b5aed483eabe291fa7e068b201bb9b0583 Mon Sep 17 00:00:00 2001 From: JATAYU000 Date: Thu, 11 Dec 2025 19:07:46 +0530 Subject: [PATCH 6/6] Update Encoder layer and output attention --- .../layers/_encoders/_encoder.py | 41 +++++++---- .../layers/_encoders/_encoder_layer.py | 69 +++++++++++-------- .../models/itransformer/__init__.py | 6 -- .../itransformer/_itransformer_pkg_v2.py | 17 ++++- .../models/itransformer/_itransformer_v2.py | 22 +++--- .../models/itransformer/submodules.py | 62 ----------------- 6 files changed, 91 insertions(+), 126 deletions(-) delete mode 100644 pytorch_forecasting/models/itransformer/submodules.py diff --git a/pytorch_forecasting/layers/_encoders/_encoder.py b/pytorch_forecasting/layers/_encoders/_encoder.py index 3b54a0838..1d2c0e773 100644 --- a/pytorch_forecasting/layers/_encoders/_encoder.py +++ b/pytorch_forecasting/layers/_encoders/_encoder.py @@ -2,39 +2,50 @@ Implementation of encoder layers from `nn.Module`. """ -import math -from math import sqrt - -import numpy as np -import torch import torch.nn as nn -import torch.nn.functional as F class Encoder(nn.Module): """ - Encoder module for the TimeXer model. + Encoder module for Tslib models. Args: layers (list): List of encoder layers. norm_layer (nn.Module, optional): Normalization layer. Defaults to None. projection (nn.Module, optional): Projection layer. Defaults to None. - """ + output_attention (Boolean, optional): Whether to output attention weights. Defaults to False. + """ # noqa: E501 - def __init__(self, layers, norm_layer=None, projection=None): + def __init__( + self, layers, norm_layer=None, projection=None, output_attention=False + ): super().__init__() self.layers = nn.ModuleList(layers) self.norm = norm_layer self.projection = projection - - def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): - for layer in self.layers: - x = layer( - x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta - ) + self.output_attention = output_attention + + def forward( + self, x, cross=None, x_mask=None, cross_mask=None, tau=None, delta=None + ): + if self.output_attention: + attns = [] + for layer in self.layers: + x, attn = layer( + x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta + ) + attns.append(attn) + else: + for layer in self.layers: + x = layer( + x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta + ) if self.norm is not None: x = self.norm(x) if self.projection is not None: x = self.projection(x) + + if self.output_attention: + return x, attns return x diff --git a/pytorch_forecasting/layers/_encoders/_encoder_layer.py b/pytorch_forecasting/layers/_encoders/_encoder_layer.py index a246edc91..0e03d437c 100644 --- a/pytorch_forecasting/layers/_encoders/_encoder_layer.py +++ b/pytorch_forecasting/layers/_encoders/_encoder_layer.py @@ -2,10 +2,6 @@ Implementation of EncoderLayer for encoder-decoder architectures from `nn.Module`. """ -import math -from math import sqrt - -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -13,25 +9,27 @@ class EncoderLayer(nn.Module): """ - Encoder layer for the TimeXer model. + Encoder layer for TsLib models. Args: self_attention (nn.Module): Self-attention mechanism. - cross_attention (nn.Module): Cross-attention mechanism. + cross_attention (nn.Module, optional): Cross-attention mechanism. d_model (int): Dimension of the model. d_ff (int, optional): Dimension of the feedforward layer. Defaults to 4 * d_model. dropout (float): Dropout rate. Defaults to 0.1. activation (str): Activation function. Defaults to "relu". - """ + output_attention (Boolean, optional): Whether to output attention weights. Defaults to False. + """ # noqa: E501 def __init__( self, self_attention, - cross_attention, - d_model, + cross_attention=None, + d_model=512, d_ff=None, dropout=0.1, activation="relu", + output_attention=False, ): super().__init__() d_ff = d_ff or 4 * d_model @@ -40,34 +38,45 @@ def __init__( self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) + if self.cross_attention is not None: + self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = F.relu if activation == "relu" else F.gelu + self.output_attention = output_attention - def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): - B, L, D = cross.shape - x = x + self.dropout( - self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] - ) - x = self.norm1(x) - - x_glb_ori = x[:, -1, :].unsqueeze(1) - x_glb = torch.reshape(x_glb_ori, (B, -1, D)) - x_glb_attn = self.dropout( - self.cross_attention( - x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta - )[0] - ) - x_glb_attn = torch.reshape( - x_glb_attn, (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2]) - ).unsqueeze(1) - x_glb = x_glb_ori + x_glb_attn - x_glb = self.norm2(x_glb) + def forward( + self, x, cross=None, x_mask=None, cross_mask=None, tau=None, delta=None + ): + if self.output_attention: + x, attn = self.self_attention( + x, x, x, attn_mask=x_mask, tau=tau, delta=None + ) + else: + x = self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] + x = x + self.dropout(x) + y = x = self.norm1(x) + if self.cross_attention is not None: + B, L, D = cross.shape + x_glb_ori = x[:, -1, :].unsqueeze(1) + x_glb = torch.reshape(x_glb_ori, (B, -1, D)) + x_glb_attn = self.dropout( + self.cross_attention( + x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta + )[0] + ) + x_glb_attn = torch.reshape( + x_glb_attn, + (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2]), + ).unsqueeze(1) + x_glb = x_glb_ori + x_glb_attn + x_glb = self.norm2(x_glb) - y = x = torch.cat([x[:, :-1, :], x_glb], dim=1) + y = x = torch.cat([x[:, :-1, :], x_glb], dim=1) y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) y = self.dropout(self.conv2(y).transpose(-1, 1)) + if self.output_attention: + return self.norm3(x + y), attn return self.norm3(x + y) diff --git a/pytorch_forecasting/models/itransformer/__init__.py b/pytorch_forecasting/models/itransformer/__init__.py index a9d9d3da4..eaa9f79f0 100644 --- a/pytorch_forecasting/models/itransformer/__init__.py +++ b/pytorch_forecasting/models/itransformer/__init__.py @@ -6,14 +6,8 @@ iTransformer_pkg_v2, ) from pytorch_forecasting.models.itransformer._itransformer_v2 import iTransformer -from pytorch_forecasting.models.itransformer.submodules import ( - Encoder, - EncoderLayer, -) __all__ = [ "iTransformer", "iTransformer_pkg_v2", - "Encoder", - "EncoderLayer", ] diff --git a/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py index a20974d4e..dc923a313 100644 --- a/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py +++ b/pytorch_forecasting/models/itransformer/_itransformer_pkg_v2.py @@ -101,7 +101,8 @@ def _get_test_datamodule_from(cls, trainer_kwargs): @classmethod def get_test_train_params(cls): """Get test train params.""" - # todo: expand test parameters + from pytorch_forecasting.metrics import QuantileLoss + return [ {}, dict(d_model=16, n_heads=2, e_layers=2, d_ff=64), @@ -115,4 +116,18 @@ def get_test_train_params(cls): batch_size=4, context_length=8, prediction_length=4 ), ), + dict( + hidden_size=32, + n_heads=2, + e_layers=1, + d_ff=64, + factor=2, + activation="relu", + dropout=0.05, + data_loader_kwargs=dict( + context_length=16, + prediction_length=4, + ), + loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), + ), ] diff --git a/pytorch_forecasting/models/itransformer/_itransformer_v2.py b/pytorch_forecasting/models/itransformer/_itransformer_v2.py index f4e7a7e34..d1d1427cb 100644 --- a/pytorch_forecasting/models/itransformer/_itransformer_v2.py +++ b/pytorch_forecasting/models/itransformer/_itransformer_v2.py @@ -126,11 +126,9 @@ def _init_network(self): from pytorch_forecasting.layers import ( AttentionLayer, DataEmbedding_inverted, - FullAttention, - ) - from pytorch_forecasting.models.itransformer.submodules import ( Encoder, EncoderLayer, + FullAttention, ) self.enc_embedding = DataEmbedding_inverted( @@ -145,7 +143,7 @@ def _init_network(self): self.encoder = Encoder( [ EncoderLayer( - AttentionLayer( + self_attention=AttentionLayer( FullAttention( False, self.factor, @@ -155,8 +153,8 @@ def _init_network(self): self.d_model, self.n_heads, ), - self.d_model, - self.d_ff, + d_model=self.d_model, + d_ff=self.d_ff, dropout=self.dropout, activation=self.activation, ) @@ -188,14 +186,15 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) # B N E -> B N E # the dimensions of embedded time series has been inverted - enc_out, attns = self.encoder(enc_out, attn_mask=None) + enc_out, attns = self.encoder(enc_out, x_mask=None) # B N E -> B N S -> B S N dec_out = self.projector(enc_out).permute(0, 2, 1)[ :, :, :N ] # filter covariates - - return dec_out, attns + if self.output_attention: + return dec_out, attns + return dec_out def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ @@ -219,6 +218,5 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: prediction = self.transform_output(prediction, x["target_scale"]) if self.output_attention: - return {"prediction": prediction, "attention_weights": attns} - else: - return {"prediction": prediction} + return {"prediction": prediction, "attention": attns} + return {"prediction": prediction} diff --git a/pytorch_forecasting/models/itransformer/submodules.py b/pytorch_forecasting/models/itransformer/submodules.py deleted file mode 100644 index 4f5af5320..000000000 --- a/pytorch_forecasting/models/itransformer/submodules.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Implementation of `nn.Modules` for iTransformer model. -""" - -import torch.nn as nn -import torch.nn.functional as F - - -class Encoder(nn.Module): - def __init__(self, attn_layers, conv_layers=None, norm_layer=None): - super().__init__() - self.attn_layers = nn.ModuleList(attn_layers) - self.conv_layers = ( - nn.ModuleList(conv_layers) if conv_layers is not None else None - ) - self.norm = norm_layer - - def forward(self, x, attn_mask=None, tau=None, delta=None): - # x [B, L, D] - attns = [] - if self.conv_layers is not None: - for i, (attn_layer, conv_layer) in enumerate( - zip(self.attn_layers, self.conv_layers) - ): - delta = delta if i == 0 else None - x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) - x = conv_layer(x) - attns.append(attn) - x, attn = self.attn_layers[-1](x, tau=tau, delta=None) - attns.append(attn) - else: - for attn_layer in self.attn_layers: - x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) - attns.append(attn) - - if self.norm is not None: - x = self.norm(x) - - return x, attns - - -class EncoderLayer(nn.Module): - def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): - super().__init__() - d_ff = d_ff or 4 * d_model - self.attention = attention - self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) - self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout = nn.Dropout(dropout) - self.activation = F.relu if activation == "relu" else F.gelu - - def forward(self, x, attn_mask=None, tau=None, delta=None): - new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta) - x = x + self.dropout(new_x) - - y = x = self.norm1(x) - y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) - y = self.dropout(self.conv2(y).transpose(-1, 1)) - - return self.norm2(x + y), attn