Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@
from .replace_scalar_with_tensor_pass import ( # noqa
ReplaceScalarWithTensorByProfilePass,
)
from .rewrite_bool_to_fp32_cast_via_int8_pass import ( # noqa
RewriteBoolToFp32CastViaInt8Pass,
)
from .rewrite_conv_pass import RewriteConvPass # noqa
from .rewrite_matmul import RewriteMatmulPass # noqa
from .rewrite_upsample import RewriteUpsamplePass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
RemoveNoopPass,
ReplaceInfAndLimitValuesPass,
ReplaceScalarWithTensorByProfilePass,
RewriteBoolToFp32CastViaInt8Pass,
RewriteConvPass,
RewriteMatmulPass,
RewriteUpsamplePass,
Expand Down Expand Up @@ -221,6 +222,7 @@ def _tosa_pipeline(
self.add_passes(
[
FuseQuantizedActivationPass(),
RewriteBoolToFp32CastViaInt8Pass(),
ConvertToClampPass(),
DecomposeTOSAUnsupportedClampPass(),
DecomposeGroupNormPass(),
Expand Down
77 changes: 77 additions & 0 deletions backends/arm/_passes/rewrite_bool_to_fp32_cast_via_int8_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch

from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
set_node_arg,
)
from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class RewriteBoolToFp32CastViaInt8Pass(ArmPass):
"""
Legalizes unsupported bool->fp32 to_dim_order_copy casts for the Arm TOSA
backend when both integer and float TOSA profiles are enabled.

For the combined INT+FP profile, this pass rewrites a single bool->fp32 cast
into a bool->int8 cast followed by an int8->fp32 cast, so that each cast
is individually supported by the TOSA INT and FP profiles. For other
profiles (INT-only or FP-only) the pass is a no-op.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

targeted_ops = {exir_ops.edge.dim_order_ops._to_dim_order_copy.default}

def call(self, graph_module: torch.fx.GraphModule):
modified = False

tosa_spec = get_context_spec()
if not (tosa_spec.support_integer() and tosa_spec.support_float()):
return PassResult(graph_module, modified)

graph = graph_module.graph
for node in graph.nodes:
if node.op != "call_function" or node.target not in self.targeted_ops:
continue

input_node = node.all_input_nodes[0]
input_dtype = get_first_fake_tensor(input_node).dtype
if input_dtype != torch.bool:
continue

output_dtype = get_first_fake_tensor(node).dtype
if output_dtype != torch.float32:
continue

set_node_arg(node, "dtype", torch.int8)

users = list(node.users)
with graph.inserting_after(node):
cast_after = create_node(
graph,
node.target,
args=(node,),
kwargs={
"dtype": torch.float32,
},
)
for user in users:
user.replace_input_with(node, cast_after)
modified = True

if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
7 changes: 7 additions & 0 deletions backends/arm/operator_support/to_dim_order_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def _merge_supported_types(
torch.float32,
],
}
SUPPORTED_INT_FP_PROFILE_DTYPES: SupportedTypeDict = {
torch.bool: [torch.float32],
}

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
Expand All @@ -137,6 +140,10 @@ def is_node_tosa_supported(
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_FP_PROFILE_DTYPES, supported_dtypes
)
if tosa_spec.support_integer() and tosa_spec.support_float():
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_INT_FP_PROFILE_DTYPES, supported_dtypes
)

if len(node.all_input_nodes) != 1:
self.reporter.report_reject(
Expand Down
36 changes: 36 additions & 0 deletions backends/arm/test/ops/test_to_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,39 @@ def test_to_u55_INT(test_data: Tuple):
non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty
)
pipeline.run()


_TO_COPY_TEST_DATA_INT_FP = {
"bool_fp32": lambda: (
torch.tensor([True, False], dtype=torch.bool),
torch.float32,
),
}


@common.parametrize("test_data", _TO_COPY_TEST_DATA_INT_FP)
@common.SkipIfNoModelConverter
def test_to_vgf_no_quant_bool_fp32(test_data: Tuple):
test_tensor, new_dtype = test_data()
pipeline = VgfPipeline[input_t1](
Cast(new_dtype),
(test_tensor,),
aten_op=[],
exir_op=[],
quantize=False,
)
pipeline.run()


@common.parametrize("test_data", _TO_COPY_TEST_DATA_INT_FP)
@common.SkipIfNoModelConverter
def test_to_vgf_quant_bool_fp32(test_data: Tuple):
test_tensor, new_dtype = test_data()
pipeline = VgfPipeline[input_t1](
Cast(new_dtype),
(test_tensor,),
aten_op=[],
exir_op=[],
quantize=True,
)
pipeline.run()
Loading