6161# Test case definitions for quantizer annotation tests.
6262# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
6363# Adding a new quantizer test only requires adding a tuple to this list.
64+ # Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
6465QUANTIZER_ANNOTATION_TEST_CASES : list [
6566 tuple [
6667 str ,
6768 GraphBuilderFn ,
6869 CadenceQuantizer ,
6970 OpOverload ,
7071 QuantizationSpec ,
71- list [QuantizationSpec ],
72+ list [QuantizationSpec | None ],
7273 ]
7374] = [
7475 (
189190 # For relu: only input_activation
190191 [qconfig_A8W8 .input_activation ],
191192 ),
193+ (
194+ "default_addmm_A8W8" ,
195+ lambda self : self ._build_addmm_graph (),
196+ CadenceDefaultQuantizer (),
197+ torch .ops .aten .addmm .default ,
198+ qconfig_A8W8 .output_activation ,
199+ # For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
200+ # Use None to skip comparison for bias since it's a DerivedQuantizationSpec
201+ [None , qconfig_A8W8 .input_activation , qconfig_A8W8 .weight ],
202+ ),
192203]
193204
194205# Derive the set of tested quantizer classes from the test cases.
@@ -405,6 +416,31 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
405416 self .assertEqual (len (relu_nodes ), 1 , "Should find exactly one relu node" )
406417 return gm , relu_nodes [0 ]
407418
419+ def _build_addmm_graph (self ) -> tuple [torch .fx .GraphModule , torch .fx .Node ]:
420+ """Build a simple graph with an addmm operation."""
421+ builder = GraphBuilder ()
422+ # addmm: bias + (mat1 @ mat2)
423+ # args: (bias, mat1, mat2)
424+ bias = builder .placeholder ("bias" , torch .randn (5 ))
425+ mat1 = builder .placeholder ("mat1" , torch .randn (1 , 10 ))
426+ mat2 = builder .placeholder ("mat2" , torch .randn (10 , 5 ))
427+ addmm = builder .call_operator (
428+ op = torch .ops .aten .addmm .default ,
429+ args = (bias , mat1 , mat2 ),
430+ meta = NodeMetadata (
431+ {"source_fn_stack" : [("addmm" , torch .ops .aten .addmm .default )]}
432+ ),
433+ )
434+ builder .output ([addmm ])
435+ gm = builder .get_graph_module ()
436+
437+ addmm_nodes = gm .graph .find_nodes (
438+ op = "call_function" ,
439+ target = torch .ops .aten .addmm .default ,
440+ )
441+ self .assertEqual (len (addmm_nodes ), 1 , "Should find exactly one addmm node" )
442+ return gm , addmm_nodes [0 ]
443+
408444 @parameterized .expand (QUANTIZER_ANNOTATION_TEST_CASES )
409445 def test_quantizer_annotation (
410446 self ,
@@ -413,7 +449,7 @@ def test_quantizer_annotation(
413449 quantizer : CadenceQuantizer ,
414450 target : OpOverload ,
415451 expected_output_qspec : QuantizationSpec ,
416- expected_input_qspecs : list [QuantizationSpec ],
452+ expected_input_qspecs : list [QuantizationSpec | None ],
417453 ) -> None :
418454 """Parameterized test for quantizer annotations."""
419455 gm , op_node = graph_builder_fn (self )
@@ -428,21 +464,24 @@ def test_quantizer_annotation(
428464
429465 # Verify input annotations
430466 self .assertEqual (len (annotation .input_qspec_map ), len (expected_input_qspecs ))
431- for i , (input_node , input_qspec ) in enumerate (
432- annotation .input_qspec_map .items ()
433- ):
434- expected_arg = op_node .args [i ]
435- assert isinstance (expected_arg , torch .fx .Node )
436- self .assertEqual (
437- input_node ,
438- expected_arg ,
439- f"Input node mismatch at index { i } " ,
440- )
441- self .assertEqual (
442- input_qspec ,
443- expected_input_qspecs [i ],
444- f"Input qspec mismatch at index { i } " ,
467+ for input_node , input_qspec in annotation .input_qspec_map .items ():
468+ # Find the index of this input node in the op's args
469+ arg_index = None
470+ for i , arg in enumerate (op_node .args ):
471+ if arg is input_node :
472+ arg_index = i
473+ break
474+ self .assertIsNotNone (
475+ arg_index ,
476+ f"Input node { input_node } not found in op_node.args" ,
445477 )
478+ # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
479+ if expected_input_qspecs [arg_index ] is not None :
480+ self .assertEqual (
481+ input_qspec ,
482+ expected_input_qspecs [arg_index ],
483+ f"Input qspec mismatch at arg index { arg_index } " ,
484+ )
446485
447486 def test_all_quantizers_have_annotation_tests (self ) -> None :
448487 """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
0 commit comments