diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b0cfc0ec..6400e11c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,12 +8,17 @@ on: paths: ["**/*.py", .github/workflows/test.yml] branches: [main] +concurrency: + # Cancel only on same PR number + group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }} + cancel-in-progress: true + jobs: tests: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-14] + os: [ubuntu-latest, macos-latest, windows-latest] version: - { python: "3.10", resolution: highest } - { python: "3.12", resolution: lowest-direct } @@ -33,7 +38,7 @@ jobs: - name: Install dependencies run: | - pip install torch --index-url https://download.pytorch.org/whl/cpu + uv pip install torch --index-url https://download.pytorch.org/whl/cpu --system uv pip install .[test] --system - name: Run Tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc43aece..5f627e19 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,4 +45,4 @@ repos: rev: 0.8.1 hooks: - id: nbstripout - args: [--drop-empty-cells, --keep-output] + args: [--drop-empty-cells] diff --git a/aviary/cgcnn/data.py b/aviary/cgcnn/data.py index d426e260..1cd7958e 100644 --- a/aviary/cgcnn/data.py +++ b/aviary/cgcnn/data.py @@ -1,5 +1,4 @@ import itertools -import json from collections.abc import Sequence from functools import cache from typing import Any @@ -12,8 +11,6 @@ from torch.utils.data import Dataset from tqdm import tqdm -from aviary import PKG_DIR - class CrystalGraphData(Dataset): """Dataset class for the CGCNN structure model.""" @@ -22,13 +19,10 @@ def __init__( self, df: pd.DataFrame, task_dict: dict[str, str], - elem_embedding: str = "cgcnn92", structure_col: str = "structure", identifiers: Sequence[str] = (), - radius: float = 5, + radius_cutoff: float = 5, max_num_nbr: int = 12, - dmin: float = 0, - step: float = 0.2, ): """Featurize crystal structures into neighborhood graphs with this data class for CGCNN. @@ -36,44 +30,21 @@ def __init__( Args: df (pd.Dataframe): Pandas dataframe holding input and target values. task_dict ({target: task}): task dict for multi-task learning - elem_embedding (str, optional): One of matscholar200, cgcnn92, megnet16, - onehot112 or path to a file with custom element embeddings. - Defaults to matscholar200. structure_col (str, optional): df column holding pymatgen Structure objects as input. identifiers (list[str], optional): df columns for distinguishing data points. Will be copied over into the model's output CSV. Defaults to (). - radius (float, optional): Cut-off radius for neighborhood. Defaults to 5. + radius_cutoff (float, optional): Cut-off radius for neighborhood. + Defaults to 5. max_num_nbr (int, optional): maximum number of neighbors to consider. Defaults to 12. - dmin (float, optional): minimum distance in Gaussian basis. Defaults to 0. - step (float, optional): increment size of Gaussian basis. Defaults to 0.2. """ self.task_dict = task_dict self.identifiers = list(identifiers) - self.radius = radius + self.radius_cutoff = radius_cutoff self.max_num_nbr = max_num_nbr - if elem_embedding in ("matscholar200", "cgcnn92", "megnet16", "onehot112"): - elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" - - with open(elem_embedding) as file: - self.elem_features = json.load(file) - - for key, value in self.elem_features.items(): - self.elem_features[key] = np.array(value, dtype=float) - if not hasattr(self, "elem_emb_len"): - self.elem_emb_len = len(value) - elif self.elem_emb_len != len(value): - raise ValueError( - f"Element embedding length mismatch: len({key})=" - f"{len(value)}, expected {self.elem_emb_len}" - ) - - self.gaussian_dist_func = GaussianDistance(dmin=dmin, dmax=radius, step=step) - self.nbr_fea_dim = self.gaussian_dist_func.embedding_size - self.df = df self.structure_col = structure_col @@ -84,7 +55,7 @@ def __init__( self.df[structure_col].items(), total=len(df), desc=desc, disable=None ): self_idx, nbr_idx, _ = get_structure_neighbor_info( - struct, radius, max_num_nbr + struct, self.radius_cutoff, self.max_num_nbr ) material_ids = [idx, *self.df.loc[idx][self.identifiers]] if 0 in (len(self_idx), len(nbr_idx)): @@ -140,16 +111,10 @@ def __getitem__(self, idx: int): material_ids = [self.df.index[idx], *row[self.identifiers]] # atom features for disordered sites - site_atoms = [atom.species.as_dict() for atom in struct] - atom_features = np.vstack( - [ - np.sum([self.elem_features[el] * amt for el, amt in site.items()], axis=0) - for site in site_atoms - ] - ) + atom_features = [atom.specie.Z for atom in struct] self_idx, nbr_idx, nbr_dist = get_structure_neighbor_info( - struct, self.radius, self.max_num_nbr + struct, self.radius_cutoff, self.max_num_nbr ) if len(self_idx) == 0: @@ -161,9 +126,7 @@ def __getitem__(self, idx: int): if set(self_idx) != set(range(len(struct))): raise ValueError(f"At least one atom in {material_ids} is isolated") - nbr_dist = self.gaussian_dist_func.expand(nbr_dist) - - atom_fea_t = Tensor(atom_features) + atom_fea_t = LongTensor(atom_features) nbr_dist_t = Tensor(nbr_dist) self_idx_t = LongTensor(self_idx) nbr_idx_t = LongTensor(nbr_idx) @@ -278,7 +241,7 @@ def __init__( "Max radii below minimum radii + step size - please increase dmax." ) - self.filter = np.arange(dmin, dmax + step, step) + self.filter = torch.arange(dmin, dmax + step, step) self.embedding_size = len(self.filter) if var is None: @@ -286,19 +249,17 @@ def __init__( self.var = var - def expand(self, distances: np.ndarray) -> np.ndarray: + def expand(self, distances: Tensor) -> Tensor: """Apply Gaussian distance filter to a numpy distance array. Args: distances (ArrayLike): A distance matrix of any shape. Returns: - np.ndarray: Expanded distance matrix with the last dimension of length + Tensor: Expanded distance matrix with the last dimension of length len(self.filter) """ - distances = np.array(distances) - - return np.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2) + return torch.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2) def get_structure_neighbor_info( diff --git a/aviary/cgcnn/model.py b/aviary/cgcnn/model.py index e96389cb..ff62ef99 100644 --- a/aviary/cgcnn/model.py +++ b/aviary/cgcnn/model.py @@ -5,9 +5,11 @@ from pymatgen.util.due import Doi, due from torch import LongTensor, Tensor, nn +from aviary.cgcnn.data import GaussianDistance from aviary.core import BaseModelClass from aviary.networks import SimpleNetwork from aviary.scatter import scatter_reduce +from aviary.utils import get_element_embedding @due.dcite(Doi("10.1103/PhysRevLett.120.145301"), description="CGCNN model") @@ -25,8 +27,10 @@ def __init__( self, robust: bool, n_targets: Sequence[int], - elem_emb_len: int, - nbr_fea_len: int, + elem_embedding: str = "cgcnn92", + radius_cutoff: float = 5.0, + radius_min: float = 0.0, + radius_step: float = 0.2, elem_fea_len: int = 64, n_graph: int = 4, h_fea_len: int = 128, @@ -42,8 +46,15 @@ def __init__( (uncertainty inherent to the sample) which can be used with a robust loss function to attenuate the weighting of uncertain samples. n_targets (list[int]): Number of targets to train on - elem_emb_len (int): Number of atom features in the input. - nbr_fea_len (int): Number of bond features. + elem_embedding (str, optional): One of matscholar200, cgcnn92, megnet16, + onehot112 or path to a file with custom element embeddings. + Defaults to matscholar200. + radius_cutoff (float, optional): Cut-off radius for neighborhood. + Defaults to 5. + radius_min (float, optional): minimum distance in Gaussian basis. + Defaults to 0. + radius_step (float, optional): increment size of Gaussian basis. + Defaults to 0.2. elem_fea_len (int, optional): Number of hidden atom features in the convolutional layers. Defaults to 64. n_graph (int, optional): Number of convolutional layers. Defaults to 4. @@ -57,6 +68,14 @@ def __init__( """ super().__init__(robust=robust, **kwargs) + self.elem_embedding = get_element_embedding(elem_embedding) + elem_emb_len = self.elem_embedding.weight.shape[1] + + self.gaussian_dist_func = GaussianDistance( + dmin=radius_min, dmax=radius_cutoff, step=radius_step + ) + nbr_fea_len = self.gaussian_dist_func.embedding_size + desc_dict = { "elem_emb_len": elem_emb_len, "nbr_fea_len": nbr_fea_len, @@ -107,6 +126,9 @@ def forward( Returns: tuple[Tensor, ...]: tuple of predictions for all targets """ + nbr_fea = self.gaussian_dist_func.expand(nbr_fea) + atom_fea = self.elem_embedding(atom_fea) + atom_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx) crys_fea = scatter_reduce(atom_fea, crystal_atom_idx, dim=0, reduce="mean") diff --git a/aviary/core.py b/aviary/core.py index 983d41a9..a023f022 100644 --- a/aviary/core.py +++ b/aviary/core.py @@ -32,6 +32,7 @@ def __init__( epoch: int = 0, device: str | None = None, best_val_scores: dict[str, float] | None = None, + **kwargs, ) -> None: """Store core model parameters. @@ -47,6 +48,7 @@ def __init__( device (str, optional): Device to store the model parameters on. best_val_scores (dict[str, float], optional): Validation score to use for early stopping. Defaults to None. + **kwargs: Additional keyword arguments. """ super().__init__() self.task_dict = task_dict @@ -299,8 +301,9 @@ def evaluate( preds = output.squeeze(1) loss = loss_func(preds, targets) - z_scored_error = preds - targets - error = normalizer.std * z_scored_error.data.cpu() + denormed_preds = normalizer.denorm(preds) + denormed_targets = normalizer.denorm(targets) + error = denormed_preds - denormed_targets target_metrics["MAE"].append(float(error.abs().mean())) target_metrics["MSE"].append(float(error.pow(2).mean())) diff --git a/aviary/roost/data.py b/aviary/roost/data.py index 3f84abb0..e5803f92 100644 --- a/aviary/roost/data.py +++ b/aviary/roost/data.py @@ -1,4 +1,3 @@ -import json from collections.abc import Sequence from functools import cache from typing import Any @@ -10,8 +9,6 @@ from torch import LongTensor, Tensor from torch.utils.data import Dataset -from aviary import PKG_DIR - class CompositionData(Dataset): """Dataset class for the Roost composition model.""" @@ -20,7 +17,6 @@ def __init__( self, df: pd.DataFrame, task_dict: dict[str, str], - elem_embedding: str = "matscholar200", inputs: str = "composition", identifiers: Sequence[str] = ("material_id", "composition"), ): @@ -47,14 +43,6 @@ def __init__( self.identifiers = list(identifiers) self.df = df - if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: - elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" - - with open(elem_embedding) as file: - self.elem_features = json.load(file) - - self.elem_emb_len = len(next(iter(self.elem_features.values()))) - self.n_targets = [] for target, task in self.task_dict.items(): if task == "regression": @@ -88,24 +76,12 @@ def __getitem__(self, idx: int): composition = row[self.inputs] material_ids = row[self.identifiers].to_list() - comp_dict = Composition(composition).get_el_amt_dict() - elements = list(comp_dict) - + comp_dict = Composition(composition).fractional_composition weights = list(comp_dict.values()) weights = np.atleast_2d(weights).T / np.sum(weights) + elem_fea = [elem.Z for elem in comp_dict] - try: - elem_fea = np.vstack([self.elem_features[element] for element in elements]) - except AssertionError as exc: - raise AssertionError( - f"{material_ids} contains element types not in embedding" - ) from exc - except ValueError as exc: - raise ValueError( - f"{material_ids} composition cannot be parsed into elements" - ) from exc - - n_elems = len(elements) + n_elems = len(comp_dict) self_idx = [] nbr_idx = [] for elem_idx in range(n_elems): @@ -114,7 +90,7 @@ def __getitem__(self, idx: int): # convert all data to tensors elem_weights = Tensor(weights) - elem_fea = Tensor(elem_fea) + elem_fea = LongTensor(elem_fea) self_idx = LongTensor(self_idx) nbr_idx = LongTensor(nbr_idx) diff --git a/aviary/roost/model.py b/aviary/roost/model.py index fec22196..3b84f96e 100644 --- a/aviary/roost/model.py +++ b/aviary/roost/model.py @@ -8,6 +8,7 @@ from aviary.core import BaseModelClass from aviary.networks import ResidualNetwork, SimpleNetwork from aviary.segments import MessageLayer, WeightedAttentionPooling +from aviary.utils import get_element_embedding @due.dcite(Doi("10.1038/s41467-020-19964-7"), description="Roost model") @@ -25,7 +26,7 @@ def __init__( self, robust: bool, n_targets: Sequence[int], - elem_emb_len: int, + elem_embedding: str = "matscholar200", elem_fea_len: int = 64, n_graph: int = 3, elem_heads: int = 3, @@ -41,6 +42,8 @@ def __init__( """Composition-only model.""" super().__init__(robust=robust, **kwargs) + self.elem_embedding = get_element_embedding(elem_embedding) + elem_emb_len = self.elem_embedding.weight.shape[1] desc_dict = { "elem_emb_len": elem_emb_len, "elem_fea_len": elem_fea_len, @@ -60,6 +63,7 @@ def __init__( "n_targets": n_targets, "out_hidden": out_hidden, "trunk_hidden": trunk_hidden, + "elem_embedding": elem_embedding, **desc_dict, } self.model_params.update(model_params) @@ -83,6 +87,8 @@ def forward( cry_elem_idx: LongTensor, ) -> tuple[Tensor, ...]: """Forward pass through the material_nn and output_nn.""" + elem_fea = self.elem_embedding(elem_fea) + crys_fea = self.material_nn( elem_weights, elem_fea, self_idx, nbr_idx, cry_elem_idx ) diff --git a/aviary/utils.py b/aviary/utils.py index 1b924092..e1a096b9 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -1,3 +1,4 @@ +import json import os import sys import time @@ -13,6 +14,7 @@ import pandas as pd import torch import wandb +from pymatgen.core import Element from sklearn.metrics import ( accuracy_score, balanced_accuracy_score, @@ -22,13 +24,13 @@ roc_auc_score, ) from torch import LongTensor, Tensor -from torch.nn import CrossEntropyLoss, L1Loss, MSELoss, NLLLoss +from torch.nn import CrossEntropyLoss, Embedding, L1Loss, MSELoss, NLLLoss from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler from torch.utils.data import DataLoader, Subset from torch.utils.tensorboard import SummaryWriter -from aviary import ROOT +from aviary import PKG_DIR, ROOT from aviary.core import BaseModelClass, Normalizer, TaskType, sampled_softmax from aviary.data import InMemoryDataLoader from aviary.losses import robust_l1_loss, robust_l2_loss @@ -795,10 +797,78 @@ def get_metrics( metrics["F1"] = f1_score(targets, pred_labels) class1_probas = predictions[:, 1] metrics["ROCAUC"] = roc_auc_score(targets, class1_probas) + else: + raise ValueError(f"Invalid task type: {type}") return {key: round(float(val), prec) for key, val in metrics.items()} +def get_element_embedding(elem_embedding: str) -> Embedding: + """Get an element embedding from a file. + + Args: + elem_embedding (str): The path to the element embedding file. + + Returns: + Embedding: The element embedding. + """ + if os.path.isfile(elem_embedding): + pass + elif elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: + elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" + else: + raise ValueError(f"Invalid element embedding: {elem_embedding}") + + with open(elem_embedding) as file: + elem_features = json.load(file) + + max_z = max(Element(elem).Z for elem in elem_features) + elem_emb_len = len(next(iter(elem_features.values()))) + elem_feature_matrix = torch.zeros((max_z + 1, elem_emb_len)) + for elem, feature in elem_features.items(): + elem_feature_matrix[Element(elem).Z] = torch.tensor(feature) + + embedding = Embedding(max_z + 1, elem_emb_len) + embedding.weight.data.copy_(elem_feature_matrix) + + return embedding + + +def get_sym_embedding(sym_embedding: str) -> Embedding: + """Get a symmetry embedding from a file. + + Args: + sym_embedding (str): The path to the symmetry embedding file. + + Returns: + Embedding: The symmetry embedding. + """ + if os.path.isfile(sym_embedding): + pass + elif sym_embedding in ("bra-alg-off", "spg-alg-off"): + sym_embedding = f"{PKG_DIR}/embeddings/wyckoff/{sym_embedding}.json" + else: + raise ValueError(f"Invalid symmetry embedding: {sym_embedding}") + + with open(sym_embedding) as sym_file: + sym_features = json.load(sym_file) + + sym_emb_len = len(next(iter(next(iter(sym_features.values())).values()))) + + len_sym_features = sum(len(feature) for feature in sym_features.values()) + sym_feature_matrix = torch.zeros((len_sym_features, sym_emb_len)) + sym_idx = 0 + for embeddings in sym_features.values(): + for feature in embeddings.values(): + sym_feature_matrix[sym_idx] = torch.tensor(feature) + sym_idx += 1 + + embedding = Embedding(len_sym_features, sym_emb_len) + embedding.weight.data.copy_(sym_feature_matrix) + + return embedding + + def as_dict_handler(obj: Any) -> dict[str, Any] | None: """Pass this func as json.dump(handler=) or as pandas.to_json(default_handler=).""" try: diff --git a/aviary/wren/data.py b/aviary/wren/data.py index fe16937c..b38c7a21 100644 --- a/aviary/wren/data.py +++ b/aviary/wren/data.py @@ -1,4 +1,5 @@ import json +from collections import defaultdict from collections.abc import Sequence from functools import cache from itertools import groupby @@ -13,11 +14,23 @@ WYCKOFF_MULTIPLICITY_DICT, WYCKOFF_POSITION_RELAB_DICT, ) +from pymatgen.core import Element from torch import LongTensor, Tensor from torch.utils.data import Dataset from aviary import PKG_DIR +with open(f"{PKG_DIR}/embeddings/wyckoff/bra-alg-off.json") as f: + sym_embeddings = json.load(f) +WYCKOFF_SPG_LETTER_MAP: dict[str, dict[str, int]] = defaultdict(dict) +i = 0 +for spg_num, embeddings in sym_embeddings.items(): + for wyckoff_letter in embeddings: + WYCKOFF_SPG_LETTER_MAP[spg_num][wyckoff_letter] = i + i += 1 + +del sym_embeddings + class WyckoffData(Dataset): """Wyckoff dataset class for the Wren model.""" @@ -26,8 +39,6 @@ def __init__( self, df: pd.DataFrame, task_dict: dict[str, str], - elem_embedding: str = "matscholar200", - sym_emb: str = "bra-alg-off", inputs: str = "protostructure", identifiers: Sequence[str] = ("material_id", "composition", "protostructure"), ): @@ -37,11 +48,6 @@ def __init__( df (pd.DataFrame): Pandas dataframe holding input and target values. task_dict (dict[str, "regression" | "classification"]): Map from target names to task type for multi-task learning. - elem_embedding (str, optional): One of "matscholar200", "cgcnn92", - "megnet16", "onehot112" or path to a file with custom element - embeddings. Defaults to "matscholar200". - sym_emb (str): Symmetry embedding. One of "bra-alg-off" (default) or - "spg-alg-off" or path to a file with custom symmetry embeddings. inputs (str, optional): df columns to be used for featurization. Defaults to "protostructure". identifiers (list, optional): df columns for distinguishing data points. @@ -56,24 +62,6 @@ def __init__( self.identifiers = list(identifiers) self.df = df - if elem_embedding in ("matscholar200", "cgcnn92", "megnet16", "onehot112"): - elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" - - with open(elem_embedding) as emb_file: - self.elem_features = json.load(emb_file) - - self.elem_emb_len = len(next(iter(self.elem_features.values()))) - - if sym_emb in ("bra-alg-off", "spg-alg-off"): - sym_emb = f"{PKG_DIR}/embeddings/wyckoff/{sym_emb}.json" - - with open(sym_emb) as sym_file: - self.sym_features = json.load(sym_file) - - self.sym_emb_len = len( - next(iter(next(iter(self.sym_features.values())).values())) - ) - self.n_targets = [] for target, task in self.task_dict.items(): if task == "regression": @@ -113,23 +101,13 @@ def __getitem__(self, idx: int): wyk_site_multiplcities ) - try: - element_features = np.vstack([self.elem_features[el] for el in elements]) - except AssertionError: - print(f"Failed to process elements for {material_ids}") - raise - - try: - symmetry_features = np.vstack( - [ - self.sym_features[spg_num][wyk_site] - for wyckoff_sites in augmented_wyks - for wyk_site in wyckoff_sites - ] - ) - except AssertionError: - print(f"Failed to process Wyckoff positions for {material_ids}") - raise + element_features = [Element(el).Z for el in elements] + + symmetry_features = [ + WYCKOFF_SPG_LETTER_MAP[spg_num][wyk_site] + for wyckoff_sites in augmented_wyks + for wyk_site in wyckoff_sites + ] n_wyks = len(elements) self_idx = [] @@ -147,8 +125,8 @@ def __getitem__(self, idx: int): # convert all data to tensors wyckoff_weights = Tensor(wyk_site_multiplcities) - element_features = Tensor(element_features) - symmetry_features = Tensor(symmetry_features) + element_features = LongTensor(element_features) + symmetry_features = LongTensor(symmetry_features) self_idx = LongTensor(self_aug_fea_idx) nbr_idx = LongTensor(nbr_aug_fea_idx) @@ -198,7 +176,7 @@ def collate_batch( # batch the features together batch_mult_weights.append(mult_weights.repeat((n_aug, 1))) - batch_elem_fea.append(elem_fea.repeat((n_aug, 1))) + batch_elem_fea.append(elem_fea.repeat(n_aug)) batch_sym_fea.append(sym_fea) # mappings from bonds to atoms diff --git a/aviary/wren/model.py b/aviary/wren/model.py index 6e6da581..f89df0c6 100644 --- a/aviary/wren/model.py +++ b/aviary/wren/model.py @@ -9,6 +9,7 @@ from aviary.networks import ResidualNetwork, SimpleNetwork from aviary.scatter import scatter_reduce from aviary.segments import MessageLayer, WeightedAttentionPooling +from aviary.utils import get_element_embedding, get_sym_embedding @due.dcite(Doi("10.1126/sciadv.abn4117"), description="Wren model") @@ -26,8 +27,8 @@ def __init__( self, robust: bool, n_targets: Sequence[int], - elem_emb_len: int, - sym_emb_len: int, + elem_embedding: str = "matscholar200", + sym_embedding: str = "bra-alg-off", elem_fea_len: int = 32, sym_fea_len: int = 32, n_graph: int = 3, @@ -44,6 +45,12 @@ def __init__( """Protostructure based model.""" super().__init__(robust=robust, **kwargs) + self.elem_embedding = get_element_embedding(elem_embedding) + elem_emb_len = self.elem_embedding.weight.shape[1] + + self.sym_embedding = get_sym_embedding(sym_embedding) + sym_emb_len = self.sym_embedding.weight.shape[1] + desc_dict = { "elem_emb_len": elem_emb_len, "elem_fea_len": elem_fea_len, @@ -62,6 +69,8 @@ def __init__( model_params = { "robust": robust, + "elem_embedding": elem_embedding, + "sym_embedding": sym_embedding, "n_targets": n_targets, "out_hidden": out_hidden, "trunk_hidden": trunk_hidden, @@ -92,6 +101,8 @@ def forward( aug_cry_idx: LongTensor, ) -> tuple[Tensor, ...]: """Forward pass through the material_nn and output_nn.""" + elem_fea = self.elem_embedding(elem_fea) + sym_fea = self.sym_embedding(sym_fea) crys_fea = self.material_nn( elem_weights, elem_fea, diff --git a/examples/cgcnn-example.py b/examples/cgcnn-example.py index aeb4cf57..7b97d288 100644 --- a/examples/cgcnn-example.py +++ b/examples/cgcnn-example.py @@ -89,24 +89,18 @@ def main( task_dict = dict(zip(targets, tasks, strict=False)) loss_dict = dict(zip(targets, losses, strict=False)) - dist_dict = { - "radius": radius, - "max_num_nbr": max_num_nbr, - "dmin": dmin, - "step": step, - } - # NOTE make sure to use dense datasets, here do not use the default na # as they can clash with "NaN" which is a valid material df = pd.read_json(data_path) df["structure"] = df.structure.map(Structure.from_dict) dataset = CrystalGraphData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict, **dist_dict + df=df, + task_dict=task_dict, + max_num_nbr=max_num_nbr, + radius_cutoff=radius, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - nbr_fea_len = dataset.nbr_fea_dim train_idx = list(range(len(dataset))) @@ -119,7 +113,10 @@ def main( print(f"using independent test set: {test_path}") test_set = CrystalGraphData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict, **dist_dict + df=df, + task_dict=task_dict, + max_num_nbr=max_num_nbr, + radius_cutoff=radius, ) test_set = torch.utils.data.Subset(test_set, range(len(test_set))) elif test_size == 0: @@ -140,7 +137,10 @@ def main( print(f"using independent validation set: {val_path}") val_set = CrystalGraphData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict, **dist_dict + df=df, + task_dict=task_dict, + max_num_nbr=max_num_nbr, + radius_cutoff=radius, ) val_set = torch.utils.data.Subset(val_set, range(len(val_set))) elif val_size == 0 and evaluate: @@ -192,8 +192,10 @@ def main( "task_dict": task_dict, "robust": robust, "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "nbr_fea_len": nbr_fea_len, + "elem_embedding": elem_embedding, + "radius_cutoff": radius, + "radius_min": dmin, + "radius_step": step, "elem_fea_len": elem_fea_len, "n_graph": n_graph, "h_fea_len": h_fea_len, diff --git a/examples/notebooks/Roost.ipynb b/examples/notebooks/Roost.ipynb index ec802b15..b95a83f2 100644 --- a/examples/notebooks/Roost.ipynb +++ b/examples/notebooks/Roost.ipynb @@ -113,7 +113,7 @@ "\n", "ensemble = 1\n", "run_id = 1\n", - "epochs = 100\n", + "epochs = 1\n", "log = False\n", "\n", "# NOTE setting workers to zero means that the data is loaded in the main\n", @@ -161,11 +161,9 @@ "\n", "dataset = CompositionData(\n", " df=df,\n", - " elem_embedding=elem_embedding,\n", " task_dict=task_dict,\n", ")\n", "n_targets = dataset.n_targets\n", - "elem_emb_len = dataset.elem_emb_len\n", "\n", "train_idx = list(range(len(dataset)))\n", "\n", @@ -192,7 +190,7 @@ " \"task_dict\": task_dict,\n", " \"robust\": robust,\n", " \"n_targets\": n_targets,\n", - " \"elem_emb_len\": elem_emb_len,\n", + " \"elem_embedding\": elem_embedding,\n", " \"elem_fea_len\": 64,\n", " \"n_graph\": 3,\n", " \"elem_heads\": 3,\n", @@ -211,16 +209,22 @@ " run_id=run_id,\n", " ensemble_folds=ensemble,\n", " epochs=epochs,\n", - " train_set=train_set,\n", - " val_set=val_set,\n", + " train_loader=train_loader,\n", + " val_loader=val_loader,\n", " log=log,\n", - " data_params=data_params,\n", " setup_params=setup_params,\n", " restart_params=restart_params,\n", " model_params=model_params,\n", " loss_dict=loss_dict,\n", - ")\n", - "\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "test_loader = DataLoader(\n", " test_set,\n", " **{**data_params, \"batch_size\": 64 * data_params[\"batch_size\"], \"shuffle\": False},\n", @@ -231,8 +235,7 @@ " model_name=model_name,\n", " run_id=run_id,\n", " ensemble_folds=ensemble,\n", - " test_set=test_set,\n", - " data_params=data_params,\n", + " test_loader=test_loader,\n", " robust=robust,\n", " task_dict=task_dict,\n", " device=device,\n", diff --git a/examples/notebooks/Wren.ipynb b/examples/notebooks/Wren.ipynb index 3a41df2c..01b32c32 100644 --- a/examples/notebooks/Wren.ipynb +++ b/examples/notebooks/Wren.ipynb @@ -75,7 +75,12 @@ "df = df[df.protostructure.map(count_wyckoff_positions) < 16]\n", "df[\"n_sites\"] = df.final_structure.map(len)\n", "df = df[df.n_sites < 64]\n", - "df = df[df.volume_per_atom < 500]" + "df = df[df.volume_per_atom < 500]\n", + "\n", + "# NOTE for roost we keep only the lowest lying structures for each composition\n", + "df = df.sort_values([\"protostructure\", \"E_vasp_per_atom\"]).drop_duplicates(\n", + " \"protostructure\", keep=\"first\"\n", + ")" ] }, { @@ -108,7 +113,7 @@ "\n", "ensemble = 1\n", "run_id = 1\n", - "epochs = 100\n", + "epochs = 1\n", "log = False\n", "\n", "# NOTE setting workers to zero means that the data is loaded in the main\n", @@ -149,18 +154,14 @@ "torch.manual_seed(0) # ensure reproducible results\n", "\n", "elem_embedding = \"matscholar200\"\n", - "sym_emb = \"bra-alg-off\"\n", + "sym_embedding = \"bra-alg-off\"\n", "model_name = \"wren-reg-test\"\n", "\n", "data_params[\"collate_fn\"] = wren_cb\n", "data_params[\"shuffle\"] = True\n", "\n", - "dataset = WyckoffData(\n", - " df=df, elem_embedding=elem_embedding, sym_emb=sym_emb, task_dict=task_dict\n", - ")\n", + "dataset = WyckoffData(df=df, task_dict=task_dict)\n", "n_targets = dataset.n_targets\n", - "elem_emb_len = dataset.elem_emb_len\n", - "sym_emb_len = dataset.sym_emb_len\n", "\n", "train_idx = list(range(len(dataset)))\n", "\n", @@ -187,9 +188,9 @@ " \"task_dict\": task_dict,\n", " \"robust\": robust,\n", " \"n_targets\": n_targets,\n", - " \"elem_emb_len\": elem_emb_len,\n", + " \"elem_embedding\": elem_embedding,\n", " \"elem_fea_len\": 32,\n", - " \"sym_emb_len\": sym_emb_len,\n", + " \"sym_embedding\": sym_embedding,\n", " \"sym_fea_len\": 32,\n", " \"n_graph\": 3,\n", " \"elem_heads\": 1,\n", @@ -215,8 +216,15 @@ " restart_params=restart_params,\n", " model_params=model_params,\n", " loss_dict=loss_dict,\n", - ")\n", - "\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "test_loader = DataLoader(\n", " test_set,\n", " **{**data_params, \"batch_size\": 64 * data_params[\"batch_size\"], \"shuffle\": False},\n", @@ -253,7 +261,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.12.9" }, "vscode": { "interpreter": { diff --git a/examples/notebooks/Wrenformer.ipynb b/examples/notebooks/Wrenformer.ipynb index 86780193..11703c61 100644 --- a/examples/notebooks/Wrenformer.ipynb +++ b/examples/notebooks/Wrenformer.ipynb @@ -55,62 +55,7 @@ "execution_count": null, "id": "2", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n" - ] - } - ], + "outputs": [], "source": [ "with gzip.open(\"taata.json.gz\", \"r\") as fin:\n", " json_bytes = fin.read()\n", @@ -129,7 +74,12 @@ "df = df[df.protostructure.map(count_wyckoff_positions) < 16]\n", "df[\"n_sites\"] = df.final_structure.map(len)\n", "df = df[df.n_sites < 64]\n", - "df = df[df.volume_per_atom < 500]" + "df = df[df.volume_per_atom < 500]\n", + "\n", + "# NOTE for roost we keep only the lowest lying structures for each composition\n", + "df = df.sort_values([\"protostructure\", \"E_vasp_per_atom\"]).drop_duplicates(\n", + " \"protostructure\", keep=\"first\"\n", + ")" ] }, { @@ -162,7 +112,7 @@ "\n", "ensemble = 1\n", "run_id = 1\n", - "epochs = 3\n", + "epochs = 1\n", "log = False\n", "\n", "# NOTE setting workers to zero means that the data is loaded in the main\n", @@ -198,38 +148,7 @@ "execution_count": null, "id": "4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "using 0.2 of training set as test set\n", - "No validation set used, using test set for evaluation purposes\n", - "Total Number of Trainable Parameters: 5,166,658\n", - "Dummy MAE: 0.9223\n", - "Epoch: [0/2]\n", - " train: E_vasp_per_atom N 76 MAE 0.89 Loss 1.09 RMSE 1.13 \n", - " evaluate: E_vasp_per_atom N 2 MAE 0.71 Loss 0.83 RMSE 0.95 \n", - "Epoch: [1/2]\n", - " train: E_vasp_per_atom N 76 MAE 0.57 Loss 0.60 RMSE 0.78 \n", - " evaluate: E_vasp_per_atom N 2 MAE 0.53 Loss 0.51 RMSE 0.71 \n", - "Epoch: [2/2]\n", - " train: E_vasp_per_atom N 76 MAE 0.45 Loss 0.36 RMSE 0.62 \n", - " evaluate: E_vasp_per_atom N 2 MAE 0.37 Loss 0.19 RMSE 0.52 \n", - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "------------Evaluate model on Test Set------------\n", - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "\n", - "Evaluating Model\n", - "\n", - "Task: target_name='E_vasp_per_atom' on test set\n", - "Model Performance Metrics:\n", - "R2 Score: 0.7842 \n", - "MAE: 0.3913\n", - "RMSE: 0.5500\n" - ] - } - ], + "outputs": [], "source": [ "torch.manual_seed(0) # ensure reproducible results\n", "\n", @@ -303,8 +222,15 @@ " restart_params=restart_params,\n", " model_params=model_params,\n", " loss_dict=loss_dict,\n", - ")\n", - "\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "test_loader = df_to_in_mem_dataloader(\n", " test_df,\n", " batch_size=batch_size * 64,\n", diff --git a/examples/roost-example.py b/examples/roost-example.py index effd0716..c8410e77 100644 --- a/examples/roost-example.py +++ b/examples/roost-example.py @@ -86,9 +86,8 @@ def main( # NOTE do not use default_na as "NaN" is a valid material df = pd.read_csv(data_path, keep_default_na=False, na_values=[]) - dataset = CompositionData(df=df, elem_embedding=elem_embedding, task_dict=task_dict) + dataset = CompositionData(df=df, task_dict=task_dict) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len train_idx = list(range(len(dataset))) @@ -99,9 +98,7 @@ def main( df = pd.read_csv(test_path, keep_default_na=False, na_values=[]) print(f"using independent test set: {test_path}") - test_set = CompositionData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict - ) + test_set = CompositionData(df=df, task_dict=task_dict) test_set = torch.utils.data.Subset(test_set, range(len(test_set))) elif test_size == 0: raise ValueError("test-size must be non-zero to evaluate model") @@ -119,9 +116,7 @@ def main( df = pd.read_csv(val_path, keep_default_na=False, na_values=[]) print(f"using independent validation set: {val_path}") - val_set = CompositionData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict - ) + val_set = CompositionData(df=df, task_dict=task_dict) val_set = torch.utils.data.Subset(val_set, range(len(val_set))) elif val_size == 0 and evaluate: print("No validation set used, using test set for evaluation purposes") @@ -172,7 +167,7 @@ def main( "task_dict": task_dict, "robust": robust, "n_targets": n_targets, - "elem_emb_len": elem_emb_len, + "elem_embedding": elem_embedding, "elem_fea_len": elem_fea_len, "n_graph": n_graph, "elem_heads": 3, diff --git a/examples/wren-example.py b/examples/wren-example.py index 73031596..de4ceba0 100644 --- a/examples/wren-example.py +++ b/examples/wren-example.py @@ -18,7 +18,7 @@ def main( losses, robust, elem_embedding="matscholar200", - sym_emb="bra-alg-off", + sym_embedding="bra-alg-off", model_name="wren", sym_fea_len=32, elem_fea_len=32, @@ -90,12 +90,8 @@ def main( # NOTE do not use default_na as "NaN" is a valid material composition df = pd.read_csv(data_path, keep_default_na=False, na_values=[]) - dataset = WyckoffData( - df=df, elem_embedding=elem_embedding, sym_emb=sym_emb, task_dict=task_dict - ) + dataset = WyckoffData(df=df, task_dict=task_dict) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - sym_emb_len = dataset.sym_emb_len train_idx = list(range(len(dataset))) @@ -108,8 +104,6 @@ def main( print(f"using independent test set: {test_path}") test_set = WyckoffData( df=df, - elem_embedding=elem_embedding, - sym_emb=sym_emb, task_dict=task_dict, ) test_set = torch.utils.data.Subset(test_set, range(len(test_set))) @@ -131,8 +125,6 @@ def main( print(f"using independent validation set: {val_path}") val_set = WyckoffData( df=df, - elem_embedding=elem_embedding, - sym_emb=sym_emb, task_dict=task_dict, ) val_set = torch.utils.data.Subset(val_set, range(len(val_set))) @@ -184,8 +176,8 @@ def main( "task_dict": task_dict, "robust": robust, "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "sym_emb_len": sym_emb_len, + "elem_embedding": elem_embedding, + "sym_embedding": sym_embedding, "elem_fea_len": elem_fea_len, "sym_fea_len": sym_fea_len, "n_graph": n_graph, diff --git a/pyproject.toml b/pyproject.toml index 96a75b8c..f2b8bd45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools>=61.2"] -build-backend = "setuptools.build_meta" +requires = ["hatchling"] +build-backend = "hatchling.build" [project] name = "aviary" @@ -8,7 +8,7 @@ version = "1.2.0" description = "A collection of machine learning models for materials discovery" authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }] readme = "README.md" -license = { file = "license" } +license = { file = "LICENSE" } keywords = [ "Graph Neural Network", "Machine Learning", @@ -52,14 +52,14 @@ Repo = "https://github.com/CompRhys/aviary" test = ["matminer", "moyopy>=0.3.3", "pytest", "pytest-cov"] moyopy = ["moyopy>=0.3.3"] -[tool.setuptools.packages] -find = { include = ["aviary*"], exclude = ["tests*"] } +[tool.hatch.build.targets.wheel] +packages = ["aviary"] -[tool.setuptools.package-data] -aviary = ["**/**/*.json", "**/*.json"] - -[tool.distutils.bdist_wheel] -universal = true +[tool.hatch.build] +include = [ + "aviary/**/*.py", + "aviary/**/*.json", +] [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/tests/test_cgcnn.py b/tests/test_cgcnn.py index e79d2fb6..4f99cca8 100644 --- a/tests/test_cgcnn.py +++ b/tests/test_cgcnn.py @@ -20,6 +20,11 @@ def base_config(): "log": False, "sample": 1, "test_size": 0.2, + "radius": 5, + "max_num_nbr": 12, + "dmin": 0, + "step": 0.2, + "patience": None, } @@ -63,12 +68,11 @@ def test_cgcnn_regression( dataset = CrystalGraphData( df=df_matbench_phonons, - elem_embedding=base_config["elem_embedding"], task_dict=task_dict, + max_num_nbr=base_config["max_num_nbr"], + radius_cutoff=base_config["radius"], ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - nbr_fea_len = dataset.nbr_fea_dim train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -112,8 +116,10 @@ def test_cgcnn_regression( "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "nbr_fea_len": nbr_fea_len, + "elem_embedding": base_config["elem_embedding"], + "radius_cutoff": base_config["radius"], + "radius_min": base_config["dmin"], + "radius_step": base_config["step"], **model_architecture, } @@ -123,6 +129,7 @@ def test_cgcnn_regression( run_id=base_config["run_id"], ensemble_folds=base_config["ensemble"], epochs=epochs, + patience=base_config["patience"], train_loader=train_loader, val_loader=val_loader, log=base_config["log"], @@ -154,12 +161,12 @@ def test_cgcnn_regression( targets = results_dict[target_name]["targets"] y_ens = np.mean(preds, axis=0) - mae, rmse, r2 = get_metrics(targets, y_ens, task).values() + metrics = get_metrics(targets, y_ens, task) assert len(targets) == len(test_set) == len(test_idx) - assert r2 > 0.7 - assert mae < 150 - assert rmse < 300 + assert metrics["R2"] > 0.7 + assert metrics["MAE"] < 150 + assert metrics["RMSE"] < 300 def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, training_config): @@ -174,30 +181,20 @@ def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, trainin dataset = CrystalGraphData( df=df_matbench_phonons, - elem_embedding=base_config["elem_embedding"], task_dict=task_dict, + max_num_nbr=base_config["max_num_nbr"], + radius_cutoff=base_config["radius"], ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - nbr_fea_len = dataset.nbr_fea_dim train_idx = list(range(len(dataset))) - - print(f"using {base_config['test_size']} of training set as test set") train_idx, test_idx = split( train_idx, random_state=base_config["data_seed"], test_size=base_config["test_size"], ) test_set = torch.utils.data.Subset(dataset, test_idx) - - print("No validation set used, using test set for evaluation purposes") - # NOTE that when using this option care must be taken not to - # peak at the test-set. The only valid model to use is the one - # obtained after the final epoch where the epoch count is - # decided in advance of the experiment. val_set = test_set - train_set = torch.utils.data.Subset(dataset, train_idx[0 :: base_config["sample"]]) data_params = { @@ -232,8 +229,10 @@ def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, trainin "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "nbr_fea_len": nbr_fea_len, + "elem_embedding": base_config["elem_embedding"], + "radius_cutoff": base_config["radius"], + "radius_min": base_config["dmin"], + "radius_step": base_config["step"], **model_architecture, } @@ -243,6 +242,7 @@ def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, trainin run_id=base_config["run_id"], ensemble_folds=base_config["ensemble"], epochs=epochs, + patience=base_config["patience"], train_loader=train_loader, val_loader=val_loader, log=base_config["log"], @@ -273,14 +273,12 @@ def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, trainin logits = results_dict["phdos_clf"]["logits"] targets = results_dict["phdos_clf"]["targets"] - # calculate metrics and errors with associated errors for ensembles ens_logits = np.mean(logits, axis=0) - - ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values() + metrics = get_metrics(targets, ens_logits, task) assert len(targets) == len(test_set) == len(test_idx) - assert ens_acc > 0.85 - assert ens_roc_auc > 0.9 + assert metrics["accuracy"] > 0.85 + assert metrics["ROCAUC"] > 0.9 if __name__ == "__main__": diff --git a/tests/test_roost.py b/tests/test_roost.py index 7ff74ec4..8be8b594 100644 --- a/tests/test_roost.py +++ b/tests/test_roost.py @@ -69,11 +69,9 @@ def test_roost_regression( dataset = CompositionData( df=df_matbench_phonons, - elem_embedding=base_config["elem_embedding"], task_dict=task_dict, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -111,7 +109,7 @@ def test_roost_regression( "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, + "elem_embedding": base_config["elem_embedding"], **model_architecture, # unpack all model architecture parameters } @@ -179,11 +177,9 @@ def test_roost_clf(df_matbench_phonons, base_config, model_architecture, trainin dataset = CompositionData( df=df_matbench_phonons, - elem_embedding=base_config["elem_embedding"], task_dict=task_dict, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -221,7 +217,7 @@ def test_roost_clf(df_matbench_phonons, base_config, model_architecture, trainin "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, + "elem_embedding": base_config["elem_embedding"], **model_architecture, # unpack all model architecture parameters } diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..1a7623ea --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,139 @@ +import json + +import numpy as np +import pandas as pd +import pytest +import torch + +from aviary.utils import get_element_embedding, get_metrics, get_sym_embedding + + +@pytest.fixture +def temp_element_embedding(tmp_path): + embedding_data = { + "H": [1.0, 2.0], + "He": [3.0, 4.0], + "Li": [5.0, 6.0], + } + path = tmp_path / "test_elem_embedding.json" + with open(path, "w") as f: + json.dump(embedding_data, f) + return str(path) + + +@pytest.fixture +def temp_sym_embedding(tmp_path): + embedding_data = { + "1": {"a": [1.0, 2.0], "b": [3.0, 4.0]}, + "2": {"c": [5.0, 6.0]}, + } + path = tmp_path / "test_sym_embedding.json" + with open(path, "w") as f: + json.dump(embedding_data, f) + return str(path) + + +def test_get_element_embedding_custom(temp_element_embedding): + embedding = get_element_embedding(temp_element_embedding) + assert isinstance(embedding, torch.nn.Embedding) + assert embedding.weight.shape == (3 + 1, 2) # max_Z + 1, embedding_dim + assert torch.allclose(embedding.weight[1], torch.tensor([1.0, 2.0])) # H + assert torch.allclose(embedding.weight[2], torch.tensor([3.0, 4.0])) # He + + +def test_get_element_embedding_builtin(): + embedding = get_element_embedding("matscholar200") + assert isinstance(embedding, torch.nn.Embedding) + assert embedding.weight.shape[1] == 200 + + +def test_get_element_embedding_invalid(): + with pytest.raises(ValueError, match="Invalid element embedding: invalid_embedding"): + get_element_embedding("invalid_embedding") + + +def test_get_sym_embedding_custom(temp_sym_embedding): + embedding = get_sym_embedding(temp_sym_embedding) + assert isinstance(embedding, torch.nn.Embedding) + assert embedding.weight.shape == (3, 2) # total features, embedding_dim + assert torch.allclose(embedding.weight[0], torch.tensor([1.0, 2.0])) + assert torch.allclose(embedding.weight[1], torch.tensor([3.0, 4.0])) + + +def test_get_sym_embedding_builtin(): + embedding = get_sym_embedding("bra-alg-off") + assert isinstance(embedding, torch.nn.Embedding) + assert isinstance(embedding.weight, torch.Tensor) + + +def test_get_sym_embedding_invalid(): + with pytest.raises(ValueError, match="Invalid symmetry embedding: invalid_embedding"): + get_sym_embedding("invalid_embedding") + + +def test_regression_metrics(): + targets = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + predictions = np.array([1.1, 2.1, 3.1, 4.1, 5.1]) + + metrics = get_metrics(targets, predictions, "regression") + + assert set(metrics.keys()) == {"MAE", "RMSE", "R2"} + assert metrics["MAE"] == pytest.approx(0.1, abs=1e-4) + assert metrics["RMSE"] == pytest.approx(0.1, abs=1e-4) + assert metrics["R2"] == pytest.approx(0.995, abs=1e-4) + + +def test_classification_metrics(): + targets = np.array([0, 1, 0, 1, 0]) + # Probabilities for class 0 and 1 + predictions = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]]) + + metrics = get_metrics(targets, predictions, "classification") + + assert set(metrics.keys()) == {"accuracy", "balanced_accuracy", "F1", "ROCAUC"} + assert metrics["accuracy"] == 1.0 + assert metrics["balanced_accuracy"] == 1.0 + assert metrics["F1"] == 1.0 + assert metrics["ROCAUC"] == 1.0 + + +def test_nan_handling(): + targets = np.array([1.0, np.nan, 3.0, 4.0]) + predictions = np.array([1.1, 2.1, np.nan, 4.1]) + + metrics = get_metrics(targets, predictions, "regression") + assert not np.isnan(metrics["MAE"]) + assert not np.isnan(metrics["RMSE"]) + assert not np.isnan(metrics["R2"]) + + +def test_pandas_input(): + targets = pd.Series([1.0, 2.0, 3.0]) + predictions = pd.Series([1.1, 2.1, 3.1]) + + metrics = get_metrics(targets, predictions, "regression") + assert set(metrics.keys()) == {"MAE", "RMSE", "R2"} + + +def test_precision(): + targets = np.array([1.0, 2.0, 3.0]) + predictions = np.array([1.12345, 2.12345, 3.12345]) + + metrics = get_metrics(targets, predictions, "regression", prec=2) + assert all(len(str(v).split(".")[-1]) <= 2 for v in metrics.values()) + + +def test_invalid_type(): + targets = np.array([1.0, 2.0]) + predictions = np.array([1.1, 2.1]) + + with pytest.raises(ValueError, match="Invalid task type: invalid_type"): + get_metrics(targets, predictions, "invalid_type") + + +def test_mismatched_shapes(): + targets = np.array([0, 1, 0]) + predictions = np.array([[0.9, 0.1], [0.1, 0.9]]) # Wrong shape + + with pytest.raises(ValueError): # noqa: PT011 + get_metrics(targets, predictions, "classification") diff --git a/tests/test_wren.py b/tests/test_wren.py index bd391391..dffd715b 100644 --- a/tests/test_wren.py +++ b/tests/test_wren.py @@ -13,7 +13,7 @@ def base_config(): return { "elem_embedding": "matscholar200", - "sym_emb": "bra-alg-off", + "sym_embedding": "bra-alg-off", "robust": True, "ensemble": 2, "run_id": 1, @@ -71,13 +71,9 @@ def test_wren_regression( dataset = WyckoffData( df=df_matbench_phonons_wyckoff, - elem_embedding=base_config["elem_embedding"], - sym_emb=base_config["sym_emb"], task_dict=task_dict, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - sym_emb_len = dataset.sym_emb_len train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -122,8 +118,8 @@ def test_wren_regression( "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "sym_emb_len": sym_emb_len, + "elem_embedding": base_config["elem_embedding"], + "sym_embedding": base_config["sym_embedding"], **model_architecture, } @@ -186,13 +182,9 @@ def test_wren_clf( dataset = WyckoffData( df=df_matbench_phonons_wyckoff, - elem_embedding=base_config["elem_embedding"], - sym_emb=base_config["sym_emb"], task_dict=task_dict, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - sym_emb_len = dataset.sym_emb_len train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -237,8 +229,8 @@ def test_wren_clf( "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "sym_emb_len": sym_emb_len, + "elem_embedding": base_config["elem_embedding"], + "sym_embedding": base_config["sym_embedding"], **model_architecture, }