Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,26 @@ on:

jobs:
tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-14]
version:
- { python: "3.10", resolution: highest }
- { python: "3.12", resolution: lowest-direct }
runs-on: ${{ matrix.os }}

steps:
- name: Check out repo
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.9
cache: pip
cache-dependency-path: pyproject.toml
python-version: ${{ matrix.version.python }}

- name: Install uv
run: pip install uv
- name: Set up uv
uses: astral-sh/setup-uv@v2

- name: Install dependencies
run: |
Expand Down
16 changes: 8 additions & 8 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@

import numpy as np
import torch
from pymatgen.analysis.prototypes import (
RE_SUBST_ONE_PREFIX,
RE_WYCKOFF_NO_PREFIX,
WYCKOFF_MULTIPLICITY_DICT,
WYCKOFF_POSITION_RELAB_DICT,
)
from torch import LongTensor, Tensor
from torch.utils.data import Dataset

from aviary import PKG_DIR
from aviary.wren.utils import (
RE_SUBST_ONE_PREFIX,
RE_WYCKOFF_NO_PREFIX,
relab_dict,
wyckoff_multiplicity_dict,
)

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -300,13 +300,13 @@ def parse_protostructure_label(
elements.extend([el] * mult)
wyckoff_set.extend([letter] * mult)
wyckoff_site_multiplicities.extend(
[float(wyckoff_multiplicity_dict[spg_num][letter])] * mult
[float(WYCKOFF_MULTIPLICITY_DICT[spg_num][letter])] * mult
)

# Create augmented Wyckoff set
augmented_wyckoff_set = {
tuple(",".join(wyckoff_set).translate(str.maketrans(trans)).split(","))
for trans in relab_dict[spg_num]
for trans in WYCKOFF_POSITION_RELAB_DICT[spg_num]
}

return spg_num, wyckoff_site_multiplicities, elements, list(augmented_wyckoff_set)
Loading