Skip to content

Commit 74805a1

Browse files
RahulC7meta-codesync[bot]
authored andcommitted
Adding Tests for CadenceDefaultQuantizer
Differential Revision: D88899457
1 parent fd9071e commit 74805a1

File tree

1 file changed

+100
-1
lines changed

1 file changed

+100
-1
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
CadenceWithSoftmaxQuantizer,
3131
qconfig_A16,
3232
qconfig_A8W8,
33+
qconfig_A8W8sym,
3334
)
3435
from executorch.exir.pass_base import NodeMetadata
3536
from parameterized import parameterized
@@ -50,7 +51,6 @@
5051
# Quantizers intentionally excluded from annotation testing.
5152
# These should be explicitly justified when added.
5253
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
53-
CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage
5454
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
5555
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
5656
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
@@ -134,6 +134,61 @@
134134
# For add: both inputs are activations
135135
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
136136
),
137+
# CadenceDefaultQuantizer test cases
138+
(
139+
"default_matmul_A8W8",
140+
lambda self: self._build_matmul_graph(),
141+
CadenceDefaultQuantizer(),
142+
torch.ops.aten.matmul.default,
143+
qconfig_A8W8.output_activation,
144+
# For matmul: both inputs are activations
145+
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
146+
),
147+
(
148+
"default_linear_A8W8",
149+
lambda self: self._build_linear_graph(),
150+
CadenceDefaultQuantizer(),
151+
torch.ops.aten.linear.default,
152+
qconfig_A8W8.output_activation,
153+
# For linear: [input_activation, weight]
154+
[qconfig_A8W8.input_activation, qconfig_A8W8.weight],
155+
),
156+
(
157+
"default_conv1d_A8W8sym",
158+
lambda self: self._build_conv1d_graph(),
159+
CadenceDefaultQuantizer(),
160+
torch.ops.aten.conv1d.default,
161+
qconfig_A8W8sym.output_activation,
162+
# For conv1d: [input_activation, weight] with symmetric weights
163+
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
164+
),
165+
(
166+
"default_conv2d_A8W8sym",
167+
lambda self: self._build_conv2d_graph(),
168+
CadenceDefaultQuantizer(),
169+
torch.ops.aten.conv2d.default,
170+
qconfig_A8W8sym.output_activation,
171+
# For conv2d: [input_activation, weight] with symmetric weights
172+
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
173+
),
174+
(
175+
"default_bmm_A8W8",
176+
lambda self: self._build_bmm_graph(),
177+
CadenceDefaultQuantizer(),
178+
torch.ops.aten.bmm.default,
179+
qconfig_A8W8.output_activation,
180+
# For bmm: both inputs are activations
181+
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
182+
),
183+
(
184+
"default_relu_A8W8",
185+
lambda self: self._build_relu_graph(),
186+
CadenceDefaultQuantizer(),
187+
torch.ops.aten.relu.default,
188+
qconfig_A8W8.output_activation,
189+
# For relu: only input_activation
190+
[qconfig_A8W8.input_activation],
191+
),
137192
]
138193

139194
# Derive the set of tested quantizer classes from the test cases.
@@ -307,6 +362,50 @@ def _build_add_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
307362
self.assertEqual(len(add_nodes), 1, "Should find exactly one add node")
308363
return gm, add_nodes[0]
309364

365+
def _build_bmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
366+
"""Build a simple graph with a bmm (batch matrix multiply) operation."""
367+
builder = GraphBuilder()
368+
# BMM requires 3D tensors: (batch, n, m) @ (batch, m, p) -> (batch, n, p)
369+
x = builder.placeholder("x", torch.randn(2, 4, 8))
370+
y = builder.placeholder("y", torch.randn(2, 8, 4))
371+
bmm = builder.call_operator(
372+
op=torch.ops.aten.bmm.default,
373+
args=(x, y),
374+
meta=NodeMetadata(
375+
{"source_fn_stack": [("bmm", torch.ops.aten.bmm.default)]}
376+
),
377+
)
378+
builder.output([bmm])
379+
gm = builder.get_graph_module()
380+
381+
bmm_nodes = gm.graph.find_nodes(
382+
op="call_function",
383+
target=torch.ops.aten.bmm.default,
384+
)
385+
self.assertEqual(len(bmm_nodes), 1, "Should find exactly one bmm node")
386+
return gm, bmm_nodes[0]
387+
388+
def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
389+
"""Build a simple graph with a relu operation."""
390+
builder = GraphBuilder()
391+
x = builder.placeholder("x", torch.randn(1, 10))
392+
relu = builder.call_operator(
393+
op=torch.ops.aten.relu.default,
394+
args=(x,),
395+
meta=NodeMetadata(
396+
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
397+
),
398+
)
399+
builder.output([relu])
400+
gm = builder.get_graph_module()
401+
402+
relu_nodes = gm.graph.find_nodes(
403+
op="call_function",
404+
target=torch.ops.aten.relu.default,
405+
)
406+
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
407+
return gm, relu_nodes[0]
408+
310409
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
311410
def test_quantizer_annotation(
312411
self,

0 commit comments

Comments
 (0)