Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
52 changes: 28 additions & 24 deletions src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING
Expand All @@ -22,6 +17,7 @@
from .auth_schemes import AuthSchemeType
from .auth_schemes import OpenIdConnectWithConfig
from .auth_tool import AuthConfig
from .credential_manager import CredentialManager
from .exchanger.oauth2_credential_exchanger import OAuth2CredentialExchanger

if TYPE_CHECKING:
Expand All @@ -48,10 +44,14 @@ async def exchange_auth_token(
self,
) -> AuthCredential:
exchanger = OAuth2CredentialExchanger()
exchange_result = await exchanger.exchange(
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
)
return exchange_result.credential

# Restore secret if needed
credential = self.auth_config.exchanged_auth_credential

with CredentialManager.restore_client_secret(credential):
res = await exchanger.exchange(credential, self.auth_config.auth_scheme)
return res.credential


async def parse_and_store_auth_response(self, state: State) -> None:

Expand Down Expand Up @@ -183,21 +183,25 @@ def generate_auth_uri(
)
scopes = list(scopes.keys())

client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
params = {
"access_type": "offline",
"prompt": "consent",
}
if auth_credential.oauth2.audience:
params["audience"] = auth_credential.oauth2.audience
uri, state = client.create_authorization_url(
url=authorization_endpoint, **params
)
client_id = auth_credential.oauth2.client_id

with CredentialManager.restore_client_secret(auth_credential):
client_secret = auth_credential.oauth2.client_secret
client = OAuth2Session(
client_id,
client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
params = {
"access_type": "offline",
"prompt": "consent",
}
if auth_credential.oauth2.audience:
params["audience"] = auth_credential.oauth2.audience
uri, state = client.create_authorization_url(
url=authorization_endpoint, **params
)

exchanged_auth_credential = auth_credential.model_copy(deep=True)
exchanged_auth_credential.oauth2.auth_uri = uri
Expand Down
11 changes: 9 additions & 2 deletions src/google/adk/auth/auth_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,16 @@ def get_credential_key(self):
)

auth_credential = self.raw_auth_credential
if auth_credential and auth_credential.model_extra:
if auth_credential and (
auth_credential.model_extra or auth_credential.oauth2
):
auth_credential = auth_credential.model_copy(deep=True)
auth_credential.model_extra.clear()
if auth_credential.model_extra:
auth_credential.model_extra.clear()
# Normalize secret to ensure stable key regardless of redaction
if auth_credential.oauth2:
auth_credential.oauth2.client_secret = None

credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
Expand Down
98 changes: 91 additions & 7 deletions src/google/adk/auth/credential_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

from __future__ import annotations

import contextlib
import logging
from typing import Optional

from fastapi.openapi.models import OAuth2

from ..agents.callback_context import CallbackContext
from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
from ..utils.feature_decorator import experimental
from .auth_credential import AuthCredential
from .auth_credential import AuthCredentialTypes
Expand Down Expand Up @@ -76,11 +76,23 @@ class CredentialManager:
```
"""

# A map to store client secrets in memory. Key is client_id, value is client_secret
_CLIENT_SECRETS: dict[str, str] = {}

def __init__(
self,
auth_config: AuthConfig,
):
self._auth_config = auth_config
# We deep copy the auth_config to avoid modifying the original object passed
# by the user. This allows for safe redaction of sensitive information without
# causing side effects.

self._auth_config = auth_config.model_copy(deep=True)

# Secure the client secret
self._secure_client_secret(self._auth_config.raw_auth_credential)
self._secure_client_secret(self._auth_config.exchanged_auth_credential)

self._exchanger_registry = CredentialExchangerRegistry()
self._refresher_registry = CredentialRefresherRegistry()
self._discovery_manager = OAuth2DiscoveryManager()
Expand All @@ -98,6 +110,8 @@ def __init__(
)

# TODO: Move ServiceAccountCredentialExchanger to the auth module
from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger

self._exchanger_registry.register(
AuthCredentialTypes.SERVICE_ACCOUNT,
ServiceAccountCredentialExchanger(),
Expand All @@ -111,6 +125,36 @@ def __init__(
AuthCredentialTypes.OPEN_ID_CONNECT, oauth2_refresher
)

def _secure_client_secret(self, credential: Optional[AuthCredential]):
"""Extracts client secret to memory and redacts it from the credential."""
if (
credential
and credential.oauth2
and credential.oauth2.client_id
and credential.oauth2.client_secret
and credential.oauth2.client_secret != "<redacted>"
):
logger.info(
f"Securing client secret for client_id: {credential.oauth2.client_id}"
)
# Store in memory map
CredentialManager._CLIENT_SECRETS[credential.oauth2.client_id] = (
credential.oauth2.client_secret
)
# Redact from config
credential.oauth2.client_secret = "<redacted>"
else:
if credential and credential.oauth2:
logger.debug(
f"Not securing secret for client_id {credential.oauth2.client_id}:"
f" secret is {credential.oauth2.client_secret}"
)

@staticmethod
def get_client_secret(client_id: str) -> Optional[str]:
"""Retrieves the client secret for a given client_id."""
return CredentialManager._CLIENT_SECRETS.get(client_id)

def register_credential_exchanger(
self,
credential_type: AuthCredentialTypes,
Expand All @@ -125,6 +169,9 @@ def register_credential_exchanger(
self._exchanger_registry.register(credential_type, exchanger_instance)

async def request_credential(self, callback_context: CallbackContext) -> None:
# We send the auth_config (which is already redacted in __init__) to the client
# Note: we need to ensure we don't send any stale exchanged credentials if they are not valid
# But usually CredentialManager manages that.
callback_context.request_credential(self._auth_config)

async def get_auth_credential(
Expand Down Expand Up @@ -206,6 +253,40 @@ async def _load_from_auth_response(
"""Load credential from auth response in callback context."""
return callback_context.get_auth_response(self._auth_config)

@staticmethod
@contextlib.contextmanager
def restore_client_secret(credential: AuthCredential, secret: str = None):
"""Context manager to temporarily restore client secret in a credential.

Args:
credential: The credential to restore secret for.
secret: Optional secret to use. If not provided, looks up by client_id.
"""
if not credential or not credential.oauth2:
yield
return

restored = False
if secret:
credential.oauth2.client_secret = secret
restored = True
elif (
credential.oauth2.client_id
and credential.oauth2.client_secret == "<redacted>"
):
stored_secret = CredentialManager.get_client_secret(
credential.oauth2.client_id
)
if stored_secret:
credential.oauth2.client_secret = stored_secret
restored = True

try:
yield
finally:
if restored:
credential.oauth2.client_secret = "<redacted>"

async def _exchange_credential(
self, credential: AuthCredential
) -> tuple[AuthCredential, bool]:
Expand All @@ -214,18 +295,21 @@ async def _exchange_credential(
if not exchanger:
return credential, False

from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger

if isinstance(exchanger, ServiceAccountCredentialExchanger):
return (
exchanger.exchange_credential(
self._auth_config.auth_scheme, credential
),
True,
)

exchange_result = await exchanger.exchange(
credential, self._auth_config.auth_scheme
)
return exchange_result.credential, exchange_result.was_exchanged
else:
with self.restore_client_secret(credential):
exchanged_credential = await exchanger.exchange(
credential, self._auth_config.auth_scheme
)
return exchanged_credential, True

async def _refresh_credential(
self, credential: AuthCredential
Expand Down
Loading