Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/hla_algorithm/hla_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
BIN2NUC,
HLA_LOCUS,
StoredHLAStandards,
allele_coordinates_sort_key,
count_strict_mismatches,
nuc2bin,
sort_allele_pairs,
)

DATE_FORMAT = "%a %b %d %H:%M:%S %Z %Y"
Expand Down Expand Up @@ -277,7 +279,13 @@ def combine_standards_stepper(
# that looks like what you get when you sequence HLA.
std_bin = np.array(std_b.sequence) | np.array(std_a.sequence)
allele_pair: tuple[str, str] = cast(
tuple[str, str], tuple(sorted((std_a.allele, std_b.allele)))
tuple[str, str],
tuple(
sorted(
(std_a.allele, std_b.allele),
key=allele_coordinates_sort_key,
)
),
)

# There could be more than one combined standard with the
Expand Down Expand Up @@ -363,7 +371,7 @@ def combine_standards(
if mismatch_count <= cutoff:
combined_std: HLACombinedStandard = HLACombinedStandard(
standard_bin=combined_std_bin,
possible_allele_pairs=tuple(sorted(pair_list)),
possible_allele_pairs=tuple(sort_allele_pairs(pair_list)),
)
result[combined_std] = mismatch_count

Expand Down
3 changes: 2 additions & 1 deletion src/hla_algorithm/interpret_from_json_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
check_bases,
check_length,
nuc2bin,
sort_allele_pairs,
)


Expand Down Expand Up @@ -143,7 +144,7 @@ def build_from_interpretation(

return HLAResult(
seqs=seqs,
alleles_all=[f"{x[0]} - {x[1]}" for x in aps.sort_pairs()],
alleles_all=[f"{x[0]} - {x[1]}" for x in sort_allele_pairs(aps.allele_pairs)],
alleles_clean=alleles_clean,
alleles_for_mismatches=f"{rep_ap[0]} - {rep_ap[1]}",
mismatches=[str(x) for x in match_details.mismatches],
Expand Down
201 changes: 150 additions & 51 deletions src/hla_algorithm/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from collections.abc import Iterable
from dataclasses import dataclass, field
from operator import itemgetter
from typing import Final, Optional

Expand All @@ -14,6 +15,7 @@
bin2nuc,
count_forgiving_mismatches,
nuc2bin,
sort_allele_pairs,
)


Expand Down Expand Up @@ -212,16 +214,17 @@ def from_frequency_entry(
)


GeneCoord = tuple[str, ...]


class AllelePairs(BaseModel):
allele_pairs: list[tuple[str, str]]

def is_homozygous(self) -> bool:
"""
Determine the homozygousness of alleles.

Homozygousity meaning a pair is matching on both sides, ex:
`Cw*0722 - Cw*0722`
Determine the homozygousness of these allele pairs.

A pair is homozygous if both elements match, e.g. C*07:22 - C*07:22.
If *any* pair of alleles matches, then we declare the whole set to be
homozygous.

Expand Down Expand Up @@ -287,7 +290,7 @@ def get_protein_pairs(self) -> set[HLAProteinPair]:
for e in self.get_paired_gene_coordinates(True)
}

def get_unambiguous_allele_pairs(
def _get_unambiguous_allele_pairs(
self,
frequencies: dict[HLAProteinPair, int],
) -> list[tuple[str, str]]:
Expand Down Expand Up @@ -333,6 +336,115 @@ def get_unambiguous_allele_pairs(

return reduced_set

@dataclass
class CleanPrefixIntermediateResult:
common_prefix: GeneCoord = ()
second_prefix: Optional[GeneCoord] = None
remaining_prefixes: list[GeneCoord] = field(default_factory=list)

@staticmethod
def _identify_clean_prefix_in_pairs(
unambiguous_pairs: list[tuple[GeneCoord, GeneCoord]],
) -> CleanPrefixIntermediateResult:
"""
Identify a "clean" gene coordinate "prefix" in the given unambiguous pairs.

This prefix can occur in either element of a given pair. For example,
if the pairs are
- B*01:01:01 - B*01:01:02:110G
- B*01:01:02:99 - B*01:22
then the longest common prefix is ("B*01", "01", "02").

If we happen to find an "exact" allele that occurs in all the pairs, then
that's a "clean" allele and we report it back, even if it's shorter than
the longest common prefix.

A precondition is that the input must be an unambiguous collection of
pairs. The algorithm may not return cogent values if not.

Return a tuple containing this clean prefix, a second clean prefix
if one is found, and a list containing all the remaining
alleles in the pairs if such a second prefix is *not* found (if a
second prefix is found, this list will be empty).
"""
if len(unambiguous_pairs) == 0:
return AllelePairs.CleanPrefixIntermediateResult()

common_prefix: GeneCoord = ()
second_prefix: Optional[GeneCoord] = None
remaining_prefixes: list[GeneCoord] = []

max_length: int = max(
[max(len(pair[0]), len(pair[1])) for pair in unambiguous_pairs]
)
for i in range(max_length, 0, -1):
# Note that this may not "cut down" some pairs if they're shorter
# than max_length.
curr_pairs = [(pair[0][0:i], pair[1][0:i]) for pair in unambiguous_pairs]

# On the first iteration, we might "accidentally" find exact matches
# which are shorter (or equal to) than max_length; if so, great
# ¯\_(ツ)_/¯
common_prefixes: set[GeneCoord] = set(curr_pairs[0])
for curr_pair in curr_pairs[1:]:
common_prefixes = common_prefixes & set(curr_pair)

if len(common_prefixes) == 0:
continue

# Having reached here, we know that we found at least one common
# prefix.
common_prefix = common_prefixes.pop()
if len(common_prefixes) == 1:
# The other prefix is good too.
second_prefix = common_prefixes.pop()

else:
# Having reached here, we know that we found exactly one common
# prefix, and will look for the best prefix in what remains.
for curr_pair in curr_pairs:
curr_unique_prefixes: set[GeneCoord] = set(curr_pair)
if len(curr_unique_prefixes) != 1:
# There were two distinct alleles in this pair, one of which
# was longest_prefix, so we retain the other one.
# (If there had only been one, then it must have been a
# homozygous pair "[longest_prefix] - [longest_prefix]",
# so we want to retain one "copy" for the next stage.)
curr_unique_prefixes.remove(common_prefix)

remaining_prefixes.append(curr_unique_prefixes.pop())
if i > 1:
# This is unnecessary but it gets us 100% test coverage
# ¯\_(ツ)_/¯
break

return AllelePairs.CleanPrefixIntermediateResult(
common_prefix, second_prefix, remaining_prefixes
)

@staticmethod
def _identify_longest_prefix(allele_prefixes: list[GeneCoord]) -> GeneCoord:
"""
Identify the longest gene coordinate "prefix" in the given allele prefixes.

Precondition: that the input must all share at least the same first
coordinate. The algorithm may not return cogent values if not.

Precondition: the specified allele prefixes do not all perfectly match,
so we lose nothing by trimming one coordinate off the end of all of
them.
"""
longest_prefix: GeneCoord = ()
if len(allele_prefixes) > 0:
max_length: int = max([len(allele) for allele in allele_prefixes])
for i in range(max_length - 1, 0, -1):
curr_prefixes: set[GeneCoord] = {allele[0:i] for allele in allele_prefixes}
if len(curr_prefixes) == 1:
longest_prefix = curr_prefixes.pop()
if i > 1:
break
return longest_prefix

def best_common_allele_pair_str(
self,
frequencies: dict[HLAProteinPair, int],
Expand All @@ -342,16 +454,14 @@ def best_common_allele_pair_str(

The allele pairs are filtered to an unambiguous set (using the specified
frequencies to determine which ones to retain). Then, the "best common
coordinates" for all the remaining allele allele pairs are used to build
coordinates" for all the remaining allele pairs are used to build
a string representation of the set.

Example: if, after filtering, the allele pairs remaining are:
```
[ [A*11:02:01, A*12:01],
[A*11:02:02, A*12:02],
[A*11:02:03, A*12:03] ]
```
we expect to get `A*11:02 - A*12`.
- A*11:02:01 - A*12:01
- A*11:02:02 - A*12:02
- A*11:02:03 - A*12:03
we expect to get "A*11:02 - A*12".

:return: A string representing the best common allele pair, and the
unambiguous set this string represents.
Expand All @@ -360,47 +470,36 @@ def best_common_allele_pair_str(
# Starting with an unambiguous set assures that we will definitely get
# a result.
unambiguous_aps: AllelePairs = AllelePairs(
allele_pairs=self.get_unambiguous_allele_pairs(frequencies)
allele_pairs=self._get_unambiguous_allele_pairs(frequencies)
)
paired_gene_coordinates: list[tuple[list[str], list[str]]] = (
unambiguous_aps.get_paired_gene_coordinates()
unambiguous_aps.get_paired_gene_coordinates(digits_only=False)
)

clean_allele: list[str] = []
for n in [0, 1]:
for i in [4, 3, 2, 1]:
all_leading_coordinates = {
":".join(a[n][0:i]) for a in paired_gene_coordinates
}
if len(all_leading_coordinates) == 1:
best_common_coords = all_leading_coordinates.pop()
clean_allele.append(
re.sub(
r"[A-Z]$",
"",
best_common_coords,
)
)
if i > 1:
# This branch is unnecessary but it gets us 100% code
# coverage ¯\_(ツ)_/¯
break
# Look for the longest common prefix present in all pairs.
curr_pairs: list[tuple[GeneCoord, GeneCoord]] = [
(tuple(pair[0]), tuple(pair[1])) for pair in paired_gene_coordinates
]

clean_allele_pair_str: str = " - ".join(clean_allele)
return (clean_allele_pair_str, set(unambiguous_aps.allele_pairs))
intermediate_data: AllelePairs.CleanPrefixIntermediateResult = (
self._identify_clean_prefix_in_pairs(curr_pairs)
)

def sort_pairs(self) -> list[tuple[str, str]]:
"""
Sort the pairs according to "coordinate order".
second_prefix: GeneCoord = (
intermediate_data.second_prefix or self._identify_longest_prefix(
intermediate_data.remaining_prefixes
)
)

If there's a tie, a last letter is used to attempt to break the tie.
"""
return sorted(
self.allele_pairs,
key=lambda pair: (
allele_coordinates_sort_key(pair[0]),
allele_coordinates_sort_key(pair[1]),
),
# Turn the two prefixes we found into strings and strip any trailing
# letters.
clean_allele_pair: list[str] = [
re.sub(r"[A-Z]$", "", ":".join(allele))
for allele in (intermediate_data.common_prefix, second_prefix)
]
return (
" - ".join(sorted(clean_allele_pair, key=allele_coordinates_sort_key)),
set(unambiguous_aps.allele_pairs),
)

def stringify(self, sorted=True, max_length: int = 3900) -> str:
Expand All @@ -415,7 +514,7 @@ def stringify(self, sorted=True, max_length: int = 3900) -> str:
"""
allele_pairs: list[tuple[str, str]] = self.allele_pairs
if sorted:
allele_pairs = self.sort_pairs()
allele_pairs = sort_allele_pairs(self.allele_pairs)
summary_str: str = ";".join([f"{_a[0]} - {_a[1]}" for _a in allele_pairs])
if len(summary_str) > max_length:
summary_str = re.sub(
Expand All @@ -426,7 +525,7 @@ def stringify(self, sorted=True, max_length: int = 3900) -> str:
return summary_str

@classmethod
def get_allele_pairs(
def combine_allele_pairs(
cls,
combined_standards: Iterable[HLACombinedStandard],
) -> "AllelePairs":
Expand All @@ -441,7 +540,7 @@ def get_allele_pairs(
all_allele_pairs: list[tuple[str, str]] = []
for combined_std in combined_standards:
all_allele_pairs.extend(combined_std.possible_allele_pairs)
all_allele_pairs.sort()
all_allele_pairs = sort_allele_pairs(all_allele_pairs)
return cls(allele_pairs=all_allele_pairs)

def contains_allele(self, allele_name: str) -> bool:
Expand Down Expand Up @@ -474,7 +573,7 @@ def best_matches(self) -> set[HLACombinedStandard]:
}

def best_matching_allele_pairs(self) -> AllelePairs:
return AllelePairs.get_allele_pairs(self.best_matches())
return AllelePairs.combine_allele_pairs(self.best_matches())

def best_common_allele_pair(
self,
Expand All @@ -491,7 +590,7 @@ def best_common_allele_pair(
ap_to_cs[ap] = cs

# Get an unambiguous set of allele pairs from the best matches:
best_aps: AllelePairs = AllelePairs.get_allele_pairs(best_matches)
best_aps: AllelePairs = AllelePairs.combine_allele_pairs(best_matches)
clean_ap_str: str
best_unambiguous: set[tuple[str, str]]
clean_ap_str, best_unambiguous = best_aps.best_common_allele_pair_str(
Expand Down
22 changes: 22 additions & 0 deletions src/hla_algorithm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,28 @@ def allele_coordinates_sort_key(allele: str) -> tuple[tuple[int, ...], str]:
return (integer_part, letters_at_end)


def allele_pair_sort_key(pair: tuple[str, str]) -> tuple[
tuple[int, ...], str, tuple[int, ...], str
]:
"""
Produce a sortable key for an allele pair.

Pairs should be sorted according to "coordinate order".
If there's a tie, a last letter is used to attempt to break the tie.
"""
return (
allele_coordinates_sort_key(pair[0])
+ allele_coordinates_sort_key(pair[1])
)


def sort_allele_pairs(allele_pairs: Iterable[tuple[str, str]]) -> list[tuple[str, str]]:
"""
Sort the pairs according to "coordinate order".
"""
return sorted(allele_pairs, key=allele_pair_sort_key)


class HLARawStandard(BaseModel):
allele: str
exon2: str
Expand Down
Loading