diff --git a/src/hla_algorithm/hla_algorithm.py b/src/hla_algorithm/hla_algorithm.py index 5171503..369eadd 100644 --- a/src/hla_algorithm/hla_algorithm.py +++ b/src/hla_algorithm/hla_algorithm.py @@ -1,9 +1,9 @@ import csv -import os from collections.abc import Generator, Iterable, Sequence from datetime import datetime from io import TextIOBase from operator import attrgetter +from pathlib import Path from typing import Final, Optional, TypedDict, cast import numpy as np @@ -49,6 +49,8 @@ class HLAAlgorithm: COLUMN_IDS: Final[dict[str, int]] = {"A": 0, "B": 2, "C": 4} + DEFAULT_CONFIG_DIR: Final[Path] = Path(__file__).parent / "default_data" + def __init__( self, loaded_standards: Optional[LoadedStandards] = None, @@ -136,12 +138,7 @@ def load_default_hla_standards() -> LoadedStandards: :return: List of known HLA standards :rtype: list[HLAStandard] """ - standards_filename: str = HLAAlgorithm._path_join_shim( - os.path.dirname(__file__), - "default_data", - "hla_standards.yaml", - ) - with open(standards_filename) as standards_file: + with open(HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_standards.yaml") as standards_file: return HLAAlgorithm.read_hla_standards(standards_file) FREQUENCY_LOCUS_COLUMNS: dict[HLA_LOCUS, tuple[str, str]] = { @@ -194,13 +191,6 @@ def read_hla_frequencies( hla_freqs[locus][protein_pair] += 1 return hla_freqs - @staticmethod - def _path_join_shim(*args) -> str: - """ - A shim for os.path.join which allows us to mock out the method easily in testing. - """ - return os.path.join(*args) - @staticmethod def load_default_hla_frequencies() -> dict[HLA_LOCUS, dict[HLAProteinPair, int]]: """ @@ -210,12 +200,7 @@ def load_default_hla_frequencies() -> dict[HLA_LOCUS, dict[HLAProteinPair, int]] :rtype: dict[HLA_LOCUS, dict[HLAProteinPair, int]] """ hla_freqs: dict[HLA_LOCUS, dict[HLAProteinPair, int]] - default_frequencies_filename: str = HLAAlgorithm._path_join_shim( - os.path.dirname(__file__), - "default_data", - "hla_frequencies.csv", - ) - with open(default_frequencies_filename, "r") as f: + with open(HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_frequencies.csv") as f: hla_freqs = HLAAlgorithm.read_hla_frequencies(f) return hla_freqs diff --git a/tests/hla_algorithm_test.py b/tests/hla_algorithm_test.py index 6c8e9c7..52249f9 100644 --- a/tests/hla_algorithm_test.py +++ b/tests/hla_algorithm_test.py @@ -1781,12 +1781,6 @@ def test_interpret_error_cases( } -def test_path_join_shim(): - expected_result: str = "/foo/bar/baz" - result: str = HLAAlgorithm._path_join_shim("/foo/bar", "baz") - assert expected_result == result - - @pytest.mark.parametrize( "raw_standards, raw_expected_result", [ @@ -1900,7 +1894,7 @@ def test_read_hla_standards( # Also try reading it from a file. p = tmp_path / "hla_standards.yaml" p.write_text(standards_file_str) - mocker.patch.object(HLAAlgorithm, "_path_join_shim", return_value=str(p)) + mocker.patch.object(HLAAlgorithm, "DEFAULT_CONFIG_DIR", tmp_path) load_result: LoadedStandards = HLAAlgorithm.load_default_hla_standards() assert load_result == expected_result @@ -2233,7 +2227,7 @@ def test_read_hla_frequencies( # Now try loading these from a file. p = tmp_path / "hla_frequencies.csv" p.write_text(frequencies_str) - mocker.patch.object(HLAAlgorithm, "_path_join_shim", return_value=str(p)) + mocker.patch.object(HLAAlgorithm, "DEFAULT_CONFIG_DIR", tmp_path) load_result: dict[HLA_LOCUS, dict[HLAProteinPair, int]] = ( HLAAlgorithm.load_default_hla_frequencies() ) @@ -2328,11 +2322,7 @@ def test_use_config_all_defaults( freq_path: Path = tmp_path / "hla_frequencies.csv" freq_path.write_text(fake_frequencies_str) - mocker.patch.object( - HLAAlgorithm, - "_path_join_shim", - side_effect=[os.fspath(standards_path), os.fspath(freq_path)], - ) + mocker.patch.object(HLAAlgorithm, "DEFAULT_CONFIG_DIR", tmp_path) hla_algorithm: HLAAlgorithm = HLAAlgorithm.use_config()