|
30 | 30 | CadenceWithSoftmaxQuantizer, |
31 | 31 | qconfig_A16, |
32 | 32 | qconfig_A8W8, |
| 33 | + qconfig_A8W8sym, |
33 | 34 | ) |
34 | 35 | from executorch.exir.pass_base import NodeMetadata |
35 | 36 | from parameterized import parameterized |
|
50 | 51 | # Quantizers intentionally excluded from annotation testing. |
51 | 52 | # These should be explicitly justified when added. |
52 | 53 | EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = { |
53 | | - CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage |
54 | 54 | CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage |
55 | 55 | CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything |
56 | 56 | CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage |
|
134 | 134 | # For add: both inputs are activations |
135 | 135 | [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], |
136 | 136 | ), |
| 137 | + # CadenceDefaultQuantizer test cases |
| 138 | + ( |
| 139 | + "default_matmul_A8W8", |
| 140 | + lambda self: self._build_matmul_graph(), |
| 141 | + CadenceDefaultQuantizer(), |
| 142 | + torch.ops.aten.matmul.default, |
| 143 | + qconfig_A8W8.output_activation, |
| 144 | + # For matmul: both inputs are activations |
| 145 | + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], |
| 146 | + ), |
| 147 | + ( |
| 148 | + "default_linear_A8W8", |
| 149 | + lambda self: self._build_linear_graph(), |
| 150 | + CadenceDefaultQuantizer(), |
| 151 | + torch.ops.aten.linear.default, |
| 152 | + qconfig_A8W8.output_activation, |
| 153 | + # For linear: [input_activation, weight] |
| 154 | + [qconfig_A8W8.input_activation, qconfig_A8W8.weight], |
| 155 | + ), |
| 156 | + ( |
| 157 | + "default_conv1d_A8W8sym", |
| 158 | + lambda self: self._build_conv1d_graph(), |
| 159 | + CadenceDefaultQuantizer(), |
| 160 | + torch.ops.aten.conv1d.default, |
| 161 | + qconfig_A8W8sym.output_activation, |
| 162 | + # For conv1d: [input_activation, weight] with symmetric weights |
| 163 | + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], |
| 164 | + ), |
| 165 | + ( |
| 166 | + "default_conv2d_A8W8sym", |
| 167 | + lambda self: self._build_conv2d_graph(), |
| 168 | + CadenceDefaultQuantizer(), |
| 169 | + torch.ops.aten.conv2d.default, |
| 170 | + qconfig_A8W8sym.output_activation, |
| 171 | + # For conv2d: [input_activation, weight] with symmetric weights |
| 172 | + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], |
| 173 | + ), |
| 174 | + ( |
| 175 | + "default_bmm_A8W8", |
| 176 | + lambda self: self._build_bmm_graph(), |
| 177 | + CadenceDefaultQuantizer(), |
| 178 | + torch.ops.aten.bmm.default, |
| 179 | + qconfig_A8W8.output_activation, |
| 180 | + # For bmm: both inputs are activations |
| 181 | + [qconfig_A8W8.input_activation, qconfig_A8W8.input_activation], |
| 182 | + ), |
| 183 | + ( |
| 184 | + "default_relu_A8W8", |
| 185 | + lambda self: self._build_relu_graph(), |
| 186 | + CadenceDefaultQuantizer(), |
| 187 | + torch.ops.aten.relu.default, |
| 188 | + qconfig_A8W8.output_activation, |
| 189 | + # For relu: only input_activation |
| 190 | + [qconfig_A8W8.input_activation], |
| 191 | + ), |
137 | 192 | ] |
138 | 193 |
|
139 | 194 | # Derive the set of tested quantizer classes from the test cases. |
@@ -307,6 +362,50 @@ def _build_add_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
307 | 362 | self.assertEqual(len(add_nodes), 1, "Should find exactly one add node") |
308 | 363 | return gm, add_nodes[0] |
309 | 364 |
|
| 365 | + def _build_bmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
| 366 | + """Build a simple graph with a bmm (batch matrix multiply) operation.""" |
| 367 | + builder = GraphBuilder() |
| 368 | + # BMM requires 3D tensors: (batch, n, m) @ (batch, m, p) -> (batch, n, p) |
| 369 | + x = builder.placeholder("x", torch.randn(2, 4, 8)) |
| 370 | + y = builder.placeholder("y", torch.randn(2, 8, 4)) |
| 371 | + bmm = builder.call_operator( |
| 372 | + op=torch.ops.aten.bmm.default, |
| 373 | + args=(x, y), |
| 374 | + meta=NodeMetadata( |
| 375 | + {"source_fn_stack": [("bmm", torch.ops.aten.bmm.default)]} |
| 376 | + ), |
| 377 | + ) |
| 378 | + builder.output([bmm]) |
| 379 | + gm = builder.get_graph_module() |
| 380 | + |
| 381 | + bmm_nodes = gm.graph.find_nodes( |
| 382 | + op="call_function", |
| 383 | + target=torch.ops.aten.bmm.default, |
| 384 | + ) |
| 385 | + self.assertEqual(len(bmm_nodes), 1, "Should find exactly one bmm node") |
| 386 | + return gm, bmm_nodes[0] |
| 387 | + |
| 388 | + def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: |
| 389 | + """Build a simple graph with a relu operation.""" |
| 390 | + builder = GraphBuilder() |
| 391 | + x = builder.placeholder("x", torch.randn(1, 10)) |
| 392 | + relu = builder.call_operator( |
| 393 | + op=torch.ops.aten.relu.default, |
| 394 | + args=(x,), |
| 395 | + meta=NodeMetadata( |
| 396 | + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} |
| 397 | + ), |
| 398 | + ) |
| 399 | + builder.output([relu]) |
| 400 | + gm = builder.get_graph_module() |
| 401 | + |
| 402 | + relu_nodes = gm.graph.find_nodes( |
| 403 | + op="call_function", |
| 404 | + target=torch.ops.aten.relu.default, |
| 405 | + ) |
| 406 | + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") |
| 407 | + return gm, relu_nodes[0] |
| 408 | + |
310 | 409 | @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) |
311 | 410 | def test_quantizer_annotation( |
312 | 411 | self, |
|
0 commit comments