11import csv
22import os
3- from collections .abc import Iterable , Sequence
3+ from collections .abc import Generator , Iterable , Sequence
44from datetime import datetime
55from io import TextIOBase
66from operator import attrgetter
@@ -232,58 +232,67 @@ def get_matching_standards(
232232 return matching_stds
233233
234234 @staticmethod
235- def combine_standards_helper (
235+ def combine_standards_stepper (
236236 matching_stds : Sequence [HLAStandardMatch ],
237237 seq : Sequence [int ],
238238 mismatch_threshold : int = 0 ,
239- ) -> dict [tuple [int , ...], tuple [ int , list [ tuple [str , str ]]] ]:
239+ ) -> Generator [tuple [tuple [ int , ...], int , tuple [str , str ]], None , None ]:
240240 """
241- Helper to identify "good" combined standards for the specified sequence.
242-
243- Returns a mapping:
244- binary sequence tuple -|-> (mismatch count, allele pair list)
245-
246- This mapping will contain "good" combined standards. It will always
247- contain the best-matching combined standard(s). If mismatch_threshold
248- is 0, then we only care about the best match; if mismatch_threshold is a
249- positive integer, it will also contain any combined standards which have
250- fewer mismatches than the threshold.
251-
252- The result may also contain other combined standards, which will be
253- winnowed out by the calling function.
241+ Identifies "good" combined standards for the specified sequence.
242+
243+ On each iteration, it continues checking combined standards until it
244+ finds a "match", and yields a tuple containing the details of that
245+ match:
246+ - the combined standard, as a tuple of integers 0-15;
247+ - the number of mismatches identified; and
248+ - the allele pair (i.e. names of the two alleles in the combination).
249+
250+ A "match" is defined by the number of mismatches between the combined
251+ standard and the sequence:
252+ - this is the best-matching combined standard found so far (may
253+ be above our mismatch threshold) or as good as the best-matching one
254+ found so far; or
255+ - this is below our mismatch threshold.
256+ If the mismatch threshold is 0, then we will only ever get the former.
254257 """
255- combos : dict [tuple [int , ...], tuple [int , list [tuple [str , str ]]]] = {}
258+ # Keep track of matches we've already found:
259+ combos : dict [tuple [int , ...], int ] = {}
256260
257261 current_rejection_threshold : int = float ("inf" )
258262 for std_ai , std_a in enumerate (matching_stds ):
259263 if std_a .mismatch > current_rejection_threshold :
260264 continue
261- for std_bi , std_b in enumerate (matching_stds ):
262- if std_ai < std_bi :
263- break
265+ for std_b in matching_stds [: (std_ai + 1 )]:
264266 if std_b .mismatch > current_rejection_threshold :
265267 continue
266268
267269 # "Mush" the two standards together to produce something
268270 # that looks like what you get when you sequence HLA.
269271 std_bin = np .array (std_b .sequence ) | np .array (std_a .sequence )
270- seq_mask = np .full_like (std_bin , fill_value = 15 )
271- # Note that seq is implicitly cast to a NumPy array:
272- mismatches : int = np .count_nonzero ((std_bin ^ seq ) & seq_mask != 0 )
273-
274- if mismatches > current_rejection_threshold :
275- continue
272+ allele_pair : tuple [str , str ] = tuple (
273+ sorted ((std_a .allele , std_b .allele ))
274+ )
276275
277276 # There could be more than one combined standard with the
278- # same sequence, so keep track of all the possible combinations .
277+ # same sequence, so check if this one's already been found .
279278 combined_std_bin : tuple [int , ...] = tuple (int (s ) for s in std_bin )
280- if combined_std_bin not in combos :
281- combos [combined_std_bin ] = (mismatches , [])
282- combos [combined_std_bin ][1 ].append (sorted ((std_a .allele , std_b .allele )))
283279
284- if mismatches < current_rejection_threshold :
280+ mismatches : int = - 1
281+ if combined_std_bin in combos :
282+ mismatches = combos [combined_std_bin ]
283+
284+ else :
285+ seq_mask = np .full_like (std_bin , fill_value = 15 )
286+ # Note that seq is implicitly cast to a NumPy array:
287+ mismatches : int = np .count_nonzero ((std_bin ^ seq ) & seq_mask != 0 )
288+ combos [combined_std_bin ] = mismatches # cache this value
289+
290+ if mismatches > current_rejection_threshold :
291+ continue
292+ elif mismatches < current_rejection_threshold :
285293 current_rejection_threshold = max (mismatches , mismatch_threshold )
286- return combos
294+
295+ yield (combined_std_bin , mismatches , allele_pair )
287296
288297 @staticmethod
289298 def combine_standards (
@@ -307,10 +316,10 @@ def combine_standards(
307316 PRECONDITION: matching_stds should contain no duplicates.
308317
309318 Returns a dictionary mapping HLACombinedStandards to their mismatch
310- counts. If mismatch_threshold is None, then the result contains only
311- the best-matching combined standard(s); otherwise, the result contains
312- all combined standards with mismatch counts up to and including the
313- threshold. All of the HLACombinedStandards have their
319+ counts. If mismatch_threshold is None or 0 , then the result contains
320+ only the best-matching combined standard(s); otherwise, the result
321+ contains all combined standards with mismatch counts up to and including
322+ the threshold. All of the HLACombinedStandards have their
314323 `possible_allele_pairs` value sorted.
315324 """
316325 if mismatch_threshold is None :
@@ -319,21 +328,25 @@ def combine_standards(
319328 # the current best match.
320329 mismatch_threshold = 0
321330
322- combos : dict [tuple [int , ...], tuple [int , list [tuple [str , str ]]]] = (
323- EasyHLA .combine_standards_helper (
324- matching_stds ,
325- seq ,
326- mismatch_threshold ,
327- )
328- )
331+ combos : dict [tuple [int , ...], tuple [int , list [tuple [str , str ]]]] = {}
332+
333+ fewest_mismatches : int = float ("inf" )
334+ for (
335+ combined_std_bin ,
336+ mismatches ,
337+ allele_pair ,
338+ ) in EasyHLA .combine_standards_stepper (matching_stds , seq , mismatch_threshold ):
339+ if combined_std_bin not in combos :
340+ combos [combined_std_bin ] = (mismatches , [])
341+ combos [combined_std_bin ][1 ].append (allele_pair )
342+ if mismatches < fewest_mismatches :
343+ fewest_mismatches = mismatches
329344
330345 # Winnow out any extraneous combined standards that don't match our
331346 # criteria.
332347 result : dict [HLACombinedStandard , int ] = {}
333348
334- fewest_mismatches : int = min ([x [0 ] for x in combos .values ()])
335349 cutoff : int = max (fewest_mismatches , mismatch_threshold )
336-
337350 for combined_std_bin , mismatch_count_and_pair_list in combos .items ():
338351 mismatch_count : int
339352 pair_list : list [tuple [str , str ]]
0 commit comments