6161# Test case definitions for quantizer annotation tests.
6262# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
6363# Adding a new quantizer test only requires adding a tuple to this list.
64+ # Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
6465QUANTIZER_ANNOTATION_TEST_CASES : list [
6566 tuple [
6667 str ,
6768 GraphBuilderFn ,
6869 CadenceQuantizer ,
6970 OpOverload ,
7071 QuantizationSpec ,
71- list [QuantizationSpec ],
72+ list [QuantizationSpec | None ],
7273 ]
7374] = [
7475 (
189190 # For relu: only input_activation
190191 [qconfig_A8W8 .input_activation ],
191192 ),
193+ (
194+ "default_addmm_A8W8" ,
195+ lambda self : self ._build_addmm_graph (),
196+ CadenceDefaultQuantizer (),
197+ torch .ops .aten .addmm .default ,
198+ qconfig_A8W8 .output_activation ,
199+ # For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
200+ # Use None to skip comparison for bias since it's a DerivedQuantizationSpec
201+ [None , qconfig_A8W8 .input_activation , qconfig_A8W8 .weight ],
202+ ),
192203]
193204
194205# Derive the set of tested quantizer classes from the test cases.
@@ -406,6 +417,31 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
406417 self .assertEqual (len (relu_nodes ), 1 , "Should find exactly one relu node" )
407418 return gm , relu_nodes [0 ]
408419
420+ def _build_addmm_graph (self ) -> tuple [torch .fx .GraphModule , torch .fx .Node ]:
421+ """Build a simple graph with an addmm operation."""
422+ builder = GraphBuilder ()
423+ # addmm: bias + (mat1 @ mat2)
424+ # args: (bias, mat1, mat2)
425+ bias = builder .placeholder ("bias" , torch .randn (5 ))
426+ mat1 = builder .placeholder ("mat1" , torch .randn (1 , 10 ))
427+ mat2 = builder .placeholder ("mat2" , torch .randn (10 , 5 ))
428+ addmm = builder .call_operator (
429+ op = torch .ops .aten .addmm .default ,
430+ args = (bias , mat1 , mat2 ),
431+ meta = NodeMetadata (
432+ {"source_fn_stack" : [("addmm" , torch .ops .aten .addmm .default )]}
433+ ),
434+ )
435+ builder .output ([addmm ])
436+ gm = builder .get_graph_module ()
437+
438+ addmm_nodes = gm .graph .find_nodes (
439+ op = "call_function" ,
440+ target = torch .ops .aten .addmm .default ,
441+ )
442+ self .assertEqual (len (addmm_nodes ), 1 , "Should find exactly one addmm node" )
443+ return gm , addmm_nodes [0 ]
444+
409445 @parameterized .expand (QUANTIZER_ANNOTATION_TEST_CASES )
410446 def test_quantizer_annotation (
411447 self ,
@@ -414,7 +450,7 @@ def test_quantizer_annotation(
414450 quantizer : CadenceQuantizer ,
415451 target : OpOverload ,
416452 expected_output_qspec : QuantizationSpec ,
417- expected_input_qspecs : list [QuantizationSpec ],
453+ expected_input_qspecs : list [QuantizationSpec | None ],
418454 ) -> None :
419455 """Parameterized test for quantizer annotations."""
420456 gm , op_node = graph_builder_fn (self )
@@ -429,21 +465,24 @@ def test_quantizer_annotation(
429465
430466 # Verify input annotations
431467 self .assertEqual (len (annotation .input_qspec_map ), len (expected_input_qspecs ))
432- for i , (input_node , input_qspec ) in enumerate (
433- annotation .input_qspec_map .items ()
434- ):
435- expected_arg = op_node .args [i ]
436- assert isinstance (expected_arg , torch .fx .Node )
437- self .assertEqual (
438- input_node ,
439- expected_arg ,
440- f"Input node mismatch at index { i } " ,
441- )
442- self .assertEqual (
443- input_qspec ,
444- expected_input_qspecs [i ],
445- f"Input qspec mismatch at index { i } " ,
468+ for input_node , input_qspec in annotation .input_qspec_map .items ():
469+ # Find the index of this input node in the op's args
470+ arg_index = None
471+ for i , arg in enumerate (op_node .args ):
472+ if arg is input_node :
473+ arg_index = i
474+ break
475+ self .assertIsNotNone (
476+ arg_index ,
477+ f"Input node { input_node } not found in op_node.args" ,
446478 )
479+ # Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
480+ if expected_input_qspecs [arg_index ] is not None :
481+ self .assertEqual (
482+ input_qspec ,
483+ expected_input_qspecs [arg_index ],
484+ f"Input qspec mismatch at arg index { arg_index } " ,
485+ )
447486
448487 def test_all_quantizers_have_annotation_tests (self ) -> None :
449488 """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
0 commit comments