Skip to content

Commit b2893dd

Browse files
RahulC7meta-codesync[bot]
authored andcommitted
Changing logic to deal with graphs with derived quantization spec
Differential Revision: D88955761
1 parent 74805a1 commit b2893dd

File tree

1 file changed

+55
-16
lines changed

1 file changed

+55
-16
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,15 @@
6161
# Test case definitions for quantizer annotation tests.
6262
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
6363
# Adding a new quantizer test only requires adding a tuple to this list.
64+
# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
6465
QUANTIZER_ANNOTATION_TEST_CASES: list[
6566
tuple[
6667
str,
6768
GraphBuilderFn,
6869
CadenceQuantizer,
6970
OpOverload,
7071
QuantizationSpec,
71-
list[QuantizationSpec],
72+
list[QuantizationSpec | None],
7273
]
7374
] = [
7475
(
@@ -189,6 +190,16 @@
189190
# For relu: only input_activation
190191
[qconfig_A8W8.input_activation],
191192
),
193+
(
194+
"default_addmm_A8W8",
195+
lambda self: self._build_addmm_graph(),
196+
CadenceDefaultQuantizer(),
197+
torch.ops.aten.addmm.default,
198+
qconfig_A8W8.output_activation,
199+
# For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
200+
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
201+
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
202+
),
192203
]
193204

194205
# Derive the set of tested quantizer classes from the test cases.
@@ -406,6 +417,31 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
406417
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
407418
return gm, relu_nodes[0]
408419

420+
def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
421+
"""Build a simple graph with an addmm operation."""
422+
builder = GraphBuilder()
423+
# addmm: bias + (mat1 @ mat2)
424+
# args: (bias, mat1, mat2)
425+
bias = builder.placeholder("bias", torch.randn(5))
426+
mat1 = builder.placeholder("mat1", torch.randn(1, 10))
427+
mat2 = builder.placeholder("mat2", torch.randn(10, 5))
428+
addmm = builder.call_operator(
429+
op=torch.ops.aten.addmm.default,
430+
args=(bias, mat1, mat2),
431+
meta=NodeMetadata(
432+
{"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]}
433+
),
434+
)
435+
builder.output([addmm])
436+
gm = builder.get_graph_module()
437+
438+
addmm_nodes = gm.graph.find_nodes(
439+
op="call_function",
440+
target=torch.ops.aten.addmm.default,
441+
)
442+
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
443+
return gm, addmm_nodes[0]
444+
409445
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
410446
def test_quantizer_annotation(
411447
self,
@@ -414,7 +450,7 @@ def test_quantizer_annotation(
414450
quantizer: CadenceQuantizer,
415451
target: OpOverload,
416452
expected_output_qspec: QuantizationSpec,
417-
expected_input_qspecs: list[QuantizationSpec],
453+
expected_input_qspecs: list[QuantizationSpec | None],
418454
) -> None:
419455
"""Parameterized test for quantizer annotations."""
420456
gm, op_node = graph_builder_fn(self)
@@ -429,21 +465,24 @@ def test_quantizer_annotation(
429465

430466
# Verify input annotations
431467
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
432-
for i, (input_node, input_qspec) in enumerate(
433-
annotation.input_qspec_map.items()
434-
):
435-
expected_arg = op_node.args[i]
436-
assert isinstance(expected_arg, torch.fx.Node)
437-
self.assertEqual(
438-
input_node,
439-
expected_arg,
440-
f"Input node mismatch at index {i}",
441-
)
442-
self.assertEqual(
443-
input_qspec,
444-
expected_input_qspecs[i],
445-
f"Input qspec mismatch at index {i}",
468+
for input_node, input_qspec in annotation.input_qspec_map.items():
469+
# Find the index of this input node in the op's args
470+
arg_index = None
471+
for i, arg in enumerate(op_node.args):
472+
if arg is input_node:
473+
arg_index = i
474+
break
475+
self.assertIsNotNone(
476+
arg_index,
477+
f"Input node {input_node} not found in op_node.args",
446478
)
479+
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
480+
if expected_input_qspecs[arg_index] is not None:
481+
self.assertEqual(
482+
input_qspec,
483+
expected_input_qspecs[arg_index],
484+
f"Input qspec mismatch at arg index {arg_index}",
485+
)
447486

448487
def test_all_quantizers_have_annotation_tests(self) -> None:
449488
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""

0 commit comments

Comments
 (0)