Skip to content

Commit c533935

Browse files
author
Richard Liang
committed
WIP: refactoring the stored HLA standard data to be in a YAML file rather than three CSV files.
1 parent 60a8366 commit c533935

File tree

9 files changed

+101205
-977
lines changed

9 files changed

+101205
-977
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"pydantic>=2.10.6",
3131
"pydantic-numpy>=8.0.1",
3232
"pytz>=2025.1",
33+
"pyyaml>=6.0.2",
3334
"requests>=2.32.3",
3435
"typer>=0.15.2",
3536
]
@@ -61,6 +62,7 @@ Source = "https://github.com/unknown/easyhla"
6162
clinical_hla = "easyhla.clinical_hla:main"
6263
interpret_from_json = "easyhla.interpret_from_json:main"
6364
bblab = "easyhla.bblab:main"
65+
update_alleles = "easyhla.update_alleles:main"
6466

6567
[project.optional-dependencies]
6668
database = [
@@ -113,6 +115,7 @@ omit = [
113115
"src/easyhla/bblab.py",
114116
"src/easyhla/clinical_hla.py",
115117
"src/easyhla/interpret_from_json.py",
118+
"src/easyhla/update_alleles.py",
116119
"src/scripts/*.py",
117120
]
118121

src/easyhla/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
# from scripts.bblab import run as run
2-
3-
from .easyhla import EasyHLA as EasyHLA

src/easyhla/default_data/hla_standards.yaml

Lines changed: 100164 additions & 0 deletions
Large diffs are not rendered by default.

src/easyhla/easyhla.py

Lines changed: 109 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from datetime import datetime
44
from io import TextIOBase
55
from operator import attrgetter
6-
from typing import Final, Optional
6+
from typing import Final, Optional, TypedDict
77

88
import numpy as np
9+
import yaml
910

1011
from .models import (
1112
HLACombinedStandard,
@@ -20,13 +21,21 @@
2021
from .utils import (
2122
BIN2NUC,
2223
HLA_LOCUS,
24+
GroupedAllele,
25+
StoredHLAStandards,
2326
count_strict_mismatches,
2427
nuc2bin,
2528
)
2629

2730
DATE_FORMAT = "%a %b %d %H:%M:%S %Z %Y"
2831

2932

33+
class ProcessedStoredStandards(TypedDict):
34+
tag: str
35+
last_modified: datetime
36+
standards: dict[HLA_LOCUS, dict[str, HLAStandard]]
37+
38+
3039
class EasyHLA:
3140
# For HLA-B interpretations, these alleles are the ones we use to determine
3241
# how close a sequence is to "B*57:01".
@@ -40,61 +49,89 @@ class EasyHLA:
4049

4150
def __init__(
4251
self,
43-
locus: HLA_LOCUS,
44-
hla_standards: Optional[dict[str, HLAStandard]] = None,
45-
hla_frequencies: Optional[dict[HLAProteinPair, int]] = None,
46-
last_modified: Optional[datetime] = None,
52+
hla_standards: Optional[ProcessedStoredStandards] = None,
53+
hla_frequencies: Optional[dict[HLA_LOCUS, dict[HLAProteinPair, int]]] = None,
4754
):
4855
"""
4956
Initialize an EasyHLA class.
5057
51-
:param locus: HLA subtype that this object will be performing
52-
interpretation against.
53-
:type locus: "A", "B", or "C"
5458
:param logger: Python logger object, defaults to None
5559
:type logger: Optional[logging.Logger], optional
56-
:raises ValueError: Raised if locus != "A"/"B"/"C"
5760
"""
58-
if locus not in ["A", "B", "C"]:
59-
raise ValueError("Invalid HLA locus specified; must be A, B, or C")
60-
self.locus: HLA_LOCUS = locus
61+
if hla_standards is None:
62+
hla_standards = self.load_default_hla_standards()
6163

62-
self.hla_standards: dict[str, HLAStandard]
63-
if hla_standards is not None:
64-
self.hla_standards = hla_standards
65-
else:
66-
self.hla_standards = self.load_default_hla_standards()
64+
self.hla_standards: dict[HLA_LOCUS, dict[str, HLAStandard]] = hla_standards[
65+
"standards"
66+
]
67+
self.last_modified: datetime = hla_standards["last_modified"]
68+
self.tag: str = hla_standards["tag"]
6769

68-
self.hla_frequencies: dict[HLAProteinPair, int]
70+
self.hla_frequencies: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
6971
if hla_frequencies is not None:
7072
self.hla_frequencies = hla_frequencies
7173
else:
7274
self.hla_frequencies = self.load_default_hla_frequencies()
7375

74-
self.last_modified: datetime
75-
if last_modified is not None:
76-
self.last_modified = last_modified
77-
else:
78-
self.last_modified = self.load_default_last_modified()
76+
@classmethod
77+
def use_config(
78+
cls,
79+
standards_path: Optional[str],
80+
frequencies_path: Optional[str] = None,
81+
) -> "EasyHLA":
82+
"""
83+
An alternate constructor that accepts file paths for the configuration.
84+
"""
85+
processed_stds: Optional[ProcessedStoredStandards] = None
86+
frequencies: Optional[dict[HLA_LOCUS, dict[HLAProteinPair, int]]] = None
87+
88+
if standards_path is not None:
89+
with open(standards_path) as f:
90+
processed_stds = cls.read_hla_standards(f)
91+
92+
if frequencies_path is not None:
93+
with open(frequencies_path) as f:
94+
frequencies = cls.read_hla_frequencies(f)
95+
96+
return cls(processed_stds, frequencies)
7997

8098
@staticmethod
81-
def read_hla_standards(standards_io: TextIOBase) -> dict[str, HLAStandard]:
99+
def read_hla_standards(
100+
standards_io: TextIOBase,
101+
) -> ProcessedStoredStandards:
82102
"""
83103
Read HLA standards from a specified file-like object.
84104
85105
:return: Dictionary of known HLA standards keyed by their name
86106
:rtype: dict[str, HLAStandard]
87107
"""
88-
hla_stds: dict[str, HLAStandard] = {}
89-
for line in standards_io.readlines():
90-
line_array = line.strip().split(",")
91-
allele: str = line_array[0]
92-
hla_stds[allele] = HLAStandard(
93-
allele=line_array[0],
94-
two=nuc2bin(line_array[1]),
95-
three=nuc2bin(line_array[2]),
96-
)
97-
return hla_stds
108+
stored_stds: StoredHLAStandards = StoredHLAStandards.model_validate(
109+
yaml.safe_load(standards_io)
110+
)
111+
112+
hla_stds: dict[HLA_LOCUS, dict[str, HLAStandard]] = {
113+
"A": {},
114+
"B": {},
115+
"C": {},
116+
}
117+
stored_grouped_alleles: dict[HLA_LOCUS, list[GroupedAllele]] = {
118+
"A": stored_stds.A,
119+
"B": stored_stds.B,
120+
"C": stored_stds.C,
121+
}
122+
for locus in ("A", "B", "C"):
123+
for grouped_allele in stored_grouped_alleles[locus]:
124+
hla_stds[locus][grouped_allele.name] = HLAStandard(
125+
allele=grouped_allele.name,
126+
two=nuc2bin(grouped_allele.exon2),
127+
three=nuc2bin(grouped_allele.exon3),
128+
)
129+
130+
return {
131+
"tag": stored_stds.tag,
132+
"last_modified": stored_stds.last_modified,
133+
"standards": hla_stds,
134+
}
98135

99136
def load_default_hla_standards(self) -> dict[str, HLAStandard]:
100137
"""
@@ -106,44 +143,50 @@ def load_default_hla_standards(self) -> dict[str, HLAStandard]:
106143
standards_filename: str = os.path.join(
107144
os.path.dirname(__file__),
108145
"default_data",
109-
f"hla_{self.locus.lower()}_std_reduced.csv",
146+
"hla_standards.yaml",
110147
)
111148
with open(standards_filename) as standards_file:
112149
return self.read_hla_standards(standards_file)
113150

114151
@staticmethod
115152
def read_hla_frequencies(
116-
locus: HLA_LOCUS,
117153
frequencies_io: TextIOBase,
118-
) -> dict[HLAProteinPair, int]:
154+
) -> dict[HLA_LOCUS, dict[HLAProteinPair, int]]:
119155
"""
120156
Load HLA frequencies from a specified file-like object.
121157
122-
This takes two columns AAAA,BBBB out of 6 (...FFFF), and then uses a
158+
This takes each two columns AAAA,BBBB out of 6 (...FFFF), and then uses a
123159
subset of these two columns (AABB,CCDD) to use as the key, in this case
124160
"AA|BB,CC|DD", we then count the number of times this key appears in our
125161
columns.
126162
127-
:return: Lookup table of HLA frequencies.
128-
:rtype: dict[HLAProteinPair, int]
163+
:return: Lookup table of locus and HLA frequencies.
164+
:rtype: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
129165
"""
130-
hla_freqs: dict[HLAProteinPair, int] = {}
131-
for line in frequencies_io.readlines():
132-
column_id = EasyHLA.COLUMN_IDS[locus]
133-
line_array = line.strip().split(",")[column_id : column_id + 2]
134-
135-
protein_pair: HLAProteinPair = HLAProteinPair(
136-
first_field_1=line_array[0][0:2],
137-
first_field_2=line_array[0][2:4],
138-
second_field_1=line_array[1][0:2],
139-
second_field_2=line_array[1][2:4],
140-
)
141-
if hla_freqs.get(protein_pair, None) is None:
142-
hla_freqs[protein_pair] = 0
143-
hla_freqs[protein_pair] += 1
166+
hla_freqs: dict[HLA_LOCUS, dict[HLAProteinPair, int]] = {
167+
"A": {},
168+
"B": {},
169+
"C": {},
170+
}
171+
for locus in ("A", "B", "C"):
172+
for line in frequencies_io.readlines():
173+
column_id = EasyHLA.COLUMN_IDS[locus]
174+
line_array = line.strip().split(",")[column_id : column_id + 2]
175+
176+
protein_pair: HLAProteinPair = HLAProteinPair(
177+
first_field_1=line_array[0][0:2],
178+
first_field_2=line_array[0][2:4],
179+
second_field_1=line_array[1][0:2],
180+
second_field_2=line_array[1][2:4],
181+
)
182+
if hla_freqs[locus].get(protein_pair, None) is None:
183+
hla_freqs[locus][protein_pair] = 0
184+
hla_freqs[locus][protein_pair] += 1
144185
return hla_freqs
145186

146-
def load_default_hla_frequencies(self) -> dict[HLAProteinPair, int]:
187+
def load_default_hla_frequencies(
188+
self,
189+
) -> dict[HLA_LOCUS, dict[HLAProteinPair, int]]:
147190
"""
148191
Load HLA frequencies from reference file.
149192
@@ -153,35 +196,18 @@ def load_default_hla_frequencies(self) -> dict[HLAProteinPair, int]:
153196
columns.
154197
155198
:return: Lookup table of HLA frequencies.
156-
:rtype: dict[HLAProteinPair, int]
199+
:rtype: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
157200
"""
158-
hla_freqs: dict[HLAProteinPair, int]
201+
hla_freqs: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
159202
default_frequencies_filename: str = os.path.join(
160203
os.path.dirname(__file__),
161204
"default_data",
162205
"hla_frequencies.csv",
163206
)
164207
with open(default_frequencies_filename, "r") as f:
165-
hla_freqs = self.read_hla_frequencies(self.locus, f)
208+
hla_freqs = self.read_hla_frequencies(f)
166209
return hla_freqs
167210

168-
@staticmethod
169-
def load_default_last_modified() -> datetime:
170-
"""
171-
Load a datetime object describing when standard definitions were last updated.
172-
173-
:return: Date and time representing when references were last updated.
174-
:rtype: datetime
175-
"""
176-
filename = os.path.join(
177-
os.path.dirname(__file__),
178-
"default_data",
179-
"hla_nuc.fasta.mtime",
180-
)
181-
with open(filename, "r", encoding="utf-8") as f:
182-
last_mod_date = "".join(f.readlines()).strip()
183-
return datetime.strptime(last_mod_date, DATE_FORMAT)
184-
185211
@staticmethod
186212
def get_matching_standards(
187213
seq: Sequence[int],
@@ -320,10 +346,11 @@ def combine_standards(
320346

321347
return result
322348

349+
@staticmethod
323350
def get_mismatches(
324-
self,
325351
standard_bin: Sequence[int],
326352
sequence_bin: Sequence[int],
353+
locus: HLA_LOCUS,
327354
) -> list[HLAMismatch]:
328355
"""
329356
Report mismatched bases and their location versus a standard.
@@ -352,7 +379,7 @@ def get_mismatches(
352379
mislist: list[HLAMismatch] = []
353380

354381
for index, correct_base_bin in correct_base_at_pos.items():
355-
if self.locus == "A" and index > 269: # i.e. > 270 in 1-based indices
382+
if locus == "A" and index > 269: # i.e. > 270 in 1-based indices
356383
# This is 241 + 1, where the 1 converts from 0-based to 1-based
357384
# indices.
358385
dex = index + 242
@@ -391,8 +418,11 @@ def interpret(
391418
:rtype: HLAInterpretation
392419
"""
393420
seq: tuple[int, ...] = hla_sequence.sequence_for_interpretation
421+
locus: HLA_LOCUS = hla_sequence.locus
394422

395-
matching_stds = self.get_matching_standards(seq, self.hla_standards.values())
423+
matching_stds = self.get_matching_standards(
424+
seq, self.hla_standards[locus].values()
425+
)
396426
if len(matching_stds) == 0:
397427
raise EasyHLA.NoMatchingStandards()
398428

@@ -405,7 +435,7 @@ def interpret(
405435
)
406436

407437
b5701_standards: Optional[list[HLAStandard]] = None
408-
if self.locus == "B":
438+
if locus == "B":
409439
b5701_standards = [
410440
self.hla_standards[allele] for allele in self.B5701_ALLELES
411441
]

src/easyhla/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import re
22
from collections.abc import Iterable
3+
from datetime import datetime
34
from operator import itemgetter
4-
from typing import Optional
5+
from typing import Optional, TypedDict
56

67
import numpy as np
78
from pydantic import BaseModel, ConfigDict

0 commit comments

Comments
 (0)