From d97bdd574db39960dd42046ff2b868f9ac091316 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Mon, 1 Dec 2025 14:55:36 +0900 Subject: [PATCH 1/3] Enables automatic transform group tracking for inversion Addresses an issue where `Invertd` fails when postprocessing contains invertible transforms before `Invertd` is called. The solution uses automatic group tracking: `Compose` assigns its ID to child transforms, allowing `Invertd` to filter and select only the relevant transforms for inversion. This ensures correct inversion when multiple transform pipelines are used or when post-processing steps include invertible transforms. `TraceableTransform` now stores group information. `Invertd` now filters transforms by group, falling back to the original behavior if no group information is present (for backward compatibility). Adds tests to verify the fix and group isolation. --- monai/transforms/compose.py | 42 ++++ monai/transforms/inverse.py | 8 +- monai/transforms/post/dictionary.py | 28 ++- monai/utils/enums.py | 1 + tests/transforms/inverse/test_invertd.py | 240 +++++++++++++++++++++++ 5 files changed, 317 insertions(+), 2 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index e984c4f26a..0038f9f0b6 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -262,6 +262,48 @@ def __init__( self.set_random_state(seed=get_seed()) self.overrides = overrides + # Automatically assign group ID to child transforms for inversion tracking + self._set_transform_groups() + + def _set_transform_groups(self): + """ + Automatically set group IDs on child transforms for inversion tracking. + This allows Invertd to identify which transforms belong to this Compose instance. + Recursively sets groups on wrapped transforms (e.g., array transforms inside dictionary transforms). + """ + from monai.transforms.inverse import TraceableTransform + + group_id = str(id(self)) + visited = set() # Track visited objects to avoid infinite recursion + + def set_group_recursive(obj, gid): + """Recursively set group on transform and its wrapped transforms.""" + # Avoid infinite recursion + obj_id = id(obj) + if obj_id in visited: + return + visited.add(obj_id) + + if isinstance(obj, TraceableTransform): + obj._group = gid + + # Handle wrapped transforms in dictionary transforms + # Check common attribute patterns for wrapped transforms + for attr_name in dir(obj): + # Skip magic methods and common non-transform attributes + if attr_name.startswith('__') or attr_name in ('transforms', 'transform'): + continue + try: + attr = getattr(obj, attr_name, None) + if attr is not None and isinstance(attr, TraceableTransform) and not isinstance(attr, Compose): + # Recursively set group on nested transforms + set_group_recursive(attr, gid) + except Exception: + pass + + for transform in self.transforms: + set_group_recursive(transform, group_id) + @LazyTransform.lazy.setter # type: ignore def lazy(self, val: bool): self._lazy = val diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 2f57f4614a..e0e058ecaa 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -125,7 +125,13 @@ def get_transform_info(self) -> dict: self.tracing, self._do_transform if hasattr(self, "_do_transform") else True, ) - return dict(zip(self.transform_info_keys(), vals)) + info = dict(zip(self.transform_info_keys(), vals)) + + # Add group if set (automatically set by Compose) + if hasattr(self, "_group") and self._group is not None: + info[TraceKeys.GROUP] = self._group + + return info def push_transform(self, data, *args, **kwargs): """ diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 65fdd22b22..67496e62fc 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -859,6 +859,29 @@ def __init__( self.post_func = ensure_tuple_rep(post_func, len(self.keys)) self._totensor = ToTensor() + def _filter_transforms_by_group(self, all_transforms: list[dict]) -> list[dict]: + """ + Filter applied_operations to only include transforms from the target Compose instance. + Uses automatic group tracking where Compose assigns its ID to child transforms. + """ + from monai.utils import TraceKeys + + # Get the group ID of the transform (Compose instance) + target_group = str(id(self.transform)) + + # Filter transforms that match the target group + filtered = [] + for xform in all_transforms: + xform_group = xform.get(TraceKeys.GROUP) + if xform_group == target_group: + filtered.append(xform) + + # If no transforms match (backward compatibility), return all transforms + if not filtered: + return all_transforms + + return filtered + def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: d = dict(data) for ( @@ -894,8 +917,11 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}" if orig_key in d and isinstance(d[orig_key], MetaTensor): - transform_info = d[orig_key].applied_operations + all_transforms = d[orig_key].applied_operations meta_info = d[orig_key].meta + + # Automatically filter by Compose instance group ID + transform_info = self._filter_transforms_by_group(all_transforms) else: transform_info = d[InvertibleTransform.trace_key(orig_key)] meta_info = d.get(orig_meta_key, {}) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index f5bb6c4c5b..52d9eed5f5 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -334,6 +334,7 @@ class TraceKeys(StrEnum): TRACING: str = "tracing" STATUSES: str = "statuses" LAZY: str = "lazy" + GROUP: str = "group" class TraceStatusKeys(StrEnum): diff --git a/tests/transforms/inverse/test_invertd.py b/tests/transforms/inverse/test_invertd.py index 2b5e9da85d..af7fe12f3d 100644 --- a/tests/transforms/inverse/test_invertd.py +++ b/tests/transforms/inverse/test_invertd.py @@ -137,6 +137,246 @@ def test_invert(self): set_determinism(seed=None) + def test_invertd_with_postprocessing_transforms(self): + """Test that Invertd ignores postprocessing transforms using automatic group tracking. + + This is a regression test for the issue where Invertd would fail when + postprocessing contains invertible transforms before Invertd is called. + The fix uses automatic group tracking where Compose assigns its ID to child transforms. + """ + from monai.data import MetaTensor, create_test_image_2d + from monai.transforms.utility.dictionary import Lambdad + + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + # Preprocessing pipeline + preprocessing = Compose([ + EnsureChannelFirstd(key), + Spacingd(key, pixdim=[2.0, 2.0]), + ]) + + # Postprocessing with Lambdad before Invertd + # Previously this would raise RuntimeError about transform ID mismatch + postprocessing = Compose([ + Lambdad(key, func=lambda x: x), # Should be ignored during inversion + Invertd(key, transform=preprocessing, orig_keys=key) + ]) + + # Apply transforms + item = {key: img} + pre = preprocessing(item) + + # This should NOT raise an error (was failing before the fix) + try: + post = postprocessing(pre) + # If we get here, the bug is fixed + self.assertIsNotNone(post) + self.assertIn(key, post) + print(f"SUCCESS! Automatic group tracking fixed the bug.") + print(f" Preprocessing group ID: {id(preprocessing)}") + print(f" Postprocessing group ID: {id(postprocessing)}") + except RuntimeError as e: + if "getting the most recently applied invertible transform" in str(e): + self.fail(f"Invertd still has the postprocessing transform bug: {e}") + + def test_invertd_multiple_pipelines(self): + """Test that Invertd correctly handles multiple independent preprocessing pipelines.""" + from monai.data import MetaTensor, create_test_image_2d + from monai.transforms.utility.dictionary import Lambdad + + img1, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img1 = MetaTensor(img1, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + img2, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img2 = MetaTensor(img2, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + + # Two different preprocessing pipelines + preprocessing1 = Compose([ + EnsureChannelFirstd("image1"), + Spacingd("image1", pixdim=[2.0, 2.0]), + ]) + + preprocessing2 = Compose([ + EnsureChannelFirstd("image2"), + Spacingd("image2", pixdim=[1.5, 1.5]), + ]) + + # Postprocessing that inverts both + postprocessing = Compose([ + Lambdad(["image1", "image2"], func=lambda x: x), + Invertd("image1", transform=preprocessing1, orig_keys="image1"), + Invertd("image2", transform=preprocessing2, orig_keys="image2"), + ]) + + # Apply transforms + item = {"image1": img1, "image2": img2} + pre1 = preprocessing1(item) + pre2 = preprocessing2(pre1) + + # Should not raise error - each Invertd should only invert its own pipeline + post = postprocessing(pre2) + self.assertIn("image1", post) + self.assertIn("image2", post) + + def test_invertd_multiple_postprocessing_transforms(self): + """Test Invertd with multiple invertible transforms in postprocessing before Invertd.""" + from monai.data import MetaTensor, create_test_image_2d + from monai.transforms.utility.dictionary import Lambdad + + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + preprocessing = Compose([ + EnsureChannelFirstd(key), + Spacingd(key, pixdim=[2.0, 2.0]), + ]) + + # Multiple transforms in postprocessing before Invertd + postprocessing = Compose([ + Lambdad(key, func=lambda x: x * 2), + Lambdad(key, func=lambda x: x + 1), + Lambdad(key, func=lambda x: x - 1), + Invertd(key, transform=preprocessing, orig_keys=key) + ]) + + item = {key: img} + pre = preprocessing(item) + post = postprocessing(pre) + + self.assertIsNotNone(post) + self.assertIn(key, post) + + def test_invertd_group_isolation(self): + """Test that groups correctly isolate transforms from different Compose instances.""" + from monai.data import MetaTensor, create_test_image_2d + + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + # First preprocessing + preprocessing1 = Compose([ + EnsureChannelFirstd(key), + Spacingd(key, pixdim=[2.0, 2.0]), + ]) + + # Second preprocessing (different pipeline) + preprocessing2 = Compose([ + Spacingd(key, pixdim=[1.5, 1.5]), + ]) + + item = {key: img} + pre1 = preprocessing1(item) + + # Verify group IDs are in applied_operations + self.assertTrue(len(pre1[key].applied_operations) > 0) + group1 = pre1[key].applied_operations[0].get("group") + self.assertIsNotNone(group1) + self.assertEqual(group1, str(id(preprocessing1))) + + # Apply second preprocessing + pre2 = preprocessing2(pre1) + + # Should have operations from both pipelines with different groups + groups = [op.get("group") for op in pre2[key].applied_operations] + self.assertIn(str(id(preprocessing1)), groups) + self.assertIn(str(id(preprocessing2)), groups) + + # Inverting preprocessing1 should only invert its transforms + inverter = Invertd(key, transform=preprocessing1, orig_keys=key) + inverted = inverter(pre2) + self.assertIsNotNone(inverted) + + def test_compose_inverse_with_groups(self): + """Test that Compose.inverse() works correctly with automatic group tracking.""" + from monai.data import MetaTensor, create_test_image_2d + + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + # Create a preprocessing pipeline + preprocessing = Compose([ + EnsureChannelFirstd(key), + Spacingd(key, pixdim=[2.0, 2.0]), + ]) + + # Apply preprocessing + item = {key: img} + pre = preprocessing(item) + + # Call inverse() directly on the Compose object + inverted = preprocessing.inverse(pre) + + # Should successfully invert + self.assertIsNotNone(inverted) + self.assertIn(key, inverted) + # Shape should be restored after inversion + self.assertEqual(inverted[key].shape[1:], img.shape) + + def test_compose_inverse_with_postprocessing_groups(self): + """Test Compose.inverse() when data has been through multiple pipelines with different groups.""" + from monai.data import MetaTensor, create_test_image_2d + from monai.transforms.utility.dictionary import Lambdad + + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + # Preprocessing pipeline + preprocessing = Compose([ + EnsureChannelFirstd(key), + Spacingd(key, pixdim=[2.0, 2.0]), + ]) + + # Postprocessing pipeline (different group) + postprocessing = Compose([ + Lambdad(key, func=lambda x: x * 2), + ]) + + # Apply both pipelines + item = {key: img} + pre = preprocessing(item) + post = postprocessing(pre) + + # Now call inverse() directly on preprocessing + # This tests that inverse() can handle data that has transforms from multiple groups + # This WILL fail because applied_operations contains postprocessing transforms + # and inverse() doesn't do group filtering (only Invertd does) + with self.assertRaises(RuntimeError): + inverted = preprocessing.inverse(post) + + def test_mixed_invertd_and_compose_inverse(self): + """Test mixing Invertd (with group filtering) and Compose.inverse() (without filtering).""" + from monai.data import MetaTensor, create_test_image_2d + + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + # First pipeline + pipeline1 = Compose([ + EnsureChannelFirstd(key), + Spacingd(key, pixdim=[2.0, 2.0]), + ]) + + # Apply first pipeline + item = {key: img} + result1 = pipeline1(item) + + # Use Compose.inverse() directly - should work fine + inverted1 = pipeline1.inverse(result1) + self.assertIsNotNone(inverted1) + self.assertEqual(inverted1[key].shape[1:], img.shape) + + # Now apply pipeline again and use Invertd + result2 = pipeline1(item) + inverter = Invertd(key, transform=pipeline1, orig_keys=key) + inverted2 = inverter(result2) + self.assertIsNotNone(inverted2) + if __name__ == "__main__": unittest.main() From 8961b770edeec13c01a94c61a620de05fd10abfa Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Mon, 1 Dec 2025 15:01:37 +0900 Subject: [PATCH 2/3] autofix formatting DCO Remediation Commit for sewon.jeon I, sewon.jeon , hereby add my Signed-off-by to this commit: d97bdd574db39960dd42046ff2b868f9ac091316 I, sewon.jeon , hereby add my Signed-off-by to this commit: c523e945155f1662eef7e73dc562df7bd812e61d Signed-off-by: sewon.jeon --- monai/transforms/compose.py | 2 +- tests/transforms/inverse/test_invertd.py | 88 +++++++++--------------- 2 files changed, 34 insertions(+), 56 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 0038f9f0b6..97417428cb 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -291,7 +291,7 @@ def set_group_recursive(obj, gid): # Check common attribute patterns for wrapped transforms for attr_name in dir(obj): # Skip magic methods and common non-transform attributes - if attr_name.startswith('__') or attr_name in ('transforms', 'transform'): + if attr_name.startswith("__") or attr_name in ("transforms", "transform"): continue try: attr = getattr(obj, attr_name, None) diff --git a/tests/transforms/inverse/test_invertd.py b/tests/transforms/inverse/test_invertd.py index af7fe12f3d..dac9433b58 100644 --- a/tests/transforms/inverse/test_invertd.py +++ b/tests/transforms/inverse/test_invertd.py @@ -152,17 +152,16 @@ def test_invertd_with_postprocessing_transforms(self): key = "image" # Preprocessing pipeline - preprocessing = Compose([ - EnsureChannelFirstd(key), - Spacingd(key, pixdim=[2.0, 2.0]), - ]) + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) # Postprocessing with Lambdad before Invertd # Previously this would raise RuntimeError about transform ID mismatch - postprocessing = Compose([ - Lambdad(key, func=lambda x: x), # Should be ignored during inversion - Invertd(key, transform=preprocessing, orig_keys=key) - ]) + postprocessing = Compose( + [ + Lambdad(key, func=lambda x: x), # Should be ignored during inversion + Invertd(key, transform=preprocessing, orig_keys=key), + ] + ) # Apply transforms item = {key: img} @@ -174,7 +173,7 @@ def test_invertd_with_postprocessing_transforms(self): # If we get here, the bug is fixed self.assertIsNotNone(post) self.assertIn(key, post) - print(f"SUCCESS! Automatic group tracking fixed the bug.") + print("SUCCESS! Automatic group tracking fixed the bug.") print(f" Preprocessing group ID: {id(preprocessing)}") print(f" Postprocessing group ID: {id(postprocessing)}") except RuntimeError as e: @@ -192,22 +191,18 @@ def test_invertd_multiple_pipelines(self): img2 = MetaTensor(img2, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) # Two different preprocessing pipelines - preprocessing1 = Compose([ - EnsureChannelFirstd("image1"), - Spacingd("image1", pixdim=[2.0, 2.0]), - ]) + preprocessing1 = Compose([EnsureChannelFirstd("image1"), Spacingd("image1", pixdim=[2.0, 2.0])]) - preprocessing2 = Compose([ - EnsureChannelFirstd("image2"), - Spacingd("image2", pixdim=[1.5, 1.5]), - ]) + preprocessing2 = Compose([EnsureChannelFirstd("image2"), Spacingd("image2", pixdim=[1.5, 1.5])]) # Postprocessing that inverts both - postprocessing = Compose([ - Lambdad(["image1", "image2"], func=lambda x: x), - Invertd("image1", transform=preprocessing1, orig_keys="image1"), - Invertd("image2", transform=preprocessing2, orig_keys="image2"), - ]) + postprocessing = Compose( + [ + Lambdad(["image1", "image2"], func=lambda x: x), + Invertd("image1", transform=preprocessing1, orig_keys="image1"), + Invertd("image2", transform=preprocessing2, orig_keys="image2"), + ] + ) # Apply transforms item = {"image1": img1, "image2": img2} @@ -228,18 +223,17 @@ def test_invertd_multiple_postprocessing_transforms(self): img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) key = "image" - preprocessing = Compose([ - EnsureChannelFirstd(key), - Spacingd(key, pixdim=[2.0, 2.0]), - ]) + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) # Multiple transforms in postprocessing before Invertd - postprocessing = Compose([ - Lambdad(key, func=lambda x: x * 2), - Lambdad(key, func=lambda x: x + 1), - Lambdad(key, func=lambda x: x - 1), - Invertd(key, transform=preprocessing, orig_keys=key) - ]) + postprocessing = Compose( + [ + Lambdad(key, func=lambda x: x * 2), + Lambdad(key, func=lambda x: x + 1), + Lambdad(key, func=lambda x: x - 1), + Invertd(key, transform=preprocessing, orig_keys=key), + ] + ) item = {key: img} pre = preprocessing(item) @@ -257,15 +251,10 @@ def test_invertd_group_isolation(self): key = "image" # First preprocessing - preprocessing1 = Compose([ - EnsureChannelFirstd(key), - Spacingd(key, pixdim=[2.0, 2.0]), - ]) + preprocessing1 = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) # Second preprocessing (different pipeline) - preprocessing2 = Compose([ - Spacingd(key, pixdim=[1.5, 1.5]), - ]) + preprocessing2 = Compose([Spacingd(key, pixdim=[1.5, 1.5])]) item = {key: img} pre1 = preprocessing1(item) @@ -298,10 +287,7 @@ def test_compose_inverse_with_groups(self): key = "image" # Create a preprocessing pipeline - preprocessing = Compose([ - EnsureChannelFirstd(key), - Spacingd(key, pixdim=[2.0, 2.0]), - ]) + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) # Apply preprocessing item = {key: img} @@ -326,15 +312,10 @@ def test_compose_inverse_with_postprocessing_groups(self): key = "image" # Preprocessing pipeline - preprocessing = Compose([ - EnsureChannelFirstd(key), - Spacingd(key, pixdim=[2.0, 2.0]), - ]) + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) # Postprocessing pipeline (different group) - postprocessing = Compose([ - Lambdad(key, func=lambda x: x * 2), - ]) + postprocessing = Compose([Lambdad(key, func=lambda x: x * 2)]) # Apply both pipelines item = {key: img} @@ -346,7 +327,7 @@ def test_compose_inverse_with_postprocessing_groups(self): # This WILL fail because applied_operations contains postprocessing transforms # and inverse() doesn't do group filtering (only Invertd does) with self.assertRaises(RuntimeError): - inverted = preprocessing.inverse(post) + preprocessing.inverse(post) def test_mixed_invertd_and_compose_inverse(self): """Test mixing Invertd (with group filtering) and Compose.inverse() (without filtering).""" @@ -357,10 +338,7 @@ def test_mixed_invertd_and_compose_inverse(self): key = "image" # First pipeline - pipeline1 = Compose([ - EnsureChannelFirstd(key), - Spacingd(key, pixdim=[2.0, 2.0]), - ]) + pipeline1 = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) # Apply first pipeline item = {key: img} From 351ec0082a609df31e59f492d24d2d7bcd11cc54 Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Mon, 1 Dec 2025 15:27:32 +0900 Subject: [PATCH 3/3] fix errors Signed-off-by: sewon.jeon --- monai/transforms/compose.py | 11 ++++------- monai/transforms/inverse.py | 9 ++++++++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 97417428cb..c474f52153 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -293,13 +293,10 @@ def set_group_recursive(obj, gid): # Skip magic methods and common non-transform attributes if attr_name.startswith("__") or attr_name in ("transforms", "transform"): continue - try: - attr = getattr(obj, attr_name, None) - if attr is not None and isinstance(attr, TraceableTransform) and not isinstance(attr, Compose): - # Recursively set group on nested transforms - set_group_recursive(attr, gid) - except Exception: - pass + attr = getattr(obj, attr_name, None) + if attr is not None and isinstance(attr, TraceableTransform) and not isinstance(attr, Compose): + # Recursively set group on nested transforms + set_group_recursive(attr, gid) for transform in self.transforms: set_group_recursive(transform, group_id) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index e0e058ecaa..d7cdedc0ef 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -82,6 +82,10 @@ def _init_trace_threadlocal(self): if not hasattr(self._tracing, "value"): self._tracing.value = MONAIEnvVars.trace_transform() != "0" + # Initialize group identifier (set by Compose for automatic group tracking) + if not hasattr(self, "_group"): + self._group: str | None = None + def __getstate__(self): """When pickling, remove the `_tracing` member from the output, if present, since it's not picklable.""" _dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object @@ -119,6 +123,9 @@ def get_transform_info(self) -> dict: """ Return a dictionary with the relevant information pertaining to an applied transform. """ + # Ensure _group is initialized + self._init_trace_threadlocal() + vals = ( self.__class__.__name__, id(self), @@ -128,7 +135,7 @@ def get_transform_info(self) -> dict: info = dict(zip(self.transform_info_keys(), vals)) # Add group if set (automatically set by Compose) - if hasattr(self, "_group") and self._group is not None: + if self._group is not None: info[TraceKeys.GROUP] = self._group return info