Skip to content

Commit 3cb39a6

Browse files
pshikoCopilotzastrowm
authored
feat: add gemini_tools field to GeminiModel with validation and tests (#1050)
Add support for Gemini-specific tools like GoogleSearch and CodeExecution, with validation to prevent FunctionDeclarations. Due to the fundamental differences in how Gemini's built-in tools operate (server-side execution without explicit tool call/result blocks), we don't implement history tracking for gemini_tools - that would require additional design work and a longer discussion on how to normalize this across all model providers. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com>
1 parent 3d03a35 commit 3cb39a6

File tree

4 files changed

+155
-2
lines changed

4 files changed

+155
-2
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
"""Steering handler implementations."""
22

3-
__all__ = []
3+
from typing import Sequence
4+
5+
__all__: Sequence[str] = []

src/strands/models/gemini.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,16 @@ class GeminiConfig(TypedDict, total=False):
4040
params: Additional model parameters (e.g., temperature).
4141
For a complete list of supported parameters, see
4242
https://ai.google.dev/api/generate-content#generationconfig.
43+
gemini_tools: Gemini-specific tools that are not FunctionDeclarations
44+
(e.g., GoogleSearch, CodeExecution, ComputerUse, UrlContext, FileSearch).
45+
Use the standard tools interface for function calling tools.
46+
For a complete list of supported tools, see
47+
https://ai.google.dev/api/caching#Tool
4348
"""
4449

4550
model_id: Required[str]
4651
params: dict[str, Any]
52+
gemini_tools: list[genai.types.Tool]
4753

4854
def __init__(
4955
self,
@@ -61,6 +67,10 @@ def __init__(
6167
validate_config_keys(model_config, GeminiModel.GeminiConfig)
6268
self.config = GeminiModel.GeminiConfig(**model_config)
6369

70+
# Validate gemini_tools if provided
71+
if "gemini_tools" in self.config:
72+
self._validate_gemini_tools(self.config["gemini_tools"])
73+
6474
logger.debug("config=<%s> | initializing", self.config)
6575

6676
self.client_args = client_args or {}
@@ -72,6 +82,10 @@ def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type:
7282
Args:
7383
**model_config: Configuration overrides.
7484
"""
85+
# Validate gemini_tools if provided
86+
if "gemini_tools" in model_config:
87+
self._validate_gemini_tools(model_config["gemini_tools"])
88+
7589
self.config.update(model_config)
7690

7791
@override
@@ -181,7 +195,7 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge
181195
Return:
182196
Gemini tool list.
183197
"""
184-
return [
198+
tools = [
185199
genai.types.Tool(
186200
function_declarations=[
187201
genai.types.FunctionDeclaration(
@@ -193,6 +207,9 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge
193207
],
194208
),
195209
]
210+
if self.config.get("gemini_tools"):
211+
tools.extend(self.config["gemini_tools"])
212+
return tools
196213

197214
def _format_request_config(
198215
self,
@@ -451,3 +468,27 @@ async def structured_output(
451468
client = genai.Client(**self.client_args).aio
452469
response = await client.models.generate_content(**request)
453470
yield {"output": output_model.model_validate(response.parsed)}
471+
472+
@staticmethod
473+
def _validate_gemini_tools(gemini_tools: list[genai.types.Tool]) -> None:
474+
"""Validate that gemini_tools does not contain FunctionDeclarations.
475+
476+
Gemini-specific tools should only include tools that cannot be represented
477+
as FunctionDeclarations (e.g., GoogleSearch, CodeExecution, ComputerUse).
478+
Standard function calling tools should use the tools interface instead.
479+
480+
Args:
481+
gemini_tools: List of Gemini tools to validate
482+
483+
Raises:
484+
ValueError: If any tool contains function_declarations
485+
"""
486+
for tool in gemini_tools:
487+
# Check if the tool has function_declarations attribute and it's not empty
488+
if hasattr(tool, "function_declarations") and tool.function_declarations:
489+
raise ValueError(
490+
"gemini_tools should not contain FunctionDeclarations. "
491+
"Use the standard tools interface for function calling tools. "
492+
"gemini_tools is reserved for Gemini-specific tools like "
493+
"GoogleSearch, CodeExecution, ComputerUse, UrlContext, and FileSearch."
494+
)

tests/strands/models/test_gemini.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,89 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath
624624
gemini_client.aio.models.generate_content.assert_called_with(**exp_request)
625625

626626

627+
def test_gemini_tools_validation_rejects_function_declarations(model_id):
628+
tool_with_function_declarations = genai.types.Tool(
629+
function_declarations=[
630+
genai.types.FunctionDeclaration(
631+
name="test_function",
632+
description="A test function",
633+
)
634+
]
635+
)
636+
637+
with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"):
638+
GeminiModel(model_id=model_id, gemini_tools=[tool_with_function_declarations])
639+
640+
641+
def test_gemini_tools_validation_allows_non_function_tools(model_id):
642+
tool_with_google_search = genai.types.Tool(google_search=genai.types.GoogleSearch())
643+
644+
model = GeminiModel(model_id=model_id, gemini_tools=[tool_with_google_search])
645+
assert "gemini_tools" in model.config
646+
647+
648+
def test_gemini_tools_validation_on_update_config(model):
649+
tool_with_function_declarations = genai.types.Tool(
650+
function_declarations=[
651+
genai.types.FunctionDeclaration(
652+
name="test_function",
653+
description="A test function",
654+
)
655+
]
656+
)
657+
658+
with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"):
659+
model.update_config(gemini_tools=[tool_with_function_declarations])
660+
661+
662+
@pytest.mark.asyncio
663+
async def test_stream_request_with_gemini_tools(gemini_client, messages, model_id):
664+
google_search_tool = genai.types.Tool(google_search=genai.types.GoogleSearch())
665+
model = GeminiModel(model_id=model_id, gemini_tools=[google_search_tool])
666+
667+
await anext(model.stream(messages))
668+
669+
exp_request = {
670+
"config": {
671+
"tools": [
672+
{"function_declarations": []},
673+
{"google_search": {}},
674+
]
675+
},
676+
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
677+
"model": model_id,
678+
}
679+
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)
680+
681+
682+
@pytest.mark.asyncio
683+
async def test_stream_request_with_gemini_tools_and_function_tools(gemini_client, messages, tool_spec, model_id):
684+
code_execution_tool = genai.types.Tool(code_execution=genai.types.ToolCodeExecution())
685+
model = GeminiModel(model_id=model_id, gemini_tools=[code_execution_tool])
686+
687+
await anext(model.stream(messages, tool_specs=[tool_spec]))
688+
689+
exp_request = {
690+
"config": {
691+
"tools": [
692+
{
693+
"function_declarations": [
694+
{
695+
"description": tool_spec["description"],
696+
"name": tool_spec["name"],
697+
"parameters_json_schema": tool_spec["inputSchema"]["json"],
698+
}
699+
]
700+
},
701+
{"code_execution": {}},
702+
]
703+
},
704+
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
705+
"model": model_id,
706+
}
707+
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)
708+
709+
627710
@pytest.mark.asyncio
628711
async def test_stream_handles_non_json_error(gemini_client, model, messages, caplog, alist):
629712
error_message = "Invalid API key"

tests_integ/models/test_model_gemini.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pydantic
44
import pytest
5+
from google import genai
56

67
import strands
78
from strands import Agent
@@ -21,6 +22,16 @@ def model():
2122
)
2223

2324

25+
@pytest.fixture
26+
def gemini_tool_model():
27+
return GeminiModel(
28+
client_args={"api_key": os.getenv("GOOGLE_API_KEY")},
29+
model_id="gemini-2.5-flash",
30+
params={"temperature": 0.15}, # Lower temperature for consistent test behavior
31+
gemini_tools=[genai.types.Tool(code_execution=genai.types.ToolCodeExecution())],
32+
)
33+
34+
2435
@pytest.fixture
2536
def tools():
2637
@strands.tool
@@ -175,3 +186,19 @@ def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow
175186
tru_color = assistant_agent.structured_output(type(yellow_color), content)
176187
exp_color = yellow_color
177188
assert tru_color == exp_color
189+
190+
191+
def test_agent_with_gemini_code_execution_tool(gemini_tool_model):
192+
system_prompt = "Generate and run code for all calculations"
193+
agent = Agent(model=gemini_tool_model, system_prompt=system_prompt)
194+
# sample prompt taken from https://ai.google.dev/gemini-api/docs/code-execution
195+
result_turn1 = agent(
196+
"What is the sum of the first 50 prime numbers? Generate and run code for the calculation, "
197+
"and make sure you get all 50."
198+
)
199+
200+
# NOTE: We don't verify tool history because built-in tools are currently represented in message history
201+
assert "5117" in str(result_turn1)
202+
203+
result_turn2 = agent("Summarize that into a single number")
204+
assert "5117" in str(result_turn2)

0 commit comments

Comments
 (0)