From 465b651fbfcf346569674002799879436c3a176a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 16 Nov 2025 23:29:17 +0100 Subject: [PATCH 01/24] use instantiate model to load data and model from ckpt --- chebai/trainer/CustomTrainer.py | 69 +++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index f7fbce26..e84aad4c 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,4 +1,5 @@ import logging +import os from typing import Any, List, Optional, Tuple import pandas as pd @@ -6,13 +7,17 @@ from lightning import LightningModule, Trainer from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call from torch.nn.utils.rnn import pad_sequence +from build.lib.chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.reader import CLS_TOKEN log = logging.getLogger(__name__) @@ -44,6 +49,7 @@ def __init__(self, *args, **kwargs): # use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops) self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: """ @@ -76,12 +82,10 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: def predict_from_file( self, - model: LightningModule, checkpoint_path: _PATH, input_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, - **kwargs, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. @@ -93,20 +97,21 @@ def predict_from_file( save_to: Path to save the predictions CSV file. classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). """ - loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) with open(input_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - loaded_model.eval() - predictions = self._predict_smiles(loaded_model, smiles_strings) - predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) - if classes_path is not None: - with open(classes_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - predictions_df.index = smiles_strings - predictions_df.to_csv(save_to) + self._predict_smiles( + checkpoint_path, + smiles=smiles_strings, + classes_path=classes_path, + save_to=save_to, + ) def _predict_smiles( - self, model: LightningModule, smiles: List[str] + self, + checkpoint_path: _PATH, + smiles: List[str], + classes_path: Optional[_PATH] = None, + save_to: _PATH = "predictions.csv", ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. @@ -118,22 +123,47 @@ def _predict_smiles( Returns: A tensor containing the predictions. """ - reader = ChemDataReader() - parsed_smiles = [reader._read_data(s) for s in smiles] + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + + ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") + dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] + ) + + model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, ckpt_file["hyper_parameters"] + ) + model.to(self.device) + model.eval() + + parsed_smiles = [dm.reader._read_data(s) for s in smiles] x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_smiles], + [torch.tensor(a, device=self.device) for a in parsed_smiles], batch_first=True, ) cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) + torch.ones(x.shape[0], dtype=torch.int, device=self.device).unsqueeze(-1) * CLS_TOKEN ) features = torch.cat((cls_tokens, x), dim=1) model_output = model({"features": features}) preds = torch.sigmoid(model_output["logits"]) - print(preds.shape) - return preds + predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(dm.classes_txt_file_path): + _add_class_columns(dm.classes_txt_file_path) + + predictions_df.index = smiles + predictions_df.to_csv(save_to) @property def log_dir(self) -> Optional[str]: @@ -157,7 +187,6 @@ def log_dir(self) -> Optional[str]: class LoadDataLaterFitLoop(_FitLoop): - def on_advance_start(self) -> None: """Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary so that the dataloaders can get information from the model. For example: The on_train_epoch_start From 2acd166b87f03df14839fa9c5c381b1669bc6e84 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 16 Nov 2025 23:29:38 +0100 Subject: [PATCH 02/24] update readme --- README.md | 2 +- chebai/preprocessing/datasets/base.py | 38 ++++++++++++++++++--------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index eeecd714..2555c0a6 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 -m chebai predict_from_file --checkpoint_path=[path-to-model] --input_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] ``` The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the one row for each SMILES string and one column for each class. diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 68254007..f1357c88 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or ( - self.label_filter is None - ), "Filter balancing requires a filter" + assert (balance_after_filter is not None) or (self.label_filter is None), ( + "Filter balancing requires a filter" + ) self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert ( - fold_index is None or self.use_inner_cross_validation is not None - ), "fold_index can only be set if cross validation is used" + assert fold_index is None or self.use_inner_cross_validation is not None, ( + "fold_index can only be set if cross validation is used" + ) if fold_index is not None and self.inner_k_folds is not None: - assert ( - fold_index < self.inner_k_folds - ), "fold_index can't be larger than the total number of folds" + assert fold_index < self.inner_k_folds, ( + "fold_index can't be larger than the total number of folds" + ) self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert ( - self._feature_vector_size is not None - ), "size of feature vector must be set" + assert self._feature_vector_size is not None, ( + "size of feature vector must be set" + ) return self._feature_vector_size @property @@ -1190,7 +1190,8 @@ def _retrieve_splits_from_csv(self) -> None: print(f"Applying label filter from {self.apply_label_filter}...") with open(self.apply_label_filter, "r") as f: label_filter = [line.strip() for line in f] - with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf: + + with open(self.classes_txt_file_path, "r") as cf: classes = [line.strip() for line in cf] # reorder labels old_labels = np.stack(df_data["labels"]) @@ -1315,3 +1316,14 @@ def processed_file_names_dict(self) -> dict: if self.n_token_limit is not None: return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} + + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in custom trainer `chebai/trainer/CustomTrainer.py` + return os.path.join(self.processed_dir_main, "classes.txt") From cfbf392f13773c0ce92719b24a83ef231cea2e32 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 17 Nov 2025 14:50:16 +0100 Subject: [PATCH 03/24] set no grad for predict --- chebai/trainer/CustomTrainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index e84aad4c..6f1a542c 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -80,6 +80,7 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value + @torch.no_grad() def predict_from_file( self, checkpoint_path: _PATH, @@ -106,6 +107,7 @@ def predict_from_file( save_to=save_to, ) + @torch.no_grad() def _predict_smiles( self, checkpoint_path: _PATH, From 82b365ca31698da387f4077c13ea345d9572aa04 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 26 Nov 2025 16:36:15 +0100 Subject: [PATCH 04/24] predict pipeline in dm and lm --- chebai/models/base.py | 8 ++++++- chebai/preprocessing/datasets/base.py | 31 ++++++++++++++++++++++----- chebai/trainer/CustomTrainer.py | 4 +++- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 7653f13c..808ea59e 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -232,7 +232,13 @@ def predict_step( Returns: Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. """ - return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) + assert isinstance(batch, XYData) + batch = batch.to(self.device) + data = self._process_batch(batch, batch_idx) + labels = data["labels"] + model_output = self(data, **data.get("model_kwargs", dict())) + pr, _ = self._get_prediction_and_labels(data, labels, model_output) + return pr def _execute( self, diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f1357c88..e2df794d 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -339,8 +339,14 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: for d in tqdm.tqdm(self._load_dict(path), total=lines) if d["features"] is not None ] + + return self._filter_to_token_limit(data) + + def _filter_to_token_limit( + self, data: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: # filter for missing features in resulting data, keep features length below token limit - data = [ + return [ val for val in data if val["features"] is not None @@ -349,8 +355,6 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: ) ] - return data - def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ Returns the train DataLoader. @@ -400,10 +404,13 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. """ + return self.dataloader("test", shuffle=False, **kwargs) def predict_dataloader( - self, *args, **kwargs + self, + smiles_list: List[str], + **kwargs, ) -> Union[DataLoader, List[DataLoader]]: """ Returns the predict DataLoader. @@ -415,7 +422,21 @@ def predict_dataloader( Returns: Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. """ - return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) + + data = [ + self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": None} + ) + for idx, smiles in enumerate(smiles_list) + ] + data = self._filter_to_token_limit(data) + + return DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 6f1a542c..acd468f2 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -4,7 +4,7 @@ import pandas as pd import torch -from lightning import LightningModule, Trainer +from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.fabric.utilities.types import _PATH from lightning.pytorch.cli import instantiate_module @@ -87,6 +87,7 @@ def predict_from_file( input_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, + **kwargs, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. @@ -114,6 +115,7 @@ def _predict_smiles( smiles: List[str], classes_path: Optional[_PATH] = None, save_to: _PATH = "predictions.csv", + **kwargs, ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. From fa6f1b521d05d38261973f602ff53232ec5eabb3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 26 Nov 2025 18:16:36 +0100 Subject: [PATCH 05/24] there is no need that predict func must depend on trainer --- chebai/result/prediction.py | 111 ++++++++++++++++++++++++++++++++ chebai/trainer/CustomTrainer.py | 107 ++++-------------------------- 2 files changed, 122 insertions(+), 96 deletions(-) create mode 100644 chebai/result/prediction.py diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py new file mode 100644 index 00000000..cb3e1415 --- /dev/null +++ b/chebai/result/prediction.py @@ -0,0 +1,111 @@ +import os +from typing import List, Optional + +import pandas as pd +import torch +from jsonargparse import CLI +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.cli import instantiate_module +from torch.utils.data import DataLoader + +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class Predictor: + def __init__(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + super().__init__() + + def predict_from_file( + self, + checkpoint_path: _PATH, + input_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + ) -> None: + """ + Loads a model from a checkpoint and makes predictions on input data from a file. + + Args: + model: The model to use for predictions. + checkpoint_path: Path to the model checkpoint. + input_path: Path to the input file containing SMILES strings. + save_to: Path to save the predictions CSV file. + classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). + """ + with open(input_path, "r") as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + + self.predict_smiles( + checkpoint_path, + smiles=smiles_strings, + classes_path=classes_path, + save_to=save_to, + ) + + @torch.inference_mode() + def predict_smiles( + self, + checkpoint_path: _PATH, + smiles: List[str], + classes_path: Optional[_PATH] = None, + save_to: Optional[_PATH] = None, + **kwargs, + ) -> torch.Tensor | None: + """ + Predicts the output for a list of SMILES strings using the model. + + Args: + model: The model to use for predictions. + smiles: A list of SMILES strings. + + Returns: + A tensor containing the predictions. + """ + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + + ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") + dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] + ) + print(f"Loaded datamodule class: {dm.__class__.__name__}") + + pred_dl: DataLoader = dm.predict_dataloader(smiles_list=smiles) + + model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, ckpt_file["hyper_parameters"] + ) + model.eval() + # model = torch.compile(model) + + print(f"Loaded model class: {model.__class__.__name__}") + + preds = [] + for batch_idx, batch in enumerate(pred_dl): + preds.append(model.predict_step(batch, batch_idx)) + + if not save_to: + # If no save path is provided, return the predictions tensor + return torch.cat(preds) + + predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(dm.classes_txt_file_path): + _add_class_columns(dm.classes_txt_file_path) + + predictions_df.index = smiles + predictions_df.to_csv(save_to) + + +if __name__ == "__main__": + # python chebai/result/prediction.py predict_from_file --help + CLI(Predictor, as_positional=False) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index acd468f2..11ade921 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,23 +1,14 @@ import logging -import os -from typing import Any, List, Optional, Tuple +from typing import Any, Optional, Tuple -import pandas as pd import torch from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch -from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.cli import instantiate_module from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.loops.fit_loop import _FitLoop from lightning.pytorch.trainer import call -from torch.nn.utils.rnn import pad_sequence -from build.lib.chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.loggers.custom import CustomLogger -from chebai.models.base import ChebaiBaseNet -from chebai.preprocessing.datasets.base import XYBaseDataModule -from chebai.preprocessing.reader import CLS_TOKEN log = logging.getLogger(__name__) @@ -80,94 +71,18 @@ def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: else: return key, value - @torch.no_grad() - def predict_from_file( + def predict( self, - checkpoint_path: _PATH, - input_path: _PATH, - save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, - **kwargs, - ) -> None: - """ - Loads a model from a checkpoint and makes predictions on input data from a file. - - Args: - model: The model to use for predictions. - checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. - save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). - """ - with open(input_path, "r") as input: - smiles_strings = [inp.strip() for inp in input.readlines()] - self._predict_smiles( - checkpoint_path, - smiles=smiles_strings, - classes_path=classes_path, - save_to=save_to, - ) - - @torch.no_grad() - def _predict_smiles( - self, - checkpoint_path: _PATH, - smiles: List[str], - classes_path: Optional[_PATH] = None, - save_to: _PATH = "predictions.csv", - **kwargs, - ) -> torch.Tensor: - """ - Predicts the output for a list of SMILES strings using the model. - - Args: - model: The model to use for predictions. - smiles: A list of SMILES strings. - - Returns: - A tensor containing the predictions. - """ - ckpt_file = torch.load( - checkpoint_path, map_location=self.device, weights_only=False - ) - - ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") - dm: XYBaseDataModule = instantiate_module( - XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] - ) - - model: ChebaiBaseNet = instantiate_module( - ChebaiBaseNet, ckpt_file["hyper_parameters"] - ) - model.to(self.device) - model.eval() - - parsed_smiles = [dm.reader._read_data(s) for s in smiles] - x = pad_sequence( - [torch.tensor(a, device=self.device) for a in parsed_smiles], - batch_first=True, - ) - cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=self.device).unsqueeze(-1) - * CLS_TOKEN + model=None, + dataloaders=None, + datamodule=None, + return_predictions=None, + ckpt_path=None, + ): + raise NotImplementedError( + "CustomTrainer.predict is not implemented." + "Use `Prediction.predict_from_file` or `Prediction.predict_smiles` from `chebai/result/prediction.py` instead." ) - features = torch.cat((cls_tokens, x), dim=1) - model_output = model({"features": features}) - preds = torch.sigmoid(model_output["logits"]) - - predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) - - def _add_class_columns(class_file_path: _PATH): - with open(class_file_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - - if classes_path is not None: - _add_class_columns(classes_path) - elif os.path.exists(dm.classes_txt_file_path): - _add_class_columns(dm.classes_txt_file_path) - - predictions_df.index = smiles - predictions_df.to_csv(save_to) @property def log_dir(self) -> Optional[str]: From ae47608eb26918d497ef2cd60e750c4c2d585782 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 13:14:18 +0100 Subject: [PATCH 06/24] model hparams for data predict pipeline and vice versa --- README.md | 2 +- chebai/preprocessing/datasets/base.py | 53 +++++++++++++++++---------- chebai/result/prediction.py | 23 ++++++------ 3 files changed, 47 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 2555c0a6..713a9d42 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont You can evaluate a model trained on the ontology extension task in one of two ways: ### 1. Using the Jupyter Notebook -An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. - Load your finetuned model and run the evaluation cells to compute metrics on the test set. ### 2. Using the Lightning CLI diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index e2df794d..5e3064b3 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or (self.label_filter is None), ( - "Filter balancing requires a filter" - ) + assert (balance_after_filter is not None) or ( + self.label_filter is None + ), "Filter balancing requires a filter" self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert fold_index is None or self.use_inner_cross_validation is not None, ( - "fold_index can only be set if cross validation is used" - ) + assert ( + fold_index is None or self.use_inner_cross_validation is not None + ), "fold_index can only be set if cross validation is used" if fold_index is not None and self.inner_k_folds is not None: - assert fold_index < self.inner_k_folds, ( - "fold_index can't be larger than the total number of folds" - ) + assert ( + fold_index < self.inner_k_folds + ), "fold_index can't be larger than the total number of folds" self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert self._feature_vector_size is not None, ( - "size of feature vector must be set" - ) + assert ( + self._feature_vector_size is not None + ), "size of feature vector must be set" return self._feature_vector_size @property @@ -410,6 +410,7 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] def predict_dataloader( self, smiles_list: List[str], + model_hparams: Optional[dict] = None, **kwargs, ) -> Union[DataLoader, List[DataLoader]]: """ @@ -423,6 +424,26 @@ def predict_dataloader( Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. """ + data = self._process_input_for_prediction(smiles_list, model_hparams) + return DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) + + def _process_input_for_prediction( + self, smiles_list: list[str], model_hparams: Optional[dict] = None + ) -> list: + """ + Process input data for prediction. + + Args: + smiles_list (List[str]): List of SMILES strings. + + Returns: + List[Dict[str, Any]]: Processed input data. + """ data = [ self.reader.to_data( {"id": f"smiles_{idx}", "features": smiles, "labels": None} @@ -430,13 +451,7 @@ def predict_dataloader( for idx, smiles in enumerate(smiles_list) ] data = self._filter_to_token_limit(data) - - return DataLoader( - data, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, - ) + return data def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index cb3e1415..250d68b6 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -67,25 +67,26 @@ def predict_smiles( checkpoint_path, map_location=self.device, weights_only=False ) - ckpt_file["datamodule_hyper_parameters"].pop("splits_file_path") - dm: XYBaseDataModule = instantiate_module( - XYBaseDataModule, ckpt_file["datamodule_hyper_parameters"] - ) + dm_hparams = ckpt_file["datamodule_hyper_parameters"] + dm_hparams.pop("splits_file_path") + dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) print(f"Loaded datamodule class: {dm.__class__.__name__}") - pred_dl: DataLoader = dm.predict_dataloader(smiles_list=smiles) - - model: ChebaiBaseNet = instantiate_module( - ChebaiBaseNet, ckpt_file["hyper_parameters"] - ) + model_hparams = ckpt_file["hyper_parameters"] + model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) model.eval() # model = torch.compile(model) - print(f"Loaded model class: {model.__class__.__name__}") + # For certain data prediction piplines, we may need model hyperparameters + pred_dl: DataLoader = dm.predict_dataloader( + smiles_list=smiles, model_hparams=model_hparams + ) + preds = [] for batch_idx, batch in enumerate(pred_dl): - preds.append(model.predict_step(batch, batch_idx)) + # For certain model prediction pipelines, we may need data module hyperparameters + preds.append(model.predict_step(batch, batch_idx, dm_hparams=dm_hparams)) if not save_to: # If no save path is provided, return the predictions tensor From 517a5a2061927f95a6e2ceb4661d1948e3d1defd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:19:55 +0100 Subject: [PATCH 07/24] fix reader ident error --- chebai/preprocessing/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 22b91a0e..7e41510c 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -51,7 +51,7 @@ def _get_raw_label(self, row: Dict[str, Any]) -> Any: def _get_raw_id(self, row: Dict[str, Any]) -> Any: """Get raw ID from the row.""" - return row.get("ident", row["features"]) + return row.get("ident", row["id"]) def _get_raw_group(self, row: Dict[str, Any]) -> Any: """Get raw group from the row.""" From 40491e55f97d94d75e2f987ffd99f4f68e02f616 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:23:37 +0100 Subject: [PATCH 08/24] fix label None error --- chebai/models/base.py | 4 +++- chebai/result/prediction.py | 32 ++++++++++++++++++++++++-------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 808ea59e..c6c347a4 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -235,8 +235,10 @@ def predict_step( assert isinstance(batch, XYData) batch = batch.to(self.device) data = self._process_batch(batch, batch_idx) - labels = data["labels"] model_output = self(data, **data.get("model_kwargs", dict())) + + # Dummy labels to avoid errors in _get_prediction_and_labels + labels = torch.zeros((len(batch), self.out_dim)).to(self.device) pr, _ = self._get_prediction_and_labels(data, labels, model_output) return pr diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 250d68b6..6f8e41c1 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -20,21 +20,25 @@ def __init__(self): def predict_from_file( self, checkpoint_path: _PATH, - input_path: _PATH, + smiles_file_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, + batch_size: Optional[int] = None, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. Args: - model: The model to use for predictions. checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. + smiles_file_path: Path to the input file containing SMILES strings. save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names (if no class names are provided, columns will be numbered). + classes_path: Optional path to a file containing class names: + if no class names are provided, code will try to get the class path + from the datamodule, else the columns will be numbered. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. """ - with open(input_path, "r") as input: + with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] self.predict_smiles( @@ -42,6 +46,7 @@ def predict_from_file( smiles=smiles_strings, classes_path=classes_path, save_to=save_to, + batch_size=batch_size, ) @torch.inference_mode() @@ -51,16 +56,24 @@ def predict_smiles( smiles: List[str], classes_path: Optional[_PATH] = None, save_to: Optional[_PATH] = None, + batch_size: Optional[int] = None, **kwargs, ) -> torch.Tensor | None: """ Predicts the output for a list of SMILES strings using the model. Args: - model: The model to use for predictions. + checkpoint_path: Path to the model checkpoint. smiles: A list of SMILES strings. - - Returns: + classes_path: Optional path to a file containing class names. If no class + names are provided, code will try to get the class path from the datamodule, + else the columns will be numbered. + save_to: Optional path to save the predictions CSV file. If not provided, + predictions will be returned as a tensor. + batch_size: Optional batch size for the DataLoader. If not provided, the default + from the datamodule will be used. + + Returns: (if save_to is None) A tensor containing the predictions. """ ckpt_file = torch.load( @@ -71,10 +84,13 @@ def predict_smiles( dm_hparams.pop("splits_file_path") dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) print(f"Loaded datamodule class: {dm.__class__.__name__}") + if batch_size is not None: + dm.batch_size = batch_size model_hparams = ckpt_file["hyper_parameters"] model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) model.eval() + # TODO: Enable torch.compile when supported # model = torch.compile(model) print(f"Loaded model class: {model.__class__.__name__}") From 7b7e48f7bb581b87f89c162b0b3d4931a134872b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:50:16 +0100 Subject: [PATCH 09/24] fix cli predict_from_file error --- chebai/cli.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai/cli.py b/chebai/cli.py index 96262447..de48f615 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -83,7 +83,6 @@ def subcommands() -> Dict[str, Set[str]]: "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, - "predict_from_file": {"model"}, } From 6a383177eff88ba768076bd094c660aeeba27102 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:50:32 +0100 Subject: [PATCH 10/24] update readme for new prediction method --- README.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 713a9d42..7af8aad4 100644 --- a/README.md +++ b/README.md @@ -70,11 +70,19 @@ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=co ### Predicting classes given SMILES strings ``` -python3 -m chebai predict_from_file --checkpoint_path=[path-to-model] --input_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] +python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]] ``` -The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the -one row for each SMILES string and one column for each class. -The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs. + +* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`). + +* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line. + +* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. + +* **`--classes_path`** *(optional)*: Path to the dataset’s `raw/classes.txt` file, which maps model output indices to ChEBI IDs. + + * If provided, the CSV columns will be named using the ChEBI IDs. + * If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially. ## Evaluation From 31b12dbf2ce85b685d2ee8bd721bf4f97e995db6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 23:26:10 +0100 Subject: [PATCH 11/24] Revert "fix reader ident error" This reverts commit 517a5a2061927f95a6e2ceb4661d1948e3d1defd. --- chebai/preprocessing/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 7e41510c..22b91a0e 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -51,7 +51,7 @@ def _get_raw_label(self, row: Dict[str, Any]) -> Any: def _get_raw_id(self, row: Dict[str, Any]) -> Any: """Get raw ID from the row.""" - return row.get("ident", row["id"]) + return row.get("ident", row["features"]) def _get_raw_group(self, row: Dict[str, Any]) -> Any: """Get raw group from the row.""" From c6e8b6137c6bc0ca0e97481587361a57d3ccbcda Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 28 Nov 2025 15:42:48 +0100 Subject: [PATCH 12/24] modify pred logic to store model and dm as instance var --- chebai/result/prediction.py | 140 ++++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 6f8e41c1..b84e59ad 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -13,116 +13,132 @@ class Predictor: - def __init__(self): + def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): + """Initializes the Predictor with a model loaded from the checkpoint. + + Args: + checkpoint_path: Path to the model checkpoint. + batch_size: Optional batch size for the DataLoader. If not provided, + the default from the datamodule will be used. + """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - super().__init__() + ckpt_file = torch.load( + checkpoint_path, map_location=self.device, weights_only=False + ) + + self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] + self._dm_hparams.pop("splits_file_path") + self._dm: XYBaseDataModule = instantiate_module( + XYBaseDataModule, self._dm_hparams + ) + print(f"Loaded datamodule class: {self._dm.__class__.__name__}") + if batch_size is not None and int(batch_size) > 0: + self._dm.batch_size = int(batch_size) + + self._model_hparams = ckpt_file["hyper_parameters"] + self._model: ChebaiBaseNet = instantiate_module( + ChebaiBaseNet, self._model_hparams + ) + self._model.eval() + # TODO: Enable torch.compile when supported + # model = torch.compile(model) + print(f"Loaded model class: {self._model.__class__.__name__}") def predict_from_file( self, - checkpoint_path: _PATH, smiles_file_path: _PATH, save_to: _PATH = "predictions.csv", classes_path: Optional[_PATH] = None, - batch_size: Optional[int] = None, ) -> None: """ Loads a model from a checkpoint and makes predictions on input data from a file. Args: - checkpoint_path: Path to the model checkpoint. smiles_file_path: Path to the input file containing SMILES strings. save_to: Path to save the predictions CSV file. classes_path: Optional path to a file containing class names: if no class names are provided, code will try to get the class path from the datamodule, else the columns will be numbered. - batch_size: Optional batch size for the DataLoader. If not provided, - the default from the datamodule will be used. """ with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - self.predict_smiles( - checkpoint_path, + preds: torch.Tensor = self.predict_smiles( smiles=smiles_strings, classes_path=classes_path, save_to=save_to, - batch_size=batch_size, ) + predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) + + def _add_class_columns(class_file_path: _PATH): + with open(class_file_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + + if classes_path is not None: + _add_class_columns(classes_path) + elif os.path.exists(self._dm.classes_txt_file_path): + _add_class_columns(self._dm.classes_txt_file_path) + + predictions_df.index = smiles_strings + predictions_df.to_csv(save_to) + @torch.inference_mode() def predict_smiles( self, - checkpoint_path: _PATH, smiles: List[str], - classes_path: Optional[_PATH] = None, - save_to: Optional[_PATH] = None, - batch_size: Optional[int] = None, - **kwargs, - ) -> torch.Tensor | None: + ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. Args: - checkpoint_path: Path to the model checkpoint. smiles: A list of SMILES strings. - classes_path: Optional path to a file containing class names. If no class - names are provided, code will try to get the class path from the datamodule, - else the columns will be numbered. - save_to: Optional path to save the predictions CSV file. If not provided, - predictions will be returned as a tensor. - batch_size: Optional batch size for the DataLoader. If not provided, the default - from the datamodule will be used. - - Returns: (if save_to is None) + + Returns: A tensor containing the predictions. """ - ckpt_file = torch.load( - checkpoint_path, map_location=self.device, weights_only=False - ) - - dm_hparams = ckpt_file["datamodule_hyper_parameters"] - dm_hparams.pop("splits_file_path") - dm: XYBaseDataModule = instantiate_module(XYBaseDataModule, dm_hparams) - print(f"Loaded datamodule class: {dm.__class__.__name__}") - if batch_size is not None: - dm.batch_size = batch_size - - model_hparams = ckpt_file["hyper_parameters"] - model: ChebaiBaseNet = instantiate_module(ChebaiBaseNet, model_hparams) - model.eval() - # TODO: Enable torch.compile when supported - # model = torch.compile(model) - print(f"Loaded model class: {model.__class__.__name__}") - # For certain data prediction piplines, we may need model hyperparameters - pred_dl: DataLoader = dm.predict_dataloader( - smiles_list=smiles, model_hparams=model_hparams + pred_dl: DataLoader = self._dm.predict_dataloader( + smiles_list=smiles, model_hparams=self._model_hparams ) preds = [] for batch_idx, batch in enumerate(pred_dl): # For certain model prediction pipelines, we may need data module hyperparameters - preds.append(model.predict_step(batch, batch_idx, dm_hparams=dm_hparams)) + preds.append( + self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) + ) - if not save_to: - # If no save path is provided, return the predictions tensor - return torch.cat(preds) + return torch.cat(preds) - predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) - - def _add_class_columns(class_file_path: _PATH): - with open(class_file_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - if classes_path is not None: - _add_class_columns(classes_path) - elif os.path.exists(dm.classes_txt_file_path): - _add_class_columns(dm.classes_txt_file_path) +class MainPredictor: + @staticmethod + def predict_from_file( + checkpoint_path: _PATH, + smiles_file_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + batch_size: Optional[int] = None, + ) -> None: + predictor = Predictor(checkpoint_path, batch_size) + predictor.predict_from_file( + smiles_file_path, + save_to, + classes_path, + ) - predictions_df.index = smiles - predictions_df.to_csv(save_to) + @staticmethod + def predict_smiles( + checkpoint_path: _PATH, + smiles: List[str], + batch_size: Optional[int] = None, + ) -> torch.Tensor: + predictor = Predictor(checkpoint_path, batch_size) + return predictor.predict_smiles(smiles) if __name__ == "__main__": # python chebai/result/prediction.py predict_from_file --help - CLI(Predictor, as_positional=False) + # python chebai/result/prediction.py predict_smiles --help + CLI(MainPredictor, as_positional=False) From 63670ddada766d84c706752319841e2b7d5d4c60 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 19:32:14 +0100 Subject: [PATCH 13/24] fix for unwanted args to predict_smiles --- chebai/result/prediction.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index b84e59ad..c558d9c7 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -27,7 +27,7 @@ def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): ) self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] - self._dm_hparams.pop("splits_file_path") + # self._dm_hparams.pop("splits_file_path") self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) @@ -63,11 +63,7 @@ def predict_from_file( with open(smiles_file_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] - preds: torch.Tensor = self.predict_smiles( - smiles=smiles_strings, - classes_path=classes_path, - save_to=save_to, - ) + preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings) predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) From d906ad404ecac72fe1236193de462391f9da607b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 20:22:18 +0100 Subject: [PATCH 14/24] avoid non_null_labels key in loss kwargs --- chebai/preprocessing/datasets/base.py | 5 ++++- chebai/result/prediction.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 5e3064b3..ee60de99 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -444,9 +444,12 @@ def _process_input_for_prediction( Returns: List[Dict[str, Any]]: Processed input data. """ + # Add dummy labels because the collate function requires them. + # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, + # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. data = [ self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": None} + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} ) for idx, smiles in enumerate(smiles_list) ] diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index c558d9c7..a0a01050 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -65,7 +65,7 @@ def predict_from_file( preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings) - predictions_df = pd.DataFrame(torch.cat(preds).detach().cpu().numpy()) + predictions_df = pd.DataFrame(preds.detach().cpu().numpy()) def _add_class_columns(class_file_path: _PATH): with open(class_file_path, "r") as f: From ba5884a996ed663ff2784d7dd3779c697df1540f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 20:37:58 +0100 Subject: [PATCH 15/24] compile model --- chebai/result/prediction.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index a0a01050..212a95e9 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -13,36 +13,47 @@ class Predictor: - def __init__(self, checkpoint_path: _PATH, batch_size: Optional[int] = None): + def __init__( + self, + checkpoint_path: _PATH, + batch_size: Optional[int] = None, + compile_model: bool = True, + ): """Initializes the Predictor with a model loaded from the checkpoint. Args: checkpoint_path: Path to the model checkpoint. batch_size: Optional batch size for the DataLoader. If not provided, the default from the datamodule will be used. + compile_model: Whether to compile the model using torch.compile. Default is True. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt_file = torch.load( checkpoint_path, map_location=self.device, weights_only=False ) + print("-" * 50) + print(f"For Loaded checkpoint from: {checkpoint_path}") + print("Below are the modules loaded from the checkpoint:") self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] # self._dm_hparams.pop("splits_file_path") self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) - print(f"Loaded datamodule class: {self._dm.__class__.__name__}") if batch_size is not None and int(batch_size) > 0: self._dm.batch_size = int(batch_size) + print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") self._model_hparams = ckpt_file["hyper_parameters"] self._model: ChebaiBaseNet = instantiate_module( ChebaiBaseNet, self._model_hparams ) + print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}") + + if compile_model: + self._model = torch.compile(self._model) self._model.eval() - # TODO: Enable torch.compile when supported - # model = torch.compile(model) - print(f"Loaded model class: {self._model.__class__.__name__}") + print("-" * 50) def predict_from_file( self, From 5a17f722df61ce6e51b24c475a002e4cfbc75ffc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 20:43:48 +0100 Subject: [PATCH 16/24] revert the comment line for splits file path --- chebai/result/prediction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index 212a95e9..e904c57d 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -36,7 +36,7 @@ def __init__( print("Below are the modules loaded from the checkpoint:") self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] - # self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("splits_file_path") self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) From bd59d5924f1de6fae1d4a2d54e2cf6a4bb3a67a1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 13 Dec 2025 00:16:42 +0100 Subject: [PATCH 17/24] handle augment electra and old ckpt files --- chebai/result/prediction.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index e904c57d..a3b82364 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -37,6 +37,13 @@ def __init__( self._dm_hparams = ckpt_file["datamodule_hyper_parameters"] self._dm_hparams.pop("splits_file_path") + self._dm_hparams.pop("augment_smiles", None) + self._dm_hparams.pop("aug_smiles_variations", None) + assert "_class_path" in self._dm_hparams, ( + "Datamodule hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) From 6b8ae0992e96a4ae77756521b9d5b1c54877d093 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 13 Dec 2025 19:58:48 +0100 Subject: [PATCH 18/24] remove unnec device --- chebai/trainer/CustomTrainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 11ade921..b9e4b0f3 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -1,7 +1,6 @@ import logging from typing import Any, Optional, Tuple -import torch from lightning import Trainer from lightning.fabric.utilities.data import _set_sampler_epoch from lightning.pytorch.loggers import WandbLogger @@ -40,7 +39,6 @@ def __init__(self, *args, **kwargs): # use custom fit loop (https://lightning.ai/docs/pytorch/LTS/extensions/loops.html#overriding-the-default-loops) self.fit_loop = LoadDataLaterFitLoop(self, self.min_epochs, self.max_epochs) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: """ From c0293959238b038165bb005e377c428373a46583 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 15 Dec 2025 15:07:22 +0100 Subject: [PATCH 19/24] raise error for invalid smiles and return None --- chebai/preprocessing/reader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 22b91a0e..d63671f7 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -203,10 +203,13 @@ def _read_data(self, raw_data: str) -> List[int]: print(f"RDKit failed to process {raw_data}") print(f"\t{e}") try: + mol = Chem.MolFromSmiles(raw_data.strip()) + if mol is None: + raise ValueError(f"Invalid SMILES: {raw_data}") return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] except ValueError as e: print(f"could not process {raw_data}") - print(f"\t{e}") + print(f"\tError: {e}") return None From 676107936d557c70c55b55e8c33481915e1b1a70 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 15 Dec 2025 22:48:05 +0100 Subject: [PATCH 20/24] rectify test as to return None for invalid strings --- tests/unit/readers/testChemDataReader.py | 32 +++++++++++++----------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/unit/readers/testChemDataReader.py b/tests/unit/readers/testChemDataReader.py index ec018f00..9d322f27 100644 --- a/tests/unit/readers/testChemDataReader.py +++ b/tests/unit/readers/testChemDataReader.py @@ -42,19 +42,22 @@ def test_read_data(self) -> None: """ Test the _read_data method with a SMILES string to ensure it correctly tokenizes the string. """ - raw_data = "CC(=O)NC1[Mg-2]" + raw_data = "CC(=O)NC1CC1[Mg-2]" # Expected output as per the tokens already in the cache, and ")" getting added to it. expected_output: List[int] = [ EMBEDDING_OFFSET + 0, # C EMBEDDING_OFFSET + 0, # C - EMBEDDING_OFFSET + 5, # = - EMBEDDING_OFFSET + 3, # O - EMBEDDING_OFFSET + 1, # N - EMBEDDING_OFFSET + len(self.reader.cache), # ( - EMBEDDING_OFFSET + 2, # C + EMBEDDING_OFFSET + 5, # ( + EMBEDDING_OFFSET + 3, # = + EMBEDDING_OFFSET + 1, # O + EMBEDDING_OFFSET + len(self.reader.cache), # ) - new token + EMBEDDING_OFFSET + 2, # N EMBEDDING_OFFSET + 0, # C EMBEDDING_OFFSET + 4, # 1 - EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 0, # C + EMBEDDING_OFFSET + 4, # 1 + EMBEDDING_OFFSET + len(self.reader.cache) + 1, # [Mg-2] - new token ] result = self.reader._read_data(raw_data) self.assertEqual( @@ -99,13 +102,14 @@ def test_read_data_with_invalid_input(self) -> None: Test the _read_data method with an invalid input. The invalid token should prompt a return value None """ - raw_data = "%INVALID%" - - result = self.reader._read_data(raw_data) - self.assertIsNone( - result, - "The output for invalid token '%INVALID%' should be None.", - ) + # see https://github.com/ChEB-AI/python-chebai/issues/137 + raw_datas = ["%INVALID%", "ADADAD", "ADASDAD", "CC(=O)NC1[Mg-2]"] + for raw_data in raw_datas: + result = self.reader._read_data(raw_data) + self.assertIsNone( + result, + f"The output for invalid token '{raw_data}' should be None.", + ) @patch("builtins.open", new_callable=mock_open) def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None: From d5e362b7a812da012235d52cbf29b4686263a51b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 16 Dec 2025 21:32:19 +0100 Subject: [PATCH 21/24] pin rdkit version - https://github.com/ChEB-AI/python-chebai/issues/83 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index eb75643d..4ba71f8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "torch", "transformers", "pysmiles==1.1.2", - "rdkit", + "rdkit==2024.3.6", "lightning==2.5.1", ] From 5d302d0c03687c70c187afd9608a4db8f55aa1a0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 16 Dec 2025 11:53:24 +0100 Subject: [PATCH 22/24] handle None returns for invalid smiles --- chebai/preprocessing/datasets/base.py | 57 ++++++++++++++++++--------- chebai/result/prediction.py | 15 +++++-- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index ee60de99..abda9471 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -412,49 +412,68 @@ def predict_dataloader( smiles_list: List[str], model_hparams: Optional[dict] = None, **kwargs, - ) -> Union[DataLoader, List[DataLoader]]: + ) -> tuple[DataLoader, list[int]]: """ Returns the predict DataLoader. Args: - *args: Additional positional arguments (unused). + smiles_list (List[str]): List of SMILES strings to predict. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. **kwargs: Additional keyword arguments, passed to dataloader(). Returns: - Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. + tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices. """ - data = self._process_input_for_prediction(smiles_list, model_hparams) - return DataLoader( - data, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, + data, valid_indices = self._process_input_for_prediction( + smiles_list, model_hparams + ) + return ( + DataLoader( + data, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ), + valid_indices, ) def _process_input_for_prediction( self, smiles_list: list[str], model_hparams: Optional[dict] = None - ) -> list: + ) -> tuple[list, list]: """ Process input data for prediction. Args: smiles_list (List[str]): List of SMILES strings. + model_hparams (Optional[dict]): Model hyperparameters. + Some prediction pre-processing pipelines may require these. Returns: - List[Dict[str, Any]]: Processed input data. + tuple[list, list]: Processed input data and valid indices. """ + data, valid_indices = [], [] + for idx, smiles in enumerate(smiles_list): + result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams) + if result is None or result["features"] is None: + continue + data.append(result) + valid_indices.append(idx) + + data = self._filter_to_token_limit(data) + return data, valid_indices + + def _preprocess_smiles_for_pred( + self, idx, smiles: str, model_hparams: Optional[dict] = None + ) -> dict: + """Preprocess prediction data.""" # Add dummy labels because the collate function requires them. # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`, # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty. - data = [ - self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} - ) - for idx, smiles in enumerate(smiles_list) - ] - data = self._filter_to_token_limit(data) - return data + return self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]} + ) def prepare_data(self, *args, **kwargs) -> None: if self._prepare_data_flag != 1: diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index a3b82364..ad5775a1 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -6,7 +6,6 @@ from jsonargparse import CLI from lightning.fabric.utilities.types import _PATH from lightning.pytorch.cli import instantiate_module -from torch.utils.data import DataLoader from chebai.models.base import ChebaiBaseNet from chebai.preprocessing.datasets.base import XYBaseDataModule @@ -101,7 +100,7 @@ def _add_class_columns(class_file_path: _PATH): def predict_smiles( self, smiles: List[str], - ) -> torch.Tensor: + ) -> list[torch.Tensor | None]: """ Predicts the output for a list of SMILES strings using the model. @@ -112,7 +111,7 @@ def predict_smiles( A tensor containing the predictions. """ # For certain data prediction piplines, we may need model hyperparameters - pred_dl: DataLoader = self._dm.predict_dataloader( + pred_dl, valid_indices = self._dm.predict_dataloader( smiles_list=smiles, model_hparams=self._model_hparams ) @@ -122,8 +121,16 @@ def predict_smiles( preds.append( self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams) ) + preds = torch.cat(preds) + + # Initialize output with None + output: list[torch.Tensor | None] = [None] * len(smiles) + + # Scatter predictions back + for pred, idx in zip(preds, valid_indices): + output[idx] = pred - return torch.cat(preds) + return output class MainPredictor: From e86f03ae8d5869bee3bd715691478710c4ed1258 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 18 Dec 2025 13:45:21 +0100 Subject: [PATCH 23/24] remove instanstior key from hparams as its causing unnecessary error --- chebai/result/prediction.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/chebai/result/prediction.py b/chebai/result/prediction.py index ad5775a1..1410a6c5 100644 --- a/chebai/result/prediction.py +++ b/chebai/result/prediction.py @@ -30,6 +30,15 @@ def __init__( ckpt_file = torch.load( checkpoint_path, map_location=self.device, weights_only=False ) + assert ( + "_class_path" in ckpt_file["datamodule_hyper_parameters"] + and "_class_path" in ckpt_file["hyper_parameters"] + ), ( + "Datamodule and Model hyperparameters must include a '_class_path' key.\n" + "Hence, either the checkpoint is corrupted or " + "it was not saved properly with latest lightning version" + ) + print("-" * 50) print(f"For Loaded checkpoint from: {checkpoint_path}") print("Below are the modules loaded from the checkpoint:") @@ -38,11 +47,7 @@ def __init__( self._dm_hparams.pop("splits_file_path") self._dm_hparams.pop("augment_smiles", None) self._dm_hparams.pop("aug_smiles_variations", None) - assert "_class_path" in self._dm_hparams, ( - "Datamodule hyperparameters must include a '_class_path' key.\n" - "Hence, either the checkpoint is corrupted or " - "it was not saved properly with latest lightning version" - ) + self._dm_hparams.pop("_instantiator", None) self._dm: XYBaseDataModule = instantiate_module( XYBaseDataModule, self._dm_hparams ) @@ -51,6 +56,7 @@ def __init__( print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}") self._model_hparams = ckpt_file["hyper_parameters"] + self._model_hparams.pop("_instantiator", None) self._model: ChebaiBaseNet = instantiate_module( ChebaiBaseNet, self._model_hparams ) From 994de55469a1bab66e2c9ffdbaabb78a4d42ac7b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 18 Dec 2025 13:48:38 +0100 Subject: [PATCH 24/24] return none for token limit too --- chebai/preprocessing/datasets/base.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 19cc222f..d47f38f7 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -341,20 +341,17 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: if d["features"] is not None ] - return self._filter_to_token_limit(data) + data = [val for val in data if self._filter_to_token_limit(val)] + return data - def _filter_to_token_limit( - self, data: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: + def _filter_to_token_limit(self, data_instance: dict) -> bool: # filter for missing features in resulting data, keep features length below token limit - return [ - val - for val in data - if val["features"] is not None - and ( - self.n_token_limit is None or len(val["features"]) <= self.n_token_limit - ) - ] + if data_instance["features"] is not None and ( + self.n_token_limit is None + or len(data_instance["features"]) <= self.n_token_limit + ): + return True + return False def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: """ @@ -459,10 +456,11 @@ def _process_input_for_prediction( result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams) if result is None or result["features"] is None: continue + if not self._filter_to_token_limit(result): + continue data.append(result) valid_indices.append(idx) - data = self._filter_to_token_limit(data) return data, valid_indices def _preprocess_smiles_for_pred(