diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index c6b3cb5a..83b63552 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -2,6 +2,7 @@ import atexit import glob import logging +from optparse import Option import os import sys from dataclasses import dataclass, field, fields @@ -62,6 +63,11 @@ class CliAction(Enum): hooks = "hooks" +class DbType(StrEnum): + local = "local" # Local ChromaDB instance + chromadb = "chromadb" # Remote ChromaDB instance + + @dataclass class Config: no_stderr: bool = False @@ -74,6 +80,7 @@ class Config: project_root: Optional[Union[str, Path]] = None query: Optional[list[str]] = None db_url: str = "http://127.0.0.1:8000" + db_type: DbType = DbType.local # falls back to a local instance embedding_function: str = "SentenceTransformerEmbeddingFunction" # This should fallback to whatever the default is. embedding_params: dict[str, Any] = field(default_factory=(lambda: {})) n_result: int = 1 @@ -106,6 +113,8 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config": default_config = Config() db_path = config_dict.get("db_path") db_url = config_dict.get("db_url") + db_type = config_dict.get("db_type", default_config.db_type) + if db_url is None: host = config_dict.get("host") port = config_dict.get("port") @@ -135,6 +144,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config": "embedding_params", default_config.embedding_params ), "db_url": db_url, + "db_type": db_type, "db_path": db_path, "db_log_path": os.path.expanduser( config_dict.get("db_log_path", default_config.db_log_path) @@ -521,6 +531,9 @@ async def get_project_config(project_root: Union[str, Path]) -> Config: if config is None: config = await load_config_file() config.project_root = project_root + + if config.db_type is None: + config.db_type = "local" return config diff --git a/src/vectorcode/common.py b/src/vectorcode/common.py index f4fff1a6..62f4b104 100644 --- a/src/vectorcode/common.py +++ b/src/vectorcode/common.py @@ -6,16 +6,15 @@ import subprocess import sys from typing import Any, AsyncGenerator -from urllib.parse import urlparse import chromadb import httpx from chromadb.api import AsyncClientAPI from chromadb.api.models.AsyncCollection import AsyncCollection -from chromadb.config import APIVersion, Settings from chromadb.utils import embedding_functions from vectorcode.cli_utils import Config, expand_path +from vectorcode.db.base import VectorStore logger = logging.getLogger(name=__name__) @@ -169,11 +168,36 @@ def get_embedding_function(configs: Config) -> chromadb.EmbeddingFunction | None raise +def build_collection_metadata(configs: Config) -> dict[str, str | int]: + assert configs.project_root is not None + full_path = str(expand_path(str(configs.project_root), absolute=True)) + + collection_meta: dict[str, str | int] = { + "path": full_path, + "hostname": socket.gethostname(), + "created-by": "VectorCode", + "username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")), + "embedding_function": configs.embedding_function, + } + + if configs.hnsw: + for key in configs.hnsw.keys(): + target_key = key + if not key.startswith("hnsw:"): + target_key = f"hnsw:{key}" + collection_meta[target_key] = configs.hnsw[key] + logger.debug( + f"Getting/Creating collection with the following metadata: {collection_meta}" + ) + + return collection_meta + + __COLLECTION_CACHE: dict[str, AsyncCollection] = {} async def get_collection( - client: AsyncClientAPI, configs: Config, make_if_missing: bool = False + db: VectorStore, configs: Config, make_if_missing: bool = False ): """ Raise ValueError when make_if_missing is False and no collection is found; @@ -205,11 +229,11 @@ async def get_collection( f"Getting/Creating collection with the following metadata: {collection_meta}" ) if not make_if_missing: - __COLLECTION_CACHE[full_path] = await client.get_collection( + __COLLECTION_CACHE[full_path] = await db.get_collection( collection_name, embedding_function ) else: - collection = await client.get_or_create_collection( + collection = await db.get_or_create_collection( collection_name, metadata=collection_meta, embedding_function=embedding_function, diff --git a/src/vectorcode/db/base.py b/src/vectorcode/db/base.py new file mode 100644 index 00000000..6be6739d --- /dev/null +++ b/src/vectorcode/db/base.py @@ -0,0 +1,127 @@ +from abc import ABC, abstractmethod +from typing import Any +from urllib.parse import urlparse + +from vectorcode.cli_utils import Config + + +class VectorStoreConnectionError(Exception): + pass + + +class VectorStore(ABC): + """Base class for vector database implementations. + + This abstract class defines the interface that all vector database implementations + must follow. It provides methods for common vector database operations like + querying, adding, and deleting vectors. + """ + + _configs: Config + + def __init__(self, configs: Config): + self.__COLLECTION_CACHE: dict[str, Any] = {} + self._configs = configs + + assert configs.project_root is not None + + @abstractmethod + async def connect(self) -> None: + """Establish connection to the vector database.""" + pass + + @abstractmethod + async def disconnect(self) -> None: + """Close connection to the vector database.""" + pass + + # @abstractmethod + # async def check_health(self) -> None: + # """Check if the database is healthy and accessible. Raises a VectorStoreConnectionError if not.""" + # pass + + @abstractmethod + async def get_collection( + self, + collection_name: str, + collection_meta: dict[str, Any] | None = None, + make_if_missing: bool = False, + ) -> Any: + """Get an existing collection.""" + pass + + # @abstractmethod + # async def get_or_create_collection( + # self, + # collection_name: str, + # metadata: Optional[Dict[str, Any]] = None, + # embedding_function: Optional[Any] = None, + # ) -> Any: + # """Get an existing collection or create a new one if it doesn't exist.""" + # pass + + # @abstractmethod + # async def query( + # self, + # collection: Any, + # query_texts: List[str], + # n_results: int, + # where: Optional[Dict[str, Any]] = None, + # include: Optional[List[str]] = None, + # ) -> Dict[str, Any]: + # """Query the vector database for similar vectors.""" + # pass + # + # @abstractmethod + # async def add( + # self, + # collection: Any, + # documents: List[str], + # metadatas: List[Dict[str, Any]], + # ids: Optional[List[str]] = None, + # ) -> None: + # """Add documents to the vector database.""" + # pass + # + # @abstractmethod + # async def delete( + # self, + # collection: Any, + # where: Optional[Dict[str, Any]] = None, + # ) -> None: + # """Delete documents from the vector database.""" + # pass + # + # @abstractmethod + # async def count( + # self, + # collection: Any, + # ) -> int: + # """Get the number of documents in the collection.""" + # pass + # + # @abstractmethod + # async def get( + # self, + # collection: Any, + # ids: Union[str, List[str]], + # include: Optional[List[str]] = None, + # ) -> Dict[str, Any]: + # """Get documents by their IDs.""" + # pass + + def print_config(self) -> None: + """Print the current database configuration.""" + parsed_url = urlparse(self._configs.db_url) + + print(f"{self._configs.db_type.title()} Configuration:") + print(f" URL: {self._configs.db_url}") + print(f" Host: {parsed_url.hostname or 'localhost'}") + print( + f" Port: {parsed_url.port or (8000 if self._configs.db_type == 'chroma' else 6333)}" + ) + print(f" SSL: {parsed_url.scheme == 'https'}") + if self._configs.db_settings: + print(" Settings:") + for key, value in self._configs.db_settings.items(): + print(f" {key}: {value}") diff --git a/src/vectorcode/db/chroma.py b/src/vectorcode/db/chroma.py new file mode 100644 index 00000000..07bc6bae --- /dev/null +++ b/src/vectorcode/db/chroma.py @@ -0,0 +1,110 @@ +import logging +from typing import Any, Dict, override +from urllib.parse import urlparse + +import chromadb +from chromadb.api import AsyncClientAPI +from chromadb.api.models.AsyncCollection import AsyncCollection +from chromadb.config import Settings +from chromadb.utils import embedding_functions + +from vectorcode.cli_utils import Config +from vectorcode.db.base import VectorStore, VectorStoreConnectionError + +logger = logging.getLogger(__name__) + + +class ChromaVectorStore(VectorStore): + """ChromaDB implementation of the vector store.""" + + _client: AsyncClientAPI | None = None + _chroma_settings: Settings + _embedding_function: chromadb.EmbeddingFunction | None + + def __init__(self, configs: Config): + super().__init__(configs) + settings: Dict[str, Any] = {"anonymized_telemetry": False} + if isinstance(self._configs.db_settings, dict): + valid_settings = { + k: v + for k, v in self._configs.db_settings.items() + if k in Settings.__fields__ + } + settings.update(valid_settings) + + parsed_url = urlparse(self._configs.db_url) + settings["chroma_server_host"] = parsed_url.hostname or "127.0.0.1" + settings["chroma_server_http_port"] = parsed_url.port or 8000 + settings["chroma_server_ssl_enabled"] = parsed_url.scheme == "https" + settings["chroma_server_api_default_path"] = "/api/v2" + + self._chroma_settings = Settings(**settings) + + try: + self._embedding_function = getattr( + embedding_functions, configs.embedding_function + )(**configs.embedding_params) + except AttributeError: + logger.warning( + f"Failed to use {configs.embedding_function}. Falling back to Sentence Transformer.", + ) + self._embedding_function = ( + embedding_functions.SentenceTransformerEmbeddingFunction() # type:ignore + ) + except Exception as e: + e.add_note( + "\nFor errors caused by missing dependency, consult the documentation of pipx (or whatever package manager that you installed VectorCode with) for instructions to inject libraries into the virtual environment." + ) + logger.error( + f"Failed to use {configs.embedding_function} with following error.", + ) + raise + + @override + async def connect(self) -> None: + """Establish connection to ChromaDB.""" + try: + if self._client is None: + logger.debug( + f"Connecting to ChromaDB at {self._chroma_settings.chroma_server_host}:{self._chroma_settings.chroma_server_http_port}." + ) + self._client = await chromadb.AsyncHttpClient( + settings=self._chroma_settings, + host=str(self._chroma_settings.chroma_server_host), + port=int(self._chroma_settings.chroma_server_http_port or 8000), + ) + + await self._client.heartbeat() + except Exception as e: + logger.error(f"Could not connect to ChromaDB: {e}") + raise VectorStoreConnectionError(e) + + @override + async def disconnect(self) -> None: + """Not required for non local chromadb.""" + pass + + @override + async def get_collection( + self, + collection_name: str, + collection_meta: dict[str, Any] | None = None, + make_if_missing: bool = False, + ) -> AsyncCollection: + """ + Raise ValueError when make_if_missing is False and no collection is found; + Raise IndexError on hash collision. + """ + if not self._client: + await self.connect() + + assert self._client is not None, "Chroma client is not connected." + + if not make_if_missing: + return await self._client.get_collection(collection_name) + else: + return await self._client.get_or_create_collection( + collection_name, + metadata=collection_meta, + embedding_function=self._embedding_function, + ) diff --git a/src/vectorcode/db/factory.py b/src/vectorcode/db/factory.py new file mode 100644 index 00000000..5cd5577f --- /dev/null +++ b/src/vectorcode/db/factory.py @@ -0,0 +1,25 @@ +from typing import Dict, Type + +from vectorcode.cli_utils import Config, DbType +from vectorcode.db.base import VectorStore +from vectorcode.db.chroma import ChromaVectorStore +from vectorcode.db.local import LocalChromaVectorStore + + +class VectorStoreFactory: + """Factory for creating vector store instances.""" + + _stores: Dict[DbType, Type[VectorStore]] = { + DbType.chromadb: ChromaVectorStore, + DbType.local: LocalChromaVectorStore, + } + + @classmethod + def create_store(cls, configs: Config) -> VectorStore: + """Create a vector store instance based on configuration.""" + store_type = configs.db_type + if store_type not in cls._stores: + raise ValueError(f"Unsupported vector store type: {store_type}") + + store_class = cls._stores[store_type] + return store_class(configs) diff --git a/src/vectorcode/db/local.py b/src/vectorcode/db/local.py new file mode 100644 index 00000000..88339106 --- /dev/null +++ b/src/vectorcode/db/local.py @@ -0,0 +1,102 @@ +import asyncio +from asyncio.subprocess import Process +import logging +import subprocess +import os +import socket +import sys +from typing import override + + +from vectorcode.cli_utils import Config, expand_path +from vectorcode.db.chroma import ChromaVectorStore + +logger = logging.getLogger(__name__) + + +class LocalChromaVectorStore(ChromaVectorStore): + """ChromaDB implementation of the vector store.""" + + _process: Process | None = None + _full_path: str + + def __init__(self, configs: Config): + super().__init__(configs) + self._full_path = str(expand_path(str(configs.project_root), absolute=True)) + + async def _start_chroma_process(self) -> None: + if self._process is not None: + return + + assert self._configs.db_path is not None, "ChromaDB db_path must be set." + db_path = os.path.expanduser(self._configs.db_path) + self._configs.db_log_path = os.path.expanduser(self._configs.db_log_path) + if not os.path.isdir(self._configs.db_log_path): + os.makedirs(self._configs.db_log_path) + if not os.path.isdir(db_path): + logger.warning( + f"Using local database at {os.path.expanduser('~/.local/share/vectorcode/chromadb/')}.", + ) + db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/") + env = os.environ.copy() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # OS selects a free ephemeral port + self._chroma_settings.chroma_server_http_port = int(s.getsockname()[1]) + + server_url = f"http://127.0.0.1:{self._chroma_settings.chroma_server_http_port}" + logger.info(f"Starting bundled ChromaDB server at {server_url}.") + env.update({"ANONYMIZED_TELEMETRY": "False"}) + + self._process = await asyncio.create_subprocess_exec( + sys.executable, + "-m", + "chromadb.cli.cli", + "run", + "--host", + "localhost", + "--port", + str(self._chroma_settings.chroma_server_http_port), + "--path", + db_path, + "--log-path", + os.path.join(str(self._configs.db_log_path), "chroma.log"), + stdout=subprocess.DEVNULL, + stderr=sys.stderr, + env=env, + ) + + @override + async def connect(self) -> None: + """Establish connection to ChromaDB.""" + if self._process is None: + await self._start_chroma_process() + # Wait for server to start up + await asyncio.sleep(2) + + # we have to wait until the local chroma server is ready + # Retry connection with exponential backoff + max_retries = 5 + retry_delay = 0.5 + + for attempt in range(max_retries): + try: + await super().connect() + return + except Exception as e: + if attempt == max_retries - 1: + raise + logger.debug( + f"Connection attempt {attempt + 1} failed, retrying in {retry_delay}s: {e}" + ) + await asyncio.sleep(retry_delay) + retry_delay *= 2 + + @override + async def disconnect(self) -> None: + """Close connection to ChromaDB.""" + if self._process is None: + return + + logger.info("Shutting down the bundled Chromadb instance.") + self._process.terminate() + await self._process.wait() diff --git a/src/vectorcode/main.py b/src/vectorcode/main.py index 3b64cff4..e9ad4f26 100644 --- a/src/vectorcode/main.py +++ b/src/vectorcode/main.py @@ -71,11 +71,10 @@ async def async_main(): return await hooks(cli_args) - from vectorcode.common import start_server, try_server + from vectorcode.db.factory import VectorStoreFactory - server_process = None - if not await try_server(final_configs.db_url): - server_process = await start_server(final_configs) + db = VectorStoreFactory.create_store(final_configs) + await db.connect() if final_configs.pipe: # NOTE: NNCF (intel GPU acceleration for sentence transformer) keeps showing logs. @@ -92,7 +91,7 @@ async def async_main(): case CliAction.vectorise: from vectorcode.subcommands import vectorise - return_val = await vectorise(final_configs) + return_val = await vectorise(db, final_configs) case CliAction.drop: from vectorcode.subcommands import drop @@ -113,10 +112,7 @@ async def async_main(): return_val = 1 logger.error(traceback.format_exc()) finally: - if server_process is not None: - logger.info("Shutting down the bundled Chromadb instance.") - server_process.terminate() - await server_process.wait() + await db.disconnect() return return_val diff --git a/src/vectorcode/subcommands/vectorise.py b/src/vectorcode/subcommands/vectorise.py index a838124f..33df3dd3 100644 --- a/src/vectorcode/subcommands/vectorise.py +++ b/src/vectorcode/subcommands/vectorise.py @@ -14,6 +14,7 @@ from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.api.types import IncludeEnum +from vectorcode.db.base import VectorStore from vectorcode.chunking import Chunk, TreeSitterChunker from vectorcode.cli_utils import ( GLOBAL_EXCLUDE_SPEC, @@ -22,7 +23,7 @@ expand_globs, expand_path, ) -from vectorcode.common import get_client, get_collection, verify_ef +from vectorcode.common import verify_ef logger = logging.getLogger(name=__name__) @@ -158,11 +159,9 @@ def load_files_from_include(project_root: str) -> list[str]: return [] -async def vectorise(configs: Config) -> int: - assert configs.project_root is not None - client = await get_client(configs) +async def vectorise(db: VectorStore, configs: Config) -> int: try: - collection = await get_collection(client, configs, True) + collection = await db.get_collection(True) except IndexError: print("Failed to get/create the collection. Please check your config.") return 1 @@ -180,6 +179,7 @@ async def vectorise(configs: Config) -> int: specs = [ gitignore_path, ] + assert configs.project_root is not None exclude_spec_path = os.path.join( configs.project_root, ".vectorcode", "vectorcode.exclude" )