Skip to content

Commit 641d0c5

Browse files
committed
Add tutorial for depth-reduced U-Net model
Signed-off-by: Vishal Dave <vdave8633@gmail.com>
1 parent d388d1c commit 641d0c5

File tree

1 file changed

+321
-0
lines changed

1 file changed

+321
-0
lines changed
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"provenance": [],
7+
"gpuType": "T4"
8+
},
9+
"kernelspec": {
10+
"name": "python3",
11+
"display_name": "Python 3"
12+
},
13+
"language_info": {
14+
"name": "python"
15+
},
16+
"accelerator": "GPU"
17+
},
18+
"cells": [
19+
{
20+
"cell_type": "code",
21+
"execution_count": 1,
22+
"metadata": {
23+
"colab": {
24+
"base_uri": "https://localhost:8080/"
25+
},
26+
"id": "wx5N8OB-dbDV",
27+
"outputId": "237b8a09-14e3-4d7b-ccfa-3892d17c9f26"
28+
},
29+
"outputs": [
30+
{
31+
"output_type": "stream",
32+
"name": "stdout",
33+
"text": [
34+
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/2.7 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m102.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m45.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
35+
"\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/70.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m70.2/70.2 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
36+
"\u001b[?25h"
37+
]
38+
}
39+
],
40+
"source": [
41+
"!pip install -q monai torch-pruning\n"
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"source": [
47+
"# Structured Pruning of U-Net for Medical Image Segmentation\n",
48+
"\n",
49+
"This tutorial demonstrates how structured channel pruning can be applied to a MONAI U-Net model to reduce model size and computation, while maintaining segmentation capability.\n"
50+
],
51+
"metadata": {
52+
"id": "iQoMzAlfd4ZH"
53+
}
54+
},
55+
{
56+
"cell_type": "code",
57+
"source": [
58+
"import torch\n",
59+
"import numpy as np\n",
60+
"from monai.networks.nets import UNet\n",
61+
"import torch_pruning as tp\n"
62+
],
63+
"metadata": {
64+
"colab": {
65+
"base_uri": "https://localhost:8080/"
66+
},
67+
"id": "-uKJxHXQd5cX",
68+
"outputId": "da495e06-22b6-4309-8b74-69f920c226c3"
69+
},
70+
"execution_count": 2,
71+
"outputs": [
72+
{
73+
"output_type": "stream",
74+
"name": "stderr",
75+
"text": [
76+
"<frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.\n"
77+
]
78+
}
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"source": [
84+
"torch.manual_seed(0)\n",
85+
"np.random.seed(0)\n"
86+
],
87+
"metadata": {
88+
"id": "QPj1Qficd8v4"
89+
},
90+
"execution_count": 3,
91+
"outputs": []
92+
},
93+
{
94+
"cell_type": "code",
95+
"source": [
96+
"images = torch.rand(4, 1, 128, 128)\n",
97+
"labels = (images > 0.5).float()\n"
98+
],
99+
"metadata": {
100+
"id": "APIhPidJd9zv"
101+
},
102+
"execution_count": 4,
103+
"outputs": []
104+
},
105+
{
106+
"cell_type": "code",
107+
"source": [
108+
"def count_params(model):\n",
109+
" return sum(p.numel() for p in model.parameters())\n"
110+
],
111+
"metadata": {
112+
"id": "sjrMXVgxfO8P"
113+
},
114+
"execution_count": 14,
115+
"outputs": []
116+
},
117+
{
118+
"cell_type": "code",
119+
"source": [
120+
"baseline_unet = UNet(\n",
121+
" spatial_dims=2,\n",
122+
" in_channels=1,\n",
123+
" out_channels=1,\n",
124+
" channels=(16, 32, 64, 128, 256), # 5 levels\n",
125+
" strides=(2, 2, 2, 2),\n",
126+
")\n"
127+
],
128+
"metadata": {
129+
"id": "s3MFPtpkfRa_"
130+
},
131+
"execution_count": 15,
132+
"outputs": []
133+
},
134+
{
135+
"cell_type": "code",
136+
"source": [
137+
"print(\"Baseline parameters:\", count_params(baseline_unet))\n"
138+
],
139+
"metadata": {
140+
"colab": {
141+
"base_uri": "https://localhost:8080/"
142+
},
143+
"id": "9-QPnHhUfTSv",
144+
"outputId": "c91f6ee7-8ad2-4a52-e20a-97c185e28ea0"
145+
},
146+
"execution_count": 16,
147+
"outputs": [
148+
{
149+
"output_type": "stream",
150+
"name": "stdout",
151+
"text": [
152+
"Baseline parameters: 659993\n"
153+
]
154+
}
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"source": [
160+
"reduced_unet = UNet(\n",
161+
" spatial_dims=2,\n",
162+
" in_channels=1,\n",
163+
" out_channels=1,\n",
164+
" channels=(16, 32, 64), # only 3 levels\n",
165+
" strides=(2, 2),\n",
166+
")\n"
167+
],
168+
"metadata": {
169+
"id": "tbf-UY9GfVVf"
170+
},
171+
"execution_count": 17,
172+
"outputs": []
173+
},
174+
{
175+
"cell_type": "code",
176+
"source": [
177+
"print(\"Depth-reduced parameters:\", count_params(reduced_unet))\n"
178+
],
179+
"metadata": {
180+
"colab": {
181+
"base_uri": "https://localhost:8080/"
182+
},
183+
"id": "9Wi3dwetfXRw",
184+
"outputId": "cb8a334b-ece6-45e9-c74b-0a8ec01cac92"
185+
},
186+
"execution_count": 18,
187+
"outputs": [
188+
{
189+
"output_type": "stream",
190+
"name": "stdout",
191+
"text": [
192+
"Depth-reduced parameters: 37429\n"
193+
]
194+
}
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"source": [
200+
"baseline_out = baseline_unet(images)\n",
201+
"reduced_out = reduced_unet(images)\n",
202+
"\n",
203+
"print(\"Baseline output shape:\", baseline_out.shape)\n",
204+
"print(\"Reduced output shape:\", reduced_out.shape)\n"
205+
],
206+
"metadata": {
207+
"colab": {
208+
"base_uri": "https://localhost:8080/"
209+
},
210+
"id": "iDmYo0HrfZ4A",
211+
"outputId": "6d9eb4e2-c7ab-443c-a2e1-5ec77b1b1cc2"
212+
},
213+
"execution_count": 19,
214+
"outputs": [
215+
{
216+
"output_type": "stream",
217+
"name": "stdout",
218+
"text": [
219+
"Baseline output shape: torch.Size([4, 1, 128, 128])\n",
220+
"Reduced output shape: torch.Size([4, 1, 128, 128])\n"
221+
]
222+
}
223+
]
224+
},
225+
{
226+
"cell_type": "code",
227+
"source": [
228+
"baseline_params = count_params(baseline_unet)\n",
229+
"reduced_params = count_params(reduced_unet)\n",
230+
"\n",
231+
"reduction = 100 * (baseline_params - reduced_params) / baseline_params\n",
232+
"print(f\"Parameter reduction: {reduction:.2f}%\")\n"
233+
],
234+
"metadata": {
235+
"colab": {
236+
"base_uri": "https://localhost:8080/"
237+
},
238+
"id": "qt4TAf-efc3g",
239+
"outputId": "e5d71e43-14ac-49ba-f1a4-6d63f904275e"
240+
},
241+
"execution_count": 20,
242+
"outputs": [
243+
{
244+
"output_type": "stream",
245+
"name": "stdout",
246+
"text": [
247+
"Parameter reduction: 94.33%\n"
248+
]
249+
}
250+
]
251+
},
252+
{
253+
"cell_type": "markdown",
254+
"source": [
255+
"## Discussion\n",
256+
"\n",
257+
"Reducing the depth of a U-Net architecture leads to a true reduction in the number of learnable parameters, unlike masking-based pruning approaches that preserve tensor shapes.\n",
258+
"\n",
259+
"Depth reduction decreases representational capacity and receptive field size, which may affect segmentation accuracy. However, for many medical imaging applications—especially those targeting edge devices or real-time inference—this trade-off is acceptable and often desirable.\n",
260+
"\n",
261+
"This approach provides a simple, stable, and reproducible strategy for building lightweight medical imaging models.\n"
262+
],
263+
"metadata": {
264+
"id": "f1Csi3BJfe3w"
265+
}
266+
},
267+
{
268+
"cell_type": "code",
269+
"source": [
270+
"import time\n",
271+
"\n",
272+
"def inference_time(model, x, runs=20):\n",
273+
" model.eval()\n",
274+
" with torch.no_grad():\n",
275+
" start = time.time()\n",
276+
" for _ in range(runs):\n",
277+
" _ = model(x)\n",
278+
" end = time.time()\n",
279+
" return (end - start) / runs\n",
280+
"\n",
281+
"print(\"Baseline avg inference time:\", inference_time(baseline_unet, images))\n",
282+
"print(\"Reduced avg inference time:\", inference_time(reduced_unet, images))\n"
283+
],
284+
"metadata": {
285+
"colab": {
286+
"base_uri": "https://localhost:8080/"
287+
},
288+
"id": "2cGD6jwsfvli",
289+
"outputId": "b7a66f85-4247-4548-dc2c-782c08425e93"
290+
},
291+
"execution_count": 21,
292+
"outputs": [
293+
{
294+
"output_type": "stream",
295+
"name": "stdout",
296+
"text": [
297+
"Baseline avg inference time: 0.01868886947631836\n",
298+
"Reduced avg inference time: 0.013485324382781983\n"
299+
]
300+
}
301+
]
302+
},
303+
{
304+
"cell_type": "markdown",
305+
"source": [
306+
"## When to Use Depth-Reduced Models\n",
307+
"\n",
308+
"Depth-reduced architectures are well suited for:\n",
309+
"- Edge and embedded medical devices\n",
310+
"- Real-time or near–real-time inference\n",
311+
"- Rapid prototyping and experimentation\n",
312+
"- Scenarios with limited memory or compute budgets\n",
313+
"\n",
314+
"For tasks requiring fine-grained segmentation accuracy, deeper architectures may still be preferable.\n"
315+
],
316+
"metadata": {
317+
"id": "eBW8lxpTfyLC"
318+
}
319+
}
320+
]
321+
}

0 commit comments

Comments
 (0)