Skip to content

Commit 32bee81

Browse files
committed
Merge branch 'main' into aws-mcp-integration-tests
2 parents 88ab7c3 + 92078b8 commit 32bee81

File tree

11 files changed

+364
-12
lines changed

11 files changed

+364
-12
lines changed

CHANGELOG.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## v1.1.4 (2025-12-04)
9+
10+
### Fix
11+
12+
- do not call initialize for q dev cli / kiro cli
13+
- patch fastmcp lowlevel session method
14+
- connect remote mcp client immediately in the initialize middleware
15+
16+
## v1.1.3 (2025-12-04)
17+
18+
### Fix
19+
20+
- avoid infinite recursion (#111)
21+
- init client and show errors in all clients (#108)
22+
- set the fastmcp log level to be the same as the proxy (#101)
23+
824
## v1.1.2 (2025-11-27)
925

1026
### Fix

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,6 @@ Licensed under the Apache License, Version 2.0 (the "License").
315315
## Disclaimer
316316

317317
LLMs are non-deterministic and they make mistakes, we advise you to always thoroughly test and follow the best practices of your organization before using these tools on customer facing accounts. Users of this package are solely responsible for implementing proper security controls and MUST use AWS Identity and Access Management (IAM) to manage access to AWS resources. You are responsible for configuring appropriate IAM policies, roles, and permissions, and any security vulnerabilities resulting from improper IAM configuration are your sole responsibility. By using this package, you acknowledge that you have read and understood this disclaimer and agree to use the package at your own risk.
318+
319+
<!-- mcp-name: io.github.aws/mcp-proxy-for-aws -->
320+
<!-- mcp-name: io.github.aws/aws-mcp -->

mcp_proxy_for_aws/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from importlib.metadata import version as _metadata_version
1818

19+
import mcp_proxy_for_aws.fastmcp_patch as _fastmcp_patch
20+
1921

2022
__all__ = ['__version__']
2123
__version__ = _metadata_version('mcp-proxy-for-aws')

mcp_proxy_for_aws/fastmcp_patch.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import fastmcp.server.low_level as low_level_module
2+
import mcp.types
3+
from functools import wraps
4+
from mcp import McpError
5+
from mcp.server.stdio import stdio_server as stdio_server
6+
from mcp.shared.session import RequestResponder
7+
8+
9+
original_receive_request = low_level_module.MiddlewareServerSession._received_request
10+
11+
12+
@wraps(original_receive_request)
13+
async def _received_request(
14+
self,
15+
responder: RequestResponder[mcp.types.ClientRequest, mcp.types.ServerResult],
16+
):
17+
"""Monkey patch fastmcp so that the initialize error from the middleware can be send back to the client.
18+
19+
https://github.com/jlowin/fastmcp/pull/2531
20+
"""
21+
if isinstance(responder.request.root, mcp.types.InitializeRequest):
22+
try:
23+
return await original_receive_request(self, responder)
24+
except McpError as e:
25+
if not responder._completed:
26+
with responder:
27+
return await responder.respond(e.error)
28+
29+
raise e
30+
else:
31+
return await original_receive_request(self, responder)
32+
33+
34+
low_level_module.MiddlewareServerSession._received_request = _received_request

mcp_proxy_for_aws/middleware/initialize_middleware.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@ async def on_initialize(
2525
try:
2626
logger.debug('Received initialize request %s.', context.message)
2727
self._client_factory.set_init_params(context.message)
28+
client = await self._client_factory.get_client()
29+
# connect the http client, fail and don't succeed the stdio connect
30+
# if remote client cannot be connected
31+
client_name = context.message.params.clientInfo.name.lower()
32+
if 'kiro cli' not in client_name and 'q dev cli' not in client_name:
33+
# q cli / kiro cli uses the rust SDK which does not handle json rpc error
34+
# properly during initialization.
35+
# https://github.com/modelcontextprotocol/rust-sdk/pull/569
36+
# if calling _connect below raise mcp error, the q cli will skip the message
37+
# and continue wait for a json rpc response message which will never come.
38+
# Luckily, q cli calls list tool immediately after being connected to a mcp server
39+
# the list_tool call will require the client to be connected again, so the mcp error
40+
# will be displayed in the q cli logs.
41+
await client._connect()
2842
return await call_next(context)
2943
except Exception:
3044
logger.exception('Initialize failed in middleware.')

mcp_proxy_for_aws/proxy.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,13 @@ def __init__(
7878
class AWSMCPProxyClient(_ProxyClient):
7979
"""Proxy client that handles HTTP errors when connection fails."""
8080

81-
def __init__(self, transport: ClientTransport, **kwargs):
81+
def __init__(self, transport: ClientTransport, max_connect_retry=3, **kwargs):
8282
"""Constructor of AutoRefreshProxyCilent."""
8383
super().__init__(transport, **kwargs)
84+
self._max_connect_retry = max_connect_retry
8485

8586
@override
86-
async def _connect(self):
87+
async def _connect(self, retry=0):
8788
"""Enter as normal && initialize only once."""
8889
logger.debug('Connecting %s', self)
8990
try:
@@ -96,27 +97,36 @@ async def _connect(self):
9697
try:
9798
body = await response.aread()
9899
jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root
99-
except Exception:
100-
logger.debug('HTTP error is not a valid MCP message.')
100+
except Exception as e:
101+
logger.debug('HTTP error is not a valid MCP message.', exc_info=e)
101102
raise http_error
102103

103104
if isinstance(jsonrpc_msg, JSONRPCError):
104-
logger.debug('Converting HTTP error to MCP error %s', http_error)
105+
logger.debug('Converting HTTP error to MCP error', exc_info=http_error)
105106
# raising McpError so that the sdk can handle the exception properly
106107
raise McpError(error=jsonrpc_msg.error) from http_error
107108
else:
108109
raise http_error
109-
except RuntimeError:
110+
except RuntimeError as e:
111+
if isinstance(e.__cause__, McpError):
112+
raise e.__cause__
113+
114+
if retry > self._max_connect_retry:
115+
raise e
116+
110117
try:
111-
logger.warning('encountered runtime error, try force disconnect.')
118+
logger.warning('encountered runtime error, try force disconnect.', exc_info=e)
112119
await self._disconnect(force=True)
113-
except Exception:
120+
except httpx.TimeoutException:
114121
# _disconnect awaits on the session_task,
115122
# which raises the timeout error that caused the client session to be terminated.
116123
# the error is ignored as long as the counter is force set to 0.
117124
# TODO: investigate how timeout error is handled by fastmcp and httpx
118-
logger.exception('encountered another error, ignoring.')
119-
return await self._connect()
125+
logger.exception(
126+
'Session was terminated due to timeout error, ignore and reconnect'
127+
)
128+
129+
return await self._connect(retry + 1)
120130

121131
async def __aexit__(self, exc_type, exc_val, exc_tb):
122132
"""The MCP Proxy for AWS project is a proxy from stdio to http (sigv4).

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ members = [
1010
name = "mcp-proxy-for-aws"
1111

1212
# NOTE: "Patch"=9223372036854775807 bumps next release to zero.
13-
version = "1.1.2"
13+
version = "1.1.4"
1414

1515
description = "MCP Proxy for AWS"
1616
readme = "README.md"

tests/unit/test_fastmcp_patch.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import mcp.types as mt
2+
import pytest
3+
from mcp import McpError
4+
from mcp.shared.session import RequestResponder
5+
from unittest.mock import AsyncMock, Mock, patch
6+
7+
8+
@pytest.mark.asyncio
9+
async def test_patched_received_request_initialize_success():
10+
"""Test that patched _received_request calls original for successful initialize."""
11+
# Import after patching is applied
12+
import fastmcp.server.low_level as low_level_module
13+
from mcp_proxy_for_aws import fastmcp_patch
14+
15+
mock_self = Mock()
16+
mock_self.fastmcp = Mock()
17+
18+
mock_request = Mock()
19+
mock_request.root = Mock(spec=mt.InitializeRequest)
20+
21+
mock_responder = Mock(spec=RequestResponder)
22+
mock_responder.request = mock_request
23+
24+
with patch.object(
25+
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock
26+
) as mock_original:
27+
await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder)
28+
mock_original.assert_called_once_with(mock_self, mock_responder)
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_patched_received_request_initialize_mcp_error_not_completed():
33+
"""Test that patched _received_request handles McpError when responder not completed."""
34+
import fastmcp.server.low_level as low_level_module
35+
from mcp_proxy_for_aws import fastmcp_patch
36+
37+
mock_self = Mock()
38+
mock_self.fastmcp = Mock()
39+
40+
mock_request = Mock()
41+
mock_request.root = Mock(spec=mt.InitializeRequest)
42+
43+
mock_responder = Mock(spec=RequestResponder)
44+
mock_responder.request = mock_request
45+
mock_responder._completed = False
46+
mock_responder.__enter__ = Mock(return_value=mock_responder)
47+
mock_responder.__exit__ = Mock(return_value=False)
48+
mock_responder.respond = AsyncMock()
49+
50+
error = mt.ErrorData(code=1, message='test error')
51+
mcp_error = McpError(error=error)
52+
53+
with patch.object(
54+
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock, side_effect=mcp_error
55+
):
56+
await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder)
57+
mock_responder.respond.assert_called_once_with(error)
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_patched_received_request_initialize_mcp_error_completed():
62+
"""Test that patched _received_request re-raises McpError when responder completed."""
63+
import fastmcp.server.low_level as low_level_module
64+
from mcp_proxy_for_aws import fastmcp_patch
65+
66+
mock_self = Mock()
67+
mock_self.fastmcp = Mock()
68+
69+
mock_request = Mock()
70+
mock_request.root = Mock(spec=mt.InitializeRequest)
71+
72+
mock_responder = Mock(spec=RequestResponder)
73+
mock_responder.request = mock_request
74+
mock_responder._completed = True
75+
76+
error = mt.ErrorData(code=1, message='test error')
77+
mcp_error = McpError(error=error)
78+
79+
with patch.object(
80+
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock, side_effect=mcp_error
81+
):
82+
with pytest.raises(McpError):
83+
await low_level_module.MiddlewareServerSession._received_request(
84+
mock_self, mock_responder
85+
)
86+
87+
88+
@pytest.mark.asyncio
89+
async def test_patched_received_request_non_initialize():
90+
"""Test that patched _received_request calls original for non-initialize requests."""
91+
import fastmcp.server.low_level as low_level_module
92+
from mcp_proxy_for_aws import fastmcp_patch
93+
94+
mock_self = Mock()
95+
96+
mock_request = Mock()
97+
mock_request.root = Mock(spec=mt.CallToolRequest)
98+
99+
mock_responder = Mock(spec=RequestResponder)
100+
mock_responder.request = mock_request
101+
102+
with patch.object(
103+
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock
104+
) as mock_original:
105+
await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder)
106+
mock_original.assert_called_once_with(mock_self, mock_responder)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import mcp.types as mt
2+
import pytest
3+
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
4+
from unittest.mock import AsyncMock, Mock
5+
6+
7+
def create_initialize_request(client_name: str) -> mt.InitializeRequest:
8+
"""Create a real InitializeRequest object."""
9+
return mt.InitializeRequest(
10+
method='initialize',
11+
params=mt.InitializeRequestParams(
12+
protocolVersion='2024-11-05',
13+
capabilities=mt.ClientCapabilities(),
14+
clientInfo=mt.Implementation(name=client_name, version='1.0'),
15+
),
16+
)
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_on_initialize_connects_client():
21+
"""Test that on_initialize calls client._connect()."""
22+
mock_client = Mock()
23+
mock_client._connect = AsyncMock()
24+
25+
mock_factory = Mock()
26+
mock_factory.set_init_params = Mock()
27+
mock_factory.get_client = AsyncMock(return_value=mock_client)
28+
29+
middleware = InitializeMiddleware(mock_factory)
30+
31+
mock_context = Mock()
32+
mock_context.message = create_initialize_request('test-client')
33+
34+
mock_call_next = AsyncMock()
35+
36+
await middleware.on_initialize(mock_context, mock_call_next)
37+
38+
mock_factory.set_init_params.assert_called_once_with(mock_context.message)
39+
mock_factory.get_client.assert_called_once()
40+
mock_client._connect.assert_called_once()
41+
mock_call_next.assert_called_once_with(mock_context)
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_on_initialize_fails_if_connect_fails():
46+
"""Test that on_initialize raises exception if _connect() fails."""
47+
mock_client = Mock()
48+
mock_client._connect = AsyncMock(side_effect=Exception('Connection failed'))
49+
50+
mock_factory = Mock()
51+
mock_factory.set_init_params = Mock()
52+
mock_factory.get_client = AsyncMock(return_value=mock_client)
53+
54+
middleware = InitializeMiddleware(mock_factory)
55+
56+
mock_context = Mock()
57+
mock_context.message = create_initialize_request('test-client')
58+
59+
mock_call_next = AsyncMock()
60+
61+
with pytest.raises(Exception, match='Connection failed'):
62+
await middleware.on_initialize(mock_context, mock_call_next)
63+
64+
mock_call_next.assert_not_called()
65+
66+
67+
@pytest.mark.asyncio
68+
@pytest.mark.parametrize(
69+
'client_name',
70+
[
71+
'Kiro CLI',
72+
'kiro cli',
73+
'KIRO CLI',
74+
'Amazon Q Dev CLI',
75+
'amazon q dev cli',
76+
'Q DEV CLI',
77+
],
78+
)
79+
async def test_on_initialize_skips_connect_for_special_clients(client_name):
80+
"""Test that on_initialize skips _connect() for Kiro CLI and Q Dev CLI."""
81+
mock_client = Mock()
82+
mock_client._connect = AsyncMock()
83+
84+
mock_factory = Mock()
85+
mock_factory.set_init_params = Mock()
86+
mock_factory.get_client = AsyncMock(return_value=mock_client)
87+
88+
middleware = InitializeMiddleware(mock_factory)
89+
90+
mock_context = Mock()
91+
mock_context.message = create_initialize_request(client_name)
92+
93+
mock_call_next = AsyncMock()
94+
95+
await middleware.on_initialize(mock_context, mock_call_next)
96+
97+
mock_client._connect.assert_not_called()
98+
mock_call_next.assert_called_once_with(mock_context)

0 commit comments

Comments
 (0)