From d6d95c57a5aa160afcb5adcde2a6383385277c58 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Sun, 15 Jun 2025 17:12:15 -0700 Subject: [PATCH 01/29] Fix channels_last transformation for new registry --- docs/overview.rst | 7 ++ setup.cfg | 4 ++ src/qonnx/__init__.py | 26 +++++++ src/qonnx/custom_op/channels_last/__init__.py | 6 +- .../channels_last/batch_normalization.py | 2 + src/qonnx/custom_op/channels_last/conv.py | 2 + src/qonnx/custom_op/channels_last/max_pool.py | 2 + src/qonnx/custom_op/general/__init__.py | 24 +++---- src/qonnx/custom_op/general/bipolar_quant.py | 2 + src/qonnx/custom_op/general/debugmarker.py | 2 + .../custom_op/general/genericpartition.py | 2 + src/qonnx/custom_op/general/im2col.py | 2 + src/qonnx/custom_op/general/maxpoolnhwc.py | 2 + src/qonnx/custom_op/general/multithreshold.py | 2 + src/qonnx/custom_op/general/quant.py | 2 + src/qonnx/custom_op/general/quantavgpool2d.py | 2 + src/qonnx/custom_op/general/trunc.py | 2 + src/qonnx/custom_op/general/xnorpopcount.py | 2 + src/qonnx/custom_op/registry.py | 70 ++++++++++++++++--- src/qonnx/transformation/channels_last.py | 3 +- tests/custom_op/test_attr.py | 5 +- 21 files changed, 142 insertions(+), 29 deletions(-) diff --git a/docs/overview.rst b/docs/overview.rst index 8e2002d7..935ef4d9 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -45,6 +45,13 @@ Custom Operations/Nodes QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defined in the ONNX operator schema. These custom nodes are marked with domain="qonnx.*" in the protobuf to identify them as such. These nodes can represent specific operations that we need for low-bit networks, or operations that are specific to a particular hardware backend. To get more familiar with custom operations and how they are created, please take a look in the Jupyter notebook about CustomOps (see chapter :ref:`tutorials` for details) or directly in the module :py:mod:`qonnx.custom_op`. +Custom ops can be registered automatically via Python entry points using the +``qonnx_custom_ops`` group. Each operator class should be decorated with +``@register_op(domain="...", op_type="...")`` from +``qonnx.custom_op.registry``. Packages installed with such an entry point will +be discovered on import and their ops made available through +``getCustomOp``. + Custom ONNX Execution Flow ========================== diff --git a/setup.cfg b/setup.cfg index 9b71bb56..a3038f13 100644 --- a/setup.cfg +++ b/setup.cfg @@ -98,6 +98,10 @@ console_scripts = qonnx-tensor-stats = qonnx.analysis.tensor_stats:main pytest_randomly.random_seeder = qonnx = qonnx.util.random_reseed:reseed +# entry points for custom op modules +qonnx_custom_ops = + qonnx = qonnx.custom_op.general + qonnx_channels_last = qonnx.custom_op.channels_last # Add here console scripts like: # console_scripts = # script_name = qonnx.module:function diff --git a/src/qonnx/__init__.py b/src/qonnx/__init__.py index e69de29b..bb2c88d0 100644 --- a/src/qonnx/__init__.py +++ b/src/qonnx/__init__.py @@ -0,0 +1,26 @@ +"""QONNX package initialization.""" + +import warnings +from importlib import metadata + + +def _load_custom_op_entry_points(): + """Import modules registered under the ``qonnx_custom_ops`` entry point.""" + + try: + eps = metadata.entry_points() + if hasattr(eps, "select"): + eps = eps.select(group="qonnx_custom_ops") + else: + eps = eps.get("qonnx_custom_ops", []) + for ep in eps: + try: + ep.load() + except Exception as e: # pragma: no cover - import failure warning + warnings.warn(f"Failed to load custom op entry point {ep.name}: {e}") + except Exception as e: # pragma: no cover - metadata failure warning + warnings.warn(f"Failed to query custom op entry points: {e}") + + +_load_custom_op_entry_points() + diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f1d7c39b..1b2ebe01 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -2,8 +2,4 @@ from qonnx.custom_op.channels_last.conv import Conv from qonnx.custom_op.channels_last.max_pool import MaxPool -custom_op = dict() - -custom_op["Conv"] = Conv -custom_op["MaxPool"] = MaxPool -custom_op["BatchNormalization"] = BatchNormalization +__all__ = ["Conv", "MaxPool", "BatchNormalization"] diff --git a/src/qonnx/custom_op/channels_last/batch_normalization.py b/src/qonnx/custom_op/channels_last/batch_normalization.py index f3b3f872..bd5d3b60 100644 --- a/src/qonnx/custom_op/channels_last/batch_normalization.py +++ b/src/qonnx/custom_op/channels_last/batch_normalization.py @@ -30,8 +30,10 @@ from onnx import TensorProto, helper from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp +from qonnx.custom_op.registry import register_op +@register_op(domain="qonnx.custom_op.channels_last", op_type="BatchNormalization") class BatchNormalization(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/conv.py b/src/qonnx/custom_op/channels_last/conv.py index b0ff237b..06a25508 100644 --- a/src/qonnx/custom_op/channels_last/conv.py +++ b/src/qonnx/custom_op/channels_last/conv.py @@ -31,8 +31,10 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.im2col import compute_conv_output_dim +from qonnx.custom_op.registry import register_op +@register_op(domain="qonnx.custom_op.channels_last", op_type="Conv") class Conv(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/max_pool.py b/src/qonnx/custom_op/channels_last/max_pool.py index 383f3008..1bb9a1ce 100644 --- a/src/qonnx/custom_op/channels_last/max_pool.py +++ b/src/qonnx/custom_op/channels_last/max_pool.py @@ -30,9 +30,11 @@ from onnx import TensorProto, helper from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp +from qonnx.custom_op.registry import register_op from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim +@register_op(domain="qonnx.custom_op.channels_last", op_type="MaxPool") class MaxPool(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index a656d4a5..09b9380c 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -37,15 +37,15 @@ from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul -custom_op = dict() - -custom_op["DebugMarker"] = DebugMarker -custom_op["QuantAvgPool2d"] = QuantAvgPool2d -custom_op["MaxPoolNHWC"] = MaxPoolNHWC -custom_op["GenericPartition"] = GenericPartition -custom_op["MultiThreshold"] = MultiThreshold -custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul -custom_op["Im2Col"] = Im2Col -custom_op["Quant"] = Quant -custom_op["Trunc"] = Trunc -custom_op["BipolarQuant"] = BipolarQuant +__all__ = [ + "DebugMarker", + "QuantAvgPool2d", + "MaxPoolNHWC", + "GenericPartition", + "MultiThreshold", + "XnorPopcountMatMul", + "Im2Col", + "Quant", + "Trunc", + "BipolarQuant", +] diff --git a/src/qonnx/custom_op/general/bipolar_quant.py b/src/qonnx/custom_op/general/bipolar_quant.py index 986a7082..e6a72486 100644 --- a/src/qonnx/custom_op/general/bipolar_quant.py +++ b/src/qonnx/custom_op/general/bipolar_quant.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op def binary_quant(inp_tensor, scale): @@ -47,6 +48,7 @@ def binary_quant(inp_tensor, scale): return out_tensor +@register_op(domain="qonnx.custom_op.general", op_type="BipolarQuant") class BipolarQuant(CustomOp): """Bipolar quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/debugmarker.py b/src/qonnx/custom_op/general/debugmarker.py index ae8cbce5..15e88d8e 100644 --- a/src/qonnx/custom_op/general/debugmarker.py +++ b/src/qonnx/custom_op/general/debugmarker.py @@ -29,8 +29,10 @@ from onnx import helper from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op +@register_op(domain="qonnx.custom_op.general", op_type="DebugMarker") class DebugMarker(CustomOp): def get_nodeattr_types(self): return {"export_debug_name": ("s", True, "")} diff --git a/src/qonnx/custom_op/general/genericpartition.py b/src/qonnx/custom_op/general/genericpartition.py index 841e4e9b..0f6fa104 100755 --- a/src/qonnx/custom_op/general/genericpartition.py +++ b/src/qonnx/custom_op/general/genericpartition.py @@ -29,8 +29,10 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.onnx_exec import execute_onnx from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op +@register_op(domain="qonnx.custom_op.general", op_type="GenericPartition") class GenericPartition(CustomOp): """Class that corresponds to the meta/container node GenericPartition which is a placeholder for a group of nodes that have been separated diff --git a/src/qonnx/custom_op/general/im2col.py b/src/qonnx/custom_op/general/im2col.py index 42477832..22b08a5a 100644 --- a/src/qonnx/custom_op/general/im2col.py +++ b/src/qonnx/custom_op/general/im2col.py @@ -31,6 +31,7 @@ import qonnx.util.basic as util from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op # adapted from A. Karpathy's CS231 im2col code # utilities to generate a patch matrix from a multichannel image @@ -140,6 +141,7 @@ def im2col_indices_nchw( # oh/ow and kh/kw will also be 1 in this case +@register_op(domain="qonnx.custom_op.general", op_type="Im2Col") class Im2Col(CustomOp): def get_nodeattr_types(self): return { diff --git a/src/qonnx/custom_op/general/maxpoolnhwc.py b/src/qonnx/custom_op/general/maxpoolnhwc.py index eb964fc4..81a6c4cb 100644 --- a/src/qonnx/custom_op/general/maxpoolnhwc.py +++ b/src/qonnx/custom_op/general/maxpoolnhwc.py @@ -33,6 +33,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op from qonnx.util.basic import qonnx_make_model @@ -44,6 +45,7 @@ def compute_pool_output_dim(ifm_dim, k, stride, pad=0, ceil_mode=0): return int(np.floor(((ifm_dim + 2 * pad - k) / stride) + 1)) +@register_op(domain="qonnx.custom_op.general", op_type="MaxPoolNHWC") class MaxPoolNHWC(CustomOp): # a MaxPool node, but using the NHWC data layout diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 6df58f95..0a5ec596 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op def multithreshold(v, thresholds, out_scale=None, out_bias=None): @@ -84,6 +85,7 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): return out_scale * ret.reshape(v.shape) + out_bias +@register_op(domain="qonnx.custom_op.general", op_type="MultiThreshold") class MultiThreshold(CustomOp): """Class that corresponds to a multithresholding node.""" diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index f81495d2..39c9f0f4 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int: @@ -165,6 +166,7 @@ def round_half_down(x): raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") +@register_op(domain="qonnx.custom_op.general", op_type="Quant") class Quant(CustomOp): """Generic quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index c0e24071..344d999b 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -32,10 +32,12 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim from qonnx.util.basic import qonnx_make_model +@register_op(domain="qonnx.custom_op.general", op_type="QuantAvgPool2d") class QuantAvgPool2d(CustomOp): """CustomOp that corresponds to the quantized average pooling layer from Brevitas""" diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 8e2eaa19..ca2310b0 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op from qonnx.custom_op.general.quant import resolve_rounding_mode @@ -58,6 +59,7 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding return y +@register_op(domain="qonnx.custom_op.general", op_type="Trunc") class Trunc(CustomOp): """Generic truncation operation for QONNX. Takes four inputs: - input tensor to truncate diff --git a/src/qonnx/custom_op/general/xnorpopcount.py b/src/qonnx/custom_op/general/xnorpopcount.py index 9a640599..a91d412b 100644 --- a/src/qonnx/custom_op/general/xnorpopcount.py +++ b/src/qonnx/custom_op/general/xnorpopcount.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op def xnorpopcountmatmul(inp0, inp1): @@ -60,6 +61,7 @@ def xnorpopcountmatmul(inp0, inp1): return (out + K) * 0.5 +@register_op(domain="qonnx.custom_op.general", op_type="XnorPopcountMatMul") class XnorPopcountMatMul(CustomOp): """Class that corresponds to a XNOR-popcount matrix multiplication node.""" diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 3540bb5a..a3403918 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -27,24 +27,78 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import importlib +import warnings +from importlib import metadata from qonnx.util.basic import get_preferred_onnx_opset +# global registry mapping (domain, op_type) -> CustomOp subclass +CUSTOM_OP_REGISTRY = {} + + +def register_op(domain, op_type): + """Decorator for registering CustomOp classes.""" + + def decorator(cls): + CUSTOM_OP_REGISTRY[(domain, op_type)] = cls + return cls + + return decorator + + +def _load_entry_points(): + """Load custom op modules registered via entry points.""" + + try: + eps = metadata.entry_points() + # compatibility between Python versions + if hasattr(eps, "select"): + eps = eps.select(group="qonnx_custom_ops") + else: + eps = eps.get("qonnx_custom_ops", []) + for ep in eps: + try: + ep.load() + except Exception as e: # pragma: no cover - import failure warning + warnings.warn(f"Failed to load custom op entry point {ep.name}: {e}") + except Exception as e: # pragma: no cover - metadata failure warning + warnings.warn(f"Failed to query custom op entry points: {e}") + + +# load entry points on module import +_load_entry_points() + def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - "Return a QONNX CustomOp instance for the given ONNX node, if it exists." + """Return a QONNX CustomOp instance for the given ONNX node, if it exists.""" + op_type = node.op_type domain = node.domain if brevitas_exception: # transparently resolve Brevitas domain ops to qonnx ones domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") + + key = (domain, op_type) + cls = CUSTOM_OP_REGISTRY.get(key) + if cls is not None: + return cls(node, onnx_opset_version=onnx_opset_version) + try: opset_module = importlib.import_module(domain) - assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain - inst_wrapper = opset_module.custom_op[op_type] - inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version) - return inst except ModuleNotFoundError: - raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain) - except KeyError: - raise Exception("Op %s not found in custom opset %s" % (op_type, domain)) + raise Exception(f"Could not load custom opset {domain}, check your PYTHONPATH") + + # op may have registered itself on import + cls = CUSTOM_OP_REGISTRY.get(key) + if cls is not None: + return cls(node, onnx_opset_version=onnx_opset_version) + + # fallback to legacy custom_op dictionary + if hasattr(opset_module, "custom_op") and isinstance(opset_module.custom_op, dict): + try: + inst_wrapper = opset_module.custom_op[op_type] + return inst_wrapper(node, onnx_opset_version=onnx_opset_version) + except KeyError: + pass + + raise Exception(f"Op {op_type} not found in custom opset {domain}") diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..5d585d0c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -44,7 +44,8 @@ from qonnx.util.onnx import is_eltwise_optype # Standard ONNX nodes which require a ChannelsLast data format to function properly -_channelsLast_node_types = list(channels_last.custom_op.keys()) +# use the list of exported op names from the channels_last package +_channelsLast_node_types = list(channels_last.__all__) # Nodes, which do not modify the shape of the tensor # And modify all values in the same way. diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index cde5a321..4ea18d30 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -29,12 +29,12 @@ import numpy as np import onnx.parser as oprs -import qonnx.custom_op.general as general from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import getCustomOp +from qonnx.custom_op.registry import getCustomOp, register_op +@register_op(domain="qonnx.custom_op.general", op_type="AttrTestOp") class AttrTestOp(CustomOp): def get_nodeattr_types(self): my_attrs = {"tensor_attr": ("t", True, np.asarray([])), "strings_attr": ("strings", True, [""])} @@ -60,7 +60,6 @@ def verify_node(self): def test_attr(): - general.custom_op["AttrTestOp"] = AttrTestOp ishp = (1, 10) wshp = (1, 3) oshp = wshp From 858cf562508f8e70d0600dcab59a00823e905609 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Sun, 15 Jun 2025 18:17:15 -0700 Subject: [PATCH 02/29] Add legacy domain fallback test --- tests/custom_op/legacy_custom_op.py | 21 ++++++++++++++++++ tests/custom_op/test_old_domain.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 tests/custom_op/legacy_custom_op.py create mode 100644 tests/custom_op/test_old_domain.py diff --git a/tests/custom_op/legacy_custom_op.py b/tests/custom_op/legacy_custom_op.py new file mode 100644 index 00000000..95a1302b --- /dev/null +++ b/tests/custom_op/legacy_custom_op.py @@ -0,0 +1,21 @@ +from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op + +@register_op(domain="legacy_custom_op", op_type="LegacyAdd") +class LegacyAdd(CustomOp): + def get_nodeattr_types(self): + return {} + + def make_shape_compatible_op(self, model): + return super().make_const_shape_op([1]) + + def infer_node_datatype(self, model): + pass + + def execute_node(self, context, graph): + a = context[self.onnx_node.input[0]] + b = context[self.onnx_node.input[1]] + context[self.onnx_node.output[0]] = a + b + + def verify_node(self): + pass diff --git a/tests/custom_op/test_old_domain.py b/tests/custom_op/test_old_domain.py new file mode 100644 index 00000000..eb3479e9 --- /dev/null +++ b/tests/custom_op/test_old_domain.py @@ -0,0 +1,33 @@ +import sys +from onnx import helper, TensorProto + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.util.basic import qonnx_make_model + + +def test_get_custom_op_old_domain(): + print('sys.path0', sys.path[0]) + assert "legacy_custom_op" not in sys.modules + + node = helper.make_node( + "LegacyAdd", + ["a", "b"], + ["c"], + domain="legacy_custom_op", + ) + + graph = helper.make_graph( + [node], + "legacy_graph", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("b", TensorProto.FLOAT, [1]), + ], + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, [1])], + ) + model = qonnx_make_model(graph, producer_name="legacy-test") + model = ModelWrapper(model) + + inst = getCustomOp(model.graph.node[0]) + assert inst.__class__.__name__ == "LegacyAdd" From 5036a7af98c63c1e02e557a474629ff63f530d35 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Sun, 15 Jun 2025 18:54:16 -0700 Subject: [PATCH 03/29] Remove debug output from old domain test --- tests/custom_op/test_old_domain.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/custom_op/test_old_domain.py b/tests/custom_op/test_old_domain.py index eb3479e9..5f9ec57d 100644 --- a/tests/custom_op/test_old_domain.py +++ b/tests/custom_op/test_old_domain.py @@ -7,7 +7,6 @@ def test_get_custom_op_old_domain(): - print('sys.path0', sys.path[0]) assert "legacy_custom_op" not in sys.modules node = helper.make_node( From e59e558c75d154e7eafdb79671373032a0c17106 Mon Sep 17 00:00:00 2001 From: tafk7 Date: Fri, 20 Jun 2025 17:54:52 +0000 Subject: [PATCH 04/29] Added passthrough Quant class --- src/qonnx/custom_op/general/__init__.py | 1 + src/qonnx/custom_op/general/quant.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index 0e5d9f53..c2bb7a82 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -35,6 +35,7 @@ from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC from qonnx.custom_op.general.multithreshold import MultiThreshold from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d +from qonnx.custom_op.general.quant import Quant from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 3d448dc3..d204bce6 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -26,11 +26,18 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from qonnx.custom_op.general.intquant import IntQuant as Quant +from qonnx.custom_op.general.intquant import IntQuant from qonnx.custom_op.general.intquant import int_quant as quant from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode +from qonnx.custom_op.registry import register_op -Quant = Quant +# Create alias and register it separately for "Quant" op_type +@register_op(domain="qonnx.custom_op.general", op_type="Quant") +class Quant(IntQuant): + """Alias for IntQuant to support legacy \"Quant\" op_type.""" + pass + +# Re-export functions quant = quant max_int = max_int min_int = min_int From 30df133a2c11a2660d4b23bd8d42a05672566c97 Mon Sep 17 00:00:00 2001 From: auphelia Date: Mon, 23 Jun 2025 10:06:54 +0100 Subject: [PATCH 05/29] Bring back lost changes from custom/brainsmith branch --- setup.cfg | 2 +- .../transformation/extract_quant_scale_zeropt.py | 8 ++++++++ src/qonnx/transformation/gemm_to_matmul.py | 13 ++++++++++++- src/qonnx/util/basic.py | 2 +- 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index a8b8f915..2fc28b09 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = importlib-metadata attrs>=22.2.0 clize>=5.0.1 - protobuf==3.20.3 + protobuf>=3.20.3 bitstring>=3.1.7 numpy>=1.24.1 onnx>=1.13.0 diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 58863f08..614df416 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -69,6 +69,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) + if hasattr(node, "metadata_props"): + inp_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -87,6 +89,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) + if hasattr(node, "metadata_props"): + inp_zeropt_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -108,6 +112,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_zeropt_node.metadata_props.extend(node.metadata_props) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -127,6 +133,8 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 5396a7d6..1298f3d6 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -76,6 +76,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -98,6 +100,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -109,6 +113,8 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) + if hasattr(n, "metadata_props"): + matMul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -144,6 +150,8 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -175,6 +183,8 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -196,7 +206,8 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - + if hasattr(n, "metadata_props"): + add_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, add_node) running_node_index += 1 diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 303331c5..a7ee197b 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -64,7 +64,7 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(op_type): "Return whether given op_type string is a QONNX or FINN custom op" - return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") + return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") or op_type.startswith("brainsmith") def get_num_default_workers(): From dad06c748d933b2807cb7b7f1077b2ced1abcc8d Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Tue, 8 Jul 2025 00:09:15 +0000 Subject: [PATCH 06/29] Refined domain-based registration --- src/qonnx/custom_op/__init__.py | 35 +++++++++++++ src/qonnx/custom_op/registry.py | 90 ++++++++++++++++++++++++--------- src/qonnx/util/basic.py | 3 +- 3 files changed, 103 insertions(+), 25 deletions(-) diff --git a/src/qonnx/custom_op/__init__.py b/src/qonnx/custom_op/__init__.py index e69de29b..95378696 100644 --- a/src/qonnx/custom_op/__init__.py +++ b/src/qonnx/custom_op/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2020 Xilinx, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from qonnx.custom_op.registry import register_custom_domain + +# Pre-register known custom op domains +register_custom_domain("qonnx.custom_op") +register_custom_domain("finn") +register_custom_domain("brainsmith") +register_custom_domain("onnx.brevitas") \ No newline at end of file diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index a3403918..f7ad3ad8 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -35,38 +35,80 @@ # global registry mapping (domain, op_type) -> CustomOp subclass CUSTOM_OP_REGISTRY = {} - -def register_op(domain, op_type): - """Decorator for registering CustomOp classes.""" +# global registry for custom op domains +_CUSTOM_DOMAINS = set() + +# global registry for custom op metadata +_OP_METADATA = {} + + +def register_custom_domain(domain): + """Register a domain as containing custom ops.""" + _CUSTOM_DOMAINS.add(domain) + + +def is_custom_op_domain(domain): + """Check if domain is registered for custom ops.""" + return any(domain.startswith(d) for d in _CUSTOM_DOMAINS) + + +def hasCustomOp(domain, op_type): + """Check if a custom op exists without creating an instance. + + Args: + domain: The domain of the custom op + op_type: The op_type of the custom op + + Returns: + bool: True if the op is registered, False otherwise + """ + return (domain, op_type) in CUSTOM_OP_REGISTRY + + +def get_ops_in_domain(domain): + """Get all registered ops in a domain. + + Args: + domain: The domain to query + + Returns: + List[Tuple[str, Type[CustomOp]]]: List of (op_type, class) tuples + """ + return [(op_type, cls) for (d, op_type), cls in CUSTOM_OP_REGISTRY.items() + if d == domain] + + +def register_op(domain, op_type, metadata=None): + """Decorator for registering CustomOp classes. + + Args: + domain: The domain for the custom op + op_type: The op_type for the custom op + metadata: Optional dict of metadata about the op (backend, version, etc.) + """ def decorator(cls): + # Auto-register the domain when an op is registered + register_custom_domain(domain) CUSTOM_OP_REGISTRY[(domain, op_type)] = cls + if metadata is not None: + _OP_METADATA[(domain, op_type)] = metadata return cls return decorator -def _load_entry_points(): - """Load custom op modules registered via entry points.""" - - try: - eps = metadata.entry_points() - # compatibility between Python versions - if hasattr(eps, "select"): - eps = eps.select(group="qonnx_custom_ops") - else: - eps = eps.get("qonnx_custom_ops", []) - for ep in eps: - try: - ep.load() - except Exception as e: # pragma: no cover - import failure warning - warnings.warn(f"Failed to load custom op entry point {ep.name}: {e}") - except Exception as e: # pragma: no cover - metadata failure warning - warnings.warn(f"Failed to query custom op entry points: {e}") - - -# load entry points on module import -_load_entry_points() +def get_op_metadata(domain, op_type): + """Get metadata for a registered custom op. + + Args: + domain: The domain of the custom op + op_type: The op_type of the custom op + + Returns: + dict: The metadata dict if available, None otherwise + """ + return _OP_METADATA.get((domain, op_type)) def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index a7ee197b..abc5449f 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -64,7 +64,8 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(op_type): "Return whether given op_type string is a QONNX or FINN custom op" - return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") or op_type.startswith("brainsmith") + from qonnx.custom_op.registry import is_custom_op_domain + return is_custom_op_domain(op_type) def get_num_default_workers(): From f6806f6186955e7665b2edbe603cc9d1e2d9e28a Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Wed, 16 Jul 2025 05:25:28 +0000 Subject: [PATCH 07/29] Refined custom_op registration --- docs/overview.rst | 6 +- src/qonnx/custom_op/__init__.py | 17 ++- .../channels_last/batch_normalization.py | 4 +- src/qonnx/custom_op/channels_last/conv.py | 4 +- src/qonnx/custom_op/channels_last/max_pool.py | 4 +- src/qonnx/custom_op/general/bipolar_quant.py | 4 +- src/qonnx/custom_op/general/debugmarker.py | 4 +- src/qonnx/custom_op/general/floatquant.py | 4 +- .../custom_op/general/genericpartition.py | 4 +- src/qonnx/custom_op/general/im2col.py | 4 +- src/qonnx/custom_op/general/intquant.py | 4 +- src/qonnx/custom_op/general/maxpoolnhwc.py | 4 +- src/qonnx/custom_op/general/multithreshold.py | 4 +- src/qonnx/custom_op/general/quant.py | 4 +- src/qonnx/custom_op/general/quantavgpool2d.py | 4 +- src/qonnx/custom_op/general/trunc.py | 4 +- src/qonnx/custom_op/general/xnorpopcount.py | 4 +- src/qonnx/custom_op/registry.py | 129 ++++++++++++++---- src/qonnx/util/basic.py | 6 +- tests/custom_op/legacy_custom_op.py | 22 --- tests/custom_op/test_attr.py | 4 +- tests/custom_op/test_old_domain.py | 32 ----- 22 files changed, 148 insertions(+), 128 deletions(-) delete mode 100644 tests/custom_op/legacy_custom_op.py delete mode 100644 tests/custom_op/test_old_domain.py diff --git a/docs/overview.rst b/docs/overview.rst index 935ef4d9..5fd87de9 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -47,9 +47,9 @@ QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defin Custom ops can be registered automatically via Python entry points using the ``qonnx_custom_ops`` group. Each operator class should be decorated with -``@register_op(domain="...", op_type="...")`` from -``qonnx.custom_op.registry``. Packages installed with such an entry point will -be discovered on import and their ops made available through +``@register_custom_op`` from ``qonnx.custom_op.registry``, which automatically +infers the domain from the module path. Packages installed with such an entry +point will be discovered on import and their ops made available through ``getCustomOp``. diff --git a/src/qonnx/custom_op/__init__.py b/src/qonnx/custom_op/__init__.py index 95378696..be4f2162 100644 --- a/src/qonnx/custom_op/__init__.py +++ b/src/qonnx/custom_op/__init__.py @@ -26,10 +26,15 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from qonnx.custom_op.registry import register_custom_domain +from qonnx.custom_op.registry import register_domain -# Pre-register known custom op domains -register_custom_domain("qonnx.custom_op") -register_custom_domain("finn") -register_custom_domain("brainsmith") -register_custom_domain("onnx.brevitas") \ No newline at end of file +# Register QONNX domains (module path defaults to domain name) +register_domain("qonnx.custom_op.general") +register_domain("qonnx.custom_op.channels_last") + +# Register parent domain for hierarchy checking +register_domain("qonnx.custom_op") + +# Special case: Brevitas compatibility domain +# (QONNX handles Brevitas ops for backward compatibility) +register_domain("onnx.brevitas") \ No newline at end of file diff --git a/src/qonnx/custom_op/channels_last/batch_normalization.py b/src/qonnx/custom_op/channels_last/batch_normalization.py index bd5d3b60..b97ab1a8 100644 --- a/src/qonnx/custom_op/channels_last/batch_normalization.py +++ b/src/qonnx/custom_op/channels_last/batch_normalization.py @@ -30,10 +30,10 @@ from onnx import TensorProto, helper from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.channels_last", op_type="BatchNormalization") +@register_custom_op class BatchNormalization(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/conv.py b/src/qonnx/custom_op/channels_last/conv.py index 06a25508..dc78a0fd 100644 --- a/src/qonnx/custom_op/channels_last/conv.py +++ b/src/qonnx/custom_op/channels_last/conv.py @@ -31,10 +31,10 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.im2col import compute_conv_output_dim -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.channels_last", op_type="Conv") +@register_custom_op class Conv(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/max_pool.py b/src/qonnx/custom_op/channels_last/max_pool.py index aec2c908..53a7b617 100644 --- a/src/qonnx/custom_op/channels_last/max_pool.py +++ b/src/qonnx/custom_op/channels_last/max_pool.py @@ -31,10 +31,10 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.channels_last", op_type="MaxPool") +@register_custom_op class MaxPool(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/general/bipolar_quant.py b/src/qonnx/custom_op/general/bipolar_quant.py index e6a72486..102f5210 100644 --- a/src/qonnx/custom_op/general/bipolar_quant.py +++ b/src/qonnx/custom_op/general/bipolar_quant.py @@ -31,7 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def binary_quant(inp_tensor, scale): @@ -48,7 +48,7 @@ def binary_quant(inp_tensor, scale): return out_tensor -@register_op(domain="qonnx.custom_op.general", op_type="BipolarQuant") +@register_custom_op class BipolarQuant(CustomOp): """Bipolar quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/debugmarker.py b/src/qonnx/custom_op/general/debugmarker.py index 15e88d8e..3da80521 100644 --- a/src/qonnx/custom_op/general/debugmarker.py +++ b/src/qonnx/custom_op/general/debugmarker.py @@ -29,10 +29,10 @@ from onnx import helper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.general", op_type="DebugMarker") +@register_custom_op class DebugMarker(CustomOp): def get_nodeattr_types(self): return {"export_debug_name": ("s", True, "")} diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py index 56698efb..ab74b1df 100644 --- a/src/qonnx/custom_op/general/floatquant.py +++ b/src/qonnx/custom_op/general/floatquant.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def compute_default_exponent_bias(exponent_bitwidth): @@ -121,7 +121,7 @@ def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask): return x_q * scale # , self.saturating, self.inf_values, self.nan_values -@register_op(domain="qonnx.custom_op.general", op_type="FloatQuant") +@register_custom_op class FloatQuant(CustomOp): """Floating point quantization operation for QONNX. diff --git a/src/qonnx/custom_op/general/genericpartition.py b/src/qonnx/custom_op/general/genericpartition.py index 0f6fa104..3418f9a5 100755 --- a/src/qonnx/custom_op/general/genericpartition.py +++ b/src/qonnx/custom_op/general/genericpartition.py @@ -29,10 +29,10 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.onnx_exec import execute_onnx from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.general", op_type="GenericPartition") +@register_custom_op class GenericPartition(CustomOp): """Class that corresponds to the meta/container node GenericPartition which is a placeholder for a group of nodes that have been separated diff --git a/src/qonnx/custom_op/general/im2col.py b/src/qonnx/custom_op/general/im2col.py index 22b08a5a..276caf7d 100644 --- a/src/qonnx/custom_op/general/im2col.py +++ b/src/qonnx/custom_op/general/im2col.py @@ -31,7 +31,7 @@ import qonnx.util.basic as util from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op # adapted from A. Karpathy's CS231 im2col code # utilities to generate a patch matrix from a multichannel image @@ -141,7 +141,7 @@ def im2col_indices_nchw( # oh/ow and kh/kw will also be 1 in this case -@register_op(domain="qonnx.custom_op.general", op_type="Im2Col") +@register_custom_op class Im2Col(CustomOp): def get_nodeattr_types(self): return { diff --git a/src/qonnx/custom_op/general/intquant.py b/src/qonnx/custom_op/general/intquant.py index 7663e95f..a053e7ef 100644 --- a/src/qonnx/custom_op/general/intquant.py +++ b/src/qonnx/custom_op/general/intquant.py @@ -31,7 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int: @@ -166,7 +166,7 @@ def round_half_down(x): raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") -@register_op(domain="qonnx.custom_op.general", op_type="IntQuant") +@register_custom_op class IntQuant(CustomOp): """Generic quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/maxpoolnhwc.py b/src/qonnx/custom_op/general/maxpoolnhwc.py index 81a6c4cb..44aa04bc 100644 --- a/src/qonnx/custom_op/general/maxpoolnhwc.py +++ b/src/qonnx/custom_op/general/maxpoolnhwc.py @@ -33,7 +33,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op from qonnx.util.basic import qonnx_make_model @@ -45,7 +45,7 @@ def compute_pool_output_dim(ifm_dim, k, stride, pad=0, ceil_mode=0): return int(np.floor(((ifm_dim + 2 * pad - k) / stride) + 1)) -@register_op(domain="qonnx.custom_op.general", op_type="MaxPoolNHWC") +@register_custom_op class MaxPoolNHWC(CustomOp): # a MaxPool node, but using the NHWC data layout diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 0a5ec596..76df4752 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -31,7 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def multithreshold(v, thresholds, out_scale=None, out_bias=None): @@ -85,7 +85,7 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): return out_scale * ret.reshape(v.shape) + out_bias -@register_op(domain="qonnx.custom_op.general", op_type="MultiThreshold") +@register_custom_op class MultiThreshold(CustomOp): """Class that corresponds to a multithresholding node.""" diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index d204bce6..858eafc2 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -29,10 +29,10 @@ from qonnx.custom_op.general.intquant import IntQuant from qonnx.custom_op.general.intquant import int_quant as quant from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op # Create alias and register it separately for "Quant" op_type -@register_op(domain="qonnx.custom_op.general", op_type="Quant") +@register_custom_op class Quant(IntQuant): """Alias for IntQuant to support legacy \"Quant\" op_type.""" pass diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index b152171f..7e51f3e3 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -33,11 +33,11 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op from qonnx.util.basic import qonnx_make_model -@register_op(domain="qonnx.custom_op.general", op_type="QuantAvgPool2d") +@register_custom_op class QuantAvgPool2d(CustomOp): """CustomOp that corresponds to the quantized average pooling layer from Brevitas""" diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 6a59e91b..d2921262 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -32,7 +32,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): @@ -59,7 +59,7 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding return y -@register_op(domain="qonnx.custom_op.general", op_type="Trunc") +@register_custom_op class Trunc(CustomOp): """Generic truncation operation for QONNX. Takes four inputs: - input tensor to truncate diff --git a/src/qonnx/custom_op/general/xnorpopcount.py b/src/qonnx/custom_op/general/xnorpopcount.py index a91d412b..c068cb9d 100644 --- a/src/qonnx/custom_op/general/xnorpopcount.py +++ b/src/qonnx/custom_op/general/xnorpopcount.py @@ -31,7 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def xnorpopcountmatmul(inp0, inp1): @@ -61,7 +61,7 @@ def xnorpopcountmatmul(inp0, inp1): return (out + K) * 0.5 -@register_op(domain="qonnx.custom_op.general", op_type="XnorPopcountMatMul") +@register_custom_op class XnorPopcountMatMul(CustomOp): """Class that corresponds to a XNOR-popcount matrix multiplication node.""" diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index f7ad3ad8..045f4c70 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -35,21 +35,20 @@ # global registry mapping (domain, op_type) -> CustomOp subclass CUSTOM_OP_REGISTRY = {} -# global registry for custom op domains -_CUSTOM_DOMAINS = set() - # global registry for custom op metadata _OP_METADATA = {} +# global registry mapping domains to their module paths +# Structure: DOMAIN_REGISTRY[domain] = module_path (or None if module_path == domain) +DOMAIN_REGISTRY = {} + -def register_custom_domain(domain): - """Register a domain as containing custom ops.""" - _CUSTOM_DOMAINS.add(domain) def is_custom_op_domain(domain): """Check if domain is registered for custom ops.""" - return any(domain.startswith(d) for d in _CUSTOM_DOMAINS) + # Check if domain is directly registered or starts with a registered domain + return domain in DOMAIN_REGISTRY or any(domain.startswith(d) for d in DOMAIN_REGISTRY) def hasCustomOp(domain, op_type): @@ -88,8 +87,6 @@ def register_op(domain, op_type, metadata=None): """ def decorator(cls): - # Auto-register the domain when an op is registered - register_custom_domain(domain) CUSTOM_OP_REGISTRY[(domain, op_type)] = cls if metadata is not None: _OP_METADATA[(domain, op_type)] = metadata @@ -111,36 +108,108 @@ def get_op_metadata(domain, op_type): return _OP_METADATA.get((domain, op_type)) -def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - """Return a QONNX CustomOp instance for the given ONNX node, if it exists.""" +def register_domain(domain, module_path=None): + """Register a domain with its associated module path. + + This function registers the domain and its module path, allowing classes + defined in any direct child module of this path to use @register_custom_op. + Subfolders/subpackages must be registered separately. + + Args: + domain: The domain to register (e.g., "finn.custom_op.fpgadataflow") + module_path: The Python module path. If None, uses the domain as the path. + """ + DOMAIN_REGISTRY[domain] = module_path + + +# Keep register_domain_path as deprecated alias for backward compatibility +def register_domain_path(module_path, domain): + """Deprecated: Use register_domain instead.""" + return register_domain(domain, module_path) + +def register_custom_op(cls=None, *, op_type=None): + """Register a custom op, inferring domain from parent module path. + + Can be used as @register_custom_op or @register_custom_op(op_type="CustomName"). + Domain is inferred from registered module paths. Op type defaults to class name. + + Args: + cls: The class to register (when used without parentheses) + op_type: Optional custom op_type (defaults to class name) + + Returns: + Decorated class or decorator function + """ + def decorator(cls): + # Get module path + module = cls.__module__ + + # Check if module is a direct child of any registered domain's module path + domain = None + for registered_domain, module_path in DOMAIN_REGISTRY.items(): + # Use domain as module path if not specified + if module_path is None: + module_path = registered_domain + # Check if module is direct child of registered path + if module.startswith(module_path + "."): + # Ensure it's a direct child, not nested deeper + remainder = module[len(module_path) + 1:] + if "." not in remainder: # No more dots = direct child + domain = registered_domain + break + # Also check exact match (for __init__.py files) + elif module == module_path: + domain = registered_domain + break + + if domain is None: + raise ValueError( + f"Module '{module}' is not in a registered domain path. " + f"Either:\n" + f"1. Use @register_op(domain='...', op_type='{cls.__name__}')\n" + f"2. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" + ) + + # Use class name as op_type if not specified + final_op_type = op_type or cls.__name__ + + # Register using the standard mechanism + return register_op(domain=domain, op_type=final_op_type)(cls) + + # Handle both @register_custom_op and @register_custom_op() + if cls is None: + return decorator + return decorator(cls) + + +def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): + """Return a QONNX CustomOp instance for the given ONNX node.""" op_type = node.op_type domain = node.domain + if brevitas_exception: # transparently resolve Brevitas domain ops to qonnx ones domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") - + key = (domain, op_type) cls = CUSTOM_OP_REGISTRY.get(key) if cls is not None: return cls(node, onnx_opset_version=onnx_opset_version) - - try: - opset_module = importlib.import_module(domain) - except ModuleNotFoundError: - raise Exception(f"Could not load custom opset {domain}, check your PYTHONPATH") - - # op may have registered itself on import - cls = CUSTOM_OP_REGISTRY.get(key) - if cls is not None: - return cls(node, onnx_opset_version=onnx_opset_version) - - # fallback to legacy custom_op dictionary - if hasattr(opset_module, "custom_op") and isinstance(opset_module.custom_op, dict): + + # Check if we need to import the module to trigger registration + if domain.startswith("finn.custom_op"): try: - inst_wrapper = opset_module.custom_op[op_type] - return inst_wrapper(node, onnx_opset_version=onnx_opset_version) - except KeyError: + importlib.import_module(domain) + # Check again after import + cls = CUSTOM_OP_REGISTRY.get(key) + if cls is not None: + return cls(node, onnx_opset_version=onnx_opset_version) + except ImportError: pass - - raise Exception(f"Op {op_type} not found in custom opset {domain}") + + available_domains = sorted(DOMAIN_REGISTRY.keys()) + raise Exception( + f"Op '{op_type}' not found in domain '{domain}'. " + f"Available domains: {available_domains}" + ) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index abc5449f..68d691e1 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -62,10 +62,10 @@ def qonnx_make_model(graph_proto, **kwargs): return make_model(graph_proto, **kwargs) -def is_finn_op(op_type): - "Return whether given op_type string is a QONNX or FINN custom op" +def is_finn_op(domain): + "Return whether given domain string is a QONNX or FINN custom op domain" from qonnx.custom_op.registry import is_custom_op_domain - return is_custom_op_domain(op_type) + return is_custom_op_domain(domain) def get_num_default_workers(): diff --git a/tests/custom_op/legacy_custom_op.py b/tests/custom_op/legacy_custom_op.py deleted file mode 100644 index 95b689b9..00000000 --- a/tests/custom_op/legacy_custom_op.py +++ /dev/null @@ -1,22 +0,0 @@ -from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op - - -@register_op(domain="legacy_custom_op", op_type="LegacyAdd") -class LegacyAdd(CustomOp): - def get_nodeattr_types(self): - return {} - - def make_shape_compatible_op(self, model): - return super().make_const_shape_op([1]) - - def infer_node_datatype(self, model): - pass - - def execute_node(self, context, graph): - a = context[self.onnx_node.input[0]] - b = context[self.onnx_node.input[1]] - context[self.onnx_node.output[0]] = a + b - - def verify_node(self): - pass diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index 4ea18d30..592a99ae 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -31,10 +31,10 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import getCustomOp, register_op +from qonnx.custom_op.registry import getCustomOp, register_custom_op -@register_op(domain="qonnx.custom_op.general", op_type="AttrTestOp") +@register_custom_op class AttrTestOp(CustomOp): def get_nodeattr_types(self): my_attrs = {"tensor_attr": ("t", True, np.asarray([])), "strings_attr": ("strings", True, [""])} diff --git a/tests/custom_op/test_old_domain.py b/tests/custom_op/test_old_domain.py deleted file mode 100644 index 88ec226f..00000000 --- a/tests/custom_op/test_old_domain.py +++ /dev/null @@ -1,32 +0,0 @@ -import sys -from onnx import TensorProto, helper - -from qonnx.core.modelwrapper import ModelWrapper -from qonnx.custom_op.registry import getCustomOp -from qonnx.util.basic import qonnx_make_model - - -def test_get_custom_op_old_domain(): - assert "legacy_custom_op" not in sys.modules - - node = helper.make_node( - "LegacyAdd", - ["a", "b"], - ["c"], - domain="legacy_custom_op", - ) - - graph = helper.make_graph( - [node], - "legacy_graph", - inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [1]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [1]), - ], - outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, [1])], - ) - model = qonnx_make_model(graph, producer_name="legacy-test") - model = ModelWrapper(model) - - inst = getCustomOp(model.graph.node[0]) - assert inst.__class__.__name__ == "LegacyAdd" From f7ab4b5cb8e5e56dc9dce0c22acdb85e95de737e Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Wed, 16 Jul 2025 21:50:57 +0000 Subject: [PATCH 08/29] Dependency resolution --- src/qonnx/custom_op/registry.py | 228 +++++++++++++++++++++++--------- 1 file changed, 163 insertions(+), 65 deletions(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 045f4c70..e8be6f22 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -42,7 +42,73 @@ # Structure: DOMAIN_REGISTRY[domain] = module_path (or None if module_path == domain) DOMAIN_REGISTRY = {} +# Track which domains have been loaded +_LOADED_DOMAINS = set() +# Track domain dependencies discovered through inheritance +_DOMAIN_DEPENDENCIES = {} # domain -> set of dependency domains + + + + +def _ensure_domain_loaded(domain): + """Ensure a domain and its dependencies are loaded.""" + if domain in _LOADED_DOMAINS: + return + + # Mark as loaded first to prevent infinite recursion + _LOADED_DOMAINS.add(domain) + + # First load any known dependencies + if domain in _DOMAIN_DEPENDENCIES: + for dep_domain in _DOMAIN_DEPENDENCIES[domain]: + if dep_domain != domain: # Avoid self-dependencies + _ensure_domain_loaded(dep_domain) + + # Try to import the domain module + if domain in DOMAIN_REGISTRY: + module_path = DOMAIN_REGISTRY[domain] or domain + try: + importlib.import_module(module_path) + except ImportError as e: + # Remove from loaded if import failed + _LOADED_DOMAINS.discard(domain) + # Continue without raising - domain might still work + elif domain.startswith(("finn.", "qonnx.")): + # Try importing even if not in registry + try: + importlib.import_module(domain) + except ImportError: + # Remove from loaded if import failed + _LOADED_DOMAINS.discard(domain) + + +def _register_op_with_dependencies(domain, op_type, cls, metadata=None): + """Register an op and track its inheritance dependencies.""" + # Register the op + CUSTOM_OP_REGISTRY[(domain, op_type)] = cls + if metadata is not None: + _OP_METADATA[(domain, op_type)] = metadata + + # Detect dependencies from inheritance + for base in cls.__bases__: + # Skip abstract base classes and non-custom ops + if base.__name__ in ('CustomOp', 'ABC', 'object', 'HWCustomOp', 'HLSBackend', 'RTLBackend'): + continue + + # Check if base class is a registered custom op + for (reg_domain, reg_op), reg_cls in CUSTOM_OP_REGISTRY.items(): + if reg_cls == base: + # Found a dependency - track it + if domain not in _DOMAIN_DEPENDENCIES: + _DOMAIN_DEPENDENCIES[domain] = set() + _DOMAIN_DEPENDENCIES[domain].add(reg_domain) + + # Immediately ensure the dependency is loaded + _ensure_domain_loaded(reg_domain) + break + + return cls def is_custom_op_domain(domain): @@ -61,6 +127,8 @@ def hasCustomOp(domain, op_type): Returns: bool: True if the op is registered, False otherwise """ + # Ensure domain is loaded first + _ensure_domain_loaded(domain) return (domain, op_type) in CUSTOM_OP_REGISTRY @@ -77,24 +145,6 @@ def get_ops_in_domain(domain): if d == domain] -def register_op(domain, op_type, metadata=None): - """Decorator for registering CustomOp classes. - - Args: - domain: The domain for the custom op - op_type: The op_type for the custom op - metadata: Optional dict of metadata about the op (backend, version, etc.) - """ - - def decorator(cls): - CUSTOM_OP_REGISTRY[(domain, op_type)] = cls - if metadata is not None: - _OP_METADATA[(domain, op_type)] = metadata - return cls - - return decorator - - def get_op_metadata(domain, op_type): """Get metadata for a registered custom op. @@ -128,59 +178,102 @@ def register_domain_path(module_path, domain): return register_domain(domain, module_path) -def register_custom_op(cls=None, *, op_type=None): - """Register a custom op, inferring domain from parent module path. + +def register_custom_op(domain=None, op_type=None, *, metadata=None): + """Register a custom op with flexible domain and op_type specification. - Can be used as @register_custom_op or @register_custom_op(op_type="CustomName"). - Domain is inferred from registered module paths. Op type defaults to class name. + Can be used in three ways: + 1. @register_custom_op("domain", "OpType") - Explicit domain and op_type + 2. @register_custom_op("domain") - Explicit domain, class name as op_type + 3. @register_custom_op - Automatic domain inference, class name as op_type Args: - cls: The class to register (when used without parentheses) - op_type: Optional custom op_type (defaults to class name) + domain: The domain for the custom op (optional) + op_type: The op_type for the custom op (optional) + metadata: Optional dict of metadata about the op (backend, version, etc.) Returns: Decorated class or decorator function """ - def decorator(cls): - # Get module path - module = cls.__module__ - - # Check if module is a direct child of any registered domain's module path - domain = None - for registered_domain, module_path in DOMAIN_REGISTRY.items(): - # Use domain as module path if not specified - if module_path is None: - module_path = registered_domain - # Check if module is direct child of registered path - if module.startswith(module_path + "."): - # Ensure it's a direct child, not nested deeper - remainder = module[len(module_path) + 1:] - if "." not in remainder: # No more dots = direct child - domain = registered_domain + # Determine which mode we're in based on arguments + if domain is not None and isinstance(domain, str): + # Mode 1 or 2: Explicit domain provided + if op_type is not None and isinstance(op_type, str): + # Mode 1: Both domain and op_type provided + def decorator(cls): + return _register_op_with_dependencies(domain, op_type, cls, metadata) + return decorator + else: + # Mode 2: Only domain provided, use class name as op_type + def decorator(cls): + final_op_type = cls.__name__ + return _register_op_with_dependencies(domain, final_op_type, cls, metadata) + return decorator + else: + # Mode 3: No domain provided, or called without arguments + # Handle the case where it's used as @register_custom_op (no parentheses) + if domain is not None and not isinstance(domain, str): + # This means domain is actually the class (decorator without parentheses) + cls = domain + module = cls.__module__ + + # Find domain from registered domains + inferred_domain = None + for registered_domain, module_path in DOMAIN_REGISTRY.items(): + if module_path is None: + module_path = registered_domain + if module.startswith(module_path + "."): + remainder = module[len(module_path) + 1:] + if "." not in remainder: + inferred_domain = registered_domain + break + elif module == module_path: + inferred_domain = registered_domain break - # Also check exact match (for __init__.py files) - elif module == module_path: - domain = registered_domain - break - - if domain is None: - raise ValueError( - f"Module '{module}' is not in a registered domain path. " - f"Either:\n" - f"1. Use @register_op(domain='...', op_type='{cls.__name__}')\n" - f"2. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" - ) - - # Use class name as op_type if not specified - final_op_type = op_type or cls.__name__ - - # Register using the standard mechanism - return register_op(domain=domain, op_type=final_op_type)(cls) - - # Handle both @register_custom_op and @register_custom_op() - if cls is None: - return decorator - return decorator(cls) + + if inferred_domain is None: + raise ValueError( + f"Module '{module}' is not in a registered domain path. " + f"Either:\n" + f"1. Use @register_custom_op('domain', 'OpType')\n" + f"2. Use @register_custom_op('domain') to use class name as op_type\n" + f"3. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" + ) + + final_op_type = cls.__name__ + return _register_op_with_dependencies(inferred_domain, final_op_type, cls, metadata) + else: + # Decorator called with parentheses but no domain + def decorator(cls): + module = cls.__module__ + + # Find domain from registered domains + inferred_domain = None + for registered_domain, module_path in DOMAIN_REGISTRY.items(): + if module_path is None: + module_path = registered_domain + if module.startswith(module_path + "."): + remainder = module[len(module_path) + 1:] + if "." not in remainder: + inferred_domain = registered_domain + break + elif module == module_path: + inferred_domain = registered_domain + break + + if inferred_domain is None: + raise ValueError( + f"Module '{module}' is not in a registered domain path. " + f"Either:\n" + f"1. Use @register_custom_op('domain', 'OpType')\n" + f"2. Use @register_custom_op('domain') to use class name as op_type\n" + f"3. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" + ) + + # Use provided op_type or default to class name + final_op_type = op_type or cls.__name__ + return _register_op_with_dependencies(inferred_domain, final_op_type, cls, metadata) + return decorator def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): @@ -192,15 +285,20 @@ def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_ex # transparently resolve Brevitas domain ops to qonnx ones domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") + # Ensure the domain is loaded (will load dependencies automatically) + _ensure_domain_loaded(domain) + key = (domain, op_type) cls = CUSTOM_OP_REGISTRY.get(key) if cls is not None: return cls(node, onnx_opset_version=onnx_opset_version) - # Check if we need to import the module to trigger registration - if domain.startswith("finn.custom_op"): + # If not found and domain starts with finn, try explicit import as fallback + # This handles cases where domain isn't registered but module exists + if domain.startswith("finn.custom_op") and domain not in _LOADED_DOMAINS: try: importlib.import_module(domain) + _LOADED_DOMAINS.add(domain) # Check again after import cls = CUSTOM_OP_REGISTRY.get(key) if cls is not None: From d08c33dba97f2ed9ec0c9426c284696c40dd5270 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 16 Jul 2025 23:00:47 +0000 Subject: [PATCH 09/29] help multithreshold handle 3-dim more efficiently --- src/qonnx/custom_op/general/multithreshold.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 0a5ec596..f54fb408 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -145,6 +145,10 @@ def execute_node(self, context, graph): # TODO: Seems like a rather sketchy solution to support arbitrary data # layouts. This does not even validate the assumption of channel last # layout. + if v.ndim == 3: + orig_shape = v.shape + v = np.expand_dims(v, axis=0) + if v.ndim not in {2, 4}: # Remember the original shape to be restored later orig_shape = v.shape From d76507a66b87c6b7fbc7566a457ed8bb05b68c40 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 17 Jul 2025 22:53:02 +0000 Subject: [PATCH 10/29] update extract model config to export config for subgraphs --- src/qonnx/util/config.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/qonnx/util/config.py b/src/qonnx/util/config.py index 63661862..2f6383d3 100644 --- a/src/qonnx/util/config.py +++ b/src/qonnx/util/config.py @@ -27,13 +27,15 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json +import onnx from qonnx.custom_op.registry import getCustomOp - -def extract_model_config_to_json(model, json_filename, attr_names_to_extract): - """Create a json file with layer name -> attribute mappings extracted from the - model. The created json file can be later applied on a model with +# update this code to handle export configs from subgraphs +# where the subgraph is found in a node's attribute as a graph type +def extract_model_config(model, attr_names_to_extract): + """Create a dictionary with layer name -> attribute mappings extracted from the + model. The created dictionary can be later applied on a model with qonnx.transform.general.ApplyConfig.""" cfg = dict() @@ -41,12 +43,22 @@ def extract_model_config_to_json(model, json_filename, attr_names_to_extract): for n in model.graph.node: oi = getCustomOp(n) layer_dict = dict() - for attr in attr_names_to_extract: - try: - layer_dict[attr] = oi.get_nodeattr(attr) - except AttributeError: - pass + for attr in n.attribute: + if attr.type == onnx.AttributeProto.GRAPH: # Graph type + # If the attribute is a graph, we need to extract the attributes from the subgraph + cfg.update(extract_model_config(model.make_subgraph_modelwrapper(attr.g), attr_names_to_extract)) + elif attr.name in attr_names_to_extract: + # If the attribute name is in the list, we can add it directly + layer_dict[attr.name] = oi.get_nodeattr(attr.name) if len(layer_dict) > 0: cfg[n.name] = layer_dict + return cfg + + +def extract_model_config_to_json(model, json_filename, attr_names_to_extract): + """Create a json file with layer name -> attribute mappings extracted from the + model. The created json file can be later applied on a model with + qonnx.transform.general.ApplyConfig.""" + with open(json_filename, "w") as f: - json.dump(cfg, f, indent=2) + json.dump(extract_model_config(model, attr_names_to_extract), f, indent=2) From fa3e0a81adb933847d0655140d15374bceaa6498 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Fri, 18 Jul 2025 02:00:42 +0000 Subject: [PATCH 11/29] Removed decorators in favor of pure domain --- docs/overview.rst | 18 +- notebooks/3_custom_op.ipynb | 31 +- src/qonnx/custom_op/__init__.py | 14 +- .../channels_last/batch_normalization.py | 2 - src/qonnx/custom_op/channels_last/conv.py | 2 - src/qonnx/custom_op/channels_last/max_pool.py | 2 - src/qonnx/custom_op/general/bipolar_quant.py | 2 - src/qonnx/custom_op/general/debugmarker.py | 2 - src/qonnx/custom_op/general/floatquant.py | 2 - .../custom_op/general/genericpartition.py | 2 - src/qonnx/custom_op/general/im2col.py | 2 - src/qonnx/custom_op/general/intquant.py | 2 - src/qonnx/custom_op/general/maxpoolnhwc.py | 2 - src/qonnx/custom_op/general/multithreshold.py | 2 - src/qonnx/custom_op/general/quant.py | 17 +- src/qonnx/custom_op/general/quantavgpool2d.py | 2 - src/qonnx/custom_op/general/trunc.py | 2 - src/qonnx/custom_op/general/xnorpopcount.py | 2 - src/qonnx/custom_op/registry.py | 395 +++++++----------- src/qonnx/util/basic.py | 16 +- tests/custom_op/test_attr.py | 8 +- 21 files changed, 184 insertions(+), 343 deletions(-) diff --git a/docs/overview.rst b/docs/overview.rst index 5fd87de9..2f0f3577 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -45,12 +45,18 @@ Custom Operations/Nodes QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defined in the ONNX operator schema. These custom nodes are marked with domain="qonnx.*" in the protobuf to identify them as such. These nodes can represent specific operations that we need for low-bit networks, or operations that are specific to a particular hardware backend. To get more familiar with custom operations and how they are created, please take a look in the Jupyter notebook about CustomOps (see chapter :ref:`tutorials` for details) or directly in the module :py:mod:`qonnx.custom_op`. -Custom ops can be registered automatically via Python entry points using the -``qonnx_custom_ops`` group. Each operator class should be decorated with -``@register_custom_op`` from ``qonnx.custom_op.registry``, which automatically -infers the domain from the module path. Packages installed with such an entry -point will be discovered on import and their ops made available through -``getCustomOp``. +Custom ops are automatically discovered through Python module namespaces. +Simply define your CustomOp subclass in the appropriate domain module +(e.g., ``qonnx.custom_op.general`` for general ops) and it will be automatically +available through ``getCustomOp``. + +For dynamic registration (e.g., in tests), use the registry functions: + +* ``getCustomOp(node)`` - Get a custom op instance from an ONNX node +* ``add_op_to_domain(domain, op_type, op_class)`` - Add an op to a domain's namespace +* ``add_domain_alias(domain, module_path)`` - Map a domain to a different module path +* ``hasCustomOp(domain, op_type)`` - Check if an op exists in a domain +* ``get_ops_in_domain(domain)`` - List all ops available in a domain Custom ONNX Execution Flow diff --git a/notebooks/3_custom_op.ipynb b/notebooks/3_custom_op.ipynb index d0cd10fd..cd01686c 100644 --- a/notebooks/3_custom_op.ipynb +++ b/notebooks/3_custom_op.ipynb @@ -129,35 +129,24 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "To make sure our custom op is available, it needs to be registered. The best practice for this is to create a submodule under `qonnx.custom_op` which includes a `custom_op` dictionary that maps strings (op names) to classes (op implementations). Since we're in a Jupyter notebook we'll just hijack it at runtime like this:" - ] + "source": "To make sure our custom op is available, we need to add it to the domain's namespace. For production code, you would place your CustomOp class directly in the appropriate module (e.g., in a file under `qonnx/custom_op/general/`). For testing and experimentation like in this notebook, we can use the `add_op_to_domain` function:" }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "import qonnx.custom_op.general as general\n", - "general.custom_op[\"MyPythonPowerOp\"] = MyPythonPowerOp" - ] + "source": "from qonnx.custom_op.registry import add_op_to_domain\n\n# Add our custom op to the general domain namespace\nadd_op_to_domain(\"qonnx.custom_op.general\", \"MyPythonPowerOp\", MyPythonPowerOp)" }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "We can see which custom ops are registered under this submodule by looking at the dictionary:" - ] + "source": "We can see which custom ops are available in a domain by using the registry function:" }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "general.custom_op" - ] + "source": "from qonnx.custom_op.registry import get_ops_in_domain\n\n# See all ops in the general domain\nops = get_ops_in_domain(\"qonnx.custom_op.general\")\nprint(f\"Available ops: {[op[0] for op in ops]}\")\n\n# Check if our op is there\nfrom qonnx.custom_op.registry import hasCustomOp\nprint(f\"MyPythonPowerOp available: {hasCustomOp('qonnx.custom_op.general', 'MyPythonPowerOp')}\")" }, { "cell_type": "markdown", @@ -462,17 +451,9 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# register our new op\n", - "general.custom_op[\"MyMixedPowerOp\"] = MyMixedPowerOp\n", - "\n", - "# make graph with new op\n", - "mixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\n", - "mixedop_graph.graph.node" - ] + "source": "# register our new op\nadd_op_to_domain(\"qonnx.custom_op.general\", \"MyMixedPowerOp\", MyMixedPowerOp)\n\n# make graph with new op\nmixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\nmixedop_graph.graph.node" }, { "cell_type": "markdown", @@ -744,4 +725,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/src/qonnx/custom_op/__init__.py b/src/qonnx/custom_op/__init__.py index be4f2162..7c38a8df 100644 --- a/src/qonnx/custom_op/__init__.py +++ b/src/qonnx/custom_op/__init__.py @@ -26,15 +26,5 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from qonnx.custom_op.registry import register_domain - -# Register QONNX domains (module path defaults to domain name) -register_domain("qonnx.custom_op.general") -register_domain("qonnx.custom_op.channels_last") - -# Register parent domain for hierarchy checking -register_domain("qonnx.custom_op") - -# Special case: Brevitas compatibility domain -# (QONNX handles Brevitas ops for backward compatibility) -register_domain("onnx.brevitas") \ No newline at end of file +# Domain aliases are automatically handled by the registry +# The onnx.brevitas -> qonnx.custom_op.general mapping is built into the registry \ No newline at end of file diff --git a/src/qonnx/custom_op/channels_last/batch_normalization.py b/src/qonnx/custom_op/channels_last/batch_normalization.py index b97ab1a8..f3b3f872 100644 --- a/src/qonnx/custom_op/channels_last/batch_normalization.py +++ b/src/qonnx/custom_op/channels_last/batch_normalization.py @@ -30,10 +30,8 @@ from onnx import TensorProto, helper from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class BatchNormalization(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/conv.py b/src/qonnx/custom_op/channels_last/conv.py index dc78a0fd..b0ff237b 100644 --- a/src/qonnx/custom_op/channels_last/conv.py +++ b/src/qonnx/custom_op/channels_last/conv.py @@ -31,10 +31,8 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.im2col import compute_conv_output_dim -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class Conv(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/max_pool.py b/src/qonnx/custom_op/channels_last/max_pool.py index 53a7b617..383f3008 100644 --- a/src/qonnx/custom_op/channels_last/max_pool.py +++ b/src/qonnx/custom_op/channels_last/max_pool.py @@ -31,10 +31,8 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class MaxPool(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/general/bipolar_quant.py b/src/qonnx/custom_op/general/bipolar_quant.py index 102f5210..986a7082 100644 --- a/src/qonnx/custom_op/general/bipolar_quant.py +++ b/src/qonnx/custom_op/general/bipolar_quant.py @@ -31,7 +31,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op def binary_quant(inp_tensor, scale): @@ -48,7 +47,6 @@ def binary_quant(inp_tensor, scale): return out_tensor -@register_custom_op class BipolarQuant(CustomOp): """Bipolar quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/debugmarker.py b/src/qonnx/custom_op/general/debugmarker.py index 3da80521..ae8cbce5 100644 --- a/src/qonnx/custom_op/general/debugmarker.py +++ b/src/qonnx/custom_op/general/debugmarker.py @@ -29,10 +29,8 @@ from onnx import helper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class DebugMarker(CustomOp): def get_nodeattr_types(self): return {"export_debug_name": ("s", True, "")} diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py index ab74b1df..a34f6c01 100644 --- a/src/qonnx/custom_op/general/floatquant.py +++ b/src/qonnx/custom_op/general/floatquant.py @@ -33,7 +33,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode -from qonnx.custom_op.registry import register_custom_op def compute_default_exponent_bias(exponent_bitwidth): @@ -121,7 +120,6 @@ def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask): return x_q * scale # , self.saturating, self.inf_values, self.nan_values -@register_custom_op class FloatQuant(CustomOp): """Floating point quantization operation for QONNX. diff --git a/src/qonnx/custom_op/general/genericpartition.py b/src/qonnx/custom_op/general/genericpartition.py index 3418f9a5..841e4e9b 100755 --- a/src/qonnx/custom_op/general/genericpartition.py +++ b/src/qonnx/custom_op/general/genericpartition.py @@ -29,10 +29,8 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.onnx_exec import execute_onnx from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class GenericPartition(CustomOp): """Class that corresponds to the meta/container node GenericPartition which is a placeholder for a group of nodes that have been separated diff --git a/src/qonnx/custom_op/general/im2col.py b/src/qonnx/custom_op/general/im2col.py index 276caf7d..42477832 100644 --- a/src/qonnx/custom_op/general/im2col.py +++ b/src/qonnx/custom_op/general/im2col.py @@ -31,7 +31,6 @@ import qonnx.util.basic as util from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op # adapted from A. Karpathy's CS231 im2col code # utilities to generate a patch matrix from a multichannel image @@ -141,7 +140,6 @@ def im2col_indices_nchw( # oh/ow and kh/kw will also be 1 in this case -@register_custom_op class Im2Col(CustomOp): def get_nodeattr_types(self): return { diff --git a/src/qonnx/custom_op/general/intquant.py b/src/qonnx/custom_op/general/intquant.py index a053e7ef..69920b97 100644 --- a/src/qonnx/custom_op/general/intquant.py +++ b/src/qonnx/custom_op/general/intquant.py @@ -31,7 +31,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int: @@ -166,7 +165,6 @@ def round_half_down(x): raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") -@register_custom_op class IntQuant(CustomOp): """Generic quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/maxpoolnhwc.py b/src/qonnx/custom_op/general/maxpoolnhwc.py index 44aa04bc..eb964fc4 100644 --- a/src/qonnx/custom_op/general/maxpoolnhwc.py +++ b/src/qonnx/custom_op/general/maxpoolnhwc.py @@ -33,7 +33,6 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op from qonnx.util.basic import qonnx_make_model @@ -45,7 +44,6 @@ def compute_pool_output_dim(ifm_dim, k, stride, pad=0, ceil_mode=0): return int(np.floor(((ifm_dim + 2 * pad - k) / stride) + 1)) -@register_custom_op class MaxPoolNHWC(CustomOp): # a MaxPool node, but using the NHWC data layout diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 76df4752..6df58f95 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -31,7 +31,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op def multithreshold(v, thresholds, out_scale=None, out_bias=None): @@ -85,7 +84,6 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): return out_scale * ret.reshape(v.shape) + out_bias -@register_custom_op class MultiThreshold(CustomOp): """Class that corresponds to a multithresholding node.""" diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 858eafc2..a7356a8f 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -26,19 +26,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Import IntQuant to create alias from qonnx.custom_op.general.intquant import IntQuant + +# Re-export functions from intquant for backward compatibility from qonnx.custom_op.general.intquant import int_quant as quant from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode -from qonnx.custom_op.registry import register_custom_op - -# Create alias and register it separately for "Quant" op_type -@register_custom_op -class Quant(IntQuant): - """Alias for IntQuant to support legacy \"Quant\" op_type.""" - pass -# Re-export functions -quant = quant -max_int = max_int -min_int = min_int -resolve_rounding_mode = resolve_rounding_mode +# Create alias for backward compatibility - Quant is just IntQuant +Quant = IntQuant \ No newline at end of file diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index 7e51f3e3..c0e24071 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -33,11 +33,9 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.custom_op.registry import register_custom_op from qonnx.util.basic import qonnx_make_model -@register_custom_op class QuantAvgPool2d(CustomOp): """CustomOp that corresponds to the quantized average pooling layer from Brevitas""" diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index d2921262..8e2eaa19 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -32,7 +32,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode -from qonnx.custom_op.registry import register_custom_op def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): @@ -59,7 +58,6 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding return y -@register_custom_op class Trunc(CustomOp): """Generic truncation operation for QONNX. Takes four inputs: - input tensor to truncate diff --git a/src/qonnx/custom_op/general/xnorpopcount.py b/src/qonnx/custom_op/general/xnorpopcount.py index c068cb9d..9a640599 100644 --- a/src/qonnx/custom_op/general/xnorpopcount.py +++ b/src/qonnx/custom_op/general/xnorpopcount.py @@ -31,7 +31,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op def xnorpopcountmatmul(inp0, inp1): @@ -61,7 +60,6 @@ def xnorpopcountmatmul(inp0, inp1): return (out + K) * 0.5 -@register_custom_op class XnorPopcountMatMul(CustomOp): """Class that corresponds to a XNOR-popcount matrix multiplication node.""" diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index e8be6f22..c7d964f6 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -27,287 +27,172 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import importlib -import warnings -from importlib import metadata +import inspect +from typing import Dict +from qonnx.custom_op.base import CustomOp from qonnx.util.basic import get_preferred_onnx_opset -# global registry mapping (domain, op_type) -> CustomOp subclass -CUSTOM_OP_REGISTRY = {} +# Domain to module path mapping (only when different) +DOMAIN_MODULES: Dict[str, str] = { + "onnx.brevitas": "qonnx.custom_op.general", # Built-in compatibility +} -# global registry for custom op metadata -_OP_METADATA = {} -# global registry mapping domains to their module paths -# Structure: DOMAIN_REGISTRY[domain] = module_path (or None if module_path == domain) -DOMAIN_REGISTRY = {} - -# Track which domains have been loaded -_LOADED_DOMAINS = set() - -# Track domain dependencies discovered through inheritance -_DOMAIN_DEPENDENCIES = {} # domain -> set of dependency domains - - - - -def _ensure_domain_loaded(domain): - """Ensure a domain and its dependencies are loaded.""" - if domain in _LOADED_DOMAINS: - return - - # Mark as loaded first to prevent infinite recursion - _LOADED_DOMAINS.add(domain) - - # First load any known dependencies - if domain in _DOMAIN_DEPENDENCIES: - for dep_domain in _DOMAIN_DEPENDENCIES[domain]: - if dep_domain != domain: # Avoid self-dependencies - _ensure_domain_loaded(dep_domain) - - # Try to import the domain module - if domain in DOMAIN_REGISTRY: - module_path = DOMAIN_REGISTRY[domain] or domain - try: - importlib.import_module(module_path) - except ImportError as e: - # Remove from loaded if import failed - _LOADED_DOMAINS.discard(domain) - # Continue without raising - domain might still work - elif domain.startswith(("finn.", "qonnx.")): - # Try importing even if not in registry - try: - importlib.import_module(domain) - except ImportError: - # Remove from loaded if import failed - _LOADED_DOMAINS.discard(domain) - - -def _register_op_with_dependencies(domain, op_type, cls, metadata=None): - """Register an op and track its inheritance dependencies.""" - # Register the op - CUSTOM_OP_REGISTRY[(domain, op_type)] = cls - if metadata is not None: - _OP_METADATA[(domain, op_type)] = metadata - - # Detect dependencies from inheritance - for base in cls.__bases__: - # Skip abstract base classes and non-custom ops - if base.__name__ in ('CustomOp', 'ABC', 'object', 'HWCustomOp', 'HLSBackend', 'RTLBackend'): - continue - - # Check if base class is a registered custom op - for (reg_domain, reg_op), reg_cls in CUSTOM_OP_REGISTRY.items(): - if reg_cls == base: - # Found a dependency - track it - if domain not in _DOMAIN_DEPENDENCIES: - _DOMAIN_DEPENDENCIES[domain] = set() - _DOMAIN_DEPENDENCIES[domain].add(reg_domain) - - # Immediately ensure the dependency is loaded - _ensure_domain_loaded(reg_domain) - break - - return cls - - -def is_custom_op_domain(domain): - """Check if domain is registered for custom ops.""" - # Check if domain is directly registered or starts with a registered domain - return domain in DOMAIN_REGISTRY or any(domain.startswith(d) for d in DOMAIN_REGISTRY) - - -def hasCustomOp(domain, op_type): - """Check if a custom op exists without creating an instance. +def add_domain_alias(domain: str, module_path: str) -> None: + """Map a domain name to a different module path. Args: - domain: The domain of the custom op - op_type: The op_type of the custom op + domain: The ONNX domain name + module_path: The Python module path to use instead - Returns: - bool: True if the op is registered, False otherwise + Example: + add_domain_alias("finn.custom_op.fpgadataflow", "finn_custom_ops.fpgadataflow") """ - # Ensure domain is loaded first - _ensure_domain_loaded(domain) - return (domain, op_type) in CUSTOM_OP_REGISTRY + DOMAIN_MODULES[domain] = module_path -def get_ops_in_domain(domain): - """Get all registered ops in a domain. +def add_op_to_domain(domain: str, op_type: str, op_class: type) -> None: + """Add a custom op directly to a domain's module namespace. - Args: - domain: The domain to query - - Returns: - List[Tuple[str, Type[CustomOp]]]: List of (op_type, class) tuples - """ - return [(op_type, cls) for (d, op_type), cls in CUSTOM_OP_REGISTRY.items() - if d == domain] - - -def get_op_metadata(domain, op_type): - """Get metadata for a registered custom op. + This function dynamically adds custom ops to module namespaces at runtime. + Useful for test cases or dynamic op registration. Args: - domain: The domain of the custom op - op_type: The op_type of the custom op + domain: The ONNX domain name (e.g., "qonnx.custom_op.general") + op_type: The operation type name (e.g., "MyCustomOp") + op_class: The CustomOp subclass to add - Returns: - dict: The metadata dict if available, None otherwise + Example: + add_op_to_domain("qonnx.custom_op.general", "TestOp", TestOp) """ - return _OP_METADATA.get((domain, op_type)) - - -def register_domain(domain, module_path=None): - """Register a domain with its associated module path. + if not inspect.isclass(op_class) or not issubclass(op_class, CustomOp): + raise ValueError(f"{op_class} must be a subclass of CustomOp") - This function registers the domain and its module path, allowing classes - defined in any direct child module of this path to use @register_custom_op. - Subfolders/subpackages must be registered separately. + # Get the actual module path + module_path = DOMAIN_MODULES.get(domain, domain) - Args: - domain: The domain to register (e.g., "finn.custom_op.fpgadataflow") - module_path: The Python module path. If None, uses the domain as the path. - """ - DOMAIN_REGISTRY[domain] = module_path - - -# Keep register_domain_path as deprecated alias for backward compatibility -def register_domain_path(module_path, domain): - """Deprecated: Use register_domain instead.""" - return register_domain(domain, module_path) - + try: + # Import the module and add the op to its namespace + module = importlib.import_module(module_path) + setattr(module, op_type, op_class) + except ModuleNotFoundError: + raise ValueError(f"Could not find module for domain '{domain}' (tried: {module_path})") -def register_custom_op(domain=None, op_type=None, *, metadata=None): - """Register a custom op with flexible domain and op_type specification. +def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset()): + """Get a custom op instance for an ONNX node. - Can be used in three ways: - 1. @register_custom_op("domain", "OpType") - Explicit domain and op_type - 2. @register_custom_op("domain") - Explicit domain, class name as op_type - 3. @register_custom_op - Automatic domain inference, class name as op_type - - Args: - domain: The domain for the custom op (optional) - op_type: The op_type for the custom op (optional) - metadata: Optional dict of metadata about the op (backend, version, etc.) - - Returns: - Decorated class or decorator function + Lookup order: + 1. Direct attribute lookup in module namespace + 2. Legacy custom_op dictionary (backward compatibility) + 3. Search all CustomOp subclasses (fallback) """ - # Determine which mode we're in based on arguments - if domain is not None and isinstance(domain, str): - # Mode 1 or 2: Explicit domain provided - if op_type is not None and isinstance(op_type, str): - # Mode 1: Both domain and op_type provided - def decorator(cls): - return _register_op_with_dependencies(domain, op_type, cls, metadata) - return decorator - else: - # Mode 2: Only domain provided, use class name as op_type - def decorator(cls): - final_op_type = cls.__name__ - return _register_op_with_dependencies(domain, final_op_type, cls, metadata) - return decorator - else: - # Mode 3: No domain provided, or called without arguments - # Handle the case where it's used as @register_custom_op (no parentheses) - if domain is not None and not isinstance(domain, str): - # This means domain is actually the class (decorator without parentheses) - cls = domain - module = cls.__module__ - - # Find domain from registered domains - inferred_domain = None - for registered_domain, module_path in DOMAIN_REGISTRY.items(): - if module_path is None: - module_path = registered_domain - if module.startswith(module_path + "."): - remainder = module[len(module_path) + 1:] - if "." not in remainder: - inferred_domain = registered_domain - break - elif module == module_path: - inferred_domain = registered_domain - break - - if inferred_domain is None: - raise ValueError( - f"Module '{module}' is not in a registered domain path. " - f"Either:\n" - f"1. Use @register_custom_op('domain', 'OpType')\n" - f"2. Use @register_custom_op('domain') to use class name as op_type\n" - f"3. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" - ) - - final_op_type = cls.__name__ - return _register_op_with_dependencies(inferred_domain, final_op_type, cls, metadata) - else: - # Decorator called with parentheses but no domain - def decorator(cls): - module = cls.__module__ - - # Find domain from registered domains - inferred_domain = None - for registered_domain, module_path in DOMAIN_REGISTRY.items(): - if module_path is None: - module_path = registered_domain - if module.startswith(module_path + "."): - remainder = module[len(module_path) + 1:] - if "." not in remainder: - inferred_domain = registered_domain - break - elif module == module_path: - inferred_domain = registered_domain - break - - if inferred_domain is None: - raise ValueError( - f"Module '{module}' is not in a registered domain path. " - f"Either:\n" - f"1. Use @register_custom_op('domain', 'OpType')\n" - f"2. Use @register_custom_op('domain') to use class name as op_type\n" - f"3. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" - ) - - # Use provided op_type or default to class name - final_op_type = op_type or cls.__name__ - return _register_op_with_dependencies(inferred_domain, final_op_type, cls, metadata) - return decorator - - -def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - """Return a QONNX CustomOp instance for the given ONNX node.""" op_type = node.op_type domain = node.domain - if brevitas_exception: - # transparently resolve Brevitas domain ops to qonnx ones - domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") + # Get module path (handles brevitas via DOMAIN_MODULES mapping) + module_path = DOMAIN_MODULES.get(domain, domain) - # Ensure the domain is loaded (will load dependencies automatically) - _ensure_domain_loaded(domain) - - key = (domain, op_type) - cls = CUSTOM_OP_REGISTRY.get(key) - if cls is not None: - return cls(node, onnx_opset_version=onnx_opset_version) - - # If not found and domain starts with finn, try explicit import as fallback - # This handles cases where domain isn't registered but module exists - if domain.startswith("finn.custom_op") and domain not in _LOADED_DOMAINS: - try: - importlib.import_module(domain) - _LOADED_DOMAINS.add(domain) - # Check again after import - cls = CUSTOM_OP_REGISTRY.get(key) - if cls is not None: + try: + # Import the domain module + module = importlib.import_module(module_path) + + # Strategy 1: Direct namespace lookup (preferred) + if hasattr(module, op_type): + obj = getattr(module, op_type) + if inspect.isclass(obj) and issubclass(obj, CustomOp): + return obj(node, onnx_opset_version=onnx_opset_version) + + # Strategy 2: Legacy custom_op dict (backward compatibility) + if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): + if op_type in module.custom_op: + cls = module.custom_op[op_type] return cls(node, onnx_opset_version=onnx_opset_version) - except ImportError: + + # Strategy 3: Search module for CustomOp subclasses (fallback) + # Useful for debugging and error messages + custom_ops = {} + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + issubclass(obj, CustomOp) and + obj is not CustomOp and + not name.startswith('_')): # Skip private classes + custom_ops[name] = obj + + # Try case-insensitive match as last resort + for name, cls in custom_ops.items(): + if name.lower() == op_type.lower(): + return cls(node, onnx_opset_version=onnx_opset_version) + + # Not found - provide helpful error + available = list(custom_ops.keys()) + raise KeyError( + f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). " + f"Available ops: {available}" + ) + + except ModuleNotFoundError: + raise Exception( + f"Could not load module '{module_path}' for domain '{domain}'. " + f"Ensure the module is installed and on your PYTHONPATH." + ) + + +# Legacy functions for backward compatibility +def hasCustomOp(domain, op_type): + """Check if a custom op exists in the domain's module namespace.""" + try: + # Create a dummy node to test + class DummyNode: pass + node = DummyNode() + node.op_type = op_type + node.domain = domain + + # Try to get the op class + module_path = DOMAIN_MODULES.get(domain, domain) + module = importlib.import_module(module_path) + + # Check namespace first + if hasattr(module, op_type): + obj = getattr(module, op_type) + if inspect.isclass(obj) and issubclass(obj, CustomOp): + return True + + # Check legacy dict + if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): + return op_type in module.custom_op + + return False + except: + return False + + +def get_ops_in_domain(domain): + """Get all ops in a domain by inspecting the module namespace.""" + ops = [] - available_domains = sorted(DOMAIN_REGISTRY.keys()) - raise Exception( - f"Op '{op_type}' not found in domain '{domain}'. " - f"Available domains: {available_domains}" - ) + try: + module_path = DOMAIN_MODULES.get(domain, domain) + module = importlib.import_module(module_path) + + # Check module namespace + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + issubclass(obj, CustomOp) and + obj is not CustomOp and + not name.startswith('_')): + ops.append((name, obj)) + + # Also check legacy dict if present + if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): + for name, cls in module.custom_op.items(): + if not any(op[0] == name for op in ops): + ops.append((name, cls)) + + return ops + except: + return [] + + diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 68d691e1..92ba5d2e 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -33,6 +33,7 @@ import warnings from qonnx.core.datatype import DataType +from qonnx.custom_op.registry import get_ops_in_domain # TODO solve by moving onnx-dependent fxns to onnx.py # finn-examples uses parts of qonnx without having @@ -63,9 +64,18 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(domain): - "Return whether given domain string is a QONNX or FINN custom op domain" - from qonnx.custom_op.registry import is_custom_op_domain - return is_custom_op_domain(domain) + """Return whether given domain string is a QONNX or FINN custom op domain. + + Validates that: + 1. The domain starts with known custom op prefixes (qonnx., finn., onnx.brevitas) + 2. The domain exists and contains at least one CustomOp + """ + # Check if domain has known custom op prefix + if not domain.startswith(("qonnx.", "finn.", "onnx.brevitas")): + return False + + # Validate that the domain actually exists and has CustomOps + return len(get_ops_in_domain(domain)) > 0 def get_num_default_workers(): diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index 592a99ae..906e154a 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -31,10 +31,9 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import getCustomOp, register_custom_op +from qonnx.custom_op.registry import getCustomOp, add_op_to_domain -@register_custom_op class AttrTestOp(CustomOp): def get_nodeattr_types(self): my_attrs = {"tensor_attr": ("t", True, np.asarray([])), "strings_attr": ("strings", True, [""])} @@ -60,6 +59,9 @@ def verify_node(self): def test_attr(): + # Add the test op to the domain + add_op_to_domain("qonnx.custom_op.general", "AttrTestOp", AttrTestOp) + ishp = (1, 10) wshp = (1, 3) oshp = wshp @@ -86,6 +88,8 @@ def test_attr(): """ model = oprs.parse_model(input) model = ModelWrapper(model) + + # Now getCustomOp should find it through the manual registry inst = getCustomOp(model.graph.node[0]) w_prod = inst.get_nodeattr("tensor_attr") From 68346e3879936b0003d98c224bd56d4bebe267a5 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Fri, 18 Jul 2025 04:54:40 +0000 Subject: [PATCH 12/29] Circular import fix --- src/qonnx/util/basic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 92ba5d2e..dae5fbd4 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -33,7 +33,6 @@ import warnings from qonnx.core.datatype import DataType -from qonnx.custom_op.registry import get_ops_in_domain # TODO solve by moving onnx-dependent fxns to onnx.py # finn-examples uses parts of qonnx without having @@ -75,6 +74,8 @@ def is_finn_op(domain): return False # Validate that the domain actually exists and has CustomOps + # Lazy import to avoid circular dependency + from qonnx.custom_op.registry import get_ops_in_domain return len(get_ops_in_domain(domain)) > 0 From 93fd8d004e398371dcdf212b9ccb41dbd048df93 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Fri, 18 Jul 2025 17:54:15 +0000 Subject: [PATCH 13/29] Added brainsmith to hide finn ops --- src/qonnx/util/basic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index dae5fbd4..72fe18c2 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -63,14 +63,14 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(domain): - """Return whether given domain string is a QONNX or FINN custom op domain. + """Return whether given domain string is a QONNX, FINN, or Brainsmith custom op domain. Validates that: - 1. The domain starts with known custom op prefixes (qonnx., finn., onnx.brevitas) + 1. The domain starts with known custom op prefixes (qonnx., finn., onnx.brevitas, brainsmith.) 2. The domain exists and contains at least one CustomOp """ # Check if domain has known custom op prefix - if not domain.startswith(("qonnx.", "finn.", "onnx.brevitas")): + if not domain.startswith(("qonnx.", "finn.", "onnx.brevitas", "brainsmith.")): return False # Validate that the domain actually exists and has CustomOps From f2c4ccd3e71795c9f116ee5a0c87a7dfd590c6d0 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Sat, 18 Oct 2025 22:55:02 -0700 Subject: [PATCH 14/29] refactor: migrate registry to thread-safe, cache-based architecture Replace namespace-based custom op registration with centralized registry using (domain, op_type) keys. Add thread-safe operations with RLock, lazy discovery, and caching. Deprecate is_finn_op and hasCustomOp in favor of is_custom_op. Simplify add_op_to_domain signature to derive op_type from class name. Update all call sites across codebase. --- docs/overview.rst | 7 +- notebooks/3_custom_op.ipynb | 11 +- src/qonnx/core/modelwrapper.py | 5 +- src/qonnx/core/onnx_exec.py | 4 +- src/qonnx/custom_op/general/__init__.py | 29 +- src/qonnx/custom_op/registry.py | 318 +++++++++++------- .../transformation/infer_data_layouts.py | 7 +- src/qonnx/transformation/infer_datatypes.py | 5 +- src/qonnx/transformation/infer_shapes.py | 6 +- src/qonnx/util/basic.py | 29 +- tests/custom_op/test_attr.py | 2 +- tests/transformation/test_channelslast.py | 4 +- 12 files changed, 249 insertions(+), 178 deletions(-) diff --git a/docs/overview.rst b/docs/overview.rst index 2f0f3577..161d1e49 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -50,13 +50,14 @@ Simply define your CustomOp subclass in the appropriate domain module (e.g., ``qonnx.custom_op.general`` for general ops) and it will be automatically available through ``getCustomOp``. -For dynamic registration (e.g., in tests), use the registry functions: +For dynamic registration and querying, use the registry functions: * ``getCustomOp(node)`` - Get a custom op instance from an ONNX node -* ``add_op_to_domain(domain, op_type, op_class)`` - Add an op to a domain's namespace +* ``is_custom_op(domain, op_type=None)`` - Check if a specific op or domain has custom ops +* ``add_op_to_domain(domain, op_class)`` - Register an op at runtime (for testing) +* ``get_ops_in_domain(domain)`` - List all ops available in a domain * ``add_domain_alias(domain, module_path)`` - Map a domain to a different module path * ``hasCustomOp(domain, op_type)`` - Check if an op exists in a domain -* ``get_ops_in_domain(domain)`` - List all ops available in a domain Custom ONNX Execution Flow diff --git a/notebooks/3_custom_op.ipynb b/notebooks/3_custom_op.ipynb index cd01686c..1b822163 100644 --- a/notebooks/3_custom_op.ipynb +++ b/notebooks/3_custom_op.ipynb @@ -129,13 +129,14 @@ { "cell_type": "markdown", "metadata": {}, - "source": "To make sure our custom op is available, we need to add it to the domain's namespace. For production code, you would place your CustomOp class directly in the appropriate module (e.g., in a file under `qonnx/custom_op/general/`). For testing and experimentation like in this notebook, we can use the `add_op_to_domain` function:" + "source": "To make sure our custom op is available, we need to add it to the domain. For production code, you would place your CustomOp class directly in the appropriate module file (e.g., in a file under `qonnx/custom_op/general/`). For testing and experimentation like in this notebook, we can use the `add_op_to_domain` function:" }, { "cell_type": "code", "metadata": {}, "outputs": [], - "source": "from qonnx.custom_op.registry import add_op_to_domain\n\n# Add our custom op to the general domain namespace\nadd_op_to_domain(\"qonnx.custom_op.general\", \"MyPythonPowerOp\", MyPythonPowerOp)" + "source": "from qonnx.custom_op.registry import add_op_to_domain\n\n# Add our custom op to the general domain\nadd_op_to_domain(\"qonnx.custom_op.general\", MyPythonPowerOp)", + "execution_count": null }, { "cell_type": "markdown", @@ -146,7 +147,8 @@ "cell_type": "code", "metadata": {}, "outputs": [], - "source": "from qonnx.custom_op.registry import get_ops_in_domain\n\n# See all ops in the general domain\nops = get_ops_in_domain(\"qonnx.custom_op.general\")\nprint(f\"Available ops: {[op[0] for op in ops]}\")\n\n# Check if our op is there\nfrom qonnx.custom_op.registry import hasCustomOp\nprint(f\"MyPythonPowerOp available: {hasCustomOp('qonnx.custom_op.general', 'MyPythonPowerOp')}\")" + "source": "from qonnx.custom_op.registry import get_ops_in_domain, is_custom_op\n\n# See all ops in the general domain\nops = get_ops_in_domain(\"qonnx.custom_op.general\")\nprint(f\"Available ops: {[op[0] for op in ops]}\")\n\n# Check if our op is there\nprint(f\"MyPythonPowerOp available: {is_custom_op('qonnx.custom_op.general', 'MyPythonPowerOp')}\")", + "execution_count": null }, { "cell_type": "markdown", @@ -453,7 +455,8 @@ "cell_type": "code", "metadata": {}, "outputs": [], - "source": "# register our new op\nadd_op_to_domain(\"qonnx.custom_op.general\", \"MyMixedPowerOp\", MyMixedPowerOp)\n\n# make graph with new op\nmixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\nmixedop_graph.graph.node" + "source": "# register our new op\nadd_op_to_domain(\"qonnx.custom_op.general\", MyMixedPowerOp)\n\n# make graph with new op\nmixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\nmixedop_graph.graph.node", + "execution_count": null }, { "cell_type": "markdown", diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index 8890e24a..de289ef3 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -38,6 +38,7 @@ import qonnx.util.basic as util import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.double_to_single_float import DoubleToSingleFloat from qonnx.transformation.general import ( RemoveStaticGraphInputs, @@ -624,11 +625,11 @@ def get_nodes_by_op_type(self, op_type): def get_finn_nodes(self): """Returns a list of nodes where domain == 'qonnx.*'.""" - return list(filter(lambda x: util.is_finn_op(x.domain), self.graph.node)) + return list(filter(lambda x: is_custom_op(x.domain), self.graph.node)) def get_non_finn_nodes(self): """Returns a list of nodes where domain != 'qonnx.*'.""" - return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node)) + return list(filter(lambda x: not is_custom_op(x.domain), self.graph.node)) def get_node_index(self, node): """Returns current index of given node, or None if not found.""" diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py index a8f4774c..3a686f7e 100644 --- a/src/qonnx/core/onnx_exec.py +++ b/src/qonnx/core/onnx_exec.py @@ -35,10 +35,10 @@ import qonnx.analysis.topology as ta import qonnx.core.execute_custom_node as ex_cu_node +from qonnx.custom_op.registry import is_custom_op from qonnx.util.basic import ( get_preferred_onnx_opset, get_sanitize_quant_tensors, - is_finn_op, qonnx_make_model, sanitize_quant_values, ) @@ -49,7 +49,7 @@ def execute_node(node, context, graph, return_full_exec_context=False, opset_ver Input/output provided via context.""" - if is_finn_op(node.domain): + if is_custom_op(node.domain, node.op_type): ex_cu_node.execute_custom_node(node, context, graph, onnx_opset_version=opset_version) else: # onnxruntime unfortunately does not implement run_node as defined by ONNX, diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index c2bb7a82..e859d860 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -39,17 +39,18 @@ from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul -__all__ = [ - "DebugMarker", - "QuantAvgPool2d", - "MaxPoolNHWC", - "GenericPartition", - "MultiThreshold", - "XnorPopcountMatMul", - "Im2Col", - "IntQuant", - "Quant", - "Trunc", - "BipolarQuant", - "FloatQuant", -] +# Legacy dictionary for backward compatibility +custom_op = { + "DebugMarker": DebugMarker, + "QuantAvgPool2d": QuantAvgPool2d, + "MaxPoolNHWC": MaxPoolNHWC, + "GenericPartition": GenericPartition, + "MultiThreshold": MultiThreshold, + "XnorPopcountMatMul": XnorPopcountMatMul, + "Im2Col": Im2Col, + "IntQuant": IntQuant, + "Quant": IntQuant, # Alias + "Trunc": Trunc, + "BipolarQuant": BipolarQuant, + "FloatQuant": FloatQuant, +} diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index c7d964f6..8eb1b378 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -28,171 +28,233 @@ import importlib import inspect -from typing import Dict +from threading import RLock +from typing import Dict, List, Optional, Tuple, Type from qonnx.custom_op.base import CustomOp from qonnx.util.basic import get_preferred_onnx_opset -# Domain to module path mapping (only when different) -DOMAIN_MODULES: Dict[str, str] = { - "onnx.brevitas": "qonnx.custom_op.general", # Built-in compatibility +# Registry keyed by original ONNX domain: (domain, op_type) -> CustomOp class +_OP_REGISTRY: Dict[Tuple[str, str], Type[CustomOp]] = {} + +_REGISTRY_LOCK = RLock() + +# Maps ONNX domain names to Python module paths (used for imports only) +_DOMAIN_ALIASES: Dict[str, str] = { + "onnx.brevitas": "qonnx.custom_op.general", } def add_domain_alias(domain: str, module_path: str) -> None: """Map a domain name to a different module path. - + + Args: + domain: The ONNX domain name (e.g., "finn.custom_op.fpgadataflow") + module_path: The Python module path to use instead (e.g., "finn_custom_ops.fpgadataflow") + """ + with _REGISTRY_LOCK: + _DOMAIN_ALIASES[domain] = module_path + + +def resolve_domain(domain: str) -> str: + """Resolve a domain to its actual module path, handling aliases. + Args: domain: The ONNX domain name - module_path: The Python module path to use instead - - Example: - add_domain_alias("finn.custom_op.fpgadataflow", "finn_custom_ops.fpgadataflow") + + Returns: + Resolved module path """ - DOMAIN_MODULES[domain] = module_path + return _DOMAIN_ALIASES.get(domain, domain) + + +def add_op_to_domain(domain: str, op_class: Type[CustomOp]) -> None: + """Register a custom op directly to a domain at runtime. + The op_type is automatically derived from the class name. + Useful for testing and experimentation. For production, define CustomOps + in the appropriate module file. -def add_op_to_domain(domain: str, op_type: str, op_class: type) -> None: - """Add a custom op directly to a domain's module namespace. - - This function dynamically adds custom ops to module namespaces at runtime. - Useful for test cases or dynamic op registration. - Args: - domain: The ONNX domain name (e.g., "qonnx.custom_op.general") - op_type: The operation type name (e.g., "MyCustomOp") - op_class: The CustomOp subclass to add - + domain: ONNX domain name (e.g., "qonnx.custom_op.general") + op_class: CustomOp subclass + Example: - add_op_to_domain("qonnx.custom_op.general", "TestOp", TestOp) + add_op_to_domain("qonnx.custom_op.general", MyTestOp) """ if not inspect.isclass(op_class) or not issubclass(op_class, CustomOp): raise ValueError(f"{op_class} must be a subclass of CustomOp") - - # Get the actual module path - module_path = DOMAIN_MODULES.get(domain, domain) - + + op_type = op_class.__name__ + + with _REGISTRY_LOCK: + _OP_REGISTRY[(domain, op_type)] = op_class + + +def _discover_custom_op(domain: str, op_type: str) -> bool: + """Discover and register a single custom op. + + Args: + domain: The ONNX domain name + op_type: The specific op type to discover + + Returns: + True if op was found and registered, False otherwise + """ + module_path = resolve_domain(domain) + try: - # Import the module and add the op to its namespace module = importlib.import_module(module_path) - setattr(module, op_type, op_class) except ModuleNotFoundError: - raise ValueError(f"Could not find module for domain '{domain}' (tried: {module_path})") + return False + + # Try namespace lookup + op_class = getattr(module, op_type, None) + if inspect.isclass(op_class) and issubclass(op_class, CustomOp): + _OP_REGISTRY[(domain, op_type)] = op_class + return True + + # Try legacy dict + custom_op_dict = getattr(module, 'custom_op', None) + if isinstance(custom_op_dict, dict): + op_class = custom_op_dict.get(op_type) + if inspect.isclass(op_class) and issubclass(op_class, CustomOp): + _OP_REGISTRY[(domain, op_type)] = op_class + return True + + return False def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset()): """Get a custom op instance for an ONNX node. - - Lookup order: - 1. Direct attribute lookup in module namespace - 2. Legacy custom_op dictionary (backward compatibility) - 3. Search all CustomOp subclasses (fallback) + + Args: + node: ONNX node with domain and op_type attributes + onnx_opset_version: ONNX opset version to use + + Returns: + CustomOp instance for the node + + Raises: + KeyError: If op_type not found in domain """ op_type = node.op_type domain = node.domain - - # Get module path (handles brevitas via DOMAIN_MODULES mapping) - module_path = DOMAIN_MODULES.get(domain, domain) - - try: - # Import the domain module - module = importlib.import_module(module_path) - - # Strategy 1: Direct namespace lookup (preferred) - if hasattr(module, op_type): - obj = getattr(module, op_type) - if inspect.isclass(obj) and issubclass(obj, CustomOp): - return obj(node, onnx_opset_version=onnx_opset_version) - - # Strategy 2: Legacy custom_op dict (backward compatibility) - if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): - if op_type in module.custom_op: - cls = module.custom_op[op_type] - return cls(node, onnx_opset_version=onnx_opset_version) - - # Strategy 3: Search module for CustomOp subclasses (fallback) - # Useful for debugging and error messages - custom_ops = {} - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) and - issubclass(obj, CustomOp) and - obj is not CustomOp and - not name.startswith('_')): # Skip private classes - custom_ops[name] = obj - - # Try case-insensitive match as last resort - for name, cls in custom_ops.items(): - if name.lower() == op_type.lower(): - return cls(node, onnx_opset_version=onnx_opset_version) - - # Not found - provide helpful error - available = list(custom_ops.keys()) + key = (domain, op_type) + + with _REGISTRY_LOCK: + if key in _OP_REGISTRY: + return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version) + + if _discover_custom_op(domain, op_type): + return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version) + + module_path = resolve_domain(domain) raise KeyError( f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). " - f"Available ops: {available}" - ) - - except ModuleNotFoundError: - raise Exception( - f"Could not load module '{module_path}' for domain '{domain}'. " - f"Ensure the module is installed and on your PYTHONPATH." + f"Ensure it's exported in the module namespace or in the custom_op dict." ) -# Legacy functions for backward compatibility -def hasCustomOp(domain, op_type): - """Check if a custom op exists in the domain's module namespace.""" - try: - # Create a dummy node to test - class DummyNode: - pass - node = DummyNode() - node.op_type = op_type - node.domain = domain - - # Try to get the op class - module_path = DOMAIN_MODULES.get(domain, domain) - module = importlib.import_module(module_path) - - # Check namespace first - if hasattr(module, op_type): - obj = getattr(module, op_type) - if inspect.isclass(obj) and issubclass(obj, CustomOp): - return True - - # Check legacy dict - if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): - return op_type in module.custom_op - - return False - except: +def is_custom_op(domain: str, op_type: Optional[str] = None) -> bool: + """Check if a custom op exists or if a domain has any custom ops. + + Args: + domain: The ONNX domain name + op_type: Optional operation type name. If None, checks if domain has any ops. + + Returns: + True if the specific op exists (when op_type given) or + if any ops exist for the domain (when op_type=None), False otherwise + """ + # Empty domain means standard ONNX op + if not domain: return False + with _REGISTRY_LOCK: + if op_type is not None: + # Check for specific op + key = (domain, op_type) + if key in _OP_REGISTRY: + return True + return _discover_custom_op(domain, op_type) + else: + # Check if domain has any registered ops + if any(d == domain for d, _ in _OP_REGISTRY.keys()): + return True + # Try to import the domain module as fallback + module_path = resolve_domain(domain) + try: + importlib.import_module(module_path) + return True + except (ModuleNotFoundError, ValueError): + return False + + +def hasCustomOp(domain: str, op_type: str) -> bool: + """Deprecated: Use is_custom_op instead. + + Check if a custom op exists. -def get_ops_in_domain(domain): - """Get all ops in a domain by inspecting the module namespace.""" + Args: + domain: The ONNX domain name + op_type: The operation type name + + Returns: + True if the op exists, False otherwise + """ + import warnings + warnings.warn( + "hasCustomOp is deprecated and will be removed in QONNX v1.0. " + "Use is_custom_op instead.", + DeprecationWarning, + stacklevel=2 + ) + return is_custom_op(domain, op_type) + + +def get_ops_in_domain(domain: str) -> List[Tuple[str, Type[CustomOp]]]: + """Get all CustomOp classes available in a domain. + + Args: + domain: ONNX domain name (e.g., "qonnx.custom_op.general") + + Returns: + List of (op_type, op_class) tuples + + Example: + ops = get_ops_in_domain("qonnx.custom_op.general") + for op_name, op_class in ops: + print(f"{op_name}: {op_class}") + """ ops = [] - - try: - module_path = DOMAIN_MODULES.get(domain, domain) - module = importlib.import_module(module_path) - - # Check module namespace - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) and - issubclass(obj, CustomOp) and - obj is not CustomOp and - not name.startswith('_')): - ops.append((name, obj)) - - # Also check legacy dict if present - if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): - for name, cls in module.custom_op.items(): - if not any(op[0] == name for op in ops): - ops.append((name, cls)) - - return ops - except: - return [] + module_path = resolve_domain(domain) + + with _REGISTRY_LOCK: + # Strategy 1: Get cached ops (fast path) + for (d, op_type), op_class in _OP_REGISTRY.items(): + if d == domain: + ops.append((op_type, op_class)) + + # Strategy 2: Discover from module (for uncached ops) + try: + module = importlib.import_module(module_path) + + # Check namespace exports + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + issubclass(obj, CustomOp) and + obj is not CustomOp and + not name.startswith('_') and + not any(op[0] == name for op in ops)): + ops.append((name, obj)) + # Check legacy custom_op dict + if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): + for name, cls in module.custom_op.items(): + if not any(op[0] == name for op in ops): + ops.append((name, cls)) + except ModuleNotFoundError: + pass # Domain doesn't exist as module, return cached ops only + return ops diff --git a/src/qonnx/transformation/infer_data_layouts.py b/src/qonnx/transformation/infer_data_layouts.py index 81143e45..2e23d771 100644 --- a/src/qonnx/transformation/infer_data_layouts.py +++ b/src/qonnx/transformation/infer_data_layouts.py @@ -30,15 +30,16 @@ import qonnx.core.data_layout as DataLayout import qonnx.custom_op.registry as registry +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.base import Transformation -from qonnx.util.basic import get_by_name, is_finn_op +from qonnx.util.basic import get_by_name def _dims_to_layout(model, node, ndims): if ndims == 2: return DataLayout.NC else: - if is_finn_op(node.domain): + if is_custom_op(node.domain): if node.op_type == "MultiThreshold" or node.op_type == "QuantAvgPool2d": mt_inst = registry.getCustomOp(node) layout = mt_inst.get_nodeattr("data_layout") @@ -72,7 +73,7 @@ def _infer_node_data_layout(model, node): Returns True if any changes were made.""" old_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output)) try: - if is_finn_op(node.domain): + if is_custom_op(node.domain): # try to guess based on number of output dims for o in node.output: ndims = len(model.get_tensor_shape(o)) diff --git a/src/qonnx/transformation/infer_datatypes.py b/src/qonnx/transformation/infer_datatypes.py index d54fd34f..167e0c3e 100644 --- a/src/qonnx/transformation/infer_datatypes.py +++ b/src/qonnx/transformation/infer_datatypes.py @@ -28,9 +28,10 @@ import qonnx.custom_op.registry as registry from qonnx.core.datatype import DataType, ScaledIntType +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.base import Transformation from qonnx.transformation.qcdq_to_qonnx import extract_elem_type -from qonnx.util.basic import get_by_name, is_finn_op +from qonnx.util.basic import get_by_name def is_scaled_int(x): @@ -82,7 +83,7 @@ def _infer_node_datatype(model, node, allow_scaledint_dtypes): idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) op_type = node.op_type - if is_finn_op(node.domain): + if is_custom_op(node.domain): # handle DataType inference for CustomOp try: # lookup op_type in registry of CustomOps diff --git a/src/qonnx/transformation/infer_shapes.py b/src/qonnx/transformation/infer_shapes.py index 87fbf0ee..3e532abf 100644 --- a/src/qonnx/transformation/infer_shapes.py +++ b/src/qonnx/transformation/infer_shapes.py @@ -30,14 +30,14 @@ import qonnx.custom_op.registry as registry from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.base import Transformation -from qonnx.util.basic import is_finn_op def _make_shape_compatible_op(node, model): """Return a shape-compatible non-QONNX op for a given QONNX op. Used for shape inference with custom ops.""" - assert is_finn_op(node.domain), "Node domain is not set to qonnx.*" + assert is_custom_op(node.domain), "Node domain is not a registered custom op domain" op_type = node.op_type try: # lookup op_type in registry of CustomOps @@ -56,7 +56,7 @@ def _hide_finn_ops(model): node_ind = 0 for node in model.graph.node: node_ind += 1 - if is_finn_op(node.domain): + if is_custom_op(node.domain): new_node = _make_shape_compatible_op(node, model) # keep old node name to help debug shape inference issues new_node.name = node.name diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 72fe18c2..3253d873 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -62,21 +62,22 @@ def qonnx_make_model(graph_proto, **kwargs): return make_model(graph_proto, **kwargs) -def is_finn_op(domain): - """Return whether given domain string is a QONNX, FINN, or Brainsmith custom op domain. - - Validates that: - 1. The domain starts with known custom op prefixes (qonnx., finn., onnx.brevitas, brainsmith.) - 2. The domain exists and contains at least one CustomOp +def is_finn_op(op_type): + """Deprecated: Use is_custom_op from qonnx.custom_op.registry instead. + + Return whether given op_type string is a QONNX or FINN custom op. + This function uses hard-coded string matching and will be removed in QONNX v1.0. + Use the registry-based is_custom_op for better accuracy and extensibility. """ - # Check if domain has known custom op prefix - if not domain.startswith(("qonnx.", "finn.", "onnx.brevitas", "brainsmith.")): - return False - - # Validate that the domain actually exists and has CustomOps - # Lazy import to avoid circular dependency - from qonnx.custom_op.registry import get_ops_in_domain - return len(get_ops_in_domain(domain)) > 0 + import warnings + warnings.warn( + "is_finn_op is deprecated and will be removed in QONNX v1.0. " + "Use 'from qonnx.custom_op.registry import is_custom_op' instead.", + DeprecationWarning, + stacklevel=2 + ) + from qonnx.custom_op.registry import is_custom_op + return is_custom_op(op_type) def get_num_default_workers(): diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index 906e154a..ac4f7a5c 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -60,7 +60,7 @@ def verify_node(self): def test_attr(): # Add the test op to the domain - add_op_to_domain("qonnx.custom_op.general", "AttrTestOp", AttrTestOp) + add_op_to_domain("qonnx.custom_op.general", AttrTestOp) ishp = (1, 10) wshp = (1, 3) diff --git a/tests/transformation/test_channelslast.py b/tests/transformation/test_channelslast.py index 24e64b4f..30382c64 100644 --- a/tests/transformation/test_channelslast.py +++ b/tests/transformation/test_channelslast.py @@ -43,11 +43,11 @@ MoveTransposePastFork, RemoveConsecutiveChanFirstAndChanLastTrafos, ) +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.general import GiveUniqueNodeNames from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.util.basic import is_finn_op from qonnx.util.test import download_model, get_golden_in_and_output, test_model_details from qonnx.util.to_channels_last import to_channels_last @@ -126,7 +126,7 @@ def analysis_test_for_left_transposes(model, test_model, make_input_channels_las def verify_all_nodes(model): result = dict() for n in model.graph.node: - if is_finn_op(n.domain): + if is_custom_op(n.domain): n_instance = getCustomOp(n) verify_result = n_instance.verify_node() result[n.name] = verify_result From ba24ecccf410680dea0ec09d872502a530b613a4 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 20:56:41 +0000 Subject: [PATCH 15/29] copy in metadata preservation --- .../transformation/extract_quant_scale_zeropt.py | 8 ++++++++ src/qonnx/transformation/gemm_to_matmul.py | 13 ++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 58863f08..614df416 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -69,6 +69,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) + if hasattr(node, "metadata_props"): + inp_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -87,6 +89,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) + if hasattr(node, "metadata_props"): + inp_zeropt_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -108,6 +112,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_zeropt_node.metadata_props.extend(node.metadata_props) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -127,6 +133,8 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 5396a7d6..1298f3d6 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -76,6 +76,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -98,6 +100,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -109,6 +113,8 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) + if hasattr(n, "metadata_props"): + matMul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -144,6 +150,8 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -175,6 +183,8 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -196,7 +206,8 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - + if hasattr(n, "metadata_props"): + add_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, add_node) running_node_index += 1 From da632b9f9b94817a6e9dba51f2ac665818a55f84 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 21:59:38 +0000 Subject: [PATCH 16/29] expand metadata copy coverage to other transforms --- src/qonnx/transformation/change_datalayout.py | 6 ++++++ src/qonnx/transformation/channels_last.py | 6 ++++++ src/qonnx/transformation/extract_conv_bias.py | 2 ++ src/qonnx/transformation/lower_convs_to_matmul.py | 4 ++++ src/qonnx/transformation/qcdq_to_qonnx.py | 4 ++++ src/qonnx/transformation/rebalance_conv.py | 2 ++ src/qonnx/transformation/resize_conv_to_deconv.py | 3 +++ src/qonnx/transformation/subpixel_to_deconv.py | 3 +++ 8 files changed, 30 insertions(+) diff --git a/src/qonnx/transformation/change_datalayout.py b/src/qonnx/transformation/change_datalayout.py index 7b73e4bf..07fbe400 100644 --- a/src/qonnx/transformation/change_datalayout.py +++ b/src/qonnx/transformation/change_datalayout.py @@ -78,6 +78,8 @@ def apply(self, model): graph.value_info.append(quantavg_out) quantavg_out = quantavg_out.name inp_trans_node = helper.make_node("Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) quantavg_node = helper.make_node( "QuantAvgPool2d", [inp_trans_out], @@ -90,8 +92,12 @@ def apply(self, model): signed=signed, data_layout="NHWC", ) + if hasattr(n, "metadata_props"): + quantavg_node.metadata_props.extend(n.metadata_props) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2]) + if hasattr(n, "metadata_props"): + out_trans_node.metadata_props.extend(n.metadata_props) # insert nodes graph.node.insert(node_ind, inp_trans_node) graph.node.insert(node_ind + 1, quantavg_node) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..8c934190 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -96,6 +96,8 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm) + if hasattr(transpose_node, "metadata_props"): + new_transpose_node.metadata_props.extend(transpose_node.metadata_props) t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp @@ -107,6 +109,8 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64)) unsqueeze_out_name = model.make_new_valueinfo_name() new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name]) + if hasattr(eltwise_inp, "metadata_props"): + new_unsqueeze_node.metadata_props.extend(eltwise_inp.metadata_props) unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape) model.graph.node.append(new_unsqueeze_node) @@ -114,6 +118,8 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm) + if hasattr(transpose_node, "metadata_props"): + new_transpose_node.metadata_props.extend(transpose_node.metadata_props) t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index bf2cf8b4..1bf264e3 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -75,6 +75,8 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) + if hasattr(n, "metadata_props"): + add_node.metadata_props.extend(n.metadata_props) graph.node.insert(node_ind, add_node) # Repoint Conv output and remove bias tensor diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 81f0b713..d864d8d2 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -178,8 +178,12 @@ def apply(self, model): matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out]) + if hasattr(node, "metadata_props"): + matmul_node.metadata_props.extend(node.metadata_props) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) + if hasattr(node, "metadata_props"): + out_trans_node.metadata_props.extend(node.metadata_props) nodes_to_insert.extend([matmul_node, out_trans_node]) diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index b7e35c0d..b122d840 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -203,6 +203,10 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: rounding_mode="ROUND", # round-to-even signed=signed, ) + # Pass on metadata from DequantizeLinear node since it's the only node that + # must be present to be able to perform this transformation. + if hasattr(node, "metadata_props"): + fused_node.metadata_props.extend(node.metadata_props) model.graph.node.insert(dequant_node_index, fused_node) for node_to_remove in nodes_to_remove: model.graph.node.remove(node_to_remove) diff --git a/src/qonnx/transformation/rebalance_conv.py b/src/qonnx/transformation/rebalance_conv.py index ecb2b5e4..098bff20 100644 --- a/src/qonnx/transformation/rebalance_conv.py +++ b/src/qonnx/transformation/rebalance_conv.py @@ -103,6 +103,8 @@ def apply(self, model): inp_reshape_node = helper.make_node( "Reshape", [node.input[0], inp_shapedata.name], [inp_reshape_out.name] ) + if hasattr(node, "metadata_props"): + inp_reshape_node.metadata_props.extend(node.metadata_props) graph.node.insert(running_node_index, inp_reshape_node) # rewire Im2Col input node.input[0] = inp_reshape_out.name diff --git a/src/qonnx/transformation/resize_conv_to_deconv.py b/src/qonnx/transformation/resize_conv_to_deconv.py index 0dd40972..30bc6c3c 100644 --- a/src/qonnx/transformation/resize_conv_to_deconv.py +++ b/src/qonnx/transformation/resize_conv_to_deconv.py @@ -242,6 +242,9 @@ def apply(self, model): group=group, dilations=dilation, ) + # Save metadata from the convolution node + if hasattr(conv, "metadata_props"): + deconv_node.metadata_props.extend(conv.metadata_props) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/transformation/subpixel_to_deconv.py b/src/qonnx/transformation/subpixel_to_deconv.py index 3f330c99..241422c8 100644 --- a/src/qonnx/transformation/subpixel_to_deconv.py +++ b/src/qonnx/transformation/subpixel_to_deconv.py @@ -197,6 +197,9 @@ def apply(self, model): group=group, dilations=dilation, ) + # Save metadata from the original convolution node + if hasattr(n, "metadata_props"): + deconv_node.metadata_props.extend(n.metadata_props) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name From 0d9d3e56ad6cc899872f5730b8be0af10f7ead02 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 22:06:39 +0000 Subject: [PATCH 17/29] add copy metadata props function --- src/qonnx/transformation/change_datalayout.py | 11 +++---- .../extract_quant_scale_zeropt.py | 13 +++----- src/qonnx/transformation/gemm_to_matmul.py | 20 ++++-------- .../transformation/lower_convs_to_matmul.py | 8 ++--- src/qonnx/transformation/qcdq_to_qonnx.py | 8 ++--- .../transformation/resize_conv_to_deconv.py | 6 ++-- .../transformation/subpixel_to_deconv.py | 6 ++-- src/qonnx/util/basic.py | 32 +++++++++++++++++++ 8 files changed, 58 insertions(+), 46 deletions(-) diff --git a/src/qonnx/transformation/change_datalayout.py b/src/qonnx/transformation/change_datalayout.py index 07fbe400..62e6140b 100644 --- a/src/qonnx/transformation/change_datalayout.py +++ b/src/qonnx/transformation/change_datalayout.py @@ -30,7 +30,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class ChangeDataLayoutQuantAvgPool2d(Transformation): @@ -78,8 +78,7 @@ def apply(self, model): graph.value_info.append(quantavg_out) quantavg_out = quantavg_out.name inp_trans_node = helper.make_node("Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]) - if hasattr(n, "metadata_props"): - inp_trans_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, inp_trans_node) quantavg_node = helper.make_node( "QuantAvgPool2d", [inp_trans_out], @@ -92,12 +91,10 @@ def apply(self, model): signed=signed, data_layout="NHWC", ) - if hasattr(n, "metadata_props"): - quantavg_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, quantavg_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2]) - if hasattr(n, "metadata_props"): - out_trans_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, out_trans_node) # insert nodes graph.node.insert(node_ind, inp_trans_node) graph.node.insert(node_ind + 1, quantavg_node) diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 614df416..f76e5555 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -33,6 +33,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph from qonnx.transformation.remove import RemoveIdentityOps +from qonnx.util.basic import copy_metadata_props class ExtractQuantScaleZeroPt(Transformation): @@ -69,8 +70,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) - if hasattr(node, "metadata_props"): - inp_scale_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, inp_scale_node) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -89,8 +89,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) - if hasattr(node, "metadata_props"): - inp_zeropt_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, inp_zeropt_node) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -112,8 +111,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) - if hasattr(node, "metadata_props"): - out_zeropt_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, out_zeropt_node) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -133,8 +131,7 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) - if hasattr(node, "metadata_props"): - out_scale_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, out_scale_node) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 1298f3d6..245a0a2a 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.core.datatype import DataType from qonnx.transformation.base import Transformation from qonnx.transformation.remove import RemoveIdentityOps -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class GemmToMatMul(Transformation): @@ -76,8 +76,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) - if hasattr(n, "metadata_props"): - inp_trans_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -100,8 +99,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) - if hasattr(n, "metadata_props"): - inp_trans_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -113,8 +111,7 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) - if hasattr(n, "metadata_props"): - matMul_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, matMul_node) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -150,8 +147,7 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) - if hasattr(n, "metadata_props"): - mul_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -183,8 +179,7 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) - if hasattr(n, "metadata_props"): - mul_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -206,8 +201,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - if hasattr(n, "metadata_props"): - add_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, add_node) graph.node.insert(running_node_index, add_node) running_node_index += 1 diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index d864d8d2..5140b71d 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name class LowerConvsToMatMul(Transformation): @@ -178,12 +178,10 @@ def apply(self, model): matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out]) - if hasattr(node, "metadata_props"): - matmul_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, matmul_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) - if hasattr(node, "metadata_props"): - out_trans_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, out_trans_node) nodes_to_insert.extend([matmul_node, out_trans_node]) diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index b122d840..7aaf9271 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -34,7 +34,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name def extract_elem_type(elem_type: int, clip_range=None) -> Tuple[int, int, bool]: @@ -203,10 +203,8 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: rounding_mode="ROUND", # round-to-even signed=signed, ) - # Pass on metadata from DequantizeLinear node since it's the only node that - # must be present to be able to perform this transformation. - if hasattr(node, "metadata_props"): - fused_node.metadata_props.extend(node.metadata_props) + # Preserve metadata from all nodes being fused + copy_metadata_props(nodes_to_remove, fused_node) model.graph.node.insert(dequant_node_index, fused_node) for node_to_remove in nodes_to_remove: model.graph.node.remove(node_to_remove) diff --git a/src/qonnx/transformation/resize_conv_to_deconv.py b/src/qonnx/transformation/resize_conv_to_deconv.py index 30bc6c3c..7eda4fa7 100644 --- a/src/qonnx/transformation/resize_conv_to_deconv.py +++ b/src/qonnx/transformation/resize_conv_to_deconv.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.general.quant import quant, resolve_rounding_mode from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_convolution(cnv_weights: np.ndarray, scale: int) -> np.ndarray: @@ -242,9 +242,7 @@ def apply(self, model): group=group, dilations=dilation, ) - # Save metadata from the convolution node - if hasattr(conv, "metadata_props"): - deconv_node.metadata_props.extend(conv.metadata_props) + copy_metadata_props(conv, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/transformation/subpixel_to_deconv.py b/src/qonnx/transformation/subpixel_to_deconv.py index 241422c8..73ef3f8f 100644 --- a/src/qonnx/transformation/subpixel_to_deconv.py +++ b/src/qonnx/transformation/subpixel_to_deconv.py @@ -31,7 +31,7 @@ from onnx import helper from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_shuffle(cnv_weights: np.ndarray, block_size: int) -> np.ndarray: @@ -197,9 +197,7 @@ def apply(self, model): group=group, dilations=dilation, ) - # Save metadata from the original convolution node - if hasattr(n, "metadata_props"): - deconv_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 4e300dd1..1696c42b 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -350,3 +350,35 @@ def auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w] else: raise Exception("Unsupported auto_pad: " + autopad_str) + + +def copy_metadata_props(source_node, target_node): + """Copy metadata properties from source node(s) to target node. + + Parameters + ---------- + source_node : onnx.NodeProto or list of onnx.NodeProto + Source node(s) from which to copy metadata_props. If a list is provided, + metadata from all nodes will be merged into the target node. + target_node : onnx.NodeProto + Target node to which metadata_props will be copied. + + Returns + ------- + None + Modifies target_node in place by extending its metadata_props. + + Examples + -------- + >>> # Copy from single node + >>> copy_metadata_props(old_node, new_node) + >>> + >>> # Copy from multiple nodes (e.g., when fusing) + >>> copy_metadata_props([quant_node, dequant_node], fused_node) + """ + # Handle both single node and list of nodes + source_nodes = source_node if isinstance(source_node, list) else [source_node] + + for node in source_nodes: + if hasattr(node, "metadata_props"): + target_node.metadata_props.extend(node.metadata_props) From 6f3a631abc80d1a9bed7865fcd485d5a9bb212ad Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 22:19:46 +0000 Subject: [PATCH 18/29] convert missed functions --- src/qonnx/transformation/channels_last.py | 11 ++++------- src/qonnx/transformation/extract_conv_bias.py | 4 ++-- src/qonnx/transformation/rebalance_conv.py | 4 ++-- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 8c934190..444a326c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -40,7 +40,7 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name from qonnx.util.onnx import is_eltwise_optype # Standard ONNX nodes which require a ChannelsLast data format to function properly @@ -96,8 +96,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm) - if hasattr(transpose_node, "metadata_props"): - new_transpose_node.metadata_props.extend(transpose_node.metadata_props) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp @@ -109,8 +108,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64)) unsqueeze_out_name = model.make_new_valueinfo_name() new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name]) - if hasattr(eltwise_inp, "metadata_props"): - new_unsqueeze_node.metadata_props.extend(eltwise_inp.metadata_props) + copy_metadata_props(eltwise_inp, new_unsqueeze_node) unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape) model.graph.node.append(new_unsqueeze_node) @@ -118,8 +116,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm) - if hasattr(transpose_node, "metadata_props"): - new_transpose_node.metadata_props.extend(transpose_node.metadata_props) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index 1bf264e3..34b017bd 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -30,6 +30,7 @@ from onnx import helper from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class ExtractBiasFromConv(Transformation): @@ -75,8 +76,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - if hasattr(n, "metadata_props"): - add_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, add_node) graph.node.insert(node_ind, add_node) # Repoint Conv output and remove bias tensor diff --git a/src/qonnx/transformation/rebalance_conv.py b/src/qonnx/transformation/rebalance_conv.py index 098bff20..0107a62a 100644 --- a/src/qonnx/transformation/rebalance_conv.py +++ b/src/qonnx/transformation/rebalance_conv.py @@ -31,6 +31,7 @@ from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class RebalanceIm2Col(Transformation): @@ -103,8 +104,7 @@ def apply(self, model): inp_reshape_node = helper.make_node( "Reshape", [node.input[0], inp_shapedata.name], [inp_reshape_out.name] ) - if hasattr(node, "metadata_props"): - inp_reshape_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, inp_reshape_node) graph.node.insert(running_node_index, inp_reshape_node) # rewire Im2Col input node.input[0] = inp_reshape_out.name From 59ca168ebd4fa744b1ee4de15f8ac849f43464e1 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 22:23:11 +0000 Subject: [PATCH 19/29] correct fused node source mistake --- src/qonnx/transformation/qcdq_to_qonnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index 7aaf9271..b4e18f25 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -204,7 +204,7 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: signed=signed, ) # Preserve metadata from all nodes being fused - copy_metadata_props(nodes_to_remove, fused_node) + copy_metadata_props(node, fused_node) model.graph.node.insert(dequant_node_index, fused_node) for node_to_remove in nodes_to_remove: model.graph.node.remove(node_to_remove) From 159060670cf5692ffb6b79d88b65e66bec39327c Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Sat, 22 Nov 2025 00:41:31 +0000 Subject: [PATCH 20/29] add metadata preservation to batchnorm transform. --- src/qonnx/transformation/batchnorm_to_affine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/qonnx/transformation/batchnorm_to_affine.py b/src/qonnx/transformation/batchnorm_to_affine.py index c89d2bdc..d63dd178 100644 --- a/src/qonnx/transformation/batchnorm_to_affine.py +++ b/src/qonnx/transformation/batchnorm_to_affine.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import get_by_name, copy_metadata_props class BatchNormToAffine(Transformation): @@ -89,6 +89,9 @@ def apply(self, model): # create Mul and Add nodes to replace the batchnorm mul_node = oh.make_node("Mul", [bn_input, mul_const.name], [mul_output.name]) add_node = oh.make_node("Add", [mul_output.name, add_const.name], [bn_output]) + # preserve metadata from original batchnorm node + copy_metadata_props(n, mul_node) + copy_metadata_props(n, add_node) # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) From 5690a790291a0cc9a54fd9fb96d8a25080d63101 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Sat, 22 Nov 2025 01:00:42 +0000 Subject: [PATCH 21/29] adding more copy metadata nodes. --- src/qonnx/transformation/batchnorm_to_affine.py | 2 +- src/qonnx/transformation/bipolar_to_xnor.py | 5 ++++- src/qonnx/transformation/channels_last.py | 5 ++++- src/qonnx/transformation/lower_convs_to_matmul.py | 2 ++ 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/qonnx/transformation/batchnorm_to_affine.py b/src/qonnx/transformation/batchnorm_to_affine.py index d63dd178..6190f867 100644 --- a/src/qonnx/transformation/batchnorm_to_affine.py +++ b/src/qonnx/transformation/batchnorm_to_affine.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name, copy_metadata_props +from qonnx.util.basic import copy_metadata_props, get_by_name class BatchNormToAffine(Transformation): diff --git a/src/qonnx/transformation/bipolar_to_xnor.py b/src/qonnx/transformation/bipolar_to_xnor.py index 37f939a2..0b764ef8 100644 --- a/src/qonnx/transformation/bipolar_to_xnor.py +++ b/src/qonnx/transformation/bipolar_to_xnor.py @@ -36,7 +36,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_datatypes import InferDataTypes from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class ConvertBipolarMatMulToXnorPopcount(Transformation): @@ -132,6 +132,9 @@ def find_prod_mt(x): # create Mul and Add nodes to replace the batchnorm mul_node = oh.make_node("Mul", [xnorpcout.name, mul_const.name], [mul_output.name]) add_node = oh.make_node("Add", [mul_output.name, add_const.name], [mm_output]) + # preserve metadata from original MatMul node + copy_metadata_props(n, mul_node) + copy_metadata_props(n, add_node) # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 1a2a0dcd..f9ca62bb 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -242,6 +242,7 @@ def apply(self, model): # channels last transpose inp_trans_node = helper.make_node("Transpose", [inp], [inp_trans_out], perm=to_channels_last_args(ndim)) graph.node.insert(running_node_index, inp_trans_node) + copy_metadata_props(n, inp_trans_node) running_node_index += 1 # Attach to original node @@ -268,6 +269,7 @@ def apply(self, model): "Transpose", [outp_trans_in], [outp], perm=to_channels_first_args(ndim) ) graph.node.insert(running_node_index, outp_trans_node) + copy_metadata_props(n, outp_trans_node) running_node_index += 1 # Attach to original node @@ -570,7 +572,8 @@ def apply(self, model): axis=1, ) graph.node.insert(node_ind, flat_node) - + copy_metadata_props(n, flat_node) + graph_modified = True else: warnings.warn( diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 5140b71d..f0981b34 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -152,6 +152,7 @@ def apply(self, model): # create new nodes # NCHW -> NHWC inp_trans_node = helper.make_node("Transpose", [cnv_input], [inp_trans_out], perm=[0, 2, 3, 1]) + copy_metadata_props(node, inp_trans_node) nodes_to_insert = [inp_trans_node] if need_im2col: @@ -174,6 +175,7 @@ def apply(self, model): dilations=dilation, ) nodes_to_insert.append(im2col_node) + copy_metadata_props(node, im2col_node) matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul From ae6154bc1c1a1d9de8dbcef7aff75f719f46e3c6 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Mon, 24 Nov 2025 13:19:23 -0800 Subject: [PATCH 22/29] Experimental metadata functions --- src/qonnx/transformation/general.py | 13 +++++- src/qonnx/util/basic.py | 64 +++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/src/qonnx/transformation/general.py b/src/qonnx/transformation/general.py index 5cfb907f..a381499e 100644 --- a/src/qonnx/transformation/general.py +++ b/src/qonnx/transformation/general.py @@ -119,15 +119,24 @@ def apply(self, model): class GiveUniqueNodeNames(Transformation): """Give unique names to each node in the graph using enumeration, starting - with given prefix (if specified in the constructor).""" + with given prefix (if specified in the constructor). - def __init__(self, prefix=""): + If only_empty=True, only renames nodes that have empty names, preserving + existing node names. This is useful after transforms that insert nodes + without names, to avoid stripping prefixes from existing nodes.""" + + def __init__(self, prefix="", only_empty=False): super().__init__() self.prefix = prefix + self.only_empty = only_empty def apply(self, model): optype_count = {} for n in model.graph.node: + # Skip nodes that already have names if only_empty=True + if self.only_empty and n.name != "": + continue + if n.op_type not in optype_count.keys(): optype_count[n.op_type] = 0 n.name = "%s%s_%d" % (self.prefix, n.op_type, optype_count[n.op_type]) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 3253d873..67bfc2a9 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -354,3 +354,67 @@ def auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w] else: raise Exception("Unsupported auto_pad: " + autopad_str) + + +def get_tensor_metadata_prop(model, tensor_name, key): + """Get metadata property from a tensor (input/output/initializer/value_info). + + Args: + model: ModelWrapper instance + tensor_name: Name of the tensor + key: Metadata key to retrieve + + Returns: + str: Metadata value if found, None otherwise + """ + # Search all possible tensor locations + tensor = get_by_name(model.graph.input, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.output, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.initializer, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.value_info, tensor_name) + + if tensor is not None: + meta = get_by_name(tensor.metadata_props, key, "key") + return meta.value if meta is not None else None + return None + + +def set_tensor_metadata_prop(model, tensor_name, key, value): + """Set metadata property on a tensor (input/output/initializer/value_info). + + Args: + model: ModelWrapper instance + tensor_name: Name of the tensor + key: Metadata key to set + value: Metadata value (will be converted to string) + + Returns: + bool: True if successful, False if tensor not found + """ + import onnx + + # Search all possible tensor locations + tensor = get_by_name(model.graph.input, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.output, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.initializer, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.value_info, tensor_name) + + if tensor is not None: + meta = get_by_name(tensor.metadata_props, key, "key") + if meta is None: + # Create new metadata entry + meta = onnx.StringStringEntryProto() + meta.key = key + meta.value = str(value) + tensor.metadata_props.append(meta) + else: + # Update existing entry + meta.value = str(value) + return True + return False From 7c04726f8bb9c51fef07ca4a05dd647053a70e2a Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Tue, 25 Nov 2025 23:42:20 +0000 Subject: [PATCH 23/29] added overwrite mode flag and basic unit tests --- src/qonnx/util/basic.py | 22 ++++++++++-- tests/util/test_copy_metadata.py | 61 ++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 tests/util/test_copy_metadata.py diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 73f4cca2..2c52285a 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -362,7 +362,7 @@ def auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h raise Exception("Unsupported auto_pad: " + autopad_str) -def copy_metadata_props(source_node, target_node): +def copy_metadata_props(source_node, target_node, mode="overwrite"): """Copy metadata properties from source node(s) to target node. Parameters @@ -386,9 +386,27 @@ def copy_metadata_props(source_node, target_node): >>> # Copy from multiple nodes (e.g., when fusing) >>> copy_metadata_props([quant_node, dequant_node], fused_node) """ + assert mode in ["overwrite", "keep_existing"], "Copy Metadata Mode must be either 'overwrite' or 'keep_existing'." + # Handle both single node and list of nodes source_nodes = source_node if isinstance(source_node, list) else [source_node] for node in source_nodes: if hasattr(node, "metadata_props"): - target_node.metadata_props.extend(node.metadata_props) + + # check for existing keys in target_node to avoid duplicates + if hasattr(target_node, "metadata_props"): + existing_keys = {prop.key for prop in target_node.metadata_props} + else: + existing_keys = set() + + for prop in node.metadata_props: + if prop.key in existing_keys: + if mode == "overwrite": + # Overwrite existing metadata property + for existing_prop in target_node.metadata_props: + if existing_prop.key == prop.key: + existing_prop.value = prop.value + break + else: + target_node.metadata_props.append(prop) \ No newline at end of file diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py new file mode 100644 index 00000000..f976eabb --- /dev/null +++ b/tests/util/test_copy_metadata.py @@ -0,0 +1,61 @@ + +import onnx +import pytest +from qonnx.util.basic import copy_metadata_props + + +def add_metadata(key, value): + return onnx.StringStringEntryProto(key=key, value=value) + + +def test_copy_metadata_props(): + + # Create source node with metadata + src_node = onnx.NodeProto( + metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")] + ) + dst_node = onnx.NodeProto() + + copy_metadata_props(src_node, dst_node) + + assert len(dst_node.metadata_props) == 2 + assert dst_node.metadata_props[0].key == "key1" + assert dst_node.metadata_props[0].value == "value1" + assert dst_node.metadata_props[1].key == "key2" + assert dst_node.metadata_props[1].value == "value2" + + +@pytest.mark.parametrize("mode", ["keep_existing", "overwrite"]) +def test_copy_metadata_props_existing_target_md(mode): + + # Create source node with metadata + src_node = onnx.NodeProto( + metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")] + ) + # Create destination node with existing metadata + dst_node = onnx.NodeProto( + metadata_props=[add_metadata("key1", "value3")] + ) + + copy_metadata_props(src_node, dst_node, mode=mode) + + assert len(dst_node.metadata_props) == 2 + assert dst_node.metadata_props[0].key == "key1" + + if mode == "keep_existing": + assert dst_node.metadata_props[0].value == "value3" # Should keep existing + elif mode == "overwrite": + assert dst_node.metadata_props[0].value == "value1" # Should be overwritten + + assert dst_node.metadata_props[1].key == "key2" + assert dst_node.metadata_props[1].value == "value2" + + +def test_copy_metadata_props_bad_mode(): + src_node = onnx.NodeProto( + metadata_props=[add_metadata("key1", "value1")] + ) + dst_node = onnx.NodeProto() + + with pytest.raises(AssertionError): + copy_metadata_props(src_node, dst_node, mode="invalid_mode") \ No newline at end of file From 3085eee25c656cca08eed48b89f9e61be45e562f Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Tue, 25 Nov 2025 23:43:15 +0000 Subject: [PATCH 24/29] update documention copy_metadata_props --- src/qonnx/util/basic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 2c52285a..2752212b 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -372,6 +372,14 @@ def copy_metadata_props(source_node, target_node, mode="overwrite"): metadata from all nodes will be merged into the target node. target_node : onnx.NodeProto Target node to which metadata_props will be copied. + mode : str, optional + Mode for handling existing metadata properties in the target node. + Options are: + - "overwrite": Existing properties in the target node will be overwritten + by those from the source node(s) if they share the same key. + - "keep_existing": Existing properties in the target node will be kept, + and only new properties from the source node(s) will be added. + Default is "overwrite". Returns ------- From 43793044915aa32ab47c2430a11f7dd80b74d987 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 26 Nov 2025 17:25:56 +0000 Subject: [PATCH 25/29] add gemm2matmul test --- tests/util/test_copy_metadata.py | 34 +++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py index f976eabb..81f39a09 100644 --- a/tests/util/test_copy_metadata.py +++ b/tests/util/test_copy_metadata.py @@ -1,6 +1,7 @@ import onnx import pytest +from qonnx.transformation.infer_shapes import InferShapes from qonnx.util.basic import copy_metadata_props @@ -58,4 +59,35 @@ def test_copy_metadata_props_bad_mode(): dst_node = onnx.NodeProto() with pytest.raises(AssertionError): - copy_metadata_props(src_node, dst_node, mode="invalid_mode") \ No newline at end of file + copy_metadata_props(src_node, dst_node, mode="invalid_mode") + + +from onnxscript import script +from onnxscript import opset9 as op +from onnxscript import FLOAT +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.gemm_to_matmul import GemmToMatMul + +def test_copy_metadata_props_gemm2matmul(): + + @script() + def MyGemm(A: FLOAT[4, 5], B: FLOAT[5, 4], C: FLOAT[4, 4]) -> FLOAT[4, 4]: + return op.Gemm(A, B, C) + + model_proto = MyGemm.to_model_proto() + gemm_node = model_proto.graph.node[0] + gemm_node.metadata_props.extend([ + add_metadata("key1", "value1"), + add_metadata("key2", "value2") + ]) + + # Create Model Wrapper + mw = ModelWrapper(model_proto) + + transformed_mw = mw.transform(GemmToMatMul()) + + for node in transformed_mw.graph.node: + assert node.metadata_props[0].key == 'key1' + assert node.metadata_props[0].value == 'value1' + assert node.metadata_props[1].key == 'key2' + assert node.metadata_props[1].value == 'value2' \ No newline at end of file From 9cbd803969fe0d844aef65ef96e9754130a63af8 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 26 Nov 2025 18:55:18 +0000 Subject: [PATCH 26/29] add batchnorm to affine test --- tests/util/test_copy_metadata.py | 46 +++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py index 81f39a09..32a9ed1f 100644 --- a/tests/util/test_copy_metadata.py +++ b/tests/util/test_copy_metadata.py @@ -63,7 +63,7 @@ def test_copy_metadata_props_bad_mode(): from onnxscript import script -from onnxscript import opset9 as op +from onnxscript import opset17 as op from onnxscript import FLOAT from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.gemm_to_matmul import GemmToMatMul @@ -86,6 +86,50 @@ def MyGemm(A: FLOAT[4, 5], B: FLOAT[5, 4], C: FLOAT[4, 4]) -> FLOAT[4, 4]: transformed_mw = mw.transform(GemmToMatMul()) + for node in transformed_mw.graph.node: + assert node.metadata_props[0].key == 'key1' + assert node.metadata_props[0].value == 'value1' + assert node.metadata_props[1].key == 'key2' + assert node.metadata_props[1].value == 'value2' + + +from onnx import helper as oh +import numpy as np +import onnxscript +from onnxscript.ir.passes.common import LiftConstantsToInitializersPass + + +def test_copy_metadata_props_batchnorm2affine(): + @script() + def MyBatchNorm(X: FLOAT[1, 3, 4, 4]) -> FLOAT[1, 3, 4, 4]: + scale = op.Constant(value=[[1.0, 1.0, 1.0]]) + B = op.Constant(value=[[0.0, 0.0, 0.0]]) + var = op.Constant(value=[[1.0, 1.0, 1.0]]) + mean = op.Constant(value=[[0.0, 0.0, 0.0]]) + return op.BatchNormalization(X, scale, B, mean, var, epsilon=1e-5, momentum=0.9) + + # remove cast-like nodes + model_proto = onnxscript.optimizer.optimize(MyBatchNorm.to_model_proto()) + + # batchnorm_to_affine requires initializers for scale/mean/var/bias + model_ir = onnxscript.ir.serde.deserialize_model(model_proto) + pass_ = LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=1) + PassResult = pass_.call(model_ir) + model_proto = onnxscript.ir.serde.serialize_model(PassResult.model) + + # Add metadata to BatchNorm node + bn_node = model_proto.graph.node[0] + bn_node.metadata_props.extend([ + add_metadata("key1", "value1"), + add_metadata("key2", "value2") + ]) + + # Create Model Wrapper + mw = ModelWrapper(model_proto) + from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine + transformed_mw = mw.transform(BatchNormToAffine()) + + # Check that metadata was copied for node in transformed_mw.graph.node: assert node.metadata_props[0].key == 'key1' assert node.metadata_props[0].value == 'value1' From 47c9bb855417e97bbda776c527f1a67a155811e1 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 26 Nov 2025 19:05:44 +0000 Subject: [PATCH 27/29] force precommit run --- tests/util/test_copy_metadata.py | 109 +++++++++++++------------------ 1 file changed, 44 insertions(+), 65 deletions(-) diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py index 32a9ed1f..1cc913b9 100644 --- a/tests/util/test_copy_metadata.py +++ b/tests/util/test_copy_metadata.py @@ -1,7 +1,14 @@ +import pytest import onnx -import pytest -from qonnx.transformation.infer_shapes import InferShapes +import onnxscript +from onnxscript import FLOAT +from onnxscript import opset17 as op +from onnxscript import script +from onnxscript.ir.passes.common import LiftConstantsToInitializersPass + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.gemm_to_matmul import GemmToMatMul from qonnx.util.basic import copy_metadata_props @@ -10,15 +17,12 @@ def add_metadata(key, value): def test_copy_metadata_props(): - # Create source node with metadata - src_node = onnx.NodeProto( - metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")] - ) + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")]) dst_node = onnx.NodeProto() - + copy_metadata_props(src_node, dst_node) - + assert len(dst_node.metadata_props) == 2 assert dst_node.metadata_props[0].key == "key1" assert dst_node.metadata_props[0].value == "value1" @@ -28,77 +32,54 @@ def test_copy_metadata_props(): @pytest.mark.parametrize("mode", ["keep_existing", "overwrite"]) def test_copy_metadata_props_existing_target_md(mode): - # Create source node with metadata - src_node = onnx.NodeProto( - metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")] - ) + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")]) # Create destination node with existing metadata - dst_node = onnx.NodeProto( - metadata_props=[add_metadata("key1", "value3")] - ) - + dst_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value3")]) + copy_metadata_props(src_node, dst_node, mode=mode) - + assert len(dst_node.metadata_props) == 2 assert dst_node.metadata_props[0].key == "key1" - + if mode == "keep_existing": assert dst_node.metadata_props[0].value == "value3" # Should keep existing elif mode == "overwrite": assert dst_node.metadata_props[0].value == "value1" # Should be overwritten - + assert dst_node.metadata_props[1].key == "key2" assert dst_node.metadata_props[1].value == "value2" - - + + def test_copy_metadata_props_bad_mode(): - src_node = onnx.NodeProto( - metadata_props=[add_metadata("key1", "value1")] - ) + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1")]) dst_node = onnx.NodeProto() - + with pytest.raises(AssertionError): copy_metadata_props(src_node, dst_node, mode="invalid_mode") - -from onnxscript import script -from onnxscript import opset17 as op -from onnxscript import FLOAT -from qonnx.core.modelwrapper import ModelWrapper -from qonnx.transformation.gemm_to_matmul import GemmToMatMul - -def test_copy_metadata_props_gemm2matmul(): +def test_copy_metadata_props_gemm2matmul(): @script() def MyGemm(A: FLOAT[4, 5], B: FLOAT[5, 4], C: FLOAT[4, 4]) -> FLOAT[4, 4]: return op.Gemm(A, B, C) model_proto = MyGemm.to_model_proto() gemm_node = model_proto.graph.node[0] - gemm_node.metadata_props.extend([ - add_metadata("key1", "value1"), - add_metadata("key2", "value2") - ]) + gemm_node.metadata_props.extend([add_metadata("key1", "value1"), add_metadata("key2", "value2")]) # Create Model Wrapper mw = ModelWrapper(model_proto) - + transformed_mw = mw.transform(GemmToMatMul()) - + for node in transformed_mw.graph.node: - assert node.metadata_props[0].key == 'key1' - assert node.metadata_props[0].value == 'value1' - assert node.metadata_props[1].key == 'key2' - assert node.metadata_props[1].value == 'value2' - - -from onnx import helper as oh -import numpy as np -import onnxscript -from onnxscript.ir.passes.common import LiftConstantsToInitializersPass - - + assert node.metadata_props[0].key == "key1" + assert node.metadata_props[0].value == "value1" + assert node.metadata_props[1].key == "key2" + assert node.metadata_props[1].value == "value2" + + def test_copy_metadata_props_batchnorm2affine(): @script() def MyBatchNorm(X: FLOAT[1, 3, 4, 4]) -> FLOAT[1, 3, 4, 4]: @@ -107,31 +88,29 @@ def MyBatchNorm(X: FLOAT[1, 3, 4, 4]) -> FLOAT[1, 3, 4, 4]: var = op.Constant(value=[[1.0, 1.0, 1.0]]) mean = op.Constant(value=[[0.0, 0.0, 0.0]]) return op.BatchNormalization(X, scale, B, mean, var, epsilon=1e-5, momentum=0.9) - + # remove cast-like nodes - model_proto = onnxscript.optimizer.optimize(MyBatchNorm.to_model_proto()) - + model_proto = onnxscript.optimizer.optimize(MyBatchNorm.to_model_proto()) + # batchnorm_to_affine requires initializers for scale/mean/var/bias model_ir = onnxscript.ir.serde.deserialize_model(model_proto) pass_ = LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=1) PassResult = pass_.call(model_ir) model_proto = onnxscript.ir.serde.serialize_model(PassResult.model) - + # Add metadata to BatchNorm node bn_node = model_proto.graph.node[0] - bn_node.metadata_props.extend([ - add_metadata("key1", "value1"), - add_metadata("key2", "value2") - ]) - + bn_node.metadata_props.extend([add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + # Create Model Wrapper mw = ModelWrapper(model_proto) from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine + transformed_mw = mw.transform(BatchNormToAffine()) - + # Check that metadata was copied for node in transformed_mw.graph.node: - assert node.metadata_props[0].key == 'key1' - assert node.metadata_props[0].value == 'value1' - assert node.metadata_props[1].key == 'key2' - assert node.metadata_props[1].value == 'value2' \ No newline at end of file + assert node.metadata_props[0].key == "key1" + assert node.metadata_props[0].value == "value1" + assert node.metadata_props[1].key == "key2" + assert node.metadata_props[1].value == "value2" From 9e1ea73c25c98f572007e5e796707a89cdde858f Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Mon, 1 Dec 2025 16:44:30 +0000 Subject: [PATCH 28/29] revert to old version of extract model config --- src/qonnx/util/config.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/src/qonnx/util/config.py b/src/qonnx/util/config.py index 0937227f..d9733984 100644 --- a/src/qonnx/util/config.py +++ b/src/qonnx/util/config.py @@ -27,39 +27,25 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json -import onnx from qonnx.custom_op.registry import getCustomOp -# update this code to handle export configs from subgraphs -# where the subgraph is found in a node's attribute as a graph type -def extract_model_config(model, attr_names_to_extract): - """Create a dictionary with layer name -> attribute mappings extracted from the - model. The created dictionary can be later applied on a model with - qonnx.transform.general.ApplyConfig.""" +def extract_model_config_to_json(model, json_filename, attr_names_to_extract): + """Create a json file with layer name -> attribute mappings extracted from the + model.""" cfg = dict() cfg["Defaults"] = dict() for n in model.graph.node: oi = getCustomOp(n) layer_dict = dict() - for attr in n.attribute: - if attr.type == onnx.AttributeProto.GRAPH: # Graph type - # If the attribute is a graph, we need to extract the attributes from the subgraph - cfg.update(extract_model_config(model.make_subgraph_modelwrapper(attr.g), attr_names_to_extract)) - elif attr.name in attr_names_to_extract: - # If the attribute name is in the list, we can add it directly - layer_dict[attr.name] = oi.get_nodeattr(attr.name) + for attr in attr_names_to_extract: + try: + layer_dict[attr] = oi.get_nodeattr(attr) + except AttributeError: + pass if len(layer_dict) > 0: cfg[n.name] = layer_dict - return cfg - - -def extract_model_config_to_json(model, json_filename, attr_names_to_extract): - """Create a json file with layer name -> attribute mappings extracted from the - model. The created json file can be later applied on a model with - qonnx.transform.general.ApplyConfig.""" - with open(json_filename, "w") as f: - json.dump(extract_model_config(model, attr_names_to_extract), f, indent=2) + json.dump(cfg, f, indent=2) From fcd803a018cb1afb8173ee65b2fdf8bdf97efaca Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Wed, 3 Dec 2025 13:28:45 -0800 Subject: [PATCH 29/29] Cleanup comments and unused entrypoint logic --- docs/overview.rst | 3 +-- src/qonnx/__init__.py | 25 ------------------------- src/qonnx/custom_op/__init__.py | 30 ------------------------------ 3 files changed, 1 insertion(+), 57 deletions(-) diff --git a/docs/overview.rst b/docs/overview.rst index 161d1e49..bd0d004a 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -46,7 +46,7 @@ Custom Operations/Nodes QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defined in the ONNX operator schema. These custom nodes are marked with domain="qonnx.*" in the protobuf to identify them as such. These nodes can represent specific operations that we need for low-bit networks, or operations that are specific to a particular hardware backend. To get more familiar with custom operations and how they are created, please take a look in the Jupyter notebook about CustomOps (see chapter :ref:`tutorials` for details) or directly in the module :py:mod:`qonnx.custom_op`. Custom ops are automatically discovered through Python module namespaces. -Simply define your CustomOp subclass in the appropriate domain module +Simply import your CustomOp subclass in the appropriate domain module (e.g., ``qonnx.custom_op.general`` for general ops) and it will be automatically available through ``getCustomOp``. @@ -57,7 +57,6 @@ For dynamic registration and querying, use the registry functions: * ``add_op_to_domain(domain, op_class)`` - Register an op at runtime (for testing) * ``get_ops_in_domain(domain)`` - List all ops available in a domain * ``add_domain_alias(domain, module_path)`` - Map a domain to a different module path -* ``hasCustomOp(domain, op_type)`` - Check if an op exists in a domain Custom ONNX Execution Flow diff --git a/src/qonnx/__init__.py b/src/qonnx/__init__.py index 217648b8..e69de29b 100644 --- a/src/qonnx/__init__.py +++ b/src/qonnx/__init__.py @@ -1,25 +0,0 @@ -"""QONNX package initialization.""" - -import warnings -from importlib import metadata - - -def _load_custom_op_entry_points(): - """Import modules registered under the ``qonnx_custom_ops`` entry point.""" - - try: - eps = metadata.entry_points() - if hasattr(eps, "select"): - eps = eps.select(group="qonnx_custom_ops") - else: - eps = eps.get("qonnx_custom_ops", []) - for ep in eps: - try: - ep.load() - except Exception as e: # pragma: no cover - import failure warning - warnings.warn(f"Failed to load custom op entry point {ep.name}: {e}") - except Exception as e: # pragma: no cover - metadata failure warning - warnings.warn(f"Failed to query custom op entry points: {e}") - - -_load_custom_op_entry_points() diff --git a/src/qonnx/custom_op/__init__.py b/src/qonnx/custom_op/__init__.py index 7c38a8df..e69de29b 100644 --- a/src/qonnx/custom_op/__init__.py +++ b/src/qonnx/custom_op/__init__.py @@ -1,30 +0,0 @@ -# Copyright (c) 2020 Xilinx, Inc. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of Xilinx nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# Domain aliases are automatically handled by the registry -# The onnx.brevitas -> qonnx.custom_op.general mapping is built into the registry \ No newline at end of file