Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from typing import Optional

import torch
import sys

if (sys.platform == "win32"):
rocminfo = "hipinfo"
else:
rocminfo = "rocminfo"


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -83,7 +89,7 @@ def get_rocm_gpu_arch() -> str:
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
result = subprocess.run([rocminfo], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
if match:
return "gfx" + match.group(1)
Expand All @@ -102,15 +108,18 @@ def get_rocm_gpu_arch() -> str:
return "unknown"


# Wavefront size (or warp size) in GPU computing is the number of threads that execute
# together in lockstep on a GPU core, typically 32 or 64, depending on the architecture
# (e.g., Nvidia is 32, older AMD GCN was 64, newer AMD RDNA can be 32 or 64).
def get_rocm_warpsize() -> int:
"""Get ROCm warp size."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
result = subprocess.run([rocminfo], capture_output=True, text=True)
match = re.search(r"(wavefront\s|warp)size:\s+([0-9]{2})(\([x0-9]{4}\))?", result.stdout, re.IGNORECASE)
if match:
return int(match.group(1))
return int(match.group(2))
else:
# default to 64 to be safe
return 64
Expand Down