diff --git a/synapse_net/training/supervised_training.py b/synapse_net/training/supervised_training.py index 3a2cebc0..1c463238 100644 --- a/synapse_net/training/supervised_training.py +++ b/synapse_net/training/supervised_training.py @@ -201,6 +201,7 @@ def supervised_training( in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, + checkpoint_path: Optional[str] = None, **loader_kwargs, ): """Run supervised segmentation training. @@ -243,6 +244,7 @@ def supervised_training( out_channels: The number of output channels of the UNet. mask_channel: Whether the last channels in the labels should be used for masking the loss. This can be used to implement more complex masking operations and is not compatible with `ignore_label`. + checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. loader_kwargs: Additional keyword arguments for the dataloader. """ train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, @@ -265,6 +267,9 @@ def supervised_training( model = get_2d_model(out_channels=out_channels, in_channels=in_channels) else: model = get_3d_model(out_channels=out_channels, in_channels=in_channels) + + if checkpoint_path: + model = torch_em.util.load_model(checkpoint=checkpoint_path) loss, metric = None, None # No ignore label -> we can use default loss.