Skip to content

Commit 67ba4f6

Browse files
authored
Merge pull request #15 from cfe-lab/ReformatBBLabAlleles
Incorporated David's suggestion that allows us to eliminate the shim …
2 parents 64343e5 + 80e8b2c commit 67ba4f6

File tree

2 files changed

+8
-33
lines changed

2 files changed

+8
-33
lines changed

src/hla_algorithm/hla_algorithm.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import csv
2-
import os
32
from collections.abc import Generator, Iterable, Sequence
43
from datetime import datetime
54
from io import TextIOBase
65
from operator import attrgetter
6+
from pathlib import Path
77
from typing import Final, Optional, TypedDict, cast
88

99
import numpy as np
@@ -49,6 +49,8 @@ class HLAAlgorithm:
4949

5050
COLUMN_IDS: Final[dict[str, int]] = {"A": 0, "B": 2, "C": 4}
5151

52+
DEFAULT_CONFIG_DIR: Final[Path] = Path(__file__).parent / "default_data"
53+
5254
def __init__(
5355
self,
5456
loaded_standards: Optional[LoadedStandards] = None,
@@ -136,12 +138,7 @@ def load_default_hla_standards() -> LoadedStandards:
136138
:return: List of known HLA standards
137139
:rtype: list[HLAStandard]
138140
"""
139-
standards_filename: str = HLAAlgorithm._path_join_shim(
140-
os.path.dirname(__file__),
141-
"default_data",
142-
"hla_standards.yaml",
143-
)
144-
with open(standards_filename) as standards_file:
141+
with open(HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_standards.yaml") as standards_file:
145142
return HLAAlgorithm.read_hla_standards(standards_file)
146143

147144
FREQUENCY_LOCUS_COLUMNS: dict[HLA_LOCUS, tuple[str, str]] = {
@@ -194,13 +191,6 @@ def read_hla_frequencies(
194191
hla_freqs[locus][protein_pair] += 1
195192
return hla_freqs
196193

197-
@staticmethod
198-
def _path_join_shim(*args) -> str:
199-
"""
200-
A shim for os.path.join which allows us to mock out the method easily in testing.
201-
"""
202-
return os.path.join(*args)
203-
204194
@staticmethod
205195
def load_default_hla_frequencies() -> dict[HLA_LOCUS, dict[HLAProteinPair, int]]:
206196
"""
@@ -210,12 +200,7 @@ def load_default_hla_frequencies() -> dict[HLA_LOCUS, dict[HLAProteinPair, int]]
210200
:rtype: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
211201
"""
212202
hla_freqs: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
213-
default_frequencies_filename: str = HLAAlgorithm._path_join_shim(
214-
os.path.dirname(__file__),
215-
"default_data",
216-
"hla_frequencies.csv",
217-
)
218-
with open(default_frequencies_filename, "r") as f:
203+
with open(HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_frequencies.csv") as f:
219204
hla_freqs = HLAAlgorithm.read_hla_frequencies(f)
220205
return hla_freqs
221206

tests/hla_algorithm_test.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,12 +1781,6 @@ def test_interpret_error_cases(
17811781
}
17821782

17831783

1784-
def test_path_join_shim():
1785-
expected_result: str = "/foo/bar/baz"
1786-
result: str = HLAAlgorithm._path_join_shim("/foo/bar", "baz")
1787-
assert expected_result == result
1788-
1789-
17901784
@pytest.mark.parametrize(
17911785
"raw_standards, raw_expected_result",
17921786
[
@@ -1900,7 +1894,7 @@ def test_read_hla_standards(
19001894
# Also try reading it from a file.
19011895
p = tmp_path / "hla_standards.yaml"
19021896
p.write_text(standards_file_str)
1903-
mocker.patch.object(HLAAlgorithm, "_path_join_shim", return_value=str(p))
1897+
mocker.patch.object(HLAAlgorithm, "DEFAULT_CONFIG_DIR", tmp_path)
19041898
load_result: LoadedStandards = HLAAlgorithm.load_default_hla_standards()
19051899
assert load_result == expected_result
19061900

@@ -2233,7 +2227,7 @@ def test_read_hla_frequencies(
22332227
# Now try loading these from a file.
22342228
p = tmp_path / "hla_frequencies.csv"
22352229
p.write_text(frequencies_str)
2236-
mocker.patch.object(HLAAlgorithm, "_path_join_shim", return_value=str(p))
2230+
mocker.patch.object(HLAAlgorithm, "DEFAULT_CONFIG_DIR", tmp_path)
22372231
load_result: dict[HLA_LOCUS, dict[HLAProteinPair, int]] = (
22382232
HLAAlgorithm.load_default_hla_frequencies()
22392233
)
@@ -2328,11 +2322,7 @@ def test_use_config_all_defaults(
23282322
freq_path: Path = tmp_path / "hla_frequencies.csv"
23292323
freq_path.write_text(fake_frequencies_str)
23302324

2331-
mocker.patch.object(
2332-
HLAAlgorithm,
2333-
"_path_join_shim",
2334-
side_effect=[os.fspath(standards_path), os.fspath(freq_path)],
2335-
)
2325+
mocker.patch.object(HLAAlgorithm, "DEFAULT_CONFIG_DIR", tmp_path)
23362326

23372327
hla_algorithm: HLAAlgorithm = HLAAlgorithm.use_config()
23382328

0 commit comments

Comments
 (0)