Skip to content

Commit fd92112

Browse files
RahulC7meta-codesync[bot]
authored andcommitted
Adding Tests for CadenceFusedConvReluQuantizer (#16358)
Summary: Pull Request resolved: #16358 A fused pattern is when the quantizer recognizes a sequence of operations and treats as a single unit for quantization purposes. So for example, for a Conv2D + ReLU fusion, rather than having something like this: ``` input → [quantize] → conv2d → [dequantize] → [quantize] → relu → [dequantize] → output ``` a fused pattern quantizes them together like so: ``` input → [quantize] → conv2d → relu → [dequantize] → output ``` We need to make a few changes in our framework to test this. # Change 1: We allow graph builders to return a 3rd element for fused patterns For fused patterns like conv+relu, the quantization annotations are split across two nodes: - Output annotation is on the relu node (the final output of the fused pattern) - Input annotations are on the conv node (where the quantized inputs enter) The existing graph builders return (gm, target_node), which works for single-op patterns where both annotations are on the same node. For fused patterns, we need to know both nodes, so graph builders can now optionally return (gm, output_node, input_source_node). # Change 2: We check annotations on the correct nodes for fused patterns The test previously assumed output_qspec and input_qspec_map were both on the same node. For fused patterns, they're on different nodes: - output_qspec is checked on the output node (relu) - input_qspec_map is checked on the input source node (conv) This change is backwards-compatible: for non-fused patterns, both nodes are the same. Reviewed By: hsharma35 Differential Revision: D89630759
1 parent 9a049ea commit fd92112

File tree

1 file changed

+90
-14
lines changed

1 file changed

+90
-14
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,18 @@
4343

4444
# Type alias for graph builder functions.
4545
# These functions take a test instance and return a graph module and the target op node.
46+
# For fused patterns (e.g., conv+relu), an optional third element specifies the node
47+
# whose args contain the quantized inputs (e.g., conv node for conv+relu fusion).
4648
GraphBuilderFn = Callable[
47-
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
49+
["QuantizerAnnotationTest"],
50+
tuple[torch.fx.GraphModule, torch.fx.Node]
51+
| tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node],
4852
]
4953

5054

5155
# Quantizers intentionally excluded from annotation testing.
5256
# These should be explicitly justified when added.
5357
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
54-
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
5558
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
5659
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
5760
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
@@ -200,6 +203,16 @@
200203
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
201204
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
202205
),
206+
# CadenceFusedConvReluQuantizer test cases
207+
(
208+
"fused_conv2d_relu_A8W8sym",
209+
lambda self: self._build_conv2d_relu_graph(),
210+
CadenceFusedConvReluQuantizer(),
211+
torch.ops.aten.relu.default,
212+
qconfig_A8W8sym.output_activation,
213+
# For fused conv2d+relu: [input_activation, weight] from conv2d node
214+
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
215+
),
203216
]
204217

205218
# Derive the set of tested quantizer classes from the test cases.
@@ -441,6 +454,52 @@ def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
441454
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
442455
return gm, addmm_nodes[0]
443456

457+
def _build_conv2d_relu_graph(
458+
self,
459+
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
460+
"""Build a graph with a conv2d followed by relu (fused pattern).
461+
462+
Returns:
463+
A tuple of (graph_module, relu_node, conv_node).
464+
The relu_node is the target node where the annotation is placed.
465+
The conv_node is the input source node whose args contain the quantized inputs.
466+
"""
467+
builder = GraphBuilder()
468+
# Input shape: (batch, in_channels, height, width)
469+
x = builder.placeholder("x", torch.randn(1, 3, 8, 8))
470+
# Weight shape: (out_channels, in_channels, kernel_h, kernel_w)
471+
weight = builder.placeholder("weight", torch.randn(6, 3, 3, 3))
472+
conv2d = builder.call_operator(
473+
op=torch.ops.aten.conv2d.default,
474+
args=(x, weight),
475+
meta=NodeMetadata(
476+
{"source_fn_stack": [("conv2d", torch.ops.aten.conv2d.default)]}
477+
),
478+
)
479+
relu = builder.call_operator(
480+
op=torch.ops.aten.relu.default,
481+
args=(conv2d,),
482+
meta=NodeMetadata(
483+
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
484+
),
485+
)
486+
builder.output([relu])
487+
gm = builder.get_graph_module()
488+
489+
relu_nodes = gm.graph.find_nodes(
490+
op="call_function",
491+
target=torch.ops.aten.relu.default,
492+
)
493+
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
494+
495+
conv2d_nodes = gm.graph.find_nodes(
496+
op="call_function",
497+
target=torch.ops.aten.conv2d.default,
498+
)
499+
self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node")
500+
501+
return gm, relu_nodes[0], conv2d_nodes[0]
502+
444503
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
445504
def test_quantizer_annotation(
446505
self,
@@ -452,28 +511,45 @@ def test_quantizer_annotation(
452511
expected_input_qspecs: list[QuantizationSpec | None],
453512
) -> None:
454513
"""Parameterized test for quantizer annotations."""
455-
gm, op_node = graph_builder_fn(self)
514+
result = graph_builder_fn(self)
515+
# Handle both 2-element and 3-element returns from graph builders.
516+
# For fused patterns, the 3rd element specifies the node whose args
517+
# contain the quantized inputs (e.g., conv node for conv+relu fusion).
518+
if len(result) == 3:
519+
gm = result[0]
520+
output_node = result[1]
521+
input_source_node = result[2]
522+
else:
523+
gm = result[0]
524+
output_node = result[1]
525+
input_source_node = output_node
456526

457527
quantizer.annotate(gm)
458528

459-
annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY]
460-
self.assertTrue(annotation._annotated)
461-
462-
# Verify output annotation
463-
self.assertEqual(annotation.output_qspec, expected_output_qspec)
529+
# Verify output annotation (always on the output node)
530+
output_annotation: QuantizationAnnotation = output_node.meta[Q_ANNOTATION_KEY]
531+
self.assertTrue(output_annotation._annotated)
532+
self.assertEqual(output_annotation.output_qspec, expected_output_qspec)
464533

465-
# Verify input annotations
466-
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
467-
for input_node, input_qspec in annotation.input_qspec_map.items():
468-
# Find the index of this input node in the op's args
534+
# Verify input annotations (on the input source node, which may differ for fused patterns)
535+
input_annotation: QuantizationAnnotation = input_source_node.meta[
536+
Q_ANNOTATION_KEY
537+
]
538+
self.assertEqual(
539+
len(input_annotation.input_qspec_map), len(expected_input_qspecs)
540+
)
541+
for input_node, input_qspec in input_annotation.input_qspec_map.items():
542+
# Find the index of this input node in the input source node's args
469543
arg_index = None
470-
for i, arg in enumerate(op_node.args):
544+
args = input_source_node.args
545+
assert isinstance(args, tuple)
546+
for i, arg in enumerate(args):
471547
if arg is input_node:
472548
arg_index = i
473549
break
474550
self.assertIsNotNone(
475551
arg_index,
476-
f"Input node {input_node} not found in op_node.args",
552+
f"Input node {input_node} not found in input_source_node.args",
477553
)
478554
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
479555
if expected_input_qspecs[arg_index] is not None:

0 commit comments

Comments
 (0)