|
57 | 57 | CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition |
58 | 58 | CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage |
59 | 59 | CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage |
60 | | - CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage |
61 | 60 | } |
62 | 61 |
|
63 | 62 |
|
|
110 | 109 | # For conv2d: [input_activation, weight] |
111 | 110 | [qconfig_A16.input_activation, qconfig_A16.weight], |
112 | 111 | ), |
| 112 | + ( |
| 113 | + "softmax_A16", |
| 114 | + lambda self: self._build_softmax_graph(), |
| 115 | + CadenceWithSoftmaxQuantizer(), |
| 116 | + torch.ops.aten._softmax.default, |
| 117 | + qconfig_A16.output_activation, |
| 118 | + # For softmax: only input_activation |
| 119 | + [qconfig_A16.input_activation], |
| 120 | + ), |
113 | 121 | ] |
114 | 122 |
|
115 | 123 | # Derive the set of tested quantizer classes from the test cases. |
@@ -214,6 +222,27 @@ def _build_conv2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
214 | 222 | self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node") |
215 | 223 | return gm, conv2d_nodes[0] |
216 | 224 |
|
| 225 | + def _build_softmax_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
| 226 | + """Build a simple graph with a softmax operation.""" |
| 227 | + builder = GraphBuilder() |
| 228 | + x = builder.placeholder("x", torch.randn(1, 10)) |
| 229 | + softmax = builder.call_operator( |
| 230 | + op=torch.ops.aten._softmax.default, |
| 231 | + args=(x, -1, False), # dim=-1, half_to_float=False |
| 232 | + meta=NodeMetadata( |
| 233 | + {"source_fn_stack": [("softmax", torch.ops.aten._softmax.default)]} |
| 234 | + ), |
| 235 | + ) |
| 236 | + builder.output([softmax]) |
| 237 | + gm = builder.get_graph_module() |
| 238 | + |
| 239 | + softmax_nodes = gm.graph.find_nodes( |
| 240 | + op="call_function", |
| 241 | + target=torch.ops.aten._softmax.default, |
| 242 | + ) |
| 243 | + self.assertEqual(len(softmax_nodes), 1, "Should find exactly one softmax node") |
| 244 | + return gm, softmax_nodes[0] |
| 245 | + |
217 | 246 | @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) |
218 | 247 | def test_quantizer_annotation( |
219 | 248 | self, |
|
0 commit comments