Skip to content

Commit af22ffe

Browse files
RahulC7facebook-github-bot
authored andcommitted
Adding Tests for CadenceFusedConvReluQuantizer
Summary: A fused pattern is when the quantizer recognizes a sequence of operations and treats as a single unit for quantization purposes. So for example, for a Conv2D + ReLU fusion, rather than having something like this: ``` input → [quantize] → conv2d → [dequantize] → [quantize] → relu → [dequantize] → output ``` a fused pattern quantizes them together like so: ``` input → [quantize] → conv2d → relu → [dequantize] → output ``` We need to make a few changes in our framework to test this. # Change 1: We allow graph builders to return a 3rd element for fused patterns For fused patterns like conv+relu, the quantization annotations are split across two nodes: - Output annotation is on the relu node (the final output of the fused pattern) - Input annotations are on the conv node (where the quantized inputs enter) The existing graph builders return (gm, target_node), which works for single-op patterns where both annotations are on the same node. For fused patterns, we need to know both nodes, so graph builders can now optionally return (gm, output_node, input_source_node). # Change 2: We check annotations on the correct nodes for fused patterns The test previously assumed output_qspec and input_qspec_map were both on the same node. For fused patterns, they're on different nodes: - output_qspec is checked on the output node (relu) - input_qspec_map is checked on the input source node (conv) This change is backwards-compatible: for non-fused patterns, both nodes are the same. Differential Revision: D89630759
1 parent 0e22a37 commit af22ffe

File tree

1 file changed

+80
-14
lines changed

1 file changed

+80
-14
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,17 @@
4343

4444
# Type alias for graph builder functions.
4545
# These functions take a test instance and return a graph module and the target op node.
46+
# For fused patterns (e.g., conv+relu), an optional third element specifies the node
47+
# whose args contain the quantized inputs (e.g., conv node for conv+relu fusion).
4648
GraphBuilderFn = Callable[
47-
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
49+
["QuantizerAnnotationTest"],
50+
tuple[torch.fx.GraphModule, torch.fx.Node] | tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node],
4851
]
4952

5053

5154
# Quantizers intentionally excluded from annotation testing.
5255
# These should be explicitly justified when added.
5356
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
54-
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
5557
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
5658
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
5759
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
@@ -200,6 +202,16 @@
200202
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
201203
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
202204
),
205+
# CadenceFusedConvReluQuantizer test cases
206+
(
207+
"fused_conv2d_relu_A8W8sym",
208+
lambda self: self._build_conv2d_relu_graph(),
209+
CadenceFusedConvReluQuantizer(),
210+
torch.ops.aten.relu.default,
211+
qconfig_A8W8sym.output_activation,
212+
# For fused conv2d+relu: [input_activation, weight] from conv2d node
213+
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
214+
),
203215
]
204216

205217
# Derive the set of tested quantizer classes from the test cases.
@@ -442,6 +454,52 @@ def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
442454
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
443455
return gm, addmm_nodes[0]
444456

457+
def _build_conv2d_relu_graph(
458+
self,
459+
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
460+
"""Build a graph with a conv2d followed by relu (fused pattern).
461+
462+
Returns:
463+
A tuple of (graph_module, relu_node, conv_node).
464+
The relu_node is the target node where the annotation is placed.
465+
The conv_node is the input source node whose args contain the quantized inputs.
466+
"""
467+
builder = GraphBuilder()
468+
# Input shape: (batch, in_channels, height, width)
469+
x = builder.placeholder("x", torch.randn(1, 3, 8, 8))
470+
# Weight shape: (out_channels, in_channels, kernel_h, kernel_w)
471+
weight = builder.placeholder("weight", torch.randn(6, 3, 3, 3))
472+
conv2d = builder.call_operator(
473+
op=torch.ops.aten.conv2d.default,
474+
args=(x, weight),
475+
meta=NodeMetadata(
476+
{"source_fn_stack": [("conv2d", torch.ops.aten.conv2d.default)]}
477+
),
478+
)
479+
relu = builder.call_operator(
480+
op=torch.ops.aten.relu.default,
481+
args=(conv2d,),
482+
meta=NodeMetadata(
483+
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
484+
),
485+
)
486+
builder.output([relu])
487+
gm = builder.get_graph_module()
488+
489+
relu_nodes = gm.graph.find_nodes(
490+
op="call_function",
491+
target=torch.ops.aten.relu.default,
492+
)
493+
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
494+
495+
conv2d_nodes = gm.graph.find_nodes(
496+
op="call_function",
497+
target=torch.ops.aten.conv2d.default,
498+
)
499+
self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node")
500+
501+
return gm, relu_nodes[0], conv2d_nodes[0]
502+
445503
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
446504
def test_quantizer_annotation(
447505
self,
@@ -453,28 +511,36 @@ def test_quantizer_annotation(
453511
expected_input_qspecs: list[QuantizationSpec | None],
454512
) -> None:
455513
"""Parameterized test for quantizer annotations."""
456-
gm, op_node = graph_builder_fn(self)
514+
result = graph_builder_fn(self)
515+
# Handle both 2-element and 3-element returns from graph builders.
516+
# For fused patterns, the 3rd element specifies the node whose args
517+
# contain the quantized inputs (e.g., conv node for conv+relu fusion).
518+
if len(result) == 3:
519+
gm, output_node, input_source_node = result
520+
else:
521+
gm, output_node = result
522+
input_source_node = output_node
457523

458524
quantizer.annotate(gm)
459525

460-
annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY]
461-
self.assertTrue(annotation._annotated)
462-
463-
# Verify output annotation
464-
self.assertEqual(annotation.output_qspec, expected_output_qspec)
526+
# Verify output annotation (always on the output node)
527+
output_annotation: QuantizationAnnotation = output_node.meta[Q_ANNOTATION_KEY]
528+
self.assertTrue(output_annotation._annotated)
529+
self.assertEqual(output_annotation.output_qspec, expected_output_qspec)
465530

466-
# Verify input annotations
467-
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
468-
for input_node, input_qspec in annotation.input_qspec_map.items():
469-
# Find the index of this input node in the op's args
531+
# Verify input annotations (on the input source node, which may differ for fused patterns)
532+
input_annotation: QuantizationAnnotation = input_source_node.meta[Q_ANNOTATION_KEY]
533+
self.assertEqual(len(input_annotation.input_qspec_map), len(expected_input_qspecs))
534+
for input_node, input_qspec in input_annotation.input_qspec_map.items():
535+
# Find the index of this input node in the input source node's args
470536
arg_index = None
471-
for i, arg in enumerate(op_node.args):
537+
for i, arg in enumerate(input_source_node.args):
472538
if arg is input_node:
473539
arg_index = i
474540
break
475541
self.assertIsNotNone(
476542
arg_index,
477-
f"Input node {input_node} not found in op_node.args",
543+
f"Input node {input_node} not found in input_source_node.args",
478544
)
479545
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
480546
if expected_input_qspecs[arg_index] is not None:

0 commit comments

Comments
 (0)