Skip to content

Conversation

@WoosukKwon
Copy link
Collaborator

No description provided.

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@mergify mergify bot added the v1 label Dec 6, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds functionality to count the number of NaN values in the logits, which is a useful feature for monitoring model stability. The implementation introduces a new Triton kernel for this purpose.

My review includes two main points:

  1. The NaN counting is currently performed on processed logits. For more accurate model health monitoring, it should be done on the raw logits before any modifications.
  2. The new Triton kernel for counting NaNs can be replaced with a much simpler and more maintainable one-line PyTorch implementation, which is preferable for a feature that is primarily for debugging and disabled by default.

Comment on lines +1 to +42
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch._inductor.runtime.triton_helpers import libdevice

from vllm.triton_utils import tl, triton


@triton.jit
def _num_nans_kernel(
logits_ptr,
logits_stride,
num_nans_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_nans = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
)
logits = logits.to(tl.float32)
is_nan = libdevice.isnan(logits).to(tl.int1)
num_nans += tl.sum(is_nan).to(tl.int32)
tl.store(num_nans_ptr + req_idx, num_nans)


def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
_num_nans_kernel[(num_reqs,)](
logits,
logits.stride(0),
num_nans,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
return num_nans
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Triton kernel for counting NaNs adds complexity. A simpler and more maintainable approach would be to use the built-in PyTorch function torch.isnan. A one-line implementation like torch.sum(torch.isnan(logits), dim=1, dtype=torch.int32) would achieve the same result. While the Triton kernel might offer a slight performance advantage by avoiding the materialization of an intermediate tensor, this feature is disabled by default and likely used for debugging, where simplicity and maintainability are often more important. Consider replacing the Triton kernel with the simpler PyTorch implementation.

Suggested change
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch._inductor.runtime.triton_helpers import libdevice
from vllm.triton_utils import tl, triton
@triton.jit
def _num_nans_kernel(
logits_ptr,
logits_stride,
num_nans_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_nans = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
)
logits = logits.to(tl.float32)
is_nan = libdevice.isnan(logits).to(tl.int1)
num_nans += tl.sum(is_nan).to(tl.int32)
tl.store(num_nans_ptr + req_idx, num_nans)
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
_num_nans_kernel[(num_reqs,)](
logits,
logits.stride(0),
num_nans,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
return num_nans
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
"""Counts the number of NaNs in the logits tensor per request."""
return torch.sum(torch.isnan(logits), dim=1, dtype=torch.int32)

else:
logprobs_tensors = None

num_nans = get_num_nans(processed_logits) if self.compute_nans else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The number of NaNs is calculated on processed_logits. To better detect model instability, it would be more effective to calculate this on the raw logits before any processing is applied (e.g., at the beginning of the __call__ method). The processing steps could potentially hide or introduce NaNs, making the metric less reliable for monitoring the model's direct output.

else:
self.logprobs_tensors = None
if sampler_output.num_nans is not None:
self.num_nans = sampler_output.num_nans.to("cpu", non_blocking=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, could do the numpy() here, same for num_sampled_tokens_np

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants