-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Weights in alpha for FocalLoss #8665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ytl0623
wants to merge
6
commits into
Project-MONAI:dev
Choose a base branch
from
ytl0623:fix-issue-8601
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+59
−14
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
88b1182
Weights in alpha for FocalLoss
ytl0623 820ce94
fix Local variable lpha_arg is assigned to but never used
ytl0623 50cc7e9
fix mypy error
ytl0623 1b24834
fix undefined type error
ytl0623 258b79d
fix alpha type bugs
ytl0623 a574d7b
fix alpha type bugs
ytl0623 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,19 +70,21 @@ def __init__( | |
| include_background: bool = True, | ||
| to_onehot_y: bool = False, | ||
| gamma: float = 2.0, | ||
| alpha: float | None = None, | ||
| alpha: float | Sequence[float] | None = None, | ||
| weight: Sequence[float] | float | int | torch.Tensor | None = None, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| use_softmax: bool = False, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| include_background: if False, channel index 0 (background category) is excluded from the loss calculation. | ||
| If False, `alpha` is invalid when using softmax. | ||
| If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights). | ||
| to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False. | ||
| gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2. | ||
| alpha: value of the alpha in the definition of the alpha-balanced Focal loss. | ||
| The value should be in [0, 1]. Defaults to None. | ||
| The value should be in [0, 1]. | ||
| If a sequence is provided, it must match the number of classes (after excluding background if set). | ||
| Defaults to None. | ||
| weight: weights to apply to the voxels of each class. If None no weights are applied. | ||
| The input can be a single value (same weight for all classes), a sequence of values (the length | ||
| of the sequence should be the same as the number of classes. If not ``include_background``, | ||
|
|
@@ -110,9 +112,15 @@ def __init__( | |
| self.include_background = include_background | ||
| self.to_onehot_y = to_onehot_y | ||
| self.gamma = gamma | ||
| self.alpha = alpha | ||
| self.weight = weight | ||
| self.use_softmax = use_softmax | ||
| self.alpha: float | torch.Tensor | None | ||
| if alpha is None: | ||
| self.alpha = None | ||
| elif isinstance(alpha, (float, int)): | ||
| self.alpha = float(alpha) | ||
| else: | ||
| self.alpha = torch.as_tensor(alpha) | ||
| weight = torch.as_tensor(weight) if weight is not None else None | ||
| self.register_buffer("class_weight", weight) | ||
| self.class_weight: None | torch.Tensor | ||
|
|
@@ -156,13 +164,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| loss: Optional[torch.Tensor] = None | ||
| input = input.float() | ||
| target = target.float() | ||
|
|
||
| alpha_arg: float | torch.Tensor | None = self.alpha | ||
| if isinstance(alpha_arg, torch.Tensor): | ||
| alpha_arg = alpha_arg.to(input.device) | ||
|
|
||
| if self.use_softmax: | ||
| if not self.include_background and self.alpha is not None: | ||
| self.alpha = None | ||
| warnings.warn("`include_background=False`, `alpha` ignored when using softmax.") | ||
| loss = softmax_focal_loss(input, target, self.gamma, self.alpha) | ||
| if isinstance(self.alpha, (float, int)): | ||
| alpha_arg = None | ||
| warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add Set explicit 🔎 Proposed fix- warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.")
+ warnings.warn("`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2)🧰 Tools🪛 Ruff (0.14.8)167-167: No explicit Set (B028) 🤖 Prompt for AI Agents |
||
| loss = softmax_focal_loss(input, target, self.gamma, alpha_arg) | ||
| else: | ||
| loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) | ||
| loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg) | ||
|
|
||
| num_of_classes = target.shape[1] | ||
| if self.class_weight is not None and num_of_classes != 1: | ||
|
|
@@ -203,7 +217,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
|
|
||
|
|
||
| def softmax_focal_loss( | ||
| input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None | ||
| input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None | ||
| ) -> torch.Tensor: | ||
| """ | ||
| FL(pt) = -alpha * (1 - pt)**gamma * log(pt) | ||
|
|
@@ -215,8 +229,22 @@ def softmax_focal_loss( | |
| loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target | ||
|
|
||
| if alpha is not None: | ||
| # (1-alpha) for the background class and alpha for the other classes | ||
| alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) | ||
| if isinstance(alpha, torch.Tensor): | ||
| alpha_t = alpha.to(device=input.device, dtype=input.dtype) | ||
| else: | ||
| alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype) | ||
|
|
||
| if alpha_t.ndim == 0: # scalar | ||
| alpha_val = alpha_t.item() | ||
| # (1-alpha) for the background class and alpha for the other classes | ||
| alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss) | ||
| else: # tensor (sequence) | ||
| if alpha_t.shape[0] != target.shape[1]: | ||
| raise ValueError( | ||
| f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." | ||
| ) | ||
| alpha_fac = alpha_t | ||
|
|
||
| broadcast_dims = [-1] + [1] * len(target.shape[2:]) | ||
| alpha_fac = alpha_fac.view(broadcast_dims) | ||
| loss = alpha_fac * loss | ||
|
|
@@ -225,7 +253,7 @@ def softmax_focal_loss( | |
|
|
||
|
|
||
| def sigmoid_focal_loss( | ||
| input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None | ||
| input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None | ||
| ) -> torch.Tensor: | ||
| """ | ||
| FL(pt) = -alpha * (1 - pt)**gamma * log(pt) | ||
|
|
@@ -248,8 +276,25 @@ def sigmoid_focal_loss( | |
| loss = (invprobs * gamma).exp() * loss | ||
|
|
||
| if alpha is not None: | ||
| # alpha if t==1; (1-alpha) if t==0 | ||
| alpha_factor = target * alpha + (1 - target) * (1 - alpha) | ||
| if isinstance(alpha, torch.Tensor): | ||
| alpha_t = alpha.to(device=input.device, dtype=input.dtype) | ||
| else: | ||
| alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype) | ||
|
|
||
| if alpha_t.ndim == 0: # scalar | ||
| # alpha if t==1; (1-alpha) if t==0 | ||
| alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t) | ||
| else: # tensor (sequence) | ||
| if alpha_t.shape[0] != target.shape[1]: | ||
| raise ValueError( | ||
| f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})." | ||
| ) | ||
| # Reshape alpha for broadcasting: (1, C, 1, 1...) | ||
| broadcast_dims = [-1] + [1] * len(target.shape[2:]) | ||
| alpha_t = alpha_t.view(broadcast_dims) | ||
| # Apply alpha_c if t==1, (1-alpha_c) if t==0 for channel c | ||
| alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t) | ||
|
|
||
| loss = alpha_factor * loss | ||
|
|
||
| return loss | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.