From e3c52176f78932b5168d48d86df14490de701102 Mon Sep 17 00:00:00 2001 From: rahulc7 Date: Mon, 22 Dec 2025 07:23:03 -0800 Subject: [PATCH 1/6] Adding test for CadenceWith16BitConvActivationsQuantizer Summary: Add annotation tests for CadenceWith16BitConvActivationsQuantizer covering both conv1d and conv2d operations. Differential Revision: D88895865 --- .../cadence/aot/tests/test_quantizer_ops.py | 67 ++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 99953346b05..871857196a6 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -56,7 +56,6 @@ CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage - CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage } @@ -93,6 +92,24 @@ # For linear: [input_activation, weight] [qconfig_A16.input_activation, qconfig_A16.weight], ), + ( + "conv1d_A16", + lambda self: self._build_conv1d_graph(), + CadenceWith16BitConvActivationsQuantizer(), + torch.ops.aten.conv1d.default, + qconfig_A16.output_activation, + # For conv1d: [input_activation, weight] + [qconfig_A16.input_activation, qconfig_A16.weight], + ), + ( + "conv2d_A16", + lambda self: self._build_conv2d_graph(), + CadenceWith16BitConvActivationsQuantizer(), + torch.ops.aten.conv2d.default, + qconfig_A16.output_activation, + # For conv2d: [input_activation, weight] + [qconfig_A16.input_activation, qconfig_A16.weight], + ), ] # Derive the set of tested quantizer classes from the test cases. @@ -149,6 +166,54 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node") return gm, linear_nodes[0] + def _build_conv1d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a conv1d operation (no bias).""" + builder = GraphBuilder() + # Input shape: (batch, in_channels, length) + x = builder.placeholder("x", torch.randn(1, 3, 10)) + # Weight shape: (out_channels, in_channels, kernel_size) + weight = builder.placeholder("weight", torch.randn(6, 3, 3)) + conv1d = builder.call_operator( + op=torch.ops.aten.conv1d.default, + args=(x, weight), + meta=NodeMetadata( + {"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]} + ), + ) + builder.output([conv1d]) + gm = builder.get_graph_module() + + conv1d_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.conv1d.default, + ) + self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node") + return gm, conv1d_nodes[0] + + def _build_conv2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a conv2d operation (no bias).""" + builder = GraphBuilder() + # Input shape: (batch, in_channels, height, width) + x = builder.placeholder("x", torch.randn(1, 3, 8, 8)) + # Weight shape: (out_channels, in_channels, kernel_h, kernel_w) + weight = builder.placeholder("weight", torch.randn(6, 3, 3, 3)) + conv2d = builder.call_operator( + op=torch.ops.aten.conv2d.default, + args=(x, weight), + meta=NodeMetadata( + {"source_fn_stack": [("conv2d", torch.ops.aten.conv2d.default)]} + ), + ) + builder.output([conv2d]) + gm = builder.get_graph_module() + + conv2d_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.conv2d.default, + ) + self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node") + return gm, conv2d_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, From 6ce0bec28930245ba3847e847ed1106ba53448fa Mon Sep 17 00:00:00 2001 From: rahulc7 Date: Mon, 22 Dec 2025 07:23:03 -0800 Subject: [PATCH 2/6] Adding test for CadenceWithLayerNormQuantizer Differential Revision: D88896712 --- .../cadence/aot/tests/test_quantizer_ops.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 871857196a6..9c4bca00965 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -57,7 +57,6 @@ CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage - CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage } @@ -110,6 +109,15 @@ # For conv2d: [input_activation, weight] [qconfig_A16.input_activation, qconfig_A16.weight], ), + ( + "softmax_A16", + lambda self: self._build_softmax_graph(), + CadenceWithSoftmaxQuantizer(), + torch.ops.aten._softmax.default, + qconfig_A16.output_activation, + # For softmax: only input_activation + [qconfig_A16.input_activation], + ), ] # Derive the set of tested quantizer classes from the test cases. @@ -214,6 +222,27 @@ def _build_conv2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node") return gm, conv2d_nodes[0] + def _build_softmax_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a softmax operation.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + softmax = builder.call_operator( + op=torch.ops.aten._softmax.default, + args=(x, -1, False), # dim=-1, half_to_float=False + meta=NodeMetadata( + {"source_fn_stack": [("softmax", torch.ops.aten._softmax.default)]} + ), + ) + builder.output([softmax]) + gm = builder.get_graph_module() + + softmax_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten._softmax.default, + ) + self.assertEqual(len(softmax_nodes), 1, "Should find exactly one softmax node") + return gm, softmax_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, From 399f502c8614bafaeb89ebfd368be3118fa4f8ec Mon Sep 17 00:00:00 2001 From: rahulc7 Date: Mon, 22 Dec 2025 07:23:03 -0800 Subject: [PATCH 3/6] Adding Test for CadenceWithLayerNormQuantizer Differential Revision: D88898823 --- .../cadence/aot/tests/test_quantizer_ops.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 9c4bca00965..01053fd9d5a 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -56,7 +56,6 @@ CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage - CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage } @@ -118,6 +117,15 @@ # For softmax: only input_activation [qconfig_A16.input_activation], ), + ( + "layer_norm_A8W8", + lambda self: self._build_layer_norm_graph(), + CadenceWithLayerNormQuantizer(), + torch.ops.aten.layer_norm.default, + qconfig_A8W8.output_activation, + # For layer_norm: only input_activation (weights/bias are passed as others) + [qconfig_A8W8.input_activation], + ), ] # Derive the set of tested quantizer classes from the test cases. @@ -243,6 +251,32 @@ def _build_softmax_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(softmax_nodes), 1, "Should find exactly one softmax node") return gm, softmax_nodes[0] + def _build_layer_norm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a layer_norm operation.""" + builder = GraphBuilder() + # Input shape: (batch, features) + x = builder.placeholder("x", torch.randn(1, 10)) + # normalized_shape must match the last dimension(s) of input + normalized_shape = [10] + layer_norm = builder.call_operator( + op=torch.ops.aten.layer_norm.default, + args=(x, normalized_shape), + meta=NodeMetadata( + {"source_fn_stack": [("layer_norm", torch.ops.aten.layer_norm.default)]} + ), + ) + builder.output([layer_norm]) + gm = builder.get_graph_module() + + layer_norm_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.layer_norm.default, + ) + self.assertEqual( + len(layer_norm_nodes), 1, "Should find exactly one layer_norm node" + ) + return gm, layer_norm_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, From fd9071e007d16257f0457176aa63d26769035a16 Mon Sep 17 00:00:00 2001 From: rahulc7 Date: Mon, 22 Dec 2025 07:23:03 -0800 Subject: [PATCH 4/6] Adding test for CadenceWakeWordQuantizer Differential Revision: D88898933 --- .../cadence/aot/tests/test_quantizer_ops.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 01053fd9d5a..4e05a33959c 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -55,7 +55,6 @@ CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition - CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage } @@ -126,6 +125,15 @@ # For layer_norm: only input_activation (weights/bias are passed as others) [qconfig_A8W8.input_activation], ), + ( + "add_A8W8", + lambda self: self._build_add_graph(), + CadenceWakeWordQuantizer(), + torch.ops.aten.add.Tensor, + qconfig_A8W8.output_activation, + # For add: both inputs are activations + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], + ), ] # Derive the set of tested quantizer classes from the test cases. @@ -277,6 +285,28 @@ def _build_layer_norm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: ) return gm, layer_norm_nodes[0] + def _build_add_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with an add operation.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + y = builder.placeholder("y", torch.randn(1, 10)) + add = builder.call_operator( + op=torch.ops.aten.add.Tensor, + args=(x, y), + meta=NodeMetadata( + {"source_fn_stack": [("add", torch.ops.aten.add.Tensor)]} + ), + ) + builder.output([add]) + gm = builder.get_graph_module() + + add_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.add.Tensor, + ) + self.assertEqual(len(add_nodes), 1, "Should find exactly one add node") + return gm, add_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, From 74805a1164083c5fd9f1e011b627e19fd9f009ee Mon Sep 17 00:00:00 2001 From: rahulc7 Date: Mon, 22 Dec 2025 07:23:03 -0800 Subject: [PATCH 5/6] Adding Tests for CadenceDefaultQuantizer Differential Revision: D88899457 --- .../cadence/aot/tests/test_quantizer_ops.py | 101 +++++++++++++++++- 1 file changed, 100 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 4e05a33959c..37c9733cd3c 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -30,6 +30,7 @@ CadenceWithSoftmaxQuantizer, qconfig_A16, qconfig_A8W8, + qconfig_A8W8sym, ) from executorch.exir.pass_base import NodeMetadata from parameterized import parameterized @@ -50,7 +51,6 @@ # Quantizers intentionally excluded from annotation testing. # These should be explicitly justified when added. EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = { - CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage @@ -134,6 +134,61 @@ # For add: both inputs are activations [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], ), + # CadenceDefaultQuantizer test cases + ( + "default_matmul_A8W8", + lambda self: self._build_matmul_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.matmul.default, + qconfig_A8W8.output_activation, + # For matmul: both inputs are activations + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], + ), + ( + "default_linear_A8W8", + lambda self: self._build_linear_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.linear.default, + qconfig_A8W8.output_activation, + # For linear: [input_activation, weight] + [qconfig_A8W8.input_activation, qconfig_A8W8.weight], + ), + ( + "default_conv1d_A8W8sym", + lambda self: self._build_conv1d_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.conv1d.default, + qconfig_A8W8sym.output_activation, + # For conv1d: [input_activation, weight] with symmetric weights + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], + ), + ( + "default_conv2d_A8W8sym", + lambda self: self._build_conv2d_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.conv2d.default, + qconfig_A8W8sym.output_activation, + # For conv2d: [input_activation, weight] with symmetric weights + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], + ), + ( + "default_bmm_A8W8", + lambda self: self._build_bmm_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.bmm.default, + qconfig_A8W8.output_activation, + # For bmm: both inputs are activations + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], + ), + ( + "default_relu_A8W8", + lambda self: self._build_relu_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.relu.default, + qconfig_A8W8.output_activation, + # For relu: only input_activation + [qconfig_A8W8.input_activation], + ), ] # Derive the set of tested quantizer classes from the test cases. @@ -307,6 +362,50 @@ def _build_add_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(add_nodes), 1, "Should find exactly one add node") return gm, add_nodes[0] + def _build_bmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a bmm (batch matrix multiply) operation.""" + builder = GraphBuilder() + # BMM requires 3D tensors: (batch, n, m) @ (batch, m, p) -> (batch, n, p) + x = builder.placeholder("x", torch.randn(2, 4, 8)) + y = builder.placeholder("y", torch.randn(2, 8, 4)) + bmm = builder.call_operator( + op=torch.ops.aten.bmm.default, + args=(x, y), + meta=NodeMetadata( + {"source_fn_stack": [("bmm", torch.ops.aten.bmm.default)]} + ), + ) + builder.output([bmm]) + gm = builder.get_graph_module() + + bmm_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.bmm.default, + ) + self.assertEqual(len(bmm_nodes), 1, "Should find exactly one bmm node") + return gm, bmm_nodes[0] + + def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with a relu operation.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 10)) + relu = builder.call_operator( + op=torch.ops.aten.relu.default, + args=(x,), + meta=NodeMetadata( + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} + ), + ) + builder.output([relu]) + gm = builder.get_graph_module() + + relu_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.relu.default, + ) + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") + return gm, relu_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, From bafed325f39bf817b80471d537cbd53dca5303ba Mon Sep 17 00:00:00 2001 From: Rahul Chandra Date: Mon, 22 Dec 2025 07:44:34 -0800 Subject: [PATCH 6/6] Changing logic to deal with graphs with derived quantization spec (#16357) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/16357 We want to add a test for `default_addmm_A8W8` to fully finish testing `CadenceDefaultQuantizer`. However there are a couple changes we need to make to the testing function. ## Change 1: We allow passing `None` in the vec of `QuantizationSpec` This is because the addmm op has 3 inputs: `bias`, `mat1`, `mat2`. The bias uses a `DerivedQuantizationSpec`, which is dynamically constructed with references to the actual graph nodes (`mat1` and `mat2`). We can't construct an identical `DerivedQuantizationSpec` in the test because we'd need to reference the exact same node objects that the quantizer creates internally. Since we can't compare it directly, we use `None` to skip validation for that input. If `mat1` and `mat2` are quantized correctly, the derived bias spec will be correct too. https://www.internalfb.com/code/fbsource/[2cfdb40fd8b628da2f46366115516408cfb9f50f]/xplat/executorch/backends/cadence/aot/quantizer/patterns.py?lines=91-103 ## Change 2: We changed how we iterate through `input_qspec_map` `input_qspec_map` is a dictionary mapping input nodes to their `qspecs`. The iteration order depends on insertion order, which follows how the quantizer processes `PartitionAnchors`. Each `QuantizationPattern` implements a `get_anchors()` method that returns a `PartitionAnchors` describing which arguments are inputs, weights, biases and nodes. This is relevant because for `addmm`, the `PartitionAnchors` lists them as `inputs=[(node, 1)], weights=[(node, 2)], biases=[(node, 0, ...)]. ` So the map might iterate in order `mat1, mat2, bias` (args indices 1, 2, 0) rather than `bias, mat1, mat2` (args indices 0, 1, 2). This means that our previous way of iterating wouldn't work. Thus, we now use the following way to iterate: ``` for input_node, input_qspec in annotation.input_qspec_map.items(): // Find the index of this input node in the op's args arg_index = None for i, arg in enumerate(op_node.args): if arg is input_node: arg_index = i break self.assertIsNotNone( arg_index, f"Input node {input_node} not found in op_node.args", ) # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec) if expected_input_qspecs[arg_index] is not None: self.assertEqual( input_qspec, expected_input_qspecs[arg_index], f"Input qspec mismatch at arg index {arg_index}", ) ``` The new code looks up which argument index each input_node corresponds to by searching through `op_node.args`, rather than assuming the enumeration index i matches the argument position. Reviewed By: hsharma35 Differential Revision: D88955761 --- .../cadence/aot/tests/test_quantizer_ops.py | 71 ++++++++++++++----- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 37c9733cd3c..2f66f5fbaa4 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -61,6 +61,7 @@ # Test case definitions for quantizer annotation tests. # Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs) # Adding a new quantizer test only requires adding a tuple to this list. +# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec). QUANTIZER_ANNOTATION_TEST_CASES: list[ tuple[ str, @@ -68,7 +69,7 @@ CadenceQuantizer, OpOverload, QuantizationSpec, - list[QuantizationSpec], + list[QuantizationSpec | None], ] ] = [ ( @@ -189,6 +190,16 @@ # For relu: only input_activation [qconfig_A8W8.input_activation], ), + ( + "default_addmm_A8W8", + lambda self: self._build_addmm_graph(), + CadenceDefaultQuantizer(), + torch.ops.aten.addmm.default, + qconfig_A8W8.output_activation, + # For addmm: [bias (DerivedQuantizationSpec), mat1, mat2] + # Use None to skip comparison for bias since it's a DerivedQuantizationSpec + [None, qconfig_A8W8.input_activation, qconfig_A8W8.weight], + ), ] # Derive the set of tested quantizer classes from the test cases. @@ -406,6 +417,31 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") return gm, relu_nodes[0] + def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: + """Build a simple graph with an addmm operation.""" + builder = GraphBuilder() + # addmm: bias + (mat1 @ mat2) + # args: (bias, mat1, mat2) + bias = builder.placeholder("bias", torch.randn(5)) + mat1 = builder.placeholder("mat1", torch.randn(1, 10)) + mat2 = builder.placeholder("mat2", torch.randn(10, 5)) + addmm = builder.call_operator( + op=torch.ops.aten.addmm.default, + args=(bias, mat1, mat2), + meta=NodeMetadata( + {"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]} + ), + ) + builder.output([addmm]) + gm = builder.get_graph_module() + + addmm_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.addmm.default, + ) + self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node") + return gm, addmm_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, @@ -414,7 +450,7 @@ def test_quantizer_annotation( quantizer: CadenceQuantizer, target: OpOverload, expected_output_qspec: QuantizationSpec, - expected_input_qspecs: list[QuantizationSpec], + expected_input_qspecs: list[QuantizationSpec | None], ) -> None: """Parameterized test for quantizer annotations.""" gm, op_node = graph_builder_fn(self) @@ -429,21 +465,24 @@ def test_quantizer_annotation( # Verify input annotations self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs)) - for i, (input_node, input_qspec) in enumerate( - annotation.input_qspec_map.items() - ): - expected_arg = op_node.args[i] - assert isinstance(expected_arg, torch.fx.Node) - self.assertEqual( - input_node, - expected_arg, - f"Input node mismatch at index {i}", - ) - self.assertEqual( - input_qspec, - expected_input_qspecs[i], - f"Input qspec mismatch at index {i}", + for input_node, input_qspec in annotation.input_qspec_map.items(): + # Find the index of this input node in the op's args + arg_index = None + for i, arg in enumerate(op_node.args): + if arg is input_node: + arg_index = i + break + self.assertIsNotNone( + arg_index, + f"Input node {input_node} not found in op_node.args", ) + # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec) + if expected_input_qspecs[arg_index] is not None: + self.assertEqual( + input_qspec, + expected_input_qspecs[arg_index], + f"Input qspec mismatch at arg index {arg_index}", + ) def test_all_quantizers_have_annotation_tests(self) -> None: """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""