From 88b1182166707ed4ea7e86789d06c11765eea015 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 14:21:55 +0800 Subject: [PATCH 1/6] Weights in alpha for FocalLoss Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 54 +++++++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 28d1c0cdc9..ede08b970d 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -70,7 +70,7 @@ def __init__( include_background: bool = True, to_onehot_y: bool = False, gamma: float = 2.0, - alpha: float | None = None, + alpha: float | Sequence[float] | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, reduction: LossReduction | str = LossReduction.MEAN, use_softmax: bool = False, @@ -78,11 +78,13 @@ def __init__( """ Args: include_background: if False, channel index 0 (background category) is excluded from the loss calculation. - If False, `alpha` is invalid when using softmax. + If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights). to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False. gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2. alpha: value of the alpha in the definition of the alpha-balanced Focal loss. - The value should be in [0, 1]. Defaults to None. + The value should be in [0, 1]. + If a sequence is provided, it must match the number of classes (after excluding background if set). + Defaults to None. weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length of the sequence should be the same as the number of classes. If not ``include_background``, @@ -156,13 +158,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: loss: Optional[torch.Tensor] = None input = input.float() target = target.float() + + alpha_arg = self.alpha if self.use_softmax: if not self.include_background and self.alpha is not None: - self.alpha = None - warnings.warn("`include_background=False`, `alpha` ignored when using softmax.") - loss = softmax_focal_loss(input, target, self.gamma, self.alpha) + if isinstance(self.alpha, (float, int)): + alpha_arg = None + warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") + loss = softmax_focal_loss(input, target, self.gamma, self.alpha_arg) else: - loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) + loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha_arg) num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -203,7 +208,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def softmax_focal_loss( - input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) @@ -215,8 +220,18 @@ def softmax_focal_loss( loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target if alpha is not None: - # (1-alpha) for the background class and alpha for the other classes - alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) + alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) + + if alpha_t.ndim == 0: # scalar + # (1-alpha) for the background class and alpha for the other classes + alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) + else: # sequence + if alpha_t.shape[0] != target.shape[1]: + raise ValueError( + f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." + ) + alpha_fac = alpha_t + broadcast_dims = [-1] + [1] * len(target.shape[2:]) alpha_fac = alpha_fac.view(broadcast_dims) loss = alpha_fac * loss @@ -225,7 +240,7 @@ def softmax_focal_loss( def sigmoid_focal_loss( - input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) @@ -248,8 +263,21 @@ def sigmoid_focal_loss( loss = (invprobs * gamma).exp() * loss if alpha is not None: - # alpha if t==1; (1-alpha) if t==0 - alpha_factor = target * alpha + (1 - target) * (1 - alpha) + alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) + if alpha_t.ndim == 0: # scalar + # alpha if t==1; (1-alpha) if t==0 + alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t) + else: # sequence / per-channel alpha + if alpha_t.shape[0] != target.shape[1]: + raise ValueError( + f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." + ) + # Reshape alpha for broadcasting: (1, C, 1, 1...) + broadcast_dims = [-1] + [1] * len(target.shape[2:]) + alpha_t = alpha_t.view(broadcast_dims) + # Apply alpha_c if t==1, (1-alpha_c) if t==0 for channel c + alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t) + loss = alpha_factor * loss return loss From 820ce947bd3d4b7192afc951590faafc7dbea837 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 14:29:05 +0800 Subject: [PATCH 2/6] fix Local variable lpha_arg is assigned to but never used Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index ede08b970d..b000a7f8fb 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -165,9 +165,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if isinstance(self.alpha, (float, int)): alpha_arg = None warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") - loss = softmax_focal_loss(input, target, self.gamma, self.alpha_arg) + loss = softmax_focal_loss(input, target, self.gamma, alpha_arg) else: - loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha_arg) + loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg) num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: From 50cc7e97bd2103a97c048c785dbffbdcebeb2353 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 14:47:14 +0800 Subject: [PATCH 3/6] fix mypy error Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index b000a7f8fb..fcd31a80f6 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -223,8 +223,9 @@ def softmax_focal_loss( alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) if alpha_t.ndim == 0: # scalar + alpha_val = alpha_t.item() # (1-alpha) for the background class and alpha for the other classes - alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) + alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss) else: # sequence if alpha_t.shape[0] != target.shape[1]: raise ValueError( From 1b2483441dccbc0ee117dfdec15f4663bd4a8e73 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 16:03:28 +0800 Subject: [PATCH 4/6] fix undefined type error Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index fcd31a80f6..e463ee9e8d 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -112,12 +112,17 @@ def __init__( self.include_background = include_background self.to_onehot_y = to_onehot_y self.gamma = gamma - self.alpha = alpha self.weight = weight self.use_softmax = use_softmax weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.alpha: float | torch.Tensor | None + + if isinstance(alpha, (list, tuple)): + self.alpha = torch.tensor(alpha) + else: + self.alpha = alpha def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -159,7 +164,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input = input.float() target = target.float() - alpha_arg = self.alpha + alpha_arg: float | torch.Tensor | None = self.alpha + if isinstance(alpha_arg, torch.Tensor): + alpha_arg = alpha_arg.to(input.device) + if self.use_softmax: if not self.include_background and self.alpha is not None: if isinstance(self.alpha, (float, int)): @@ -208,7 +216,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def softmax_focal_loss( - input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) @@ -241,7 +249,7 @@ def softmax_focal_loss( def sigmoid_focal_loss( - input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | Sequence[float] | None = None + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None ) -> torch.Tensor: """ FL(pt) = -alpha * (1 - pt)**gamma * log(pt) From 258b79dc8785292dce42e98d8a1eac80f96ac283 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 16:19:34 +0800 Subject: [PATCH 5/6] fix alpha type bugs Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index e463ee9e8d..e3fc246c0f 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -114,15 +114,15 @@ def __init__( self.gamma = gamma self.weight = weight self.use_softmax = use_softmax - weight = torch.as_tensor(weight) if weight is not None else None - self.register_buffer("class_weight", weight) - self.class_weight: None | torch.Tensor self.alpha: float | torch.Tensor | None if isinstance(alpha, (list, tuple)): self.alpha = torch.tensor(alpha) else: self.alpha = alpha + weight = torch.as_tensor(weight) if weight is not None else None + self.register_buffer("class_weight", weight) + self.class_weight: None | torch.Tensor def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -165,8 +165,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = target.float() alpha_arg: float | torch.Tensor | None = self.alpha - if isinstance(alpha_arg, torch.Tensor): - alpha_arg = alpha_arg.to(input.device) if self.use_softmax: if not self.include_background and self.alpha is not None: @@ -228,13 +226,16 @@ def softmax_focal_loss( loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target if alpha is not None: - alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) + if isinstance(alpha, torch.Tensor): + alpha_t = alpha.to(device=input.device, dtype=input.dtype) + else: + alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype) if alpha_t.ndim == 0: # scalar alpha_val = alpha_t.item() # (1-alpha) for the background class and alpha for the other classes alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss) - else: # sequence + else: # tensor (sequence) if alpha_t.shape[0] != target.shape[1]: raise ValueError( f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." @@ -272,11 +273,15 @@ def sigmoid_focal_loss( loss = (invprobs * gamma).exp() * loss if alpha is not None: - alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype) + if isinstance(alpha, torch.Tensor): + alpha_t = alpha.to(device=input.device, dtype=input.dtype) + else: + alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype) + if alpha_t.ndim == 0: # scalar # alpha if t==1; (1-alpha) if t==0 alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t) - else: # sequence / per-channel alpha + else: # tensor (sequence) if alpha_t.shape[0] != target.shape[1]: raise ValueError( f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." From a574d7ba2af3a9f6c539ac8c237485208fc3066a Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 19 Dec 2025 16:36:07 +0800 Subject: [PATCH 6/6] fix alpha type bugs Signed-off-by: ytl0623 --- monai/losses/focal_loss.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index e3fc246c0f..a5d3748814 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -115,11 +115,12 @@ def __init__( self.weight = weight self.use_softmax = use_softmax self.alpha: float | torch.Tensor | None - - if isinstance(alpha, (list, tuple)): - self.alpha = torch.tensor(alpha) + if alpha is None: + self.alpha = None + elif isinstance(alpha, (float, int)): + self.alpha = float(alpha) else: - self.alpha = alpha + self.alpha = torch.as_tensor(alpha) weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor @@ -165,6 +166,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = target.float() alpha_arg: float | torch.Tensor | None = self.alpha + if isinstance(alpha_arg, torch.Tensor): + alpha_arg = alpha_arg.to(input.device) if self.use_softmax: if not self.include_background and self.alpha is not None: