Skip to content

Commit 8ed2114

Browse files
author
Richard Liang
committed
WIP: some polish on how the standards are stored and loaded.
Tests still need to be updated, as do the default standards YAML.
1 parent a139acc commit 8ed2114

File tree

4 files changed

+152
-35
lines changed

4 files changed

+152
-35
lines changed

src/easyhla/easyhla.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from .utils import (
2222
BIN2NUC,
2323
HLA_LOCUS,
24-
GroupedAllele,
2524
StoredHLAStandards,
2625
count_strict_mismatches,
2726
nuc2bin,
@@ -76,7 +75,7 @@ def __init__(
7675
@classmethod
7776
def use_config(
7877
cls,
79-
standards_path: Optional[str],
78+
standards_path: Optional[str] = None,
8079
frequencies_path: Optional[str] = None,
8180
) -> "EasyHLA":
8281
"""
@@ -112,13 +111,8 @@ def read_hla_standards(standards_io: TextIOBase) -> LoadedStandards:
112111
"B": {},
113112
"C": {},
114113
}
115-
stored_grouped_alleles: dict[HLA_LOCUS, list[GroupedAllele]] = {
116-
"A": stored_stds.A,
117-
"B": stored_stds.B,
118-
"C": stored_stds.C,
119-
}
120114
for locus in ("A", "B", "C"):
121-
for grouped_allele in stored_grouped_alleles[locus]:
115+
for grouped_allele in stored_stds.standards[locus]:
122116
hla_stds[locus][grouped_allele.name] = HLAStandard(
123117
allele=grouped_allele.name,
124118
two=nuc2bin(grouped_allele.exon2),

src/easyhla/models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import re
22
from collections.abc import Iterable
3-
from datetime import datetime
43
from operator import itemgetter
5-
from typing import Optional, TypedDict
4+
from typing import Optional
65

76
import numpy as np
87
from pydantic import BaseModel, ConfigDict

src/easyhla/update_alleles.py

Lines changed: 108 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import argparse
44
import hashlib
5+
import json
56
import logging
67
import os
78
import time
89
from datetime import datetime
910
from io import StringIO
10-
from typing import Final
11+
from typing import Final, Optional, TypedDict
1112

1213
import Bio
1314
import requests
@@ -81,9 +82,13 @@
8182

8283
# Find all releases (and their corresponding tags) of the HLA data at
8384
# https://github.com/ANHIG/IMGTHLA/releases
84-
REPO_PATH: Final[str] = os.environ.get(
85-
"EASYHLA_REPO_PATH",
86-
"https://raw.githubusercontent.com/ANHIG/IMGTHLA",
85+
REPO_OWNER: Final[str] = os.environ.get(
86+
"EASYHLA_REPO_OWNER",
87+
"ANHIG",
88+
)
89+
REPO_NAME: Final[str] = os.environ.get(
90+
"EASYHLA_REPO_NAME",
91+
"IMGTHLA",
8792
)
8893
HLA_ALLELES_FILENAME: Final[str] = os.environ.get(
8994
"EASYHLA_REPO_ALLELES_FILENAME",
@@ -95,21 +100,109 @@ class RetrieveAllelesError(Exception):
95100
pass
96101

97102

103+
class RetrieveCommitHashError(Exception):
104+
pass
105+
106+
98107
def get_alleles_file(
99108
tag: str,
100-
base_url: str = REPO_PATH,
109+
repo_owner: str = REPO_OWNER,
110+
repo_name: str = REPO_NAME,
101111
alleles_filename: str = HLA_ALLELES_FILENAME,
102112
) -> str:
103113
"""
104114
Retrieve the HLA alleles file from the specified tag.
105115
"""
106-
url: str = f"{base_url}/{tag}/{alleles_filename}"
107-
response: requests.Response = requests.get(url)
116+
url: str = (
117+
f"https://api.github.com/repos/{repo_owner}/{repo_name}/"
118+
f"contents/{alleles_filename}?ref={tag}"
119+
)
120+
response: requests.Response = requests.get(
121+
url,
122+
headers={
123+
"Accept": "application/vnd.github.raw+json",
124+
"X-GitHub-Api-Version": "2022-11-28",
125+
},
126+
)
108127
if response.status_code != requests.codes.ok:
109128
raise RetrieveAllelesError()
110129
return response.text
111130

112131

132+
class CommitInfo(TypedDict):
133+
sha: str
134+
url: str
135+
136+
137+
class TagInfo(TypedDict):
138+
name: str
139+
commit: CommitInfo
140+
zipball_url: str
141+
tarball_url: str
142+
node_id: str
143+
144+
145+
def get_commit_hash(
146+
tag_name: str,
147+
repo_owner: str = REPO_OWNER,
148+
repo_name: str = REPO_NAME,
149+
) -> Optional[str]:
150+
"""
151+
Retrieve the commit hash of the specified tag.
152+
"""
153+
url: str = f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags"
154+
response: requests.Response = requests.get(
155+
url,
156+
headers={
157+
"Accept": "application/vnd.github+json",
158+
"X-GitHub-Api-Version": "2022-11-28",
159+
},
160+
)
161+
if response.status_code != requests.codes.ok:
162+
raise RetrieveCommitHashError()
163+
164+
tags: list[TagInfo] = json.loads(response.text)
165+
for tag in tags:
166+
if tag["name"] == tag_name:
167+
return tag["commit"]["sha"]
168+
169+
return None
170+
171+
172+
def get_from_git(tag: str) -> tuple[str, datetime, str]:
173+
alleles_str: str
174+
retrieval_datetime: datetime
175+
for i in range(5):
176+
try:
177+
retrieval_datetime = datetime.now()
178+
alleles_str = get_alleles_file(tag)
179+
except RetrieveAllelesError:
180+
if i < 4:
181+
logger.info("Failed to retrieve alleles; retrying in 20 seconds....")
182+
time.sleep(20)
183+
else:
184+
raise
185+
else:
186+
break
187+
188+
commit_hash: str
189+
for i in range(5):
190+
try:
191+
commit_hash = get_commit_hash(tag)
192+
except RetrieveCommitHashError:
193+
if i < 4:
194+
logger.info(
195+
"Failed to retrieve the commit hash; retrying in 20 seconds...."
196+
)
197+
time.sleep(20)
198+
else:
199+
raise
200+
else:
201+
break
202+
203+
return alleles_str, retrieval_datetime, commit_hash
204+
205+
113206
def main():
114207
parser: argparse.ArgumentParser = argparse.ArgumentParser(
115208
"Retrieve HLA alleles from IPD-IMGT/HLA."
@@ -178,18 +271,12 @@ def main():
178271
logger.info(f"Retrieving alleles from tag {args.tag}....")
179272
alleles_str: str
180273
retrieval_datetime: datetime
181-
for i in range(5):
182-
try:
183-
retrieval_datetime = datetime.now()
184-
alleles_str = get_alleles_file(args.tag)
185-
except RetrieveAllelesError:
186-
if i < 4:
187-
logger.info("Failed to retrieve alleles; retrying in 20 seconds....")
188-
time.sleep(20)
189-
else:
190-
raise
191-
else:
192-
break
274+
commit_hash: str
275+
alleles_str, retrieval_datetime, commit_hash = get_from_git(args.tag)
276+
logger.info(
277+
f"Alleles (version {args.tag}, commit hash {commit_hash}) retrieved at "
278+
f"{retrieval_datetime}."
279+
)
193280

194281
if args.dump_full_fasta_to != "":
195282
logger.info(f"Dumping the full FASTA file to {args.dump_full_fasta_to}.")
@@ -214,8 +301,9 @@ def main():
214301
logger.info("Identifying identical HLA alleles....")
215302
standards_for_saving: StoredHLAStandards = StoredHLAStandards(
216303
tag=args.tag,
304+
commit_hash=commit_hash,
217305
last_updated=retrieval_datetime,
218-
**{
306+
standards={
219307
locus: group_identical_alleles(raw_standards[locus])
220308
for locus in ("A", "B", "C")
221309
},

src/easyhla/utils.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
import hashlib
12
import logging
23
import re
34
from collections import defaultdict
45
from collections.abc import Iterable, Sequence
56
from datetime import datetime
67
from operator import attrgetter
7-
from typing import Final, Literal, Optional
8+
from typing import Final, Literal, Optional, Self
89

910
import numpy as np
1011
from Bio.SeqIO import SeqRecord
11-
from pydantic import BaseModel, computed_field
12+
from pydantic import BaseModel, computed_field, model_validator
1213

1314
# A lookup table of translations from ambiguous nucleotides to unambiguous
1415
# nucleotides.
@@ -490,9 +491,44 @@ def group_identical_alleles(
490491
return sorted(grouped_alleles, key=attrgetter("name"))
491492

492493

494+
def compute_stored_standard_checksum(
495+
tag: str,
496+
commit_hash: str,
497+
last_updated: datetime,
498+
alleles: dict[HLA_LOCUS, list[GroupedAllele]],
499+
) -> str:
500+
"""
501+
Compute a checksum for the stored data.
502+
"""
503+
stored_string: str = f"{tag}\n{commit_hash}\n{last_updated}\n"
504+
for locus in ("A", "B", "C"):
505+
for ga in alleles[locus]:
506+
stored_string += f"{ga.name},{ga.exon2},{ga.exon3},{';'.join(ga.alleles)}\n"
507+
508+
# Compute the checksum.
509+
sha256_calc = hashlib.sha256()
510+
sha256_calc.update(stored_string.encode())
511+
return sha256_calc.hexdigest()
512+
513+
493514
class StoredHLAStandards(BaseModel):
494515
tag: str
516+
commit_hash: str
495517
last_updated: datetime
496-
A: list[GroupedAllele]
497-
B: list[GroupedAllele]
498-
C: list[GroupedAllele]
518+
standards: dict[HLA_LOCUS, list[GroupedAllele]]
519+
checksum: Optional[str] = None
520+
521+
@model_validator(mode="after")
522+
def compute_compare_checksum(self) -> Self:
523+
checksum: str = compute_stored_standard_checksum(
524+
self.tag,
525+
self.commit_hash,
526+
self.last_updated,
527+
self.standards,
528+
)
529+
530+
if self.checksum is None:
531+
self.checksum = checksum
532+
else:
533+
if self.checksum != checksum:
534+
raise ValueError("Checksum mismatch")

0 commit comments

Comments
 (0)