Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.6
rev: v0.11.4
hooks:
- id: ruff
args: [--fix]
Expand All @@ -23,14 +23,14 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.1
hooks:
- id: codespell
exclude_types: [json]
args: [--check-filenames]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.14.1
rev: v1.15.0
hooks:
- id: mypy
exclude: (tests|examples)/
Expand Down
6 changes: 3 additions & 3 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def predict_from_wandb_checkpoints(
print(f"Using checkpoints from {len(runs)} run(s):")

run_target = runs[0].config["target"]
assert all(
run_target == run.config["target"] for run in runs
), f"Runs have differing targets, first {run_target=}"
assert all(run_target == run.config["target"] for run in runs), (
f"Runs have differing targets, first {run_target=}"
)

target_col = kwargs.get("target_col")
if target_col and target_col != run_target:
Expand Down
12 changes: 6 additions & 6 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,9 @@ def checkpoint_model(
torch.save(checkpoint_dict, checkpoint_path)

if checkpoint_endpoint == "wandb":
assert (
wandb.run is not None
), "can't save model checkpoint to Weights and Biases, wandb.run is None"
assert wandb.run is not None, (
"can't save model checkpoint to Weights and Biases, wandb.run is None"
)
torch.save(
checkpoint_dict,
f"{wandb.run.dir}/{timestamp + '-' if timestamp else ''}{run_name}-{epochs}.pth",
Expand Down Expand Up @@ -584,9 +584,9 @@ def df_train_test_split(
if folds:
n_folds, test_fold_idx = folds
assert 1 < n_folds <= 10, f"{n_folds = } must be between 2 and 10"
assert (
0 <= test_fold_idx < n_folds
), f"{test_fold_idx = } must be between 0 and {n_folds - 1}"
assert 0 <= test_fold_idx < n_folds, (
f"{test_fold_idx = } must be between 0 and {n_folds - 1}"
)

df_splits: list[pd.DataFrame] = np.array_split(df, n_folds)
test_df = df_splits.pop(test_fold_idx)
Expand Down
3 changes: 1 addition & 2 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ def parse_protostructure_label(

if len(elems) != len(wyckoff_letters):
raise ValueError(
f"Chemical system {chemsys} does not match Wyckoff letters "
f"{wyckoff_letters}"
f"Chemical system {chemsys} does not match Wyckoff letters {wyckoff_letters}"
)

wyckoff_site_multiplicities = []
Expand Down
12 changes: 6 additions & 6 deletions tests/test_wyckoff_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,9 @@ def test_get_protostructure_label_from_aflow(structure, expected):
@pytest.mark.parametrize("structure, expected", zip(TEST_STRUCTS, TEST_PROTOSTRUCTURES))
def test_get_protostructure_label_from_moyopy(structure, expected):
"""Check that moyopy gives correct protostructure label simple cases."""
assert (
get_protostructure_label_from_moyopy(structure) == expected
), f"unexpected moyopy protostructure for {structure=}"
assert get_protostructure_label_from_moyopy(structure) == expected, (
f"unexpected moyopy protostructure for {structure=}"
)


@pytest.mark.parametrize(
Expand All @@ -365,9 +365,9 @@ def test_moyopy_spglib_consistency(protostructure):
moyopy_label = get_protostructure_label_from_moyopy(struct)
spglib_label = get_protostructure_label_from_spglib(struct)

assert (
moyopy_label == spglib_label
), f"spglib moyopy protostructure mismatch for {protostructure}"
assert moyopy_label == spglib_label, (
f"spglib moyopy protostructure mismatch for {protostructure}"
)


@pytest.mark.skipif(pyxtal is None, reason="pyxtal not installed")
Expand Down