From 1e407ce30b1cdcd5b0fd91e55df485d2c94b4e3c Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 16 Dec 2025 15:40:47 -0800 Subject: [PATCH 01/10] Client-side polling via `Last-Event-ID` --- .../Client/HttpClientTransportOptions.cs | 13 ++ .../StreamableHttpClientSessionTransport.cs | 182 +++++++++++++++--- 2 files changed, 164 insertions(+), 31 deletions(-) diff --git a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs index 624a14aa1..582a6f8e0 100644 --- a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs @@ -106,4 +106,17 @@ public required Uri Endpoint /// Gets sor sets the authorization provider to use for authentication. /// public ClientOAuthOptions? OAuth { get; set; } + + /// + /// Gets or sets the maximum number of reconnection attempts when an SSE stream is disconnected. + /// + /// + /// The maximum number of reconnection attempts. The default is 2. + /// + /// + /// When an SSE stream is disconnected (e.g., due to a network issue), the client will attempt to + /// reconnect using the Last-Event-ID header to resume from where it left off. This property controls + /// how many reconnection attempts are made before giving up. + /// + public int MaxReconnectionAttempts { get; set; } = 2; } diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 534249038..e36d3e966 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -16,6 +16,8 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private static readonly MediaTypeWithQualityHeaderValue s_applicationJsonMediaType = new("application/json"); private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream"); + private static readonly TimeSpan s_defaultReconnectionDelay = TimeSpan.FromSeconds(1); + private readonly McpHttpClient _httpClient; private readonly HttpClientTransportOptions _options; private readonly CancellationTokenSource _connectionCts; @@ -105,8 +107,18 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes } else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream") { - using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken); - rpcResponseOrError = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false); + var sseState = new SseStreamState(); + using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var sseResponse = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, sseState, cancellationToken).ConfigureAwait(false); + rpcResponseOrError = sseResponse.Response; + + // Resumability: If POST SSE stream ended without a response but we have a Last-Event-ID (from priming), + // attempt to resume by sending a GET request with Last-Event-ID header. The server will replay + // events from the event store, allowing us to receive the pending response. + if (rpcResponseOrError is null && rpcRequest is not null && sseState.LastEventId is not null) + { + rpcResponseOrError = await SendGetSseRequestWithRetriesAsync(rpcRequest, sseState, cancellationToken).ConfigureAwait(false); + } } if (rpcRequest is null) @@ -188,56 +200,140 @@ public override async ValueTask DisposeAsync() private async Task ReceiveUnsolicitedMessagesAsync() { - // Send a GET request to handle any unsolicited messages not sent over a POST response. - using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint); - request.Headers.Accept.Add(s_textEventStreamMediaType); - CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); + var state = new SseStreamState(); - // Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages. - HttpResponseMessage response; - try + // Continuously receive unsolicited messages until canceled + while (!_connectionCts.Token.IsCancellationRequested) { - response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false); - } - catch (HttpRequestException) - { - return; - } + await SendGetSseRequestWithRetriesAsync( + relatedRpcRequest: null, + state, + _connectionCts.Token).ConfigureAwait(false); - using (response) - { - if (!response.IsSuccessStatusCode) + // If we exhausted retries without receiving any events, stop trying + if (state.LastEventId is null) { return; } - - using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false); - await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false); } } - private async Task ProcessSseResponseAsync(Stream responseStream, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) + /// + /// Sends a GET request for SSE with retry logic and resumability support. + /// + private async Task SendGetSseRequestWithRetriesAsync( + JsonRpcRequest? relatedRpcRequest, + SseStreamState state, + CancellationToken cancellationToken) { - await foreach (SseItem sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) + int attempt = 0; + + // Delay before first attempt if we're reconnecting (have a Last-Event-ID) + bool shouldDelay = state.LastEventId is not null; + + while (attempt < _options.MaxReconnectionAttempts) { - if (sseEvent.EventType != "message") + cancellationToken.ThrowIfCancellationRequested(); + + if (shouldDelay) { - continue; + var delay = state.RetryInterval ?? s_defaultReconnectionDelay; + await Task.Delay(delay, cancellationToken).ConfigureAwait(false); } + shouldDelay = true; - var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false); + using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint); + request.Headers.Accept.Add(s_textEventStreamMediaType); + CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion, state.LastEventId); - // The server SHOULD end the HTTP response body here anyway, but we won't leave it to chance. This transport makes - // a GET request for any notifications that might need to be sent after the completion of each POST. - if (rpcResponseOrError is not null) + HttpResponseMessage response; + try { - return rpcResponseOrError; + response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false); + } + catch (HttpRequestException) + { + attempt++; + continue; + } + + using (response) + { + if (!response.IsSuccessStatusCode) + { + // If the server could be reached but returned a non-success status code, + // retrying likely won't change that. + return null; + } + + using var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var sseResponse = await ProcessSseResponseAsync(responseStream, relatedRpcRequest, state, cancellationToken).ConfigureAwait(false); + + if (sseResponse.Response is { } rpcResponseOrError) + { + return rpcResponseOrError; + } + + // If we reach here, then the stream closed without the response. + + if (sseResponse.IsNetworkError || state.LastEventId is null) + { + // No event ID means server may not support resumability; don't retry indefinitely. + attempt++; + } + else + { + // We have an event ID, so we continue polling to receive more events. + // The server should eventually send a response or return an error. + attempt = 0; + } } } return null; } + private async Task ProcessSseResponseAsync( + Stream responseStream, + JsonRpcRequest? relatedRpcRequest, + SseStreamState state, + CancellationToken cancellationToken) + { + try + { + await foreach (SseItem sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) + { + // Track event ID and retry interval for resumability + if (!string.IsNullOrEmpty(sseEvent.EventId)) + { + state.LastEventId = sseEvent.EventId; + } + if (sseEvent.ReconnectionInterval.HasValue) + { + state.RetryInterval = sseEvent.ReconnectionInterval.Value; + } + + // Skip events with empty data + if (string.IsNullOrEmpty(sseEvent.Data)) + { + continue; + } + + var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false); + if (rpcResponseOrError is not null) + { + return new() { Response = rpcResponseOrError }; + } + } + } + catch (Exception ex) when (ex is IOException or HttpRequestException) + { + return new() { IsNetworkError = true }; + } + + return default; + } + private async Task ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) { LogTransportReceivedMessageSensitive(Name, data); @@ -292,7 +388,8 @@ internal static void CopyAdditionalHeaders( HttpRequestHeaders headers, IDictionary? additionalHeaders, string? sessionId, - string? protocolVersion) + string? protocolVersion, + string? lastEventId = null) { if (sessionId is not null) { @@ -304,6 +401,11 @@ internal static void CopyAdditionalHeaders( headers.Add("MCP-Protocol-Version", protocolVersion); } + if (lastEventId is not null) + { + headers.Add("Last-Event-ID", lastEventId); + } + if (additionalHeaders is null) { return; @@ -317,4 +419,22 @@ internal static void CopyAdditionalHeaders( } } } + + /// + /// Tracks state across SSE stream connections. + /// + private sealed class SseStreamState + { + public string? LastEventId { get; set; } + public TimeSpan? RetryInterval { get; set; } + } + + /// + /// Represents the result of processing an SSE response. + /// + private readonly struct SseResponse + { + public JsonRpcMessageWithId? Response { get; init; } + public bool IsNetworkError { get; init; } + } } From 677d9ea6966aa3772d12264ac6355f673a3283ef Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 16 Dec 2025 15:56:13 -0800 Subject: [PATCH 02/10] `ISseEventStreamStore` and test implementation --- .../Server/ISseEventStreamReader.cs | 36 + .../Server/ISseEventStreamStore.cs | 23 + .../Server/ISseEventStreamWriter.cs | 43 ++ .../Server/SseEventStreamMode.cs | 20 + .../Server/SseEventStreamOptions.cs | 22 + .../Server/SseEventStreamReaderExtensions.cs | 50 ++ .../SseEventStreamStoreTests.cs | 637 ++++++++++++++++++ .../Utils/TestSseEventStreamStore.cs | 246 +++++++ 8 files changed, 1077 insertions(+) create mode 100644 src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs create mode 100644 src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs create mode 100644 src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs create mode 100644 src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs create mode 100644 src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs create mode 100644 src/ModelContextProtocol.Core/Server/SseEventStreamReaderExtensions.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/SseEventStreamStoreTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs b/src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs new file mode 100644 index 000000000..01c642355 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs @@ -0,0 +1,36 @@ +using ModelContextProtocol.Protocol; +using System.Net.ServerSentEvents; + +namespace ModelContextProtocol.Server; + +/// +/// Provides read access to an SSE event stream, allowing events to be consumed asynchronously. +/// +public interface ISseEventStreamReader +{ + /// + /// Gets the session ID associated with the stream being read. + /// + string SessionId { get; } + + /// + /// Gets the ID of the stream. + /// + /// + /// This value is guaranteed to be unique on a per-session basis. + /// + string StreamId { get; } + + /// + /// Gets the messages from the stream as an . + /// + /// A token to cancel the operation. + /// An of containing JSON-RPC messages. + /// + /// If the stream's mode is set to , the returned + /// messages will only include the currently-available events starting at the last event ID specified + /// when the reader was created. Otherwise, the returned messages will continue until the associated + /// is disposed. + /// + IAsyncEnumerable> ReadEventsAsync(CancellationToken cancellationToken = default); +} diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs new file mode 100644 index 000000000..723afab07 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs @@ -0,0 +1,23 @@ +namespace ModelContextProtocol.Server; + +/// +/// Provides storage and retrieval of SSE event streams, enabling resumability and redelivery of events. +/// +public interface ISseEventStreamStore +{ + /// + /// Creates a new SSE event stream with the specified options. + /// + /// The configuration options for the new stream. + /// A token to cancel the operation. + /// A writer for the newly created event stream. + ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default); + + /// + /// Gets a reader for an existing event stream based on the last event ID. + /// + /// The ID of the last event received by the client, used to resume from that point. + /// A token to cancel the operation. + /// A reader for the event stream, or null if no matching stream is found. + ValueTask GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken); +} diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs b/src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs new file mode 100644 index 000000000..8cb99af90 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs @@ -0,0 +1,43 @@ +using ModelContextProtocol.Protocol; +using System.Net.ServerSentEvents; + +namespace ModelContextProtocol.Server; + +/// +/// Provides write access to an SSE event stream, allowing events to be written and tracked with unique IDs. +/// +public interface ISseEventStreamWriter : IAsyncDisposable +{ + /// + /// Gets the ID of the stream. + /// + /// + /// This value is guaranteed to be unique on a per-session basis. + /// + string StreamId { get; } + + /// + /// Gets the current mode of the event stream. + /// + SseEventStreamMode Mode { get; } + + /// + /// Sets the mode of the event stream. + /// + /// The new mode to set for the event stream. + /// A token to cancel the operation. + /// A task that represents the asynchronous operation. + ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default); + + /// + /// Writes an event to the stream. + /// + /// The original . + /// A token to cancel the operation. + /// A new with a populated event ID. + /// + /// If the provided already has an event ID, this method skips writing the event. + /// Otherwise, an event ID unique to all sessions and streams is generated and assigned to the event. + /// + ValueTask> WriteEventAsync(SseItem sseItem, CancellationToken cancellationToken = default); +} diff --git a/src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs b/src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs new file mode 100644 index 000000000..cd5c7f4bf --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs @@ -0,0 +1,20 @@ +namespace ModelContextProtocol.Server; + +/// +/// Represents the mode of an SSE event stream. +/// +public enum SseEventStreamMode +{ + /// + /// Causes the event stream returned by to only end when + /// the associated gets disposed. + /// + Default = 0, + + /// + /// Causes the event stream returned by to end + /// after the most recent event has been consumed. This forces clients to keep making new requests in order to receive + /// the latest messages. + /// + Polling = 1, +} diff --git a/src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs b/src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs new file mode 100644 index 000000000..6eca0e0b8 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs @@ -0,0 +1,22 @@ +namespace ModelContextProtocol.Server; + +/// +/// Configuration options for creating an SSE event stream. +/// +public sealed class SseEventStreamOptions +{ + /// + /// Gets or sets the session ID associated with the event stream. + /// + public required string SessionId { get; set; } + + /// + /// Gets or sets the stream ID that uniquely identifies this stream within a session. + /// + public required string StreamId { get; set; } + + /// + /// Gets or sets the mode of the event stream. Defaults to . + /// + public SseEventStreamMode Mode { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Server/SseEventStreamReaderExtensions.cs b/src/ModelContextProtocol.Core/Server/SseEventStreamReaderExtensions.cs new file mode 100644 index 000000000..ff15ba98c --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/SseEventStreamReaderExtensions.cs @@ -0,0 +1,50 @@ +using ModelContextProtocol.Protocol; +using System.Buffers; +using System.Net.ServerSentEvents; +using System.Text.Json; + +namespace ModelContextProtocol.Server; + +/// +/// Provides extension methods for . +/// +public static class SseEventStreamReaderExtensions +{ + /// + /// Copies all events from the reader to the destination stream in SSE format. + /// + /// The event stream reader to copy events from. + /// The destination stream to write SSE-formatted events to. + /// A token to cancel the operation. + /// A task that represents the asynchronous copy operation. + /// Thrown when or is null. + public static async Task CopyToAsync(this ISseEventStreamReader reader, Stream destination, CancellationToken cancellationToken = default) + { + Throw.IfNull(reader); + Throw.IfNull(destination); + + Utf8JsonWriter? jsonWriter = null; + + var events = reader.ReadEventsAsync(cancellationToken); + await SseFormatter.WriteAsync(events, destination, FormatEvent, cancellationToken); + + void FormatEvent(SseItem item, IBufferWriter writer) + { + if (item.Data is null) + { + return; + } + + if (jsonWriter is null) + { + jsonWriter = new Utf8JsonWriter(writer); + } + else + { + jsonWriter.Reset(writer); + } + + JsonSerializer.Serialize(jsonWriter, item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!); + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseEventStreamStoreTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseEventStreamStoreTests.cs new file mode 100644 index 000000000..52391f351 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseEventStreamStoreTests.cs @@ -0,0 +1,637 @@ +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Net.ServerSentEvents; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for the interface and implementation. +/// +public class SseEventStreamStoreTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +{ + private CancellationToken CancellationToken => TestContext.Current.CancellationToken; + + #region CreateStreamAsync Tests + + [Fact] + public async Task CreateStreamAsync_ReturnsWriter_WithCorrectStreamId() + { + var store = new TestSseEventStreamStore(); + var options = new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }; + + var writer = await store.CreateStreamAsync(options, CancellationToken); + + Assert.NotNull(writer); + Assert.Equal("stream-1", writer.StreamId); + } + + [Fact] + public async Task CreateStreamAsync_ReturnsWriter_WithCorrectMode() + { + var store = new TestSseEventStreamStore(); + + var defaultModeOptions = new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }; + var defaultWriter = await store.CreateStreamAsync(defaultModeOptions, CancellationToken); + Assert.Equal(SseEventStreamMode.Default, defaultWriter.Mode); + + var pollingModeOptions = new SseEventStreamOptions + { + SessionId = "session-2", + StreamId = "stream-2", + Mode = SseEventStreamMode.Polling + }; + var pollingWriter = await store.CreateStreamAsync(pollingModeOptions, CancellationToken); + Assert.Equal(SseEventStreamMode.Polling, pollingWriter.Mode); + } + + [Fact] + public async Task CreateStreamAsync_MultipleStreams_CreatesDistinctWriters() + { + var store = new TestSseEventStreamStore(); + + var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-2", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + Assert.NotSame(writer1, writer2); + Assert.Equal("stream-1", writer1.StreamId); + Assert.Equal("stream-2", writer2.StreamId); + } + + #endregion + + #region WriteEventAsync Tests + + [Fact] + public async Task WriteEventAsync_AssignsEventId_WhenNotPresent() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message = new JsonRpcRequest { Method = "test", Id = new RequestId("1") }; + var item = new SseItem(message, "message"); + + var result = await writer.WriteEventAsync(item, CancellationToken); + + Assert.NotNull(result.EventId); + Assert.Equal("1", result.EventId); + } + + [Fact] + public async Task WriteEventAsync_SkipsAssigningEventId_WhenAlreadyPresent() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message = new JsonRpcRequest { Method = "test", Id = new RequestId("1") }; + var item = new SseItem(message, "message") { EventId = "existing-id" }; + + var result = await writer.WriteEventAsync(item, CancellationToken); + + Assert.Equal("existing-id", result.EventId); + } + + [Fact] + public async Task WriteEventAsync_GeneratesSequentialEventIds() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") }; + var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") }; + var message3 = new JsonRpcRequest { Method = "test3", Id = new RequestId("3") }; + + var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + var result2 = await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken); + var result3 = await writer.WriteEventAsync(new SseItem(message3, "message"), CancellationToken); + + Assert.Equal("1", result1.EventId); + Assert.Equal("2", result2.EventId); + Assert.Equal("3", result3.EventId); + } + + [Fact] + public async Task WriteEventAsync_TracksStoredEventIds() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") }; + var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") }; + + await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken); + + Assert.Equal(2, store.StoreEventCallCount); + Assert.Equal(["1", "2"], store.StoredEventIds); + } + + [Fact] + public async Task WriteEventAsync_PreservesEventData() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message = new JsonRpcRequest { Method = "test-method", Id = new RequestId("req-1") }; + var item = new SseItem(message, "custom-event"); + + var result = await writer.WriteEventAsync(item, CancellationToken); + + Assert.Same(message, result.Data); + Assert.Equal("custom-event", result.EventType); + } + + [Fact] + public async Task WriteEventAsync_HandlesNullData() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var item = new SseItem(null, "priming"); + + var result = await writer.WriteEventAsync(item, CancellationToken); + + Assert.NotNull(result.EventId); + Assert.Null(result.Data); + } + + #endregion + + #region SetModeAsync Tests + + [Fact] + public async Task SetModeAsync_ChangesMode_FromDefaultToPolling() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + Assert.Equal(SseEventStreamMode.Default, writer.Mode); + + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + Assert.Equal(SseEventStreamMode.Polling, writer.Mode); + } + + [Fact] + public async Task SetModeAsync_ChangesMode_FromPollingToDefault() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + Assert.Equal(SseEventStreamMode.Polling, writer.Mode); + + await writer.SetModeAsync(SseEventStreamMode.Default, CancellationToken); + + Assert.Equal(SseEventStreamMode.Default, writer.Mode); + } + + #endregion + + #region GetStreamReaderAsync Tests + + [Fact] + public async Task GetStreamReaderAsync_ReturnsNull_WhenEventIdNotFound() + { + var store = new TestSseEventStreamStore(); + + var reader = await store.GetStreamReaderAsync("nonexistent-id", CancellationToken); + + Assert.Null(reader); + } + + [Fact] + public async Task GetStreamReaderAsync_ReturnsReader_WhenEventIdExists() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var message = new JsonRpcRequest { Method = "test", Id = new RequestId("1") }; + var result = await writer.WriteEventAsync(new SseItem(message, "message"), CancellationToken); + + var reader = await store.GetStreamReaderAsync(result.EventId!, CancellationToken); + + Assert.NotNull(reader); + Assert.Equal("stream-1", reader.StreamId); + } + + [Fact] + public async Task GetStreamReaderAsync_Throws_WhenStreamReplaced() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Create a new stream with the same key, effectively replacing the old one + await Assert.ThrowsAsync(async () => await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken)); + } + + #endregion + + #region ReadEventsAsync Tests + + [Fact] + public async Task ReadEventsAsync_ReturnsEventsAfterLastEventId() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") }; + var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") }; + var message3 = new JsonRpcRequest { Method = "test3", Id = new RequestId("3") }; + + var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken); + await writer.WriteEventAsync(new SseItem(message3, "message"), CancellationToken); + + var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken); + Assert.NotNull(reader); + + var events = await reader.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken); + + Assert.Equal(2, events.Count); + Assert.Equal("test2", ((JsonRpcRequest)events[0].Data!).Method); + Assert.Equal("test3", ((JsonRpcRequest)events[1].Data!).Method); + } + + [Fact] + public async Task ReadEventsAsync_IncludesNullDataEvents() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") }; + var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") }; + + var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + await writer.WriteEventAsync(new SseItem(null, "priming"), CancellationToken); // null data event + await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken); + + var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken); + Assert.NotNull(reader); + + var events = await reader.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken); + + Assert.Equal(2, events.Count); + Assert.Null(events[0].Data); + Assert.Equal("priming", events[0].EventType); + Assert.Equal("test2", ((JsonRpcRequest)events[1].Data!).Method); + } + + [Fact] + public async Task ReadEventsAsync_ReturnsEmpty_WhenNoEventsAfterLastEventId() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var message = new JsonRpcRequest { Method = "test", Id = new RequestId("1") }; + var result = await writer.WriteEventAsync(new SseItem(message, "message"), CancellationToken); + + var reader = await store.GetStreamReaderAsync(result.EventId!, CancellationToken); + Assert.NotNull(reader); + + var events = await reader.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken); + + Assert.Empty(events); + } + + [Fact] + public async Task ReadEventsAsync_InPollingMode_CompletesAfterStoredEvents() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") }; + var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") }; + + var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken); + + var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken); + Assert.NotNull(reader); + + // In polling mode, ReadEventsAsync should complete immediately after returning stored events + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(1)); + var events = await reader.ReadEventsAsync(cts.Token).ToListAsync(cts.Token); + + Assert.Single(events); + } + + [Fact] + public async Task ReadEventsAsync_InDefaultMode_WaitsForNewEvents() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") }; + var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + + var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken); + Assert.NotNull(reader); + + var readTask = Task.Run(async () => + { + var events = new List>(); + await foreach (var e in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(e); + if (events.Count >= 2) + { + break; + } + } + return events; + }, CancellationToken); + + // Give the read task time to start waiting + await Task.Delay(50, CancellationToken); + + // Write new events + var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") }; + var message3 = new JsonRpcRequest { Method = "test3", Id = new RequestId("3") }; + await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken); + await writer.WriteEventAsync(new SseItem(message3, "message"), CancellationToken); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + var events = await readTask.WaitAsync(cts.Token); + + Assert.Equal(2, events.Count); + Assert.Equal("test2", ((JsonRpcRequest)events[0].Data!).Method); + Assert.Equal("test3", ((JsonRpcRequest)events[1].Data!).Method); + } + + [Fact] + public async Task ReadEventsAsync_InDefaultMode_CompletesWhenWriterDisposed() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") }; + var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + + var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken); + Assert.NotNull(reader); + + var readTask = reader.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken).AsTask(); + + // Give the read task time to start waiting + await Task.Delay(50, CancellationToken); + + // Dispose the writer + await writer.DisposeAsync(); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + var events = await readTask.WaitAsync(cts.Token); + + Assert.Empty(events); + } + + [Fact] + public async Task ReadEventsAsync_RespectsCancellation() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") }; + var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + + var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken); + Assert.NotNull(reader); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromMilliseconds(100)); + + await Assert.ThrowsAsync(async () => + { + await reader.ReadEventsAsync(cts.Token).ToListAsync(cts.Token); + }); + } + + #endregion + + #region Cross-Session Tests + + [Fact] + public async Task EventIds_AreUniqueAcrossSessions() + { + var store = new TestSseEventStreamStore(); + + var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-2", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") }; + var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") }; + + var result1 = await writer1.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + var result2 = await writer2.WriteEventAsync(new SseItem(message2, "message"), CancellationToken); + + Assert.NotEqual(result1.EventId, result2.EventId); + } + + [Fact] + public async Task GetStreamReaderAsync_ReturnsCorrectStream_ForDifferentSessions() + { + var store = new TestSseEventStreamStore(); + + var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-2", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var message1 = new JsonRpcRequest { Method = "session1-test", Id = new RequestId("1") }; + var message2 = new JsonRpcRequest { Method = "session2-test", Id = new RequestId("2") }; + var message1b = new JsonRpcRequest { Method = "session1-test2", Id = new RequestId("3") }; + var message2b = new JsonRpcRequest { Method = "session2-test2", Id = new RequestId("4") }; + + var result1 = await writer1.WriteEventAsync(new SseItem(message1, "message"), CancellationToken); + var result2 = await writer2.WriteEventAsync(new SseItem(message2, "message"), CancellationToken); + await writer1.WriteEventAsync(new SseItem(message1b, "message"), CancellationToken); + await writer2.WriteEventAsync(new SseItem(message2b, "message"), CancellationToken); + + var reader1 = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken); + var reader2 = await store.GetStreamReaderAsync(result2.EventId!, CancellationToken); + + Assert.NotNull(reader1); + Assert.NotNull(reader2); + + var events1 = await reader1.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken); + var events2 = await reader2.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken); + + Assert.Single(events1); + Assert.Equal("session1-test2", ((JsonRpcRequest)events1[0].Data!).Method); + + Assert.Single(events2); + Assert.Equal("session2-test2", ((JsonRpcRequest)events2[0].Data!).Method); + } + + #endregion + + #region DisposeAsync Tests + + [Fact] + public async Task DisposeAsync_CompletesWithoutError() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + await writer.DisposeAsync(); + // Should not throw + } + + [Fact] + public async Task DisposeAsync_CanBeCalledMultipleTimes() + { + var store = new TestSseEventStreamStore(); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + await writer.DisposeAsync(); + await writer.DisposeAsync(); + await writer.DisposeAsync(); + // Should not throw + } + + #endregion +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs new file mode 100644 index 000000000..e4058eacd --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs @@ -0,0 +1,246 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Net.ServerSentEvents; +using System.Runtime.CompilerServices; +using System.Threading.Channels; + +namespace ModelContextProtocol.AspNetCore.Tests.Utils; + +/// +/// In-memory event store for testing resumability. +/// This is a simple implementation intended for testing, not for production use. +/// +public sealed class TestSseEventStreamStore : ISseEventStreamStore +{ + private readonly ConcurrentDictionary _streams = new(); + private readonly ConcurrentDictionary _eventLookup = new(); + private readonly List _storedEventIds = []; + private readonly object _storedEventIdsLock = new(); + private int _storeEventCallCount; + private long _globalSequence; + + /// + /// Gets the number of times events have been stored. + /// + public int StoreEventCallCount => _storeEventCallCount; + + /// + /// Gets the list of stored event IDs in order. + /// + public IReadOnlyList StoredEventIds + { + get + { + lock (_storedEventIdsLock) + { + return [.. _storedEventIds]; + } + } + } + + /// + public ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default) + { + var streamKey = GetStreamKey(options.SessionId, options.StreamId); + var state = new StreamState(options.SessionId, options.StreamId, options.Mode); + if (!_streams.TryAdd(streamKey, state)) + { + throw new InvalidOperationException($"A stream with key '{streamKey}' has already been created."); + } + var writer = new InMemoryEventStreamWriter(this, state); + return new ValueTask(writer); + } + + /// + public ValueTask GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default) + { + // Look up the event by its ID to find which stream it belongs to + if (!_eventLookup.TryGetValue(lastEventId, out var lookup)) + { + return new ValueTask((ISseEventStreamReader?)null); + } + + var reader = new InMemoryEventStreamReader(lookup.Stream, lookup.Sequence); + return new ValueTask(reader); + } + + private string GenerateEventId() => Interlocked.Increment(ref _globalSequence).ToString(); + + private void TrackEvent(string eventId, StreamState stream, long sequence) + { + _eventLookup[eventId] = (stream, sequence); + lock (_storedEventIdsLock) + { + _storedEventIds.Add(eventId); + } + Interlocked.Increment(ref _storeEventCallCount); + } + + private static string GetStreamKey(string sessionId, string streamId) => $"{sessionId}:{streamId}"; + + /// + /// Holds the state for a single stream. + /// + private sealed class StreamState + { + private readonly Channel<(SseItem Item, long Sequence)> _channel; + private readonly List<(SseItem Item, long Sequence)> _events = []; + private readonly object _lock = new(); + private long _sequence; + + public StreamState(string sessionId, string streamId, SseEventStreamMode mode) + { + SessionId = sessionId; + StreamId = streamId; + Mode = mode; + _channel = Channel.CreateUnbounded<(SseItem, long)>(); + } + + public string SessionId { get; } + public string StreamId { get; } + public SseEventStreamMode Mode { get; set; } + public bool IsCompleted { get; private set; } + + public long NextSequence() => Interlocked.Increment(ref _sequence); + + public void AddEvent(SseItem item, long sequence) + { + if (IsCompleted) + { + throw new InvalidOperationException("Cannot add events to a completed stream."); + } + + lock (_lock) + { + _events.Add((item, sequence)); + } + _channel.Writer.TryWrite((item, sequence)); + } + + public List> GetEventsAfter(long sequence) + { + lock (_lock) + { + var result = new List>(); + foreach (var (item, seq) in _events) + { + if (seq > sequence) + { + result.Add(item); + } + } + return result; + } + } + + public ChannelReader<(SseItem Item, long Sequence)> Reader => _channel.Reader; + + public void Complete() + { + IsCompleted = true; + _channel.Writer.TryComplete(); + } + } + + private sealed class InMemoryEventStreamWriter : ISseEventStreamWriter + { + private readonly TestSseEventStreamStore _store; + private readonly StreamState _state; + private bool _disposed; + + public InMemoryEventStreamWriter(TestSseEventStreamStore store, StreamState state) + { + _store = store; + _state = state; + } + + public string StreamId => _state.StreamId; + public SseEventStreamMode Mode => _state.Mode; + + public ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default) + { + _state.Mode = mode; + return default; + } + + public ValueTask> WriteEventAsync(SseItem sseItem, CancellationToken cancellationToken = default) + { + // Skip if already has an event ID + if (sseItem.EventId is not null) + { + return new ValueTask>(sseItem); + } + + var sequence = _state.NextSequence(); + var eventId = _store.GenerateEventId(); + var newItem = sseItem with { EventId = eventId }; + + _state.AddEvent(newItem, sequence); + _store.TrackEvent(eventId, _state, sequence); + + return new ValueTask>(newItem); + } + + public ValueTask DisposeAsync() + { + if (_disposed) + { + return default; + } + + _disposed = true; + _state.Complete(); + return default; + } + } + + private sealed class InMemoryEventStreamReader : ISseEventStreamReader + { + private readonly StreamState _state; + private readonly long _startSequence; + + public InMemoryEventStreamReader(StreamState state, long startSequence) + { + _state = state; + _startSequence = startSequence; + } + + public string SessionId => _state.SessionId; + public string StreamId => _state.StreamId; + + public async IAsyncEnumerable> ReadEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // First, return any events that were already written after the start sequence + var existingEvents = _state.GetEventsAfter(_startSequence); + long lastSeenSequence = _startSequence; + foreach (var evt in existingEvents) + { + yield return evt; + } + + // If in polling mode, stop after returning currently available events + if (_state.Mode == SseEventStreamMode.Polling) + { + yield break; + } + + // If the stream is already completed, stop + if (_state.IsCompleted) + { + yield break; + } + + // Wait for new events from the channel + await foreach (var (item, sequence) in _state.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) + { + // Only yield events we haven't seen yet + if (sequence > lastSeenSequence) + { + lastSeenSequence = sequence; + yield return item; + } + } + } + } +} From 63fb0223d3726a0f24145cde0dc2fce1f873f15a Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 16 Dec 2025 16:07:59 -0800 Subject: [PATCH 03/10] Server-side resumability via `Last-Event-ID` --- .../HttpServerTransportOptions.cs | 28 +++ .../McpEndpointRouteBuilderExtensions.cs | 8 +- .../StreamableHttpHandler.cs | 113 ++++++++++-- .../McpSessionHandler.cs | 18 +- .../Server/SseResponseStreamTransport.cs | 2 +- .../Server/SseWriter.cs | 42 +++-- .../Server/StreamableHttpPostTransport.cs | 81 ++++++--- .../Server/StreamableHttpServerTransport.cs | 166 ++++++++++++++++-- .../StatelessServerTests.cs | 5 +- .../Client/McpClientTests.cs | 24 +-- 10 files changed, 411 insertions(+), 76 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 67f4f4e1d..bbbcf6beb 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -43,6 +43,34 @@ public class HttpServerTransportOptions /// public bool Stateless { get; set; } + /// + /// Gets or sets the event store for resumability support. + /// When set, events are stored and can be replayed when clients reconnect with a Last-Event-ID header. + /// + /// + /// When configured, the server will: + /// + /// Generate unique event IDs for each SSE message + /// Store events for later replay + /// Replay missed events when a client reconnects with a Last-Event-ID header + /// Send priming events to establish resumability before any actual messages + /// + /// + public ISseEventStreamStore? EventStreamStore { get; set; } + + /// + /// Gets or sets the retry interval to suggest to clients in SSE retry field. + /// + /// + /// The retry interval. The default is 1 second. + /// + /// + /// When is set, the server will include a retry field in priming events. + /// This value suggests to clients how long to wait before attempting to reconnect after a connection is lost. + /// Clients may use this value to implement polling behavior during long-running operations. + /// + public TimeSpan RetryInterval { get; set; } = TimeSpan.FromSeconds(1); + /// /// Gets or sets a value that indicates whether the server uses a single execution context for the entire session. /// diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 7c05ac102..99f881b71 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -38,12 +38,12 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); + streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + if (!streamableHttpHandler.HttpServerTransportOptions.Stateless) { - // The GET and DELETE endpoints are not mapped in Stateless mode since there's no way to send unsolicited messages - // for the GET to handle, and there is no server-side state for the DELETE to clean up. - streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + // The DELETE endpoints are not mapped in Stateless mode since there's no server-side state to clean up. streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); // Map legacy HTTP with SSE endpoints only if not in Stateless mode, because we cannot guarantee the /message requests diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 9f4af7ea5..838e4e2da 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -23,6 +23,7 @@ internal sealed class StreamableHttpHandler( ILoggerFactory loggerFactory) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; + private const string LastEventIdHeaderName = "Last-Event-ID"; private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); @@ -81,17 +82,80 @@ await WriteJsonRpcErrorAsync(context, return; } + StreamableHttpSession? session = null; + ISseEventStreamReader? eventStreamReader = null; + var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString(); - var session = await GetSessionAsync(context, sessionId); + var lastEventId = context.Request.Headers[LastEventIdHeaderName].ToString(); + + if (!string.IsNullOrEmpty(sessionId)) + { + session = await GetSessionAsync(context, sessionId); + if (session is null) + { + // There was an error obtaining the session; consider the request failed. + return; + } + } + + if (!string.IsNullOrEmpty(lastEventId)) + { + if (HttpServerTransportOptions.Stateless) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: The Last-Event-ID header is not supported in stateless mode.", + StatusCodes.Status400BadRequest); + return; + } + + eventStreamReader = await GetEventStreamReaderAsync(context, lastEventId); + if (eventStreamReader is null) + { + // There was an error obtaining the event stream; consider the request failed. + return; + } + } + + if (session is not null && eventStreamReader is not null && !string.Equals(session.Id, eventStreamReader.SessionId, StringComparison.Ordinal)) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: The Last-Event-ID header refers to a session with a different session ID.", + StatusCodes.Status400BadRequest); + return; + } + + if (eventStreamReader is null || string.Equals(eventStreamReader.StreamId, StreamableHttpServerTransport.UnsolicitedMessageStreamId, StringComparison.Ordinal)) + { + await HandleUnsolicitedMessageStreamAsync(context, session, eventStreamReader); + } + else + { + await HandleResumePostResponseStreamAsync(context, eventStreamReader); + } + } + + private async Task HandleUnsolicitedMessageStreamAsync(HttpContext context, StreamableHttpSession? session, ISseEventStreamReader? eventStreamReader) + { + if (HttpServerTransportOptions.Stateless) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: Unsolicited messages are not supported in stateless mode.", + StatusCodes.Status400BadRequest); + return; + } + if (session is null) { + await WriteJsonRpcErrorAsync(context, + "Bad Request: Mcp-Session-Id header is required", + StatusCodes.Status400BadRequest); return; } - if (!session.TryStartGetRequest()) + if (eventStreamReader is null && !session.TryStartGetRequest()) { await WriteJsonRpcErrorAsync(context, - "Bad Request: This server does not support multiple GET requests. Start a new session to get a new GET SSE response.", + "Bad Request: This server does not support multiple GET requests. Use Last-Event-ID header to resume or start a new session.", StatusCodes.Status400BadRequest); return; } @@ -111,7 +175,7 @@ await WriteJsonRpcErrorAsync(context, // will be sent in response to a different POST request. It might be a while before we send a message // over this response body. await context.Response.Body.FlushAsync(cancellationToken); - await session.Transport.HandleGetRequestAsync(context.Response.Body, cancellationToken); + await session.Transport.HandleGetRequestAsync(context.Response.Body, eventStreamReader, cancellationToken); } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { @@ -120,6 +184,12 @@ await WriteJsonRpcErrorAsync(context, } } + private static async Task HandleResumePostResponseStreamAsync(HttpContext context, ISseEventStreamReader eventStreamReader) + { + InitializeSseResponse(context); + await eventStreamReader.CopyToAsync(context.Response.Body, context.RequestAborted); + } + public async Task HandleDeleteRequestAsync(HttpContext context) { var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString(); @@ -131,14 +201,7 @@ public async Task HandleDeleteRequestAsync(HttpContext context) private async ValueTask GetSessionAsync(HttpContext context, string sessionId) { - StreamableHttpSession? session; - - if (string.IsNullOrEmpty(sessionId)) - { - await WriteJsonRpcErrorAsync(context, "Bad Request: Mcp-Session-Id header is required", StatusCodes.Status400BadRequest); - return null; - } - else if (!sessionManager.TryGetValue(sessionId, out session)) + if (!sessionManager.TryGetValue(sessionId, out var session)) { // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this @@ -194,12 +257,16 @@ private async ValueTask StartNewSessionAsync(HttpContext { SessionId = sessionId, FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext, + EventStreamStore = HttpServerTransportOptions.EventStreamStore, + RetryInterval = HttpServerTransportOptions.RetryInterval, }; context.Response.Headers[McpSessionIdHeaderName] = sessionId; } else { // In stateless mode, each request is independent. Don't set any session ID on the transport. + // If in the future we support resuming stateless requests, we should populate + // the event stream store and retry interval here as well. sessionId = ""; transport = new() { @@ -246,6 +313,28 @@ private async ValueTask CreateSessionAsync( return session; } + private async ValueTask GetEventStreamReaderAsync(HttpContext context, string lastEventId) + { + if (HttpServerTransportOptions.EventStreamStore is not { } eventStreamStore) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: This server does not support resuming streams.", + StatusCodes.Status400BadRequest); + return null; + } + + var eventStreamReader = await eventStreamStore.GetStreamReaderAsync(lastEventId, context.RequestAborted); + if (eventStreamReader is null) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: The specified Last-Event-ID is either invalid or expired.", + StatusCodes.Status400BadRequest); + return null; + } + + return eventStreamReader; + } + private static Task WriteJsonRpcErrorAsync(HttpContext context, string errorMessage, int statusCode, int errorCode = -32000) { var jsonRpcError = new JsonRpcError diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index dd6814640..683e1d18d 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -29,16 +29,32 @@ internal sealed partial class McpSessionHandler : IAsyncDisposable "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); /// The latest version of the protocol supported by this implementation. - internal const string LatestProtocolVersion = "2025-06-18"; + internal const string LatestProtocolVersion = "2025-11-25"; /// All protocol versions supported by this implementation. internal static readonly string[] SupportedProtocolVersions = [ "2024-11-05", "2025-03-26", + "2025-06-18", LatestProtocolVersion, ]; + /// + /// Checks if the given protocol version supports priming events. + /// + /// The protocol version to check. + /// True if the protocol version supports resumability. + /// + /// Priming events are only supported in protocol version >= 2025-11-25. + /// Older clients may crash when receiving SSE events with empty data. + /// + internal static bool SupportsPrimingEvent(string? protocolVersion) + { + const string MinResumabilityProtocolVersion = "2025-11-25"; + return string.Compare(protocolVersion, MinResumabilityProtocolVersion, StringComparison.Ordinal) >= 0; + } + private readonly bool _isServer; private readonly string _transportKind; private readonly ITransport _transport; diff --git a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs index 5d7315241..4c784712f 100644 --- a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs @@ -67,7 +67,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can { Throw.IfNull(message); // If the underlying writer has been disposed, just drop the message. - await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + await _sseWriter.SendMessageAsync(message, eventStreamWriter: null, cancellationToken).ConfigureAwait(false); } /// diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs index a2314e623..7ccd45282 100644 --- a/src/ModelContextProtocol.Core/Server/SseWriter.cs +++ b/src/ModelContextProtocol.Core/Server/SseWriter.cs @@ -22,8 +22,6 @@ internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOp private readonly SemaphoreSlim _disposeLock = new(1, 1); private bool _disposed; - public Func>, CancellationToken, IAsyncEnumerable>>? MessageFilter { get; set; } - public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken) { Throw.IfNull(sseResponseStream); @@ -38,21 +36,40 @@ public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellati _writeCancellationToken = cancellationToken; var messages = _messages.Reader.ReadAllAsync(cancellationToken); - if (MessageFilter is not null) - { - messages = MessageFilter(messages, cancellationToken); - } - _writeTask = SseFormatter.WriteAsync(messages, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken); return _writeTask; } - public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + public Task SendPrimingEventAsync(TimeSpan retryInterval, ISseEventStreamWriter eventStreamWriter, CancellationToken cancellationToken = default) + { + // Create a priming event: empty data with an event ID + var primingItem = new SseItem(null, "prime") + { + ReconnectionInterval = retryInterval, + }; + + return SendMessageAsync(primingItem, eventStreamWriter, cancellationToken); + } + + public Task SendMessageAsync(JsonRpcMessage message, ISseEventStreamWriter? eventStreamWriter, CancellationToken cancellationToken = default) { Throw.IfNull(message); + // Emit redundant "event: message" lines for better compatibility with other SDKs. + return SendMessageAsync(new SseItem(message, SseParser.EventTypeDefault), eventStreamWriter, cancellationToken); + } + + private async Task SendMessageAsync(SseItem item, ISseEventStreamWriter? eventStreamWriter, CancellationToken cancellationToken = default) + { using var _ = await _disposeLock.LockAsync(cancellationToken).ConfigureAwait(false); + if (eventStreamWriter is not null && item.EventId is null) + { + // Store the event first, even if the underlying writer has completed, so that + // messages can still be retrieved from the event store. + item = await eventStreamWriter.WriteEventAsync(item, cancellationToken: cancellationToken).ConfigureAwait(false); + } + if (_disposed) { // Don't throw ObjectDisposedException here; just return false to indicate the message wasn't sent. @@ -60,8 +77,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationTok return false; } - // Emit redundant "event: message" lines for better compatibility with other SDKs. - await _messages.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false); + await _messages.Writer.WriteAsync(item, cancellationToken).ConfigureAwait(false); return true; } @@ -101,7 +117,10 @@ private void WriteJsonRpcMessageToBuffer(SseItem item, IBufferW return; } - JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!); + if (item.Data is not null) + { + JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!); + } } private Utf8JsonWriter GetUtf8JsonWriter(IBufferWriter writer) @@ -118,3 +137,4 @@ private Utf8JsonWriter GetUtf8JsonWriter(IBufferWriter writer) return _jsonWriter; } } + diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 1109c2b2b..755eb3d99 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -1,9 +1,5 @@ using ModelContextProtocol.Protocol; using System.Diagnostics; -using System.IO.Pipelines; -using System.Net.ServerSentEvents; -using System.Runtime.CompilerServices; -using System.Security.Claims; using System.Text.Json; using System.Threading.Channels; @@ -17,13 +13,15 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport { private readonly SseWriter _sseWriter = new(); private RequestId _pendingRequest; + private ISseEventStreamWriter? _eventStreamWriter; + private SemaphoreSlim _eventStreamLock = new(1, 1); public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.Context.RelatedTransport should only be used for sending messages."); string? ITransport.SessionId => parentTransport.SessionId; /// - /// True, if data was written to the respond body. + /// True, if data was written to the response body. /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// @@ -35,11 +33,11 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio { _pendingRequest = request.Id; - // Invoke the initialize request callback if applicable. - if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) + // Invoke the initialize request handler if applicable. + if (request.Method == RequestMethods.Initialize) { var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); - await onInitRequest(initializeRequest).ConfigureAwait(false); + await parentTransport.HandleInitRequestAsync(initializeRequest).ConfigureAwait(false); } } @@ -58,8 +56,17 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio return false; } - _sseWriter.MessageFilter = StopOnFinalResponseFilter; - await _sseWriter.WriteAllAsync(responseStream, cancellationToken).ConfigureAwait(false); + // Start the write task immediately so that we don't risk filling up the channel with + // messages before they start being consumed. + var writeTask = _sseWriter.WriteAllAsync(responseStream, cancellationToken); + + var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); + if (eventStreamWriter is not null) + { + await _sseWriter.SendPrimingEventAsync(parentTransport.RetryInterval, eventStreamWriter, cancellationToken).ConfigureAwait(false); + } + + await writeTask.ConfigureAwait(false); return true; } @@ -72,31 +79,63 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can throw new InvalidOperationException("Server to client requests are not supported in stateless mode."); } - bool isAccepted = await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); - if (!isAccepted) + var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); + + var isAccepted = await _sseWriter.SendMessageAsync(message, eventStreamWriter, cancellationToken).ConfigureAwait(false); + if (!isAccepted && eventStreamWriter is null) { - // The underlying writer didn't accept the message because the underlying request has completed. + // The underlying writer didn't accept the message because the underlying request has completed, + // and there isn't a fallback event stream writer. // Rather than drop the message, fall back to sending it via the parent transport. await parentTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } + + if (message is JsonRpcResponse or JsonRpcError && ((JsonRpcMessageWithId)message).Id == _pendingRequest) + { + // Complete the SSE response stream and SSE event stream writer now that all pending requests have been processed. + await _sseWriter.DisposeAsync().ConfigureAwait(false); + + if (_eventStreamWriter is not null) + { + await _eventStreamWriter.DisposeAsync().ConfigureAwait(false); + } + } } public async ValueTask DisposeAsync() { await _sseWriter.DisposeAsync().ConfigureAwait(false); + + // Don't dispose the event stream writer here, as we may continue to write to the event store + // after disposal. } - private async IAsyncEnumerable> StopOnFinalResponseFilter(IAsyncEnumerable> messages, [EnumeratorCancellation] CancellationToken cancellationToken) + private async ValueTask GetOrCreateEventStreamAsync(CancellationToken cancellationToken) { - await foreach (var message in messages.WithCancellation(cancellationToken)) + using var _ = await _eventStreamLock.LockAsync(cancellationToken).ConfigureAwait(false); + + if (_eventStreamWriter is not null) { - yield return message; + return _eventStreamWriter; + } - if (message.Data is JsonRpcResponse or JsonRpcError && ((JsonRpcMessageWithId)message.Data).Id == _pendingRequest) - { - // Complete the SSE response stream now that all pending requests have been processed. - break; - } + if (parentTransport.EventStreamStore is not { } eventStreamStore || _pendingRequest.Id is null || !McpSessionHandler.SupportsPrimingEvent(parentTransport.NegotiatedProtocolVersion)) + { + return null; } + + // We use the 'Default' stream mode so that in the case of an unexpected network disconnection, + // the client can continue reading the remaining messages in a single, streamed response. + // This may be changed to 'Polling' if the transport is later explicitly switched to polling mode. + const SseEventStreamMode Mode = SseEventStreamMode.Default; + + _eventStreamWriter = await eventStreamStore.CreateStreamAsync(options: new() + { + SessionId = parentTransport.SessionId ?? Guid.NewGuid().ToString("N"), + StreamId = _pendingRequest.Id.ToString()!, + Mode = Mode, + }, cancellationToken).ConfigureAwait(false); + + return _eventStreamWriter; } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 8a64094e4..486bd62bf 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -1,5 +1,4 @@ using ModelContextProtocol.Protocol; -using System.IO.Pipelines; using System.Security.Claims; using System.Threading.Channels; @@ -21,21 +20,30 @@ namespace ModelContextProtocol.Server; /// public sealed class StreamableHttpServerTransport : ITransport { + /// + /// The stream ID used for unsolicited messages sent via the standalone GET SSE stream. + /// + public static readonly string UnsolicitedMessageStreamId = "__get__"; + // For JsonRpcMessages without a RelatedTransport, we don't want to block just because the client didn't make a GET request to handle unsolicited messages. - private readonly SseWriter _sseWriter = new(channelOptions: new BoundedChannelOptions(1) + private static readonly BoundedChannelOptions _sseWriterChannelOptions = new(1) { SingleReader = true, SingleWriter = false, FullMode = BoundedChannelFullMode.DropOldest, - }); + }; private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1) { SingleReader = true, SingleWriter = false, }); private readonly CancellationTokenSource _disposeCts = new(); + private readonly SemaphoreSlim _sendLock = new(1, 1); - private int _getRequestStarted; + private SseWriter _sseWriter = new(channelOptions: _sseWriterChannelOptions); + private ISseEventStreamWriter? _eventStreamWriter; + private bool _getRequestStarted; + private bool _disposed; /// public string? SessionId { get; set; } @@ -63,20 +71,73 @@ public sealed class StreamableHttpServerTransport : ITransport /// public Func? OnInitRequestReceived { get; set; } + /// + /// Gets or sets the event store for resumability support. + /// When set, events are stored and can be replayed when clients reconnect with a Last-Event-ID header. + /// + public ISseEventStreamStore? EventStreamStore { get; set; } + + /// + /// Gets or sets the retry interval to suggest to clients in SSE retry field. + /// When is set, the server will include a retry field in priming events. + /// + /// + /// The default value is 1 second. + /// + public TimeSpan RetryInterval { get; set; } = TimeSpan.FromSeconds(1); + + /// + /// Gets or sets the negotiated protocol version for this session. + /// + internal string? NegotiatedProtocolVersion { get; private set; } + /// public ChannelReader MessageReader => _incomingChannel.Reader; internal ChannelWriter MessageWriter => _incomingChannel.Writer; + /// + /// Handles the initialize request by capturing the protocol version and invoking the user callback. + /// + internal async ValueTask HandleInitRequestAsync(InitializeRequestParams? initParams) + { + // Capture the negotiated protocol version for resumability checks + NegotiatedProtocolVersion = initParams?.ProtocolVersion; + + // Invoke user-provided callback if specified + if (OnInitRequestReceived is { } callback) + { + await callback(initParams).ConfigureAwait(false); + } + } + + /// + /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by + /// writing any unsolicited JSON-RPC messages sent via + /// to the SSE response stream until cancellation is requested or the transport is disposed. + /// + /// The response stream to write MCP JSON-RPC messages as SSE events to. + /// The to monitor for cancellation requests. The default is . + /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. + public Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default) + => HandleGetRequestAsync(sseResponseStream, eventStreamReader: null, cancellationToken); + /// /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by /// writing any unsolicited JSON-RPC messages sent via /// to the SSE response stream until cancellation is requested or the transport is disposed. /// /// The response stream to write MCP JSON-RPC messages as SSE events to. + /// The to replay events from before writing this transport's messages to the response stream. /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default) + public async Task HandleGetRequestAsync(Stream sseResponseStream, ISseEventStreamReader? eventStreamReader, CancellationToken cancellationToken = default) + { + var writeTask = await StartGetRequestAsync(sseResponseStream, eventStreamReader, cancellationToken).ConfigureAwait(false); + await writeTask.ConfigureAwait(false); + } + + private async Task StartGetRequestAsync(Stream sseResponseStream, ISseEventStreamReader? eventStreamReader, CancellationToken cancellationToken) { Throw.IfNull(sseResponseStream); @@ -85,13 +146,41 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo throw new InvalidOperationException("GET requests are not supported in stateless mode."); } - if (Interlocked.Exchange(ref _getRequestStarted, 1) == 1) + using var _ = await _sendLock.LockAsync(cancellationToken); + + ThrowIfDisposed(); + + if (_getRequestStarted) { - throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); + await _sseWriter.DisposeAsync().ConfigureAwait(false); + _sseWriter = new(); + } + + _getRequestStarted = true; + + if (eventStreamReader is not null) + { + if (eventStreamReader.SessionId != SessionId) + { + throw new InvalidOperationException("The provided SSE event stream reader relates to a different session."); + } + + if (eventStreamReader.StreamId != UnsolicitedMessageStreamId) + { + throw new InvalidOperationException("The event stream reader does not relate to the unsolicited message stream."); + } + + await eventStreamReader.CopyToAsync(sseResponseStream, cancellationToken); + } + + var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); + if (eventStreamWriter is not null) + { + await _sseWriter.SendPrimingEventAsync(RetryInterval, eventStreamWriter, cancellationToken).ConfigureAwait(false); } // We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully. - await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false); + return _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken); } /// @@ -131,13 +220,26 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can throw new InvalidOperationException("Unsolicited server to client messages are not supported in stateless mode."); } - // If the underlying writer has been disposed, just drop the message. - await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + using var _ = await _sendLock.LockAsync(cancellationToken); + + ThrowIfDisposed(); + + // If the underlying writer has been disposed, rely on the event stream writer, if present. + // Otherwise, just drop the message. + var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); + await _sseWriter.SendMessageAsync(message, eventStreamWriter, cancellationToken).ConfigureAwait(false); } /// public async ValueTask DisposeAsync() { + using var _ = await _sendLock.LockAsync(); + + if (_disposed) + { + return; + } + try { _incomingChannel.Writer.TryComplete(); @@ -148,11 +250,55 @@ public async ValueTask DisposeAsync() try { await _sseWriter.DisposeAsync().ConfigureAwait(false); + + if (_eventStreamWriter is not null) + { + await _eventStreamWriter.DisposeAsync().ConfigureAwait(false); + } } finally { _disposeCts.Dispose(); + _disposed = true; } } } + + private async ValueTask GetOrCreateEventStreamAsync(CancellationToken cancellationToken) + { + if (_eventStreamWriter is not null) + { + return _eventStreamWriter; + } + + if (EventStreamStore is null || !McpSessionHandler.SupportsPrimingEvent(NegotiatedProtocolVersion)) + { + return null; + } + + // We set the mode to 'Polling' so that the transport can take over writing to the response stream after + // messages have been replayed. + const SseEventStreamMode Mode = SseEventStreamMode.Polling; + + _eventStreamWriter = await EventStreamStore.CreateStreamAsync(options: new() + { + SessionId = SessionId ?? Guid.NewGuid().ToString("N"), + StreamId = UnsolicitedMessageStreamId, + Mode = Mode, + }, cancellationToken).ConfigureAwait(false); + + return _eventStreamWriter; + } + + private void ThrowIfDisposed() + { +#if NET + ObjectDisposedException.ThrowIf(_disposed, this); +#else + if (_disposed) + { + throw new ObjectDisposedException(GetType().Name); + } +#endif + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index a843e2975..91af874ef 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -85,13 +85,10 @@ public async Task EnablingStatelessMode_Disables_SseEndpoints() } [Fact] - public async Task EnablingStatelessMode_Disables_GetAndDeleteEndpoints() + public async Task EnablingStatelessMode_Disables_DeleteEndpoint() { await StartAsync(); - using var getResponse = await HttpClient.GetAsync("/", TestContext.Current.CancellationToken); - Assert.Equal(HttpStatusCode.MethodNotAllowed, getResponse.StatusCode); - using var deleteResponse = await HttpClient.DeleteAsync("/", TestContext.Current.CancellationToken); Assert.Equal(HttpStatusCode.MethodNotAllowed, deleteResponse.StatusCode); } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs index b400c6b0b..e5f55685b 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -107,7 +107,7 @@ public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperat { Messages = [ - new SamplingMessage + new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Hello" }] @@ -157,7 +157,7 @@ public async Task CreateSamplingHandler_ShouldHandleImageMessages() { Messages = [ - new SamplingMessage + new SamplingMessage { Role = Role.User, Content = [new ImageContentBlock @@ -492,7 +492,7 @@ public async Task AsClientLoggerProvider_MessagesSentToClient() public async Task ReturnsNegotiatedProtocolVersion(string? protocolVersion) { await using McpClient client = await CreateMcpClientForServer(new() { ProtocolVersion = protocolVersion }); - Assert.Equal(protocolVersion ?? "2025-06-18", client.NegotiatedProtocolVersion); + Assert.Equal(protocolVersion ?? "2025-11-25", client.NegotiatedProtocolVersion); } [Fact] @@ -500,7 +500,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn { int getWeatherToolCallCount = 0; int askClientToolCallCount = 0; - + Server.ServerOptions.ToolCollection?.Add(McpServerTool.Create( async (McpServer server, string query, CancellationToken cancellationToken) => { @@ -513,14 +513,14 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn return $"Weather in {location}: sunny, 22°C"; }, "get_weather", "Gets the weather for a location"); - + var response = await server .AsSamplingChatClient() .AsBuilder() .UseFunctionInvocation() .Build() .GetResponseAsync(query, new ChatOptions { Tools = [weatherTool] }, cancellationToken); - + return response.Text ?? "No response"; }, new() { Name = "ask_client", Description = "Asks the client a question using sampling" })); @@ -530,7 +530,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn { int currentCall = samplingCallCount++; var lastMessage = messages.LastOrDefault(); - + // First call: Return a tool call request for get_weather if (currentCall == 0) { @@ -552,7 +552,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn string resultText = toolResult.Result?.ToString() ?? string.Empty; Assert.Contains("Weather in Paris: sunny", resultText); - + return Task.FromResult(new([ new ChatMessage(ChatRole.User, messages.First().Contents), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call_weather_123", "get_weather", new Dictionary { ["location"] = "Paris" })]), @@ -577,7 +577,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Null(result.IsError); - + var textContent = result.Content.OfType().FirstOrDefault(); Assert.NotNull(textContent); Assert.Contains("Weather in Paris: sunny, 22", textContent.Text); @@ -585,7 +585,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn Assert.Equal(1, askClientToolCallCount); Assert.Equal(2, samplingCallCount); } - + /// Simple test IChatClient implementation for testing. private sealed class TestChatClient(Func, ChatOptions?, CancellationToken, Task> getResponse) : IChatClient { @@ -594,7 +594,7 @@ public Task GetResponseAsync( ChatOptions? options = null, CancellationToken cancellationToken = default) => getResponse(messages, options, cancellationToken); - + async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( IEnumerable messages, ChatOptions? options, @@ -606,7 +606,7 @@ async IAsyncEnumerable IChatClient.GetStreamingResponseAsync yield return update; } } - + object? IChatClient.GetService(Type serviceType, object? serviceKey) => null; void IDisposable.Dispose() { } } From 11494e1e5808109976bec2ebeb63368c1c4d416d Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 16 Dec 2025 16:11:27 -0800 Subject: [PATCH 04/10] Server-side disconnect --- .../Server/RequestContext.cs | 17 +++++++++++++++++ .../Server/StreamableHttpPostTransport.cs | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs index 6f1bc8566..c77e7fb6c 100644 --- a/src/ModelContextProtocol.Core/Server/RequestContext.cs +++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs @@ -81,4 +81,21 @@ public McpServer Server /// including the method name, parameters, request ID, and associated transport and user information. /// public JsonRpcRequest JsonRpcRequest { get; } + + /// + /// Ends the current response and enables polling for updates from the server. + /// + /// The interval at which the client should poll for updates. + /// The cancellation token. + /// A that completes when polling has been enabled. + public async ValueTask EnablePollingAsync(TimeSpan retryInterval, CancellationToken cancellationToken = default) + { + if (JsonRpcRequest.Context?.RelatedTransport is not StreamableHttpPostTransport transport) + { + // Polling is only supported for Streamable HTTP POST transports. + return; + } + + await transport.EnablePollingAsync(retryInterval, cancellationToken); + } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 755eb3d99..56a36e717 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -102,6 +102,25 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can } } + public async ValueTask EnablePollingAsync(TimeSpan retryInterval, CancellationToken cancellationToken) + { + var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); + if (eventStreamWriter is null) + { + return; + } + + // Set the mode to 'Polling' so that the replay stream ends as soon as all available messages have been sent. + // This prevents the client from immediately establishing another long-lived connection. + await eventStreamWriter.SetModeAsync(SseEventStreamMode.Polling, cancellationToken).ConfigureAwait(false); + + // Send the priming event with the new retry interval. + await _sseWriter.SendPrimingEventAsync(retryInterval, eventStreamWriter, cancellationToken).ConfigureAwait(false); + + // Dispose the writer to close it and force future writes to only apply to the SSE event store. + await _sseWriter.DisposeAsync(); + } + public async ValueTask DisposeAsync() { await _sseWriter.DisposeAsync().ConfigureAwait(false); From 493062a737735dbd3babd497955225738c9bc3b5 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 16 Dec 2025 16:11:50 -0800 Subject: [PATCH 05/10] Integration tests --- .../ResumabilityIntegrationTests.cs | 459 ++++++++++++++++++ .../Utils/FaultingStreamHandler.cs | 157 ++++++ .../Utils/KestrelInMemoryTest.cs | 30 +- 3 files changed, 645 insertions(+), 1 deletion(-) create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs new file mode 100644 index 000000000..2f6b31779 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs @@ -0,0 +1,459 @@ +using System.ComponentModel; +using System.Diagnostics; +using System.Net; +using System.Net.ServerSentEvents; +using System.Text; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Integration tests for SSE resumability with full client-server flow. +/// These tests use McpClient for end-to-end testing and only use raw HTTP +/// for SSE format verification where McpClient abstracts away the details. +/// +public class ResumabilityIntegrationTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + private const string InitializeRequest = """ + {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"TestClient","version":"1.0.0"}}} + """; + + [Fact] + public async Task Server_StoresEvents_WhenEventStoreConfigured() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act - Make a tool call which generates events + var result = await client.CallToolAsync("echo", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + // Assert - Events were stored + Assert.NotNull(result); + Assert.True(eventStreamStore.StoreEventCallCount > 0, "Expected events to be stored when EventStore is configured"); + } + + [Fact] + public async Task Server_StoresMultipleEvents_ForMultipleToolCalls() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act - Make multiple tool calls + var initialCount = eventStreamStore.StoreEventCallCount; + + await client.CallToolAsync("echo", + new Dictionary { ["message"] = "test1" }, + cancellationToken: TestContext.Current.CancellationToken); + + var countAfterFirst = eventStreamStore.StoreEventCallCount; + + await client.CallToolAsync("echo", + new Dictionary { ["message"] = "test2" }, + cancellationToken: TestContext.Current.CancellationToken); + + var countAfterSecond = eventStreamStore.StoreEventCallCount; + + // Assert - More events were stored for each call + Assert.True(countAfterFirst > initialCount, "Expected more events after first call"); + Assert.True(countAfterSecond > countAfterFirst, "Expected more events after second call"); + } + + [Fact] + public async Task Client_CanMakeMultipleRequests_WithResumabilityEnabled() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act - Make many requests to verify stability + for (int i = 0; i < 5; i++) + { + var result = await client.CallToolAsync("echo", + new Dictionary { ["message"] = $"test{i}" }, + cancellationToken: TestContext.Current.CancellationToken); + + var textContent = Assert.Single(result.Content.OfType()); + Assert.Equal($"Echo: test{i}", textContent.Text); + } + + // Assert - All requests succeeded and events were stored + Assert.True(eventStreamStore.StoreEventCallCount >= 5, "Expected events to be stored for each request"); + } + + [Fact] + public async Task Ping_WorksWithResumabilityEnabled() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act & Assert - Ping should work + await client.PingAsync(cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task ListTools_WorksWithResumabilityEnabled() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(tools); + Assert.Single(tools); + } + + [Fact] + public async Task Server_IncludesEventIdAndRetry_InSseResponse() + { + // Arrange + var expectedRetryInterval = TimeSpan.FromSeconds(5); + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore, retryInterval: expectedRetryInterval); + + // Act + var sseResponse = await SendInitializeAndReadSseResponseAsync(InitializeRequest); + + // Assert - Event IDs and retry field should be present in the response + Assert.True(sseResponse.LastEventId is not null, "Expected SSE response to contain event IDs"); + Assert.Equal(expectedRetryInterval, sseResponse.RetryInterval); + } + + [Fact] + public async Task Server_WithoutEventStore_DoesNotIncludeEventIdAndRetry() + { + // Arrange - Server without event store + await using var app = await CreateServerAsync(); + + // Act + var sseResponse = await SendInitializeAndReadSseResponseAsync(InitializeRequest); + + // Assert - No event IDs or retry field when EventStore is not configured + Assert.True(sseResponse.LastEventId is null, "Did not expect event IDs when EventStore is not configured"); + Assert.True(sseResponse.RetryInterval is null, "Did not expect retry field when EventStore is not configured"); + } + + [Fact] + public async Task Server_DoesNotSendPrimingEvents_ToOlderProtocolVersionClients() + { + // Arrange - Server with resumability enabled + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore, retryInterval: TimeSpan.FromSeconds(5)); + + // Use an older protocol version that doesn't support resumability + const string OldProtocolInitRequest = """ + {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"OldClient","version":"1.0.0"}}} + """; + + var sseResponse = await SendInitializeAndReadSseResponseAsync(OldProtocolInitRequest); + + // Assert - Old clients should not receive event IDs or retry fields (no priming events) + Assert.True(sseResponse.LastEventId is null, "Old protocol clients should not receive event IDs"); + Assert.True(sseResponse.RetryInterval is null, "Old protocol clients should not receive retry field"); + + // Event store should not have been called for old clients + Assert.Equal(0, eventStreamStore.StoreEventCallCount); + } + + [Fact] + public async Task Client_ReceivesRetryInterval_FromServer() + { + // Arrange - Server with specific retry interval + var expectedRetry = TimeSpan.FromMilliseconds(3000); + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore, retryInterval: expectedRetry); + + // Act - Send initialize and read the retry field + var sseItem = await SendInitializeAndReadSseResponseAsync(InitializeRequest); + + // Assert - Client receives the retry interval from server + Assert.Equal(expectedRetry, sseItem.RetryInterval); + } + + [Fact] + public async Task Client_CanPollResponse_FromServer() + { + const string ProgressToolName = "progress_tool"; + var clientReceivedInitialValueTcs = new TaskCompletionSource(); + var clientReceivedPolledValueTcs = new TaskCompletionSource(); + var progressTool = McpServerTool.Create(async (RequestContext context, IProgress progress) => + { + progress.Report(new() { Progress = 0, Message = "Initial value" }); + + await clientReceivedInitialValueTcs.Task; + + await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1)); + + progress.Report(new() { Progress = 50, Message = "Polled value" }); + + await clientReceivedPolledValueTcs.Task; + + return "Complete"; + }, options: new() { Name = ProgressToolName }); + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore, configureServer: builder => + { + builder.WithTools([progressTool]); + }); + await using var client = await ConnectClientAsync(); + + var progressHandler = new Progress(value => + { + switch (value.Message) + { + case "Initial value": + Assert.True(clientReceivedInitialValueTcs.TrySetResult(), "Received the initial value more than once."); + break; + case "Polled value": + Assert.True(clientReceivedPolledValueTcs.TrySetResult(), "Received the polled value more than once."); + break; + default: + throw new UnreachableException($"Unknown progress message '{value.Message}'"); + } + }); + + var result = await client.CallToolAsync(ProgressToolName, progress: progressHandler, cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError is true); + Assert.Equal("Complete", result.Content.OfType().Single().Text); + } + + [Fact] + public async Task Client_CanResumePostResponseStream_AfterDisconnection() + { + var manualStreamHandler = new FaultingStreamHandler(); + SetHttpMessageHandler(manualStreamHandler); + + const string ProgressToolName = "progress_tool"; + var clientReceivedInitialValueTcs = new TaskCompletionSource(); + var clientReceivedReconnectValueTcs = new TaskCompletionSource(); + var progressTool = McpServerTool.Create(async (RequestContext context, IProgress progress, CancellationToken cancellationToken) => + { + progress.Report(new() { Progress = 0, Message = "Initial value" }); + + // Make sure the client receives one message before we disconnect. + await clientReceivedInitialValueTcs.Task; + + // Simulate a network disconnection by faulting the response stream. + var reconnectAttempt = await manualStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken); + + // Send another message that the client should receive after reconnecting. + progress.Report(new() { Progress = 50, Message = "Reconnect value" }); + + reconnectAttempt.Continue(); + + // Wait for the client to receive the message via replay. + await clientReceivedReconnectValueTcs.Task; + + // Return the final result with the client still connected. + return "Complete"; + }, options: new() { Name = ProgressToolName }); + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore, configureServer: builder => + { + builder.WithTools([progressTool]); + }); + await using var client = await ConnectClientAsync(); + + var progressHandler = new Progress(value => + { + switch (value.Message) + { + case "Initial value": + clientReceivedInitialValueTcs.SetResult(); + break; + case "Reconnect value": + clientReceivedReconnectValueTcs.SetResult(); + break; + default: + throw new UnreachableException($"Unknown progress message '{value.Message}'"); + } + }); + + var result = await client.CallToolAsync(ProgressToolName, progress: progressHandler, cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError is true); + Assert.Equal("Complete", result.Content.OfType().Single().Text); + } + + [Fact] + public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() + { + var faultingStreamHandler = new FaultingStreamHandler(); + SetHttpMessageHandler(faultingStreamHandler); + + var eventStreamStore = new TestSseEventStreamStore(); + + // Capture the server instance via RunSessionHandler + var serverTcs = new TaskCompletionSource(); + + await using var app = await CreateServerAsync(eventStreamStore, configureTransport: options => + { + options.RunSessionHandler = (httpContext, mcpServer, cancellationToken) => + { + serverTcs.TrySetResult(mcpServer); + return mcpServer.RunAsync(cancellationToken); + }; + }); + + await using var client = await ConnectClientAsync(); + + // Get the server instance + var server = await serverTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + // Set up notification tracking + var clientReceivedInitialNotificationTcs = new TaskCompletionSource(); + var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(); + var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(); + + const string CustomNotificationMethod = "test/custom_notification"; + + await using var _ = client.RegisterNotificationHandler(CustomNotificationMethod, (notification, cancellationToken) => + { + // First notification completes initial TCS, second completes replay TCS, thirdly completes reconnect TCS. + if (clientReceivedInitialNotificationTcs.TrySetResult()) + { + return default; + } + + if (clientReceivedReplayedNotificationTcs.TrySetResult()) + { + return default; + } + + if (clientReceivedReconnectNotificationTcs.TrySetResult()) + { + return default; + } + + return default; + }); + + // Send a custom notification to the client on the unsolicited message stream + await server.SendNotificationAsync(CustomNotificationMethod, TestContext.Current.CancellationToken); + + // Wait for client to receive the first notification + await clientReceivedInitialNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + // Fault the unsolicited message stream (GET SSE) + var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken); + + // Send another notification while the client is disconnected - this should be stored + await server.SendNotificationAsync(CustomNotificationMethod, TestContext.Current.CancellationToken); + + // Allow the client to reconnect + reconnectAttempt.Continue(); + + // Wait for client to receive the notification via replay + await clientReceivedReplayedNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + // Send a final notification while the client has reconnected - this should be handled by the transport + await server.SendNotificationAsync(CustomNotificationMethod, TestContext.Current.CancellationToken); + + // Wait for the client to receive the final notification + await clientReceivedReconnectNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + } + + [McpServerToolType] + private class ResumabilityTestTools + { + [McpServerTool(Name = "echo"), Description("Echoes the message back")] + public static string Echo(string message) => $"Echo: {message}"; + } + + private async Task CreateServerAsync( + ISseEventStreamStore? eventStreamStore = null, + TimeSpan? retryInterval = null, + Action? configureServer = null, + Action? configureTransport = null) + { + var serverBuilder = Builder.Services.AddMcpServer() + .WithHttpTransport(options => + { + options.EventStreamStore = eventStreamStore; + if (retryInterval.HasValue) + { + options.RetryInterval = retryInterval.Value; + } + configureTransport?.Invoke(options); + }) + .WithTools(); + + configureServer?.Invoke(serverBuilder); + + var app = Builder.Build(); + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + private async Task ConnectClientAsync() + { + var transport = new HttpClientTransport(new HttpClientTransportOptions + { + Endpoint = new Uri("http://localhost:5000/"), + TransportMode = HttpTransportMode.StreamableHttp, + }, HttpClient, LoggerFactory); + + return await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + } + + private async Task SendInitializeAndReadSseResponseAsync(string initializeRequest) + { + using var requestContent = new StringContent(initializeRequest, Encoding.UTF8, "application/json"); + using var request = new HttpRequestMessage(HttpMethod.Post, "/") + { + Headers = + { + Accept = { new("application/json"), new("text/event-stream") } + }, + Content = requestContent, + }; + + var response = await HttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, + TestContext.Current.CancellationToken); + + response.EnsureSuccessStatusCode(); + + var sseResponse = new SseResponse(); + await using var stream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); + await foreach (var sseItem in SseParser.Create(stream).EnumerateAsync(TestContext.Current.CancellationToken)) + { + if (!string.IsNullOrEmpty(sseItem.EventId)) + { + sseResponse.LastEventId = sseItem.EventId; + } + if (sseItem.ReconnectionInterval.HasValue) + { + sseResponse.RetryInterval = sseItem.ReconnectionInterval.Value; + } + } + + return sseResponse; + } + + private struct SseResponse + { + public string? LastEventId { get; set; } + public TimeSpan? RetryInterval { get; set; } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs new file mode 100644 index 000000000..c42fe082d --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs @@ -0,0 +1,157 @@ +using System.Diagnostics; +using System.Net; + +namespace ModelContextProtocol.AspNetCore.Tests.Utils; + +/// +/// A message handler that wraps SSE response streams and can trigger faults mid-stream +/// to simulate network disconnections during SSE streaming. +/// +internal sealed class FaultingStreamHandler : DelegatingHandler +{ + private FaultingStream? _lastStream; + private TaskCompletionSource? _reconnectTcs; + + public async Task TriggerFaultAsync(CancellationToken cancellationToken) + { + if (_lastStream is null or { IsDisposed: true }) + { + throw new InvalidOperationException("There is no active response stream to fault."); + } + + if (_reconnectTcs is not null) + { + throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection."); + } + + _reconnectTcs = new(); + await _lastStream.TriggerFaultAsync(cancellationToken); + + return new(_reconnectTcs); + } + + public sealed class ReconnectAttempt(TaskCompletionSource reconnectTcs) + { + public void Continue() + => reconnectTcs.SetResult(); + } + + protected override async Task SendAsync( + HttpRequestMessage request, CancellationToken cancellationToken) + { + if (_reconnectTcs is not null && request.Headers.Accept.Contains(new("text/event-stream"))) + { + // If we're blocking reconnection, wait until we're allowed to continue. + await _reconnectTcs.Task.WaitAsync(cancellationToken); + _reconnectTcs = null; + } + + var response = await base.SendAsync(request, cancellationToken); + + // Only wrap SSE streams (text/event-stream) + if (response.Content.Headers.ContentType?.MediaType == "text/event-stream") + { + var originalStream = await response.Content.ReadAsStreamAsync(cancellationToken); + _lastStream = new FaultingStream(originalStream); + var faultingContent = new FaultingStreamContent(_lastStream); + + // Copy headers from original content + var newContent = faultingContent; + foreach (var header in response.Content.Headers) + { + newContent.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + + response.Content = newContent; + } + + return response; + } + + private sealed class FaultingStreamContent(FaultingStream stream) : HttpContent + { + private readonly FaultingStream _manualStream = new(stream); + + protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) + => throw new NotSupportedException(); + + protected override Task CreateContentReadStreamAsync() + => Task.FromResult(stream); + + protected override bool TryComputeLength(out long length) + { + length = -1; + return false; + } + } + + private sealed class FaultingStream(Stream innerStream) : Stream + { + private readonly CancellationTokenSource _cts = new(); + private TaskCompletionSource? _faultTcs; + private bool _disposed; + + public bool IsDisposed => _disposed; + + public async Task TriggerFaultAsync(CancellationToken cancellationToken) + { + if (_faultTcs is not null) + { + throw new InvalidOperationException("Only one fault can be triggered per stream."); + } + + _faultTcs = new TaskCompletionSource(); + + await _cts.CancelAsync(); + await _faultTcs.Task.WaitAsync(cancellationToken); + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + try + { + _cts.Token.ThrowIfCancellationRequested(); + + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cts.Token); + return await innerStream.ReadAsync(buffer, linkedCts.Token); + } + catch (OperationCanceledException) when (_cts.IsCancellationRequested) + { + Debug.Assert(_faultTcs is not null); + + if (!_faultTcs.TrySetResult()) + { + throw new InvalidOperationException("Attempted to read an already-faulted stream."); + } + + throw new IOException("Simulated network disconnection."); + } + } + + public override int Read(byte[] buffer, int offset, int count) + => throw new NotSupportedException("Synchronous reads are not supported."); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + + public override bool CanRead => innerStream.CanRead; + public override bool CanSeek => innerStream.CanSeek; + public override bool CanWrite => innerStream.CanWrite; + public override long Length => innerStream.Length; + public override long Position { get => innerStream.Position; set => innerStream.Position = value; } + public override void Flush() => innerStream.Flush(); + public override long Seek(long offset, SeekOrigin origin) => innerStream.Seek(offset, origin); + public override void SetLength(long value) => innerStream.SetLength(value); + public override void Write(byte[] buffer, int offset, int count) => innerStream.Write(buffer, offset, count); + protected override void Dispose(bool disposing) + { + if (!disposing || _disposed) + { + return; + } + + _disposed = true; + innerStream.Dispose(); + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs index 0045e7bc2..43324fb95 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs @@ -8,6 +8,8 @@ namespace ModelContextProtocol.AspNetCore.Tests.Utils; public class KestrelInMemoryTest : LoggedTest { + private readonly TestHttpMessageHandler _httpMessageHandler; + public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { @@ -24,7 +26,9 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) return new(connection.ClientStream); }; - HttpClient = new HttpClient(SocketsHttpHandler) + _httpMessageHandler = new(SocketsHttpHandler); + + HttpClient = new HttpClient(_httpMessageHandler) { BaseAddress = new Uri("http://localhost:5000/"), Timeout = TimeSpan.FromSeconds(10), @@ -39,9 +43,33 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) public KestrelInMemoryTransport KestrelInMemoryTransport { get; } = new(); + protected void SetHttpMessageHandler(DelegatingHandler? handler) + { + if (handler is null) + { + _httpMessageHandler.InnerHandler = SocketsHttpHandler; + } + else + { + _httpMessageHandler.InnerHandler = handler; + handler.InnerHandler = SocketsHttpHandler; + } + } + public override void Dispose() { HttpClient.Dispose(); base.Dispose(); } + + private sealed class TestHttpMessageHandler : DelegatingHandler + { + public TestHttpMessageHandler(HttpMessageHandler innerHandler) + { + InnerHandler = innerHandler; + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + => base.SendAsync(request, cancellationToken); + } } From fe71286913a967b328de1783cc14796685d6cd54 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 16 Dec 2025 16:28:02 -0800 Subject: [PATCH 06/10] Make `EnablePollingAsync` no-op in stateless mode --- .../Server/StreamableHttpPostTransport.cs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 56a36e717..f5f6672b7 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -104,6 +104,12 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can public async ValueTask EnablePollingAsync(TimeSpan retryInterval, CancellationToken cancellationToken) { + if (parentTransport.Stateless) + { + // Polling is currently not supported in stateless mode. + return; + } + var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); if (eventStreamWriter is null) { From 66fcdb13dae0e02237df494087c79c18d2fcc256 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 18 Dec 2025 11:10:42 -0800 Subject: [PATCH 07/10] Update src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs Co-authored-by: Stephen Halter --- .../Client/HttpClientTransportOptions.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs index feff5bf37..c6a6e546d 100644 --- a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs @@ -110,7 +110,7 @@ public required Uri Endpoint public ClientOAuthOptions? OAuth { get; set; } /// - /// Gets or sets the maximum number of reconnection attempts when an SSE stream is disconnected. + /// Gets or sets the maximum number of consecutive reconnection attempts when an SSE stream is disconnected. /// /// /// The maximum number of reconnection attempts. The default is 2. From c9044c8d33b0d8828a720bcb48d6a6eeec74db04 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 18 Dec 2025 13:32:49 -0800 Subject: [PATCH 08/10] PR feedback: Don't dispose `_sseWriter` on final message --- .../StreamableHttpHandler.cs | 2 +- .../Server/SseWriter.cs | 7 +++++ .../Server/StreamableHttpPostTransport.cs | 31 ++++++++++++++----- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index a89f5a542..7925341f5 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -140,7 +140,7 @@ private async Task HandleUnsolicitedMessageStreamAsync(HttpContext context, Stre { await WriteJsonRpcErrorAsync(context, "Bad Request: Unsolicited messages are not supported in stateless mode.", - StatusCodes.Status400BadRequest); + StatusCodes.Status405MethodNotAllowed); return; } diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs index 7ccd45282..4d9803788 100644 --- a/src/ModelContextProtocol.Core/Server/SseWriter.cs +++ b/src/ModelContextProtocol.Core/Server/SseWriter.cs @@ -22,6 +22,8 @@ internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOp private readonly SemaphoreSlim _disposeLock = new(1, 1); private bool _disposed; + public Func>, CancellationToken, IAsyncEnumerable>>? MessageFilter { get; set; } + public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken) { Throw.IfNull(sseResponseStream); @@ -36,6 +38,11 @@ public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellati _writeCancellationToken = cancellationToken; var messages = _messages.Reader.ReadAllAsync(cancellationToken); + if (MessageFilter is not null) + { + messages = MessageFilter(messages, cancellationToken); + } + _writeTask = SseFormatter.WriteAsync(messages, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken); return _writeTask; } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index f5f6672b7..1e8fc0901 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -1,5 +1,7 @@ using ModelContextProtocol.Protocol; using System.Diagnostics; +using System.Net.ServerSentEvents; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading.Channels; @@ -58,6 +60,7 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio // Start the write task immediately so that we don't risk filling up the channel with // messages before they start being consumed. + _sseWriter.MessageFilter = StopOnFinalResponseFilter; var writeTask = _sseWriter.WriteAllAsync(responseStream, cancellationToken); var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); @@ -90,15 +93,10 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can await parentTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } - if (message is JsonRpcResponse or JsonRpcError && ((JsonRpcMessageWithId)message).Id == _pendingRequest) + if (_eventStreamWriter is not null && IsFinalResponse(message)) { - // Complete the SSE response stream and SSE event stream writer now that all pending requests have been processed. - await _sseWriter.DisposeAsync().ConfigureAwait(false); - - if (_eventStreamWriter is not null) - { - await _eventStreamWriter.DisposeAsync().ConfigureAwait(false); - } + // Complete the SSE event stream writer now that all pending requests have been processed. + await _eventStreamWriter.DisposeAsync().ConfigureAwait(false); } } @@ -135,6 +133,23 @@ public async ValueTask DisposeAsync() // after disposal. } + private async IAsyncEnumerable> StopOnFinalResponseFilter(IAsyncEnumerable> messages, [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var message in messages.WithCancellation(cancellationToken)) + { + yield return message; + + if (IsFinalResponse(message.Data)) + { + // Complete the SSE response stream now that all pending requests have been processed. + break; + } + } + } + + private bool IsFinalResponse(JsonRpcMessage? message) + => (message is JsonRpcResponse or JsonRpcError) && ((JsonRpcMessageWithId)message).Id == _pendingRequest; + private async ValueTask GetOrCreateEventStreamAsync(CancellationToken cancellationToken) { using var _ = await _eventStreamLock.LockAsync(cancellationToken).ConfigureAwait(false); From 2f19044f3e4de2691eb03b6698561abcd9df2c91 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 18 Dec 2025 14:09:57 -0800 Subject: [PATCH 09/10] PR feedback: Use `SseEventStreamMode.Default` for unsolicited messages --- .../StreamableHttpHandler.cs | 86 +++++++++---------- .../Server/StreamableHttpServerTransport.cs | 84 ++++-------------- 2 files changed, 59 insertions(+), 111 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 7925341f5..20cce477b 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -82,59 +82,59 @@ await WriteJsonRpcErrorAsync(context, return; } - StreamableHttpSession? session = null; - ISseEventStreamReader? eventStreamReader = null; - var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString(); - var lastEventId = context.Request.Headers[LastEventIdHeaderName].ToString(); - - if (!string.IsNullOrEmpty(sessionId)) + var session = await GetSessionAsync(context, sessionId); + if (session is null) { - session = await GetSessionAsync(context, sessionId); - if (session is null) - { - // There was an error obtaining the session; consider the request failed. - return; - } + return; } + var lastEventId = context.Request.Headers[LastEventIdHeaderName].ToString(); if (!string.IsNullOrEmpty(lastEventId)) { - if (HttpServerTransportOptions.Stateless) - { - await WriteJsonRpcErrorAsync(context, - "Bad Request: The Last-Event-ID header is not supported in stateless mode.", - StatusCodes.Status400BadRequest); - return; - } - - eventStreamReader = await GetEventStreamReaderAsync(context, lastEventId); - if (eventStreamReader is null) - { - // There was an error obtaining the event stream; consider the request failed. - return; - } + await HandleResumedStreamAsync(context, session, lastEventId); + } + else + { + await HandleUnsolicitedMessageStreamAsync(context, session); } + } - if (session is not null && eventStreamReader is not null && !string.Equals(session.Id, eventStreamReader.SessionId, StringComparison.Ordinal)) + private async Task HandleResumedStreamAsync(HttpContext context, StreamableHttpSession session, string lastEventId) + { + if (HttpServerTransportOptions.Stateless) { await WriteJsonRpcErrorAsync(context, - "Bad Request: The Last-Event-ID header refers to a session with a different session ID.", + "Bad Request: The Last-Event-ID header is not supported in stateless mode.", StatusCodes.Status400BadRequest); return; } - if (eventStreamReader is null || string.Equals(eventStreamReader.StreamId, StreamableHttpServerTransport.UnsolicitedMessageStreamId, StringComparison.Ordinal)) + var eventStreamReader = await GetEventStreamReaderAsync(context, lastEventId); + if (eventStreamReader is null) { - await HandleUnsolicitedMessageStreamAsync(context, session, eventStreamReader); + // There was an error obtaining the event stream; consider the request failed. + return; } - else + + if (!string.Equals(session.Id, eventStreamReader.SessionId, StringComparison.Ordinal)) { - await HandleResumePostResponseStreamAsync(context, eventStreamReader); + await WriteJsonRpcErrorAsync(context, + "Bad Request: The Last-Event-ID header refers to a session with a different session ID.", + StatusCodes.Status400BadRequest); + return; } + + using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping); + var cancellationToken = sseCts.Token; + + await using var _ = await session.AcquireReferenceAsync(cancellationToken); + + InitializeSseResponse(context); + await eventStreamReader.CopyToAsync(context.Response.Body, context.RequestAborted); } - private async Task HandleUnsolicitedMessageStreamAsync(HttpContext context, StreamableHttpSession? session, ISseEventStreamReader? eventStreamReader) + private async Task HandleUnsolicitedMessageStreamAsync(HttpContext context, StreamableHttpSession session) { if (HttpServerTransportOptions.Stateless) { @@ -144,18 +144,10 @@ await WriteJsonRpcErrorAsync(context, return; } - if (session is null) - { - await WriteJsonRpcErrorAsync(context, - "Bad Request: Mcp-Session-Id header is required", - StatusCodes.Status400BadRequest); - return; - } - - if (eventStreamReader is null && !session.TryStartGetRequest()) + if (!session.TryStartGetRequest()) { await WriteJsonRpcErrorAsync(context, - "Bad Request: This server does not support multiple GET requests. Use Last-Event-ID header to resume or start a new session.", + "Bad Request: This server does not support multiple GET requests. Start a new session or use Last-Event-ID header to resume.", StatusCodes.Status400BadRequest); return; } @@ -175,7 +167,7 @@ await WriteJsonRpcErrorAsync(context, // will be sent in response to a different POST request. It might be a while before we send a message // over this response body. await context.Response.Body.FlushAsync(cancellationToken); - await session.Transport.HandleGetRequestAsync(context.Response.Body, eventStreamReader, cancellationToken); + await session.Transport.HandleGetRequestAsync(context.Response.Body, cancellationToken); } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { @@ -201,6 +193,12 @@ public async Task HandleDeleteRequestAsync(HttpContext context) private async ValueTask GetSessionAsync(HttpContext context, string sessionId) { + if (string.IsNullOrEmpty(sessionId)) + { + await WriteJsonRpcErrorAsync(context, "Bad Request: Mcp-Session-Id header is required", StatusCodes.Status400BadRequest); + return null; + } + if (!sessionManager.TryGetValue(sessionId, out var session)) { // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index e505c35b1..3731102f9 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -123,29 +123,7 @@ internal async ValueTask HandleInitRequestAsync(InitializeRequestParams? initPar /// /// is and GET requests are not supported in stateless mode. /// - public Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default) - => HandleGetRequestAsync(sseResponseStream, eventStreamReader: null, cancellationToken); - - /// - /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by - /// writing any unsolicited JSON-RPC messages sent via - /// to the SSE response stream until cancellation is requested or the transport is disposed. - /// - /// The response stream to write MCP JSON-RPC messages as SSE events to. - /// The to replay events from before writing this transport's messages to the response stream. - /// The to monitor for cancellation requests. The default is . - /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - /// is . - /// - /// is and GET requests are not supported in stateless mode. - /// - public async Task HandleGetRequestAsync(Stream sseResponseStream, ISseEventStreamReader? eventStreamReader, CancellationToken cancellationToken = default) - { - var writeTask = await StartGetRequestAsync(sseResponseStream, eventStreamReader, cancellationToken).ConfigureAwait(false); - await writeTask.ConfigureAwait(false); - } - - private async Task StartGetRequestAsync(Stream sseResponseStream, ISseEventStreamReader? eventStreamReader, CancellationToken cancellationToken) + public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default) { Throw.IfNull(sseResponseStream); @@ -154,41 +132,27 @@ private async Task StartGetRequestAsync(Stream sseResponseStream, ISseEven throw new InvalidOperationException("GET requests are not supported in stateless mode."); } - using var _ = await _sendLock.LockAsync(cancellationToken); - - ThrowIfDisposed(); - - if (_getRequestStarted) + Task writeTask; + using (await _sendLock.LockAsync(cancellationToken)) { - await _sseWriter.DisposeAsync().ConfigureAwait(false); - _sseWriter = new(); - } - - _getRequestStarted = true; - - if (eventStreamReader is not null) - { - if (eventStreamReader.SessionId != SessionId) + if (_getRequestStarted) { - throw new InvalidOperationException("The provided SSE event stream reader relates to a different session."); + throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); } - if (eventStreamReader.StreamId != UnsolicitedMessageStreamId) - { - throw new InvalidOperationException("The event stream reader does not relate to the unsolicited message stream."); - } + _getRequestStarted = true; - await eventStreamReader.CopyToAsync(sseResponseStream, cancellationToken); - } + // We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully. + writeTask = _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken); - var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); - if (eventStreamWriter is not null) - { - await _sseWriter.SendPrimingEventAsync(RetryInterval, eventStreamWriter, cancellationToken).ConfigureAwait(false); + var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); + if (eventStreamWriter is not null) + { + await _sseWriter.SendPrimingEventAsync(RetryInterval, eventStreamWriter, cancellationToken).ConfigureAwait(false); + } } - // We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully. - return _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken); + await writeTask.ConfigureAwait(false); } /// @@ -231,8 +195,6 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can using var _ = await _sendLock.LockAsync(cancellationToken); - ThrowIfDisposed(); - // If the underlying writer has been disposed, rely on the event stream writer, if present. // Otherwise, just drop the message. var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false); @@ -285,9 +247,9 @@ public async ValueTask DisposeAsync() return null; } - // We set the mode to 'Polling' so that the transport can take over writing to the response stream after - // messages have been replayed. - const SseEventStreamMode Mode = SseEventStreamMode.Polling; + // We use the 'Default' stream mode so that in the case of an unexpected network disconnection, + // the client can continue reading the remaining messages in a single, streamed response. + const SseEventStreamMode Mode = SseEventStreamMode.Default; _eventStreamWriter = await EventStreamStore.CreateStreamAsync(options: new() { @@ -298,16 +260,4 @@ public async ValueTask DisposeAsync() return _eventStreamWriter; } - - private void ThrowIfDisposed() - { -#if NET - ObjectDisposedException.ThrowIf(_disposed, this); -#else - if (_disposed) - { - throw new ObjectDisposedException(GetType().Name); - } -#endif - } } From f6858ab36e4c9078773b1c4c594aa93b3f331f9c Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 18 Dec 2025 14:34:11 -0800 Subject: [PATCH 10/10] Return 405 on GET in statless mode --- .../McpEndpointRouteBuilderExtensions.cs | 10 ++++++---- .../StreamableHttpHandler.cs | 8 -------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index c6e04f867..ad87f7a4c 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -39,12 +39,14 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); - streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); - if (!streamableHttpHandler.HttpServerTransportOptions.Stateless) { - // The DELETE endpoints are not mapped in Stateless mode since there's no server-side state to clean up. + // The GET endpoint is not mapped in Stateless mode since there's no way to send unsolicited messages. + // Resuming streams via GET is currently not supported in Stateless mode. + streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + + // The DELETE endpoint is not mapped in Stateless mode since there is no server-side state for the DELETE to clean up. streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); // Map legacy HTTP with SSE endpoints only if not in Stateless mode, because we cannot guarantee the /message requests diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 20cce477b..a47e66b1f 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -136,14 +136,6 @@ await WriteJsonRpcErrorAsync(context, private async Task HandleUnsolicitedMessageStreamAsync(HttpContext context, StreamableHttpSession session) { - if (HttpServerTransportOptions.Stateless) - { - await WriteJsonRpcErrorAsync(context, - "Bad Request: Unsolicited messages are not supported in stateless mode.", - StatusCodes.Status405MethodNotAllowed); - return; - } - if (!session.TryStartGetRequest()) { await WriteJsonRpcErrorAsync(context,