-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Model Runner V2] Support num NaNs in logits #30187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds functionality to count the number of NaN values in the logits, which is a useful feature for monitoring model stability. The implementation introduces a new Triton kernel for this purpose.
My review includes two main points:
- The NaN counting is currently performed on processed logits. For more accurate model health monitoring, it should be done on the raw logits before any modifications.
- The new Triton kernel for counting NaNs can be replaced with a much simpler and more maintainable one-line PyTorch implementation, which is preferable for a feature that is primarily for debugging and disabled by default.
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import torch | ||
| from torch._inductor.runtime.triton_helpers import libdevice | ||
|
|
||
| from vllm.triton_utils import tl, triton | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _num_nans_kernel( | ||
| logits_ptr, | ||
| logits_stride, | ||
| num_nans_ptr, | ||
| vocab_size, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| req_idx = tl.program_id(0) | ||
| num_nans = 0 | ||
| for i in range(0, vocab_size, BLOCK_SIZE): | ||
| block = i + tl.arange(0, BLOCK_SIZE) | ||
| mask = block < vocab_size | ||
| logits = tl.load( | ||
| logits_ptr + req_idx * logits_stride + block, mask=mask, other=0 | ||
| ) | ||
| logits = logits.to(tl.float32) | ||
| is_nan = libdevice.isnan(logits).to(tl.int1) | ||
| num_nans += tl.sum(is_nan).to(tl.int32) | ||
| tl.store(num_nans_ptr + req_idx, num_nans) | ||
|
|
||
|
|
||
| def get_num_nans(logits: torch.Tensor) -> torch.Tensor: | ||
| num_reqs, vocab_size = logits.shape | ||
| BLOCK_SIZE = 8192 | ||
| num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device) | ||
| _num_nans_kernel[(num_reqs,)]( | ||
| logits, | ||
| logits.stride(0), | ||
| num_nans, | ||
| vocab_size, | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| ) | ||
| return num_nans |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Triton kernel for counting NaNs adds complexity. A simpler and more maintainable approach would be to use the built-in PyTorch function torch.isnan. A one-line implementation like torch.sum(torch.isnan(logits), dim=1, dtype=torch.int32) would achieve the same result. While the Triton kernel might offer a slight performance advantage by avoiding the materialization of an intermediate tensor, this feature is disabled by default and likely used for debugging, where simplicity and maintainability are often more important. Consider replacing the Triton kernel with the simpler PyTorch implementation.
| # SPDX-License-Identifier: Apache-2.0 | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | |
| import torch | |
| from torch._inductor.runtime.triton_helpers import libdevice | |
| from vllm.triton_utils import tl, triton | |
| @triton.jit | |
| def _num_nans_kernel( | |
| logits_ptr, | |
| logits_stride, | |
| num_nans_ptr, | |
| vocab_size, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| req_idx = tl.program_id(0) | |
| num_nans = 0 | |
| for i in range(0, vocab_size, BLOCK_SIZE): | |
| block = i + tl.arange(0, BLOCK_SIZE) | |
| mask = block < vocab_size | |
| logits = tl.load( | |
| logits_ptr + req_idx * logits_stride + block, mask=mask, other=0 | |
| ) | |
| logits = logits.to(tl.float32) | |
| is_nan = libdevice.isnan(logits).to(tl.int1) | |
| num_nans += tl.sum(is_nan).to(tl.int32) | |
| tl.store(num_nans_ptr + req_idx, num_nans) | |
| def get_num_nans(logits: torch.Tensor) -> torch.Tensor: | |
| num_reqs, vocab_size = logits.shape | |
| BLOCK_SIZE = 8192 | |
| num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device) | |
| _num_nans_kernel[(num_reqs,)]( | |
| logits, | |
| logits.stride(0), | |
| num_nans, | |
| vocab_size, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| ) | |
| return num_nans | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | |
| import torch | |
| def get_num_nans(logits: torch.Tensor) -> torch.Tensor: | |
| """Counts the number of NaNs in the logits tensor per request.""" | |
| return torch.sum(torch.isnan(logits), dim=1, dtype=torch.int32) | |
| else: | ||
| logprobs_tensors = None | ||
|
|
||
| num_nans = get_num_nans(processed_logits) if self.compute_nans else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of NaNs is calculated on processed_logits. To better detect model instability, it would be more effective to calculate this on the raw logits before any processing is applied (e.g., at the beginning of the __call__ method). The processing steps could potentially hide or introduce NaNs, making the metric less reliable for monitoring the model's direct output.
| else: | ||
| self.logprobs_tensors = None | ||
| if sampler_output.num_nans is not None: | ||
| self.num_nans = sampler_output.num_nans.to("cpu", non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, could do the numpy() here, same for num_sampled_tokens_np
No description provided.