Skip to content

Commit 8942c98

Browse files
aianchJackYPCOnline
authored andcommitted
feat(bedrock): add guardrail_last_turn_only option
1 parent 894ba80 commit 8942c98

File tree

3 files changed

+149
-2
lines changed

3 files changed

+149
-2
lines changed

src/strands/models/bedrock.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class BedrockConfig(TypedDict, total=False):
8282
guardrail_redact_input_message: If a Bedrock Input guardrail triggers, replace the input with this message.
8383
guardrail_redact_output: Flag to redact output if guardrail is triggered. Defaults to False.
8484
guardrail_redact_output_message: If a Bedrock Output guardrail triggers, replace output with this message.
85+
guardrail_last_turn_only: Flag to send only the last turn to guardrails instead of full conversation.
86+
Defaults to False.
8587
max_tokens: Maximum number of tokens to generate in the response
8688
model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0")
8789
include_tool_result_status: Flag to include status field in tool results.
@@ -105,6 +107,7 @@ class BedrockConfig(TypedDict, total=False):
105107
guardrail_redact_input_message: Optional[str]
106108
guardrail_redact_output: Optional[bool]
107109
guardrail_redact_output_message: Optional[str]
110+
guardrail_last_turn_only: Optional[bool]
108111
max_tokens: Optional[int]
109112
model_id: str
110113
include_tool_result_status: Optional[Literal["auto"] | bool]
@@ -206,9 +209,19 @@ def _format_request(
206209
Returns:
207210
A Bedrock converse stream request.
208211
"""
212+
# Filter messages for guardrails if guardrail_last_turn_only is enabled
213+
messages_for_request = messages
214+
if (
215+
self.config.get("guardrail_last_turn_only", False)
216+
and self.config.get("guardrail_id")
217+
and self.config.get("guardrail_version")
218+
):
219+
messages_for_request = self._get_last_turn_messages(messages)
220+
209221
if not tool_specs:
210222
has_tool_content = any(
211-
any("toolUse" in block or "toolResult" in block for block in msg.get("content", [])) for msg in messages
223+
any("toolUse" in block or "toolResult" in block for block in msg.get("content", []))
224+
for msg in messages_for_request
212225
)
213226
if has_tool_content:
214227
tool_specs = [noop_tool.tool_spec]
@@ -224,7 +237,7 @@ def _format_request(
224237

225238
return {
226239
"modelId": self.config["model_id"],
227-
"messages": self._format_bedrock_messages(messages),
240+
"messages": self._format_bedrock_messages(messages_for_request),
228241
"system": system_blocks,
229242
**(
230243
{
@@ -295,6 +308,42 @@ def _format_request(
295308
),
296309
}
297310

311+
def _get_last_turn_messages(self, messages: Messages) -> Messages:
312+
"""Get the last turn messages for guardrail evaluation.
313+
314+
Returns the latest user message and the previous assistant message (if it exists).
315+
This reduces the conversation context sent to guardrails when guardrail_last_turn_only is True.
316+
317+
Args:
318+
messages: Full conversation messages.
319+
320+
Returns:
321+
Messages containing only the last turn (user + previous assistant if exists).
322+
"""
323+
if not messages:
324+
return []
325+
326+
# Find the last user message
327+
last_user_index = -1
328+
for i in range(len(messages) - 1, -1, -1):
329+
if messages[i]["role"] == "user":
330+
last_user_index = i
331+
break
332+
333+
if last_user_index == -1:
334+
# No user message found, return empty
335+
return []
336+
337+
# Include the previous assistant message if it exists
338+
result_messages: Messages = []
339+
if last_user_index > 0 and messages[last_user_index - 1]["role"] == "assistant":
340+
result_messages.append(messages[last_user_index - 1])
341+
342+
# Add the last user message
343+
result_messages.append(messages[last_user_index])
344+
345+
return result_messages
346+
298347
def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]:
299348
"""Format messages for Bedrock API compatibility.
300349

tests/strands/models/test_bedrock.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,50 @@ def test_format_request_filters_output_schema(model, messages, model_id):
20522052
assert tool_spec["inputSchema"] == {"type": "object", "properties": {}}
20532053

20542054

2055+
def test_get_last_turn_messages(model):
2056+
"""Test _get_last_turn_messages helper method."""
2057+
# Test empty messages
2058+
assert model._get_last_turn_messages([]) == []
2059+
2060+
# Test single user message
2061+
messages = [{"role": "user", "content": [{"text": "Hello"}]}]
2062+
result = model._get_last_turn_messages(messages)
2063+
assert len(result) == 1
2064+
assert result[0]["role"] == "user"
2065+
2066+
# Test user-assistant pair
2067+
messages = [
2068+
{"role": "user", "content": [{"text": "Hello"}]},
2069+
{"role": "assistant", "content": [{"text": "Hi"}]},
2070+
{"role": "user", "content": [{"text": "How are you?"}]},
2071+
]
2072+
result = model._get_last_turn_messages(messages)
2073+
assert len(result) == 2
2074+
assert result[0]["role"] == "assistant"
2075+
assert result[1]["role"] == "user"
2076+
assert result[1]["content"][0]["text"] == "How are you?"
2077+
2078+
2079+
def test_format_request_with_guardrail_last_turn_only(model, model_id):
2080+
"""Test _format_request uses filtered messages when guardrail_last_turn_only=True."""
2081+
model.update_config(guardrail_id="test-guardrail", guardrail_version="DRAFT", guardrail_last_turn_only=True)
2082+
2083+
messages = [
2084+
{"role": "user", "content": [{"text": "First message"}]},
2085+
{"role": "assistant", "content": [{"text": "First response"}]},
2086+
{"role": "user", "content": [{"text": "Latest message"}]},
2087+
]
2088+
2089+
request = model._format_request(messages)
2090+
2091+
# Should only include the last turn (assistant + user)
2092+
formatted_messages = request["messages"]
2093+
assert len(formatted_messages) == 2
2094+
assert formatted_messages[0]["role"] == "assistant"
2095+
assert formatted_messages[1]["role"] == "user"
2096+
assert formatted_messages[1]["content"][0]["text"] == "Latest message"
2097+
2098+
20552099
@pytest.mark.asyncio
20562100
async def test_stream_backward_compatibility_system_prompt(bedrock_client, model, messages, alist):
20572101
"""Test that system_prompt is converted to system_prompt_content when system_prompt_content is None."""

tests_integ/test_bedrock_guardrails.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,60 @@ def list_users() -> str:
289289
assert tool_result["content"][0]["text"] == INPUT_REDACT_MESSAGE
290290

291291

292+
def test_guardrail_last_turn_only(boto_session, bedrock_guardrail):
293+
"""Test that guardrail_last_turn_only only sends the last turn to guardrails."""
294+
bedrock_model = BedrockModel(
295+
guardrail_id=bedrock_guardrail,
296+
guardrail_version="DRAFT",
297+
guardrail_last_turn_only=True,
298+
boto_session=boto_session,
299+
)
300+
301+
agent = Agent(model=bedrock_model, system_prompt="You are a helpful assistant.", callback_handler=None)
302+
303+
# First conversation turn - should not trigger guardrail
304+
response1 = agent("Hello, how are you?")
305+
assert response1.stop_reason != "guardrail_intervened"
306+
307+
# Second conversation turn with blocked word - should trigger guardrail
308+
# Since guardrail_last_turn_only=True, only this message and the previous assistant response
309+
# should be evaluated by the guardrail, not the entire conversation history
310+
response2 = agent("CACTUS")
311+
assert response2.stop_reason == "guardrail_intervened"
312+
assert str(response2).strip() == BLOCKED_INPUT
313+
314+
315+
def test_guardrail_last_turn_only_recovery_scenario(boto_session, bedrock_guardrail):
316+
"""Test guardrail recovery: blocked content followed by normal question.
317+
318+
This tests the key benefit of guardrail_last_turn_only:
319+
1. First turn: blocked content triggers guardrail
320+
2. Second turn: normal question should work because only last turn is analyzed
321+
"""
322+
bedrock_model = BedrockModel(
323+
guardrail_id=bedrock_guardrail,
324+
guardrail_version="DRAFT",
325+
guardrail_last_turn_only=True,
326+
boto_session=boto_session,
327+
)
328+
329+
agent = Agent(model=bedrock_model, system_prompt="You are a helpful assistant.", callback_handler=None)
330+
331+
# First turn - should be blocked by guardrail
332+
response1 = agent("CACTUS")
333+
assert response1.stop_reason == "guardrail_intervened"
334+
assert str(response1).strip() == BLOCKED_INPUT
335+
336+
# Second turn - should work normally with last turn only
337+
# This is the key test: normal questions should work after blocked content
338+
response2 = agent("What is the weather like today?")
339+
assert response2.stop_reason != "guardrail_intervened"
340+
assert str(response2).strip() != BLOCKED_INPUT
341+
342+
# Verify the conversation has both messages
343+
assert len(agent.messages) == 4 # 2 user + 2 assistant messages
344+
345+
292346
def test_guardrail_input_intervention_properly_redacts_in_session(boto_session, bedrock_guardrail, temp_dir):
293347
bedrock_model = BedrockModel(
294348
guardrail_id=bedrock_guardrail,

0 commit comments

Comments
 (0)