From 69ac4c3d485122ffde60e53b4e80ef2764fa8c15 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 21:34:38 -0800 Subject: [PATCH 1/6] First version of aten index put Signed-off-by: Ganesan Ramalingam --- .../function_libs/torch_lib/ops/core.py | 247 +++++++++++++----- 1 file changed, 181 insertions(+), 66 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 326075b2fe..6a25cc17b6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4432,81 +4432,196 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ - - def _make_reshape_list_broadcastable(reshape_list, values_shape): - # Remove ones until the rank of reshape_list matches values_shape. - while len(reshape_list) > len(values_shape) and 1 in reshape_list: - reshape_list.remove(1) - - # Now ensure each dimension is broadcastable: - # This is mandatory when mixing basic and advanced indexing - # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) - # the reshape list should be : [[2, 1], [1, 3], [2, 1]] - for i, r in enumerate(reshape_list): - if r not in (1, values_shape[i]): - value_index = values_shape.index(r) - # Swap elements - # For the example above the current reshape list is [1, 2] for last dim, - # to make it broadcastable, we swap the elements - reshape_list[value_index], reshape_list[i] = r, 1 - - return reshape_list - # Ensure the number of indices matches the tensor rank. self_rank = len(self.shape) if len(indices) < self_rank: indices = list(indices) + [None] * (self_rank - len(indices)) - # Get values shape - values_shape = tuple(values.shape) + # Identify advanced indices. + def is_advanced_index(index): + return index is not None and len(index.shape) > 0 + + def is_scalar_index(index): + return index is not None and len(index.shape) == 0 + + scalar_indices : list[int] = [] + advanced_indices : list[int] = [] + none_indices : list[int] = [] + num_advanced_indices = 0 + num_scalar_indices = 0 + num_none_indices = 0 + + for i, index in enumerate(indices): + if is_advanced_index(index): + advanced_indices.append(i) + num_advanced_indices += 1 + elif is_scalar_index(index): + scalar_indices.append(i) + num_scalar_indices += 1 + elif index is None: + none_indices.append(i) + num_none_indices += 1 + else: + raise ValueError(f"Unhandled index at position {i}: {index}") + + if num_scalar_indices > 0: + raise NotImplementedError("Scalar indices not yet supported in aten_index_put.") - index_vectors = [] - for i in range(self_rank): - if indices[i] is None: - # For a full slice along dim i, create a range index [0, self.shape[i]). - idx = op.Range(0, self.shape[i], 1) - reshape_update = self.shape[i] + self_shape = op.Shape(self) + if num_advanced_indices == 0: + return op.Expand(values, self_shape) + + # More than one advanced index may require broadcasting of index values + if num_advanced_indices > 1: + # Check for special case where all advanced indices have same shape. + # But need to ensure none of the shapes have None as a dimension, which + # will invalidate equality-based check. + first_shape = indices[advanced_indices[0]].shape + def same_shape(other_shape: ir.Shape) -> bool: + return (not any(d is None for d in other_shape)) and other_shape == first_shape + all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices) + if not all_same_shape: + # Broadcast advanced indices to a common shape. + advanced_index_rank = max(len(indices[i].shape) for i in advanced_indices) + shapes = [] + for i in advanced_indices: + index = indices[i] + index_rank = len(index.shape) + index_shape = op.Shape(index) + if index_rank < advanced_index_rank: + padding = op.Constant(value_ints=[1 for _ in range(advanced_index_rank - index_rank)]) + index_shape = op.Concat(padding, index_shape, axis=0) + shapes.append(index_shape) + advanced_indices_shape = op.Max(*shapes) + indices = [op.Expand(index, advanced_indices_shape) if is_advanced_index(index) else index for index in indices] else: - idx = indices[i] - reshape_update = math.prod(idx.shape) - # when Index is more than 1D, flatten it and also the values shape - # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) - # Indices -> (2*4,) and values shape (2*4, 32) - if len(idx.shape) > 1: - values_shape = (reshape_update, *values_shape[len(idx.shape) :]) - - # Flatten index (always working with 1D index in each dim) - idx = op.Reshape(idx, [-1]) - - # Create a reshape pattern: one value per index dimension, - # with the current dimension set to the update size. - reshape_list = [1] * len(indices) - reshape_list[i] = reshape_update - - # Adjust the reshape list to match the values shape. - reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) - - # Reshape and expand the index. - idx = op.Reshape(idx, reshape_list, allowzero=True) - idx = op.Expand(idx, values_shape) - - # Flatten the index to 1D and unsqueeze to form a column vector. - idx = op.Reshape(idx, [-1]) - idx = op.Unsqueeze(idx, axes=[1]) - index_vectors.append(idx) - - # Concatenate the index vectors along axis=1 to form the final indices. - new_index = op.Concat(*index_vectors, axis=1) - - # Flatten values to match the indices - flat_values = op.Reshape(values, [-1]) - - if accumulate: - result = op.ScatterND(self, new_index, flat_values, reduction="add") + advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) + advanced_index_rank = len(indices[advanced_indices[0]].shape) else: - result = op.ScatterND(self, new_index, flat_values) - + advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) + advanced_index_rank = len(indices[advanced_indices[0]].shape) + + none_indices_constant = op.Constant(value_ints=none_indices) + none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0) + target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0) + target_rank = advanced_index_rank + num_none_indices + + # Generate indices tensor required by ONNX ScatterND by unsqueezing an extra dimension and + # concatenating all advanced indices along this new dimension. + minus_one = op.Constant(value_ints=[-1]) + advanced_index_values = [op.Unsqueeze(indices[i], minus_one) for i in advanced_indices] + onnx_index = op.Concat(*advanced_index_values, axis=-1) + + # Check if advanced indices are contiguous: + non_contiguous = False + if advanced_indices: + if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices): + non_contiguous = True + + if non_contiguous: + raise NotImplementedError("Non-contiguous advanced indices not yet supported in aten_index_put.") + + # Bring advanced indices to front: + perm = advanced_indices + none_indices + transposed = op.Transpose(self, perm=perm) + + # Expand values to match target shape: + # First, transpose values if necessary to match advanced indices order! + num_padded_dims = target_rank - len(values.shape) + if num_padded_dims > 0: + unsqueezed_dims = op.Constant(value_ints=list(range(num_padded_dims))) + values = op.Unsqueeze(values, unsqueezed_dims) + initial_none_index_positions = list(range(0, advanced_indices[0])) + advanced_index_replacement_positions = list(range(advanced_indices[0], advanced_indices[0] + advanced_index_rank)) + final_none_index_positions = list(range(advanced_indices[0] + advanced_index_rank, target_rank)) + values_perm = advanced_index_replacement_positions + initial_none_index_positions + final_none_index_positions + transposed_values = op.Transpose(values, perm=values_perm) + expanded_values = op.Expand(transposed_values, target_shape) + + updated = op.ScatterND(transposed, onnx_index, expanded_values, reduction="add" if accumulate else None) + + # Inverse transpose to restore original dimension order: + inverse_perm = [0] * self_rank + for i, p in enumerate(perm): + inverse_perm[p] = i + result = op.Transpose(updated, perm=inverse_perm) return result + + + # def _make_reshape_list_broadcastable(reshape_list, values_shape): + # # Remove ones until the rank of reshape_list matches values_shape. + # while len(reshape_list) > len(values_shape) and 1 in reshape_list: + # reshape_list.remove(1) + + # # Now ensure each dimension is broadcastable: + # # This is mandatory when mixing basic and advanced indexing + # # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) + # # the reshape list should be : [[2, 1], [1, 3], [2, 1]] + # for i, r in enumerate(reshape_list): + # if r not in (1, values_shape[i]): + # value_index = values_shape.index(r) + # # Swap elements + # # For the example above the current reshape list is [1, 2] for last dim, + # # to make it broadcastable, we swap the elements + # reshape_list[value_index], reshape_list[i] = r, 1 + + # return reshape_list + + # # Ensure the number of indices matches the tensor rank. + # self_rank = len(self.shape) + # if len(indices) < self_rank: + # indices = list(indices) + [None] * (self_rank - len(indices)) + + # # Get values shape + # values_shape = tuple(values.shape) + + # index_vectors = [] + # for i in range(self_rank): + # if indices[i] is None: + # # For a full slice along dim i, create a range index [0, self.shape[i]). + # idx = op.Range(0, self.shape[i], 1) + # reshape_update = self.shape[i] + # else: + # idx = indices[i] + # reshape_update = math.prod(idx.shape) + # # when Index is more than 1D, flatten it and also the values shape + # # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) + # # Indices -> (2*4,) and values shape (2*4, 32) + # if len(idx.shape) > 1: + # values_shape = (reshape_update, *values_shape[len(idx.shape) :]) + + # # Flatten index (always working with 1D index in each dim) + # idx = op.Reshape(idx, [-1]) + + # # Create a reshape pattern: one value per index dimension, + # # with the current dimension set to the update size. + # reshape_list = [1] * len(indices) + # reshape_list[i] = reshape_update + + # # Adjust the reshape list to match the values shape. + # reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) + + # # Reshape and expand the index. + # idx = op.Reshape(idx, reshape_list, allowzero=True) + # idx = op.Expand(idx, values_shape) + + # # Flatten the index to 1D and unsqueeze to form a column vector. + # idx = op.Reshape(idx, [-1]) + # idx = op.Unsqueeze(idx, axes=[1]) + # index_vectors.append(idx) + + # # Concatenate the index vectors along axis=1 to form the final indices. + # new_index = op.Concat(*index_vectors, axis=1) + + # # Flatten values to match the indices + # flat_values = op.Reshape(values, [-1]) + + # if accumulate: + # result = op.ScatterND(self, new_index, flat_values, reduction="add") + # else: + # result = op.ScatterND(self, new_index, flat_values) + + # return result @torch_op("aten::index_put", trace_only=True) From 1c902a47b9ef1c3086c73eb021b8340c2611ef0a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 22:15:45 -0800 Subject: [PATCH 2/6] Aten index put Signed-off-by: Ganesan Ramalingam --- .../function_libs/torch_lib/ops/core.py | 69 ++++++++++++------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6a25cc17b6..7e2a242c43 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4437,16 +4437,16 @@ def aten_index_put( if len(indices) < self_rank: indices = list(indices) + [None] * (self_rank - len(indices)) - # Identify advanced indices. + # Identify advanced indices. def is_advanced_index(index): return index is not None and len(index.shape) > 0 - + def is_scalar_index(index): return index is not None and len(index.shape) == 0 - scalar_indices : list[int] = [] - advanced_indices : list[int] = [] - none_indices : list[int] = [] + scalar_indices: list[int] = [] + advanced_indices: list[int] = [] + none_indices: list[int] = [] num_advanced_indices = 0 num_scalar_indices = 0 num_none_indices = 0 @@ -4470,15 +4470,17 @@ def is_scalar_index(index): self_shape = op.Shape(self) if num_advanced_indices == 0: return op.Expand(values, self_shape) - + # More than one advanced index may require broadcasting of index values if num_advanced_indices > 1: # Check for special case where all advanced indices have same shape. # But need to ensure none of the shapes have None as a dimension, which # will invalidate equality-based check. first_shape = indices[advanced_indices[0]].shape + def same_shape(other_shape: ir.Shape) -> bool: return (not any(d is None for d in other_shape)) and other_shape == first_shape + all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices) if not all_same_shape: # Broadcast advanced indices to a common shape. @@ -4489,11 +4491,16 @@ def same_shape(other_shape: ir.Shape) -> bool: index_rank = len(index.shape) index_shape = op.Shape(index) if index_rank < advanced_index_rank: - padding = op.Constant(value_ints=[1 for _ in range(advanced_index_rank - index_rank)]) + padding = op.Constant( + value_ints=[1 for _ in range(advanced_index_rank - index_rank)] + ) index_shape = op.Concat(padding, index_shape, axis=0) shapes.append(index_shape) advanced_indices_shape = op.Max(*shapes) - indices = [op.Expand(index, advanced_indices_shape) if is_advanced_index(index) else index for index in indices] + indices = [ + op.Expand(index, advanced_indices_shape) if is_advanced_index(index) else index + for index in indices + ] else: advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) advanced_index_rank = len(indices[advanced_indices[0]].shape) @@ -4513,13 +4520,10 @@ def same_shape(other_shape: ir.Shape) -> bool: onnx_index = op.Concat(*advanced_index_values, axis=-1) # Check if advanced indices are contiguous: - non_contiguous = False + contiguous = True if advanced_indices: if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices): - non_contiguous = True - - if non_contiguous: - raise NotImplementedError("Non-contiguous advanced indices not yet supported in aten_index_put.") + contiguous = False # Bring advanced indices to front: perm = advanced_indices + none_indices @@ -4527,26 +4531,39 @@ def same_shape(other_shape: ir.Shape) -> bool: # Expand values to match target shape: # First, transpose values if necessary to match advanced indices order! - num_padded_dims = target_rank - len(values.shape) - if num_padded_dims > 0: - unsqueezed_dims = op.Constant(value_ints=list(range(num_padded_dims))) - values = op.Unsqueeze(values, unsqueezed_dims) - initial_none_index_positions = list(range(0, advanced_indices[0])) - advanced_index_replacement_positions = list(range(advanced_indices[0], advanced_indices[0] + advanced_index_rank)) - final_none_index_positions = list(range(advanced_indices[0] + advanced_index_rank, target_rank)) - values_perm = advanced_index_replacement_positions + initial_none_index_positions + final_none_index_positions - transposed_values = op.Transpose(values, perm=values_perm) - expanded_values = op.Expand(transposed_values, target_shape) - - updated = op.ScatterND(transposed, onnx_index, expanded_values, reduction="add" if accumulate else None) + if contiguous: + # values may need to be transposed before expanding to target shape + num_padded_dims = target_rank - len(values.shape) + if num_padded_dims > 0: + unsqueezed_dims = op.Constant(value_ints=list(range(num_padded_dims))) + values = op.Unsqueeze(values, unsqueezed_dims) + initial_none_index_positions = list(range(advanced_indices[0])) + advanced_index_replacement_positions = list( + range(advanced_indices[0], advanced_indices[0] + advanced_index_rank) + ) + final_none_index_positions = list( + range(advanced_indices[0] + advanced_index_rank, target_rank) + ) + values_perm = ( + advanced_index_replacement_positions + + initial_none_index_positions + + final_none_index_positions + ) + values = op.Transpose(values, perm=values_perm) + + expanded_values = op.Expand(values, target_shape) + + updated = op.ScatterND( + transposed, onnx_index, expanded_values, reduction="add" if accumulate else None + ) # Inverse transpose to restore original dimension order: + inverse_perm = [0] * self_rank for i, p in enumerate(perm): inverse_perm[p] = i result = op.Transpose(updated, perm=inverse_perm) return result - # def _make_reshape_list_broadcastable(reshape_list, values_shape): # # Remove ones until the rank of reshape_list matches values_shape. From c9db17164645756bd97a0bfee558c9a697314c1a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 18 Nov 2025 10:13:27 -0800 Subject: [PATCH 3/6] Add test cases Signed-off-by: Ganesan Ramalingam --- .../function_libs/torch_lib/ops/core.py | 86 +++---------------- .../function_libs/torch_lib/e2e_ops_tests.py | 66 ++++++++++++-- 2 files changed, 71 insertions(+), 81 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 7e2a242c43..08445db803 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4432,11 +4432,14 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ - # Ensure the number of indices matches the tensor rank. + # Ensure the number of indices matches the tensor rank by appending trailing Nones. self_rank = len(self.shape) if len(indices) < self_rank: indices = list(indices) + [None] * (self_rank - len(indices)) + # The behavior of the op is dependent on whether there are advanced indices (i.e., non-scalar tensors) + # and whether these advanced indices are contiguous. + # Identify advanced indices. def is_advanced_index(index): return index is not None and len(index.shape) > 0 @@ -4465,6 +4468,7 @@ def is_scalar_index(index): raise ValueError(f"Unhandled index at position {i}: {index}") if num_scalar_indices > 0: + # TODO: handle scalar indices raise NotImplementedError("Scalar indices not yet supported in aten_index_put.") self_shape = op.Shape(self) @@ -4508,6 +4512,11 @@ def same_shape(other_shape: ir.Shape) -> bool: advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) advanced_index_rank = len(indices[advanced_indices[0]].shape) + # ONNX ScatterND supports only the case where all advanced indices appear first, + # followed by None indices. So, we need to transpose self and values so that the + # advanced indices appear first, and then transpose the result back to original + # order at the end. + none_indices_constant = op.Constant(value_ints=none_indices) none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0) target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0) @@ -4565,81 +4574,6 @@ def same_shape(other_shape: ir.Shape) -> bool: result = op.Transpose(updated, perm=inverse_perm) return result - # def _make_reshape_list_broadcastable(reshape_list, values_shape): - # # Remove ones until the rank of reshape_list matches values_shape. - # while len(reshape_list) > len(values_shape) and 1 in reshape_list: - # reshape_list.remove(1) - - # # Now ensure each dimension is broadcastable: - # # This is mandatory when mixing basic and advanced indexing - # # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) - # # the reshape list should be : [[2, 1], [1, 3], [2, 1]] - # for i, r in enumerate(reshape_list): - # if r not in (1, values_shape[i]): - # value_index = values_shape.index(r) - # # Swap elements - # # For the example above the current reshape list is [1, 2] for last dim, - # # to make it broadcastable, we swap the elements - # reshape_list[value_index], reshape_list[i] = r, 1 - - # return reshape_list - - # # Ensure the number of indices matches the tensor rank. - # self_rank = len(self.shape) - # if len(indices) < self_rank: - # indices = list(indices) + [None] * (self_rank - len(indices)) - - # # Get values shape - # values_shape = tuple(values.shape) - - # index_vectors = [] - # for i in range(self_rank): - # if indices[i] is None: - # # For a full slice along dim i, create a range index [0, self.shape[i]). - # idx = op.Range(0, self.shape[i], 1) - # reshape_update = self.shape[i] - # else: - # idx = indices[i] - # reshape_update = math.prod(idx.shape) - # # when Index is more than 1D, flatten it and also the values shape - # # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) - # # Indices -> (2*4,) and values shape (2*4, 32) - # if len(idx.shape) > 1: - # values_shape = (reshape_update, *values_shape[len(idx.shape) :]) - - # # Flatten index (always working with 1D index in each dim) - # idx = op.Reshape(idx, [-1]) - - # # Create a reshape pattern: one value per index dimension, - # # with the current dimension set to the update size. - # reshape_list = [1] * len(indices) - # reshape_list[i] = reshape_update - - # # Adjust the reshape list to match the values shape. - # reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) - - # # Reshape and expand the index. - # idx = op.Reshape(idx, reshape_list, allowzero=True) - # idx = op.Expand(idx, values_shape) - - # # Flatten the index to 1D and unsqueeze to form a column vector. - # idx = op.Reshape(idx, [-1]) - # idx = op.Unsqueeze(idx, axes=[1]) - # index_vectors.append(idx) - - # # Concatenate the index vectors along axis=1 to form the final indices. - # new_index = op.Concat(*index_vectors, axis=1) - - # # Flatten values to match the indices - # flat_values = op.Reshape(values, [-1]) - - # if accumulate: - # result = op.ScatterND(self, new_index, flat_values, reduction="add") - # else: - # result = op.ScatterND(self, new_index, flat_values) - - # return result - @torch_op("aten::index_put", trace_only=True) def aten_index_put_bool( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 1546de59bd..b62ab9239f 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -1,10 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo +from __future__ import annotations import unittest +# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo import torch from torch.onnx._internal.exporter import _testing @@ -520,6 +520,62 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) - -if __name__ == "__main__": - unittest.main() + def test_my_index_put(self): + def test(x_shape, index_list, update_shape, testname): + with self.subTest(testname=testname): + indices = [ + (torch.tensor(index, dtype=torch.int64) if index is not None else None) + for index in index_list + ] + + class Model(torch.nn.Module): + def forward(self, x, update): + return torch.ops.aten.index_put(x, indices, update, accumulate=True) + + x = torch.zeros(x_shape, dtype=torch.float32) + update = torch.randn(update_shape, dtype=torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, update), + input_names=["x", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + # Test cases + shape_6x6x6 = (6, 6, 6) + + # Multiple advanced indices, all 1D tensors. + # Non-contiguous advanced indices: updates must be broadcastable to (2, 6) + base = "non_contiguous_non_broadcast_indices_" + test(shape_6x6x6, [[0, 1], None, [2, 3]], (2, 6), base + "no_value_broadcast") + test(shape_6x6x6, [[0, 1], None, [2, 3]], (2, 1), base + "expand_dim2") + test(shape_6x6x6, [[0, 1], None, [2, 3]], (1, 6), base + "expand_dim1") + test(shape_6x6x6, [[0, 1], None, [2, 3]], (6,), base + "new_dim1") + test(shape_6x6x6, [[0, 1], None, [2, 3]], (), base + "scalar") + + # Contiguous advanced indices versions of above tests: updates must be broadcastable to (6, 2) + base = "contiguous_non_broadcast_indices_" + test(shape_6x6x6, [None, [0, 1], [2, 3]], (6, 2), base + "no_value_broadcast") + test(shape_6x6x6, [None, [0, 1], [2, 3]], (6, 1), base + "expand_dim2") + test(shape_6x6x6, [None, [0, 1], [2, 3]], (1, 2), base + "expand_dim1") + test(shape_6x6x6, [None, [0, 1], [2, 3]], (2,), base + "new_dim1") + test(shape_6x6x6, [None, [0, 1], [2, 3]], (), base + "scalar") + + # Multiple advanced indices, with broadcasting among indices. + # Contiguous advanced indices: + # This produces index tuples [(0,2), (0, 3), (1,2), (1,3)] in shape (2,2) + # The update values must be broadcastable to (6,2,2) + base = "contiguous_broadcast_indices_" + test(shape_6x6x6, [None, [[0], [1]], [2, 3]], (6, 2, 2), base + "no_value_broadcast") + test(shape_6x6x6, [None, [[0], [1]], [2, 3]], (6, 1, 1), base + "expand_dim2_dim3") + test(shape_6x6x6, [None, [[0], [1]], [2, 3]], (2,), base + "extend_dim1_dim2") + + # Non-contiguous advanced indices versions of above tests: + # Here, update values must be broadcastable to (2,2,6) + base = "non_contiguous_broadcast_indices_" + test(shape_6x6x6, [[[0], [1]], None, [2, 3]], (2, 2, 6), base + "no_value_broadcast") + test(shape_6x6x6, [[[0], [1]], None, [2, 3]], (1, 1, 6), base + "expand_dim1_dim2") + test(shape_6x6x6, [[[0], [1]], None, [2, 3]], (6,), base + "extend_dim1_dim2") From c8578cf1bc3a1a3292d977d49b91466f27f2e74f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 18 Nov 2025 13:47:00 -0800 Subject: [PATCH 4/6] Handling scalar tensors Signed-off-by: Ganesan Ramalingam --- .../function_libs/torch_lib/ops/core.py | 15 ++------------- .../function_libs/torch_lib/e2e_ops_tests.py | 19 ++++++++++++++++++- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 08445db803..c5c1fc3cd6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4442,35 +4442,24 @@ def aten_index_put( # Identify advanced indices. def is_advanced_index(index): - return index is not None and len(index.shape) > 0 + # Note: In this function, the index is assumed to be either None or an int64 Tensor. + return index is not None - def is_scalar_index(index): - return index is not None and len(index.shape) == 0 - - scalar_indices: list[int] = [] advanced_indices: list[int] = [] none_indices: list[int] = [] num_advanced_indices = 0 - num_scalar_indices = 0 num_none_indices = 0 for i, index in enumerate(indices): if is_advanced_index(index): advanced_indices.append(i) num_advanced_indices += 1 - elif is_scalar_index(index): - scalar_indices.append(i) - num_scalar_indices += 1 elif index is None: none_indices.append(i) num_none_indices += 1 else: raise ValueError(f"Unhandled index at position {i}: {index}") - if num_scalar_indices > 0: - # TODO: handle scalar indices - raise NotImplementedError("Scalar indices not yet supported in aten_index_put.") - self_shape = op.Shape(self) if num_advanced_indices == 0: return op.Expand(values, self_shape) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index b62ab9239f..057947cec1 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -522,7 +522,7 @@ def forward(self, x): def test_my_index_put(self): def test(x_shape, index_list, update_shape, testname): - with self.subTest(testname=testname): + with self.subTest(testname="my_index_put_" + testname): indices = [ (torch.tensor(index, dtype=torch.int64) if index is not None else None) for index in index_list @@ -534,6 +534,7 @@ def forward(self, x, update): x = torch.zeros(x_shape, dtype=torch.float32) update = torch.randn(update_shape, dtype=torch.float32) + onnx_program = torch.onnx.export( Model(), (x, update), @@ -546,6 +547,7 @@ def forward(self, x, update): # Test cases shape_6x6x6 = (6, 6, 6) + shape_4x4x4x4 = (4, 4, 4, 4) # Multiple advanced indices, all 1D tensors. # Non-contiguous advanced indices: updates must be broadcastable to (2, 6) @@ -579,3 +581,18 @@ def forward(self, x, update): test(shape_6x6x6, [[[0], [1]], None, [2, 3]], (2, 2, 6), base + "no_value_broadcast") test(shape_6x6x6, [[[0], [1]], None, [2, 3]], (1, 1, 6), base + "expand_dim1_dim2") test(shape_6x6x6, [[[0], [1]], None, [2, 3]], (6,), base + "extend_dim1_dim2") + + test( + shape_4x4x4x4, [None, [0, 1], None, [2, 3]], (2, 4, 4), "non_contiguous_non_first" + ) + test(shape_6x6x6, [0, None, None], (6, 6), "single_scalar_index") + test( + shape_6x6x6, [0, None, [0, 1]], (2, 6), "non_contiguous_scalar_index_and_1d_index" + ) + test(shape_6x6x6, [None, 0, [0, 1]], (6, 2), "contiguous_scalar_index_and_1d_index") + # (TODO): Exporter doesn't yet support all None indices + # test(shape_6x6x6, [None, None, None], shape_6x6x6, "all_none_indices") + + +if __name__ == "__main__": + unittest.main() From 4cea2911c6a8cffcf46ab8b729fdf218cff9d96b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 18 Nov 2025 13:48:21 -0800 Subject: [PATCH 5/6] Fix test name Signed-off-by: Ganesan Ramalingam --- tests/function_libs/torch_lib/e2e_ops_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 057947cec1..5348ac7a9a 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -520,9 +520,9 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) - def test_my_index_put(self): + def test_index_put(self): def test(x_shape, index_list, update_shape, testname): - with self.subTest(testname="my_index_put_" + testname): + with self.subTest(testname="index_put_" + testname): indices = [ (torch.tensor(index, dtype=torch.int64) if index is not None else None) for index in index_list From 3641a1edd2be5e08277dd772cd44b08fddf03a34 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 11 Dec 2025 11:46:15 -0800 Subject: [PATCH 6/6] Tests Signed-off-by: Justin Chu --- .../function_libs/torch_lib/e2e_ops_tests.py | 296 ++++++++++++++---- 1 file changed, 234 insertions(+), 62 deletions(-) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index c7a8abc661..d344723408 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -2,8 +2,11 @@ # Licensed under the MIT License. from __future__ import annotations +import math import unittest +import parameterized + # TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo import torch from torch.onnx._internal.exporter import _testing @@ -626,78 +629,247 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) - def test_index_put(self): - def test(x_shape, index_list, update_shape, testname): - with self.subTest(testname="index_put_" + testname): - indices = [ - (torch.tensor(index, dtype=torch.int64) if index is not None else None) - for index in index_list - ] + @parameterized.parameterized.expand( + [ + # Multiple advanced indices, all 1D tensors. + # Non-contiguous advanced indices: updates must be broadcastable to (2, 6) + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (2, 6), + "non_contiguous_non_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (2, 1), + "non_contiguous_non_broadcast_indices_expand_dim2", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (1, 6), + "non_contiguous_non_broadcast_indices_expand_dim1", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (6,), + "non_contiguous_non_broadcast_indices_new_dim1", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (), + "non_contiguous_non_broadcast_indices_scalar", + ), + # Contiguous advanced indices versions of above tests: updates must be broadcastable to (6, 2) + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (6, 2), + "contiguous_non_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (6, 1), + "contiguous_non_broadcast_indices_expand_dim2", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (1, 2), + "contiguous_non_broadcast_indices_expand_dim1", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (2,), + "contiguous_non_broadcast_indices_new_dim1", + ), + ((6, 6, 6), [None, [0, 1], [2, 3]], (), "contiguous_non_broadcast_indices_scalar"), + # Multiple advanced indices, with broadcasting among indices. + # Contiguous advanced indices: + # This produces index tuples [(0,2), (0, 3), (1,2), (1,3)] in shape (2,2) + # The update values must be broadcastable to (6,2,2) + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (6, 2, 2), + "contiguous_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (6, 1, 1), + "contiguous_broadcast_indices_expand_dim2_dim3", + ), + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (2,), + "contiguous_broadcast_indices_extend_dim1_dim2", + ), + # Non-contiguous advanced indices versions of above tests: + # Here, update values must be broadcastable to (2,2,6) + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (2, 2, 6), + "non_contiguous_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (1, 1, 6), + "non_contiguous_broadcast_indices_expand_dim1_dim2", + ), + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (6,), + "non_contiguous_broadcast_indices_extend_dim1_dim2", + ), + # Other test cases + ( + (4, 4, 4, 4), + [None, [0, 1], None, [2, 3]], + (2, 4, 4), + "non_contiguous_non_first", + ), + ((6, 6, 6), [0, None, None], (6, 6), "single_scalar_index"), + ((6, 6, 6), [0, None, [0, 1]], (2, 6), "non_contiguous_scalar_index_and_1d_index"), + ((6, 6, 6), [None, 0, [0, 1]], (6, 2), "contiguous_scalar_index_and_1d_index"), + # (TODO): Exporter doesn't yet support all None indices + # ((6, 6, 6), [None, None, None], (6, 6, 6), "all_none_indices"), + ] + ) + def test_index_put(self, x_shape, index_list, update_shape, _: str): + indices = [ + (torch.tensor(index, dtype=torch.int64) if index is not None else None) + for index in index_list + ] + + class Model(torch.nn.Module): + def forward(self, x, update): + return torch.ops.aten.index_put(x, indices, update, accumulate=True) - class Model(torch.nn.Module): - def forward(self, x, update): - return torch.ops.aten.index_put(x, indices, update, accumulate=True) + x = torch.zeros(x_shape, dtype=torch.float32) + update = torch.randn(update_shape, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x, update), + input_names=["x", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) - x = torch.zeros(x_shape, dtype=torch.float32) - update = torch.randn(update_shape, dtype=torch.float32) + def test_index_put_dynamic(self): + for dimension in [3, 4, 2]: + with self.subTest(dimension=dimension): + class Model(torch.nn.Module): + def __init__(self, dimension): + super().__init__() + self.params = torch.zeros( + (4, 5) + if dimension == 2 + else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5)) + ) + self.dimension = dimension + + def forward(self, update, index1, index2): + copy = self.params.clone() + if self.dimension == 2: + copy[index1, index2] = update + elif self.dimension == 3: + copy[:, index1, index2] = update + else: + copy[:, :, index1, index2] = update + return copy + + update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32) + index1 = torch.tensor([1, 2], dtype=torch.int64) + index2 = torch.tensor([3, 4], dtype=torch.int64) + feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2))) onnx_program = torch.onnx.export( - Model(), - (x, update), - input_names=["x", "update"], + Model(dimension), + tuple(feeds.values()), + input_names=["update", "index1", "index2"], output_names=["output"], opset_version=18, dynamo=True, + dynamic_shapes={ + "update": {0: "dn"}, + "index1": {0: "dn"}, + "index2": {0: "dn"}, + }, ) _testing.assert_onnx_program(onnx_program) - # Test cases - shape_6x6x6 = (6, 6, 6) - shape_4x4x4x4 = (4, 4, 4, 4) - - # Multiple advanced indices, all 1D tensors. - # Non-contiguous advanced indices: updates must be broadcastable to (2, 6) - base = "non_contiguous_non_broadcast_indices_" - test(shape_6x6x6, [[0, 1], None, [2, 3]], (2, 6), base + "no_value_broadcast") - test(shape_6x6x6, [[0, 1], None, [2, 3]], (2, 1), base + "expand_dim2") - test(shape_6x6x6, [[0, 1], None, [2, 3]], (1, 6), base + "expand_dim1") - test(shape_6x6x6, [[0, 1], None, [2, 3]], (6,), base + "new_dim1") - test(shape_6x6x6, [[0, 1], None, [2, 3]], (), base + "scalar") - - # Contiguous advanced indices versions of above tests: updates must be broadcastable to (6, 2) - base = "contiguous_non_broadcast_indices_" - test(shape_6x6x6, [None, [0, 1], [2, 3]], (6, 2), base + "no_value_broadcast") - test(shape_6x6x6, [None, [0, 1], [2, 3]], (6, 1), base + "expand_dim2") - test(shape_6x6x6, [None, [0, 1], [2, 3]], (1, 2), base + "expand_dim1") - test(shape_6x6x6, [None, [0, 1], [2, 3]], (2,), base + "new_dim1") - test(shape_6x6x6, [None, [0, 1], [2, 3]], (), base + "scalar") - - # Multiple advanced indices, with broadcasting among indices. - # Contiguous advanced indices: - # This produces index tuples [(0,2), (0, 3), (1,2), (1,3)] in shape (2,2) - # The update values must be broadcastable to (6,2,2) - base = "contiguous_broadcast_indices_" - test(shape_6x6x6, [None, [[0], [1]], [2, 3]], (6, 2, 2), base + "no_value_broadcast") - test(shape_6x6x6, [None, [[0], [1]], [2, 3]], (6, 1, 1), base + "expand_dim2_dim3") - test(shape_6x6x6, [None, [[0], [1]], [2, 3]], (2,), base + "extend_dim1_dim2") - - # Non-contiguous advanced indices versions of above tests: - # Here, update values must be broadcastable to (2,2,6) - base = "non_contiguous_broadcast_indices_" - test(shape_6x6x6, [[[0], [1]], None, [2, 3]], (2, 2, 6), base + "no_value_broadcast") - test(shape_6x6x6, [[[0], [1]], None, [2, 3]], (1, 1, 6), base + "expand_dim1_dim2") - test(shape_6x6x6, [[[0], [1]], None, [2, 3]], (6,), base + "extend_dim1_dim2") - - test( - shape_4x4x4x4, [None, [0, 1], None, [2, 3]], (2, 4, 4), "non_contiguous_non_first" - ) - test(shape_6x6x6, [0, None, None], (6, 6), "single_scalar_index") - test( - shape_6x6x6, [0, None, [0, 1]], (2, 6), "non_contiguous_scalar_index_and_1d_index" - ) - test(shape_6x6x6, [None, 0, [0, 1]], (6, 2), "contiguous_scalar_index_and_1d_index") - # (TODO): Exporter doesn't yet support all None indices - # test(shape_6x6x6, [None, None, None], shape_6x6x6, "all_none_indices") + def test_index_put_55_12_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update) + + x = torch.zeros((6, 5), dtype=torch.float32) + index = torch.tensor([[2, 1]], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_55_2_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update, accumulate=True) + + x = torch.ones((6, 5), dtype=torch.float32) + index = torch.tensor([4, 3], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_scatter_nd(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + x = x.clone() + return torch.ops.aten.index_put(x, [None, index, None], update) + + shape = (2, 3, 2) + N = math.prod(shape) + x = torch.arange(N, dtype=torch.float32).reshape(shape) + update = (torch.arange(N, dtype=torch.float32).reshape(shape) + 1) * 100 + index = ((torch.arange(shape[-2])).to(torch.int64) + 1) % shape[-2] + + feeds = dict(zip(["x", "index", "update"], (x, index, update))) + onnx_program = torch.onnx.export( + Model(), + tuple(feeds.values()), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}), + ) + _testing.assert_onnx_program(onnx_program) if __name__ == "__main__":