Skip to content

Commit 50cc7e9

Browse files
committed
fix mypy error
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent 820ce94 commit 50cc7e9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

monai/losses/focal_loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,9 @@ def softmax_focal_loss(
223223
alpha_t = torch.as_tensor(alpha, device=input.device, dtype=input.dtype)
224224

225225
if alpha_t.ndim == 0: # scalar
226+
alpha_val = alpha_t.item()
226227
# (1-alpha) for the background class and alpha for the other classes
227-
alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
228+
alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss)
228229
else: # sequence
229230
if alpha_t.shape[0] != target.shape[1]:
230231
raise ValueError(

0 commit comments

Comments
 (0)