Skip to content

Commit bafed32

Browse files
RahulC7meta-codesync[bot]
authored andcommitted
Changing logic to deal with graphs with derived quantization spec (#16357)
Summary: Pull Request resolved: #16357 We want to add a test for `default_addmm_A8W8` to fully finish testing `CadenceDefaultQuantizer`. However there are a couple changes we need to make to the testing function. ## Change 1: We allow passing `None` in the vec of `QuantizationSpec` This is because the addmm op has 3 inputs: `bias`, `mat1`, `mat2`. The bias uses a `DerivedQuantizationSpec`, which is dynamically constructed with references to the actual graph nodes (`mat1` and `mat2`). We can't construct an identical `DerivedQuantizationSpec` in the test because we'd need to reference the exact same node objects that the quantizer creates internally. Since we can't compare it directly, we use `None` to skip validation for that input. If `mat1` and `mat2` are quantized correctly, the derived bias spec will be correct too. https://www.internalfb.com/code/fbsource/[2cfdb40fd8b628da2f46366115516408cfb9f50f]/xplat/executorch/backends/cadence/aot/quantizer/patterns.py?lines=91-103 ## Change 2: We changed how we iterate through `input_qspec_map` `input_qspec_map` is a dictionary mapping input nodes to their `qspecs`. The iteration order depends on insertion order, which follows how the quantizer processes `PartitionAnchors`. Each `QuantizationPattern` implements a `get_anchors()` method that returns a `PartitionAnchors` describing which arguments are inputs, weights, biases and nodes. This is relevant because for `addmm`, the `PartitionAnchors` lists them as `inputs=[(node, 1)], weights=[(node, 2)], biases=[(node, 0, ...)]. ` So the map might iterate in order `mat1, mat2, bias` (args indices 1, 2, 0) rather than `bias, mat1, mat2` (args indices 0, 1, 2). This means that our previous way of iterating wouldn't work. Thus, we now use the following way to iterate: ``` for input_node, input_qspec in annotation.input_qspec_map.items(): // Find the index of this input node in the op's args arg_index = None for i, arg in enumerate(op_node.args): if arg is input_node: arg_index = i break self.assertIsNotNone( arg_index, f"Input node {input_node} not found in op_node.args", ) # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec) if expected_input_qspecs[arg_index] is not None: self.assertEqual( input_qspec, expected_input_qspecs[arg_index], f"Input qspec mismatch at arg index {arg_index}", ) ``` The new code looks up which argument index each input_node corresponds to by searching through `op_node.args`, rather than assuming the enumeration index i matches the argument position. Reviewed By: hsharma35 Differential Revision: D88955761
1 parent 74805a1 commit bafed32

File tree

1 file changed

+55
-16
lines changed

1 file changed

+55
-16
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,15 @@
6161
# Test case definitions for quantizer annotation tests.
6262
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
6363
# Adding a new quantizer test only requires adding a tuple to this list.
64+
# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
6465
QUANTIZER_ANNOTATION_TEST_CASES: list[
6566
tuple[
6667
str,
6768
GraphBuilderFn,
6869
CadenceQuantizer,
6970
OpOverload,
7071
QuantizationSpec,
71-
list[QuantizationSpec],
72+
list[QuantizationSpec | None],
7273
]
7374
] = [
7475
(
@@ -189,6 +190,16 @@
189190
# For relu: only input_activation
190191
[qconfig_A8W8.input_activation],
191192
),
193+
(
194+
"default_addmm_A8W8",
195+
lambda self: self._build_addmm_graph(),
196+
CadenceDefaultQuantizer(),
197+
torch.ops.aten.addmm.default,
198+
qconfig_A8W8.output_activation,
199+
# For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
200+
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
201+
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
202+
),
192203
]
193204

194205
# Derive the set of tested quantizer classes from the test cases.
@@ -406,6 +417,31 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
406417
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
407418
return gm, relu_nodes[0]
408419

420+
def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
421+
"""Build a simple graph with an addmm operation."""
422+
builder = GraphBuilder()
423+
# addmm: bias + (mat1 @ mat2)
424+
# args: (bias, mat1, mat2)
425+
bias = builder.placeholder("bias", torch.randn(5))
426+
mat1 = builder.placeholder("mat1", torch.randn(1, 10))
427+
mat2 = builder.placeholder("mat2", torch.randn(10, 5))
428+
addmm = builder.call_operator(
429+
op=torch.ops.aten.addmm.default,
430+
args=(bias, mat1, mat2),
431+
meta=NodeMetadata(
432+
{"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]}
433+
),
434+
)
435+
builder.output([addmm])
436+
gm = builder.get_graph_module()
437+
438+
addmm_nodes = gm.graph.find_nodes(
439+
op="call_function",
440+
target=torch.ops.aten.addmm.default,
441+
)
442+
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
443+
return gm, addmm_nodes[0]
444+
409445
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
410446
def test_quantizer_annotation(
411447
self,
@@ -414,7 +450,7 @@ def test_quantizer_annotation(
414450
quantizer: CadenceQuantizer,
415451
target: OpOverload,
416452
expected_output_qspec: QuantizationSpec,
417-
expected_input_qspecs: list[QuantizationSpec],
453+
expected_input_qspecs: list[QuantizationSpec | None],
418454
) -> None:
419455
"""Parameterized test for quantizer annotations."""
420456
gm, op_node = graph_builder_fn(self)
@@ -429,21 +465,24 @@ def test_quantizer_annotation(
429465

430466
# Verify input annotations
431467
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
432-
for i, (input_node, input_qspec) in enumerate(
433-
annotation.input_qspec_map.items()
434-
):
435-
expected_arg = op_node.args[i]
436-
assert isinstance(expected_arg, torch.fx.Node)
437-
self.assertEqual(
438-
input_node,
439-
expected_arg,
440-
f"Input node mismatch at index {i}",
441-
)
442-
self.assertEqual(
443-
input_qspec,
444-
expected_input_qspecs[i],
445-
f"Input qspec mismatch at index {i}",
468+
for input_node, input_qspec in annotation.input_qspec_map.items():
469+
# Find the index of this input node in the op's args
470+
arg_index = None
471+
for i, arg in enumerate(op_node.args):
472+
if arg is input_node:
473+
arg_index = i
474+
break
475+
self.assertIsNotNone(
476+
arg_index,
477+
f"Input node {input_node} not found in op_node.args",
446478
)
479+
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
480+
if expected_input_qspecs[arg_index] is not None:
481+
self.assertEqual(
482+
input_qspec,
483+
expected_input_qspecs[arg_index],
484+
f"Input qspec mismatch at arg index {arg_index}",
485+
)
447486

448487
def test_all_quantizers_have_annotation_tests(self) -> None:
449488
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""

0 commit comments

Comments
 (0)