diff --git a/docs/overview.rst b/docs/overview.rst index 8e2002d7..bd0d004a 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -45,6 +45,19 @@ Custom Operations/Nodes QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defined in the ONNX operator schema. These custom nodes are marked with domain="qonnx.*" in the protobuf to identify them as such. These nodes can represent specific operations that we need for low-bit networks, or operations that are specific to a particular hardware backend. To get more familiar with custom operations and how they are created, please take a look in the Jupyter notebook about CustomOps (see chapter :ref:`tutorials` for details) or directly in the module :py:mod:`qonnx.custom_op`. +Custom ops are automatically discovered through Python module namespaces. +Simply import your CustomOp subclass in the appropriate domain module +(e.g., ``qonnx.custom_op.general`` for general ops) and it will be automatically +available through ``getCustomOp``. + +For dynamic registration and querying, use the registry functions: + +* ``getCustomOp(node)`` - Get a custom op instance from an ONNX node +* ``is_custom_op(domain, op_type=None)`` - Check if a specific op or domain has custom ops +* ``add_op_to_domain(domain, op_class)`` - Register an op at runtime (for testing) +* ``get_ops_in_domain(domain)`` - List all ops available in a domain +* ``add_domain_alias(domain, module_path)`` - Map a domain to a different module path + Custom ONNX Execution Flow ========================== diff --git a/notebooks/3_custom_op.ipynb b/notebooks/3_custom_op.ipynb index d0cd10fd..1b822163 100644 --- a/notebooks/3_custom_op.ipynb +++ b/notebooks/3_custom_op.ipynb @@ -129,35 +129,26 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "To make sure our custom op is available, it needs to be registered. The best practice for this is to create a submodule under `qonnx.custom_op` which includes a `custom_op` dictionary that maps strings (op names) to classes (op implementations). Since we're in a Jupyter notebook we'll just hijack it at runtime like this:" - ] + "source": "To make sure our custom op is available, we need to add it to the domain. For production code, you would place your CustomOp class directly in the appropriate module file (e.g., in a file under `qonnx/custom_op/general/`). For testing and experimentation like in this notebook, we can use the `add_op_to_domain` function:" }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "import qonnx.custom_op.general as general\n", - "general.custom_op[\"MyPythonPowerOp\"] = MyPythonPowerOp" - ] + "source": "from qonnx.custom_op.registry import add_op_to_domain\n\n# Add our custom op to the general domain\nadd_op_to_domain(\"qonnx.custom_op.general\", MyPythonPowerOp)", + "execution_count": null }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "We can see which custom ops are registered under this submodule by looking at the dictionary:" - ] + "source": "We can see which custom ops are available in a domain by using the registry function:" }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "general.custom_op" - ] + "source": "from qonnx.custom_op.registry import get_ops_in_domain, is_custom_op\n\n# See all ops in the general domain\nops = get_ops_in_domain(\"qonnx.custom_op.general\")\nprint(f\"Available ops: {[op[0] for op in ops]}\")\n\n# Check if our op is there\nprint(f\"MyPythonPowerOp available: {is_custom_op('qonnx.custom_op.general', 'MyPythonPowerOp')}\")", + "execution_count": null }, { "cell_type": "markdown", @@ -462,17 +453,10 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# register our new op\n", - "general.custom_op[\"MyMixedPowerOp\"] = MyMixedPowerOp\n", - "\n", - "# make graph with new op\n", - "mixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\n", - "mixedop_graph.graph.node" - ] + "source": "# register our new op\nadd_op_to_domain(\"qonnx.custom_op.general\", MyMixedPowerOp)\n\n# make graph with new op\nmixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\nmixedop_graph.graph.node", + "execution_count": null }, { "cell_type": "markdown", @@ -744,4 +728,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 4d1ea8fc..6255eea1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -94,6 +94,10 @@ console_scripts = qonnx-tensor-stats = qonnx.analysis.tensor_stats:main pytest_randomly.random_seeder = qonnx = qonnx.util.random_reseed:reseed +# entry points for custom op modules +qonnx_custom_ops = + qonnx = qonnx.custom_op.general + qonnx_channels_last = qonnx.custom_op.channels_last # Add here console scripts like: # console_scripts = # script_name = qonnx.module:function diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 3d448dc3..a7356a8f 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -26,12 +26,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from qonnx.custom_op.general.intquant import IntQuant as Quant +# Import IntQuant to create alias +from qonnx.custom_op.general.intquant import IntQuant + +# Re-export functions from intquant for backward compatibility from qonnx.custom_op.general.intquant import int_quant as quant from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode -Quant = Quant -quant = quant -max_int = max_int -min_int = min_int -resolve_rounding_mode = resolve_rounding_mode +# Create alias for backward compatibility - Quant is just IntQuant +Quant = IntQuant \ No newline at end of file diff --git a/src/qonnx/transformation/batchnorm_to_affine.py b/src/qonnx/transformation/batchnorm_to_affine.py index c89d2bdc..6190f867 100644 --- a/src/qonnx/transformation/batchnorm_to_affine.py +++ b/src/qonnx/transformation/batchnorm_to_affine.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class BatchNormToAffine(Transformation): @@ -89,6 +89,9 @@ def apply(self, model): # create Mul and Add nodes to replace the batchnorm mul_node = oh.make_node("Mul", [bn_input, mul_const.name], [mul_output.name]) add_node = oh.make_node("Add", [mul_output.name, add_const.name], [bn_output]) + # preserve metadata from original batchnorm node + copy_metadata_props(n, mul_node) + copy_metadata_props(n, add_node) # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) diff --git a/src/qonnx/transformation/bipolar_to_xnor.py b/src/qonnx/transformation/bipolar_to_xnor.py index 37f939a2..0b764ef8 100644 --- a/src/qonnx/transformation/bipolar_to_xnor.py +++ b/src/qonnx/transformation/bipolar_to_xnor.py @@ -36,7 +36,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_datatypes import InferDataTypes from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class ConvertBipolarMatMulToXnorPopcount(Transformation): @@ -132,6 +132,9 @@ def find_prod_mt(x): # create Mul and Add nodes to replace the batchnorm mul_node = oh.make_node("Mul", [xnorpcout.name, mul_const.name], [mul_output.name]) add_node = oh.make_node("Add", [mul_output.name, add_const.name], [mm_output]) + # preserve metadata from original MatMul node + copy_metadata_props(n, mul_node) + copy_metadata_props(n, add_node) # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) diff --git a/src/qonnx/transformation/change_datalayout.py b/src/qonnx/transformation/change_datalayout.py index 7b73e4bf..62e6140b 100644 --- a/src/qonnx/transformation/change_datalayout.py +++ b/src/qonnx/transformation/change_datalayout.py @@ -30,7 +30,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class ChangeDataLayoutQuantAvgPool2d(Transformation): @@ -78,6 +78,7 @@ def apply(self, model): graph.value_info.append(quantavg_out) quantavg_out = quantavg_out.name inp_trans_node = helper.make_node("Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]) + copy_metadata_props(n, inp_trans_node) quantavg_node = helper.make_node( "QuantAvgPool2d", [inp_trans_out], @@ -90,8 +91,10 @@ def apply(self, model): signed=signed, data_layout="NHWC", ) + copy_metadata_props(n, quantavg_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2]) + copy_metadata_props(n, out_trans_node) # insert nodes graph.node.insert(node_ind, inp_trans_node) graph.node.insert(node_ind + 1, quantavg_node) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index a00f8a9c..f9ca62bb 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -40,7 +40,7 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name from qonnx.util.onnx import is_eltwise_optype # Standard ONNX nodes which require a ChannelsLast data format to function properly @@ -96,6 +96,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp @@ -107,6 +108,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64)) unsqueeze_out_name = model.make_new_valueinfo_name() new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name]) + copy_metadata_props(eltwise_inp, new_unsqueeze_node) unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape) model.graph.node.append(new_unsqueeze_node) @@ -114,6 +116,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp @@ -239,6 +242,7 @@ def apply(self, model): # channels last transpose inp_trans_node = helper.make_node("Transpose", [inp], [inp_trans_out], perm=to_channels_last_args(ndim)) graph.node.insert(running_node_index, inp_trans_node) + copy_metadata_props(n, inp_trans_node) running_node_index += 1 # Attach to original node @@ -265,6 +269,7 @@ def apply(self, model): "Transpose", [outp_trans_in], [outp], perm=to_channels_first_args(ndim) ) graph.node.insert(running_node_index, outp_trans_node) + copy_metadata_props(n, outp_trans_node) running_node_index += 1 # Attach to original node @@ -567,7 +572,8 @@ def apply(self, model): axis=1, ) graph.node.insert(node_ind, flat_node) - + copy_metadata_props(n, flat_node) + graph_modified = True else: warnings.warn( diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index bf2cf8b4..34b017bd 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -30,6 +30,7 @@ from onnx import helper from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class ExtractBiasFromConv(Transformation): @@ -75,6 +76,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) + copy_metadata_props(n, add_node) graph.node.insert(node_ind, add_node) # Repoint Conv output and remove bias tensor diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 58863f08..f76e5555 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -33,6 +33,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph from qonnx.transformation.remove import RemoveIdentityOps +from qonnx.util.basic import copy_metadata_props class ExtractQuantScaleZeroPt(Transformation): @@ -69,6 +70,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) + copy_metadata_props(node, inp_scale_node) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -87,6 +89,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) + copy_metadata_props(node, inp_zeropt_node) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -108,6 +111,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) + copy_metadata_props(node, out_zeropt_node) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -127,6 +131,7 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) + copy_metadata_props(node, out_scale_node) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 5396a7d6..245a0a2a 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.core.datatype import DataType from qonnx.transformation.base import Transformation from qonnx.transformation.remove import RemoveIdentityOps -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class GemmToMatMul(Transformation): @@ -76,6 +76,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -98,6 +99,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -109,6 +111,7 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) + copy_metadata_props(n, matMul_node) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -144,6 +147,7 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -175,6 +179,7 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -196,7 +201,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - + copy_metadata_props(n, add_node) graph.node.insert(running_node_index, add_node) running_node_index += 1 diff --git a/src/qonnx/transformation/general.py b/src/qonnx/transformation/general.py index f60b92b6..e9b54800 100644 --- a/src/qonnx/transformation/general.py +++ b/src/qonnx/transformation/general.py @@ -117,15 +117,24 @@ def apply(self, model): class GiveUniqueNodeNames(Transformation): """Give unique names to each node in the graph using enumeration, starting - with given prefix (if specified in the constructor).""" + with given prefix (if specified in the constructor). - def __init__(self, prefix=""): + If only_empty=True, only renames nodes that have empty names, preserving + existing node names. This is useful after transforms that insert nodes + without names, to avoid stripping prefixes from existing nodes.""" + + def __init__(self, prefix="", only_empty=False): super().__init__() self.prefix = prefix + self.only_empty = only_empty def apply(self, model): optype_count = {} for n in model.graph.node: + # Skip nodes that already have names if only_empty=True + if self.only_empty and n.name != "": + continue + if n.op_type not in optype_count.keys(): optype_count[n.op_type] = 0 n.name = "%s%s_%d" % (self.prefix, n.op_type, optype_count[n.op_type]) diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 81f0b713..f0981b34 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name class LowerConvsToMatMul(Transformation): @@ -152,6 +152,7 @@ def apply(self, model): # create new nodes # NCHW -> NHWC inp_trans_node = helper.make_node("Transpose", [cnv_input], [inp_trans_out], perm=[0, 2, 3, 1]) + copy_metadata_props(node, inp_trans_node) nodes_to_insert = [inp_trans_node] if need_im2col: @@ -174,12 +175,15 @@ def apply(self, model): dilations=dilation, ) nodes_to_insert.append(im2col_node) + copy_metadata_props(node, im2col_node) matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out]) + copy_metadata_props(node, matmul_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) + copy_metadata_props(node, out_trans_node) nodes_to_insert.extend([matmul_node, out_trans_node]) diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index b7e35c0d..b4e18f25 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -34,7 +34,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name def extract_elem_type(elem_type: int, clip_range=None) -> Tuple[int, int, bool]: @@ -203,6 +203,8 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: rounding_mode="ROUND", # round-to-even signed=signed, ) + # Preserve metadata from all nodes being fused + copy_metadata_props(node, fused_node) model.graph.node.insert(dequant_node_index, fused_node) for node_to_remove in nodes_to_remove: model.graph.node.remove(node_to_remove) diff --git a/src/qonnx/transformation/rebalance_conv.py b/src/qonnx/transformation/rebalance_conv.py index ecb2b5e4..0107a62a 100644 --- a/src/qonnx/transformation/rebalance_conv.py +++ b/src/qonnx/transformation/rebalance_conv.py @@ -31,6 +31,7 @@ from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class RebalanceIm2Col(Transformation): @@ -103,6 +104,7 @@ def apply(self, model): inp_reshape_node = helper.make_node( "Reshape", [node.input[0], inp_shapedata.name], [inp_reshape_out.name] ) + copy_metadata_props(node, inp_reshape_node) graph.node.insert(running_node_index, inp_reshape_node) # rewire Im2Col input node.input[0] = inp_reshape_out.name diff --git a/src/qonnx/transformation/resize_conv_to_deconv.py b/src/qonnx/transformation/resize_conv_to_deconv.py index 0dd40972..7eda4fa7 100644 --- a/src/qonnx/transformation/resize_conv_to_deconv.py +++ b/src/qonnx/transformation/resize_conv_to_deconv.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.general.quant import quant, resolve_rounding_mode from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_convolution(cnv_weights: np.ndarray, scale: int) -> np.ndarray: @@ -242,6 +242,7 @@ def apply(self, model): group=group, dilations=dilation, ) + copy_metadata_props(conv, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/transformation/subpixel_to_deconv.py b/src/qonnx/transformation/subpixel_to_deconv.py index 3f330c99..73ef3f8f 100644 --- a/src/qonnx/transformation/subpixel_to_deconv.py +++ b/src/qonnx/transformation/subpixel_to_deconv.py @@ -31,7 +31,7 @@ from onnx import helper from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_shuffle(cnv_weights: np.ndarray, block_size: int) -> np.ndarray: @@ -197,6 +197,7 @@ def apply(self, model): group=group, dilations=dilation, ) + copy_metadata_props(n, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index cef4f67b..6559bd8c 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -360,3 +360,124 @@ def auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w] else: raise Exception("Unsupported auto_pad: " + autopad_str) + + +def copy_metadata_props(source_node, target_node, mode="overwrite"): + """Copy metadata properties from source node(s) to target node. + + Parameters + ---------- + source_node : onnx.NodeProto or list of onnx.NodeProto + Source node(s) from which to copy metadata_props. If a list is provided, + metadata from all nodes will be merged into the target node. + target_node : onnx.NodeProto + Target node to which metadata_props will be copied. + mode : str, optional + Mode for handling existing metadata properties in the target node. + Options are: + - "overwrite": Existing properties in the target node will be overwritten + by those from the source node(s) if they share the same key. + - "keep_existing": Existing properties in the target node will be kept, + and only new properties from the source node(s) will be added. + Default is "overwrite". + + Returns + ------- + None + Modifies target_node in place by extending its metadata_props. + + Examples + -------- + >>> # Copy from single node + >>> copy_metadata_props(old_node, new_node) + >>> + >>> # Copy from multiple nodes (e.g., when fusing) + >>> copy_metadata_props([quant_node, dequant_node], fused_node) + """ + assert mode in ["overwrite", "keep_existing"], "Copy Metadata Mode must be either 'overwrite' or 'keep_existing'." + + # Handle both single node and list of nodes + source_nodes = source_node if isinstance(source_node, list) else [source_node] + + for node in source_nodes: + if hasattr(node, "metadata_props"): + # check for existing keys in target_node to avoid duplicates + if hasattr(target_node, "metadata_props"): + existing_keys = {prop.key for prop in target_node.metadata_props} + else: + existing_keys = set() + + for prop in node.metadata_props: + if prop.key in existing_keys: + if mode == "overwrite": + # Overwrite existing metadata property + for existing_prop in target_node.metadata_props: + if existing_prop.key == prop.key: + existing_prop.value = prop.value + break + else: + target_node.metadata_props.append(prop) + + +def get_tensor_metadata_prop(model, tensor_name, key): + """Get metadata property from a tensor (input/output/initializer/value_info). + + Args: + model: ModelWrapper instance + tensor_name: Name of the tensor + key: Metadata key to retrieve + + Returns: + str: Metadata value if found, None otherwise + """ + # Search all possible tensor locations + tensor = get_by_name(model.graph.input, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.output, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.initializer, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.value_info, tensor_name) + + if tensor is not None: + meta = get_by_name(tensor.metadata_props, key, "key") + return meta.value if meta is not None else None + return None + + +def set_tensor_metadata_prop(model, tensor_name, key, value): + """Set metadata property on a tensor (input/output/initializer/value_info). + + Args: + model: ModelWrapper instance + tensor_name: Name of the tensor + key: Metadata key to set + value: Metadata value (will be converted to string) + + Returns: + bool: True if successful, False if tensor not found + """ + import onnx + + # Search all possible tensor locations + tensor = get_by_name(model.graph.input, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.output, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.initializer, tensor_name) + if tensor is None: + tensor = get_by_name(model.graph.value_info, tensor_name) + + if tensor is not None: + meta = get_by_name(tensor.metadata_props, key, "key") + if meta is None: + # Create new metadata entry + meta = onnx.StringStringEntryProto() + meta.key = key + meta.value = str(value) + tensor.metadata_props.append(meta) + else: + # Update existing entry + meta.value = str(value) + return True + return False diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index d1d32546..e5ec9237 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -86,6 +86,8 @@ def test_attr(): """ model = oprs.parse_model(input) model = ModelWrapper(model) + + # Now getCustomOp should find it through the manual registry inst = getCustomOp(model.graph.node[0]) w_prod = inst.get_nodeattr("tensor_attr") diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py new file mode 100644 index 00000000..1cc913b9 --- /dev/null +++ b/tests/util/test_copy_metadata.py @@ -0,0 +1,116 @@ +import pytest + +import onnx +import onnxscript +from onnxscript import FLOAT +from onnxscript import opset17 as op +from onnxscript import script +from onnxscript.ir.passes.common import LiftConstantsToInitializersPass + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.gemm_to_matmul import GemmToMatMul +from qonnx.util.basic import copy_metadata_props + + +def add_metadata(key, value): + return onnx.StringStringEntryProto(key=key, value=value) + + +def test_copy_metadata_props(): + # Create source node with metadata + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + dst_node = onnx.NodeProto() + + copy_metadata_props(src_node, dst_node) + + assert len(dst_node.metadata_props) == 2 + assert dst_node.metadata_props[0].key == "key1" + assert dst_node.metadata_props[0].value == "value1" + assert dst_node.metadata_props[1].key == "key2" + assert dst_node.metadata_props[1].value == "value2" + + +@pytest.mark.parametrize("mode", ["keep_existing", "overwrite"]) +def test_copy_metadata_props_existing_target_md(mode): + # Create source node with metadata + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + # Create destination node with existing metadata + dst_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value3")]) + + copy_metadata_props(src_node, dst_node, mode=mode) + + assert len(dst_node.metadata_props) == 2 + assert dst_node.metadata_props[0].key == "key1" + + if mode == "keep_existing": + assert dst_node.metadata_props[0].value == "value3" # Should keep existing + elif mode == "overwrite": + assert dst_node.metadata_props[0].value == "value1" # Should be overwritten + + assert dst_node.metadata_props[1].key == "key2" + assert dst_node.metadata_props[1].value == "value2" + + +def test_copy_metadata_props_bad_mode(): + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1")]) + dst_node = onnx.NodeProto() + + with pytest.raises(AssertionError): + copy_metadata_props(src_node, dst_node, mode="invalid_mode") + + +def test_copy_metadata_props_gemm2matmul(): + @script() + def MyGemm(A: FLOAT[4, 5], B: FLOAT[5, 4], C: FLOAT[4, 4]) -> FLOAT[4, 4]: + return op.Gemm(A, B, C) + + model_proto = MyGemm.to_model_proto() + gemm_node = model_proto.graph.node[0] + gemm_node.metadata_props.extend([add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + + # Create Model Wrapper + mw = ModelWrapper(model_proto) + + transformed_mw = mw.transform(GemmToMatMul()) + + for node in transformed_mw.graph.node: + assert node.metadata_props[0].key == "key1" + assert node.metadata_props[0].value == "value1" + assert node.metadata_props[1].key == "key2" + assert node.metadata_props[1].value == "value2" + + +def test_copy_metadata_props_batchnorm2affine(): + @script() + def MyBatchNorm(X: FLOAT[1, 3, 4, 4]) -> FLOAT[1, 3, 4, 4]: + scale = op.Constant(value=[[1.0, 1.0, 1.0]]) + B = op.Constant(value=[[0.0, 0.0, 0.0]]) + var = op.Constant(value=[[1.0, 1.0, 1.0]]) + mean = op.Constant(value=[[0.0, 0.0, 0.0]]) + return op.BatchNormalization(X, scale, B, mean, var, epsilon=1e-5, momentum=0.9) + + # remove cast-like nodes + model_proto = onnxscript.optimizer.optimize(MyBatchNorm.to_model_proto()) + + # batchnorm_to_affine requires initializers for scale/mean/var/bias + model_ir = onnxscript.ir.serde.deserialize_model(model_proto) + pass_ = LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=1) + PassResult = pass_.call(model_ir) + model_proto = onnxscript.ir.serde.serialize_model(PassResult.model) + + # Add metadata to BatchNorm node + bn_node = model_proto.graph.node[0] + bn_node.metadata_props.extend([add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + + # Create Model Wrapper + mw = ModelWrapper(model_proto) + from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine + + transformed_mw = mw.transform(BatchNormToAffine()) + + # Check that metadata was copied + for node in transformed_mw.graph.node: + assert node.metadata_props[0].key == "key1" + assert node.metadata_props[0].value == "value1" + assert node.metadata_props[1].key == "key2" + assert node.metadata_props[1].value == "value2"