Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 121 additions & 66 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4541,80 +4541,135 @@ def aten_index_put(
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""

def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
while len(reshape_list) > len(values_shape) and 1 in reshape_list:
reshape_list.remove(1)

# Now ensure each dimension is broadcastable:
# This is mandatory when mixing basic and advanced indexing
# Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3)
# the reshape list should be : [[2, 1], [1, 3], [2, 1]]
for i, r in enumerate(reshape_list):
if r not in (1, values_shape[i]):
value_index = values_shape.index(r)
# Swap elements
# For the example above the current reshape list is [1, 2] for last dim,
# to make it broadcastable, we swap the elements
reshape_list[value_index], reshape_list[i] = r, 1

return reshape_list

# Ensure the number of indices matches the tensor rank.
# Ensure the number of indices matches the tensor rank by appending trailing Nones.
self_rank = len(self.shape)
if len(indices) < self_rank:
indices = list(indices) + [None] * (self_rank - len(indices))

# Get values shape
values_shape = tuple(values.shape)
# The behavior of the op is dependent on whether there are advanced indices (i.e., non-scalar tensors)
# and whether these advanced indices are contiguous.

# Identify advanced indices.
def is_advanced_index(index):
# Note: In this function, the index is assumed to be either None or an int64 Tensor.
return index is not None

advanced_indices: list[int] = []
none_indices: list[int] = []
num_advanced_indices = 0
num_none_indices = 0

for i, index in enumerate(indices):
if is_advanced_index(index):
advanced_indices.append(i)
num_advanced_indices += 1
elif index is None:
none_indices.append(i)
num_none_indices += 1
else:
raise ValueError(f"Unhandled index at position {i}: {index}")

index_vectors = []
for i in range(self_rank):
if indices[i] is None:
# For a full slice along dim i, create a range index [0, self.shape[i]).
idx = op.Range(0, self.shape[i], 1)
reshape_update = self.shape[i]
self_shape = op.Shape(self)
if num_advanced_indices == 0:
return op.Expand(values, self_shape)

# More than one advanced index may require broadcasting of index values
if num_advanced_indices > 1:
# Check for special case where all advanced indices have same shape.
# But need to ensure none of the shapes have None as a dimension, which
# will invalidate equality-based check.
first_shape = indices[advanced_indices[0]].shape

def same_shape(other_shape: ir.Shape) -> bool:
return (not any(d is None for d in other_shape)) and other_shape == first_shape

all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices)
if not all_same_shape:
# Broadcast advanced indices to a common shape.
advanced_index_rank = max(len(indices[i].shape) for i in advanced_indices)
shapes = []
for i in advanced_indices:
index = indices[i]
index_rank = len(index.shape)
index_shape = op.Shape(index)
if index_rank < advanced_index_rank:
padding = op.Constant(
value_ints=[1 for _ in range(advanced_index_rank - index_rank)]
)
index_shape = op.Concat(padding, index_shape, axis=0)
shapes.append(index_shape)
advanced_indices_shape = op.Max(*shapes)
indices = [
op.Expand(index, advanced_indices_shape) if is_advanced_index(index) else index
for index in indices
]
else:
idx = indices[i]
reshape_update = math.prod(idx.shape)
# when Index is more than 1D, flatten it and also the values shape
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
# Indices -> (2*4,) and values shape (2*4, 32)
if len(idx.shape) > 1:
values_shape = (reshape_update, *values_shape[len(idx.shape) :])

# Flatten index (always working with 1D index in each dim)
idx = op.Reshape(idx, [-1])

# Create a reshape pattern: one value per index dimension,
# with the current dimension set to the update size.
reshape_list = [1] * len(indices)
reshape_list[i] = reshape_update

# Adjust the reshape list to match the values shape.
reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape)

# Reshape and expand the index.
idx = op.Reshape(idx, reshape_list, allowzero=True)
idx = op.Expand(idx, values_shape)

# Flatten the index to 1D and unsqueeze to form a column vector.
idx = op.Reshape(idx, [-1])
idx = op.Unsqueeze(idx, axes=[1])
index_vectors.append(idx)

# Concatenate the index vectors along axis=1 to form the final indices.
new_index = op.Concat(*index_vectors, axis=1)

# Flatten values to match the indices
flat_values = op.Reshape(values, [-1])

if accumulate:
result = op.ScatterND(self, new_index, flat_values, reduction="add")
advanced_indices_shape = op.Shape(indices[advanced_indices[0]])
advanced_index_rank = len(indices[advanced_indices[0]].shape)
else:
result = op.ScatterND(self, new_index, flat_values)
advanced_indices_shape = op.Shape(indices[advanced_indices[0]])
advanced_index_rank = len(indices[advanced_indices[0]].shape)

# ONNX ScatterND supports only the case where all advanced indices appear first,
# followed by None indices. So, we need to transpose self and values so that the
# advanced indices appear first, and then transpose the result back to original
# order at the end.

none_indices_constant = op.Constant(value_ints=none_indices)
none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0)
target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0)
target_rank = advanced_index_rank + num_none_indices

# Generate indices tensor required by ONNX ScatterND by unsqueezing an extra dimension and
# concatenating all advanced indices along this new dimension.
minus_one = op.Constant(value_ints=[-1])
advanced_index_values = [op.Unsqueeze(indices[i], minus_one) for i in advanced_indices]
onnx_index = op.Concat(*advanced_index_values, axis=-1)

# Check if advanced indices are contiguous:
contiguous = True
if advanced_indices:
if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices):
contiguous = False

# Bring advanced indices to front:
perm = advanced_indices + none_indices
transposed = op.Transpose(self, perm=perm)

# Expand values to match target shape:
# First, transpose values if necessary to match advanced indices order!
if contiguous:
# values may need to be transposed before expanding to target shape
num_padded_dims = target_rank - len(values.shape)
if num_padded_dims > 0:
unsqueezed_dims = op.Constant(value_ints=list(range(num_padded_dims)))
values = op.Unsqueeze(values, unsqueezed_dims)
initial_none_index_positions = list(range(advanced_indices[0]))
advanced_index_replacement_positions = list(
range(advanced_indices[0], advanced_indices[0] + advanced_index_rank)
)
final_none_index_positions = list(
range(advanced_indices[0] + advanced_index_rank, target_rank)
)
values_perm = (
advanced_index_replacement_positions
+ initial_none_index_positions
+ final_none_index_positions
)
values = op.Transpose(values, perm=values_perm)

expanded_values = op.Expand(values, target_shape)

updated = op.ScatterND(
transposed, onnx_index, expanded_values, reduction="add" if accumulate else None
)

# Inverse transpose to restore original dimension order:

inverse_perm = [0] * self_rank
for i, p in enumerate(perm):
inverse_perm[p] = i
result = op.Transpose(updated, perm=inverse_perm)
return result


Expand Down
Loading
Loading