Skip to content

Commit a03e618

Browse files
author
Richard Liang
committed
WIP: fixing a lot of mypy errors.
1 parent bcb7921 commit a03e618

File tree

14 files changed

+122
-139
lines changed

14 files changed

+122
-139
lines changed

.devcontainer/Dockerfile

Lines changed: 0 additions & 23 deletions
This file was deleted.

.github/workflows/test.yml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ jobs:
4747
- name: Run tests
4848
run: uv run pytest --junitxml=pytest.xml
4949

50-
# TODO: Look into github actions, these are out of date
51-
# - name: Upload coverage data
52-
# uses: actions/upload-artifact@v3
53-
# with:
54-
# name: coverage-data
55-
# path: coverage.xml
56-
57-
# - name: Publish Test Report
58-
# uses: mikepenz/action-junit-report@v3
59-
# if: success() || failure()
60-
# with:
61-
# report_paths: unit_test.xml
50+
# TODO: Look into github actions, these are out of date
51+
# - name: Upload coverage data
52+
# uses: actions/upload-artifact@v3
53+
# with:
54+
# name: coverage-data
55+
# path: coverage.xml
56+
57+
# - name: Publish Test Report
58+
# uses: mikepenz/action-junit-report@v3
59+
# if: success() || failure()
60+
# with:
61+
# report_paths: unit_test.xml

.yamllint.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
ignore:
22
- .git/*
33
- .venv/*
4+
- src/easyhla/default_data/hla_standards.yaml
45

56
extends: default
67

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,4 @@ match = "src/**/*.py"
156156
[tool.mypy]
157157
plugins = ["numpy.typing.mypy_plugin"]
158158
ignore_missing_imports = true
159+
exclude = ["scripts/"]

src/easyhla/bblab.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@
55
from pathlib import Path
66
from typing import Any, Optional
77

8-
import Bio
98
import typer
9+
from Bio.Seq import Seq
10+
from Bio.SeqIO import parse
1011

1112
from .bblab_lib import (
1213
EXON_AND_OTHER_EXON,
1314
HLAInterpretationRow,
1415
HLAMismatchRow,
1516
pair_exons,
1617
)
17-
from .easyhla import DATE_FORMAT, EXON_NAME, EasyHLA
18+
from .easyhla import DATE_FORMAT, EasyHLA
1819
from .models import HLAInterpretation, HLASequence
20+
from .utils import EXON_NAME
1921

2022
logger = logging.Logger(__name__, logging.ERROR)
2123

@@ -49,21 +51,21 @@ def log_and_print(
4951

5052

5153
def report_unmatched_sequences(
52-
unmatched: dict[EXON_NAME, dict[str, Bio.SeqIO.SeqRecord]],
54+
unmatched: dict[EXON_NAME, dict[str, Seq]],
5355
to_stdout: bool = False,
5456
) -> None:
5557
"""
5658
Report exon sequences that did not have a matching exon.
5759
5860
:param unmatched: unmatched exon sequences, grouped by which exon they represent
59-
:type unmatched: dict[EXON_NAME, dict[str, Bio.SeqIO.SeqRecord]]
61+
:type unmatched: dict[EXON_NAME, dict[str, Seq]]
6062
:param to_stdout: ..., defaults to None
6163
:type to_stdout: Optional[bool], optional
6264
"""
6365
for exon, other_exon in EXON_AND_OTHER_EXON:
64-
for entry in unmatched[exon]:
66+
for sequence_id in unmatched[exon].keys():
6567
log_and_print(
66-
f"No matching {other_exon} for {entry.description}",
68+
f"No matching {other_exon} for {sequence_id}",
6769
to_stdout=to_stdout,
6870
)
6971

@@ -79,6 +81,8 @@ def process_from_file_to_files(
7981
):
8082
if threshold and threshold < 0:
8183
raise RuntimeError("Threshold must be >=0 or None!")
84+
elif threshold is None:
85+
threshold = 0
8286

8387
rows: list[HLAInterpretationRow] = []
8488
mismatch_rows: list[HLAMismatchRow] = []
@@ -93,13 +97,13 @@ def process_from_file_to_files(
9397
)
9498

9599
matched_sequences: list[HLASequence]
96-
unmatched: dict[EXON_NAME, dict[str, Bio.SeqIO.SeqRecord]]
100+
unmatched: dict[EXON_NAME, dict[str, Seq]]
97101

98102
with open(filename, "r", encoding="utf-8") as f:
99103
matched_sequences, unmatched = pair_exons(
100-
Bio.SeqIO.parse(f, "fasta"),
104+
parse(f, "fasta"),
101105
locus.value,
102-
list(hla_alg.standards.values())[0],
106+
list(hla_alg.hla_standards[locus.value].values())[0],
103107
)
104108

105109
for hla_sequence in matched_sequences:
@@ -133,10 +137,10 @@ def process_from_file_to_files(
133137
row: HLAInterpretationRow = HLAInterpretationRow.summary_row(result)
134138
rows.append(row)
135139

136-
mismatch_rows.extend(result.mismatch_rows())
140+
mismatch_rows.extend(HLAMismatchRow.mismatch_rows(result))
137141

138142
npats += 1
139-
nseqs += hla_sequence.num_seqs
143+
nseqs += hla_sequence.num_sequences_used
140144

141145
report_unmatched_sequences(unmatched, to_stdout=to_stdout)
142146

@@ -171,11 +175,11 @@ def process_from_file_to_files(
171175
),
172176
)
173177
mismatch_csv.writeheader()
174-
mismatch_csv.writerows([dict[row] for row in mismatch_rows])
178+
mismatch_csv.writerows([dict(row) for row in mismatch_rows])
175179

176180
log_and_print(
177181
f"{npats} patients, {nseqs} sequences processed.",
178-
log_level=logger.INFO,
182+
log_level=logging.INFO,
179183
to_stdout=to_stdout,
180184
)
181185

src/easyhla/clinical_hla.py

Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import os
88
from datetime import datetime
9-
from typing import Final, Optional, TypedDict
9+
from typing import Final, Literal, Optional, TypedDict, cast
1010

1111
from sqlalchemy import create_engine, event
1212
from sqlalchemy.engine import Engine
@@ -36,38 +36,15 @@
3636
)
3737

3838
# Database connection parameters:
39-
HLA_DB_USER: Final[str] = os.environ.get("HLA_DB_USER")
40-
HLA_DB_PASSWORD: Final[str] = os.environ.get("HLA_DB_PASSWORD")
39+
HLA_DB_USER: Final[Optional[str]] = os.environ.get("HLA_DB_USER")
40+
HLA_DB_PASSWORD: Final[Optional[str]] = os.environ.get("HLA_DB_PASSWORD")
4141
HLA_DB_HOST: Final[str] = os.environ.get("HLA_DB_HOST", "192.168.67.7")
42-
HLA_DB_PORT: Final[int] = os.environ.get("HLA_DB_PORT", 1521)
42+
HLA_DB_PORT: Final[int] = int(os.environ.get("HLA_DB_PORT", 1521))
4343
HLA_DB_SERVICE_NAME: Final[str] = os.environ.get("HLA_DB_SERVICE_NAME", "cfe")
4444

45-
HLA_ORACLE_LIB_PATH: Final[str] = os.environ.get("HLA_ORACLE_LIB_PATH")
46-
47-
# These are the "configuration files" that the algorithm uses; these are or may
48-
# be updated, in which case you specify the path to the new version in the
49-
# environment.
50-
HLA_STANDARDS: Final[str] = os.environ.get("HLA_STANDARDS")
51-
HLA_FREQUENCIES: Final[str] = os.environ.get("HLA_FREQUENCIES")
52-
53-
54-
def prepare_interpretation_for_serialization(
55-
interpretation: HLAInterpretation,
56-
locus: HLA_LOCUS,
57-
processing_datetime: datetime,
58-
) -> HLASequenceA | HLASequenceB | HLASequenceC:
59-
"""
60-
Prepare an HLA interpretation for output.
61-
"""
62-
if locus == "A":
63-
return HLASequenceA.build_from_interpretation(
64-
interpretation, processing_datetime
65-
)
66-
elif locus == "B":
67-
return HLASequenceB.build_from_interpretation(
68-
interpretation, processing_datetime
69-
)
70-
return HLASequenceC.build_from_interpretation(interpretation, processing_datetime)
45+
HLA_ORACLE_LIB_PATH: Final[str] = os.environ.get(
46+
"HLA_ORACLE_LIB_PATH", "/opt/oracle/instant_client"
47+
)
7148

7249

7350
class SequencesByLocus(TypedDict):
@@ -91,10 +68,10 @@ def interpret_sequences(
9168

9269
def clinical_hla_driver(
9370
input_dir: str,
71+
hla_a_results: str,
72+
hla_b_results: str,
73+
hla_c_results: str,
9474
db_engine: Optional[Engine] = None,
95-
hla_a_results: Optional[str] = None,
96-
hla_b_results: Optional[str] = None,
97-
hla_c_results: Optional[str] = None,
9875
standards_path: Optional[str] = None,
9976
frequencies_path: Optional[str] = None,
10077
) -> None:
@@ -105,7 +82,8 @@ def clinical_hla_driver(
10582
"C": [],
10683
}
10784
for locus in ("B", "C"):
108-
sequences[locus] = read_bc_sequences(input_dir, locus, logger)
85+
b_or_c: Literal["B", "C"] = cast(Literal["B", "C"], locus)
86+
sequences[b_or_c] = read_bc_sequences(input_dir, b_or_c, logger)
10987

11088
# Perform interpretations:
11189
interpretations: dict[HLA_LOCUS, list[HLAInterpretation]] = {
@@ -116,25 +94,30 @@ def clinical_hla_driver(
11694
processing_datetime: datetime = datetime.now()
11795
easyhla: EasyHLA = EasyHLA.use_config(standards_path, frequencies_path)
11896
for locus in ("A", "B", "C"):
119-
interpretations[locus] = interpret_sequences(easyhla, sequences[locus])
97+
interpretations[cast(HLA_LOCUS, locus)] = interpret_sequences(
98+
easyhla, sequences[cast(HLA_LOCUS, locus)]
99+
)
120100

121101
# Prepare the interpretations for output:
122102
seqs_for_db: SequencesByLocus = {
123103
"A": [],
124104
"B": [],
125105
"C": [],
126106
}
127-
for locus in ("A", "B", "C"):
128-
# Each locus has a slightly different schema in the database, so we
129-
# customize for each one.
130-
for interp in interpretations[locus]:
131-
seqs_for_db[locus].append(
132-
prepare_interpretation_for_serialization(
133-
interp,
134-
locus,
135-
processing_datetime,
136-
)
137-
)
107+
# This next bit looks repetitive but mypy didn't like my solution for doing
108+
# this in a loop (because each one is a different type).
109+
for interp in interpretations["A"]:
110+
seqs_for_db["A"].append(
111+
HLASequenceA.build_from_interpretation(interp, processing_datetime)
112+
)
113+
for interp in interpretations["B"]:
114+
seqs_for_db["B"].append(
115+
HLASequenceB.build_from_interpretation(interp, processing_datetime)
116+
)
117+
for interp in interpretations["C"]:
118+
seqs_for_db["C"].append(
119+
HLASequenceC.build_from_interpretation(interp, processing_datetime)
120+
)
138121

139122
# First, write to the output files:
140123
output_files: dict[HLA_LOCUS, str] = {
@@ -148,19 +131,23 @@ def clinical_hla_driver(
148131
"C": HLASequenceC.CSV_HEADER,
149132
}
150133
for locus in ("A", "B", "C"):
151-
if len(seqs_for_db[locus]) > 0:
152-
with open(output_files[locus], "w") as f:
134+
if len(seqs_for_db[cast(HLA_LOCUS, locus)]) > 0:
135+
with open(output_files[cast(HLA_LOCUS, locus)], "w") as f:
153136
result_csv: csv.DictWriter = csv.DictWriter(
154-
f, fieldnames=csv_headers[locus], extrasaction="ignore"
137+
f,
138+
fieldnames=csv_headers[cast(HLA_LOCUS, locus)],
139+
extrasaction="ignore",
155140
)
156141
result_csv.writeheader()
157-
result_csv.writerows(dataclasses.asdict(x) for x in seqs_for_db[locus])
142+
result_csv.writerows(
143+
dataclasses.asdict(x) for x in seqs_for_db[cast(HLA_LOCUS, locus)]
144+
)
158145

159146
# Finally, write to the DB.
160147
if db_engine is not None:
161148
with Session(db_engine) as session:
162149
for locus in ("A", "B", "C"):
163-
session.add_all(seqs_for_db[locus])
150+
session.add_all(seqs_for_db[cast(HLA_LOCUS, locus)])
164151
session.commit()
165152

166153

@@ -246,10 +233,10 @@ def schema_workaround(dbapi_connection, _):
246233

247234
clinical_hla_driver(
248235
args.input_dir,
249-
db_engine,
250236
args.hla_a_results,
251237
args.hla_b_results,
252238
args.hla_c_results,
239+
db_engine,
253240
args.hla_standards,
254241
args.hla_frequencies,
255242
)

src/easyhla/clinical_hla_lib.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def get_common_serialization_fields(
7676
"alleles_all": ap.stringify(),
7777
"ambiguous": str(ap.is_ambiguous()),
7878
"homozygous": str(ap.is_homozygous()),
79-
"mismatch_count": interpretation.lowest_mismatch_count(),
79+
"mismatch_count": mismatch_count,
8080
"mismatches": mismatches_str,
8181
"enterdate": processing_datetime,
8282
}
@@ -94,7 +94,7 @@ class HLASequenceA(HLADBBase):
9494
alleles_all: Mapped[Optional[str]] = mapped_column(String)
9595
ambiguous: Mapped[Optional[str]] = mapped_column(String)
9696
homozygous: Mapped[Optional[str]] = mapped_column(String)
97-
mismatch_count: Mapped[Optional[str]] = mapped_column(Integer)
97+
mismatch_count: Mapped[Optional[int]] = mapped_column(Integer)
9898
mismatches: Mapped[Optional[str]] = mapped_column(String)
9999
seq: Mapped[Optional[str]] = mapped_column(String)
100100
enterdate: Mapped[Optional[datetime]] = mapped_column(DateTime)
@@ -140,7 +140,7 @@ class HLASequenceB(HLADBBase):
140140
alleles_all: Mapped[Optional[str]] = mapped_column(String)
141141
ambiguous: Mapped[Optional[str]] = mapped_column(String)
142142
homozygous: Mapped[Optional[str]] = mapped_column(String)
143-
mismatch_count: Mapped[Optional[str]] = mapped_column(Integer)
143+
mismatch_count: Mapped[Optional[int]] = mapped_column(Integer)
144144
mismatches: Mapped[Optional[str]] = mapped_column(String)
145145
b5701: Mapped[Optional[str]] = mapped_column(String)
146146
b5701_dist: Mapped[Optional[int]] = mapped_column(Integer)
@@ -201,7 +201,7 @@ class HLASequenceC(HLADBBase):
201201
alleles_all: Mapped[Optional[str]] = mapped_column(String)
202202
ambiguous: Mapped[Optional[str]] = mapped_column(String)
203203
homozygous: Mapped[Optional[str]] = mapped_column(String)
204-
mismatch_count: Mapped[Optional[str]] = mapped_column(Integer)
204+
mismatch_count: Mapped[Optional[int]] = mapped_column(Integer)
205205
mismatches: Mapped[Optional[str]] = mapped_column(String)
206206
seqa: Mapped[Optional[str]] = mapped_column(String)
207207
seqb: Mapped[Optional[str]] = mapped_column(String)

src/easyhla/interpret_from_json.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,7 @@ def main():
3838
hla_input.hla_std_path,
3939
hla_input.hla_freq_path,
4040
)
41-
interp: HLAInterpretation = easyhla.interpret(
42-
hla_input.hla_sequence(),
43-
hla_input.locus,
44-
)
41+
interp: HLAInterpretation = easyhla.interpret(hla_input.hla_sequence())
4542
print(HLAResult.build_from_interpretation(interp).model_dump_json())
4643

4744

src/easyhla/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from collections.abc import Iterable
33
from operator import itemgetter
4-
from typing import ClassVar, Final, Optional, Self
4+
from typing import Final, Optional, Self
55

66
import numpy as np
77
from pydantic import BaseModel, ConfigDict
@@ -134,8 +134,10 @@ def __lt__(self, other: "HLAProteinPair") -> bool:
134134
)
135135
return me_tuple < other_tuple
136136

137-
UNMAPPED: ClassVar[Final[str]] = "unmapped"
138-
DEPRECATED: ClassVar[Final[str]] = "deprecated"
137+
# Note: originally these were annotated as ClassVar[Final[str]] but this
138+
# isn't supported in versions of Python prior to 3.13.
139+
UNMAPPED: Final[str] = "unmapped"
140+
DEPRECATED: Final[str] = "deprecated"
139141

140142
class NonAlleleException(Exception):
141143
def __init__(

0 commit comments

Comments
 (0)