Skip to content

Commit 3759295

Browse files
RahulC7facebook-github-bot
authored andcommitted
Adding test for CadenceWithSoftmaxQuantizer (#16206)
Summary: Add annotation tests for CadenceWithSoftmaxQuantizer. https://www.internalfb.com/code/fbsource/[01c566b03c670b1869136cbb64f25d16d730c8d4]/fbcode/executorch/backends/cadence/aot/quantizer/quantizer.py?lines=360-369 Reviewed By: hsharma35 Differential Revision: D88896712
1 parent d89c44a commit 3759295

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
5858
CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage
5959
CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage
60-
CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage
6160
}
6261

6362

@@ -110,6 +109,15 @@
110109
# For conv2d: [input_activation, weight]
111110
[qconfig_A16.input_activation, qconfig_A16.weight],
112111
),
112+
(
113+
"softmax_A16",
114+
lambda self: self._build_softmax_graph(),
115+
CadenceWithSoftmaxQuantizer(),
116+
torch.ops.aten._softmax.default,
117+
qconfig_A16.output_activation,
118+
# For softmax: only input_activation
119+
[qconfig_A16.input_activation],
120+
),
113121
]
114122

115123
# Derive the set of tested quantizer classes from the test cases.
@@ -214,6 +222,27 @@ def _build_conv2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
214222
self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node")
215223
return gm, conv2d_nodes[0]
216224

225+
def _build_softmax_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
226+
"""Build a simple graph with a softmax operation."""
227+
builder = GraphBuilder()
228+
x = builder.placeholder("x", torch.randn(1, 10))
229+
softmax = builder.call_operator(
230+
op=torch.ops.aten._softmax.default,
231+
args=(x, -1, False), # dim=-1, half_to_float=False
232+
meta=NodeMetadata(
233+
{"source_fn_stack": [("softmax", torch.ops.aten._softmax.default)]}
234+
),
235+
)
236+
builder.output([softmax])
237+
gm = builder.get_graph_module()
238+
239+
softmax_nodes = gm.graph.find_nodes(
240+
op="call_function",
241+
target=torch.ops.aten._softmax.default,
242+
)
243+
self.assertEqual(len(softmax_nodes), 1, "Should find exactly one softmax node")
244+
return gm, softmax_nodes[0]
245+
217246
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
218247
def test_quantizer_annotation(
219248
self,

0 commit comments

Comments
 (0)