diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index cd96a7566..cd199052a 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -467,6 +467,10 @@ async def _initialize(self) -> None: # pragma: no cover """Load stored tokens and client info.""" self.context.current_tokens = await self.context.storage.get_tokens() self.context.client_info = await self.context.storage.get_client_info() + + if self.context.current_tokens and self.context.current_tokens.expires_in is not None: + self.context.update_token_expiry(self.context.current_tokens) + self._initialized = True def _add_auth_header(self, request: httpx.Request) -> None: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 593d5cfe0..2589fe83c 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -82,6 +82,17 @@ def valid_tokens(): ) +@pytest.fixture +def expired_tokens(): + return OAuthToken( + access_token="test_access_token", + token_type="Bearer", + expires_in=-100, # Expired 100 seconds ago + refresh_token="test_refresh_token", + scope="read write", + ) + + @pytest.fixture def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): async def redirect_handler(url: str) -> None: @@ -259,6 +270,100 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O assert context.token_expiry_time is None +class TestTokenInitialization: + """Test token loading from storage during initialization.""" + + @pytest.mark.anyio + async def test_initialize_sets_token_expiry_from_stored_tokens( + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken + ): + """Test _initialize() sets token_expiry_time when loading tokens from storage.""" + context = oauth_provider.context + await context.storage.set_tokens(valid_tokens) + + # Before initialization + assert oauth_provider._initialized is False + assert context.current_tokens is None + assert context.token_expiry_time is None + + # Trigger initialization by starting auth flow + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # First request calls _initialize() + request = await auth_flow.__anext__() + + # After first request, verify tokens were loaded + assert oauth_provider._initialized is True + assert oauth_provider.context.current_tokens is not None + assert oauth_provider.context.current_tokens.access_token == "test_access_token" + + # token_expiry_time should be set by update_token_expiry() + assert oauth_provider.context.token_expiry_time is not None + + # Verify token is considered valid + assert oauth_provider.context.is_token_valid() is True + + # Request should have auth header added + assert request.headers["Authorization"] == "Bearer test_access_token" + + # Complete the flow + response = httpx.Response(200, request=request) + try: + await auth_flow.asend(response) + except StopAsyncIteration: + pass + + @pytest.mark.anyio + async def test_initialize_with_expired_tokens_detects_expiry( + self, oauth_provider: OAuthClientProvider, expired_tokens: OAuthToken + ): + """Test that expired tokens loaded from storage are detected as invalid.""" + context = oauth_provider.context + await context.storage.set_tokens(expired_tokens) + await context.storage.set_client_info( + OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + ) + + # First request + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # This should trigger a refresh attempt, not the original request + refresh_request = await auth_flow.__anext__() + + # Verify tokens were loaded + assert context.current_tokens is not None + + # token_expiry_time should be set by update_token_expiry() + assert context.token_expiry_time is not None + + # Token should be detected as invalid (expired) + assert context.is_token_valid() is False + + # Should be able to refresh + assert context.can_refresh_token() is True + + # Complete the flow + refresh_response = httpx.Response( + 200, + content=b'{"access_token": "new_token", "token_type": "Bearer", "expires_in": 3600}', + request=refresh_request, + ) + try: + original_request = await auth_flow.asend(refresh_response) + # Should retry original request with new token + assert original_request.headers["Authorization"] == "Bearer new_token" + final_response = httpx.Response(200, request=original_request) + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass + + class TestOAuthFlow: """Test OAuth flow methods."""