We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 820ce94 commit 50cc7e9Copy full SHA for 50cc7e9
monai/losses/focal_loss.py
@@ -223,8 +223,9 @@ def softmax_focal_loss(
223
alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype)
224
225
if alpha_t.ndim == 0: # scalar
226
+ alpha_val = alpha_t.item()
227
# (1-alpha) for the background class and alpha for the other classes
- alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
228
+ alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss)
229
else: # sequence
230
if alpha_t.shape[0] != target.shape[1]:
231
raise ValueError(
0 commit comments