Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 5 additions & 20 deletions src/hla_algorithm/hla_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import csv
import os
from collections.abc import Generator, Iterable, Sequence
from datetime import datetime
from io import TextIOBase
from operator import attrgetter
from pathlib import Path
from typing import Final, Optional, TypedDict, cast

import numpy as np
Expand Down Expand Up @@ -49,6 +49,8 @@ class HLAAlgorithm:

COLUMN_IDS: Final[dict[str, int]] = {"A": 0, "B": 2, "C": 4}

DEFAULT_CONFIG_DIR: Final[Path] = Path(__file__).parent / "default_data"

def __init__(
self,
loaded_standards: Optional[LoadedStandards] = None,
Expand Down Expand Up @@ -136,12 +138,7 @@ def load_default_hla_standards() -> LoadedStandards:
:return: List of known HLA standards
:rtype: list[HLAStandard]
"""
standards_filename: str = HLAAlgorithm._path_join_shim(
os.path.dirname(__file__),
"default_data",
"hla_standards.yaml",
)
with open(standards_filename) as standards_file:
with open(HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_standards.yaml") as standards_file:
return HLAAlgorithm.read_hla_standards(standards_file)

FREQUENCY_LOCUS_COLUMNS: dict[HLA_LOCUS, tuple[str, str]] = {
Expand Down Expand Up @@ -194,13 +191,6 @@ def read_hla_frequencies(
hla_freqs[locus][protein_pair] += 1
return hla_freqs

@staticmethod
def _path_join_shim(*args) -> str:
"""
A shim for os.path.join which allows us to mock out the method easily in testing.
"""
return os.path.join(*args)

@staticmethod
def load_default_hla_frequencies() -> dict[HLA_LOCUS, dict[HLAProteinPair, int]]:
"""
Expand All @@ -210,12 +200,7 @@ def load_default_hla_frequencies() -> dict[HLA_LOCUS, dict[HLAProteinPair, int]]
:rtype: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
"""
hla_freqs: dict[HLA_LOCUS, dict[HLAProteinPair, int]]
default_frequencies_filename: str = HLAAlgorithm._path_join_shim(
os.path.dirname(__file__),
"default_data",
"hla_frequencies.csv",
)
with open(default_frequencies_filename, "r") as f:
with open(HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_frequencies.csv") as f:
hla_freqs = HLAAlgorithm.read_hla_frequencies(f)
return hla_freqs

Expand Down
16 changes: 3 additions & 13 deletions tests/hla_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,12 +1781,6 @@ def test_interpret_error_cases(
}


def test_path_join_shim():
expected_result: str = "/foo/bar/baz"
result: str = HLAAlgorithm._path_join_shim("/foo/bar", "baz")
assert expected_result == result


@pytest.mark.parametrize(
"raw_standards, raw_expected_result",
[
Expand Down Expand Up @@ -1900,7 +1894,7 @@ def test_read_hla_standards(
# Also try reading it from a file.
p = tmp_path / "hla_standards.yaml"
p.write_text(standards_file_str)
mocker.patch.object(HLAAlgorithm, "_path_join_shim", return_value=str(p))
mocker.patch.object(HLAAlgorithm, "DEFAULT_CONFIG_DIR", tmp_path)
load_result: LoadedStandards = HLAAlgorithm.load_default_hla_standards()
assert load_result == expected_result

Expand Down Expand Up @@ -2233,7 +2227,7 @@ def test_read_hla_frequencies(
# Now try loading these from a file.
p = tmp_path / "hla_frequencies.csv"
p.write_text(frequencies_str)
mocker.patch.object(HLAAlgorithm, "_path_join_shim", return_value=str(p))
mocker.patch.object(HLAAlgorithm, "DEFAULT_CONFIG_DIR", tmp_path)
load_result: dict[HLA_LOCUS, dict[HLAProteinPair, int]] = (
HLAAlgorithm.load_default_hla_frequencies()
)
Expand Down Expand Up @@ -2328,11 +2322,7 @@ def test_use_config_all_defaults(
freq_path: Path = tmp_path / "hla_frequencies.csv"
freq_path.write_text(fake_frequencies_str)

mocker.patch.object(
HLAAlgorithm,
"_path_join_shim",
side_effect=[os.fspath(standards_path), os.fspath(freq_path)],
)
mocker.patch.object(HLAAlgorithm, "DEFAULT_CONFIG_DIR", tmp_path)

hla_algorithm: HLAAlgorithm = HLAAlgorithm.use_config()

Expand Down