diff --git a/doc/VectorCode-cli.txt b/doc/VectorCode-cli.txt
index 510510fa..92aa5117 100644
--- a/doc/VectorCode-cli.txt
+++ b/doc/VectorCode-cli.txt
@@ -771,7 +771,8 @@ will use that as the default project root for this process;
Note that:
1. For easier parsing, `--pipe` is assumed to be enabled in LSP mode;
-2. At the time this only work with vectorcode setup that uses a **standalone ChromaDB server**, which is not difficult to setup using docker;
+2. A `vectorcode.lock` file will be created in your `db_path` directory **if you’re using the bundled chromadb server**. Please do not delete it while a
+vectorcode process is running;
3. The LSP server supports `vectorise`, `query` and `ls` subcommands. The other
subcommands may be added in the future.
@@ -789,9 +790,7 @@ features:
- `vectorise`vectorise files into a given project.
To try it out, install the `vectorcode[mcp]` dependency group and the MCP
-server is available in the shell as `vectorcode-mcp-server`, and make sure
-you’re using a |VectorCode-cli-standalone-chromadb-server| configured in the
-|VectorCode-cli-json| via the `host` and `port` options.
+server is available in the shell as `vectorcode-mcp-server`.
The MCP server entry point (`vectorcode-mcp-server`) provides some CLI options
that you can use to customise the default behaviour of the server. To view the
diff --git a/doc/VectorCode.txt b/doc/VectorCode.txt
index 22aeeec1..78410ecb 100644
--- a/doc/VectorCode.txt
+++ b/doc/VectorCode.txt
@@ -372,9 +372,9 @@ path to the executable) by calling `vim.lsp.config('vectorcode_server', opts)`.
minimal extra config required loading/unloading embedding models;
Progress reports.
- Cons Heavy IO overhead because the Requires vectorcode-server; Only
- embedding model and database works if you’re using a standalone
- client need to be initialised ChromaDB server.
+ Cons Heavy IO overhead because the Requires vectorcode-server
+ embedding model and database
+ client need to be initialised
for every query.
-------------------------------------------------------------------------------
You may choose which backend to use by setting the |VectorCode-`setup`| option
diff --git a/docs/cli.md b/docs/cli.md
index 3b386cdb..f6dc17c2 100644
--- a/docs/cli.md
+++ b/docs/cli.md
@@ -696,8 +696,9 @@ will:
Note that:
1. For easier parsing, `--pipe` is assumed to be enabled in LSP mode;
-2. At the time this only work with vectorcode setup that uses a **standalone
- ChromaDB server**, which is not difficult to setup using docker;
+2. A `vectorcode.lock` file will be created in your `db_path` directory __if
+ you're using the bundled chromadb server__. Please do not delete it while a
+ vectorcode process is running;
3. The LSP server supports `vectorise`, `query` and `ls` subcommands. The other
subcommands may be added in the future.
@@ -714,9 +715,7 @@ features:
- `vectorise`: vectorise files into a given project.
To try it out, install the `vectorcode[mcp]` dependency group and the MCP server
-is available in the shell as `vectorcode-mcp-server`, and make sure you're using
-a [standalone chromadb server](#chromadb) configured in the [JSON](#configuring-vectorcode)
-via the `host` and `port` options.
+is available in the shell as `vectorcode-mcp-server`.
The MCP server entry point (`vectorcode-mcp-server`) provides some CLI options
that you can use to customise the default behaviour of the server. To view the
diff --git a/docs/neovim.md b/docs/neovim.md
index 0f61566c..a2a3981c 100644
--- a/docs/neovim.md
+++ b/docs/neovim.md
@@ -332,7 +332,7 @@ interface:
| Features | `default` | `lsp` |
|----------|-----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|
| **Pros** | Fully backward compatible with minimal extra config required | Less IO overhead for loading/unloading embedding models; Progress reports. |
-| **Cons** | Heavy IO overhead because the embedding model and database client need to be initialised for every query. | Requires `vectorcode-server`; Only works if you're using a standalone ChromaDB server. |
+| **Cons** | Heavy IO overhead because the embedding model and database client need to be initialised for every query. | Requires `vectorcode-server` |
You may choose which backend to use by setting the [`setup`](#setupopts) option `async_backend`,
and acquire the corresponding backend by the following API:
diff --git a/pyproject.toml b/pyproject.toml
index dc05e2c6..45ec151f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,7 @@ dependencies = [
"charset-normalizer>=3.4.1",
"json5",
"posthog<6.0.0",
+ "filelock>=3.15.0",
]
requires-python = ">=3.11,<3.14"
readme = "README.md"
diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py
index 38fd74ef..e49b4b24 100644
--- a/src/vectorcode/cli_utils.py
+++ b/src/vectorcode/cli_utils.py
@@ -12,6 +12,7 @@
import json5
import shtab
+from filelock import AsyncFileLock
from vectorcode import __version__
@@ -610,3 +611,31 @@ def config_logging(
handlers=handlers,
level=level,
)
+
+
+class LockManager:
+ """
+ A class that manages file locks that protects the database files in daemon processes (LSP, MCP).
+ """
+
+ __locks: dict[str, AsyncFileLock]
+ singleton: Optional["LockManager"] = None
+
+ def __new__(cls) -> "LockManager":
+ if cls.singleton is None:
+ cls.singleton = super().__new__(cls)
+ cls.singleton.__locks = {}
+ return cls.singleton
+
+ def get_lock(self, path: str | os.PathLike) -> AsyncFileLock:
+ path = str(expand_path(str(path), True))
+ if os.path.isdir(path):
+ lock_file = os.path.join(path, "vectorcode.lock")
+ logger.info(f"Creating {lock_file} for locking.")
+ if not os.path.isfile(lock_file):
+ with open(lock_file, mode="w") as fin:
+ fin.write("")
+ path = lock_file
+ if self.__locks.get(path) is None:
+ self.__locks[path] = AsyncFileLock(path) # pyright: ignore[reportArgumentType]
+ return self.__locks[path]
diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py
index 65ce495e..a297c25f 100644
--- a/src/vectorcode/common.py
+++ b/src/vectorcode/common.py
@@ -1,11 +1,14 @@
import asyncio
+import contextlib
import hashlib
import logging
import os
import socket
import subprocess
import sys
-from typing import Any, AsyncGenerator
+from asyncio.subprocess import Process
+from dataclasses import dataclass
+from typing import Any, AsyncGenerator, Optional
from urllib.parse import urlparse
import chromadb
@@ -16,7 +19,7 @@
from chromadb.config import APIVersion, Settings
from chromadb.utils import embedding_functions
-from vectorcode.cli_utils import Config, expand_path
+from vectorcode.cli_utils import Config, LockManager, expand_path
logger = logging.getLogger(name=__name__)
@@ -112,32 +115,6 @@ async def start_server(configs: Config):
return process
-__CLIENT_CACHE: dict[str, AsyncClientAPI] = {}
-
-
-async def get_client(configs: Config) -> AsyncClientAPI:
- client_entry = configs.db_url
- if __CLIENT_CACHE.get(client_entry) is None:
- settings: dict[str, Any] = {"anonymized_telemetry": False}
- if isinstance(configs.db_settings, dict):
- valid_settings = {
- k: v for k, v in configs.db_settings.items() if k in Settings.__fields__
- }
- settings.update(valid_settings)
- parsed_url = urlparse(configs.db_url)
- settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
- settings["chroma_server_http_port"] = parsed_url.port or 8000
- settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
- settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2
- settings_obj = Settings(**settings)
- __CLIENT_CACHE[client_entry] = await chromadb.AsyncHttpClient(
- settings=settings_obj,
- host=str(settings_obj.chroma_server_host),
- port=int(settings_obj.chroma_server_http_port or 8000),
- )
- return __CLIENT_CACHE[client_entry]
-
-
def get_collection_name(full_path: str) -> str:
full_path = str(expand_path(full_path, absolute=True))
hasher = hashlib.sha256()
@@ -261,3 +238,80 @@ async def list_collection_files(collection: AsyncCollection) -> list[str]:
or []
)
)
+
+
+@dataclass
+class _ClientModel:
+ client: AsyncClientAPI
+ is_bundled: bool = False
+ process: Optional[Process] = None
+
+
+class ClientManager:
+ singleton: Optional["ClientManager"] = None
+ __clients: dict[str, _ClientModel]
+
+ def __new__(cls) -> "ClientManager":
+ if cls.singleton is None:
+ cls.singleton = super().__new__(cls)
+ cls.singleton.__clients = {}
+ return cls.singleton
+
+ @contextlib.asynccontextmanager
+ async def get_client(self, configs: Config, need_lock: bool = True):
+ project_root = str(expand_path(str(configs.project_root), True))
+ is_bundled = False
+ if self.__clients.get(project_root) is None:
+ process = None
+ if not await try_server(configs.db_url):
+ logger.info(f"Starting a new server at {configs.db_url}")
+ process = await start_server(configs)
+ is_bundled = True
+
+ self.__clients[project_root] = _ClientModel(
+ client=await self._create_client(configs),
+ is_bundled=is_bundled,
+ process=process,
+ )
+ lock = None
+ if self.__clients[project_root].is_bundled and need_lock:
+ lock = LockManager().get_lock(str(configs.db_path))
+ logger.debug(f"Locking {configs.db_path}")
+ await lock.acquire()
+ yield self.__clients[project_root].client
+ if lock is not None:
+ logger.debug(f"Unlocking {configs.db_path}")
+ await lock.release()
+
+ def get_processes(self) -> list[Process]:
+ return [i.process for i in self.__clients.values() if i.process is not None]
+
+ async def kill_servers(self):
+ termination_tasks: list[asyncio.Task] = []
+ for p in self.get_processes():
+ logger.info(f"Killing bundled chroma server with PID: {p.pid}")
+ p.terminate()
+ termination_tasks.append(asyncio.create_task(p.wait()))
+ await asyncio.gather(*termination_tasks)
+
+ async def _create_client(self, configs: Config) -> AsyncClientAPI:
+ settings: dict[str, Any] = {"anonymized_telemetry": False}
+ if isinstance(configs.db_settings, dict):
+ valid_settings = {
+ k: v for k, v in configs.db_settings.items() if k in Settings.__fields__
+ }
+ settings.update(valid_settings)
+ parsed_url = urlparse(configs.db_url)
+ settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1"
+ settings["chroma_server_http_port"] = parsed_url.port or 8000
+ settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https"
+ settings["chroma_server_api_default_path"] = parsed_url.path or APIVersion.V2
+ settings_obj = Settings(**settings)
+ return await chromadb.AsyncHttpClient(
+ settings=settings_obj,
+ host=str(settings_obj.chroma_server_host),
+ port=int(settings_obj.chroma_server_http_port or 8000),
+ )
+
+ def clear(self):
+ self.__clients.clear()
diff --git a/src/vectorcode/lsp_main.py b/src/vectorcode/lsp_main.py
index f9b883f9..bd78854a 100644
--- a/src/vectorcode/lsp_main.py
+++ b/src/vectorcode/lsp_main.py
@@ -35,7 +35,6 @@
from vectorcode import __version__
from vectorcode.cli_utils import (
CliAction,
- Config,
cleanup_path,
config_logging,
expand_globs,
@@ -43,28 +42,14 @@
get_project_config,
parse_cli_args,
)
-from vectorcode.common import get_client, get_collection, try_server
+from vectorcode.common import ClientManager, get_collection
from vectorcode.subcommands.ls import get_collection_list
from vectorcode.subcommands.query import build_query_results
-cached_project_configs: dict[str, Config] = {}
DEFAULT_PROJECT_ROOT: str | None = None
logger = logging.getLogger(__name__)
-async def make_caches(project_root: str):
- assert os.path.isabs(project_root)
- if cached_project_configs.get(project_root) is None:
- cached_project_configs[project_root] = await get_project_config(project_root)
- config = cached_project_configs[project_root]
- config.project_root = project_root
- db_url = config.db_url
- if not await try_server(db_url): # pragma: nocover
- raise ConnectionError(
- "Failed to find an existing ChromaDB server, which is a hard requirement for LSP mode!"
- )
-
-
def get_arg_parser():
parser = argparse.ArgumentParser(
"vectorcode-server", description="VectorCode LSP daemon."
@@ -109,134 +94,140 @@ async def execute_command(ls: LanguageServer, args: list[str]):
collection = None
if parsed_args.project_root is not None:
parsed_args.project_root = os.path.abspath(str(parsed_args.project_root))
- await make_caches(parsed_args.project_root)
- final_configs = await cached_project_configs[
- parsed_args.project_root
- ].merge_from(parsed_args)
+
+ final_configs = await (
+ await get_project_config(parsed_args.project_root)
+ ).merge_from(parsed_args)
final_configs.pipe = True
- client = await get_client(final_configs)
+ else:
+ final_configs = parsed_args
+ logger.info("Merged final configs: %s", final_configs)
+ async with ClientManager().get_client(final_configs) as client:
+ progress_token = str(uuid.uuid4())
+
if final_configs.action in {CliAction.vectorise, CliAction.query}:
collection = await get_collection(
client=client,
configs=final_configs,
make_if_missing=final_configs.action in {CliAction.vectorise},
)
- else:
- final_configs = parsed_args
- client = await get_client(parsed_args)
- collection = None
- logger.info("Merged final configs: %s", final_configs)
- progress_token = str(uuid.uuid4())
-
- await ls.progress.create_async(progress_token)
- match final_configs.action:
- case CliAction.query:
- ls.progress.begin(
- progress_token,
- types.WorkDoneProgressBegin(
- "VectorCode",
- message=f"Querying {cleanup_path(str(final_configs.project_root))}",
- ),
- )
- final_results = []
- try:
- assert collection is not None, (
- "Failed to find the correct collection."
- )
- final_results.extend(
- await build_query_results(collection, final_configs)
- )
- finally:
- log_message = f"Retrieved {len(final_results)} result{'s' if len(final_results) > 1 else ''} in {round(time.time() - start_time, 2)}s."
- ls.progress.end(
+ await ls.progress.create_async(progress_token)
+ match final_configs.action:
+ case CliAction.query:
+ ls.progress.begin(
progress_token,
- types.WorkDoneProgressEnd(message=log_message),
+ types.WorkDoneProgressBegin(
+ "VectorCode",
+ message=f"Querying {cleanup_path(str(final_configs.project_root))}",
+ ),
)
- logger.info(log_message)
- return final_results
- case CliAction.ls:
- ls.progress.begin(
- progress_token,
- types.WorkDoneProgressBegin(
- "VectorCode",
- message="Looking for available projects indexed by VectorCode",
- ),
- )
- projects: list[dict] = []
- try:
- projects.extend(await get_collection_list(client))
- finally:
- ls.progress.end(
+ final_results = []
+ try:
+ assert collection is not None, (
+ "Failed to find the correct collection."
+ )
+ final_results.extend(
+ await build_query_results(collection, final_configs)
+ )
+ finally:
+ log_message = f"Retrieved {len(final_results)} result{'s' if len(final_results) > 1 else ''} in {round(time.time() - start_time, 2)}s."
+ ls.progress.end(
+ progress_token,
+ types.WorkDoneProgressEnd(message=log_message),
+ )
+ logger.info(log_message)
+ return final_results
+ case CliAction.ls:
+ ls.progress.begin(
progress_token,
- types.WorkDoneProgressEnd(message="List retrieved."),
+ types.WorkDoneProgressBegin(
+ "VectorCode",
+ message="Looking for available projects indexed by VectorCode",
+ ),
)
- logger.info(f"Retrieved {len(projects)} project(s).")
- return projects
- case CliAction.vectorise:
- assert collection is not None, "Failed to find the correct collection."
- ls.progress.begin(
- progress_token,
- types.WorkDoneProgressBegin(
- title="VectorCode", message="Vectorising files...", percentage=0
- ),
- )
- files = await expand_globs(
- final_configs.files
- or load_files_from_include(str(final_configs.project_root)),
- recursive=final_configs.recursive,
- include_hidden=final_configs.include_hidden,
- )
- if not final_configs.force: # pragma: nocover
- # tested in 'vectorise.py'
- for spec in find_exclude_specs(final_configs):
- if os.path.isfile(spec):
- logger.info(f"Loading ignore specs from {spec}.")
- files = exclude_paths_by_spec((str(i) for i in files), spec)
- stats = VectoriseStats()
- collection_lock = asyncio.Lock()
- stats_lock = asyncio.Lock()
- max_batch_size = await client.get_max_batch_size()
- semaphore = asyncio.Semaphore(os.cpu_count() or 1)
- tasks = [
- asyncio.create_task(
- chunked_add(
- str(file),
- collection,
- collection_lock,
- stats,
- stats_lock,
- final_configs,
- max_batch_size,
- semaphore,
+ projects: list[dict] = []
+ try:
+ projects.extend(await get_collection_list(client))
+ finally:
+ ls.progress.end(
+ progress_token,
+ types.WorkDoneProgressEnd(message="List retrieved."),
)
+ logger.info(f"Retrieved {len(projects)} project(s).")
+ return projects
+ case CliAction.vectorise:
+ assert collection is not None, (
+ "Failed to find the correct collection."
)
- for file in files
- ]
- for i, task in enumerate(asyncio.as_completed(tasks), start=1):
- await task
- ls.progress.report(
+ ls.progress.begin(
progress_token,
- types.WorkDoneProgressReport(
+ types.WorkDoneProgressBegin(
+ title="VectorCode",
message="Vectorising files...",
- percentage=int(100 * i / len(tasks)),
+ percentage=0,
),
)
+ files = await expand_globs(
+ final_configs.files
+ or load_files_from_include(str(final_configs.project_root)),
+ recursive=final_configs.recursive,
+ include_hidden=final_configs.include_hidden,
+ )
+ if not final_configs.force: # pragma: nocover
+ # tested in 'vectorise.py'
+ for spec in find_exclude_specs(final_configs):
+ if os.path.isfile(spec):
+ logger.info(f"Loading ignore specs from {spec}.")
+ files = exclude_paths_by_spec(
+ (str(i) for i in files), spec
+ )
+ stats = VectoriseStats()
+ collection_lock = asyncio.Lock()
+ stats_lock = asyncio.Lock()
+ max_batch_size = await client.get_max_batch_size()
+ semaphore = asyncio.Semaphore(os.cpu_count() or 1)
+ tasks = [
+ asyncio.create_task(
+ chunked_add(
+ str(file),
+ collection,
+ collection_lock,
+ stats,
+ stats_lock,
+ final_configs,
+ max_batch_size,
+ semaphore,
+ )
+ )
+ for file in files
+ ]
+ for i, task in enumerate(asyncio.as_completed(tasks), start=1):
+ await task
+ ls.progress.report(
+ progress_token,
+ types.WorkDoneProgressReport(
+ message="Vectorising files...",
+ percentage=int(100 * i / len(tasks)),
+ ),
+ )
- await remove_orphanes(collection, collection_lock, stats, stats_lock)
+ await remove_orphanes(
+ collection, collection_lock, stats, stats_lock
+ )
- ls.progress.end(
- progress_token,
- types.WorkDoneProgressEnd(
- message=f"Vectorised {stats.add + stats.update} files."
- ),
- )
- return stats.to_dict()
- case _ as c: # pragma: nocover
- error_message = f"Unsupported vectorcode subcommand: {str(c)}"
- logger.error(
- error_message,
- )
- raise JsonRpcInvalidRequest(error_message)
+ ls.progress.end(
+ progress_token,
+ types.WorkDoneProgressEnd(
+ message=f"Vectorised {stats.add + stats.update} files."
+ ),
+ )
+ return stats.to_dict()
+ case _ as c: # pragma: nocover
+ error_message = f"Unsupported vectorcode subcommand: {str(c)}"
+ logger.error(
+ error_message,
+ )
+ raise JsonRpcInvalidRequest(error_message)
except Exception as e: # pragma: nocover
if isinstance(e, JsonRpcException):
# pygls exception. raise it as is.
@@ -266,9 +257,11 @@ async def lsp_start() -> int:
logger.info(f"{DEFAULT_PROJECT_ROOT=}")
logger.info("Parsed LSP server CLI arguments: %s", args)
- await asyncio.to_thread(server.start_io)
-
- return 0
+ try:
+ await asyncio.to_thread(server.start_io)
+ finally:
+ await ClientManager().kill_servers()
+ return 0
def main(): # pragma: nocover
diff --git a/src/vectorcode/main.py b/src/vectorcode/main.py
index 3ea8eefa..70cc1aba 100644
--- a/src/vectorcode/main.py
+++ b/src/vectorcode/main.py
@@ -12,6 +12,7 @@
get_project_config,
parse_cli_args,
)
+from vectorcode.common import ClientManager
logger = logging.getLogger(name=__name__)
@@ -63,12 +64,6 @@ async def async_main():
return await chunks(final_configs)
- from vectorcode.common import start_server, try_server
-
- server_process = None
- if not await try_server(final_configs.db_url):
- server_process = await start_server(final_configs)
-
if final_configs.pipe: # pragma: nocover
# NOTE: NNCF (intel GPU acceleration for sentence transformer) keeps showing logs.
# This disables logs below ERROR so that it doesn't hurt the `pipe` output.
@@ -105,10 +100,7 @@ async def async_main():
return_val = 1
logger.error(traceback.format_exc())
finally:
- if server_process is not None:
- logger.info("Shutting down the bundled Chromadb instance.")
- server_process.terminate()
- await server_process.wait()
+ await ClientManager().kill_servers()
return return_val
diff --git a/src/vectorcode/mcp_main.py b/src/vectorcode/mcp_main.py
index 86d9989c..72f57d92 100644
--- a/src/vectorcode/mcp_main.py
+++ b/src/vectorcode/mcp_main.py
@@ -8,9 +8,6 @@
from typing import Optional
import shtab
-from chromadb.api import AsyncClientAPI
-from chromadb.api.models.AsyncCollection import AsyncCollection
-from chromadb.errors import InvalidCollectionException
from vectorcode.subcommands.vectorise import (
VectoriseStats,
@@ -32,6 +29,7 @@
from vectorcode.cli_utils import (
Config,
+ LockManager,
cleanup_path,
config_logging,
expand_globs,
@@ -39,11 +37,12 @@
get_project_config,
load_config_file,
)
-from vectorcode.common import get_client, get_collection, get_collections
+from vectorcode.common import ClientManager, get_collection, get_collections
from vectorcode.subcommands.prompt import prompt_by_categories
from vectorcode.subcommands.query import get_query_result_files
logger = logging.getLogger(name=__name__)
+locks = LockManager()
@dataclass
@@ -79,23 +78,20 @@ def get_arg_parser():
return parser
+default_project_root: Optional[str] = None
default_config: Optional[Config] = None
-default_client: Optional[AsyncClientAPI] = None
-default_collection: Optional[AsyncCollection] = None
async def list_collections() -> list[str]:
- global default_config, default_client, default_collection
names: list[str] = []
- client = default_client
- if client is None:
- # load from global config when failed to detect a project-local config.
- client = await get_client(await load_config_file())
- async for col in get_collections(client):
- if col.metadata is not None:
- names.append(cleanup_path(str(col.metadata.get("path"))))
- logger.info("Retrieved the following collections: %s", names)
- return names
+ async with ClientManager().get_client(
+ await load_config_file(default_project_root)
+ ) as client:
+ async for col in get_collections(client):
+ if col.metadata is not None:
+ names.append(cleanup_path(str(col.metadata.get("path"))))
+ logger.info("Retrieved the following collections: %s", names)
+ return names
async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]:
@@ -110,8 +106,53 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]
)
config = await get_project_config(project_root)
try:
- client = await get_client(config)
- collection = await get_collection(client, config, True)
+ async with ClientManager().get_client(config) as client:
+ collection = await get_collection(client, config, True)
+ if collection is None: # pragma: nocover
+ raise McpError(
+ ErrorData(
+ code=1,
+ message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
+ )
+ )
+ paths = [os.path.expanduser(i) for i in await expand_globs(paths)]
+ final_config = await config.merge_from(
+ Config(
+ files=[i for i in paths if os.path.isfile(i)],
+ project_root=project_root,
+ )
+ )
+ for ignore_spec in find_exclude_specs(final_config):
+ if os.path.isfile(ignore_spec):
+ logger.info(f"Loading ignore specs from {ignore_spec}.")
+ paths = exclude_paths_by_spec((str(i) for i in paths), ignore_spec)
+
+ stats = VectoriseStats()
+ collection_lock = asyncio.Lock()
+ stats_lock = asyncio.Lock()
+ max_batch_size = await client.get_max_batch_size()
+ semaphore = asyncio.Semaphore(os.cpu_count() or 1)
+ tasks = [
+ asyncio.create_task(
+ chunked_add(
+ str(file),
+ collection,
+ collection_lock,
+ stats,
+ stats_lock,
+ final_config,
+ max_batch_size,
+ semaphore,
+ )
+ )
+ for file in paths
+ ]
+ for i, task in enumerate(asyncio.as_completed(tasks), start=1):
+ await task
+
+ await remove_orphanes(collection, collection_lock, stats, stats_lock)
+
+ return stats.to_dict()
except Exception as e:
logger.error("Failed to access collection at %s", project_root)
raise McpError(
@@ -120,48 +161,6 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]
message=f"{e.__class__.__name__}: Failed to create the collection at {project_root}.",
)
)
- if collection is None: # pragma: nocover
- raise McpError(
- ErrorData(
- code=1,
- message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
- )
- )
-
- paths = [os.path.expanduser(i) for i in await expand_globs(paths)]
- final_config = await config.merge_from(
- Config(files=[i for i in paths if os.path.isfile(i)], project_root=project_root)
- )
- for ignore_spec in find_exclude_specs(final_config):
- if os.path.isfile(ignore_spec):
- logger.info(f"Loading ignore specs from {ignore_spec}.")
- paths = exclude_paths_by_spec((str(i) for i in paths), ignore_spec)
- stats = VectoriseStats()
- collection_lock = asyncio.Lock()
- stats_lock = asyncio.Lock()
- max_batch_size = await client.get_max_batch_size()
- semaphore = asyncio.Semaphore(os.cpu_count() or 1)
- tasks = [
- asyncio.create_task(
- chunked_add(
- str(file),
- collection,
- collection_lock,
- stats,
- stats_lock,
- final_config,
- max_batch_size,
- semaphore,
- )
- )
- for file in paths
- ]
- for i, task in enumerate(asyncio.as_completed(tasks), start=1):
- await task
-
- await remove_orphanes(collection, collection_lock, stats, stats_lock)
-
- return stats.to_dict()
async def query_tool(
@@ -184,78 +183,80 @@ async def query_tool(
message="Use `list_collections` tool to get a list of valid paths for this field.",
)
)
- else:
- config = await get_project_config(project_root)
- try:
- client = await get_client(config)
+ config = await get_project_config(project_root)
+ try:
+ async with ClientManager().get_client(config) as client:
collection = await get_collection(client, config, False)
- except Exception as e:
- logger.error("Failed to access collection at %s", project_root)
- raise McpError(
- ErrorData(
- code=1,
- message=f"{e.__class__.__name__}: Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
+
+ if collection is None: # pragma: nocover
+ raise McpError(
+ ErrorData(
+ code=1,
+ message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
+ )
)
+ query_config = await config.merge_from(
+ Config(n_result=n_query, query=query_messages)
+ )
+ logger.info("Built the final config: %s", query_config)
+ result_paths = await get_query_result_files(
+ collection=collection,
+ configs=query_config,
)
- if collection is None:
+ results: list[str] = []
+ for path in result_paths:
+ if os.path.isfile(path):
+ with open(path) as fin:
+ rel_path = os.path.relpath(path, config.project_root)
+ results.append(
+ f"{rel_path}\n{fin.read()}",
+ )
+ logger.info("Retrieved the following files: %s", result_paths)
+ return results
+
+ except Exception as e:
+ logger.error("Failed to access collection at %s", project_root)
raise McpError(
ErrorData(
code=1,
- message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
+ message=f"{e.__class__.__name__}: Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
)
)
- query_config = await config.merge_from(
- Config(n_result=n_query, query=query_messages)
- )
- logger.info("Built the final config: %s", query_config)
- result_paths = await get_query_result_files(
- collection=collection,
- configs=query_config,
- )
- results: list[str] = []
- for path in result_paths:
- if os.path.isfile(path):
- with open(path) as fin:
- rel_path = os.path.relpath(path, config.project_root)
- results.append(
- f"{rel_path}\n{fin.read()}",
- )
- logger.info("Retrieved the following files: %s", result_paths)
- return results
async def mcp_server():
- global default_config, default_client, default_collection
+ global default_config, default_project_root
local_config_dir = await find_project_config_dir(".")
+ default_instructions = "\n".join(
+ "\n".join(i) for i in prompt_by_categories.values()
+ )
if local_config_dir is not None:
logger.info("Found project config: %s", local_config_dir)
project_root = str(Path(local_config_dir).parent.resolve())
+ default_project_root = project_root
default_config = await get_project_config(project_root)
default_config.project_root = project_root
- default_client = await get_client(default_config)
- try:
- default_collection = await get_collection(default_client, default_config)
+ async with ClientManager().get_client(default_config) as client:
logger.info("Collection initialised for %s.", project_root)
- except InvalidCollectionException: # pragma: nocover
- default_collection = None
- default_instructions = "\n".join(
- "\n".join(i) for i in prompt_by_categories.values()
- )
- if default_client is None:
- if mcp_config.ls_on_start: # pragma: nocover
- logger.warning(
- "Failed to initialise a chromadb client. Ignoring --ls-on-start flag."
- )
- else:
- if mcp_config.ls_on_start:
- logger.info("Adding available collections to the server instructions.")
- default_instructions += "\nYou have access to the following collections:\n"
- for name in await list_collections():
- default_instructions += f"{name}"
+ if client is None:
+ if mcp_config.ls_on_start: # pragma: nocover
+ logger.warning(
+ "Failed to initialise a chromadb client. Ignoring --ls-on-start flag."
+ )
+ else:
+ if mcp_config.ls_on_start:
+ logger.info(
+ "Adding available collections to the server instructions."
+ )
+ default_instructions += (
+ "\nYou have access to the following collections:\n"
+ )
+ for name in await list_collections():
+ default_instructions += f"{name}"
mcp = FastMCP("VectorCode", instructions=default_instructions)
mcp.add_tool(
@@ -292,9 +293,12 @@ def parse_cli_args(args: Optional[list[str]] = None) -> MCPConfig:
async def run_server(): # pragma: nocover
- mcp = await mcp_server()
- await mcp.run_stdio_async()
- return 0
+ try:
+ mcp = await mcp_server()
+ await mcp.run_stdio_async()
+ finally:
+ await ClientManager().kill_servers()
+ return 0
def main(): # pragma: nocover
diff --git a/src/vectorcode/subcommands/clean.py b/src/vectorcode/subcommands/clean.py
index 4a58aeb9..bae7ed48 100644
--- a/src/vectorcode/subcommands/clean.py
+++ b/src/vectorcode/subcommands/clean.py
@@ -4,7 +4,7 @@
from chromadb.api import AsyncClientAPI
from vectorcode.cli_utils import Config
-from vectorcode.common import get_client, get_collections
+from vectorcode.common import ClientManager, get_collections
logger = logging.getLogger(name=__name__)
@@ -21,5 +21,6 @@ async def run_clean_on_client(client: AsyncClientAPI, pipe_mode: bool):
async def clean(configs: Config) -> int:
- await run_clean_on_client(await get_client(configs), configs.pipe)
- return 0
+ async with ClientManager().get_client(configs) as client:
+ await run_clean_on_client(client, configs.pipe)
+ return 0
diff --git a/src/vectorcode/subcommands/drop.py b/src/vectorcode/subcommands/drop.py
index 08fbbbae..155c303f 100644
--- a/src/vectorcode/subcommands/drop.py
+++ b/src/vectorcode/subcommands/drop.py
@@ -3,22 +3,22 @@
from chromadb.errors import InvalidCollectionException
from vectorcode.cli_utils import Config
-from vectorcode.common import get_client, get_collection
+from vectorcode.common import ClientManager, get_collection
logger = logging.getLogger(name=__name__)
async def drop(config: Config) -> int:
- client = await get_client(config)
- try:
- collection = await get_collection(client, config)
- collection_path = collection.metadata["path"]
- await client.delete_collection(collection.name)
- print(f"Collection for {collection_path} has been deleted.")
- logger.info(f"Deteted collection at {collection_path}.")
- return 0
- except (ValueError, InvalidCollectionException) as e:
- logger.error(
- f"{e.__class__.__name__}: There's no existing collection for {config.project_root}"
- )
- return 1
+ async with ClientManager().get_client(config) as client:
+ try:
+ collection = await get_collection(client, config)
+ collection_path = collection.metadata["path"]
+ await client.delete_collection(collection.name)
+ print(f"Collection for {collection_path} has been deleted.")
+ logger.info(f"Deteted collection at {collection_path}.")
+ return 0
+ except (ValueError, InvalidCollectionException) as e:
+ logger.error(
+ f"{e.__class__.__name__}: There's no existing collection for {config.project_root}"
+ )
+ return 1
diff --git a/src/vectorcode/subcommands/ls.py b/src/vectorcode/subcommands/ls.py
index 246eb85b..c78d82ac 100644
--- a/src/vectorcode/subcommands/ls.py
+++ b/src/vectorcode/subcommands/ls.py
@@ -8,7 +8,7 @@
from chromadb.api.types import IncludeEnum
from vectorcode.cli_utils import Config, cleanup_path
-from vectorcode.common import get_client, get_collections
+from vectorcode.common import ClientManager, get_collections
logger = logging.getLogger(name=__name__)
@@ -36,34 +36,34 @@ async def get_collection_list(client: AsyncClientAPI) -> list[dict]:
async def ls(configs: Config) -> int:
- client = await get_client(configs)
- result: list[dict] = await get_collection_list(client)
- logger.info(f"Found the following collections: {result}")
+ async with ClientManager().get_client(configs) as client:
+ result: list[dict] = await get_collection_list(client)
+ logger.info(f"Found the following collections: {result}")
- if configs.pipe:
- print(json.dumps(result))
- else:
- table = []
- for meta in result:
- project_root = meta["project-root"]
- if os.environ.get("HOME"):
- project_root = project_root.replace(os.environ["HOME"], "~")
- row = [
- project_root,
- meta["size"],
- meta["num_files"],
- meta["embedding_function"],
- ]
- table.append(row)
- print(
- tabulate.tabulate(
- table,
- headers=[
- "Project Root",
- "Collection Size",
- "Number of Files",
- "Embedding Function",
- ],
+ if configs.pipe:
+ print(json.dumps(result))
+ else:
+ table = []
+ for meta in result:
+ project_root = meta["project-root"]
+ if os.environ.get("HOME"):
+ project_root = project_root.replace(os.environ["HOME"], "~")
+ row = [
+ project_root,
+ meta["size"],
+ meta["num_files"],
+ meta["embedding_function"],
+ ]
+ table.append(row)
+ print(
+ tabulate.tabulate(
+ table,
+ headers=[
+ "Project Root",
+ "Collection Size",
+ "Number of Files",
+ "Embedding Function",
+ ],
+ )
)
- )
- return 0
+ return 0
diff --git a/src/vectorcode/subcommands/query/__init__.py b/src/vectorcode/subcommands/query/__init__.py
index 51c3a550..e1c0fc2f 100644
--- a/src/vectorcode/subcommands/query/__init__.py
+++ b/src/vectorcode/subcommands/query/__init__.py
@@ -17,7 +17,7 @@
expand_path,
)
from vectorcode.common import (
- get_client,
+ ClientManager,
get_collection,
verify_ef,
)
@@ -160,52 +160,52 @@ async def query(configs: Config) -> int:
"Having both chunk and document in the output is not supported!",
)
return 1
- client = await get_client(configs)
- try:
- collection = await get_collection(client, configs, False)
- if not verify_ef(collection, configs):
+ async with ClientManager().get_client(configs) as client:
+ try:
+ collection = await get_collection(client, configs, False)
+ if not verify_ef(collection, configs):
+ return 1
+ except (ValueError, InvalidCollectionException) as e:
+ logger.error(
+ f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}",
+ )
+ return 1
+ except InvalidDimensionException as e:
+ logger.error(
+ f"{e.__class__.__name__}: The collection was embedded with a different embedding model.",
+ )
+ return 1
+ except IndexError as e: # pragma: nocover
+ logger.error(
+ f"{e.__class__.__name__}: Failed to get the collection. Please check your config."
+ )
return 1
- except (ValueError, InvalidCollectionException) as e:
- logger.error(
- f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}",
- )
- return 1
- except InvalidDimensionException as e:
- logger.error(
- f"{e.__class__.__name__}: The collection was embedded with a different embedding model.",
- )
- return 1
- except IndexError as e: # pragma: nocover
- logger.error(
- f"{e.__class__.__name__}: Failed to get the collection. Please check your config."
- )
- return 1
- if not configs.pipe:
- print("Starting querying...")
+ if not configs.pipe:
+ print("Starting querying...")
- if QueryInclude.chunk in configs.include:
- if len((await collection.get(where={"start": {"$gte": 0}}))["ids"]) == 0:
- logger.warning(
- """
-This collection doesn't contain line range metadata. Falling back to `--include path document`.
-Please re-vectorise it to use `--include chunk`.""",
- )
- configs.include = [QueryInclude.path, QueryInclude.document]
+ if QueryInclude.chunk in configs.include:
+ if len((await collection.get(where={"start": {"$gte": 0}}))["ids"]) == 0:
+ logger.warning(
+ """
+ This collection doesn't contain line range metadata. Falling back to `--include path document`.
+ Please re-vectorise it to use `--include chunk`.""",
+ )
+ configs.include = [QueryInclude.path, QueryInclude.document]
- try:
- structured_result = await build_query_results(collection, configs)
- except RerankerError as e: # pragma: nocover
- # error logs should be handled where they're raised
- logger.error(f"{e.__class__.__name__}")
- return 1
+ try:
+ structured_result = await build_query_results(collection, configs)
+ except RerankerError as e: # pragma: nocover
+ # error logs should be handled where they're raised
+ logger.error(f"{e.__class__.__name__}")
+ return 1
- if configs.pipe:
- print(json.dumps(structured_result))
- else:
- for idx, result in enumerate(structured_result):
- for include_item in configs.include:
- print(f"{include_item.to_header()}{result.get(include_item.value)}")
- if idx != len(structured_result) - 1:
- print()
- return 0
+ if configs.pipe:
+ print(json.dumps(structured_result))
+ else:
+ for idx, result in enumerate(structured_result):
+ for include_item in configs.include:
+ print(f"{include_item.to_header()}{result.get(include_item.value)}")
+ if idx != len(structured_result) - 1:
+ print()
+ return 0
diff --git a/src/vectorcode/subcommands/update.py b/src/vectorcode/subcommands/update.py
index 2d7d4322..1416a7b8 100644
--- a/src/vectorcode/subcommands/update.py
+++ b/src/vectorcode/subcommands/update.py
@@ -9,78 +9,85 @@
from chromadb.errors import InvalidCollectionException
from vectorcode.cli_utils import Config
-from vectorcode.common import get_client, get_collection, verify_ef
+from vectorcode.common import ClientManager, get_collection, verify_ef
from vectorcode.subcommands.vectorise import VectoriseStats, chunked_add, show_stats
logger = logging.getLogger(name=__name__)
async def update(configs: Config) -> int:
- client = await get_client(configs)
- try:
- collection = await get_collection(client, configs, False)
- except IndexError as e:
- print(
- f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config."
- )
- return 1
- except (ValueError, InvalidCollectionException) as e:
- print(
- f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}",
- file=sys.stderr,
- )
- return 1
- if collection is None or not verify_ef(collection, configs):
- return 1
+ async with ClientManager().get_client(configs) as client:
+ try:
+ collection = await get_collection(client, configs, False)
+ except IndexError as e:
+ print(
+ f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config."
+ )
+ return 1
+ except (ValueError, InvalidCollectionException) as e:
+ print(
+ f"{e.__class__.__name__}: There's no existing collection for {configs.project_root}",
+ file=sys.stderr,
+ )
+ return 1
+ if collection is None: # pragma: nocover
+ logger.error(
+ f"Failed to find a collection at {configs.project_root} from {configs.db_url}"
+ )
+ return 1
+ if not verify_ef(collection, configs): # pragma: nocover
+ return 1
- metas = (await collection.get(include=[IncludeEnum.metadatas]))["metadatas"]
- if metas is None:
- return 0
- files_gen = (str(meta.get("path", "")) for meta in metas)
- files = set()
- orphanes = set()
- for file in files_gen:
- if os.path.isfile(file):
- files.add(file)
- else:
- orphanes.add(file)
+ metas = (await collection.get(include=[IncludeEnum.metadatas]))["metadatas"]
+ if metas is None or len(metas) == 0: # pragma: nocover
+ logger.debug("Empty collection.")
+ return 0
- stats = VectoriseStats(removed=len(orphanes))
- collection_lock = Lock()
- stats_lock = Lock()
- max_batch_size = await client.get_max_batch_size()
- semaphore = asyncio.Semaphore(os.cpu_count() or 1)
+ files_gen = (str(meta.get("path", "")) for meta in metas)
+ files = set()
+ orphanes = set()
+ for file in files_gen:
+ if os.path.isfile(file):
+ files.add(file)
+ else:
+ orphanes.add(file)
- with tqdm.tqdm(
- total=len(files), desc="Vectorising files...", disable=configs.pipe
- ) as bar:
- logger.info(f"Updating embeddings for {len(files)} file(s).")
- try:
- tasks = [
- asyncio.create_task(
- chunked_add(
- str(file),
- collection,
- collection_lock,
- stats,
- stats_lock,
- configs,
- max_batch_size,
- semaphore,
+ stats = VectoriseStats(removed=len(orphanes))
+ collection_lock = Lock()
+ stats_lock = Lock()
+ max_batch_size = await client.get_max_batch_size()
+ semaphore = asyncio.Semaphore(os.cpu_count() or 1)
+
+ with tqdm.tqdm(
+ total=len(files), desc="Vectorising files...", disable=configs.pipe
+ ) as bar:
+ logger.info(f"Updating embeddings for {len(files)} file(s).")
+ try:
+ tasks = [
+ asyncio.create_task(
+ chunked_add(
+ str(file),
+ collection,
+ collection_lock,
+ stats,
+ stats_lock,
+ configs,
+ max_batch_size,
+ semaphore,
+ )
)
- )
- for file in files
- ]
- for task in asyncio.as_completed(tasks):
- await task
- bar.update(1)
- except asyncio.CancelledError: # pragma: nocover
- print("Abort.", file=sys.stderr)
- return 1
+ for file in files
+ ]
+ for task in asyncio.as_completed(tasks):
+ await task
+ bar.update(1)
+ except asyncio.CancelledError: # pragma: nocover
+ print("Abort.", file=sys.stderr)
+ return 1
- if len(orphanes):
- logger.info(f"Removing {len(orphanes)} orphaned files from database.")
- await collection.delete(where={"path": {"$in": list(orphanes)}})
+ if len(orphanes):
+ logger.info(f"Removing {len(orphanes)} orphaned files from database.")
+ await collection.delete(where={"path": {"$in": list(orphanes)}})
- show_stats(configs, stats)
- return 0
+ show_stats(configs, stats)
+ return 0
diff --git a/src/vectorcode/subcommands/vectorise.py b/src/vectorcode/subcommands/vectorise.py
index 40ef1619..a0bea88f 100644
--- a/src/vectorcode/subcommands/vectorise.py
+++ b/src/vectorcode/subcommands/vectorise.py
@@ -24,7 +24,7 @@
expand_path,
)
from vectorcode.common import (
- get_client,
+ ClientManager,
get_collection,
list_collection_files,
verify_ef,
@@ -251,64 +251,64 @@ def find_exclude_specs(configs: Config) -> list[str]:
async def vectorise(configs: Config) -> int:
assert configs.project_root is not None
- client = await get_client(configs)
- try:
- collection = await get_collection(client, configs, True)
- except IndexError as e:
- print(
- f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config."
- )
- return 1
- if not verify_ef(collection, configs):
- return 1
-
- files = await expand_globs(
- configs.files or load_files_from_include(str(configs.project_root)),
- recursive=configs.recursive,
- include_hidden=configs.include_hidden,
- )
-
- if not configs.force:
- for spec_path in find_exclude_specs(configs):
- if os.path.isfile(spec_path):
- logger.info(f"Loading ignore specs from {spec_path}.")
- files = exclude_paths_by_spec((str(i) for i in files), spec_path)
- else: # pragma: nocover
- logger.info("Ignoring exclude specs.")
-
- stats = VectoriseStats()
- collection_lock = Lock()
- stats_lock = Lock()
- max_batch_size = await client.get_max_batch_size()
- semaphore = asyncio.Semaphore(os.cpu_count() or 1)
-
- with tqdm.tqdm(
- total=len(files), desc="Vectorising files...", disable=configs.pipe
- ) as bar:
+ async with ClientManager().get_client(configs) as client:
try:
- tasks = [
- asyncio.create_task(
- chunked_add(
- str(file),
- collection,
- collection_lock,
- stats,
- stats_lock,
- configs,
- max_batch_size,
- semaphore,
- )
- )
- for file in files
- ]
- for task in asyncio.as_completed(tasks):
- await task
- bar.update(1)
- except asyncio.CancelledError:
- print("Abort.", file=sys.stderr)
+ collection = await get_collection(client, configs, True)
+ except IndexError as e:
+ print(
+ f"{e.__class__.__name__}: Failed to get/create the collection. Please check your config."
+ )
+ return 1
+ if not verify_ef(collection, configs):
return 1
- await remove_orphanes(collection, collection_lock, stats, stats_lock)
+ files = await expand_globs(
+ configs.files or load_files_from_include(str(configs.project_root)),
+ recursive=configs.recursive,
+ include_hidden=configs.include_hidden,
+ )
- show_stats(configs=configs, stats=stats)
- return 0
+ if not configs.force:
+ for spec_path in find_exclude_specs(configs):
+ if os.path.isfile(spec_path):
+ logger.info(f"Loading ignore specs from {spec_path}.")
+ files = exclude_paths_by_spec((str(i) for i in files), spec_path)
+ else: # pragma: nocover
+ logger.info("Ignoring exclude specs.")
+
+ stats = VectoriseStats()
+ collection_lock = Lock()
+ stats_lock = Lock()
+ max_batch_size = await client.get_max_batch_size()
+ semaphore = asyncio.Semaphore(os.cpu_count() or 1)
+
+ with tqdm.tqdm(
+ total=len(files), desc="Vectorising files...", disable=configs.pipe
+ ) as bar:
+ try:
+ tasks = [
+ asyncio.create_task(
+ chunked_add(
+ str(file),
+ collection,
+ collection_lock,
+ stats,
+ stats_lock,
+ configs,
+ max_batch_size,
+ semaphore,
+ )
+ )
+ for file in files
+ ]
+ for task in asyncio.as_completed(tasks):
+ await task
+ bar.update(1)
+ except asyncio.CancelledError:
+ print("Abort.", file=sys.stderr)
+ return 1
+
+ await remove_orphanes(collection, collection_lock, stats, stats_lock)
+
+ show_stats(configs=configs, stats=stats)
+ return 0
diff --git a/tests/subcommands/query/test_query.py b/tests/subcommands/query/test_query.py
index 4a54de9d..9f8e4078 100644
--- a/tests/subcommands/query/test_query.py
+++ b/tests/subcommands/query/test_query.py
@@ -355,7 +355,7 @@ async def test_query_success(mock_config):
mock_collection = AsyncMock()
with (
- patch("vectorcode.subcommands.query.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.query.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.query.get_collection", return_value=mock_collection
),
@@ -367,6 +367,7 @@ async def test_query_success(mock_config):
patch("os.path.relpath", return_value="rel/path.py"),
patch("os.path.abspath", return_value="/abs/path.py"),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
# Set up the mock file paths and contents
mock_get_files.return_value = ["file1.py", "file2.py"]
mock_file_handle = MagicMock()
@@ -396,7 +397,7 @@ async def test_query_pipe_mode(mock_config):
mock_collection = AsyncMock()
with (
- patch("vectorcode.subcommands.query.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.query.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.query.get_collection", return_value=mock_collection
),
@@ -408,6 +409,7 @@ async def test_query_pipe_mode(mock_config):
patch("os.path.relpath", return_value="rel/path.py"),
patch("os.path.abspath", return_value="/abs/path.py"),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
# Set up the mock file paths and contents
mock_get_files.return_value = ["file1.py", "file2.py"]
mock_file_handle = MagicMock()
@@ -434,7 +436,7 @@ async def test_query_absolute_path(mock_config):
mock_collection = AsyncMock()
with (
- patch("vectorcode.subcommands.query.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.query.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.query.get_collection", return_value=mock_collection
),
@@ -445,6 +447,7 @@ async def test_query_absolute_path(mock_config):
patch("os.path.relpath", return_value="rel/path.py"),
patch("os.path.abspath", return_value="/abs/path.py"),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
# Set up the mock file paths and contents
mock_get_files.return_value = ["file1.py"]
mock_file_handle = MagicMock()
@@ -463,7 +466,7 @@ async def test_query_collection_not_found():
config = Config(project_root="/test/project")
with (
- patch("vectorcode.subcommands.query.get_client"),
+ patch("vectorcode.subcommands.query.ClientManager"),
patch("vectorcode.subcommands.query.get_collection") as mock_get_collection,
patch("sys.stderr"),
):
@@ -482,7 +485,7 @@ async def test_query_invalid_collection():
config = Config(project_root="/test/project")
with (
- patch("vectorcode.subcommands.query.get_client"),
+ patch("vectorcode.subcommands.query.ClientManager"),
patch("vectorcode.subcommands.query.get_collection") as mock_get_collection,
patch("sys.stderr"),
):
@@ -503,7 +506,7 @@ async def test_query_invalid_dimension():
config = Config(project_root="/test/project")
with (
- patch("vectorcode.subcommands.query.get_client"),
+ patch("vectorcode.subcommands.query.ClientManager"),
patch("vectorcode.subcommands.query.get_collection") as mock_get_collection,
patch("sys.stderr"),
):
@@ -524,7 +527,7 @@ async def test_query_invalid_file(mock_config):
mock_collection = AsyncMock()
with (
- patch("vectorcode.subcommands.query.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.query.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.query.get_collection", return_value=mock_collection
),
@@ -532,6 +535,7 @@ async def test_query_invalid_file(mock_config):
patch("vectorcode.subcommands.query.get_query_result_files") as mock_get_files,
patch("os.path.isfile", return_value=False),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
# Set up the mock file paths
mock_get_files.return_value = ["invalid_file.py"]
@@ -549,12 +553,13 @@ async def test_query_invalid_ef(mock_config):
mock_collection = AsyncMock()
with (
- patch("vectorcode.subcommands.query.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.query.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.query.get_collection", return_value=mock_collection
),
patch("vectorcode.subcommands.query.verify_ef", return_value=False),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
# Call the function
result = await query(mock_config)
@@ -580,13 +585,14 @@ async def test_query_chunk_mode_no_metadata_fallback(mock_config):
mock_collection.get.return_value = {"ids": []}
with (
- patch("vectorcode.subcommands.query.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.query.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.query.get_collection", return_value=mock_collection
),
patch("vectorcode.subcommands.query.verify_ef", return_value=True),
patch("vectorcode.subcommands.query.build_query_results") as mock_build_results,
):
+ MockClientManager.return_value._create_client.return_value = mock_client
mock_build_results.return_value = [] # Return empty results for simplicity
result = await query(mock_config)
diff --git a/tests/subcommands/test_clean.py b/tests/subcommands/test_clean.py
index 1fc345fd..8c79fd7f 100644
--- a/tests/subcommands/test_clean.py
+++ b/tests/subcommands/test_clean.py
@@ -73,10 +73,10 @@ async def mock_get_collections(client):
@pytest.mark.asyncio
async def test_clean():
- mock_client = AsyncMock(spec=AsyncClientAPI)
+ AsyncMock(spec=AsyncClientAPI)
mock_config = Config(pipe=False)
- with patch("vectorcode.subcommands.clean.get_client", return_value=mock_client):
+ with patch("vectorcode.subcommands.clean.ClientManager"):
result = await clean(mock_config)
assert result == 0
diff --git a/tests/subcommands/test_drop.py b/tests/subcommands/test_drop.py
index dcf4b1f0..15b990d8 100644
--- a/tests/subcommands/test_drop.py
+++ b/tests/subcommands/test_drop.py
@@ -1,3 +1,4 @@
+from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, patch
import pytest
@@ -31,19 +32,31 @@ def mock_collection():
async def test_drop_success(mock_config, mock_client, mock_collection):
mock_client.get_collection.return_value = mock_collection
mock_client.delete_collection = AsyncMock()
- with patch("vectorcode.subcommands.drop.get_client", return_value=mock_client):
- with patch(
+ with (
+ patch("vectorcode.subcommands.drop.ClientManager") as MockClientManager,
+ patch(
"vectorcode.subcommands.drop.get_collection", return_value=mock_collection
- ):
- result = await drop(mock_config)
- assert result == 0
- mock_client.delete_collection.assert_called_once_with(mock_collection.name)
+ ),
+ ):
+ mock_client = AsyncMock()
+
+ @asynccontextmanager
+ async def _get_client(self, config=None, need_lock=True):
+ yield mock_client
+
+ mock_client_manager = MockClientManager.return_value
+ mock_client_manager._create_client = AsyncMock(return_value=mock_client)
+ mock_client_manager.get_client = _get_client
+
+ result = await drop(mock_config)
+ assert result == 0
+ mock_client.delete_collection.assert_called_once_with(mock_collection.name)
@pytest.mark.asyncio
async def test_drop_collection_not_found(mock_config, mock_client):
mock_client.get_collection.side_effect = ValueError("Collection not found")
- with patch("vectorcode.subcommands.drop.get_client", return_value=mock_client):
+ with patch("vectorcode.subcommands.drop.ClientManager"):
with patch(
"vectorcode.subcommands.drop.get_collection",
side_effect=ValueError("Collection not found"),
diff --git a/tests/subcommands/test_ls.py b/tests/subcommands/test_ls.py
index 36f67469..bbc674eb 100644
--- a/tests/subcommands/test_ls.py
+++ b/tests/subcommands/test_ls.py
@@ -1,6 +1,6 @@
import json
import socket
-from unittest.mock import AsyncMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import tabulate
@@ -77,7 +77,7 @@ async def mock_get_collections(client):
yield mock_collection
with (
- patch("vectorcode.subcommands.ls.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.ls.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.ls.get_collection_list",
return_value=[
@@ -90,6 +90,10 @@ async def mock_get_collections(client):
],
),
):
+ mock_client = MagicMock()
+ mock_client_manager = MockClientManager.return_value
+ mock_client_manager._create_client = AsyncMock(return_value=mock_client)
+
config = Config(pipe=True)
await ls(config)
captured = capsys.readouterr()
@@ -126,7 +130,7 @@ async def mock_get_collections(client):
yield mock_collection
with (
- patch("vectorcode.subcommands.ls.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.ls.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.ls.get_collection_list",
return_value=[
@@ -139,6 +143,10 @@ async def mock_get_collections(client):
],
),
):
+ mock_client = MagicMock()
+ mock_client_manager = MockClientManager.return_value
+ mock_client_manager._create_client = AsyncMock(return_value=mock_client)
+
config = Config(pipe=False)
await ls(config)
captured = capsys.readouterr()
@@ -159,7 +167,7 @@ async def mock_get_collections(client):
# Test with HOME environment variable set
monkeypatch.setenv("HOME", "/test")
with (
- patch("vectorcode.subcommands.ls.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.ls.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.ls.get_collection_list",
return_value=[
@@ -172,6 +180,9 @@ async def mock_get_collections(client):
],
),
):
+ mock_client = MagicMock()
+ mock_client_manager = MockClientManager.return_value
+ mock_client_manager._create_client = AsyncMock(return_value=mock_client)
config = Config(pipe=False)
await ls(config)
captured = capsys.readouterr()
diff --git a/tests/subcommands/test_update.py b/tests/subcommands/test_update.py
index febfa405..314f7c2a 100644
--- a/tests/subcommands/test_update.py
+++ b/tests/subcommands/test_update.py
@@ -19,7 +19,7 @@ async def test_update_success():
mock_client.get_max_batch_size.return_value = 100
with (
- patch("vectorcode.subcommands.update.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.update.ClientManager"),
patch(
"vectorcode.subcommands.update.get_collection", return_value=mock_collection
),
@@ -50,7 +50,7 @@ async def test_update_with_orphans():
mock_client.get_max_batch_size.return_value = 100
with (
- patch("vectorcode.subcommands.update.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.update.ClientManager"),
patch(
"vectorcode.subcommands.update.get_collection", return_value=mock_collection
),
@@ -78,10 +78,11 @@ async def test_update_index_error():
# mock_collection = AsyncMock()
with (
- patch("vectorcode.subcommands.update.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.update.ClientManager") as MockClientManager,
patch("vectorcode.subcommands.update.get_collection", side_effect=IndexError),
patch("sys.stderr"),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
config = Config(project_root="/test/project", pipe=False)
result = await update(config)
@@ -94,10 +95,11 @@ async def test_update_value_error():
# mock_collection = AsyncMock()
with (
- patch("vectorcode.subcommands.update.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.update.ClientManager") as MockClientManager,
patch("vectorcode.subcommands.update.get_collection", side_effect=ValueError),
patch("sys.stderr"),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
config = Config(project_root="/test/project", pipe=False)
result = await update(config)
@@ -110,13 +112,14 @@ async def test_update_invalid_collection_exception():
# mock_collection = AsyncMock()
with (
- patch("vectorcode.subcommands.update.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.update.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.update.get_collection",
side_effect=InvalidCollectionException,
),
patch("sys.stderr"),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
config = Config(project_root="/test/project", pipe=False)
result = await update(config)
diff --git a/tests/subcommands/test_vectorise.py b/tests/subcommands/test_vectorise.py
index 2f363a8b..6b2287bf 100644
--- a/tests/subcommands/test_vectorise.py
+++ b/tests/subcommands/test_vectorise.py
@@ -370,9 +370,7 @@ async def test_vectorise(capsys):
with ExitStack() as stack:
stack.enter_context(
- patch(
- "vectorcode.subcommands.vectorise.get_client", return_value=mock_client
- )
+ patch("vectorcode.subcommands.vectorise.ClientManager"),
)
stack.enter_context(patch("os.path.isfile", return_value=False))
stack.enter_context(
@@ -427,7 +425,7 @@ async def mock_chunked_add(*args, **kwargs):
"vectorcode.subcommands.vectorise.chunked_add", side_effect=mock_chunked_add
) as mock_add,
patch("sys.stderr") as mock_stderr,
- patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.vectorise.get_collection",
return_value=mock_collection,
@@ -438,6 +436,7 @@ async def mock_chunked_add(*args, **kwargs):
lambda x: not (x.endswith("gitignore") or x.endswith("vectorcode.exclude")),
),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
result = await vectorise(configs)
assert result == 1
mock_add.assert_called_once()
@@ -458,7 +457,7 @@ async def test_vectorise_orphaned_files():
pipe=False,
)
- mock_client = AsyncMock()
+ AsyncMock()
mock_collection = AsyncMock()
# Define a mock response for collection.get in vectorise
@@ -494,7 +493,7 @@ def is_file_side_effect(path):
"vectorcode.subcommands.vectorise.TreeSitterChunker",
return_value=mock_chunker,
),
- patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.vectorise.ClientManager"),
patch(
"vectorcode.subcommands.vectorise.get_collection",
return_value=mock_collection,
@@ -532,10 +531,11 @@ async def test_vectorise_collection_index_error():
mock_client = AsyncMock()
with (
- patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager,
patch("vectorcode.subcommands.vectorise.get_collection") as mock_get_collection,
patch("os.path.isfile", return_value=False),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
mock_get_collection.side_effect = IndexError("Collection not found")
result = await vectorise(configs)
assert result == 1
@@ -558,7 +558,7 @@ async def test_vectorise_verify_ef_false():
mock_collection = AsyncMock()
with (
- patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.vectorise.get_collection",
return_value=mock_collection,
@@ -566,6 +566,7 @@ async def test_vectorise_verify_ef_false():
patch("vectorcode.subcommands.vectorise.verify_ef", return_value=False),
patch("os.path.isfile", return_value=False),
):
+ MockClientManager.return_value._create_client.return_value = mock_client
result = await vectorise(configs)
assert result == 1
@@ -588,7 +589,7 @@ async def test_vectorise_gitignore():
mock_collection.get.return_value = {"metadatas": []}
with (
- patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.vectorise.get_collection",
return_value=mock_collection,
@@ -608,6 +609,7 @@ async def test_vectorise_gitignore():
"vectorcode.subcommands.vectorise.exclude_paths_by_spec"
) as mock_exclude_paths,
):
+ MockClientManager.return_value._create_client.return_value = mock_client
await vectorise(configs)
mock_exclude_paths.assert_called_once()
@@ -635,7 +637,7 @@ async def test_vectorise_exclude_file(tmpdir):
mock_collection.get.return_value = {"ids": []}
with (
- patch("vectorcode.subcommands.vectorise.get_client", return_value=mock_client),
+ patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager,
patch(
"vectorcode.subcommands.vectorise.get_collection",
return_value=mock_collection,
@@ -652,6 +654,7 @@ async def test_vectorise_exclude_file(tmpdir):
),
patch("vectorcode.subcommands.vectorise.chunked_add") as mock_chunked_add,
):
+ MockClientManager.return_value._create_client.return_value = mock_client
await vectorise(configs)
# Assert that chunked_add is only called for test_file.py, not excluded_file.py
call_args = [call[0][0] for call in mock_chunked_add.call_args_list]
@@ -664,7 +667,6 @@ async def test_vectorise_exclude_file(tmpdir):
@pytest.mark.asyncio
-@patch("vectorcode.subcommands.vectorise.get_client", new_callable=AsyncMock)
@patch("vectorcode.subcommands.vectorise.get_collection", new_callable=AsyncMock)
@patch("vectorcode.subcommands.vectorise.expand_globs", new_callable=AsyncMock)
@patch("vectorcode.subcommands.vectorise.chunked_add", new_callable=AsyncMock)
@@ -681,7 +683,6 @@ async def test_vectorise_uses_global_exclude_when_local_missing(
mock_chunked_add,
mock_expand_globs,
mock_get_collection,
- mock_get_client,
tmp_path,
):
"""
@@ -712,14 +713,20 @@ def isfile_side_effect(p):
global_exclude_content = "*.bin"
m_open = mock_open(read_data=global_exclude_content)
- with patch("builtins.open", m_open):
+ with (
+ patch("builtins.open", m_open),
+ patch("vectorcode.subcommands.vectorise.ClientManager") as MockClientManager,
+ ):
mock_spec_instance = MagicMock()
mock_spec_instance.match_file = lambda path: str(path).endswith(".bin")
mock_gitignore_spec.from_lines.return_value = mock_spec_instance
mock_client_instance = AsyncMock()
mock_client_instance.get_max_batch_size = AsyncMock(return_value=100)
- mock_get_client.return_value = mock_client_instance
+
+ MockClientManager.return_value._create_client.return_value = (
+ mock_client_instance
+ )
mock_collection_instance = AsyncMock()
mock_collection_instance.get = AsyncMock(
diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py
index 3252683f..655ef19f 100644
--- a/tests/test_cli_utils.py
+++ b/tests/test_cli_utils.py
@@ -11,6 +11,7 @@
from vectorcode.cli_utils import (
CliAction,
Config,
+ LockManager,
PromptCategory,
QueryInclude,
cleanup_path,
@@ -553,3 +554,11 @@ def test_shtab():
.stderr.read()
.decode()
) == ""
+
+
+@pytest.mark.asyncio
+async def test_filelock():
+ manager = LockManager()
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ manager.get_lock(tmp_dir)
+ assert os.path.isfile(os.path.join(tmp_dir, "vectorcode.lock"))
diff --git a/tests/test_common.py b/tests/test_common.py
index 98f1370b..40b51fd6 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -3,7 +3,7 @@
import subprocess
import sys
import tempfile
-from unittest.mock import MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
@@ -13,7 +13,7 @@
from vectorcode.cli_utils import Config
from vectorcode.common import (
- get_client,
+ ClientManager,
get_collection,
get_collection_name,
get_collections,
@@ -150,82 +150,6 @@ async def test_try_server_versions():
assert await try_server("http://localhost:8300") is False
-@pytest.mark.asyncio
-async def test_get_client():
- # Patch chromadb.AsyncHttpClient to avoid actual network calls
- with patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient:
- mock_client = MagicMock(spec=AsyncClientAPI)
- MockAsyncHttpClient.return_value = mock_client
-
- config = Config(db_url="https://test_host:1234", db_path="test_db")
- client = await get_client(config)
-
- assert isinstance(client, AsyncClientAPI)
- MockAsyncHttpClient.assert_called_once()
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host
- == "test_host"
- )
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port
- == 1234
- )
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry
- is False
- )
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_ssl_enabled
- is True
- )
-
- # Test with valid db_settings (only anonymized_telemetry)
- config = Config(
- db_url="http://test_host1:1234",
- db_path="test_db",
- db_settings={"anonymized_telemetry": True},
- )
- client = await get_client(config)
-
- assert isinstance(client, AsyncClientAPI)
- MockAsyncHttpClient.assert_called()
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host
- == "test_host1"
- )
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port
- == 1234
- )
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry
- is True
- )
-
- # Test with multiple db_settings, including an invalid one. The invalid one
- # should be filtered out inside get_client.
- config = Config(
- db_url="http://test_host2:1234",
- db_path="test_db",
- db_settings={"anonymized_telemetry": True, "other_setting": "value"},
- )
- client = await get_client(config)
- assert isinstance(client, AsyncClientAPI)
- MockAsyncHttpClient.assert_called()
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host
- == "test_host2"
- )
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port
- == 1234
- )
- assert (
- MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry
- is True
- )
-
-
def test_verify_ef():
# Mocking AsyncCollection and Config
mock_collection = MagicMock()
@@ -581,3 +505,139 @@ async def test_wait_for_server_timeout():
# Verify try_server was called multiple times (due to retries)
assert mock_try_server.call_count > 1
+
+
+@pytest.mark.asyncio
+async def test_client_manager_get_client():
+ config = Config(
+ db_url="https://test_host:1234", db_path="test_db", project_root="test_proj"
+ )
+ config1 = Config(
+ db_url="http://test_host1:1234",
+ db_path="test_db",
+ project_root="test_proj1",
+ db_settings={"anonymized_telemetry": True},
+ )
+ config1_alt = Config(
+ db_url="http://test_host1:1234",
+ db_path="test_db",
+ project_root="test_proj1",
+ db_settings={"anonymized_telemetry": True, "other_setting": "value"},
+ )
+ # Patch chromadb.AsyncHttpClient to avoid actual network calls
+ with (
+ patch("chromadb.AsyncHttpClient") as MockAsyncHttpClient,
+ patch("vectorcode.common.try_server", return_value=True),
+ ):
+ mock_client = MagicMock(spec=AsyncClientAPI)
+ MockAsyncHttpClient.return_value = mock_client
+
+ async with (
+ ClientManager().get_client(config) as client,
+ ):
+ assert isinstance(client, AsyncClientAPI)
+ MockAsyncHttpClient.assert_called()
+ assert (
+ MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host
+ == "test_host"
+ )
+ assert (
+ MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_http_port
+ == 1234
+ )
+ assert (
+ MockAsyncHttpClient.call_args.kwargs["settings"].anonymized_telemetry
+ is False
+ )
+ assert (
+ MockAsyncHttpClient.call_args.kwargs[
+ "settings"
+ ].chroma_server_ssl_enabled
+ is True
+ )
+
+ async with (
+ ClientManager().get_client(config1) as client1,
+ ClientManager().get_client(config1_alt) as client1_alt,
+ ):
+ assert isinstance(client1, AsyncClientAPI)
+ MockAsyncHttpClient.assert_called()
+ assert (
+ MockAsyncHttpClient.call_args.kwargs["settings"].chroma_server_host
+ == "test_host1"
+ )
+ assert (
+ MockAsyncHttpClient.call_args.kwargs[
+ "settings"
+ ].chroma_server_http_port
+ == 1234
+ )
+ assert (
+ MockAsyncHttpClient.call_args.kwargs[
+ "settings"
+ ].anonymized_telemetry
+ is True
+ )
+
+ # Test with multiple db_settings, including an invalid one. The invalid one
+ # should be filtered out inside get_client.
+ assert id(client1_alt) == id(client1)
+
+
+@pytest.mark.asyncio
+async def test_client_manager_list_server_processes():
+ async def _try_server(url):
+ return "127.0.0.1" in url or "localhost" in url
+
+ async def _start_server(cfg):
+ return AsyncMock()
+
+ with (
+ tempfile.TemporaryDirectory() as temp_dir,
+ patch("vectorcode.common.start_server", side_effect=_start_server),
+ patch("vectorcode.common.try_server", side_effect=_try_server),
+ ):
+ db_path = os.path.join(temp_dir, "db")
+ os.makedirs(db_path, exist_ok=True)
+
+ ClientManager._create_client = AsyncMock()
+ async with ClientManager().get_client(
+ Config(
+ db_url="http://test_host:8001",
+ project_root="proj1",
+ db_path=db_path,
+ )
+ ):
+ print(ClientManager().get_processes())
+ async with ClientManager().get_client(
+ Config(
+ db_url="http://test_host:8002",
+ project_root="proj2",
+ db_path=db_path,
+ )
+ ):
+ pass
+ assert len(ClientManager().get_processes()) == 2
+
+
+@pytest.mark.asyncio
+async def test_client_manager_kill_servers():
+ manager = ClientManager()
+ manager.clear()
+
+ async def _try_server(url):
+ return "127.0.0.1" in url or "localhost" in url
+
+ mock_process = AsyncMock()
+ mock_process.terminate = MagicMock()
+ with (
+ patch("vectorcode.common.start_server", return_value=mock_process),
+ patch("vectorcode.common.try_server", side_effect=_try_server),
+ ):
+ manager._create_client = AsyncMock(return_value=AsyncMock())
+ async with manager.get_client(Config(db_url="http://test_host:1081")):
+ pass
+ assert len(manager.get_processes()) == 1
+ await manager.kill_servers()
+ mock_process.terminate.assert_called_once()
+ mock_process.wait.assert_awaited()
diff --git a/tests/test_lsp.py b/tests/test_lsp.py
index 46bcf7eb..18f999ff 100644
--- a/tests/test_lsp.py
+++ b/tests/test_lsp.py
@@ -1,3 +1,4 @@
+from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -9,7 +10,6 @@
from vectorcode.lsp_main import (
execute_command,
lsp_start,
- make_caches,
)
@@ -39,66 +39,20 @@ def mock_config():
return config
-@pytest.mark.asyncio
-async def test_make_caches(tmp_path):
- project_root = str(tmp_path)
- config_file = tmp_path / ".vectorcode" / "config.json"
- config_file.parent.mkdir(exist_ok=True)
- config_file.write_text('{"host": "test_host", "port": 9999}')
- from vectorcode.lsp_main import cached_project_configs
-
- with (
- patch(
- "vectorcode.lsp_main.get_project_config", new_callable=AsyncMock
- ) as mock_get_project_config,
- patch(
- "vectorcode.lsp_main.try_server", new_callable=AsyncMock
- ) as mock_try_server,
- ):
- mock_try_server.return_value = True
- await make_caches(project_root)
-
- mock_get_project_config.assert_called_once_with(project_root)
- assert project_root in cached_project_configs
-
-
-@pytest.mark.asyncio
-async def test_make_caches_server_unavailable(tmp_path):
- project_root = str(tmp_path)
- config_file = tmp_path / ".vectorcode" / "config.json"
- config_file.parent.mkdir(exist_ok=True)
- config_file.write_text('{"host": "test_host", "port": 9999}')
-
- with (
- patch("vectorcode.lsp_main.get_project_config", new_callable=AsyncMock),
- patch(
- "vectorcode.lsp_main.try_server", new_callable=AsyncMock
- ) as mock_try_server,
- ):
- mock_try_server.return_value = False
- with pytest.raises(ConnectionError):
- await make_caches(project_root)
-
-
@pytest.mark.asyncio
async def test_execute_command_query(mock_language_server, mock_config):
with (
patch(
"vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock
) as mock_parse_cli_args,
- patch("vectorcode.lsp_main.get_client", new_callable=AsyncMock),
+ patch("vectorcode.lsp_main.ClientManager"),
patch("vectorcode.lsp_main.get_collection", new_callable=AsyncMock),
patch(
"vectorcode.lsp_main.build_query_results", new_callable=AsyncMock
) as mock_get_query_result_files,
patch("os.path.isfile", return_value=True),
- patch("vectorcode.lsp_main.try_server", return_value=True),
patch("builtins.open", MagicMock()) as mock_open,
- patch("vectorcode.lsp_main.cached_project_configs", {}),
):
- from vectorcode.lsp_main import cached_project_configs
-
- cached_project_configs.clear()
mock_parse_cli_args.return_value = mock_config
mock_get_query_result_files.return_value = ["/test/file.txt"]
@@ -110,9 +64,6 @@ async def test_execute_command_query(mock_language_server, mock_config):
# Ensure parsed_args.project_root is not None
mock_config.project_root = "/test/project"
- # Add a mock config to cached_project_configs
- cached_project_configs["/test/project"] = mock_config
-
# Mock the merge_from method
mock_config.merge_from = AsyncMock(return_value=mock_config)
@@ -131,22 +82,17 @@ async def test_execute_command_query_default_proj_root(
patch(
"vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock
) as mock_parse_cli_args,
- patch("vectorcode.lsp_main.get_client", new_callable=AsyncMock),
+ patch("vectorcode.lsp_main.ClientManager"),
patch("vectorcode.lsp_main.get_collection", new_callable=AsyncMock),
patch(
"vectorcode.lsp_main.build_query_results", new_callable=AsyncMock
) as mock_get_query_result_files,
patch("os.path.isfile", return_value=True),
- patch("vectorcode.lsp_main.try_server", return_value=True),
patch("builtins.open", MagicMock()) as mock_open,
- patch("vectorcode.lsp_main.cached_project_configs", {}),
):
- from vectorcode.lsp_main import cached_project_configs
-
global DEFAULT_PROJECT_ROOT
mock_config.project_root = None
- cached_project_configs.clear()
mock_parse_cli_args.return_value = mock_config
mock_get_query_result_files.return_value = ["/test/file.txt"]
@@ -158,9 +104,6 @@ async def test_execute_command_query_default_proj_root(
# Ensure parsed_args.project_root is not None
DEFAULT_PROJECT_ROOT = "/test/project"
- # Add a mock config to cached_project_configs
- cached_project_configs["/test/project"] = mock_config
-
# Mock the merge_from method
mock_config.merge_from = AsyncMock(return_value=mock_config)
@@ -183,26 +126,18 @@ async def test_execute_command_ls(mock_language_server, mock_config):
patch(
"vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock
) as mock_parse_cli_args,
- patch("vectorcode.lsp_main.get_client", new_callable=AsyncMock),
+ patch("vectorcode.lsp_main.ClientManager"),
patch(
"vectorcode.lsp_main.get_collection_list", new_callable=AsyncMock
) as mock_get_collection_list,
- patch("vectorcode.lsp_main.cached_project_configs", {}),
patch("vectorcode.common.get_embedding_function") as mock_embedding_function,
patch("vectorcode.common.get_collection") as mock_get_collection,
- patch("vectorcode.lsp_main.try_server", return_value=True),
):
- from vectorcode.lsp_main import cached_project_configs
-
- cached_project_configs.clear()
mock_parse_cli_args.return_value = mock_config
# Ensure parsed_args.project_root is not None
mock_config.project_root = "/test/project"
- # Add a mock config to cached_project_configs
- cached_project_configs["/test/project"] = mock_config
-
# Mock the merge_from method
mock_config.merge_from = AsyncMock(return_value=mock_config)
@@ -236,9 +171,7 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf
patch(
"vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock
) as mock_parse_cli_args,
- patch(
- "vectorcode.lsp_main.get_client", new_callable=AsyncMock
- ) as mock_get_client,
+ patch("vectorcode.lsp_main.ClientManager") as MockClientManager,
patch(
"vectorcode.lsp_main.get_collection", new_callable=AsyncMock
) as mock_get_collection,
@@ -255,16 +188,11 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf
patch(
"vectorcode.lsp_main.chunked_add", new_callable=AsyncMock
) as mock_chunked_add,
- patch("vectorcode.lsp_main.try_server", return_value=True),
- patch("vectorcode.lsp_main.cached_project_configs", {}),
patch(
"vectorcode.lsp_main.load_files_from_include",
return_value=dummy_initial_files,
) as mock_load_files_from_include,
patch("os.cpu_count", return_value=1), # For asyncio.Semaphore
- patch(
- "vectorcode.lsp_main.make_caches", new_callable=AsyncMock
- ), # Mock make_caches to avoid actual file system ops
patch(
"vectorcode.lsp_main.remove_orphanes", new_callable=AsyncMock
) as mock_remove_orphanes,
@@ -273,15 +201,14 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf
from lsprotocol import types
- from vectorcode.lsp_main import cached_project_configs
-
- cached_project_configs.clear()
- cached_project_configs["/test/project"] = mock_config # Add config to cache
+ @asynccontextmanager
+ async def _get_client(*args):
+ yield mock_client
# Set return values for mocks
mock_parse_cli_args.return_value = mock_config
mock_client = AsyncMock()
- mock_get_client.return_value = mock_client
+ MockClientManager.return_value.get_client.side_effect = _get_client
mock_collection = AsyncMock()
mock_get_collection.return_value = mock_collection
mock_client.get_max_batch_size.return_value = 100 # Mock batch size
@@ -319,7 +246,7 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf
recursive=mock_config.recursive,
include_hidden=mock_config.include_hidden,
)
- mock_find_exclude_specs.assert_called_once_with(mock_config)
+ mock_find_exclude_specs.assert_called_once()
mock_exclude_paths_by_spec.assert_not_called() # Because mock_find_exclude_specs returns empty list (no specs to exclude by)
mock_client.get_max_batch_size.assert_called_once()
@@ -332,7 +259,7 @@ async def test_execute_command_vectorise(mock_language_server, mock_config: Conf
ANY, # asyncio.Lock object
ANY, # stats dict
ANY, # stats_lock
- mock_config,
+ ANY,
100, # max_batch_size
ANY, # semaphore
)
@@ -362,16 +289,9 @@ async def test_execute_command_unsupported_action(
patch(
"vectorcode.lsp_main.get_collection", new_callable=AsyncMock
) as mock_get_collection,
- patch("vectorcode.lsp_main.cached_project_configs", {}),
- patch("vectorcode.lsp_main.try_server", return_value=True),
):
- from vectorcode.lsp_main import cached_project_configs
-
- cached_project_configs.clear()
mock_parse_cli_args.return_value = mock_config
- # Add a mock config to cached_project_configs
- cached_project_configs["/test/project"] = mock_config
mock_collection = MagicMock()
mock_get_collection.return_value = mock_collection
@@ -449,7 +369,6 @@ async def test_execute_command_no_default_project_root(
patch(
"vectorcode.lsp_main.parse_cli_args", new_callable=AsyncMock
) as mock_parse_cli_args,
- patch("vectorcode.lsp_main.get_client", new_callable=AsyncMock),
):
mock_parse_cli_args.return_value = mock_config
with pytest.raises((AssertionError, JsonRpcInternalError)):
diff --git a/tests/test_main.py b/tests/test_main.py
index d6eadd0b..c9f9b718 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -4,7 +4,7 @@
from vectorcode import __version__
from vectorcode.cli_utils import CliAction
-from vectorcode.main import async_main, main
+from vectorcode.main import async_main
@pytest.mark.asyncio
@@ -140,32 +140,6 @@ async def test_async_main_cli_action_prompts(monkeypatch):
mock_prompts.assert_called_once()
-@pytest.mark.asyncio
-async def test_async_main_try_server_unavailable(monkeypatch):
- mock_cli_args = MagicMock(no_stderr=False, project_root=".", action=CliAction.query)
- monkeypatch.setattr(
- "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args)
- )
- mock_final_configs = MagicMock(host="test_host", port=1234, action=CliAction.query)
- monkeypatch.setattr(
- "vectorcode.main.get_project_config",
- AsyncMock(
- return_value=MagicMock(
- merge_from=AsyncMock(return_value=mock_final_configs)
- )
- ),
- )
- monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=False))
- mock_start_server = AsyncMock()
- monkeypatch.setattr("vectorcode.common.start_server", mock_start_server)
- monkeypatch.setattr("vectorcode.subcommands.query", AsyncMock(return_value=0))
- mock_start_server.return_value.wait = AsyncMock()
- mock_start_server.return_value.terminate = MagicMock()
-
- await async_main()
- mock_start_server.assert_called_once_with(mock_final_configs)
-
-
@pytest.mark.asyncio
async def test_async_main_cli_action_query(monkeypatch):
mock_cli_args = MagicMock(no_stderr=False, project_root=".", action=CliAction.query)
@@ -343,41 +317,3 @@ async def test_async_main_exception_handling(monkeypatch):
with patch("vectorcode.main.logger") as mock_logger:
assert await async_main() == 1
mock_logger.error.assert_called_once()
-
-
-@pytest.mark.asyncio
-async def test_async_main_server_process_termination(monkeypatch):
- mock_cli_args = MagicMock(no_stderr=False, project_root=".", action=CliAction.query)
- monkeypatch.setattr(
- "vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args)
- )
- mock_final_configs = MagicMock(host="test_host", port=1234, action=CliAction.query)
- monkeypatch.setattr(
- "vectorcode.main.get_project_config",
- AsyncMock(
- return_value=MagicMock(
- merge_from=AsyncMock(return_value=mock_final_configs)
- )
- ),
- )
- monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=False))
- mock_server_process = AsyncMock()
- mock_start_server = AsyncMock(return_value=mock_server_process)
- monkeypatch.setattr("vectorcode.common.start_server", mock_start_server)
- monkeypatch.setattr("vectorcode.subcommands.query", AsyncMock(return_value=0))
- mock_server_process.terminate = MagicMock()
- mock_server_process.wait = AsyncMock()
-
- await async_main()
-
- mock_server_process.terminate.assert_called_once()
- await mock_server_process.wait()
-
-
-def test_main(monkeypatch):
- mock_async_main = AsyncMock(return_value=0)
- monkeypatch.setattr("vectorcode.main.async_main", mock_async_main)
- monkeypatch.setattr("asyncio.run", MagicMock(return_value=0))
-
- result = main()
- assert result == 0
diff --git a/tests/test_mcp.py b/tests/test_mcp.py
index a9be2f11..43b9eac9 100644
--- a/tests/test_mcp.py
+++ b/tests/test_mcp.py
@@ -20,11 +20,13 @@
@pytest.mark.asyncio
async def test_list_collections_success():
with (
- patch("vectorcode.mcp_main.get_client") as mock_get_client,
patch("vectorcode.mcp_main.get_collections") as mock_get_collections,
+ patch("vectorcode.common.try_server", return_value=True),
):
+ from vectorcode.mcp_main import ClientManager
+
mock_client = AsyncMock()
- mock_get_client.return_value = mock_client
+ ClientManager._create_client = AsyncMock(return_value=mock_client)
mock_collection1 = AsyncMock()
mock_collection1.metadata = {"path": "path1"}
@@ -44,11 +46,14 @@ async def async_generator():
@pytest.mark.asyncio
async def test_list_collections_no_metadata():
with (
- patch("vectorcode.mcp_main.get_client") as mock_get_client,
patch("vectorcode.mcp_main.get_collections") as mock_get_collections,
+ patch("vectorcode.common.try_server", return_value=True),
):
+ from vectorcode.mcp_main import ClientManager
+
mock_client = AsyncMock()
- mock_get_client.return_value = mock_client
+ ClientManager._create_client = AsyncMock(return_value=mock_client)
+
mock_collection1 = AsyncMock()
mock_collection1.metadata = {"path": "path1"}
mock_collection2 = AsyncMock()
@@ -82,23 +87,28 @@ async def test_query_tool_invalid_project_root():
@pytest.mark.asyncio
async def test_query_tool_success():
with (
+ tempfile.TemporaryDirectory() as temp_dir,
patch("os.path.isdir", return_value=True),
patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config,
- patch("vectorcode.mcp_main.get_client") as mock_get_client,
patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
patch(
"vectorcode.subcommands.query.get_query_result_files"
) as mock_get_query_result_files,
+ patch("vectorcode.common.try_server", return_value=True),
patch("builtins.open", create=True) as mock_open,
patch("os.path.isfile", return_value=True),
patch("os.path.relpath", return_value="rel/path.py"),
patch("vectorcode.cli_utils.load_config_file") as mock_load_config_file,
):
- mock_config = Config(chunk_size=100, overlap_ratio=0.1, reranker=None)
+ from vectorcode.mcp_main import ClientManager
+
+ mock_config = Config(
+ chunk_size=100, overlap_ratio=0.1, reranker=None, project_root=temp_dir
+ )
mock_load_config_file.return_value = mock_config
mock_get_project_config.return_value = mock_config
mock_client = AsyncMock()
- mock_get_client.return_value = mock_client
+ ClientManager._create_client = AsyncMock(return_value=mock_client)
# Mock the collection's query method to return a valid QueryResult
mock_collection = AsyncMock()
@@ -119,7 +129,7 @@ async def test_query_tool_success():
mock_open.return_value = mock_file_handle
result = await query_tool(
- n_query=2, query_messages=["keyword1"], project_root="/valid/path"
+ n_query=2, query_messages=["keyword1"], project_root=temp_dir
)
assert len(result) == 2
@@ -131,11 +141,14 @@ async def test_query_tool_collection_access_failure():
with (
patch("os.path.isdir", return_value=True),
patch("vectorcode.mcp_main.get_project_config"),
- patch("vectorcode.mcp_main.get_client") as mock_get_client,
- patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
+ patch("vectorcode.mcp_main.get_collection"), # Still mock get_collection
):
- mock_get_client.side_effect = Exception("Failed to connect")
- mock_get_collection.side_effect = Exception("Failed to connect")
+ from vectorcode.mcp_main import ClientManager
+
+ async def failing_get_client(*args, **kwargs):
+ raise Exception("Failed to connect")
+
+ ClientManager._create_client = AsyncMock(side_effect=failing_get_client)
with pytest.raises(McpError) as exc_info:
await query_tool(
@@ -154,9 +167,13 @@ async def test_query_tool_no_collection():
with (
patch("os.path.isdir", return_value=True),
patch("vectorcode.mcp_main.get_project_config"),
- patch("vectorcode.mcp_main.get_client"),
- patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
+ patch(
+ "vectorcode.mcp_main.get_collection"
+ ) as mock_get_collection, # Still mock get_collection
+ patch("vectorcode.common.ClientManager") as MockClientManager,
):
+ mock_client = AsyncMock()
+ MockClientManager.return_value._create_client.return_value = mock_client
mock_get_collection.return_value = None
with pytest.raises(McpError) as exc_info:
@@ -166,8 +183,8 @@ async def test_query_tool_no_collection():
assert exc_info.value.error.code == 1
assert (
- exc_info.value.error.message
- == "Failed to access the collection at /valid/path. Use `list_collections` tool to get a list of valid paths for this field."
+ "Failed to access the collection at /valid/path. Use `list_collections` tool to get a list of valid paths for this field."
+ in exc_info.value.error.message
)
@@ -190,17 +207,22 @@ async def test_vectorise_files_success():
with (
patch("os.path.isdir", return_value=True),
patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config,
- patch("vectorcode.mcp_main.get_client") as mock_get_client,
patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
patch("vectorcode.subcommands.vectorise.chunked_add"),
patch(
"vectorcode.subcommands.vectorise.hash_file", return_value="test_hash"
),
+ patch("vectorcode.common.try_server", return_value=True),
):
+ from vectorcode.mcp_main import ClientManager
+
mock_config = Config(project_root=temp_dir)
mock_get_project_config.return_value = mock_config
mock_client = AsyncMock()
- mock_get_client.return_value = mock_client
+
+ # Ensure ClientManager's internal client creation method returns our mock.
+ ClientManager._create_client = AsyncMock(return_value=mock_client)
+
mock_collection = AsyncMock()
mock_collection.get.return_value = {"ids": [], "metadatas": []}
mock_get_collection.return_value = mock_collection
@@ -210,8 +232,8 @@ async def test_vectorise_files_success():
assert result["add"] == 1
mock_get_project_config.assert_called_once_with(temp_dir)
- mock_get_client.assert_called_once_with(mock_config)
- mock_get_collection.assert_called_once_with(mock_client, mock_config, True)
+ # Assert that the mocked get_collection was called with our mock_client.
+ mock_get_collection.assert_called_once()
@pytest.mark.asyncio
@@ -219,9 +241,13 @@ async def test_vectorise_files_collection_access_failure():
with (
patch("os.path.isdir", return_value=True),
patch("vectorcode.mcp_main.get_project_config"),
- patch("vectorcode.mcp_main.get_client", side_effect=Exception("Client error")),
+ patch("vectorcode.common.ClientManager"), # Patch ClientManager class
patch("vectorcode.mcp_main.get_collection"),
):
+ from vectorcode.mcp_main import ClientManager
+
+ ClientManager._create_client = AsyncMock(side_effect=Exception("Client error"))
+
with pytest.raises(McpError) as exc_info:
await vectorise_files(paths=["file.py"], project_root="/valid/path")
@@ -257,7 +283,6 @@ def mock_open_side_effect(filename, *args, **kwargs):
with (
patch("os.path.isdir", return_value=True),
patch("vectorcode.mcp_main.get_project_config") as mock_get_project_config,
- patch("vectorcode.mcp_main.get_client") as mock_get_client,
patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
patch("vectorcode.subcommands.vectorise.chunked_add") as mock_chunked_add,
patch(
@@ -270,11 +295,15 @@ def mock_open_side_effect(filename, *args, **kwargs):
"os.path.isfile",
side_effect=lambda x: x in [file1, excluded_file, exclude_spec_file],
),
+ patch("vectorcode.common.try_server", return_value=True),
):
+ from vectorcode.mcp_main import ClientManager
+
mock_config = Config(project_root=temp_dir)
mock_get_project_config.return_value = mock_config
mock_client = AsyncMock()
- mock_get_client.return_value = mock_client
+ ClientManager._create_client = AsyncMock(return_value=mock_client)
+
mock_collection = AsyncMock()
mock_collection.get.return_value = {"ids": [], "metadatas": []}
mock_get_collection.return_value = mock_collection
@@ -297,14 +326,18 @@ async def test_mcp_server():
"vectorcode.mcp_main.find_project_config_dir"
) as mock_find_project_config_dir,
patch("vectorcode.mcp_main.load_config_file") as mock_load_config_file,
- patch("vectorcode.mcp_main.get_client") as mock_get_client,
+ # patch("vectorcode.mcp_main.get_client") as mock_get_client, # Removed
patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool,
+ patch("vectorcode.common.try_server", return_value=True),
):
+ from vectorcode.mcp_main import ClientManager
+
mock_find_project_config_dir.return_value = "/path/to/config"
mock_load_config_file.return_value = Config(project_root="/path/to/project")
mock_client = AsyncMock()
- mock_get_client.return_value = mock_client
+
+ ClientManager._create_client = AsyncMock(return_value=mock_client)
mock_collection = AsyncMock()
mock_get_collection.return_value = mock_collection
@@ -315,26 +348,29 @@ async def test_mcp_server():
@pytest.mark.asyncio
async def test_mcp_server_ls_on_start():
+ mock_collection = AsyncMock()
+
with (
patch(
"vectorcode.mcp_main.find_project_config_dir"
) as mock_find_project_config_dir,
patch("vectorcode.mcp_main.load_config_file") as mock_load_config_file,
- patch("vectorcode.mcp_main.get_client") as mock_get_client,
patch("vectorcode.mcp_main.get_collection") as mock_get_collection,
patch(
"vectorcode.mcp_main.get_collections", spec=AsyncMock
) as mock_get_collections,
patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool,
+ patch("vectorcode.common.try_server", return_value=True),
):
- from vectorcode.mcp_main import mcp_config
+ from vectorcode.mcp_main import ClientManager, mcp_config
mcp_config.ls_on_start = True
mock_find_project_config_dir.return_value = "/path/to/config"
mock_load_config_file.return_value = Config(project_root="/path/to/project")
mock_client = AsyncMock()
- mock_get_client.return_value = mock_client
- mock_collection = AsyncMock()
+
+ ClientManager._create_client = AsyncMock(return_value=mock_client)
+
mock_collection.metadata = {"path": "/path/to/project"}
mock_get_collection.return_value = mock_collection