diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b287cec057..099b786d74 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4541,80 +4541,135 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ - - def _make_reshape_list_broadcastable(reshape_list, values_shape): - # Remove ones until the rank of reshape_list matches values_shape. - while len(reshape_list) > len(values_shape) and 1 in reshape_list: - reshape_list.remove(1) - - # Now ensure each dimension is broadcastable: - # This is mandatory when mixing basic and advanced indexing - # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) - # the reshape list should be : [[2, 1], [1, 3], [2, 1]] - for i, r in enumerate(reshape_list): - if r not in (1, values_shape[i]): - value_index = values_shape.index(r) - # Swap elements - # For the example above the current reshape list is [1, 2] for last dim, - # to make it broadcastable, we swap the elements - reshape_list[value_index], reshape_list[i] = r, 1 - - return reshape_list - - # Ensure the number of indices matches the tensor rank. + # Ensure the number of indices matches the tensor rank by appending trailing Nones. self_rank = len(self.shape) if len(indices) < self_rank: indices = list(indices) + [None] * (self_rank - len(indices)) - # Get values shape - values_shape = tuple(values.shape) + # The behavior of the op is dependent on whether there are advanced indices (i.e., non-scalar tensors) + # and whether these advanced indices are contiguous. + + # Identify advanced indices. + def is_advanced_index(index): + # Note: In this function, the index is assumed to be either None or an int64 Tensor. + return index is not None + + advanced_indices: list[int] = [] + none_indices: list[int] = [] + num_advanced_indices = 0 + num_none_indices = 0 + + for i, index in enumerate(indices): + if is_advanced_index(index): + advanced_indices.append(i) + num_advanced_indices += 1 + elif index is None: + none_indices.append(i) + num_none_indices += 1 + else: + raise ValueError(f"Unhandled index at position {i}: {index}") - index_vectors = [] - for i in range(self_rank): - if indices[i] is None: - # For a full slice along dim i, create a range index [0, self.shape[i]). - idx = op.Range(0, self.shape[i], 1) - reshape_update = self.shape[i] + self_shape = op.Shape(self) + if num_advanced_indices == 0: + return op.Expand(values, self_shape) + + # More than one advanced index may require broadcasting of index values + if num_advanced_indices > 1: + # Check for special case where all advanced indices have same shape. + # But need to ensure none of the shapes have None as a dimension, which + # will invalidate equality-based check. + first_shape = indices[advanced_indices[0]].shape + + def same_shape(other_shape: ir.Shape) -> bool: + return (not any(d is None for d in other_shape)) and other_shape == first_shape + + all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices) + if not all_same_shape: + # Broadcast advanced indices to a common shape. + advanced_index_rank = max(len(indices[i].shape) for i in advanced_indices) + shapes = [] + for i in advanced_indices: + index = indices[i] + index_rank = len(index.shape) + index_shape = op.Shape(index) + if index_rank < advanced_index_rank: + padding = op.Constant( + value_ints=[1 for _ in range(advanced_index_rank - index_rank)] + ) + index_shape = op.Concat(padding, index_shape, axis=0) + shapes.append(index_shape) + advanced_indices_shape = op.Max(*shapes) + indices = [ + op.Expand(index, advanced_indices_shape) if is_advanced_index(index) else index + for index in indices + ] else: - idx = indices[i] - reshape_update = math.prod(idx.shape) - # when Index is more than 1D, flatten it and also the values shape - # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) - # Indices -> (2*4,) and values shape (2*4, 32) - if len(idx.shape) > 1: - values_shape = (reshape_update, *values_shape[len(idx.shape) :]) - - # Flatten index (always working with 1D index in each dim) - idx = op.Reshape(idx, [-1]) - - # Create a reshape pattern: one value per index dimension, - # with the current dimension set to the update size. - reshape_list = [1] * len(indices) - reshape_list[i] = reshape_update - - # Adjust the reshape list to match the values shape. - reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) - - # Reshape and expand the index. - idx = op.Reshape(idx, reshape_list, allowzero=True) - idx = op.Expand(idx, values_shape) - - # Flatten the index to 1D and unsqueeze to form a column vector. - idx = op.Reshape(idx, [-1]) - idx = op.Unsqueeze(idx, axes=[1]) - index_vectors.append(idx) - - # Concatenate the index vectors along axis=1 to form the final indices. - new_index = op.Concat(*index_vectors, axis=1) - - # Flatten values to match the indices - flat_values = op.Reshape(values, [-1]) - - if accumulate: - result = op.ScatterND(self, new_index, flat_values, reduction="add") + advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) + advanced_index_rank = len(indices[advanced_indices[0]].shape) else: - result = op.ScatterND(self, new_index, flat_values) + advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) + advanced_index_rank = len(indices[advanced_indices[0]].shape) + + # ONNX ScatterND supports only the case where all advanced indices appear first, + # followed by None indices. So, we need to transpose self and values so that the + # advanced indices appear first, and then transpose the result back to original + # order at the end. + + none_indices_constant = op.Constant(value_ints=none_indices) + none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0) + target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0) + target_rank = advanced_index_rank + num_none_indices + + # Generate indices tensor required by ONNX ScatterND by unsqueezing an extra dimension and + # concatenating all advanced indices along this new dimension. + minus_one = op.Constant(value_ints=[-1]) + advanced_index_values = [op.Unsqueeze(indices[i], minus_one) for i in advanced_indices] + onnx_index = op.Concat(*advanced_index_values, axis=-1) + + # Check if advanced indices are contiguous: + contiguous = True + if advanced_indices: + if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices): + contiguous = False + + # Bring advanced indices to front: + perm = advanced_indices + none_indices + transposed = op.Transpose(self, perm=perm) + + # Expand values to match target shape: + # First, transpose values if necessary to match advanced indices order! + if contiguous: + # values may need to be transposed before expanding to target shape + num_padded_dims = target_rank - len(values.shape) + if num_padded_dims > 0: + unsqueezed_dims = op.Constant(value_ints=list(range(num_padded_dims))) + values = op.Unsqueeze(values, unsqueezed_dims) + initial_none_index_positions = list(range(advanced_indices[0])) + advanced_index_replacement_positions = list( + range(advanced_indices[0], advanced_indices[0] + advanced_index_rank) + ) + final_none_index_positions = list( + range(advanced_indices[0] + advanced_index_rank, target_rank) + ) + values_perm = ( + advanced_index_replacement_positions + + initial_none_index_positions + + final_none_index_positions + ) + values = op.Transpose(values, perm=values_perm) + + expanded_values = op.Expand(values, target_shape) + + updated = op.ScatterND( + transposed, onnx_index, expanded_values, reduction="add" if accumulate else None + ) + + # Inverse transpose to restore original dimension order: + inverse_perm = [0] * self_rank + for i, p in enumerate(perm): + inverse_perm[p] = i + result = op.Transpose(updated, perm=inverse_perm) return result diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index a2ced58c44..d344723408 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -1,10 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations -# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo - +import math import unittest +import parameterized + +# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo import torch from torch.onnx._internal.exporter import _testing @@ -626,6 +629,248 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + @parameterized.parameterized.expand( + [ + # Multiple advanced indices, all 1D tensors. + # Non-contiguous advanced indices: updates must be broadcastable to (2, 6) + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (2, 6), + "non_contiguous_non_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (2, 1), + "non_contiguous_non_broadcast_indices_expand_dim2", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (1, 6), + "non_contiguous_non_broadcast_indices_expand_dim1", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (6,), + "non_contiguous_non_broadcast_indices_new_dim1", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (), + "non_contiguous_non_broadcast_indices_scalar", + ), + # Contiguous advanced indices versions of above tests: updates must be broadcastable to (6, 2) + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (6, 2), + "contiguous_non_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (6, 1), + "contiguous_non_broadcast_indices_expand_dim2", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (1, 2), + "contiguous_non_broadcast_indices_expand_dim1", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (2,), + "contiguous_non_broadcast_indices_new_dim1", + ), + ((6, 6, 6), [None, [0, 1], [2, 3]], (), "contiguous_non_broadcast_indices_scalar"), + # Multiple advanced indices, with broadcasting among indices. + # Contiguous advanced indices: + # This produces index tuples [(0,2), (0, 3), (1,2), (1,3)] in shape (2,2) + # The update values must be broadcastable to (6,2,2) + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (6, 2, 2), + "contiguous_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (6, 1, 1), + "contiguous_broadcast_indices_expand_dim2_dim3", + ), + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (2,), + "contiguous_broadcast_indices_extend_dim1_dim2", + ), + # Non-contiguous advanced indices versions of above tests: + # Here, update values must be broadcastable to (2,2,6) + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (2, 2, 6), + "non_contiguous_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (1, 1, 6), + "non_contiguous_broadcast_indices_expand_dim1_dim2", + ), + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (6,), + "non_contiguous_broadcast_indices_extend_dim1_dim2", + ), + # Other test cases + ( + (4, 4, 4, 4), + [None, [0, 1], None, [2, 3]], + (2, 4, 4), + "non_contiguous_non_first", + ), + ((6, 6, 6), [0, None, None], (6, 6), "single_scalar_index"), + ((6, 6, 6), [0, None, [0, 1]], (2, 6), "non_contiguous_scalar_index_and_1d_index"), + ((6, 6, 6), [None, 0, [0, 1]], (6, 2), "contiguous_scalar_index_and_1d_index"), + # (TODO): Exporter doesn't yet support all None indices + # ((6, 6, 6), [None, None, None], (6, 6, 6), "all_none_indices"), + ] + ) + def test_index_put(self, x_shape, index_list, update_shape, _: str): + indices = [ + (torch.tensor(index, dtype=torch.int64) if index is not None else None) + for index in index_list + ] + + class Model(torch.nn.Module): + def forward(self, x, update): + return torch.ops.aten.index_put(x, indices, update, accumulate=True) + + x = torch.zeros(x_shape, dtype=torch.float32) + update = torch.randn(update_shape, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x, update), + input_names=["x", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_dynamic(self): + for dimension in [3, 4, 2]: + with self.subTest(dimension=dimension): + + class Model(torch.nn.Module): + def __init__(self, dimension): + super().__init__() + self.params = torch.zeros( + (4, 5) + if dimension == 2 + else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5)) + ) + self.dimension = dimension + + def forward(self, update, index1, index2): + copy = self.params.clone() + if self.dimension == 2: + copy[index1, index2] = update + elif self.dimension == 3: + copy[:, index1, index2] = update + else: + copy[:, :, index1, index2] = update + return copy + + update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32) + index1 = torch.tensor([1, 2], dtype=torch.int64) + index2 = torch.tensor([3, 4], dtype=torch.int64) + feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2))) + onnx_program = torch.onnx.export( + Model(dimension), + tuple(feeds.values()), + input_names=["update", "index1", "index2"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes={ + "update": {0: "dn"}, + "index1": {0: "dn"}, + "index2": {0: "dn"}, + }, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_55_12_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update) + + x = torch.zeros((6, 5), dtype=torch.float32) + index = torch.tensor([[2, 1]], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_55_2_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update, accumulate=True) + + x = torch.ones((6, 5), dtype=torch.float32) + index = torch.tensor([4, 3], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_scatter_nd(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + x = x.clone() + return torch.ops.aten.index_put(x, [None, index, None], update) + + shape = (2, 3, 2) + N = math.prod(shape) + x = torch.arange(N, dtype=torch.float32).reshape(shape) + update = (torch.arange(N, dtype=torch.float32).reshape(shape) + 1) * 100 + index = ((torch.arange(shape[-2])).to(torch.int64) + 1) % shape[-2] + + feeds = dict(zip(["x", "index", "update"], (x, index, update))) + onnx_program = torch.onnx.export( + Model(), + tuple(feeds.values()), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}), + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main()