Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,21 @@ def __init__(
include_background: bool = True,
to_onehot_y: bool = False,
gamma: float = 2.0,
alpha: float | None = None,
alpha: float | Sequence[float] | None = None,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
) -> None:
"""
Args:
include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
If False, `alpha` is invalid when using softmax.
If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights).
to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
The value should be in [0, 1]. Defaults to None.
The value should be in [0, 1].
If a sequence is provided, it must match the number of classes (after excluding background if set).
Defaults to None.
weight: weights to apply to the voxels of each class. If None no weights are applied.
The input can be a single value (same weight for all classes), a sequence of values (the length
of the sequence should be the same as the number of classes. If not ``include_background``,
Expand Down Expand Up @@ -110,9 +112,15 @@ def __init__(
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.gamma = gamma
self.alpha = alpha
self.weight = weight
self.use_softmax = use_softmax
self.alpha: float | torch.Tensor | None
if alpha is None:
self.alpha = None
elif isinstance(alpha, (float, int)):
self.alpha = float(alpha)
else:
self.alpha = torch.as_tensor(alpha)
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
Expand Down Expand Up @@ -156,13 +164,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss: Optional[torch.Tensor] = None
input = input.float()
target = target.float()

alpha_arg: float | torch.Tensor | None = self.alpha
if isinstance(alpha_arg, torch.Tensor):
alpha_arg = alpha_arg.to(input.device)

if self.use_softmax:
if not self.include_background and self.alpha is not None:
self.alpha = None
warnings.warn("`include_background=False`, `alpha` ignored when using softmax.")
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
if isinstance(self.alpha, (float, int)):
alpha_arg = None
warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add stacklevel=2 to warning.

Set explicit stacklevel=2 so the warning points to the caller rather than this internal line.

🔎 Proposed fix
-                    warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.")
+                    warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2)
🧰 Tools
🪛 Ruff (0.14.8)

167-167: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🤖 Prompt for AI Agents
In monai/losses/focal_loss.py around line 167, the warnings.warn call should
include stacklevel=2 so the warning points to the caller; update the call to
pass stacklevel=2 (e.g., warnings.warn("`include_background=False`, scalar
`alpha` ignored when using softmax.", stacklevel=2)) so the warning's traceback
references the caller rather than this internal line.

loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)
else:
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg)

num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
Expand Down Expand Up @@ -203,7 +217,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:


def softmax_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
Expand All @@ -215,8 +229,22 @@ def softmax_focal_loss(
loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target

if alpha is not None:
# (1-alpha) for the background class and alpha for the other classes
alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
if isinstance(alpha, torch.Tensor):
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
else:
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)

if alpha_t.ndim == 0: # scalar
alpha_val = alpha_t.item()
# (1-alpha) for the background class and alpha for the other classes
alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss)
else: # tensor (sequence)
if alpha_t.shape[0] != target.shape[1]:
raise ValueError(
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
)
alpha_fac = alpha_t

broadcast_dims = [-1] + [1] * len(target.shape[2:])
alpha_fac = alpha_fac.view(broadcast_dims)
loss = alpha_fac * loss
Expand All @@ -225,7 +253,7 @@ def softmax_focal_loss(


def sigmoid_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
Expand All @@ -248,8 +276,25 @@ def sigmoid_focal_loss(
loss = (invprobs * gamma).exp() * loss

if alpha is not None:
# alpha if t==1; (1-alpha) if t==0
alpha_factor = target * alpha + (1 - target) * (1 - alpha)
if isinstance(alpha, torch.Tensor):
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
else:
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)

if alpha_t.ndim == 0: # scalar
# alpha if t==1; (1-alpha) if t==0
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)
else: # tensor (sequence)
if alpha_t.shape[0] != target.shape[1]:
raise ValueError(
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
)
# Reshape alpha for broadcasting: (1, C, 1, 1...)
broadcast_dims = [-1] + [1] * len(target.shape[2:])
alpha_t = alpha_t.view(broadcast_dims)
# Apply alpha_c if t==1, (1-alpha_c) if t==0 for channel c
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)

loss = alpha_factor * loss

return loss
Loading