Skip to content

Commit ccdb8e8

Browse files
RahulC7meta-codesync[bot]
authored andcommitted
Adding Tests for CadenceFusedConvReluQuantizer (#16358)
Summary: Pull Request resolved: #16358 A fused pattern is when the quantizer recognizes a sequence of operations and treats as a single unit for quantization purposes. So for example, for a Conv2D + ReLU fusion, rather than having something like this: ``` input → [quantize] → conv2d → [dequantize] → [quantize] → relu → [dequantize] → output ``` a fused pattern quantizes them together like so: ``` input → [quantize] → conv2d → relu → [dequantize] → output ``` We need to make a few changes in our framework to test this. # Change 1: We allow graph builders to return a 3rd element for fused patterns For fused patterns like conv+relu, the quantization annotations are split across two nodes: - Output annotation is on the relu node (the final output of the fused pattern) - Input annotations are on the conv node (where the quantized inputs enter) The existing graph builders return (gm, target_node), which works for single-op patterns where both annotations are on the same node. For fused patterns, we need to know both nodes, so graph builders can now optionally return (gm, output_node, input_source_node). # Change 2: We check annotations on the correct nodes for fused patterns The test previously assumed output_qspec and input_qspec_map were both on the same node. For fused patterns, they're on different nodes: - output_qspec is checked on the output node (relu) - input_qspec_map is checked on the input source node (conv) This change is backwards-compatible: for non-fused patterns, both nodes are the same. Reviewed By: hsharma35 Differential Revision: D89630759
1 parent b2893dd commit ccdb8e8

File tree

1 file changed

+85
-14
lines changed

1 file changed

+85
-14
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,18 @@
4343

4444
# Type alias for graph builder functions.
4545
# These functions take a test instance and return a graph module and the target op node.
46+
# For fused patterns (e.g., conv+relu), an optional third element specifies the node
47+
# whose args contain the quantized inputs (e.g., conv node for conv+relu fusion).
4648
GraphBuilderFn = Callable[
47-
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
49+
["QuantizerAnnotationTest"],
50+
tuple[torch.fx.GraphModule, torch.fx.Node]
51+
| tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node],
4852
]
4953

5054

5155
# Quantizers intentionally excluded from annotation testing.
5256
# These should be explicitly justified when added.
5357
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
54-
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
5558
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
5659
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
5760
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
@@ -200,6 +203,16 @@
200203
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
201204
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
202205
),
206+
# CadenceFusedConvReluQuantizer test cases
207+
(
208+
"fused_conv2d_relu_A8W8sym",
209+
lambda self: self._build_conv2d_relu_graph(),
210+
CadenceFusedConvReluQuantizer(),
211+
torch.ops.aten.relu.default,
212+
qconfig_A8W8sym.output_activation,
213+
# For fused conv2d+relu: [input_activation, weight] from conv2d node
214+
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
215+
),
203216
]
204217

205218
# Derive the set of tested quantizer classes from the test cases.
@@ -442,6 +455,52 @@ def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
442455
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
443456
return gm, addmm_nodes[0]
444457

458+
def _build_conv2d_relu_graph(
459+
self,
460+
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
461+
"""Build a graph with a conv2d followed by relu (fused pattern).
462+
463+
Returns:
464+
A tuple of (graph_module, relu_node, conv_node).
465+
The relu_node is the target node where the annotation is placed.
466+
The conv_node is the input source node whose args contain the quantized inputs.
467+
"""
468+
builder = GraphBuilder()
469+
# Input shape: (batch, in_channels, height, width)
470+
x = builder.placeholder("x", torch.randn(1, 3, 8, 8))
471+
# Weight shape: (out_channels, in_channels, kernel_h, kernel_w)
472+
weight = builder.placeholder("weight", torch.randn(6, 3, 3, 3))
473+
conv2d = builder.call_operator(
474+
op=torch.ops.aten.conv2d.default,
475+
args=(x, weight),
476+
meta=NodeMetadata(
477+
{"source_fn_stack": [("conv2d", torch.ops.aten.conv2d.default)]}
478+
),
479+
)
480+
relu = builder.call_operator(
481+
op=torch.ops.aten.relu.default,
482+
args=(conv2d,),
483+
meta=NodeMetadata(
484+
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
485+
),
486+
)
487+
builder.output([relu])
488+
gm = builder.get_graph_module()
489+
490+
relu_nodes = gm.graph.find_nodes(
491+
op="call_function",
492+
target=torch.ops.aten.relu.default,
493+
)
494+
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
495+
496+
conv2d_nodes = gm.graph.find_nodes(
497+
op="call_function",
498+
target=torch.ops.aten.conv2d.default,
499+
)
500+
self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node")
501+
502+
return gm, relu_nodes[0], conv2d_nodes[0]
503+
445504
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
446505
def test_quantizer_annotation(
447506
self,
@@ -453,28 +512,40 @@ def test_quantizer_annotation(
453512
expected_input_qspecs: list[QuantizationSpec | None],
454513
) -> None:
455514
"""Parameterized test for quantizer annotations."""
456-
gm, op_node = graph_builder_fn(self)
515+
result = graph_builder_fn(self)
516+
# Handle both 2-element and 3-element returns from graph builders.
517+
# For fused patterns, the 3rd element specifies the node whose args
518+
# contain the quantized inputs (e.g., conv node for conv+relu fusion).
519+
if len(result) == 3:
520+
gm, output_node, input_source_node = result
521+
else:
522+
gm, output_node = result
523+
input_source_node = output_node
457524

458525
quantizer.annotate(gm)
459526

460-
annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY]
461-
self.assertTrue(annotation._annotated)
462-
463-
# Verify output annotation
464-
self.assertEqual(annotation.output_qspec, expected_output_qspec)
527+
# Verify output annotation (always on the output node)
528+
output_annotation: QuantizationAnnotation = output_node.meta[Q_ANNOTATION_KEY]
529+
self.assertTrue(output_annotation._annotated)
530+
self.assertEqual(output_annotation.output_qspec, expected_output_qspec)
465531

466-
# Verify input annotations
467-
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
468-
for input_node, input_qspec in annotation.input_qspec_map.items():
469-
# Find the index of this input node in the op's args
532+
# Verify input annotations (on the input source node, which may differ for fused patterns)
533+
input_annotation: QuantizationAnnotation = input_source_node.meta[
534+
Q_ANNOTATION_KEY
535+
]
536+
self.assertEqual(
537+
len(input_annotation.input_qspec_map), len(expected_input_qspecs)
538+
)
539+
for input_node, input_qspec in input_annotation.input_qspec_map.items():
540+
# Find the index of this input node in the input source node's args
470541
arg_index = None
471-
for i, arg in enumerate(op_node.args):
542+
for i, arg in enumerate(input_source_node.args):
472543
if arg is input_node:
473544
arg_index = i
474545
break
475546
self.assertIsNotNone(
476547
arg_index,
477-
f"Input node {input_node} not found in op_node.args",
548+
f"Input node {input_node} not found in input_source_node.args",
478549
)
479550
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
480551
if expected_input_qspecs[arg_index] is not None:

0 commit comments

Comments
 (0)