Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ examples/**/job-logs/
examples/**/artifacts/
examples/**/*.csv
wandb/

# profiling
*.prof
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![GitHub last commit](https://img.shields.io/github/last-commit/comprhys/aviary?label=Last+Commit)](https://github.com/comprhys/aviary/commits)
[![Tests](https://github.com/CompRhys/aviary/actions/workflows/test.yml/badge.svg)](https://github.com/CompRhys/aviary/actions/workflows/test.yml)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/CompRhys/aviary/main.svg)](https://results.pre-commit.ci/latest/github/CompRhys/aviary/main)
[![This project supports Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
[![This project supports Python 3.10+](https://img.shields.io/badge/Python-3.10+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)

</h4>

Expand Down Expand Up @@ -50,6 +50,10 @@ python examples/roost-example.py --train --evaluate --data-path examples/inputs/
python examples/wren-example.py --train --evaluate --data-path examples/inputs/examples.csv --targets E_f --tasks regression --losses L1 --robust --epoch 10
```

```sh
python examples/wrenformer-example.py --train --evaluate --data-path examples/inputs/examples.csv --targets E_f --tasks regression --losses L1 --robust --epoch 10
```

```sh
python examples/cgcnn-example.py --train --evaluate --data-path examples/inputs/examples.json --targets E_f --tasks regression --losses L1 --robust --epoch 10
```
Expand Down
24 changes: 11 additions & 13 deletions aviary/cgcnn/data.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
from __future__ import annotations

import itertools
import json
from collections.abc import Sequence
from functools import cache
from typing import TYPE_CHECKING, Any
from typing import Any

import numpy as np
import pandas as pd
import torch
from pymatgen.core import Structure
from torch import LongTensor, Tensor
from torch.utils.data import Dataset
from tqdm import tqdm

from aviary import PKG_DIR

if TYPE_CHECKING:
from collections.abc import Sequence

import pandas as pd
from pymatgen.core import Structure


class CrystalGraphData(Dataset):
"""Dataset class for the CGCNN structure model."""
Expand Down Expand Up @@ -253,9 +248,10 @@ def collate_batch(
return (
(atom_fea, nbr_dist, self_idx, nbr_idx, cry_idx),
tuple(
torch.stack(b_target, dim=0).to(device) for b_target in zip(*batch_targets)
torch.stack(b_target, dim=0).to(device)
for b_target in zip(*batch_targets, strict=False)
),
*zip(*batch_identifiers),
*zip(*batch_identifiers, strict=False),
)


Expand Down Expand Up @@ -332,10 +328,12 @@ def get_structure_neighbor_info(
_neighbor_dists: list[float] = []

for _, idx_group in itertools.groupby( # group by site index
zip(site_indices, neighbor_indices, neighbor_dists), key=lambda x: x[0]
zip(site_indices, neighbor_indices, neighbor_dists, strict=False),
key=lambda x: x[0],
):
site_indices, neighbor_idx, neighbor_dist = zip(
*sorted(idx_group, key=lambda x: x[2]) # sort by distance
*sorted(idx_group, key=lambda x: x[2]),
strict=False, # sort by distance
)
_center_indices.extend(site_indices[:max_num_nbr])
_neighbor_indices.extend(neighbor_idx[:max_num_nbr])
Expand Down
9 changes: 2 additions & 7 deletions aviary/cgcnn/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from collections.abc import Sequence

import torch
import torch.nn.functional as F
Expand All @@ -11,9 +9,6 @@
from aviary.networks import SimpleNetwork
from aviary.scatter import scatter_reduce

if TYPE_CHECKING:
from collections.abc import Sequence


@due.dcite(Doi("10.1103/PhysRevLett.120.145301"), description="CGCNN model")
class CrystalGraphConvNet(BaseModelClass):
Expand Down Expand Up @@ -215,7 +210,7 @@ def forward(
Args:
atom_in_fea (Tensor): Atom hidden features before convolution
nbr_fea (Tensor): Bond features of each atom's neighbors
self_idx (LongTensor): _description_
self_idx (LongTensor): Indices of the atom's self
nbr_idx (LongTensor): Indices of M neighbors of each atom

Returns:
Expand Down
135 changes: 40 additions & 95 deletions aviary/core.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
from __future__ import annotations

import gc
import os
import shutil
from abc import ABC
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Literal
from collections.abc import Callable, Mapping
from typing import Any, Literal

import numpy as np
import torch
import wandb
from sklearn.metrics import f1_score
from torch import BoolTensor, Tensor, nn
from torch.nn.functional import softmax
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from aviary import ROOT

if TYPE_CHECKING:
from collections.abc import Mapping

from torch.utils.data import DataLoader

from aviary.data import InMemoryDataLoader
from aviary.data import InMemoryDataLoader, Normalizer

TaskType = Literal["regression", "classification"]

Expand Down Expand Up @@ -129,6 +123,14 @@ def fit(
for metric, val in metrics.items():
writer.add_scalar(f"{task}/train/{metric}", val, epoch)

if writer == "wandb":
flat_train_metrics = {}
for task, metrics in train_metrics.items():
for metric, val in metrics.items():
flat_train_metrics[f"train_{task}_{metric.lower()}"] = val
flat_train_metrics["epoch"] = epoch
wandb.log(flat_train_metrics)

# Validation
if val_loader is not None:
with torch.no_grad():
Expand All @@ -149,6 +151,14 @@ def fit(
f"{task}/validation/{metric}", val, epoch
)

if writer == "wandb":
flat_val_metrics = {}
for task, metrics in val_metrics.items():
for metric, val in metrics.items():
flat_val_metrics[f"val_{task}_{metric.lower()}"] = val
flat_val_metrics["epoch"] = epoch
wandb.log(flat_val_metrics)

# TODO test all tasks to see if they are best,
# save a best model if any is best.
# TODO what are the costs of this approach.
Expand Down Expand Up @@ -207,9 +217,6 @@ def fit(
# catch memory leak
gc.collect()

if writer == "wandb":
wandb.log({"train": train_metrics, "validation": val_metrics})

except KeyboardInterrupt:
pass

Expand Down Expand Up @@ -271,7 +278,11 @@ def evaluate(
mixed_loss: Tensor = 0 # type: ignore[assignment]

for target_name, targets, output, normalizer in zip(
self.target_names, targets_list, outputs, normalizer_dict.values()
self.target_names,
targets_list,
outputs,
normalizer_dict.values(),
strict=False,
):
task, loss_func = loss_dict[target_name]
target_metrics = epoch_metrics[target_name]
Expand Down Expand Up @@ -318,7 +329,7 @@ def evaluate(
else:
raise ValueError(f"invalid task: {task}")

epoch_metrics[target_name]["Loss"].append(loss.cpu().item())
target_metrics["Loss"].append(loss.cpu().item())

# NOTE multitasking currently just uses a direct sum of individual
# target losses this should be okay but is perhaps sub-optimal
Expand Down Expand Up @@ -396,11 +407,13 @@ def predict(
# for multitask learning
targets = tuple(
torch.cat(targets, dim=0).view(-1).cpu().numpy()
for targets in zip(*test_targets)
for targets in zip(*test_targets, strict=False)
)
predictions = tuple(
torch.cat(preds, dim=0) for preds in zip(*test_preds, strict=False)
)
predictions = tuple(torch.cat(preds, dim=0) for preds in zip(*test_preds))
# identifier columns
ids = tuple(np.concatenate(x) for x in zip(*test_ids))
ids = tuple(np.concatenate(x) for x in zip(*test_ids, strict=False))
return targets, predictions, ids

@torch.no_grad()
Expand Down Expand Up @@ -445,83 +458,6 @@ def __repr__(self) -> str:
return f"{cls_name} with {n_params:,} trainable params at {n_epochs:,} epochs"


class Normalizer:
"""Normalize a Tensor and restore it later."""

def __init__(self) -> None:
"""Initialize Normalizer with mean 0 and std 1."""
self.mean = torch.tensor(0)
self.std = torch.tensor(1)

def fit(self, tensor: Tensor, dim: int = 0, keepdim: bool = False) -> None:
"""Compute the mean and standard deviation of the given tensor.

Args:
tensor (Tensor): Tensor to determine the mean and standard deviation over.
dim (int, optional): Which dimension to take mean and standard deviation
over. Defaults to 0.
keepdim (bool, optional): Whether to keep the reduced dimension in Tensor.
Defaults to False.
"""
self.mean = torch.mean(tensor, dim, keepdim)
self.std = torch.std(tensor, dim, keepdim)

def norm(self, tensor: Tensor) -> Tensor:
"""Normalize a Tensor.

Args:
tensor (Tensor): Tensor to be normalized

Returns:
Tensor: Normalized Tensor
"""
return (tensor - self.mean) / self.std

def denorm(self, normed_tensor: Tensor) -> Tensor:
"""Restore normalized Tensor to original.

Args:
normed_tensor (Tensor): Tensor to be restored

Returns:
Tensor: Restored Tensor
"""
return normed_tensor * self.std + self.mean

def state_dict(self) -> dict[str, Tensor]:
"""Get Normalizer parameters mean and std.

Returns:
dict[str, Tensor]: Dictionary storing Normalizer parameters.
"""
return {"mean": self.mean, "std": self.std}

def load_state_dict(self, state_dict: dict[str, Tensor]) -> None:
"""Overwrite Normalizer parameters given a new state_dict.

Args:
state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters.
"""
self.mean = state_dict["mean"].cpu()
self.std = state_dict["std"].cpu()

@classmethod
def from_state_dict(cls, state_dict: dict[str, Tensor]) -> Normalizer:
"""Create a new Normalizer given a state_dict.

Args:
state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters.

Returns:
Normalizer
"""
instance = cls()
instance.mean = state_dict["mean"].cpu()
instance.std = state_dict["std"].cpu()

return instance


def save_checkpoint(
state: dict[str, Any], is_best: bool, model_name: str, run_id: int
) -> None:
Expand Down Expand Up @@ -662,3 +598,12 @@ def masked_min(x: Tensor, mask: BoolTensor, dim: int = 0) -> Tensor:
x_inf = x.float().masked_fill(~mask, float("inf"))
x_min, _ = x_inf.min(dim=dim)
return x_min


AGGREGATORS: dict[str, Callable[[Tensor, BoolTensor, int], Tensor]] = {
"mean": masked_mean,
"std": masked_std,
"max": masked_max,
"min": masked_min,
"sum": lambda x, mask, dim: (x * mask).sum(dim=dim),
}
Loading