Skip to content

Commit 5818990

Browse files
author
David Rickett
committed
More tests
1 parent 81aeb1e commit 5818990

File tree

2 files changed

+151
-68
lines changed

2 files changed

+151
-68
lines changed

src/easyhla/easyhla.py

Lines changed: 97 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def interpret(
579579
# DR 2023-02-24: To whomever made this comment, great shoutout!
580580
all_combos = self.combine_stds(matching_stds, seq, threshold)
581581

582-
self.get_mismatches(
582+
self.report_mismatches(
583583
letter=self.letter,
584584
all_combos=all_combos,
585585
seq=seq,
@@ -591,48 +591,11 @@ def interpret(
591591
best_matches: List[HLACombinedStandardResult] = min(all_combos.items())[1]
592592
mismatch_count: int = min(all_combos.items())[0]
593593

594-
mishash: Dict[int, List[int]] = {}
595-
596-
for cons in best_matches:
597-
_seq = [int(nuc) for nuc in cons.standard.split("-")]
598-
for i in range(len(_seq)):
599-
base = EasyHLA.BIN2NUC[seq[i]]
600-
if _seq[i] ^ seq[i] != 0:
601-
correct_base = EasyHLA.BIN2NUC[_seq[i]]
602-
if letter == "A" and i > 270:
603-
dex = i + 242
604-
else:
605-
dex = i + 1
606-
if not i in mishash:
607-
mishash[i] = []
608-
if not _seq[i] in mishash[i]:
609-
mishash[i].append(_seq[i])
610-
611-
mislist: List[str] = []
612-
613-
for m, mlist in mishash.items():
614-
if letter == "A" and m > 270:
615-
dex = m + 241
616-
else:
617-
dex = m + 1
618-
619-
base = EasyHLA.BIN2NUC[seq[m]]
620-
correct_bases = ""
621-
for correct_bin in mlist:
622-
if not correct_bases:
623-
correct_bases = EasyHLA.BIN2NUC[correct_bin]
624-
else:
625-
correct_bases += "/" + EasyHLA.BIN2NUC[correct_bin]
626-
mislist.append(f"{dex}:{base}->{correct_bases}")
627-
628-
# mislist = mislist.sort_by{|b| b.split(":")[0].to_i}
629-
# mismatches = mislist.join(";")
630-
mislist.sort(key=lambda item: item.split(":")[0])
631-
mismatches = ";".join(mislist)
632-
633594
# Clean the alleles
634595

635-
fcnt = EasyHLA.COLUMN_IDS[letter]
596+
mismatches = self.get_mismatches(
597+
self.letter, best_matches=best_matches, seq=seq
598+
)
636599

637600
alleles = self.get_all_alleles(best_matches=best_matches)
638601
ambig = alleles.is_ambiguous()
@@ -690,6 +653,9 @@ def run(
690653
threshold: Optional[int] = None,
691654
to_stdout: Optional[bool] = None,
692655
):
656+
if threshold and threshold < 0:
657+
raise RuntimeError("Threshold must be >=0 or None!")
658+
693659
rows = []
694660
npats = 0
695661
nseqs = 0
@@ -760,14 +726,15 @@ def filter_reportable_alleles(
760726
self, letter: HLA_TYPES, alleles: Alleles
761727
) -> List[Tuple[str, str]]:
762728
"""
763-
_summary_
729+
In case we have an ambiguous set of alleles, remove ambiguous alleles
730+
using HLA Freq standards.
764731
765-
:param letter: _description_
732+
:param letter: ...
766733
:type letter: HLA_TYPES
767-
:param best_matches: _description_
768-
:type best_matches: List[HLACombinedStandardResult]
769-
:return: _description_
770-
:rtype: Tuple[bool, List[Tuple[str,str]]]
734+
:param alleles: ...
735+
:type alleles: Alleles
736+
:return: List of alleles filtered by HLA frequency.
737+
:rtype: List[Tuple[str,str]]
771738
"""
772739

773740
collection_ambig = alleles.get_ambiguous_collection()
@@ -776,8 +743,6 @@ def filter_reportable_alleles(
776743
if freq.startswith(k):
777744
collection_ambig[k] = self.hla_freqs.get(freq, 0)
778745

779-
# TODO: Implement like the following commented ruby
780-
# Easier if we made things a model.
781746
def sort_allele(item: Tuple[str, int]):
782747
"""
783748
Produce a tuple that the sort function will use to determine the maximum allele.
@@ -823,15 +788,85 @@ def sort_allele(item: Tuple[str, int]):
823788
return _alleles
824789

825790
def get_mismatches(
791+
self,
792+
letter: HLA_TYPES,
793+
best_matches: List[HLACombinedStandardResult],
794+
seq: np.ndarray,
795+
) -> str:
796+
"""
797+
Report mismatched bases and their location versus a standard reference.
798+
799+
The output looks like "$LOC:$SEQ_BASE->$STANDARD_BASE", if multiple
800+
mismatches are present, this will be delimited with `;`'s.
801+
802+
:param letter: ...
803+
:type letter: HLA_TYPES
804+
:param best_matches: List of the "best matched" standards to the sequence.
805+
:type best_matches: List[HLACombinedStandardResult]
806+
:param seq: The sequence being interpretted.
807+
:type seq: np.ndarray
808+
:return: A string-concatentated list of locations containing mismatches.
809+
:rtype: str
810+
"""
811+
correct_bases_at_pos: Dict[int, List[int]] = {}
812+
813+
for hla_csr in best_matches:
814+
_seq = np.array([int(nuc) for nuc in hla_csr.standard.split("-")])
815+
# TODO: replace with https://stackoverflow.com/questions/16094563/numpy-get-index-where-value-is-true
816+
for idx in np.flatnonzero(_seq ^ seq):
817+
if not idx in correct_bases_at_pos:
818+
correct_bases_at_pos[idx] = []
819+
if not _seq[idx] in correct_bases_at_pos[idx]:
820+
correct_bases_at_pos[idx].append(_seq[idx])
821+
822+
mislist: List[str] = []
823+
824+
for index, correct_bases in correct_bases_at_pos.items():
825+
if letter == "A" and index > 270:
826+
dex = index + 241
827+
else:
828+
dex = index + 1
829+
830+
base = EasyHLA.BIN2NUC[seq[index]]
831+
_correct_bases = "/".join(
832+
[EasyHLA.BIN2NUC[correct_bin] for correct_bin in correct_bases]
833+
)
834+
mislist.append(f"{dex}:{base}->{_correct_bases}")
835+
836+
mislist.sort(key=lambda item: item.split(":")[0])
837+
mismatches = ";".join(mislist)
838+
839+
return mismatches
840+
841+
def report_mismatches(
826842
self,
827843
letter: HLA_TYPES,
828844
all_combos: Dict[int, List[HLACombinedStandardResult]],
829845
seq: np.ndarray,
830-
threshold: Optional[int],
846+
threshold: Optional[int] = None,
831847
sequence_components: Optional[Exon] = None,
832848
to_stdout: Optional[bool] = None,
833849
) -> None:
850+
"""
851+
Report mismatches to log/stdout (if applicable).
852+
853+
:param letter: ...
854+
:type letter: HLA_TYPES
855+
:param all_combos: All possible combos
856+
:type all_combos: Dict[int, List[HLACombinedStandardResult]]
857+
:param seq: Sequence currently being interpretted.
858+
:type seq: np.ndarray
859+
:param threshold: Maximum allowed mismatches in a sequence compared to a standard, must be non-negative or None, defaults to None
860+
:type threshold: Optional[int], optional
861+
:param sequence_components: Components of a sequence, ex: Exon2, Intron, Exon3, defaults to None
862+
:type sequence_components: Optional[Exon], optional
863+
:param to_stdout: Print to STDOUT if true, defaults to None
864+
:type to_stdout: Optional[bool], optional
865+
:raises RuntimeError: Raised if threshold is < 0
866+
"""
834867
if threshold:
868+
if threshold < 0:
869+
raise RuntimeError("Threshold must be >=0 or None!")
835870
for i, combos in sorted(all_combos.items()):
836871
if i > threshold:
837872
if i == 0:
@@ -843,21 +878,15 @@ def get_mismatches(
843878
to_stdout=to_stdout,
844879
)
845880
break
846-
for cons in combos:
847-
misstrings = []
848-
_seq = np.array([int(nuc) for nuc in cons.standard.split("-")])
849-
# TODO: replace with https://stackoverflow.com/questions/16094563/numpy-get-index-where-value-is-true
850-
for n in np.flatnonzero(_seq ^ seq):
851-
base = EasyHLA.BIN2NUC[seq[n]]
852-
correct_base = EasyHLA.BIN2NUC[_seq[n]]
853-
if letter == "A" and n > 270:
854-
dex = n + 242
855-
else:
856-
dex = n + 1
857-
misstrings.append(f"{dex}:{base}->{correct_base}")
858-
self.print(
859-
";".join(misstrings) + ","
860-
f"{sequence_components},{sequence_components},{sequence_components}",
861-
log_level=logging.INFO,
862-
to_stdout=to_stdout,
863-
)
881+
# We can reuse get_mismatches here, as instead of the "best" match, we have "a match"
882+
mismatches = self.get_mismatches(
883+
letter=letter, best_matches=combos, seq=seq
884+
)
885+
_seq_str = ""
886+
if sequence_components:
887+
_seq_str = f"{sequence_components.two},{sequence_components.intron},{sequence_components.three}"
888+
self.print(
889+
f"{mismatches},{_seq_str}",
890+
log_level=logging.INFO,
891+
to_stdout=to_stdout,
892+
)

tests/easyhla_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,60 @@ def test_get_all_alleles(
239239
assert result_alleles.is_ambiguous() == exp_ambig
240240
assert result_alleles.alleles == exp_alleles
241241

242+
@pytest.mark.parametrize(
243+
"alleles, exp_result",
244+
[
245+
(
246+
[
247+
("A*11:01", "A*26:01"),
248+
("A*11:01", "A*26:01"),
249+
("A*11:19", "A*26:13"),
250+
],
251+
[
252+
("A*11:01", "A*26:01"),
253+
("A*11:01", "A*26:01"),
254+
("A*11:19", "A*26:13"),
255+
],
256+
),
257+
(
258+
[
259+
("A*11:01", "A*26:01"),
260+
("A*11:40", "A*23:01"),
261+
],
262+
[("A*11:01", "A*26:01")],
263+
),
264+
(
265+
[
266+
("A*11:01", "A*12:01"),
267+
("A*11:01", "A*12:01"),
268+
("A*11:40", "A*13:01"),
269+
],
270+
[("A*11:40", "A*13:01")],
271+
),
272+
(
273+
[
274+
("A*11:01", "A*12:01"),
275+
("A*13:01", "A*12:44"),
276+
("A*13:40", "A*12:01"),
277+
],
278+
[("A*13:01", "A*12:44"), ("A*13:40", "A*12:01")],
279+
),
280+
],
281+
)
282+
def test_filter_reportable_alleles(
283+
self,
284+
easyhla: EasyHLA,
285+
alleles: List[Tuple[str, str]],
286+
exp_result: List[Tuple[str, str]],
287+
):
288+
result = easyhla.filter_reportable_alleles(
289+
letter=easyhla.letter, alleles=Alleles(alleles=alleles)
290+
)
291+
292+
print(result)
293+
294+
assert result == exp_result
295+
242296

243297
@pytest.mark.parametrize("easyhla", ["B"], indirect=True)
244298
class TestEasyHLADiscreteHLATypeB:

0 commit comments

Comments
 (0)