Skip to content

Commit 9a049ea

Browse files
RahulC7meta-codesync[bot]
authored andcommitted
Changing logic to deal with graphs with derived quantization spec
Differential Revision: D88955761
1 parent 9d7a808 commit 9a049ea

File tree

1 file changed

+55
-16
lines changed

1 file changed

+55
-16
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,15 @@
6161
# Test case definitions for quantizer annotation tests.
6262
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
6363
# Adding a new quantizer test only requires adding a tuple to this list.
64+
# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
6465
QUANTIZER_ANNOTATION_TEST_CASES: list[
6566
tuple[
6667
str,
6768
GraphBuilderFn,
6869
CadenceQuantizer,
6970
OpOverload,
7071
QuantizationSpec,
71-
list[QuantizationSpec],
72+
list[QuantizationSpec | None],
7273
]
7374
] = [
7475
(
@@ -189,6 +190,16 @@
189190
# For relu: only input_activation
190191
[qconfig_A8W8.input_activation],
191192
),
193+
(
194+
"default_addmm_A8W8",
195+
lambda self: self._build_addmm_graph(),
196+
CadenceDefaultQuantizer(),
197+
torch.ops.aten.addmm.default,
198+
qconfig_A8W8.output_activation,
199+
# For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
200+
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
201+
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
202+
),
192203
]
193204

194205
# Derive the set of tested quantizer classes from the test cases.
@@ -405,6 +416,31 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
405416
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
406417
return gm, relu_nodes[0]
407418

419+
def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
420+
"""Build a simple graph with an addmm operation."""
421+
builder = GraphBuilder()
422+
# addmm: bias + (mat1 @ mat2)
423+
# args: (bias, mat1, mat2)
424+
bias = builder.placeholder("bias", torch.randn(5))
425+
mat1 = builder.placeholder("mat1", torch.randn(1, 10))
426+
mat2 = builder.placeholder("mat2", torch.randn(10, 5))
427+
addmm = builder.call_operator(
428+
op=torch.ops.aten.addmm.default,
429+
args=(bias, mat1, mat2),
430+
meta=NodeMetadata(
431+
{"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]}
432+
),
433+
)
434+
builder.output([addmm])
435+
gm = builder.get_graph_module()
436+
437+
addmm_nodes = gm.graph.find_nodes(
438+
op="call_function",
439+
target=torch.ops.aten.addmm.default,
440+
)
441+
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
442+
return gm, addmm_nodes[0]
443+
408444
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
409445
def test_quantizer_annotation(
410446
self,
@@ -413,7 +449,7 @@ def test_quantizer_annotation(
413449
quantizer: CadenceQuantizer,
414450
target: OpOverload,
415451
expected_output_qspec: QuantizationSpec,
416-
expected_input_qspecs: list[QuantizationSpec],
452+
expected_input_qspecs: list[QuantizationSpec | None],
417453
) -> None:
418454
"""Parameterized test for quantizer annotations."""
419455
gm, op_node = graph_builder_fn(self)
@@ -428,21 +464,24 @@ def test_quantizer_annotation(
428464

429465
# Verify input annotations
430466
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
431-
for i, (input_node, input_qspec) in enumerate(
432-
annotation.input_qspec_map.items()
433-
):
434-
expected_arg = op_node.args[i]
435-
assert isinstance(expected_arg, torch.fx.Node)
436-
self.assertEqual(
437-
input_node,
438-
expected_arg,
439-
f"Input node mismatch at index {i}",
440-
)
441-
self.assertEqual(
442-
input_qspec,
443-
expected_input_qspecs[i],
444-
f"Input qspec mismatch at index {i}",
467+
for input_node, input_qspec in annotation.input_qspec_map.items():
468+
# Find the index of this input node in the op's args
469+
arg_index = None
470+
for i, arg in enumerate(op_node.args):
471+
if arg is input_node:
472+
arg_index = i
473+
break
474+
self.assertIsNotNone(
475+
arg_index,
476+
f"Input node {input_node} not found in op_node.args",
445477
)
478+
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
479+
if expected_input_qspecs[arg_index] is not None:
480+
self.assertEqual(
481+
input_qspec,
482+
expected_input_qspecs[arg_index],
483+
f"Input qspec mismatch at arg index {arg_index}",
484+
)
446485

447486
def test_all_quantizers_have_annotation_tests(self) -> None:
448487
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""

0 commit comments

Comments
 (0)