33from datetime import datetime
44from io import TextIOBase
55from operator import attrgetter
6- from typing import Final , Optional
6+ from typing import Final , Optional , TypedDict
77
88import numpy as np
9+ import yaml
910
1011from .models import (
1112 HLACombinedStandard ,
2021from .utils import (
2122 BIN2NUC ,
2223 HLA_LOCUS ,
24+ GroupedAllele ,
25+ StoredHLAStandards ,
2326 count_strict_mismatches ,
2427 nuc2bin ,
2528)
2629
2730DATE_FORMAT = "%a %b %d %H:%M:%S %Z %Y"
2831
2932
33+ class ProcessedStoredStandards (TypedDict ):
34+ tag : str
35+ last_modified : datetime
36+ standards : dict [HLA_LOCUS , dict [str , HLAStandard ]]
37+
38+
3039class EasyHLA :
3140 # For HLA-B interpretations, these alleles are the ones we use to determine
3241 # how close a sequence is to "B*57:01".
@@ -40,61 +49,89 @@ class EasyHLA:
4049
4150 def __init__ (
4251 self ,
43- locus : HLA_LOCUS ,
44- hla_standards : Optional [dict [str , HLAStandard ]] = None ,
45- hla_frequencies : Optional [dict [HLAProteinPair , int ]] = None ,
46- last_modified : Optional [datetime ] = None ,
52+ hla_standards : Optional [ProcessedStoredStandards ] = None ,
53+ hla_frequencies : Optional [dict [HLA_LOCUS , dict [HLAProteinPair , int ]]] = None ,
4754 ):
4855 """
4956 Initialize an EasyHLA class.
5057
51- :param locus: HLA subtype that this object will be performing
52- interpretation against.
53- :type locus: "A", "B", or "C"
5458 :param logger: Python logger object, defaults to None
5559 :type logger: Optional[logging.Logger], optional
56- :raises ValueError: Raised if locus != "A"/"B"/"C"
5760 """
58- if locus not in ["A" , "B" , "C" ]:
59- raise ValueError ("Invalid HLA locus specified; must be A, B, or C" )
60- self .locus : HLA_LOCUS = locus
61+ if hla_standards is None :
62+ hla_standards = self .load_default_hla_standards ()
6163
62- self .hla_standards : dict [str , HLAStandard ]
63- if hla_standards is not None :
64- self . hla_standards = hla_standards
65- else :
66- self .hla_standards = self . load_default_hla_standards ()
64+ self .hla_standards : dict [HLA_LOCUS , dict [ str , HLAStandard ]] = hla_standards [
65+ "standards"
66+ ]
67+ self . last_modified : datetime = hla_standards [ "last_modified" ]
68+ self .tag : str = hla_standards [ "tag" ]
6769
68- self .hla_frequencies : dict [HLAProteinPair , int ]
70+ self .hla_frequencies : dict [HLA_LOCUS , dict [ HLAProteinPair , int ] ]
6971 if hla_frequencies is not None :
7072 self .hla_frequencies = hla_frequencies
7173 else :
7274 self .hla_frequencies = self .load_default_hla_frequencies ()
7375
74- self .last_modified : datetime
75- if last_modified is not None :
76- self .last_modified = last_modified
77- else :
78- self .last_modified = self .load_default_last_modified ()
76+ @classmethod
77+ def use_config (
78+ cls ,
79+ standards_path : Optional [str ],
80+ frequencies_path : Optional [str ] = None ,
81+ ) -> "EasyHLA" :
82+ """
83+ An alternate constructor that accepts file paths for the configuration.
84+ """
85+ processed_stds : Optional [ProcessedStoredStandards ] = None
86+ frequencies : Optional [dict [HLA_LOCUS , dict [HLAProteinPair , int ]]] = None
87+
88+ if standards_path is not None :
89+ with open (standards_path ) as f :
90+ processed_stds = cls .read_hla_standards (f )
91+
92+ if frequencies_path is not None :
93+ with open (frequencies_path ) as f :
94+ frequencies = cls .read_hla_frequencies (f )
95+
96+ return cls (processed_stds , frequencies )
7997
8098 @staticmethod
81- def read_hla_standards (standards_io : TextIOBase ) -> dict [str , HLAStandard ]:
99+ def read_hla_standards (
100+ standards_io : TextIOBase ,
101+ ) -> ProcessedStoredStandards :
82102 """
83103 Read HLA standards from a specified file-like object.
84104
85105 :return: Dictionary of known HLA standards keyed by their name
86106 :rtype: dict[str, HLAStandard]
87107 """
88- hla_stds : dict [str , HLAStandard ] = {}
89- for line in standards_io .readlines ():
90- line_array = line .strip ().split ("," )
91- allele : str = line_array [0 ]
92- hla_stds [allele ] = HLAStandard (
93- allele = line_array [0 ],
94- two = nuc2bin (line_array [1 ]),
95- three = nuc2bin (line_array [2 ]),
96- )
97- return hla_stds
108+ stored_stds : StoredHLAStandards = StoredHLAStandards .model_validate (
109+ yaml .safe_load (standards_io )
110+ )
111+
112+ hla_stds : dict [HLA_LOCUS , dict [str , HLAStandard ]] = {
113+ "A" : {},
114+ "B" : {},
115+ "C" : {},
116+ }
117+ stored_grouped_alleles : dict [HLA_LOCUS , list [GroupedAllele ]] = {
118+ "A" : stored_stds .A ,
119+ "B" : stored_stds .B ,
120+ "C" : stored_stds .C ,
121+ }
122+ for locus in ("A" , "B" , "C" ):
123+ for grouped_allele in stored_grouped_alleles [locus ]:
124+ hla_stds [locus ][grouped_allele .name ] = HLAStandard (
125+ allele = grouped_allele .name ,
126+ two = nuc2bin (grouped_allele .exon2 ),
127+ three = nuc2bin (grouped_allele .exon3 ),
128+ )
129+
130+ return {
131+ "tag" : stored_stds .tag ,
132+ "last_modified" : stored_stds .last_modified ,
133+ "standards" : hla_stds ,
134+ }
98135
99136 def load_default_hla_standards (self ) -> dict [str , HLAStandard ]:
100137 """
@@ -106,44 +143,50 @@ def load_default_hla_standards(self) -> dict[str, HLAStandard]:
106143 standards_filename : str = os .path .join (
107144 os .path .dirname (__file__ ),
108145 "default_data" ,
109- f"hla_ { self . locus . lower () } _std_reduced.csv " ,
146+ "hla_standards.yaml " ,
110147 )
111148 with open (standards_filename ) as standards_file :
112149 return self .read_hla_standards (standards_file )
113150
114151 @staticmethod
115152 def read_hla_frequencies (
116- locus : HLA_LOCUS ,
117153 frequencies_io : TextIOBase ,
118- ) -> dict [HLAProteinPair , int ]:
154+ ) -> dict [HLA_LOCUS , dict [ HLAProteinPair , int ] ]:
119155 """
120156 Load HLA frequencies from a specified file-like object.
121157
122- This takes two columns AAAA,BBBB out of 6 (...FFFF), and then uses a
158+ This takes each two columns AAAA,BBBB out of 6 (...FFFF), and then uses a
123159 subset of these two columns (AABB,CCDD) to use as the key, in this case
124160 "AA|BB,CC|DD", we then count the number of times this key appears in our
125161 columns.
126162
127- :return: Lookup table of HLA frequencies.
128- :rtype: dict[HLAProteinPair, int]
163+ :return: Lookup table of locus and HLA frequencies.
164+ :rtype: dict[HLA_LOCUS, dict[ HLAProteinPair, int] ]
129165 """
130- hla_freqs : dict [HLAProteinPair , int ] = {}
131- for line in frequencies_io .readlines ():
132- column_id = EasyHLA .COLUMN_IDS [locus ]
133- line_array = line .strip ().split ("," )[column_id : column_id + 2 ]
134-
135- protein_pair : HLAProteinPair = HLAProteinPair (
136- first_field_1 = line_array [0 ][0 :2 ],
137- first_field_2 = line_array [0 ][2 :4 ],
138- second_field_1 = line_array [1 ][0 :2 ],
139- second_field_2 = line_array [1 ][2 :4 ],
140- )
141- if hla_freqs .get (protein_pair , None ) is None :
142- hla_freqs [protein_pair ] = 0
143- hla_freqs [protein_pair ] += 1
166+ hla_freqs : dict [HLA_LOCUS , dict [HLAProteinPair , int ]] = {
167+ "A" : {},
168+ "B" : {},
169+ "C" : {},
170+ }
171+ for locus in ("A" , "B" , "C" ):
172+ for line in frequencies_io .readlines ():
173+ column_id = EasyHLA .COLUMN_IDS [locus ]
174+ line_array = line .strip ().split ("," )[column_id : column_id + 2 ]
175+
176+ protein_pair : HLAProteinPair = HLAProteinPair (
177+ first_field_1 = line_array [0 ][0 :2 ],
178+ first_field_2 = line_array [0 ][2 :4 ],
179+ second_field_1 = line_array [1 ][0 :2 ],
180+ second_field_2 = line_array [1 ][2 :4 ],
181+ )
182+ if hla_freqs [locus ].get (protein_pair , None ) is None :
183+ hla_freqs [locus ][protein_pair ] = 0
184+ hla_freqs [locus ][protein_pair ] += 1
144185 return hla_freqs
145186
146- def load_default_hla_frequencies (self ) -> dict [HLAProteinPair , int ]:
187+ def load_default_hla_frequencies (
188+ self ,
189+ ) -> dict [HLA_LOCUS , dict [HLAProteinPair , int ]]:
147190 """
148191 Load HLA frequencies from reference file.
149192
@@ -153,35 +196,18 @@ def load_default_hla_frequencies(self) -> dict[HLAProteinPair, int]:
153196 columns.
154197
155198 :return: Lookup table of HLA frequencies.
156- :rtype: dict[HLAProteinPair, int]
199+ :rtype: dict[HLA_LOCUS, dict[ HLAProteinPair, int] ]
157200 """
158- hla_freqs : dict [HLAProteinPair , int ]
201+ hla_freqs : dict [HLA_LOCUS , dict [ HLAProteinPair , int ] ]
159202 default_frequencies_filename : str = os .path .join (
160203 os .path .dirname (__file__ ),
161204 "default_data" ,
162205 "hla_frequencies.csv" ,
163206 )
164207 with open (default_frequencies_filename , "r" ) as f :
165- hla_freqs = self .read_hla_frequencies (self . locus , f )
208+ hla_freqs = self .read_hla_frequencies (f )
166209 return hla_freqs
167210
168- @staticmethod
169- def load_default_last_modified () -> datetime :
170- """
171- Load a datetime object describing when standard definitions were last updated.
172-
173- :return: Date and time representing when references were last updated.
174- :rtype: datetime
175- """
176- filename = os .path .join (
177- os .path .dirname (__file__ ),
178- "default_data" ,
179- "hla_nuc.fasta.mtime" ,
180- )
181- with open (filename , "r" , encoding = "utf-8" ) as f :
182- last_mod_date = "" .join (f .readlines ()).strip ()
183- return datetime .strptime (last_mod_date , DATE_FORMAT )
184-
185211 @staticmethod
186212 def get_matching_standards (
187213 seq : Sequence [int ],
@@ -320,10 +346,11 @@ def combine_standards(
320346
321347 return result
322348
349+ @staticmethod
323350 def get_mismatches (
324- self ,
325351 standard_bin : Sequence [int ],
326352 sequence_bin : Sequence [int ],
353+ locus : HLA_LOCUS ,
327354 ) -> list [HLAMismatch ]:
328355 """
329356 Report mismatched bases and their location versus a standard.
@@ -352,7 +379,7 @@ def get_mismatches(
352379 mislist : list [HLAMismatch ] = []
353380
354381 for index , correct_base_bin in correct_base_at_pos .items ():
355- if self . locus == "A" and index > 269 : # i.e. > 270 in 1-based indices
382+ if locus == "A" and index > 269 : # i.e. > 270 in 1-based indices
356383 # This is 241 + 1, where the 1 converts from 0-based to 1-based
357384 # indices.
358385 dex = index + 242
@@ -391,8 +418,11 @@ def interpret(
391418 :rtype: HLAInterpretation
392419 """
393420 seq : tuple [int , ...] = hla_sequence .sequence_for_interpretation
421+ locus : HLA_LOCUS = hla_sequence .locus
394422
395- matching_stds = self .get_matching_standards (seq , self .hla_standards .values ())
423+ matching_stds = self .get_matching_standards (
424+ seq , self .hla_standards [locus ].values ()
425+ )
396426 if len (matching_stds ) == 0 :
397427 raise EasyHLA .NoMatchingStandards ()
398428
@@ -405,7 +435,7 @@ def interpret(
405435 )
406436
407437 b5701_standards : Optional [list [HLAStandard ]] = None
408- if self . locus == "B" :
438+ if locus == "B" :
409439 b5701_standards = [
410440 self .hla_standards [allele ] for allele in self .B5701_ALLELES
411441 ]
0 commit comments