Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
465b651
use instantiate model to load data and model from ckpt
aditya0by0 Nov 16, 2025
2acd166
update readme
aditya0by0 Nov 16, 2025
cfbf392
set no grad for predict
aditya0by0 Nov 17, 2025
82b365c
predict pipeline in dm and lm
aditya0by0 Nov 26, 2025
fa6f1b5
there is no need that predict func must depend on trainer
aditya0by0 Nov 26, 2025
ae47608
model hparams for data predict pipeline and vice versa
aditya0by0 Nov 27, 2025
517a5a2
fix reader ident error
aditya0by0 Nov 27, 2025
40491e5
fix label None error
aditya0by0 Nov 27, 2025
7b7e48f
fix cli predict_from_file error
aditya0by0 Nov 27, 2025
6a38317
update readme for new prediction method
aditya0by0 Nov 27, 2025
31b12db
Revert "fix reader ident error"
aditya0by0 Nov 27, 2025
c6e8b61
modify pred logic to store model and dm as instance var
aditya0by0 Nov 28, 2025
63670dd
fix for unwanted args to predict_smiles
aditya0by0 Dec 6, 2025
d906ad4
avoid non_null_labels key in loss kwargs
aditya0by0 Dec 6, 2025
ba5884a
compile model
aditya0by0 Dec 6, 2025
5a17f72
revert the comment line for splits file path
aditya0by0 Dec 6, 2025
bd59d59
handle augment electra and old ckpt files
aditya0by0 Dec 12, 2025
6b8ae09
remove unnec device
aditya0by0 Dec 13, 2025
c029395
raise error for invalid smiles and return None
aditya0by0 Dec 15, 2025
6761079
rectify test as to return None for invalid strings
aditya0by0 Dec 15, 2025
bbb8f10
Merge branch 'dev' into fix/read_data
aditya0by0 Dec 15, 2025
d5e362b
pin rdkit version - https://github.com/ChEB-AI/python-chebai/issues/83
aditya0by0 Dec 16, 2025
5d302d0
handle None returns for invalid smiles
aditya0by0 Dec 16, 2025
b75209e
Merge branch 'dev' into fix/generalize_predict_func
aditya0by0 Dec 18, 2025
ead78b6
Merge branch 'fix/read_data' into fix/generalize_predict_func
aditya0by0 Dec 18, 2025
e86f03a
remove instanstior key from hparams as its causing unnecessary error
aditya0by0 Dec 18, 2025
994de55
return none for token limit too
aditya0by0 Dec 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,19 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con

### Predicting classes given SMILES strings
```
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] ----smiles_file_path=[path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
```
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
one row for each SMILES string and one column for each class.
The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs.

* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`).

* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line.

* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class.

* **`--classes_path`** *(optional)*: Path to the dataset’s `raw/classes.txt` file, which maps model output indices to ChEBI IDs.

* If provided, the CSV columns will be named using the ChEBI IDs.
* If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially.

## Evaluation

Expand Down
1 change: 0 additions & 1 deletion chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def subcommands() -> Dict[str, Set[str]]:
"validate": {"model", "dataloaders", "datamodule"},
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
"predict_from_file": {"model"},
}


Expand Down
10 changes: 9 additions & 1 deletion chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,15 @@ def predict_step(
Returns:
Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step.
"""
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
assert isinstance(batch, XYData)
batch = batch.to(self.device)
data = self._process_batch(batch, batch_idx)
model_output = self(data, **data.get("model_kwargs", dict()))

# Dummy labels to avoid errors in _get_prediction_and_labels
labels = torch.zeros((len(batch), self.out_dim)).to(self.device)
pr, _ = self._get_prediction_and_labels(data, labels, model_output)
return pr

def _execute(
self,
Expand Down
98 changes: 83 additions & 15 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,19 @@ def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]:
for d in tqdm.tqdm(self._load_dict(path), total=lines)
if d["features"] is not None
]
# filter for missing features in resulting data, keep features length below token limit
data = [
val
for val in data
if val["features"] is not None
and (
self.n_token_limit is None or len(val["features"]) <= self.n_token_limit
)
]

data = [val for val in data if self._filter_to_token_limit(val)]
return data

def _filter_to_token_limit(self, data_instance: dict) -> bool:
# filter for missing features in resulting data, keep features length below token limit
if data_instance["features"] is not None and (
self.n_token_limit is None
or len(data_instance["features"]) <= self.n_token_limit
):
return True
return False

def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
"""
Returns the train DataLoader.
Expand Down Expand Up @@ -401,22 +402,77 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
Returns:
Union[DataLoader, List[DataLoader]]: A DataLoader object for test data.
"""

return self.dataloader("test", shuffle=False, **kwargs)

def predict_dataloader(
self, *args, **kwargs
) -> Union[DataLoader, List[DataLoader]]:
self,
smiles_list: List[str],
model_hparams: Optional[dict] = None,
**kwargs,
) -> tuple[DataLoader, list[int]]:
"""
Returns the predict DataLoader.

Args:
*args: Additional positional arguments (unused).
smiles_list (List[str]): List of SMILES strings to predict.
model_hparams (Optional[dict]): Model hyperparameters.
Some prediction pre-processing pipelines may require these.
**kwargs: Additional keyword arguments, passed to dataloader().

Returns:
Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data.
tuple[DataLoader, list[int]]: A DataLoader object for prediction data and a list of valid indices.
"""
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)

data, valid_indices = self._process_input_for_prediction(
smiles_list, model_hparams
)
return (
DataLoader(
data,
collate_fn=self.reader.collator,
batch_size=self.batch_size,
**kwargs,
),
valid_indices,
)

def _process_input_for_prediction(
self, smiles_list: list[str], model_hparams: Optional[dict] = None
) -> tuple[list, list]:
"""
Process input data for prediction.

Args:
smiles_list (List[str]): List of SMILES strings.
model_hparams (Optional[dict]): Model hyperparameters.
Some prediction pre-processing pipelines may require these.

Returns:
tuple[list, list]: Processed input data and valid indices.
"""
data, valid_indices = [], []
for idx, smiles in enumerate(smiles_list):
result = self._preprocess_smiles_for_pred(idx, smiles, model_hparams)
if result is None or result["features"] is None:
continue
if not self._filter_to_token_limit(result):
continue
data.append(result)
valid_indices.append(idx)

return data, valid_indices

def _preprocess_smiles_for_pred(
self, idx, smiles: str, model_hparams: Optional[dict] = None
) -> dict:
"""Preprocess prediction data."""
# Add dummy labels because the collate function requires them.
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
return self.reader.to_data(
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
)

def prepare_data(self, *args, **kwargs) -> None:
if self._prepare_data_flag != 1:
Expand Down Expand Up @@ -1191,7 +1247,8 @@ def _retrieve_splits_from_csv(self) -> None:
print(f"Applying label filter from {self.apply_label_filter}...")
with open(self.apply_label_filter, "r") as f:
label_filter = [line.strip() for line in f]
with open(os.path.join(self.processed_dir_main, "classes.txt"), "r") as cf:

with open(self.classes_txt_file_path, "r") as cf:
classes = [line.strip() for line in cf]
# reorder labels
old_labels = np.stack(df_data["labels"])
Expand Down Expand Up @@ -1316,3 +1373,14 @@ def processed_file_names_dict(self) -> dict:
if self.n_token_limit is not None:
return {"data": f"data_maxlen{self.n_token_limit}.pt"}
return {"data": "data.pt"}

@property
def classes_txt_file_path(self) -> str:
"""
Returns the filename for the classes text file.

Returns:
str: The filename for the classes text file.
"""
# This property also used in custom trainer `chebai/trainer/CustomTrainer.py`
return os.path.join(self.processed_dir_main, "classes.txt")
5 changes: 4 additions & 1 deletion chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,13 @@ def _read_data(self, raw_data: str) -> List[int]:
print(f"RDKit failed to process {raw_data}")
print(f"\t{e}")
try:
mol = Chem.MolFromSmiles(raw_data.strip())
if mol is None:
raise ValueError(f"Invalid SMILES: {raw_data}")
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
except ValueError as e:
print(f"could not process {raw_data}")
print(f"\t{e}")
print(f"\tError: {e}")
return None

def _back_to_smiles(self, smiles_encoded):
Expand Down
171 changes: 171 additions & 0 deletions chebai/result/prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import os
from typing import List, Optional

import pandas as pd
import torch
from jsonargparse import CLI
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.cli import instantiate_module

from chebai.models.base import ChebaiBaseNet
from chebai.preprocessing.datasets.base import XYBaseDataModule


class Predictor:
def __init__(
self,
checkpoint_path: _PATH,
batch_size: Optional[int] = None,
compile_model: bool = True,
):
"""Initializes the Predictor with a model loaded from the checkpoint.

Args:
checkpoint_path: Path to the model checkpoint.
batch_size: Optional batch size for the DataLoader. If not provided,
the default from the datamodule will be used.
compile_model: Whether to compile the model using torch.compile. Default is True.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_file = torch.load(
checkpoint_path, map_location=self.device, weights_only=False
)
assert (
"_class_path" in ckpt_file["datamodule_hyper_parameters"]
and "_class_path" in ckpt_file["hyper_parameters"]
), (
"Datamodule and Model hyperparameters must include a '_class_path' key.\n"
"Hence, either the checkpoint is corrupted or "
"it was not saved properly with latest lightning version"
)

print("-" * 50)
print(f"For Loaded checkpoint from: {checkpoint_path}")
print("Below are the modules loaded from the checkpoint:")

self._dm_hparams = ckpt_file["datamodule_hyper_parameters"]
self._dm_hparams.pop("splits_file_path")
self._dm_hparams.pop("augment_smiles", None)
self._dm_hparams.pop("aug_smiles_variations", None)
self._dm_hparams.pop("_instantiator", None)
self._dm: XYBaseDataModule = instantiate_module(
XYBaseDataModule, self._dm_hparams
)
if batch_size is not None and int(batch_size) > 0:
self._dm.batch_size = int(batch_size)
print("*" * 10, f"Loaded datamodule class: {self._dm.__class__.__name__}")

self._model_hparams = ckpt_file["hyper_parameters"]
self._model_hparams.pop("_instantiator", None)
self._model: ChebaiBaseNet = instantiate_module(
ChebaiBaseNet, self._model_hparams
)
print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}")

if compile_model:
self._model = torch.compile(self._model)
self._model.eval()
print("-" * 50)

def predict_from_file(
self,
smiles_file_path: _PATH,
save_to: _PATH = "predictions.csv",
classes_path: Optional[_PATH] = None,
) -> None:
"""
Loads a model from a checkpoint and makes predictions on input data from a file.

Args:
smiles_file_path: Path to the input file containing SMILES strings.
save_to: Path to save the predictions CSV file.
classes_path: Optional path to a file containing class names:
if no class names are provided, code will try to get the class path
from the datamodule, else the columns will be numbered.
"""
with open(smiles_file_path, "r") as input:
smiles_strings = [inp.strip() for inp in input.readlines()]

preds: torch.Tensor = self.predict_smiles(smiles=smiles_strings)

predictions_df = pd.DataFrame(preds.detach().cpu().numpy())

def _add_class_columns(class_file_path: _PATH):
with open(class_file_path, "r") as f:
predictions_df.columns = [cls.strip() for cls in f.readlines()]

if classes_path is not None:
_add_class_columns(classes_path)
elif os.path.exists(self._dm.classes_txt_file_path):
_add_class_columns(self._dm.classes_txt_file_path)

predictions_df.index = smiles_strings
predictions_df.to_csv(save_to)

@torch.inference_mode()
def predict_smiles(
self,
smiles: List[str],
) -> list[torch.Tensor | None]:
"""
Predicts the output for a list of SMILES strings using the model.

Args:
smiles: A list of SMILES strings.

Returns:
A tensor containing the predictions.
"""
# For certain data prediction piplines, we may need model hyperparameters
pred_dl, valid_indices = self._dm.predict_dataloader(
smiles_list=smiles, model_hparams=self._model_hparams
)

preds = []
for batch_idx, batch in enumerate(pred_dl):
# For certain model prediction pipelines, we may need data module hyperparameters
preds.append(
self._model.predict_step(batch, batch_idx, dm_hparams=self._dm_hparams)
)
preds = torch.cat(preds)

# Initialize output with None
output: list[torch.Tensor | None] = [None] * len(smiles)

# Scatter predictions back
for pred, idx in zip(preds, valid_indices):
output[idx] = pred

return output


class MainPredictor:
@staticmethod
def predict_from_file(
checkpoint_path: _PATH,
smiles_file_path: _PATH,
save_to: _PATH = "predictions.csv",
classes_path: Optional[_PATH] = None,
batch_size: Optional[int] = None,
) -> None:
predictor = Predictor(checkpoint_path, batch_size)
predictor.predict_from_file(
smiles_file_path,
save_to,
classes_path,
)

@staticmethod
def predict_smiles(
checkpoint_path: _PATH,
smiles: List[str],
batch_size: Optional[int] = None,
) -> torch.Tensor:
predictor = Predictor(checkpoint_path, batch_size)
return predictor.predict_smiles(smiles)


if __name__ == "__main__":
# python chebai/result/prediction.py predict_from_file --help
# python chebai/result/prediction.py predict_smiles --help
CLI(MainPredictor, as_positional=False)
Loading