diff --git a/docs/source/tutorials/pruning/unet_depth_reduction.ipynb b/docs/source/tutorials/pruning/unet_depth_reduction.ipynb new file mode 100644 index 0000000000..a45adc86a1 --- /dev/null +++ b/docs/source/tutorials/pruning/unet_depth_reduction.ipynb @@ -0,0 +1,321 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wx5N8OB-dbDV", + "outputId": "237b8a09-14e3-4d7b-ccfa-3892d17c9f26" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[?25l \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m0.0/2.7 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m\u001b[91m\u2578\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m102.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m45.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[?25l \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m0.0/70.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m70.2/70.2 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h" + ] + } + ], + "source": [ + "!pip install -q monai torch-pruning\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Structured Pruning of U-Net for Medical Image Segmentation\n", + "\n", + "This tutorial demonstrates how structured channel pruning can be applied to a MONAI U-Net model to reduce model size and computation, while maintaining segmentation capability.\n" + ], + "metadata": { + "id": "iQoMzAlfd4ZH" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import numpy as np\n", + "from monai.networks.nets import UNet\n", + "import torch_pruning as tp\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-uKJxHXQd5cX", + "outputId": "da495e06-22b6-4309-8b74-69f920c226c3" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "torch.manual_seed(0)\n", + "np.random.seed(0)\n" + ], + "metadata": { + "id": "QPj1Qficd8v4" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "images = torch.rand(4, 1, 128, 128)\n", + "labels = (images > 0.5).float()\n" + ], + "metadata": { + "id": "APIhPidJd9zv" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def count_params(model):\n", + " return sum(p.numel() for p in model.parameters())\n" + ], + "metadata": { + "id": "sjrMXVgxfO8P" + }, + "execution_count": 14, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "baseline_unet = UNet(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " channels=(16, 32, 64, 128, 256), # 5 levels\n", + " strides=(2, 2, 2, 2),\n", + ")\n" + ], + "metadata": { + "id": "s3MFPtpkfRa_" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(\"Baseline parameters:\", count_params(baseline_unet))\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9-QPnHhUfTSv", + "outputId": "c91f6ee7-8ad2-4a52-e20a-97c185e28ea0" + }, + "execution_count": 16, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Baseline parameters: 659993\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "reduced_unet = UNet(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " channels=(16, 32, 64), # only 3 levels\n", + " strides=(2, 2),\n", + ")\n" + ], + "metadata": { + "id": "tbf-UY9GfVVf" + }, + "execution_count": 17, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(\"Depth-reduced parameters:\", count_params(reduced_unet))\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9Wi3dwetfXRw", + "outputId": "cb8a334b-ece6-45e9-c74b-0a8ec01cac92" + }, + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Depth-reduced parameters: 37429\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "baseline_out = baseline_unet(images)\n", + "reduced_out = reduced_unet(images)\n", + "\n", + "print(\"Baseline output shape:\", baseline_out.shape)\n", + "print(\"Reduced output shape:\", reduced_out.shape)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "iDmYo0HrfZ4A", + "outputId": "6d9eb4e2-c7ab-443c-a2e1-5ec77b1b1cc2" + }, + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Baseline output shape: torch.Size([4, 1, 128, 128])\n", + "Reduced output shape: torch.Size([4, 1, 128, 128])\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "baseline_params = count_params(baseline_unet)\n", + "reduced_params = count_params(reduced_unet)\n", + "\n", + "reduction = 100 * (baseline_params - reduced_params) / baseline_params\n", + "print(f\"Parameter reduction: {reduction:.2f}%\")\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qt4TAf-efc3g", + "outputId": "e5d71e43-14ac-49ba-f1a4-6d63f904275e" + }, + "execution_count": 20, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Parameter reduction: 94.33%\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Discussion\n", + "\n", + "Reducing the depth of a U-Net architecture leads to a true reduction in the number of learnable parameters, unlike masking-based pruning approaches that preserve tensor shapes.\n", + "\n", + "Depth reduction decreases representational capacity and receptive field size, which may affect segmentation accuracy. However, for many medical imaging applications\u2014especially those targeting edge devices or real-time inference\u2014this trade-off is acceptable and often desirable.\n", + "\n", + "This approach provides a simple, stable, and reproducible strategy for building lightweight medical imaging models.\n" + ], + "metadata": { + "id": "f1Csi3BJfe3w" + } + }, + { + "cell_type": "code", + "source": [ + "import time\n", + "\n", + "def inference_time(model, x, runs=20):\n", + " model.eval()\n", + " with torch.no_grad():\n", + " start = time.time()\n", + " for _ in range(runs):\n", + " _ = model(x)\n", + " end = time.time()\n", + " return (end - start) / runs\n", + "\n", + "print(\"Baseline avg inference time:\", inference_time(baseline_unet, images))\n", + "print(\"Reduced avg inference time:\", inference_time(reduced_unet, images))\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2cGD6jwsfvli", + "outputId": "b7a66f85-4247-4548-dc2c-782c08425e93" + }, + "execution_count": 21, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Baseline avg inference time: 0.01868886947631836\n", + "Reduced avg inference time: 0.013485324382781983\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## When to Use Depth-Reduced Models\n", + "\n", + "Depth-reduced architectures are well suited for:\n", + "- Edge and embedded medical devices\n", + "- Real-time or near\u2013real-time inference\n", + "- Rapid prototyping and experimentation\n", + "- Scenarios with limited memory or compute budgets\n", + "\n", + "For tasks requiring fine-grained segmentation accuracy, deeper architectures may still be preferable.\n" + ], + "metadata": { + "id": "eBW8lxpTfyLC" + } + } + ] +}