diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3ce603c3ed2..546879c3e86 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -3778,17 +3778,17 @@ def test_kernel_image(self, dtype, device): @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_image_inplace(self, dtype, device): - input = make_image(self.INPUT_SIZE, dtype=dtype, device=device) - input_version = input._version + inpt = make_image(self.INPUT_SIZE, dtype=dtype, device=device) + input_version = inpt._version - output_out_of_place = F.erase_image(input, **self.FUNCTIONAL_KWARGS) - assert output_out_of_place.data_ptr() != input.data_ptr() - assert output_out_of_place is not input + output_out_of_place = F.erase_image(inpt, **self.FUNCTIONAL_KWARGS) + assert output_out_of_place.data_ptr() != inpt.data_ptr() + assert output_out_of_place is not inpt - output_inplace = F.erase_image(input, **self.FUNCTIONAL_KWARGS, inplace=True) - assert output_inplace.data_ptr() == input.data_ptr() + output_inplace = F.erase_image(inpt, **self.FUNCTIONAL_KWARGS, inplace=True) + assert output_inplace.data_ptr() == inpt.data_ptr() assert output_inplace._version > input_version - assert output_inplace is input + assert output_inplace is inpt assert_equal(output_inplace, output_out_of_place) @@ -3797,7 +3797,15 @@ def test_kernel_video(self): @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], ) def test_functional(self, make_input): check_functional(F.erase, make_input(), **self.FUNCTIONAL_KWARGS) @@ -3809,25 +3817,48 @@ def test_functional(self, make_input): (F._augment._erase_image_pil, PIL.Image.Image), (F.erase_image, tv_tensors.Image), (F.erase_video, tv_tensors.Video), + pytest.param( + F._augment._erase_image_cvcuda, + None, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if kernel is F._augment._erase_image_cvcuda: + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.erase, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", - [make_image_tensor, make_image_pil, make_image, make_video], + [ + make_image_tensor, + make_image_pil, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform(self, make_input, device): - input = make_input(device=device) + inpt = make_input(device=device) - with pytest.warns(UserWarning, match="currently passing through inputs of type"): + # shouldn't get a warning for cvcuda + if make_input is make_image_cvcuda: check_transform( transforms.RandomErasing(p=1), - input, - check_v1_compatibility=not isinstance(input, PIL.Image.Image), + inpt, + check_v1_compatibility=False, ) + else: + with pytest.warns(UserWarning, match="currently passing through inputs of type"): + check_transform( + transforms.RandomErasing(p=1), + inpt, + check_v1_compatibility=not isinstance(inpt, PIL.Image.Image), + ) def _reference_erase_image(self, image, *, i, j, h, w, v): mask = torch.zeros_like(image, dtype=torch.bool) @@ -3842,16 +3873,38 @@ def _reference_erase_image(self, image, *, i, j, h, w, v): return erased_image + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_functional_image_correctness(self, dtype, device): - image = make_image(dtype=dtype, device=device) + def test_functional_image_correctness(self, make_input, dtype, device): + image = make_input(dtype=dtype, device=device) actual = F.erase(image, **self.FUNCTIONAL_KWARGS) + + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = self._reference_erase_image(image, **self.FUNCTIONAL_KWARGS) assert_equal(actual, expected) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) @param_value_parametrization( scale=[(0.1, 0.2), [0.0, 1.0]], ratio=[(0.3, 0.7), [0.1, 5.0]], @@ -3860,10 +3913,10 @@ def test_functional_image_correctness(self, dtype, device): @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("seed", list(range(5))) - def test_transform_image_correctness(self, param, value, dtype, device, seed): + def test_transform_image_correctness(self, make_input, param, value, dtype, device, seed): transform = transforms.RandomErasing(**{param: value}, p=1) - image = make_image(dtype=dtype, device=device) + image = make_input(dtype=dtype, device=device) with freeze_rng_state(): torch.manual_seed(seed) @@ -3874,9 +3927,18 @@ def test_transform_image_correctness(self, param, value, dtype, device, seed): torch.manual_seed(seed) actual = transform(image) + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0].cpu() + expected = self._reference_erase_image(image, **params) - assert_equal(actual, expected) + if make_input is make_image_cvcuda and value == "random": + # CV-CUDA doesnt have same random distribution as torchvision + # it uses its own seeding, but we have determinism + # set seed with torch.randint in the kernel + assert_close(actual, expected, rtol=0, atol=256) + else: + assert_equal(actual, expected) def test_transform_errors(self): with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"): diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index c6da9aba98b..ccb61e57069 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -10,11 +10,15 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F +from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor from ._transform import _RandomApplyTransform, Transform from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size +CVCUDA_AVAILABLE = _is_cvcuda_available() + + class RandomErasing(_RandomApplyTransform): """Randomly select a rectangle region in the input image or video and erase its pixels. @@ -48,6 +52,9 @@ class RandomErasing(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomErasing + if CVCUDA_AVAILABLE: + _transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,) + def _extract_params_for_v1_transform(self) -> dict[str, Any]: return dict( super()._extract_params_for_v1_transform(), diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -16,7 +16,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + _is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index a904d8d7cbd..5d51c243fb4 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,4 +1,6 @@ import io +from types import SimpleNamespace +from typing import TYPE_CHECKING import PIL.Image @@ -8,7 +10,13 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] def erase( @@ -58,6 +66,97 @@ def erase_video( return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) +def _erase_image_cvcuda( + image: "cvcuda.Tensor", + i: int, + j: int, + h: int, + w: int, + v: torch.Tensor, + inplace: bool = False, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + if inplace: + raise ValueError("inplace is not supported for cvcuda.Tensor") + + # the v tensor is random if it has spatial dimensions > 1x1 + is_random_fill = v.shape[-2:] != (1, 1) + + # allocate any space for standard torch tensors + mask = (1 << image.shape[3]) - 1 + src_anchor = torch.tensor([[j, i]], dtype=torch.int32, device="cuda") + src_erasing = torch.tensor([[w, h, mask]], dtype=torch.int32, device="cuda") + src_idx = torch.tensor([0], dtype=torch.int32, device="cuda") + + # allocate the fill values based on if random or not + # use zeros for random fill since we have to pass the tensor to the kernel anyway + if is_random_fill: + src_vals = torch.zeros(4, device="cuda", dtype=torch.float32) + # CV-CUDA requires that the fill values is a flat size 4 tensor + # so we need to flatten the fill values and pad with zeros if needed + else: + v_flat = v.flatten().to(dtype=torch.float32, device="cuda") + if v_flat.numel() == 1: + src_vals = v_flat.expand(4).contiguous() + else: + if v_flat.numel() >= 4: + src_vals = v_flat[:4] + else: + pad_len = 4 - v_flat.numel() + src_vals = torch.cat([v_flat, torch.zeros(pad_len, device="cuda", dtype=torch.float32)]) + src_vals = src_vals.contiguous() + + # the simple tensors can be read directly by CV-CUDA + cv_imgIdx = cvcuda.as_tensor( + src_idx.reshape( + 1, + ), + "N", + ) + cv_values = cvcuda.as_tensor( + src_vals.reshape( + 1 * 4, + ), + "N", + ) + + # packed types (_2S32, _3S32) need to be copied into pre-allocated tensors + # torch does not support these packed types directly, so we create a helper function + # which will enable torch copy into the data directly (by overriding type/strides info) + def _to_torch(cv_tensor: cvcuda.Tensor, shape: tuple[int, ...], typestr: str) -> torch.Tensor: + iface = cv_tensor.cuda().__cuda_array_interface__ + iface.update(shape=shape, typestr=typestr, strides=None) + return torch.as_tensor(SimpleNamespace(__cuda_array_interface__=iface), device="cuda") + + # allocate the data for packed types + cv_anchor = cvcuda.Tensor((1,), cvcuda.Type._2S32, "N") + cv_erasing = cvcuda.Tensor((1,), cvcuda.Type._3S32, "N") + + # do a memcpy with torch, pretending data is scalar type contiguous + _to_torch(cv_anchor, (1, 2), " torch.Tensor: """See :class:`~torchvision.transforms.v2.JPEG` for details.""" if torch.jit.is_scripting(): diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..af03ad018d4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, _import_cvcuda().Tensor)(get_size_image_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)