diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index 0131a5e2..438a491d 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -127,6 +127,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config": default_config = Config() db_path = config_dict.get("db_path") + expand_envs_in_dict(config_dict) if db_path is None: db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/") elif not os.path.isdir(db_path): @@ -470,7 +471,7 @@ async def parse_cli_args(args: Optional[Sequence[str]] = None): def expand_envs_in_dict(d: dict): - if not isinstance(d, dict): + if not isinstance(d, dict): # pragma: nocover return stack = [d] while stack: @@ -482,31 +483,43 @@ def expand_envs_in_dict(d: dict): stack.append(curr[k]) -async def load_config_file(path: Optional[Union[str, Path]] = None): - """Load config file from ~/.config/vectorcode/config.json(5)""" - if path is None: - for name in ("config.json5", "config.json"): - p = os.path.join(GLOBAL_CONFIG_DIR, name) - if os.path.isfile(p): - path = str(p) - break - if path and os.path.isfile(path): - logger.debug(f"Loading config from {path}") - with open(path) as fin: - content = fin.read() - if content: - config = json5.loads(content) - if isinstance(config, dict): - expand_envs_in_dict(config) - return await Config.import_from(config) - else: - logger.error("Invalid configuration format!") - raise ValueError("Invalid configuration format!") - else: - logger.debug("Skipping empty json file.") - else: - logger.warning("Loading default config.") - return Config() +async def load_config_file(path: str | Path | None = None) -> Config: + """ + Load config object by merging the project-local and the global config files. + `path` can be a _file path_ or a _project-root_ path. + + Raises `ValueError` if the config file is not a valid json dictionary. + """ + valid_config_paths = [] + # default to load from the global config + for name in ("config.json5", "config.json"): + p = os.path.join(GLOBAL_CONFIG_DIR, name) + if os.path.isfile(p): + valid_config_paths.append(str(p)) + break + + if path: + if os.path.isfile((path)): + valid_config_paths.append(path) + elif os.path.isdir(path): + for name in ("config.json5", "config.json"): + p = os.path.join(path, ".vectorcode", name) + if os.path.isfile(p): + valid_config_paths.append(str(p)) + break + + final_config = Config() + + for p in valid_config_paths: + with open(p) as fin: + content = json5.load(fin) + logger.info(f"Loaded config from {p}") + if not isinstance(content, dict): + raise ValueError("Invalid configuration format!") + final_config = await final_config.merge_from(await Config.import_from(content)) + logger.debug(f"Merged config: {final_config}") + + return final_config async def find_project_config_dir(start_from: Union[str, Path] = "."): @@ -543,13 +556,12 @@ def find_project_root( start_from = start_from.parent -async def get_project_config(project_root: Union[str, Path]) -> Config: +async def get_project_config(project_root: str | Path) -> Config: """ Load config file for `project_root`. Fallback to global config, and then default config. """ - if not os.path.isabs(project_root): - project_root = os.path.abspath(project_root) + project_root = os.path.abspath(os.path.expanduser(project_root)) exts = ("json5", "json") config = None for ext in exts: diff --git a/src/vectorcode/lsp_main.py b/src/vectorcode/lsp_main.py index 80820f11..2a032f5a 100644 --- a/src/vectorcode/lsp_main.py +++ b/src/vectorcode/lsp_main.py @@ -46,7 +46,7 @@ expand_globs, expand_path, find_project_root, - get_project_config, + load_config_file, parse_cli_args, ) from vectorcode.common import ClientManager, get_collection, list_collection_files @@ -113,7 +113,7 @@ async def execute_command(ls: LanguageServer, args: list[str]): parsed_args.project_root = os.path.abspath(str(parsed_args.project_root)) final_configs = await ( - await get_project_config(parsed_args.project_root) + await load_config_file(parsed_args.project_root) ).merge_from(parsed_args) final_configs.pipe = True else: diff --git a/src/vectorcode/main.py b/src/vectorcode/main.py index 345aedc5..158b50ad 100644 --- a/src/vectorcode/main.py +++ b/src/vectorcode/main.py @@ -11,7 +11,7 @@ CliAction, config_logging, find_project_root, - get_project_config, + load_config_file, parse_cli_args, ) from vectorcode.common import ClientManager @@ -24,7 +24,7 @@ async def async_main(): if cli_args.no_stderr: sys.stderr = open(os.devnull, "w") - if cli_args.debug: + if cli_args.debug: # pragma: nocover from vectorcode import debugging debugging.enable() @@ -43,7 +43,7 @@ async def async_main(): try: final_configs = await ( - await get_project_config(cli_args.project_root) + await load_config_file(cli_args.project_root) ).merge_from(cli_args) except IOError as e: traceback.print_exception(e, file=sys.stderr) diff --git a/src/vectorcode/subcommands/init.py b/src/vectorcode/subcommands/init.py index bba9af63..4ae28eee 100644 --- a/src/vectorcode/subcommands/init.py +++ b/src/vectorcode/subcommands/init.py @@ -114,8 +114,6 @@ async def init(configs: Config) -> int: else: os.makedirs(project_config_dir, exist_ok=True) for item in ( - "config.json5", - "config.json", "vectorcode.include", "vectorcode.exclude", ): diff --git a/tests/subcommands/test_init.py b/tests/subcommands/test_init.py index fa07ce7b..53a13da0 100644 --- a/tests/subcommands/test_init.py +++ b/tests/subcommands/test_init.py @@ -88,7 +88,11 @@ async def test_init_copies_global_config(capsys): # Assert files were copied assert return_code == 0 - assert copyfile_mock.call_count == len(config_items) + assert copyfile_mock.call_count == sum( + # not copying `json`s. + "json" not in i + for i in config_items.keys() + ) # Check output messages captured = capsys.readouterr() diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index bd10efc5..3f0f2179 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -207,6 +207,52 @@ async def test_load_config_file_invalid_json(): await load_config_file(config_path) +@pytest.mark.asyncio +async def test_load_config_file_merging(): + with tempfile.TemporaryDirectory() as dummy_home: + global_config_dir = os.path.join(dummy_home, ".config", "vectorcode") + os.makedirs(global_config_dir, exist_ok=True) + with open(os.path.join(global_config_dir, "config.json"), mode="w") as fin: + fin.writelines(['{"embedding_function": "DummyEmbeddingFunction"}']) + + with tempfile.TemporaryDirectory(dir=dummy_home) as proj_root: + os.makedirs(os.path.join(proj_root, ".vectorcode"), exist_ok=True) + with open( + os.path.join(proj_root, ".vectorcode", "config.json"), mode="w" + ) as fin: + fin.writelines( + ['{"embedding_function": "AnotherDummyEmbeddingFunction"}'] + ) + + with patch( + "vectorcode.cli_utils.GLOBAL_CONFIG_DIR", new=str(global_config_dir) + ): + assert ( + await load_config_file() + ).embedding_function == "DummyEmbeddingFunction" + assert ( + await load_config_file(proj_root) + ).embedding_function == "AnotherDummyEmbeddingFunction" + + +@pytest.mark.asyncio +async def test_load_config_file_with_envs(): + with tempfile.TemporaryDirectory() as proj_root: + os.makedirs(os.path.join(proj_root, ".vectorcode"), exist_ok=True) + with ( + open( + os.path.join(proj_root, ".vectorcode", "config.json"), mode="w" + ) as fin, + ): + fin.writelines(['{"embedding_function": "$DUMMY_EMBEDDING_FUNCTION"}']) + with patch.dict( + os.environ, {"DUMMY_EMBEDDING_FUNCTION": "DummyEmbeddingFunction"} + ): + assert ( + await load_config_file(proj_root) + ).embedding_function == "DummyEmbeddingFunction" + + @pytest.mark.asyncio async def test_load_from_default_config(): for name in ("config.json5", "config.json"): @@ -261,7 +307,8 @@ async def test_load_config_file_empty_file(): with open(config_path, "w") as f: f.write("") - assert await load_config_file(config_path) == Config() + with pytest.raises(ValueError): + await load_config_file(config_path) @pytest.mark.asyncio diff --git a/tests/test_main.py b/tests/test_main.py index b46c9989..ea08ab81 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -43,7 +43,7 @@ async def test_async_main_ioerror(monkeypatch): "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args) ) monkeypatch.setattr( - "vectorcode.main.get_project_config", + "vectorcode.main.load_config_file", AsyncMock(side_effect=IOError("Test Error")), ) @@ -62,7 +62,7 @@ async def test_async_main_cli_action_check(monkeypatch): mock_check = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.check", mock_check) monkeypatch.setattr( - "vectorcode.main.get_project_config", + "vectorcode.main.load_config_file", AsyncMock(return_value=MagicMock(merge_from=AsyncMock())), ) @@ -79,7 +79,7 @@ async def test_async_main_cli_action_init(monkeypatch): ) mock_init = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.init", mock_init) - monkeypatch.setattr("vectorcode.main.get_project_config", AsyncMock()) + monkeypatch.setattr("vectorcode.main.load_config_file", AsyncMock()) return_code = await async_main() assert return_code == 0 @@ -95,7 +95,7 @@ async def test_async_main_cli_action_chunks(monkeypatch): mock_chunks = AsyncMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.chunks", mock_chunks) monkeypatch.setattr( - "vectorcode.main.get_project_config", AsyncMock(return_value=Config()) + "vectorcode.main.load_config_file", AsyncMock(return_value=Config()) ) monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True)) @@ -126,7 +126,7 @@ async def test_async_main_cli_action_prompts(monkeypatch): mock_prompts = MagicMock(return_value=0) monkeypatch.setattr("vectorcode.subcommands.prompts", mock_prompts) monkeypatch.setattr( - "vectorcode.main.get_project_config", AsyncMock(return_value=Config()) + "vectorcode.main.load_config_file", AsyncMock(return_value=Config()) ) return_code = await async_main() @@ -144,7 +144,7 @@ async def test_async_main_cli_action_query(monkeypatch): db_url="http://test_host:1234", action=CliAction.query, pipe=False ) monkeypatch.setattr( - "vectorcode.main.get_project_config", + "vectorcode.main.load_config_file", AsyncMock( return_value=AsyncMock( merge_from=AsyncMock(return_value=mock_final_configs) @@ -175,7 +175,7 @@ async def test_async_main_cli_action_vectorise(monkeypatch): db_url="http://test_host:1234", action=CliAction.vectorise, include_hidden=True ) monkeypatch.setattr( - "vectorcode.main.get_project_config", + "vectorcode.main.load_config_file", AsyncMock( return_value=AsyncMock( merge_from=AsyncMock(return_value=mock_final_configs) @@ -199,7 +199,7 @@ async def test_async_main_cli_action_drop(monkeypatch): ) mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.drop) monkeypatch.setattr( - "vectorcode.main.get_project_config", + "vectorcode.main.load_config_file", AsyncMock( return_value=AsyncMock( merge_from=AsyncMock(return_value=mock_final_configs) @@ -223,7 +223,7 @@ async def test_async_main_cli_action_ls(monkeypatch): ) mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.ls) monkeypatch.setattr( - "vectorcode.main.get_project_config", + "vectorcode.main.load_config_file", AsyncMock( return_value=AsyncMock( merge_from=AsyncMock(return_value=mock_final_configs) @@ -259,7 +259,7 @@ async def test_async_main_cli_action_update(monkeypatch): ) mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.update) monkeypatch.setattr( - "vectorcode.main.get_project_config", + "vectorcode.main.load_config_file", AsyncMock( return_value=AsyncMock( merge_from=AsyncMock(return_value=mock_final_configs) @@ -283,7 +283,7 @@ async def test_async_main_cli_action_clean(monkeypatch): ) mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.clean) monkeypatch.setattr( - "vectorcode.main.get_project_config", + "vectorcode.main.load_config_file", AsyncMock( return_value=AsyncMock( merge_from=AsyncMock(return_value=mock_final_configs) @@ -307,7 +307,7 @@ async def test_async_main_exception_handling(monkeypatch): ) mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.query) monkeypatch.setattr( - "vectorcode.main.get_project_config", + "vectorcode.main.load_config_file", AsyncMock( return_value=AsyncMock( merge_from=AsyncMock(return_value=mock_final_configs)