diff --git a/examples/mcp-client/langchain/main.py b/examples/mcp-client/langchain/main.py index 88b00d0..ac54cf6 100644 --- a/examples/mcp-client/langchain/main.py +++ b/examples/mcp-client/langchain/main.py @@ -26,12 +26,6 @@ - MCP_REGION: AWS region where the MCP server is hosted (e.g., "us-west-2") 3. Run: `uv run main.py` -Example .env file: -================== -MCP_SERVER_URL=https://example.gateway.bedrock-agentcore.us-west-2.amazonaws.com/mcp -MCP_SERVER_AWS_SERVICE=bedrock-agentcore -MCP_SERVER_REGION=us-west-2 - Example .env file: ================== MCP_SERVER_URL=https://example.gateway.bedrock-agentcore.us-west-2.amazonaws.com/mcp diff --git a/mcp_proxy_for_aws/middleware/initialize_middleware.py b/mcp_proxy_for_aws/middleware/initialize_middleware.py index 5d74def..3a314f2 100644 --- a/mcp_proxy_for_aws/middleware/initialize_middleware.py +++ b/mcp_proxy_for_aws/middleware/initialize_middleware.py @@ -9,7 +9,7 @@ class InitializeMiddleware(Middleware): - """Intecept MCP initialize request and initialize the proxy client.""" + """Intercept MCP initialize request and initialize the proxy client.""" def __init__(self, client_factory: AWSMCPProxyClientFactory) -> None: """Create a middleware with client factory.""" diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index 441d1ce..d03feaa 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -1,3 +1,17 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import os import pytest @@ -90,3 +104,15 @@ def _build_endpoint_environment_remote_configuration(): endpoint=remote_endpoint_url, region_name=region_name, ) + + +@pytest_asyncio.fixture(loop_scope='module', scope='module') +async def aws_mcp_client(): + """Create MCP Client for AWS MCP Server.""" + client = build_mcp_client( + endpoint='https://aws-mcp.us-east-1.api.aws/mcp', + region_name='us-east-1', + ) + + async with client: + yield client diff --git a/tests/integ/mcp/simple_mcp_client.py b/tests/integ/mcp/simple_mcp_client.py index 6657c63..db2c133 100644 --- a/tests/integ/mcp/simple_mcp_client.py +++ b/tests/integ/mcp/simple_mcp_client.py @@ -1,3 +1,17 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import boto3 import fastmcp import logging @@ -27,7 +41,7 @@ def build_mcp_client( **_build_mcp_config(endpoint=endpoint, region_name=region_name, metadata=metadata) ), elicitation_handler=_basic_elicitation_handler, - timeout=10.0, # seconds + timeout=20.0, # seconds ) diff --git a/tests/integ/test_aws_mcp_server_happy_path.py b/tests/integ/test_aws_mcp_server_happy_path.py new file mode 100644 index 0000000..ac1ca31 --- /dev/null +++ b/tests/integ/test_aws_mcp_server_happy_path.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Happy path integration tests for AWS MCP Server at https://aws-mcp.us-east-1.api.aws/mcp.""" + +import fastmcp +import json +import logging +import pytest +from fastmcp.client.client import CallToolResult +from tests.integ.test_proxy_simple_mcp_server import get_text_content + + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio(loop_scope='module') +async def test_aws_mcp_ping(aws_mcp_client: fastmcp.Client): + """Test ping to AWS MCP Server.""" + await aws_mcp_client.ping() + + +@pytest.mark.asyncio(loop_scope='module') +async def test_aws_mcp_list_tools(aws_mcp_client: fastmcp.Client): + """Test list tools from AWS MCP Server.""" + tools = await aws_mcp_client.list_tools() + + assert len(tools) > 0, f'AWS MCP Server should have tools (got {len(tools)})' + + +def verify_response_content(response: CallToolResult): + """Verify that a tool call response is successful and contains text content. + + Args: + response: The CallToolResult from an MCP tool call + + Returns: + str: The extracted text content from the response + + Raises: + AssertionError: If response indicates an error or has empty content + """ + assert response.is_error is False, ( + f'is_error returned true. Returned response body: {response}.' + ) + assert len(response.content) > 0, f'Empty result list in response. Response: {response}' + + response_text = get_text_content(response) + assert len(response_text) > 0, f'Empty response text. Response: {response}' + + return response_text + + +def verify_json_response(response: CallToolResult): + """Verify that a tool call response is successful and contains valid JSON data. + + Args: + response: The CallToolResult from an MCP tool call + + Raises: + AssertionError: If response indicates an error, has empty content, + contains invalid JSON, or JSON data is empty + """ + response_text = verify_response_content(response) + + # Verify response_text is valid JSON + try: + response_data = json.loads(response_text) + except json.JSONDecodeError: + raise AssertionError(f'Response text is not valid JSON. Response text: {response_text}') + + assert len(response_data) > 0, f'Empty JSON content in response. Response: {response}' + + +@pytest.mark.parametrize( + 'tool_name,params', + [ + ('aws___list_regions', {}), + ('aws___suggest_aws_commands', {'query': 'how to list my lambda functions'}), + ('aws___search_documentation', {'search_phrase': 'S3 bucket versioning'}), + ( + 'aws___recommend', + {'url': 'https://docs.aws.amazon.com/lambda/latest/dg/lambda-invocation.html'}, + ), + ( + 'aws___read_documentation', + {'url': 'https://docs.aws.amazon.com/lambda/latest/dg/lambda-invocation.html'}, + ), + ( + 'aws___get_regional_availability', + {'resource_type': 'cfn', 'region': 'us-east-1'}, + ), + ('aws___call_aws', {'cli_command': 'aws s3 ls', 'max_results': 10}), + ], + ids=[ + 'list_regions', + 'suggest_aws_commands', + 'search_documentation', + 'recommend', + 'read_documentation', + 'get_regional_availability', + 'call_aws', + ], +) +@pytest.mark.asyncio(loop_scope='module') +async def test_aws_mcp_tools(aws_mcp_client: fastmcp.Client, tool_name: str, params: dict): + """Test AWS MCP tools with minimal valid params.""" + response = await aws_mcp_client.call_tool(tool_name, params) + verify_json_response(response) + + +@pytest.mark.asyncio(loop_scope='module') +async def test_aws_mcp_tools_retrieve_agent_sop(aws_mcp_client: fastmcp.Client): + """Test aws___retrieve_agent_sop by retrieving the list of available SOPs.""" + # Step 1: Call retrieve_agent_sop with empty params to get list of available SOPs + list_sops_response = await aws_mcp_client.call_tool('aws___retrieve_agent_sop') + + list_sops_response_text = verify_response_content(list_sops_response) + + # Parse SOP names from text (format: "* sop_name : description") + sop_names = [] + for line in list_sops_response_text.split('\n'): + line = line.strip() + if line.startswith('*') and ':' in line: + # Extract the SOP name between '*' and ':' + sop_name = line.split('*', 1)[1].split(':', 1)[0].strip() + if sop_name: + sop_names.append(sop_name) + + assert len(sop_names) > 0, ( + f'No SOPs found in response. Response: {list_sops_response_text[:200]}...' + ) + logger.info('Found %d SOPs: %s', len(sop_names), sop_names) + + # Step 2: Test retrieving the first SOP + test_script = sop_names[0] + logger.info('Testing with SOP: %s', test_script) + + response = await aws_mcp_client.call_tool( + 'aws___retrieve_agent_sop', {'sop_name': test_script} + ) + + verify_response_content(response) diff --git a/tests/integ/test_aws_mcp_server_negative.py b/tests/integ/test_aws_mcp_server_negative.py new file mode 100644 index 0000000..f07a79c --- /dev/null +++ b/tests/integ/test_aws_mcp_server_negative.py @@ -0,0 +1,101 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Negative integration tests for AWS MCP Server at https://aws-mcp.us-east-1.api.aws/mcp.""" + +import boto3 +import fastmcp +import logging +import pytest +from fastmcp.client import StdioTransport + + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio(loop_scope='module') +async def test_expired_credentials(): + """Test that expired credentials are properly rejected. + + This test uses real AWS credentials but modifies the session token to simulate + an expired token, which should result in an 'expired token' error message. + + This test will: + - PASS when expired credentials are rejected with appropriate error + - FAIL if the modified credentials somehow work + """ + # Get real credentials from boto3 + session = boto3.Session() + creds = session.get_credentials() + + # Use real access key and secret, but modify the token to simulate expiration by changing a few characters + expired_token = 'EXPIRED_TOKEN_12345' + + expired_client = fastmcp.Client( + StdioTransport( + command='mcp-proxy-for-aws', + args=[ + 'https://aws-mcp.us-east-1.api.aws/mcp', + '--log-level', + 'DEBUG', + '--region', + 'us-east-1', + ], + env={ + 'AWS_REGION': 'us-east-1', + 'AWS_ACCESS_KEY_ID': creds.access_key, + 'AWS_SECRET_ACCESS_KEY': creds.secret_key, + 'AWS_SESSION_TOKEN': expired_token, + }, + ), + timeout=30.0, + ) + + exception_raised = False + exception_message = '' + + try: + async with expired_client: + response = await expired_client.call_tool('aws___list_regions') + logger.info('Tool call completed without exception. Response: %s', response) + except Exception as e: + exception_raised = True + exception_message = str(e) + logger.info('Exception raised as expected: %s: %s', type(e).__name__, exception_message) + + # Assert that an exception was raised (credentials are invalid) + assert exception_raised, ( + 'Expected authentication exception with invalid credentials, but tool call succeeded.' + ) + + # Verify the exception is related to authentication/credentials + error_message_lower = exception_message.lower() + auth_error_patterns = [ + 'credential', + 'authentication', + 'authorization', + 'access denied', + 'unauthorized', + 'invalid', + 'expired', + 'signature', + '401', + ] + + assert any(pattern in error_message_lower for pattern in auth_error_patterns), ( + f"Exception was raised but doesn't appear to be authentication-related. " + f'Expected one of {auth_error_patterns}, but got: {exception_message[:200]}' + ) + + logger.info('Test passed: Invalid credentials correctly rejected')