diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 798c25c9ee0..52f14d326dd 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -113,6 +113,9 @@ from .replace_scalar_with_tensor_pass import ( # noqa ReplaceScalarWithTensorByProfilePass, ) +from .rewrite_bool_to_fp32_cast_via_int8_pass import ( # noqa + RewriteBoolToFp32CastViaInt8Pass, +) from .rewrite_conv_pass import RewriteConvPass # noqa from .rewrite_matmul import RewriteMatmulPass # noqa from .rewrite_upsample import RewriteUpsamplePass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 7c77d779cfa..b18bd39c28d 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -103,6 +103,7 @@ RemoveNoopPass, ReplaceInfAndLimitValuesPass, ReplaceScalarWithTensorByProfilePass, + RewriteBoolToFp32CastViaInt8Pass, RewriteConvPass, RewriteMatmulPass, RewriteUpsamplePass, @@ -221,6 +222,7 @@ def _tosa_pipeline( self.add_passes( [ FuseQuantizedActivationPass(), + RewriteBoolToFp32CastViaInt8Pass(), ConvertToClampPass(), DecomposeTOSAUnsupportedClampPass(), DecomposeGroupNormPass(), diff --git a/backends/arm/_passes/rewrite_bool_to_fp32_cast_via_int8_pass.py b/backends/arm/_passes/rewrite_bool_to_fp32_cast_via_int8_pass.py new file mode 100644 index 00000000000..4f8eeb8d5ac --- /dev/null +++ b/backends/arm/_passes/rewrite_bool_to_fp32_cast_via_int8_pass.py @@ -0,0 +1,77 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, + set_node_arg, +) +from executorch.backends.arm.tosa.specification import get_context_spec +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RewriteBoolToFp32CastViaInt8Pass(ArmPass): + """ + Legalizes unsupported bool->fp32 to_dim_order_copy casts for the Arm TOSA + backend when both integer and float TOSA profiles are enabled. + + For the combined INT+FP profile, this pass rewrites a single bool->fp32 cast + into a bool->int8 cast followed by an int8->fp32 cast, so that each cast + is individually supported by the TOSA INT and FP profiles. For other + profiles (INT-only or FP-only) the pass is a no-op. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + targeted_ops = {exir_ops.edge.dim_order_ops._to_dim_order_copy.default} + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + + tosa_spec = get_context_spec() + if not (tosa_spec.support_integer() and tosa_spec.support_float()): + return PassResult(graph_module, modified) + + graph = graph_module.graph + for node in graph.nodes: + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + + input_node = node.all_input_nodes[0] + input_dtype = get_first_fake_tensor(input_node).dtype + if input_dtype != torch.bool: + continue + + output_dtype = get_first_fake_tensor(node).dtype + if output_dtype != torch.float32: + continue + + set_node_arg(node, "dtype", torch.int8) + + users = list(node.users) + with graph.inserting_after(node): + cast_after = create_node( + graph, + node.target, + args=(node,), + kwargs={ + "dtype": torch.float32, + }, + ) + for user in users: + user.replace_input_with(node, cast_after) + modified = True + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/operator_support/to_dim_order_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py index 48f0c4d8604..bd600a4df2c 100644 --- a/backends/arm/operator_support/to_dim_order_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -117,6 +117,9 @@ def _merge_supported_types( torch.float32, ], } + SUPPORTED_INT_FP_PROFILE_DTYPES: SupportedTypeDict = { + torch.bool: [torch.float32], + } def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification @@ -137,6 +140,10 @@ def is_node_tosa_supported( supported_dtypes = self._merge_supported_types( self.SUPPORTED_FP_PROFILE_DTYPES, supported_dtypes ) + if tosa_spec.support_integer() and tosa_spec.support_float(): + supported_dtypes = self._merge_supported_types( + self.SUPPORTED_INT_FP_PROFILE_DTYPES, supported_dtypes + ) if len(node.all_input_nodes) != 1: self.reporter.report_reject( diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index 17db2c3f226..114051d3877 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -258,3 +258,39 @@ def test_to_u55_INT(test_data: Tuple): non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty ) pipeline.run() + + +_TO_COPY_TEST_DATA_INT_FP = { + "bool_fp32": lambda: ( + torch.tensor([True, False], dtype=torch.bool), + torch.float32, + ), +} + + +@common.parametrize("test_data", _TO_COPY_TEST_DATA_INT_FP) +@common.SkipIfNoModelConverter +def test_to_vgf_no_quant_bool_fp32(test_data: Tuple): + test_tensor, new_dtype = test_data() + pipeline = VgfPipeline[input_t1]( + Cast(new_dtype), + (test_tensor,), + aten_op=[], + exir_op=[], + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", _TO_COPY_TEST_DATA_INT_FP) +@common.SkipIfNoModelConverter +def test_to_vgf_quant_bool_fp32(test_data: Tuple): + test_tensor, new_dtype = test_data() + pipeline = VgfPipeline[input_t1]( + Cast(new_dtype), + (test_tensor,), + aten_op=[], + exir_op=[], + quantize=True, + ) + pipeline.run()