Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
44db71c
implement additional cvcuda infra for all branches to avoid duplicate…
justincdavis Nov 25, 2025
e3dd700
update make_image_cvcuda to have default batch dim
justincdavis Nov 25, 2025
c035df1
add stanardized setup to main for easier updating of PRs and branches
justincdavis Dec 2, 2025
98d7dfb
update is_cvcuda_tensor
justincdavis Dec 2, 2025
ddc116d
add cvcuda to pil compatible to transforms by default
justincdavis Dec 2, 2025
e51dc7e
remove cvcuda from transform class
justincdavis Dec 2, 2025
e14e210
merge with main
justincdavis Dec 4, 2025
4939355
resolve more formatting naming
justincdavis Dec 4, 2025
fbea584
update is cvcuda tensor impl
justincdavis Dec 4, 2025
0afb9cd
stash wip
justincdavis Nov 25, 2025
6691339
implement additional cvcuda infra for all branches to avoid duplicate…
justincdavis Nov 25, 2025
6521570
stash wip
justincdavis Nov 25, 2025
5b451f9
wip
justincdavis Nov 25, 2025
b8c468c
rotate passing tests
justincdavis Nov 25, 2025
8861042
update rotate to use correct logic
justincdavis Dec 2, 2025
550656f
cvcuda rotate verified correct visualizly and passing all tests
justincdavis Dec 2, 2025
5fbeac3
move transformed type check to Rotate transform
justincdavis Dec 2, 2025
ea0bdec
update rotate to main standards
justincdavis Dec 4, 2025
5aaea08
remove unneeed cvcuda refs
justincdavis Dec 4, 2025
1fc4d6d
refacotr interp into helper
justincdavis Dec 4, 2025
ec2a97c
drop more unused refs to cvcuda
justincdavis Dec 4, 2025
cfc4412
update to resize interp func
justincdavis Dec 5, 2025
1a2d572
refactor interp setup
justincdavis Dec 5, 2025
5fe3394
minimize diff
justincdavis Dec 12, 2025
06a2e2c
duplicate comment for verbosity
justincdavis Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -2130,6 +2130,9 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
Expand All @@ -2144,9 +2147,16 @@ def test_functional(self, make_input):
(F.rotate_mask, tv_tensors.Mask),
(F.rotate_video, tv_tensors.Video),
(F.rotate_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._rotate_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._rotate_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
Expand All @@ -2159,6 +2169,9 @@ def test_functional_signature(self, kernel, input_type):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
Expand All @@ -2174,20 +2187,40 @@ def test_transform(self, make_input, device):
)
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
def test_functional_image_correctness(self, angle, center, interpolation, expand, fill):
image = make_image(dtype=torch.uint8, device="cpu")
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional_image_correctness(self, angle, center, interpolation, expand, fill, make_input):
image = make_input(dtype=torch.uint8, device="cpu")

fill = adapt_fill(fill, dtype=torch.uint8)

actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill)

if make_input is make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual)[0].cpu()
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(
F.rotate(
F.to_pil_image(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill
)
)

mae = (actual.float() - expected.float()).abs().mean()
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
if make_input is make_image_cvcuda:
# CV-CUDA Interp.NEAREST is actually NEAREST_EXACT, it has no direct match
# Interp handler uses Interp.NEAREST (NEAREST_EXACT) for InterpolationMode.NEAREST
# As such, we compound error, mostly on borders where padding is added
assert mae < 122.5 if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}"
else:
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}"

@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize(
Expand All @@ -2196,8 +2229,17 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_image_correctness(self, center, interpolation, expand, fill, seed):
image = make_image(dtype=torch.uint8, device="cpu")
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_transform_image_correctness(self, center, interpolation, expand, fill, seed, make_input):
image = make_input(dtype=torch.uint8, device="cpu")

fill = adapt_fill(fill, dtype=torch.uint8)

Expand All @@ -2213,10 +2255,21 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill,
actual = transform(image)

torch.manual_seed(seed)

if make_input is make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual)[0].cpu()
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(transform(F.to_pil_image(image)))

mae = (actual.float() - expected.float()).abs().mean()
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
if make_input is make_image_cvcuda:
# CV-CUDA Interp.NEAREST is actually NEAREST_EXACT, it has no direct match
# Interp handler uses Interp.NEAREST (NEAREST_EXACT) for InterpolationMode.NEAREST
# As such, we compound error, mostly on borders where padding is added
assert mae < (122.5) if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}"
else:
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6, f"MAE: {mae}"

def _compute_output_canvas_size(self, *, expand, canvas_size, affine_matrix):
if not expand:
Expand Down
2 changes: 2 additions & 0 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,8 @@ class RandomRotation(Transform):

_v1_transform_cls = _transforms.RandomRotation

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(
self,
degrees: Union[numbers.Number, Sequence],
Expand Down
85 changes: 85 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ._utils import (
_FillTypeJIT,
_get_cvcuda_interp,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
Expand Down Expand Up @@ -1535,6 +1536,90 @@ def rotate_video(
return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)


def _rotate_image_cvcuda(
inpt: "cvcuda.Tensor",
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[list[float]] = None,
fill: _FillTypeJIT = None,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

angle = angle % 360

if angle == 0:
return inpt

if angle == 180:
return cvcuda.flip(inpt, flipCode=-1)

interp = _get_cvcuda_interp(interpolation)

input_height, input_width = inpt.shape[1], inpt.shape[2]
num_channels = inpt.shape[3]

if fill is None:
fill_value = [0.0] * num_channels
elif isinstance(fill, (int, float)):
fill_value = [float(fill)] * num_channels
else:
fill_value = [float(f) for f in fill]

# Determine the rotation center
# torchvision uses image center by default, cvcuda rotates around upper-left (0,0)
# We need to calculate a shift to effectively rotate around the desired center
if center is None:
cx, cy = input_width / 2.0, input_height / 2.0
center_f = [0.0, 0.0]
else:
cx, cy = float(center[0]), float(center[1])
# Convert to image-center-relative coordinates (same as torchvision)
center_f = [cx - input_width * 0.5, cy - input_height * 0.5]

angle_rad = math.radians(angle)
cos_angle = math.cos(angle_rad)
sin_angle = math.sin(angle_rad)

# if we are not expanding, simple case
if not expand:
shift_x = (1 - cos_angle) * cx - sin_angle * cy
shift_y = sin_angle * cx + (1 - cos_angle) * cy

return cvcuda.rotate(inpt, angle_deg=angle, shift=(shift_x, shift_y), interpolation=interp)

# if we need to expand, use much of the same logic as torchvision, for output size/pad
# Use center_f (image-center-relative coords) to match torchvision's output size calculation
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
output_width, output_height = _compute_affine_output_size(matrix, input_width, input_height)

pad_left = (output_width - input_width) // 2
pad_right = output_width - input_width - pad_left
pad_top = (output_height - input_height) // 2
pad_bottom = output_height - input_height - pad_top

padded = cvcuda.copymakeborder(
inpt,
top=pad_top,
left=pad_left,
bottom=pad_bottom,
right=pad_right,
border_mode=cvcuda.Border.CONSTANT,
border_value=fill_value,
)

new_cx = pad_left + cx
new_cy = pad_top + cy
shift_x = (1 - cos_angle) * new_cx - sin_angle * new_cy
shift_y = sin_angle * new_cx + (1 - cos_angle) * new_cy

return cvcuda.rotate(padded, angle_deg=angle, shift=(shift_x, shift_y), interpolation=interp)


if CVCUDA_AVAILABLE:
_register_kernel_internal(rotate, _import_cvcuda().Tensor)(_rotate_image_cvcuda)


def pad(
inpt: torch.Tensor,
padding: list[int],
Expand Down
48 changes: 47 additions & 1 deletion torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import functools
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

import torch
from torchvision import tv_tensors
from torchvision.transforms.functional import InterpolationMode

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]

_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[list[float]]
Expand Down Expand Up @@ -177,3 +181,45 @@ def _is_cvcuda_tensor(inpt: Any) -> bool:
return isinstance(inpt, cvcuda.Tensor)
except ImportError:
return False


_interpolation_mode_to_cvcuda_interp: dict[InterpolationMode | str | int, "cvcuda.Interp"] = {}


def _get_cvcuda_interp(interpolation: InterpolationMode | str | int) -> "cvcuda.Interp":
"""
Get the CV-CUDA interpolation mode for a given interpolation mode.

CV-CUDA has the two following differences (evaluated in tests) comapred to TorchVision/PIL:
1. CV-CUDA does not have a match for NEAREST, its Interp.NEAREST is actually NEAREST_EXACT
Since we need to do interpolation, we will map NEAREST to Interp.NEAREST (which is NEAREST_EXACT)
2. BICUBIC interpolation method is different compared to TorchVision/PIL, algorithmic difference
"""
if len(_interpolation_mode_to_cvcuda_interp) == 0:
cvcuda = _import_cvcuda()
_interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp[InterpolationMode.NEAREST_EXACT] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp[InterpolationMode.BILINEAR] = cvcuda.Interp.LINEAR
_interpolation_mode_to_cvcuda_interp[InterpolationMode.BICUBIC] = cvcuda.Interp.CUBIC
_interpolation_mode_to_cvcuda_interp[InterpolationMode.BOX] = cvcuda.Interp.BOX
_interpolation_mode_to_cvcuda_interp[InterpolationMode.HAMMING] = cvcuda.Interp.HAMMING
_interpolation_mode_to_cvcuda_interp[InterpolationMode.LANCZOS] = cvcuda.Interp.LANCZOS
_interpolation_mode_to_cvcuda_interp["nearest"] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp["nearest-exact"] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp["bilinear"] = cvcuda.Interp.LINEAR
_interpolation_mode_to_cvcuda_interp["bicubic"] = cvcuda.Interp.CUBIC
_interpolation_mode_to_cvcuda_interp["box"] = cvcuda.Interp.BOX
_interpolation_mode_to_cvcuda_interp["hamming"] = cvcuda.Interp.HAMMING
_interpolation_mode_to_cvcuda_interp["lanczos"] = cvcuda.Interp.LANCZOS
_interpolation_mode_to_cvcuda_interp[0] = cvcuda.Interp.NEAREST
_interpolation_mode_to_cvcuda_interp[2] = cvcuda.Interp.LINEAR
_interpolation_mode_to_cvcuda_interp[3] = cvcuda.Interp.CUBIC
_interpolation_mode_to_cvcuda_interp[4] = cvcuda.Interp.BOX
_interpolation_mode_to_cvcuda_interp[5] = cvcuda.Interp.HAMMING
_interpolation_mode_to_cvcuda_interp[1] = cvcuda.Interp.LANCZOS

interp = _interpolation_mode_to_cvcuda_interp.get(interpolation)
if interp is None:
raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA")

return interp