diff --git a/pyproject.toml b/pyproject.toml index 09bc7e6..cd0f10a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pgmq" -version = "1.0.2" +version = "1.0.3" description = "Python client for the PGMQ Postgres extension." readme = "README.md" license = "Apache-2.0" diff --git a/src/pgmq/queue.py b/src/pgmq/queue.py index 8883567..26b2c2c 100644 --- a/src/pgmq/queue.py +++ b/src/pgmq/queue.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field -from typing import Optional, List, Union +from collections.abc import Callable +from typing import Optional, List, Union, Any from psycopg.types.json import Jsonb +from psycopg.conninfo import make_conninfo from psycopg_pool import ConnectionPool import os from pgmq.messages import Message, QueueMetrics @@ -21,7 +23,7 @@ class PGMQueue: delay: int = 0 vt: int = 30 pool_size: int = 10 - kwargs: dict = field(default_factory=dict) + kwargs: Union[dict, Callable[[], dict[str, Any]]] = field(default_factory=dict) verbose: bool = False log_filename: Optional[str] = None init_extension: bool = True @@ -37,7 +39,23 @@ def __post_init__(self) -> None: user={self.username} password={self.password} """ - self.pool = ConnectionPool(conninfo, open=True, **self.kwargs) + if callable(self.kwargs): + # When kwargs is callable, create a callable conninfo that merges + # the base connection string with dynamic values (e.g., IAM auth tokens). + # psycopg_pool calls this each time a new connection is needed. + kwargs_callable = self.kwargs + + def get_conninfo() -> str: + extra = kwargs_callable() # e.g., {"password": "fresh_token"} + return make_conninfo(conninfo, **extra) + + self.pool = ConnectionPool(get_conninfo, open=True) + else: + if "kwargs" in self.kwargs: + raise TypeError( + "The 'kwargs' key is reserved for callables and cannot be used in the kwargs dictionary." + ) + self.pool = ConnectionPool(conninfo, open=True, **self.kwargs) self._initialize_logging() if self.init_extension: self._initialize_extensions() diff --git a/uv.lock b/uv.lock index 4f4b499..82dfe30 100644 --- a/uv.lock +++ b/uv.lock @@ -1428,7 +1428,7 @@ wheels = [ [[package]] name = "pgmq" -version = "1.0.1" +version = "1.0.3" source = { editable = "." } dependencies = [ { name = "orjson" },