Skip to content

Commit 204d027

Browse files
author
Richard Liang
committed
Changed combined_standards_helper into a generator; added tests.
1 parent eea16fa commit 204d027

File tree

2 files changed

+502
-66
lines changed

2 files changed

+502
-66
lines changed

src/easyhla/easyhla.py

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import csv
22
import os
3-
from collections.abc import Iterable, Sequence
3+
from collections.abc import Generator, Iterable, Sequence
44
from datetime import datetime
55
from io import TextIOBase
66
from operator import attrgetter
@@ -232,58 +232,67 @@ def get_matching_standards(
232232
return matching_stds
233233

234234
@staticmethod
235-
def combine_standards_helper(
235+
def combine_standards_stepper(
236236
matching_stds: Sequence[HLAStandardMatch],
237237
seq: Sequence[int],
238238
mismatch_threshold: int = 0,
239-
) -> dict[tuple[int, ...], tuple[int, list[tuple[str, str]]]]:
239+
) -> Generator[tuple[tuple[int, ...], int, tuple[str, str]], None, None]:
240240
"""
241-
Helper to identify "good" combined standards for the specified sequence.
242-
243-
Returns a mapping:
244-
binary sequence tuple -|-> (mismatch count, allele pair list)
245-
246-
This mapping will contain "good" combined standards. It will always
247-
contain the best-matching combined standard(s). If mismatch_threshold
248-
is 0, then we only care about the best match; if mismatch_threshold is a
249-
positive integer, it will also contain any combined standards which have
250-
fewer mismatches than the threshold.
251-
252-
The result may also contain other combined standards, which will be
253-
winnowed out by the calling function.
241+
Identifies "good" combined standards for the specified sequence.
242+
243+
On each iteration, it continues checking combined standards until it
244+
finds a "match", and yields a tuple containing the details of that
245+
match:
246+
- the combined standard, as a tuple of integers 0-15;
247+
- the number of mismatches identified; and
248+
- the allele pair (i.e. names of the two alleles in the combination).
249+
250+
A "match" is defined by the number of mismatches between the combined
251+
standard and the sequence:
252+
- this is the best-matching combined standard found so far (may
253+
be above our mismatch threshold) or as good as the best-matching one
254+
found so far; or
255+
- this is below our mismatch threshold.
256+
If the mismatch threshold is 0, then we will only ever get the former.
254257
"""
255-
combos: dict[tuple[int, ...], tuple[int, list[tuple[str, str]]]] = {}
258+
# Keep track of matches we've already found:
259+
combos: dict[tuple[int, ...], int] = {}
256260

257261
current_rejection_threshold: int = float("inf")
258262
for std_ai, std_a in enumerate(matching_stds):
259263
if std_a.mismatch > current_rejection_threshold:
260264
continue
261-
for std_bi, std_b in enumerate(matching_stds):
262-
if std_ai < std_bi:
263-
break
265+
for std_b in matching_stds[: (std_ai + 1)]:
264266
if std_b.mismatch > current_rejection_threshold:
265267
continue
266268

267269
# "Mush" the two standards together to produce something
268270
# that looks like what you get when you sequence HLA.
269271
std_bin = np.array(std_b.sequence) | np.array(std_a.sequence)
270-
seq_mask = np.full_like(std_bin, fill_value=15)
271-
# Note that seq is implicitly cast to a NumPy array:
272-
mismatches: int = np.count_nonzero((std_bin ^ seq) & seq_mask != 0)
273-
274-
if mismatches > current_rejection_threshold:
275-
continue
272+
allele_pair: tuple[str, str] = tuple(
273+
sorted((std_a.allele, std_b.allele))
274+
)
276275

277276
# There could be more than one combined standard with the
278-
# same sequence, so keep track of all the possible combinations.
277+
# same sequence, so check if this one's already been found.
279278
combined_std_bin: tuple[int, ...] = tuple(int(s) for s in std_bin)
280-
if combined_std_bin not in combos:
281-
combos[combined_std_bin] = (mismatches, [])
282-
combos[combined_std_bin][1].append(sorted((std_a.allele, std_b.allele)))
283279

284-
if mismatches < current_rejection_threshold:
280+
mismatches: int = -1
281+
if combined_std_bin in combos:
282+
mismatches = combos[combined_std_bin]
283+
284+
else:
285+
seq_mask = np.full_like(std_bin, fill_value=15)
286+
# Note that seq is implicitly cast to a NumPy array:
287+
mismatches: int = np.count_nonzero((std_bin ^ seq) & seq_mask != 0)
288+
combos[combined_std_bin] = mismatches # cache this value
289+
290+
if mismatches > current_rejection_threshold:
291+
continue
292+
elif mismatches < current_rejection_threshold:
285293
current_rejection_threshold = max(mismatches, mismatch_threshold)
286-
return combos
294+
295+
yield (combined_std_bin, mismatches, allele_pair)
287296

288297
@staticmethod
289298
def combine_standards(
@@ -307,10 +316,10 @@ def combine_standards(
307316
PRECONDITION: matching_stds should contain no duplicates.
308317
309318
Returns a dictionary mapping HLACombinedStandards to their mismatch
310-
counts. If mismatch_threshold is None, then the result contains only
311-
the best-matching combined standard(s); otherwise, the result contains
312-
all combined standards with mismatch counts up to and including the
313-
threshold. All of the HLACombinedStandards have their
319+
counts. If mismatch_threshold is None or 0, then the result contains
320+
only the best-matching combined standard(s); otherwise, the result
321+
contains all combined standards with mismatch counts up to and including
322+
the threshold. All of the HLACombinedStandards have their
314323
`possible_allele_pairs` value sorted.
315324
"""
316325
if mismatch_threshold is None:
@@ -319,21 +328,25 @@ def combine_standards(
319328
# the current best match.
320329
mismatch_threshold = 0
321330

322-
combos: dict[tuple[int, ...], tuple[int, list[tuple[str, str]]]] = (
323-
EasyHLA.combine_standards_helper(
324-
matching_stds,
325-
seq,
326-
mismatch_threshold,
327-
)
328-
)
331+
combos: dict[tuple[int, ...], tuple[int, list[tuple[str, str]]]] = {}
332+
333+
fewest_mismatches: int = float("inf")
334+
for (
335+
combined_std_bin,
336+
mismatches,
337+
allele_pair,
338+
) in EasyHLA.combine_standards_stepper(matching_stds, seq, mismatch_threshold):
339+
if combined_std_bin not in combos:
340+
combos[combined_std_bin] = (mismatches, [])
341+
combos[combined_std_bin][1].append(allele_pair)
342+
if mismatches < fewest_mismatches:
343+
fewest_mismatches = mismatches
329344

330345
# Winnow out any extraneous combined standards that don't match our
331346
# criteria.
332347
result: dict[HLACombinedStandard, int] = {}
333348

334-
fewest_mismatches: int = min([x[0] for x in combos.values()])
335349
cutoff: int = max(fewest_mismatches, mismatch_threshold)
336-
337350
for combined_std_bin, mismatch_count_and_pair_list in combos.items():
338351
mismatch_count: int
339352
pair_list: list[tuple[str, str]]

0 commit comments

Comments
 (0)