Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 317 additions & 21 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
CadenceWithSoftmaxQuantizer,
qconfig_A16,
qconfig_A8W8,
qconfig_A8W8sym,
)
from executorch.exir.pass_base import NodeMetadata
from parameterized import parameterized
Expand All @@ -50,29 +51,25 @@
# Quantizers intentionally excluded from annotation testing.
# These should be explicitly justified when added.
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
CadenceRmsNormNopQuantizer, # No-op quantizer, doesn't annotate anything, preserves rms_norm from decomposition
CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage
CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage
CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage
CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage
}


# Test case definitions for quantizer annotation tests.
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
# Adding a new quantizer test only requires adding a tuple to this list.
# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
QUANTIZER_ANNOTATION_TEST_CASES: list[
tuple[
str,
GraphBuilderFn,
CadenceQuantizer,
OpOverload,
QuantizationSpec,
list[QuantizationSpec],
list[QuantizationSpec | None],
]
] = [
(
Expand All @@ -93,6 +90,116 @@
# For linear: [input_activation, weight]
[qconfig_A16.input_activation, qconfig_A16.weight],
),
(
"conv1d_A16",
lambda self: self._build_conv1d_graph(),
CadenceWith16BitConvActivationsQuantizer(),
torch.ops.aten.conv1d.default,
qconfig_A16.output_activation,
# For conv1d: [input_activation, weight]
[qconfig_A16.input_activation, qconfig_A16.weight],
),
(
"conv2d_A16",
lambda self: self._build_conv2d_graph(),
CadenceWith16BitConvActivationsQuantizer(),
torch.ops.aten.conv2d.default,
qconfig_A16.output_activation,
# For conv2d: [input_activation, weight]
[qconfig_A16.input_activation, qconfig_A16.weight],
),
(
"softmax_A16",
lambda self: self._build_softmax_graph(),
CadenceWithSoftmaxQuantizer(),
torch.ops.aten._softmax.default,
qconfig_A16.output_activation,
# For softmax: only input_activation
[qconfig_A16.input_activation],
),
(
"layer_norm_A8W8",
lambda self: self._build_layer_norm_graph(),
CadenceWithLayerNormQuantizer(),
torch.ops.aten.layer_norm.default,
qconfig_A8W8.output_activation,
# For layer_norm: only input_activation (weights/bias are passed as others)
[qconfig_A8W8.input_activation],
),
(
"add_A8W8",
lambda self: self._build_add_graph(),
CadenceWakeWordQuantizer(),
torch.ops.aten.add.Tensor,
qconfig_A8W8.output_activation,
# For add: both inputs are activations
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
),
# CadenceDefaultQuantizer test cases
(
"default_matmul_A8W8",
lambda self: self._build_matmul_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.matmul.default,
qconfig_A8W8.output_activation,
# For matmul: both inputs are activations
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
),
(
"default_linear_A8W8",
lambda self: self._build_linear_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.linear.default,
qconfig_A8W8.output_activation,
# For linear: [input_activation, weight]
[qconfig_A8W8.input_activation, qconfig_A8W8.weight],
),
(
"default_conv1d_A8W8sym",
lambda self: self._build_conv1d_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.conv1d.default,
qconfig_A8W8sym.output_activation,
# For conv1d: [input_activation, weight] with symmetric weights
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
),
(
"default_conv2d_A8W8sym",
lambda self: self._build_conv2d_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.conv2d.default,
qconfig_A8W8sym.output_activation,
# For conv2d: [input_activation, weight] with symmetric weights
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
),
(
"default_bmm_A8W8",
lambda self: self._build_bmm_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.bmm.default,
qconfig_A8W8.output_activation,
# For bmm: both inputs are activations
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
),
(
"default_relu_A8W8",
lambda self: self._build_relu_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.relu.default,
qconfig_A8W8.output_activation,
# For relu: only input_activation
[qconfig_A8W8.input_activation],
),
(
"default_addmm_A8W8",
lambda self: self._build_addmm_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.addmm.default,
qconfig_A8W8.output_activation,
# For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
),
]

# Derive the set of tested quantizer classes from the test cases.
Expand Down Expand Up @@ -149,6 +256,192 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
return gm, linear_nodes[0]

def _build_conv1d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a conv1d operation (no bias)."""
builder = GraphBuilder()
# Input shape: (batch, in_channels, length)
x = builder.placeholder("x", torch.randn(1, 3, 10))
# Weight shape: (out_channels, in_channels, kernel_size)
weight = builder.placeholder("weight", torch.randn(6, 3, 3))
conv1d = builder.call_operator(
op=torch.ops.aten.conv1d.default,
args=(x, weight),
meta=NodeMetadata(
{"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]}
),
)
builder.output([conv1d])
gm = builder.get_graph_module()

conv1d_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.conv1d.default,
)
self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node")
return gm, conv1d_nodes[0]

def _build_conv2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a conv2d operation (no bias)."""
builder = GraphBuilder()
# Input shape: (batch, in_channels, height, width)
x = builder.placeholder("x", torch.randn(1, 3, 8, 8))
# Weight shape: (out_channels, in_channels, kernel_h, kernel_w)
weight = builder.placeholder("weight", torch.randn(6, 3, 3, 3))
conv2d = builder.call_operator(
op=torch.ops.aten.conv2d.default,
args=(x, weight),
meta=NodeMetadata(
{"source_fn_stack": [("conv2d", torch.ops.aten.conv2d.default)]}
),
)
builder.output([conv2d])
gm = builder.get_graph_module()

conv2d_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.conv2d.default,
)
self.assertEqual(len(conv2d_nodes), 1, "Should find exactly one conv2d node")
return gm, conv2d_nodes[0]

def _build_softmax_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a softmax operation."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 10))
softmax = builder.call_operator(
op=torch.ops.aten._softmax.default,
args=(x, -1, False), # dim=-1, half_to_float=False
meta=NodeMetadata(
{"source_fn_stack": [("softmax", torch.ops.aten._softmax.default)]}
),
)
builder.output([softmax])
gm = builder.get_graph_module()

softmax_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten._softmax.default,
)
self.assertEqual(len(softmax_nodes), 1, "Should find exactly one softmax node")
return gm, softmax_nodes[0]

def _build_layer_norm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a layer_norm operation."""
builder = GraphBuilder()
# Input shape: (batch, features)
x = builder.placeholder("x", torch.randn(1, 10))
# normalized_shape must match the last dimension(s) of input
normalized_shape = [10]
layer_norm = builder.call_operator(
op=torch.ops.aten.layer_norm.default,
args=(x, normalized_shape),
meta=NodeMetadata(
{"source_fn_stack": [("layer_norm", torch.ops.aten.layer_norm.default)]}
),
)
builder.output([layer_norm])
gm = builder.get_graph_module()

layer_norm_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.layer_norm.default,
)
self.assertEqual(
len(layer_norm_nodes), 1, "Should find exactly one layer_norm node"
)
return gm, layer_norm_nodes[0]

def _build_add_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with an add operation."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 10))
y = builder.placeholder("y", torch.randn(1, 10))
add = builder.call_operator(
op=torch.ops.aten.add.Tensor,
args=(x, y),
meta=NodeMetadata(
{"source_fn_stack": [("add", torch.ops.aten.add.Tensor)]}
),
)
builder.output([add])
gm = builder.get_graph_module()

add_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.add.Tensor,
)
self.assertEqual(len(add_nodes), 1, "Should find exactly one add node")
return gm, add_nodes[0]

def _build_bmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a bmm (batch matrix multiply) operation."""
builder = GraphBuilder()
# BMM requires 3D tensors: (batch, n, m) @ (batch, m, p) -> (batch, n, p)
x = builder.placeholder("x", torch.randn(2, 4, 8))
y = builder.placeholder("y", torch.randn(2, 8, 4))
bmm = builder.call_operator(
op=torch.ops.aten.bmm.default,
args=(x, y),
meta=NodeMetadata(
{"source_fn_stack": [("bmm", torch.ops.aten.bmm.default)]}
),
)
builder.output([bmm])
gm = builder.get_graph_module()

bmm_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.bmm.default,
)
self.assertEqual(len(bmm_nodes), 1, "Should find exactly one bmm node")
return gm, bmm_nodes[0]

def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a relu operation."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 10))
relu = builder.call_operator(
op=torch.ops.aten.relu.default,
args=(x,),
meta=NodeMetadata(
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
),
)
builder.output([relu])
gm = builder.get_graph_module()

relu_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.relu.default,
)
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
return gm, relu_nodes[0]

def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with an addmm operation."""
builder = GraphBuilder()
# addmm: bias + (mat1 @ mat2)
# args: (bias, mat1, mat2)
bias = builder.placeholder("bias", torch.randn(5))
mat1 = builder.placeholder("mat1", torch.randn(1, 10))
mat2 = builder.placeholder("mat2", torch.randn(10, 5))
addmm = builder.call_operator(
op=torch.ops.aten.addmm.default,
args=(bias, mat1, mat2),
meta=NodeMetadata(
{"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]}
),
)
builder.output([addmm])
gm = builder.get_graph_module()

addmm_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.addmm.default,
)
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
return gm, addmm_nodes[0]

@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
def test_quantizer_annotation(
self,
Expand All @@ -157,7 +450,7 @@ def test_quantizer_annotation(
quantizer: CadenceQuantizer,
target: OpOverload,
expected_output_qspec: QuantizationSpec,
expected_input_qspecs: list[QuantizationSpec],
expected_input_qspecs: list[QuantizationSpec | None],
) -> None:
"""Parameterized test for quantizer annotations."""
gm, op_node = graph_builder_fn(self)
Expand All @@ -172,21 +465,24 @@ def test_quantizer_annotation(

# Verify input annotations
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
for i, (input_node, input_qspec) in enumerate(
annotation.input_qspec_map.items()
):
expected_arg = op_node.args[i]
assert isinstance(expected_arg, torch.fx.Node)
self.assertEqual(
input_node,
expected_arg,
f"Input node mismatch at index {i}",
)
self.assertEqual(
input_qspec,
expected_input_qspecs[i],
f"Input qspec mismatch at index {i}",
for input_node, input_qspec in annotation.input_qspec_map.items():
# Find the index of this input node in the op's args
arg_index = None
for i, arg in enumerate(op_node.args):
if arg is input_node:
arg_index = i
break
self.assertIsNotNone(
arg_index,
f"Input node {input_node} not found in op_node.args",
)
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
if expected_input_qspecs[arg_index] is not None:
self.assertEqual(
input_qspec,
expected_input_qspecs[arg_index],
f"Input qspec mismatch at arg index {arg_index}",
)

def test_all_quantizers_have_annotation_tests(self) -> None:
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
Expand Down
Loading