diff --git a/synapse_net/training/supervised_training.py b/synapse_net/training/supervised_training.py index 3a2cebc0..c5cb6973 100644 --- a/synapse_net/training/supervised_training.py +++ b/synapse_net/training/supervised_training.py @@ -1,6 +1,6 @@ import os from glob import glob -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch_em @@ -201,6 +201,7 @@ def supervised_training( in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, + checkpoint_path: Optional[Union[os.PathLike, str]] = None, **loader_kwargs, ): """Run supervised segmentation training. @@ -303,7 +304,7 @@ def supervised_training( loss=loss, metric=metric, ) - trainer.fit(n_iterations) + trainer.fit(n_iterations, load_from_checkpoint=checkpoint_path) def _derive_key_from_files(files, key):