diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
index 67f4f4e1d..bbbcf6beb 100644
--- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
+++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
@@ -43,6 +43,34 @@ public class HttpServerTransportOptions
///
public bool Stateless { get; set; }
+ ///
+ /// Gets or sets the event store for resumability support.
+ /// When set, events are stored and can be replayed when clients reconnect with a Last-Event-ID header.
+ ///
+ ///
+ /// When configured, the server will:
+ ///
+ /// - Generate unique event IDs for each SSE message
+ /// - Store events for later replay
+ /// - Replay missed events when a client reconnects with a Last-Event-ID header
+ /// - Send priming events to establish resumability before any actual messages
+ ///
+ ///
+ public ISseEventStreamStore? EventStreamStore { get; set; }
+
+ ///
+ /// Gets or sets the retry interval to suggest to clients in SSE retry field.
+ ///
+ ///
+ /// The retry interval. The default is 1 second.
+ ///
+ ///
+ /// When is set, the server will include a retry field in priming events.
+ /// This value suggests to clients how long to wait before attempting to reconnect after a connection is lost.
+ /// Clients may use this value to implement polling behavior during long-running operations.
+ ///
+ public TimeSpan RetryInterval { get; set; } = TimeSpan.FromSeconds(1);
+
///
/// Gets or sets a value that indicates whether the server uses a single execution context for the entire session.
///
diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
index 8c78d7516..ad87f7a4c 100644
--- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
+++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
@@ -41,10 +41,12 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
if (!streamableHttpHandler.HttpServerTransportOptions.Stateless)
{
- // The GET and DELETE endpoints are not mapped in Stateless mode since there's no way to send unsolicited messages
- // for the GET to handle, and there is no server-side state for the DELETE to clean up.
+ // The GET endpoint is not mapped in Stateless mode since there's no way to send unsolicited messages.
+ // Resuming streams via GET is currently not supported in Stateless mode.
streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync)
.WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"]));
+
+ // The DELETE endpoint is not mapped in Stateless mode since there is no server-side state for the DELETE to clean up.
streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync);
// Map legacy HTTP with SSE endpoints only if not in Stateless mode, because we cannot guarantee the /message requests
diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
index c0f59363a..a47e66b1f 100644
--- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
+++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
@@ -23,6 +23,7 @@ internal sealed class StreamableHttpHandler(
ILoggerFactory loggerFactory)
{
private const string McpSessionIdHeaderName = "Mcp-Session-Id";
+ private const string LastEventIdHeaderName = "Last-Event-ID";
private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo();
private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo();
@@ -88,10 +89,57 @@ await WriteJsonRpcErrorAsync(context,
return;
}
+ var lastEventId = context.Request.Headers[LastEventIdHeaderName].ToString();
+ if (!string.IsNullOrEmpty(lastEventId))
+ {
+ await HandleResumedStreamAsync(context, session, lastEventId);
+ }
+ else
+ {
+ await HandleUnsolicitedMessageStreamAsync(context, session);
+ }
+ }
+
+ private async Task HandleResumedStreamAsync(HttpContext context, StreamableHttpSession session, string lastEventId)
+ {
+ if (HttpServerTransportOptions.Stateless)
+ {
+ await WriteJsonRpcErrorAsync(context,
+ "Bad Request: The Last-Event-ID header is not supported in stateless mode.",
+ StatusCodes.Status400BadRequest);
+ return;
+ }
+
+ var eventStreamReader = await GetEventStreamReaderAsync(context, lastEventId);
+ if (eventStreamReader is null)
+ {
+ // There was an error obtaining the event stream; consider the request failed.
+ return;
+ }
+
+ if (!string.Equals(session.Id, eventStreamReader.SessionId, StringComparison.Ordinal))
+ {
+ await WriteJsonRpcErrorAsync(context,
+ "Bad Request: The Last-Event-ID header refers to a session with a different session ID.",
+ StatusCodes.Status400BadRequest);
+ return;
+ }
+
+ using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
+ var cancellationToken = sseCts.Token;
+
+ await using var _ = await session.AcquireReferenceAsync(cancellationToken);
+
+ InitializeSseResponse(context);
+ await eventStreamReader.CopyToAsync(context.Response.Body, context.RequestAborted);
+ }
+
+ private async Task HandleUnsolicitedMessageStreamAsync(HttpContext context, StreamableHttpSession session)
+ {
if (!session.TryStartGetRequest())
{
await WriteJsonRpcErrorAsync(context,
- "Bad Request: This server does not support multiple GET requests. Start a new session to get a new GET SSE response.",
+ "Bad Request: This server does not support multiple GET requests. Start a new session or use Last-Event-ID header to resume.",
StatusCodes.Status400BadRequest);
return;
}
@@ -120,6 +168,12 @@ await WriteJsonRpcErrorAsync(context,
}
}
+ private static async Task HandleResumePostResponseStreamAsync(HttpContext context, ISseEventStreamReader eventStreamReader)
+ {
+ InitializeSseResponse(context);
+ await eventStreamReader.CopyToAsync(context.Response.Body, context.RequestAborted);
+ }
+
public async Task HandleDeleteRequestAsync(HttpContext context)
{
var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString();
@@ -131,14 +185,13 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
private async ValueTask GetSessionAsync(HttpContext context, string sessionId)
{
- StreamableHttpSession? session;
-
if (string.IsNullOrEmpty(sessionId))
{
await WriteJsonRpcErrorAsync(context, "Bad Request: Mcp-Session-Id header is required", StatusCodes.Status400BadRequest);
return null;
}
- else if (!sessionManager.TryGetValue(sessionId, out session))
+
+ if (!sessionManager.TryGetValue(sessionId, out var session))
{
// -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does.
// One of the few other usages I found was from some Ethereum JSON-RPC documentation and this
@@ -194,12 +247,16 @@ private async ValueTask StartNewSessionAsync(HttpContext
{
SessionId = sessionId,
FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext,
+ EventStreamStore = HttpServerTransportOptions.EventStreamStore,
+ RetryInterval = HttpServerTransportOptions.RetryInterval,
};
context.Response.Headers[McpSessionIdHeaderName] = sessionId;
}
else
{
// In stateless mode, each request is independent. Don't set any session ID on the transport.
+ // If in the future we support resuming stateless requests, we should populate
+ // the event stream store and retry interval here as well.
sessionId = "";
transport = new()
{
@@ -246,6 +303,28 @@ private async ValueTask CreateSessionAsync(
return session;
}
+ private async ValueTask GetEventStreamReaderAsync(HttpContext context, string lastEventId)
+ {
+ if (HttpServerTransportOptions.EventStreamStore is not { } eventStreamStore)
+ {
+ await WriteJsonRpcErrorAsync(context,
+ "Bad Request: This server does not support resuming streams.",
+ StatusCodes.Status400BadRequest);
+ return null;
+ }
+
+ var eventStreamReader = await eventStreamStore.GetStreamReaderAsync(lastEventId, context.RequestAborted);
+ if (eventStreamReader is null)
+ {
+ await WriteJsonRpcErrorAsync(context,
+ "Bad Request: The specified Last-Event-ID is either invalid or expired.",
+ StatusCodes.Status400BadRequest);
+ return null;
+ }
+
+ return eventStreamReader;
+ }
+
private static Task WriteJsonRpcErrorAsync(HttpContext context, string errorMessage, int statusCode, int errorCode = -32000)
{
var jsonRpcError = new JsonRpcError
diff --git a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs
index 43b6ef30d..c6a6e546d 100644
--- a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs
+++ b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs
@@ -108,4 +108,17 @@ public required Uri Endpoint
/// Gets sor sets the authorization provider to use for authentication.
///
public ClientOAuthOptions? OAuth { get; set; }
+
+ ///
+ /// Gets or sets the maximum number of consecutive reconnection attempts when an SSE stream is disconnected.
+ ///
+ ///
+ /// The maximum number of reconnection attempts. The default is 2.
+ ///
+ ///
+ /// When an SSE stream is disconnected (e.g., due to a network issue), the client will attempt to
+ /// reconnect using the Last-Event-ID header to resume from where it left off. This property controls
+ /// how many reconnection attempts are made before giving up.
+ ///
+ public int MaxReconnectionAttempts { get; set; } = 2;
}
diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs
index 534249038..e36d3e966 100644
--- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs
+++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs
@@ -16,6 +16,8 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa
private static readonly MediaTypeWithQualityHeaderValue s_applicationJsonMediaType = new("application/json");
private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream");
+ private static readonly TimeSpan s_defaultReconnectionDelay = TimeSpan.FromSeconds(1);
+
private readonly McpHttpClient _httpClient;
private readonly HttpClientTransportOptions _options;
private readonly CancellationTokenSource _connectionCts;
@@ -105,8 +107,18 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes
}
else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream")
{
- using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken);
- rpcResponseOrError = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false);
+ var sseState = new SseStreamState();
+ using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
+ var sseResponse = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, sseState, cancellationToken).ConfigureAwait(false);
+ rpcResponseOrError = sseResponse.Response;
+
+ // Resumability: If POST SSE stream ended without a response but we have a Last-Event-ID (from priming),
+ // attempt to resume by sending a GET request with Last-Event-ID header. The server will replay
+ // events from the event store, allowing us to receive the pending response.
+ if (rpcResponseOrError is null && rpcRequest is not null && sseState.LastEventId is not null)
+ {
+ rpcResponseOrError = await SendGetSseRequestWithRetriesAsync(rpcRequest, sseState, cancellationToken).ConfigureAwait(false);
+ }
}
if (rpcRequest is null)
@@ -188,56 +200,140 @@ public override async ValueTask DisposeAsync()
private async Task ReceiveUnsolicitedMessagesAsync()
{
- // Send a GET request to handle any unsolicited messages not sent over a POST response.
- using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
- request.Headers.Accept.Add(s_textEventStreamMediaType);
- CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion);
+ var state = new SseStreamState();
- // Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages.
- HttpResponseMessage response;
- try
+ // Continuously receive unsolicited messages until canceled
+ while (!_connectionCts.Token.IsCancellationRequested)
{
- response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false);
- }
- catch (HttpRequestException)
- {
- return;
- }
+ await SendGetSseRequestWithRetriesAsync(
+ relatedRpcRequest: null,
+ state,
+ _connectionCts.Token).ConfigureAwait(false);
- using (response)
- {
- if (!response.IsSuccessStatusCode)
+ // If we exhausted retries without receiving any events, stop trying
+ if (state.LastEventId is null)
{
return;
}
-
- using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false);
- await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false);
}
}
- private async Task ProcessSseResponseAsync(Stream responseStream, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken)
+ ///
+ /// Sends a GET request for SSE with retry logic and resumability support.
+ ///
+ private async Task SendGetSseRequestWithRetriesAsync(
+ JsonRpcRequest? relatedRpcRequest,
+ SseStreamState state,
+ CancellationToken cancellationToken)
{
- await foreach (SseItem sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false))
+ int attempt = 0;
+
+ // Delay before first attempt if we're reconnecting (have a Last-Event-ID)
+ bool shouldDelay = state.LastEventId is not null;
+
+ while (attempt < _options.MaxReconnectionAttempts)
{
- if (sseEvent.EventType != "message")
+ cancellationToken.ThrowIfCancellationRequested();
+
+ if (shouldDelay)
{
- continue;
+ var delay = state.RetryInterval ?? s_defaultReconnectionDelay;
+ await Task.Delay(delay, cancellationToken).ConfigureAwait(false);
}
+ shouldDelay = true;
- var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false);
+ using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
+ request.Headers.Accept.Add(s_textEventStreamMediaType);
+ CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion, state.LastEventId);
- // The server SHOULD end the HTTP response body here anyway, but we won't leave it to chance. This transport makes
- // a GET request for any notifications that might need to be sent after the completion of each POST.
- if (rpcResponseOrError is not null)
+ HttpResponseMessage response;
+ try
{
- return rpcResponseOrError;
+ response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false);
+ }
+ catch (HttpRequestException)
+ {
+ attempt++;
+ continue;
+ }
+
+ using (response)
+ {
+ if (!response.IsSuccessStatusCode)
+ {
+ // If the server could be reached but returned a non-success status code,
+ // retrying likely won't change that.
+ return null;
+ }
+
+ using var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
+ var sseResponse = await ProcessSseResponseAsync(responseStream, relatedRpcRequest, state, cancellationToken).ConfigureAwait(false);
+
+ if (sseResponse.Response is { } rpcResponseOrError)
+ {
+ return rpcResponseOrError;
+ }
+
+ // If we reach here, then the stream closed without the response.
+
+ if (sseResponse.IsNetworkError || state.LastEventId is null)
+ {
+ // No event ID means server may not support resumability; don't retry indefinitely.
+ attempt++;
+ }
+ else
+ {
+ // We have an event ID, so we continue polling to receive more events.
+ // The server should eventually send a response or return an error.
+ attempt = 0;
+ }
}
}
return null;
}
+ private async Task ProcessSseResponseAsync(
+ Stream responseStream,
+ JsonRpcRequest? relatedRpcRequest,
+ SseStreamState state,
+ CancellationToken cancellationToken)
+ {
+ try
+ {
+ await foreach (SseItem sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false))
+ {
+ // Track event ID and retry interval for resumability
+ if (!string.IsNullOrEmpty(sseEvent.EventId))
+ {
+ state.LastEventId = sseEvent.EventId;
+ }
+ if (sseEvent.ReconnectionInterval.HasValue)
+ {
+ state.RetryInterval = sseEvent.ReconnectionInterval.Value;
+ }
+
+ // Skip events with empty data
+ if (string.IsNullOrEmpty(sseEvent.Data))
+ {
+ continue;
+ }
+
+ var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false);
+ if (rpcResponseOrError is not null)
+ {
+ return new() { Response = rpcResponseOrError };
+ }
+ }
+ }
+ catch (Exception ex) when (ex is IOException or HttpRequestException)
+ {
+ return new() { IsNetworkError = true };
+ }
+
+ return default;
+ }
+
private async Task ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken)
{
LogTransportReceivedMessageSensitive(Name, data);
@@ -292,7 +388,8 @@ internal static void CopyAdditionalHeaders(
HttpRequestHeaders headers,
IDictionary? additionalHeaders,
string? sessionId,
- string? protocolVersion)
+ string? protocolVersion,
+ string? lastEventId = null)
{
if (sessionId is not null)
{
@@ -304,6 +401,11 @@ internal static void CopyAdditionalHeaders(
headers.Add("MCP-Protocol-Version", protocolVersion);
}
+ if (lastEventId is not null)
+ {
+ headers.Add("Last-Event-ID", lastEventId);
+ }
+
if (additionalHeaders is null)
{
return;
@@ -317,4 +419,22 @@ internal static void CopyAdditionalHeaders(
}
}
}
+
+ ///
+ /// Tracks state across SSE stream connections.
+ ///
+ private sealed class SseStreamState
+ {
+ public string? LastEventId { get; set; }
+ public TimeSpan? RetryInterval { get; set; }
+ }
+
+ ///
+ /// Represents the result of processing an SSE response.
+ ///
+ private readonly struct SseResponse
+ {
+ public JsonRpcMessageWithId? Response { get; init; }
+ public bool IsNetworkError { get; init; }
+ }
}
diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs
index 0b915c9f1..1dcb11f65 100644
--- a/src/ModelContextProtocol.Core/McpSessionHandler.cs
+++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs
@@ -29,16 +29,32 @@ internal sealed partial class McpSessionHandler : IAsyncDisposable
"mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false);
/// The latest version of the protocol supported by this implementation.
- internal const string LatestProtocolVersion = "2025-06-18";
+ internal const string LatestProtocolVersion = "2025-11-25";
/// All protocol versions supported by this implementation.
internal static readonly string[] SupportedProtocolVersions =
[
"2024-11-05",
"2025-03-26",
+ "2025-06-18",
LatestProtocolVersion,
];
+ ///
+ /// Checks if the given protocol version supports priming events.
+ ///
+ /// The protocol version to check.
+ /// True if the protocol version supports resumability.
+ ///
+ /// Priming events are only supported in protocol version >= 2025-11-25.
+ /// Older clients may crash when receiving SSE events with empty data.
+ ///
+ internal static bool SupportsPrimingEvent(string? protocolVersion)
+ {
+ const string MinResumabilityProtocolVersion = "2025-11-25";
+ return string.Compare(protocolVersion, MinResumabilityProtocolVersion, StringComparison.Ordinal) >= 0;
+ }
+
private readonly bool _isServer;
private readonly string _transportKind;
private readonly ITransport _transport;
diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs b/src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs
new file mode 100644
index 000000000..01c642355
--- /dev/null
+++ b/src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs
@@ -0,0 +1,36 @@
+using ModelContextProtocol.Protocol;
+using System.Net.ServerSentEvents;
+
+namespace ModelContextProtocol.Server;
+
+///
+/// Provides read access to an SSE event stream, allowing events to be consumed asynchronously.
+///
+public interface ISseEventStreamReader
+{
+ ///
+ /// Gets the session ID associated with the stream being read.
+ ///
+ string SessionId { get; }
+
+ ///
+ /// Gets the ID of the stream.
+ ///
+ ///
+ /// This value is guaranteed to be unique on a per-session basis.
+ ///
+ string StreamId { get; }
+
+ ///
+ /// Gets the messages from the stream as an .
+ ///
+ /// A token to cancel the operation.
+ /// An of containing JSON-RPC messages.
+ ///
+ /// If the stream's mode is set to , the returned
+ /// messages will only include the currently-available events starting at the last event ID specified
+ /// when the reader was created. Otherwise, the returned messages will continue until the associated
+ /// is disposed.
+ ///
+ IAsyncEnumerable> ReadEventsAsync(CancellationToken cancellationToken = default);
+}
diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs
new file mode 100644
index 000000000..723afab07
--- /dev/null
+++ b/src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs
@@ -0,0 +1,23 @@
+namespace ModelContextProtocol.Server;
+
+///
+/// Provides storage and retrieval of SSE event streams, enabling resumability and redelivery of events.
+///
+public interface ISseEventStreamStore
+{
+ ///
+ /// Creates a new SSE event stream with the specified options.
+ ///
+ /// The configuration options for the new stream.
+ /// A token to cancel the operation.
+ /// A writer for the newly created event stream.
+ ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default);
+
+ ///
+ /// Gets a reader for an existing event stream based on the last event ID.
+ ///
+ /// The ID of the last event received by the client, used to resume from that point.
+ /// A token to cancel the operation.
+ /// A reader for the event stream, or null if no matching stream is found.
+ ValueTask GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken);
+}
diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs b/src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs
new file mode 100644
index 000000000..8cb99af90
--- /dev/null
+++ b/src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs
@@ -0,0 +1,43 @@
+using ModelContextProtocol.Protocol;
+using System.Net.ServerSentEvents;
+
+namespace ModelContextProtocol.Server;
+
+///
+/// Provides write access to an SSE event stream, allowing events to be written and tracked with unique IDs.
+///
+public interface ISseEventStreamWriter : IAsyncDisposable
+{
+ ///
+ /// Gets the ID of the stream.
+ ///
+ ///
+ /// This value is guaranteed to be unique on a per-session basis.
+ ///
+ string StreamId { get; }
+
+ ///
+ /// Gets the current mode of the event stream.
+ ///
+ SseEventStreamMode Mode { get; }
+
+ ///
+ /// Sets the mode of the event stream.
+ ///
+ /// The new mode to set for the event stream.
+ /// A token to cancel the operation.
+ /// A task that represents the asynchronous operation.
+ ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default);
+
+ ///
+ /// Writes an event to the stream.
+ ///
+ /// The original .
+ /// A token to cancel the operation.
+ /// A new with a populated event ID.
+ ///
+ /// If the provided already has an event ID, this method skips writing the event.
+ /// Otherwise, an event ID unique to all sessions and streams is generated and assigned to the event.
+ ///
+ ValueTask> WriteEventAsync(SseItem sseItem, CancellationToken cancellationToken = default);
+}
diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs
index a8d1f66c9..a5ef73840 100644
--- a/src/ModelContextProtocol.Core/Server/RequestContext.cs
+++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs
@@ -82,4 +82,21 @@ public McpServer Server
/// including the method name, parameters, request ID, and associated transport and user information.
///
public JsonRpcRequest JsonRpcRequest { get; }
+
+ ///
+ /// Ends the current response and enables polling for updates from the server.
+ ///
+ /// The interval at which the client should poll for updates.
+ /// The cancellation token.
+ /// A that completes when polling has been enabled.
+ public async ValueTask EnablePollingAsync(TimeSpan retryInterval, CancellationToken cancellationToken = default)
+ {
+ if (JsonRpcRequest.Context?.RelatedTransport is not StreamableHttpPostTransport transport)
+ {
+ // Polling is only supported for Streamable HTTP POST transports.
+ return;
+ }
+
+ await transport.EnablePollingAsync(retryInterval, cancellationToken);
+ }
}
diff --git a/src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs b/src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs
new file mode 100644
index 000000000..cd5c7f4bf
--- /dev/null
+++ b/src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs
@@ -0,0 +1,20 @@
+namespace ModelContextProtocol.Server;
+
+///
+/// Represents the mode of an SSE event stream.
+///
+public enum SseEventStreamMode
+{
+ ///
+ /// Causes the event stream returned by to only end when
+ /// the associated gets disposed.
+ ///
+ Default = 0,
+
+ ///
+ /// Causes the event stream returned by to end
+ /// after the most recent event has been consumed. This forces clients to keep making new requests in order to receive
+ /// the latest messages.
+ ///
+ Polling = 1,
+}
diff --git a/src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs b/src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs
new file mode 100644
index 000000000..6eca0e0b8
--- /dev/null
+++ b/src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs
@@ -0,0 +1,22 @@
+namespace ModelContextProtocol.Server;
+
+///
+/// Configuration options for creating an SSE event stream.
+///
+public sealed class SseEventStreamOptions
+{
+ ///
+ /// Gets or sets the session ID associated with the event stream.
+ ///
+ public required string SessionId { get; set; }
+
+ ///
+ /// Gets or sets the stream ID that uniquely identifies this stream within a session.
+ ///
+ public required string StreamId { get; set; }
+
+ ///
+ /// Gets or sets the mode of the event stream. Defaults to .
+ ///
+ public SseEventStreamMode Mode { get; set; }
+}
diff --git a/src/ModelContextProtocol.Core/Server/SseEventStreamReaderExtensions.cs b/src/ModelContextProtocol.Core/Server/SseEventStreamReaderExtensions.cs
new file mode 100644
index 000000000..ff15ba98c
--- /dev/null
+++ b/src/ModelContextProtocol.Core/Server/SseEventStreamReaderExtensions.cs
@@ -0,0 +1,50 @@
+using ModelContextProtocol.Protocol;
+using System.Buffers;
+using System.Net.ServerSentEvents;
+using System.Text.Json;
+
+namespace ModelContextProtocol.Server;
+
+///
+/// Provides extension methods for .
+///
+public static class SseEventStreamReaderExtensions
+{
+ ///
+ /// Copies all events from the reader to the destination stream in SSE format.
+ ///
+ /// The event stream reader to copy events from.
+ /// The destination stream to write SSE-formatted events to.
+ /// A token to cancel the operation.
+ /// A task that represents the asynchronous copy operation.
+ /// Thrown when or is null.
+ public static async Task CopyToAsync(this ISseEventStreamReader reader, Stream destination, CancellationToken cancellationToken = default)
+ {
+ Throw.IfNull(reader);
+ Throw.IfNull(destination);
+
+ Utf8JsonWriter? jsonWriter = null;
+
+ var events = reader.ReadEventsAsync(cancellationToken);
+ await SseFormatter.WriteAsync(events, destination, FormatEvent, cancellationToken);
+
+ void FormatEvent(SseItem item, IBufferWriter writer)
+ {
+ if (item.Data is null)
+ {
+ return;
+ }
+
+ if (jsonWriter is null)
+ {
+ jsonWriter = new Utf8JsonWriter(writer);
+ }
+ else
+ {
+ jsonWriter.Reset(writer);
+ }
+
+ JsonSerializer.Serialize(jsonWriter, item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!);
+ }
+ }
+}
diff --git a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs
index afdf29943..c537917ac 100644
--- a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs
+++ b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs
@@ -67,7 +67,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
{
Throw.IfNull(message);
// If the underlying writer has been disposed, just drop the message.
- await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
+ await _sseWriter.SendMessageAsync(message, eventStreamWriter: null, cancellationToken).ConfigureAwait(false);
}
///
diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs
index a2314e623..4d9803788 100644
--- a/src/ModelContextProtocol.Core/Server/SseWriter.cs
+++ b/src/ModelContextProtocol.Core/Server/SseWriter.cs
@@ -47,12 +47,36 @@ public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellati
return _writeTask;
}
- public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
+ public Task SendPrimingEventAsync(TimeSpan retryInterval, ISseEventStreamWriter eventStreamWriter, CancellationToken cancellationToken = default)
+ {
+ // Create a priming event: empty data with an event ID
+ var primingItem = new SseItem(null, "prime")
+ {
+ ReconnectionInterval = retryInterval,
+ };
+
+ return SendMessageAsync(primingItem, eventStreamWriter, cancellationToken);
+ }
+
+ public Task SendMessageAsync(JsonRpcMessage message, ISseEventStreamWriter? eventStreamWriter, CancellationToken cancellationToken = default)
{
Throw.IfNull(message);
+ // Emit redundant "event: message" lines for better compatibility with other SDKs.
+ return SendMessageAsync(new SseItem(message, SseParser.EventTypeDefault), eventStreamWriter, cancellationToken);
+ }
+
+ private async Task SendMessageAsync(SseItem item, ISseEventStreamWriter? eventStreamWriter, CancellationToken cancellationToken = default)
+ {
using var _ = await _disposeLock.LockAsync(cancellationToken).ConfigureAwait(false);
+ if (eventStreamWriter is not null && item.EventId is null)
+ {
+ // Store the event first, even if the underlying writer has completed, so that
+ // messages can still be retrieved from the event store.
+ item = await eventStreamWriter.WriteEventAsync(item, cancellationToken: cancellationToken).ConfigureAwait(false);
+ }
+
if (_disposed)
{
// Don't throw ObjectDisposedException here; just return false to indicate the message wasn't sent.
@@ -60,8 +84,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationTok
return false;
}
- // Emit redundant "event: message" lines for better compatibility with other SDKs.
- await _messages.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false);
+ await _messages.Writer.WriteAsync(item, cancellationToken).ConfigureAwait(false);
return true;
}
@@ -101,7 +124,10 @@ private void WriteJsonRpcMessageToBuffer(SseItem item, IBufferW
return;
}
- JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!);
+ if (item.Data is not null)
+ {
+ JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!);
+ }
}
private Utf8JsonWriter GetUtf8JsonWriter(IBufferWriter writer)
@@ -118,3 +144,4 @@ private Utf8JsonWriter GetUtf8JsonWriter(IBufferWriter writer)
return _jsonWriter;
}
}
+
diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs
index 1109c2b2b..1e8fc0901 100644
--- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs
+++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs
@@ -1,9 +1,7 @@
using ModelContextProtocol.Protocol;
using System.Diagnostics;
-using System.IO.Pipelines;
using System.Net.ServerSentEvents;
using System.Runtime.CompilerServices;
-using System.Security.Claims;
using System.Text.Json;
using System.Threading.Channels;
@@ -17,13 +15,15 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport
{
private readonly SseWriter _sseWriter = new();
private RequestId _pendingRequest;
+ private ISseEventStreamWriter? _eventStreamWriter;
+ private SemaphoreSlim _eventStreamLock = new(1, 1);
public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.Context.RelatedTransport should only be used for sending messages.");
string? ITransport.SessionId => parentTransport.SessionId;
///
- /// True, if data was written to the respond body.
+ /// True, if data was written to the response body.
/// False, if nothing was written because the request body did not contain any messages to respond to.
/// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario.
///
@@ -35,11 +35,11 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio
{
_pendingRequest = request.Id;
- // Invoke the initialize request callback if applicable.
- if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize)
+ // Invoke the initialize request handler if applicable.
+ if (request.Method == RequestMethods.Initialize)
{
var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams);
- await onInitRequest(initializeRequest).ConfigureAwait(false);
+ await parentTransport.HandleInitRequestAsync(initializeRequest).ConfigureAwait(false);
}
}
@@ -58,8 +58,18 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio
return false;
}
+ // Start the write task immediately so that we don't risk filling up the channel with
+ // messages before they start being consumed.
_sseWriter.MessageFilter = StopOnFinalResponseFilter;
- await _sseWriter.WriteAllAsync(responseStream, cancellationToken).ConfigureAwait(false);
+ var writeTask = _sseWriter.WriteAllAsync(responseStream, cancellationToken);
+
+ var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false);
+ if (eventStreamWriter is not null)
+ {
+ await _sseWriter.SendPrimingEventAsync(parentTransport.RetryInterval, eventStreamWriter, cancellationToken).ConfigureAwait(false);
+ }
+
+ await writeTask.ConfigureAwait(false);
return true;
}
@@ -72,18 +82,55 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
throw new InvalidOperationException("Server to client requests are not supported in stateless mode.");
}
- bool isAccepted = await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
- if (!isAccepted)
+ var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false);
+
+ var isAccepted = await _sseWriter.SendMessageAsync(message, eventStreamWriter, cancellationToken).ConfigureAwait(false);
+ if (!isAccepted && eventStreamWriter is null)
{
- // The underlying writer didn't accept the message because the underlying request has completed.
+ // The underlying writer didn't accept the message because the underlying request has completed,
+ // and there isn't a fallback event stream writer.
// Rather than drop the message, fall back to sending it via the parent transport.
await parentTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
+
+ if (_eventStreamWriter is not null && IsFinalResponse(message))
+ {
+ // Complete the SSE event stream writer now that all pending requests have been processed.
+ await _eventStreamWriter.DisposeAsync().ConfigureAwait(false);
+ }
+ }
+
+ public async ValueTask EnablePollingAsync(TimeSpan retryInterval, CancellationToken cancellationToken)
+ {
+ if (parentTransport.Stateless)
+ {
+ // Polling is currently not supported in stateless mode.
+ return;
+ }
+
+ var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false);
+ if (eventStreamWriter is null)
+ {
+ return;
+ }
+
+ // Set the mode to 'Polling' so that the replay stream ends as soon as all available messages have been sent.
+ // This prevents the client from immediately establishing another long-lived connection.
+ await eventStreamWriter.SetModeAsync(SseEventStreamMode.Polling, cancellationToken).ConfigureAwait(false);
+
+ // Send the priming event with the new retry interval.
+ await _sseWriter.SendPrimingEventAsync(retryInterval, eventStreamWriter, cancellationToken).ConfigureAwait(false);
+
+ // Dispose the writer to close it and force future writes to only apply to the SSE event store.
+ await _sseWriter.DisposeAsync();
}
public async ValueTask DisposeAsync()
{
await _sseWriter.DisposeAsync().ConfigureAwait(false);
+
+ // Don't dispose the event stream writer here, as we may continue to write to the event store
+ // after disposal.
}
private async IAsyncEnumerable> StopOnFinalResponseFilter(IAsyncEnumerable> messages, [EnumeratorCancellation] CancellationToken cancellationToken)
@@ -92,11 +139,43 @@ public async ValueTask DisposeAsync()
{
yield return message;
- if (message.Data is JsonRpcResponse or JsonRpcError && ((JsonRpcMessageWithId)message.Data).Id == _pendingRequest)
+ if (IsFinalResponse(message.Data))
{
// Complete the SSE response stream now that all pending requests have been processed.
break;
}
}
}
+
+ private bool IsFinalResponse(JsonRpcMessage? message)
+ => (message is JsonRpcResponse or JsonRpcError) && ((JsonRpcMessageWithId)message).Id == _pendingRequest;
+
+ private async ValueTask GetOrCreateEventStreamAsync(CancellationToken cancellationToken)
+ {
+ using var _ = await _eventStreamLock.LockAsync(cancellationToken).ConfigureAwait(false);
+
+ if (_eventStreamWriter is not null)
+ {
+ return _eventStreamWriter;
+ }
+
+ if (parentTransport.EventStreamStore is not { } eventStreamStore || _pendingRequest.Id is null || !McpSessionHandler.SupportsPrimingEvent(parentTransport.NegotiatedProtocolVersion))
+ {
+ return null;
+ }
+
+ // We use the 'Default' stream mode so that in the case of an unexpected network disconnection,
+ // the client can continue reading the remaining messages in a single, streamed response.
+ // This may be changed to 'Polling' if the transport is later explicitly switched to polling mode.
+ const SseEventStreamMode Mode = SseEventStreamMode.Default;
+
+ _eventStreamWriter = await eventStreamStore.CreateStreamAsync(options: new()
+ {
+ SessionId = parentTransport.SessionId ?? Guid.NewGuid().ToString("N"),
+ StreamId = _pendingRequest.Id.ToString()!,
+ Mode = Mode,
+ }, cancellationToken).ConfigureAwait(false);
+
+ return _eventStreamWriter;
+ }
}
diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs
index c99b1fa39..3731102f9 100644
--- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs
+++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs
@@ -1,5 +1,4 @@
using ModelContextProtocol.Protocol;
-using System.IO.Pipelines;
using System.Security.Claims;
using System.Threading.Channels;
@@ -21,21 +20,30 @@ namespace ModelContextProtocol.Server;
///
public sealed class StreamableHttpServerTransport : ITransport
{
+ ///
+ /// The stream ID used for unsolicited messages sent via the standalone GET SSE stream.
+ ///
+ public static readonly string UnsolicitedMessageStreamId = "__get__";
+
// For JsonRpcMessages without a RelatedTransport, we don't want to block just because the client didn't make a GET request to handle unsolicited messages.
- private readonly SseWriter _sseWriter = new(channelOptions: new BoundedChannelOptions(1)
+ private static readonly BoundedChannelOptions _sseWriterChannelOptions = new(1)
{
SingleReader = true,
SingleWriter = false,
FullMode = BoundedChannelFullMode.DropOldest,
- });
+ };
private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1)
{
SingleReader = true,
SingleWriter = false,
});
private readonly CancellationTokenSource _disposeCts = new();
+ private readonly SemaphoreSlim _sendLock = new(1, 1);
- private int _getRequestStarted;
+ private SseWriter _sseWriter = new(channelOptions: _sseWriterChannelOptions);
+ private ISseEventStreamWriter? _eventStreamWriter;
+ private bool _getRequestStarted;
+ private bool _disposed;
///
public string? SessionId { get; set; }
@@ -63,11 +71,46 @@ public sealed class StreamableHttpServerTransport : ITransport
///
public Func? OnInitRequestReceived { get; set; }
+ ///
+ /// Gets or sets the event store for resumability support.
+ /// When set, events are stored and can be replayed when clients reconnect with a Last-Event-ID header.
+ ///
+ public ISseEventStreamStore? EventStreamStore { get; set; }
+
+ ///
+ /// Gets or sets the retry interval to suggest to clients in SSE retry field.
+ /// When is set, the server will include a retry field in priming events.
+ ///
+ ///
+ /// The default value is 1 second.
+ ///
+ public TimeSpan RetryInterval { get; set; } = TimeSpan.FromSeconds(1);
+
+ ///
+ /// Gets or sets the negotiated protocol version for this session.
+ ///
+ internal string? NegotiatedProtocolVersion { get; private set; }
+
///
public ChannelReader MessageReader => _incomingChannel.Reader;
internal ChannelWriter MessageWriter => _incomingChannel.Writer;
+ ///
+ /// Handles the initialize request by capturing the protocol version and invoking the user callback.
+ ///
+ internal async ValueTask HandleInitRequestAsync(InitializeRequestParams? initParams)
+ {
+ // Capture the negotiated protocol version for resumability checks
+ NegotiatedProtocolVersion = initParams?.ProtocolVersion;
+
+ // Invoke user-provided callback if specified
+ if (OnInitRequestReceived is { } callback)
+ {
+ await callback(initParams).ConfigureAwait(false);
+ }
+ }
+
///
/// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by
/// writing any unsolicited JSON-RPC messages sent via
@@ -78,8 +121,7 @@ public sealed class StreamableHttpServerTransport : ITransport
/// A task representing the send loop that writes JSON-RPC messages to the SSE response stream.
/// is .
///
- /// is and GET requests are not supported in stateless mode,
- /// or a GET request has already been started for this session.
+ /// is and GET requests are not supported in stateless mode.
///
public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default)
{
@@ -90,13 +132,27 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo
throw new InvalidOperationException("GET requests are not supported in stateless mode.");
}
- if (Interlocked.Exchange(ref _getRequestStarted, 1) == 1)
+ Task writeTask;
+ using (await _sendLock.LockAsync(cancellationToken))
{
- throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session.");
+ if (_getRequestStarted)
+ {
+ throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session.");
+ }
+
+ _getRequestStarted = true;
+
+ // We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully.
+ writeTask = _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken);
+
+ var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false);
+ if (eventStreamWriter is not null)
+ {
+ await _sseWriter.SendPrimingEventAsync(RetryInterval, eventStreamWriter, cancellationToken).ConfigureAwait(false);
+ }
}
- // We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully.
- await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false);
+ await writeTask.ConfigureAwait(false);
}
///
@@ -137,13 +193,24 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can
throw new InvalidOperationException("Unsolicited server to client messages are not supported in stateless mode.");
}
- // If the underlying writer has been disposed, just drop the message.
- await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
+ using var _ = await _sendLock.LockAsync(cancellationToken);
+
+ // If the underlying writer has been disposed, rely on the event stream writer, if present.
+ // Otherwise, just drop the message.
+ var eventStreamWriter = await GetOrCreateEventStreamAsync(cancellationToken).ConfigureAwait(false);
+ await _sseWriter.SendMessageAsync(message, eventStreamWriter, cancellationToken).ConfigureAwait(false);
}
///
public async ValueTask DisposeAsync()
{
+ using var _ = await _sendLock.LockAsync();
+
+ if (_disposed)
+ {
+ return;
+ }
+
try
{
_incomingChannel.Writer.TryComplete();
@@ -154,11 +221,43 @@ public async ValueTask DisposeAsync()
try
{
await _sseWriter.DisposeAsync().ConfigureAwait(false);
+
+ if (_eventStreamWriter is not null)
+ {
+ await _eventStreamWriter.DisposeAsync().ConfigureAwait(false);
+ }
}
finally
{
_disposeCts.Dispose();
+ _disposed = true;
}
}
}
+
+ private async ValueTask GetOrCreateEventStreamAsync(CancellationToken cancellationToken)
+ {
+ if (_eventStreamWriter is not null)
+ {
+ return _eventStreamWriter;
+ }
+
+ if (EventStreamStore is null || !McpSessionHandler.SupportsPrimingEvent(NegotiatedProtocolVersion))
+ {
+ return null;
+ }
+
+ // We use the 'Default' stream mode so that in the case of an unexpected network disconnection,
+ // the client can continue reading the remaining messages in a single, streamed response.
+ const SseEventStreamMode Mode = SseEventStreamMode.Default;
+
+ _eventStreamWriter = await EventStreamStore.CreateStreamAsync(options: new()
+ {
+ SessionId = SessionId ?? Guid.NewGuid().ToString("N"),
+ StreamId = UnsolicitedMessageStreamId,
+ Mode = Mode,
+ }, cancellationToken).ConfigureAwait(false);
+
+ return _eventStreamWriter;
+ }
}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs
new file mode 100644
index 000000000..2f6b31779
--- /dev/null
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs
@@ -0,0 +1,459 @@
+using System.ComponentModel;
+using System.Diagnostics;
+using System.Net;
+using System.Net.ServerSentEvents;
+using System.Text;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.Extensions.DependencyInjection;
+using ModelContextProtocol.AspNetCore.Tests.Utils;
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol;
+using ModelContextProtocol.Server;
+
+namespace ModelContextProtocol.AspNetCore.Tests;
+
+///
+/// Integration tests for SSE resumability with full client-server flow.
+/// These tests use McpClient for end-to-end testing and only use raw HTTP
+/// for SSE format verification where McpClient abstracts away the details.
+///
+public class ResumabilityIntegrationTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper)
+{
+ private const string InitializeRequest = """
+ {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"TestClient","version":"1.0.0"}}}
+ """;
+
+ [Fact]
+ public async Task Server_StoresEvents_WhenEventStoreConfigured()
+ {
+ // Arrange
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore);
+ await using var client = await ConnectClientAsync();
+
+ // Act - Make a tool call which generates events
+ var result = await client.CallToolAsync("echo",
+ new Dictionary { ["message"] = "test" },
+ cancellationToken: TestContext.Current.CancellationToken);
+
+ // Assert - Events were stored
+ Assert.NotNull(result);
+ Assert.True(eventStreamStore.StoreEventCallCount > 0, "Expected events to be stored when EventStore is configured");
+ }
+
+ [Fact]
+ public async Task Server_StoresMultipleEvents_ForMultipleToolCalls()
+ {
+ // Arrange
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore);
+ await using var client = await ConnectClientAsync();
+
+ // Act - Make multiple tool calls
+ var initialCount = eventStreamStore.StoreEventCallCount;
+
+ await client.CallToolAsync("echo",
+ new Dictionary { ["message"] = "test1" },
+ cancellationToken: TestContext.Current.CancellationToken);
+
+ var countAfterFirst = eventStreamStore.StoreEventCallCount;
+
+ await client.CallToolAsync("echo",
+ new Dictionary { ["message"] = "test2" },
+ cancellationToken: TestContext.Current.CancellationToken);
+
+ var countAfterSecond = eventStreamStore.StoreEventCallCount;
+
+ // Assert - More events were stored for each call
+ Assert.True(countAfterFirst > initialCount, "Expected more events after first call");
+ Assert.True(countAfterSecond > countAfterFirst, "Expected more events after second call");
+ }
+
+ [Fact]
+ public async Task Client_CanMakeMultipleRequests_WithResumabilityEnabled()
+ {
+ // Arrange
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore);
+ await using var client = await ConnectClientAsync();
+
+ // Act - Make many requests to verify stability
+ for (int i = 0; i < 5; i++)
+ {
+ var result = await client.CallToolAsync("echo",
+ new Dictionary { ["message"] = $"test{i}" },
+ cancellationToken: TestContext.Current.CancellationToken);
+
+ var textContent = Assert.Single(result.Content.OfType());
+ Assert.Equal($"Echo: test{i}", textContent.Text);
+ }
+
+ // Assert - All requests succeeded and events were stored
+ Assert.True(eventStreamStore.StoreEventCallCount >= 5, "Expected events to be stored for each request");
+ }
+
+ [Fact]
+ public async Task Ping_WorksWithResumabilityEnabled()
+ {
+ // Arrange
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore);
+ await using var client = await ConnectClientAsync();
+
+ // Act & Assert - Ping should work
+ await client.PingAsync(cancellationToken: TestContext.Current.CancellationToken);
+ }
+
+ [Fact]
+ public async Task ListTools_WorksWithResumabilityEnabled()
+ {
+ // Arrange
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore);
+ await using var client = await ConnectClientAsync();
+
+ // Act
+ var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
+
+ // Assert
+ Assert.NotNull(tools);
+ Assert.Single(tools);
+ }
+
+ [Fact]
+ public async Task Server_IncludesEventIdAndRetry_InSseResponse()
+ {
+ // Arrange
+ var expectedRetryInterval = TimeSpan.FromSeconds(5);
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore, retryInterval: expectedRetryInterval);
+
+ // Act
+ var sseResponse = await SendInitializeAndReadSseResponseAsync(InitializeRequest);
+
+ // Assert - Event IDs and retry field should be present in the response
+ Assert.True(sseResponse.LastEventId is not null, "Expected SSE response to contain event IDs");
+ Assert.Equal(expectedRetryInterval, sseResponse.RetryInterval);
+ }
+
+ [Fact]
+ public async Task Server_WithoutEventStore_DoesNotIncludeEventIdAndRetry()
+ {
+ // Arrange - Server without event store
+ await using var app = await CreateServerAsync();
+
+ // Act
+ var sseResponse = await SendInitializeAndReadSseResponseAsync(InitializeRequest);
+
+ // Assert - No event IDs or retry field when EventStore is not configured
+ Assert.True(sseResponse.LastEventId is null, "Did not expect event IDs when EventStore is not configured");
+ Assert.True(sseResponse.RetryInterval is null, "Did not expect retry field when EventStore is not configured");
+ }
+
+ [Fact]
+ public async Task Server_DoesNotSendPrimingEvents_ToOlderProtocolVersionClients()
+ {
+ // Arrange - Server with resumability enabled
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore, retryInterval: TimeSpan.FromSeconds(5));
+
+ // Use an older protocol version that doesn't support resumability
+ const string OldProtocolInitRequest = """
+ {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"OldClient","version":"1.0.0"}}}
+ """;
+
+ var sseResponse = await SendInitializeAndReadSseResponseAsync(OldProtocolInitRequest);
+
+ // Assert - Old clients should not receive event IDs or retry fields (no priming events)
+ Assert.True(sseResponse.LastEventId is null, "Old protocol clients should not receive event IDs");
+ Assert.True(sseResponse.RetryInterval is null, "Old protocol clients should not receive retry field");
+
+ // Event store should not have been called for old clients
+ Assert.Equal(0, eventStreamStore.StoreEventCallCount);
+ }
+
+ [Fact]
+ public async Task Client_ReceivesRetryInterval_FromServer()
+ {
+ // Arrange - Server with specific retry interval
+ var expectedRetry = TimeSpan.FromMilliseconds(3000);
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore, retryInterval: expectedRetry);
+
+ // Act - Send initialize and read the retry field
+ var sseItem = await SendInitializeAndReadSseResponseAsync(InitializeRequest);
+
+ // Assert - Client receives the retry interval from server
+ Assert.Equal(expectedRetry, sseItem.RetryInterval);
+ }
+
+ [Fact]
+ public async Task Client_CanPollResponse_FromServer()
+ {
+ const string ProgressToolName = "progress_tool";
+ var clientReceivedInitialValueTcs = new TaskCompletionSource();
+ var clientReceivedPolledValueTcs = new TaskCompletionSource();
+ var progressTool = McpServerTool.Create(async (RequestContext context, IProgress progress) =>
+ {
+ progress.Report(new() { Progress = 0, Message = "Initial value" });
+
+ await clientReceivedInitialValueTcs.Task;
+
+ await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1));
+
+ progress.Report(new() { Progress = 50, Message = "Polled value" });
+
+ await clientReceivedPolledValueTcs.Task;
+
+ return "Complete";
+ }, options: new() { Name = ProgressToolName });
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore, configureServer: builder =>
+ {
+ builder.WithTools([progressTool]);
+ });
+ await using var client = await ConnectClientAsync();
+
+ var progressHandler = new Progress(value =>
+ {
+ switch (value.Message)
+ {
+ case "Initial value":
+ Assert.True(clientReceivedInitialValueTcs.TrySetResult(), "Received the initial value more than once.");
+ break;
+ case "Polled value":
+ Assert.True(clientReceivedPolledValueTcs.TrySetResult(), "Received the polled value more than once.");
+ break;
+ default:
+ throw new UnreachableException($"Unknown progress message '{value.Message}'");
+ }
+ });
+
+ var result = await client.CallToolAsync(ProgressToolName, progress: progressHandler, cancellationToken: TestContext.Current.CancellationToken);
+
+ Assert.False(result.IsError is true);
+ Assert.Equal("Complete", result.Content.OfType().Single().Text);
+ }
+
+ [Fact]
+ public async Task Client_CanResumePostResponseStream_AfterDisconnection()
+ {
+ var manualStreamHandler = new FaultingStreamHandler();
+ SetHttpMessageHandler(manualStreamHandler);
+
+ const string ProgressToolName = "progress_tool";
+ var clientReceivedInitialValueTcs = new TaskCompletionSource();
+ var clientReceivedReconnectValueTcs = new TaskCompletionSource();
+ var progressTool = McpServerTool.Create(async (RequestContext context, IProgress progress, CancellationToken cancellationToken) =>
+ {
+ progress.Report(new() { Progress = 0, Message = "Initial value" });
+
+ // Make sure the client receives one message before we disconnect.
+ await clientReceivedInitialValueTcs.Task;
+
+ // Simulate a network disconnection by faulting the response stream.
+ var reconnectAttempt = await manualStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken);
+
+ // Send another message that the client should receive after reconnecting.
+ progress.Report(new() { Progress = 50, Message = "Reconnect value" });
+
+ reconnectAttempt.Continue();
+
+ // Wait for the client to receive the message via replay.
+ await clientReceivedReconnectValueTcs.Task;
+
+ // Return the final result with the client still connected.
+ return "Complete";
+ }, options: new() { Name = ProgressToolName });
+ var eventStreamStore = new TestSseEventStreamStore();
+ await using var app = await CreateServerAsync(eventStreamStore, configureServer: builder =>
+ {
+ builder.WithTools([progressTool]);
+ });
+ await using var client = await ConnectClientAsync();
+
+ var progressHandler = new Progress(value =>
+ {
+ switch (value.Message)
+ {
+ case "Initial value":
+ clientReceivedInitialValueTcs.SetResult();
+ break;
+ case "Reconnect value":
+ clientReceivedReconnectValueTcs.SetResult();
+ break;
+ default:
+ throw new UnreachableException($"Unknown progress message '{value.Message}'");
+ }
+ });
+
+ var result = await client.CallToolAsync(ProgressToolName, progress: progressHandler, cancellationToken: TestContext.Current.CancellationToken);
+
+ Assert.False(result.IsError is true);
+ Assert.Equal("Complete", result.Content.OfType().Single().Text);
+ }
+
+ [Fact]
+ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
+ {
+ var faultingStreamHandler = new FaultingStreamHandler();
+ SetHttpMessageHandler(faultingStreamHandler);
+
+ var eventStreamStore = new TestSseEventStreamStore();
+
+ // Capture the server instance via RunSessionHandler
+ var serverTcs = new TaskCompletionSource();
+
+ await using var app = await CreateServerAsync(eventStreamStore, configureTransport: options =>
+ {
+ options.RunSessionHandler = (httpContext, mcpServer, cancellationToken) =>
+ {
+ serverTcs.TrySetResult(mcpServer);
+ return mcpServer.RunAsync(cancellationToken);
+ };
+ });
+
+ await using var client = await ConnectClientAsync();
+
+ // Get the server instance
+ var server = await serverTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
+
+ // Set up notification tracking
+ var clientReceivedInitialNotificationTcs = new TaskCompletionSource();
+ var clientReceivedReplayedNotificationTcs = new TaskCompletionSource();
+ var clientReceivedReconnectNotificationTcs = new TaskCompletionSource();
+
+ const string CustomNotificationMethod = "test/custom_notification";
+
+ await using var _ = client.RegisterNotificationHandler(CustomNotificationMethod, (notification, cancellationToken) =>
+ {
+ // First notification completes initial TCS, second completes replay TCS, thirdly completes reconnect TCS.
+ if (clientReceivedInitialNotificationTcs.TrySetResult())
+ {
+ return default;
+ }
+
+ if (clientReceivedReplayedNotificationTcs.TrySetResult())
+ {
+ return default;
+ }
+
+ if (clientReceivedReconnectNotificationTcs.TrySetResult())
+ {
+ return default;
+ }
+
+ return default;
+ });
+
+ // Send a custom notification to the client on the unsolicited message stream
+ await server.SendNotificationAsync(CustomNotificationMethod, TestContext.Current.CancellationToken);
+
+ // Wait for client to receive the first notification
+ await clientReceivedInitialNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
+
+ // Fault the unsolicited message stream (GET SSE)
+ var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken);
+
+ // Send another notification while the client is disconnected - this should be stored
+ await server.SendNotificationAsync(CustomNotificationMethod, TestContext.Current.CancellationToken);
+
+ // Allow the client to reconnect
+ reconnectAttempt.Continue();
+
+ // Wait for client to receive the notification via replay
+ await clientReceivedReplayedNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
+
+ // Send a final notification while the client has reconnected - this should be handled by the transport
+ await server.SendNotificationAsync(CustomNotificationMethod, TestContext.Current.CancellationToken);
+
+ // Wait for the client to receive the final notification
+ await clientReceivedReconnectNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
+ }
+
+ [McpServerToolType]
+ private class ResumabilityTestTools
+ {
+ [McpServerTool(Name = "echo"), Description("Echoes the message back")]
+ public static string Echo(string message) => $"Echo: {message}";
+ }
+
+ private async Task CreateServerAsync(
+ ISseEventStreamStore? eventStreamStore = null,
+ TimeSpan? retryInterval = null,
+ Action? configureServer = null,
+ Action? configureTransport = null)
+ {
+ var serverBuilder = Builder.Services.AddMcpServer()
+ .WithHttpTransport(options =>
+ {
+ options.EventStreamStore = eventStreamStore;
+ if (retryInterval.HasValue)
+ {
+ options.RetryInterval = retryInterval.Value;
+ }
+ configureTransport?.Invoke(options);
+ })
+ .WithTools();
+
+ configureServer?.Invoke(serverBuilder);
+
+ var app = Builder.Build();
+ app.MapMcp();
+
+ await app.StartAsync(TestContext.Current.CancellationToken);
+ return app;
+ }
+
+ private async Task ConnectClientAsync()
+ {
+ var transport = new HttpClientTransport(new HttpClientTransportOptions
+ {
+ Endpoint = new Uri("http://localhost:5000/"),
+ TransportMode = HttpTransportMode.StreamableHttp,
+ }, HttpClient, LoggerFactory);
+
+ return await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory,
+ cancellationToken: TestContext.Current.CancellationToken);
+ }
+
+ private async Task SendInitializeAndReadSseResponseAsync(string initializeRequest)
+ {
+ using var requestContent = new StringContent(initializeRequest, Encoding.UTF8, "application/json");
+ using var request = new HttpRequestMessage(HttpMethod.Post, "/")
+ {
+ Headers =
+ {
+ Accept = { new("application/json"), new("text/event-stream") }
+ },
+ Content = requestContent,
+ };
+
+ var response = await HttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead,
+ TestContext.Current.CancellationToken);
+
+ response.EnsureSuccessStatusCode();
+
+ var sseResponse = new SseResponse();
+ await using var stream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken);
+ await foreach (var sseItem in SseParser.Create(stream).EnumerateAsync(TestContext.Current.CancellationToken))
+ {
+ if (!string.IsNullOrEmpty(sseItem.EventId))
+ {
+ sseResponse.LastEventId = sseItem.EventId;
+ }
+ if (sseItem.ReconnectionInterval.HasValue)
+ {
+ sseResponse.RetryInterval = sseItem.ReconnectionInterval.Value;
+ }
+ }
+
+ return sseResponse;
+ }
+
+ private struct SseResponse
+ {
+ public string? LastEventId { get; set; }
+ public TimeSpan? RetryInterval { get; set; }
+ }
+}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseEventStreamStoreTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseEventStreamStoreTests.cs
new file mode 100644
index 000000000..52391f351
--- /dev/null
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseEventStreamStoreTests.cs
@@ -0,0 +1,637 @@
+using ModelContextProtocol.AspNetCore.Tests.Utils;
+using ModelContextProtocol.Protocol;
+using ModelContextProtocol.Server;
+using ModelContextProtocol.Tests.Utils;
+using System.Net.ServerSentEvents;
+
+namespace ModelContextProtocol.AspNetCore.Tests;
+
+///
+/// Tests for the interface and implementation.
+///
+public class SseEventStreamStoreTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
+{
+ private CancellationToken CancellationToken => TestContext.Current.CancellationToken;
+
+ #region CreateStreamAsync Tests
+
+ [Fact]
+ public async Task CreateStreamAsync_ReturnsWriter_WithCorrectStreamId()
+ {
+ var store = new TestSseEventStreamStore();
+ var options = new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ };
+
+ var writer = await store.CreateStreamAsync(options, CancellationToken);
+
+ Assert.NotNull(writer);
+ Assert.Equal("stream-1", writer.StreamId);
+ }
+
+ [Fact]
+ public async Task CreateStreamAsync_ReturnsWriter_WithCorrectMode()
+ {
+ var store = new TestSseEventStreamStore();
+
+ var defaultModeOptions = new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ };
+ var defaultWriter = await store.CreateStreamAsync(defaultModeOptions, CancellationToken);
+ Assert.Equal(SseEventStreamMode.Default, defaultWriter.Mode);
+
+ var pollingModeOptions = new SseEventStreamOptions
+ {
+ SessionId = "session-2",
+ StreamId = "stream-2",
+ Mode = SseEventStreamMode.Polling
+ };
+ var pollingWriter = await store.CreateStreamAsync(pollingModeOptions, CancellationToken);
+ Assert.Equal(SseEventStreamMode.Polling, pollingWriter.Mode);
+ }
+
+ [Fact]
+ public async Task CreateStreamAsync_MultipleStreams_CreatesDistinctWriters()
+ {
+ var store = new TestSseEventStreamStore();
+
+ var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-2",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ Assert.NotSame(writer1, writer2);
+ Assert.Equal("stream-1", writer1.StreamId);
+ Assert.Equal("stream-2", writer2.StreamId);
+ }
+
+ #endregion
+
+ #region WriteEventAsync Tests
+
+ [Fact]
+ public async Task WriteEventAsync_AssignsEventId_WhenNotPresent()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var message = new JsonRpcRequest { Method = "test", Id = new RequestId("1") };
+ var item = new SseItem(message, "message");
+
+ var result = await writer.WriteEventAsync(item, CancellationToken);
+
+ Assert.NotNull(result.EventId);
+ Assert.Equal("1", result.EventId);
+ }
+
+ [Fact]
+ public async Task WriteEventAsync_SkipsAssigningEventId_WhenAlreadyPresent()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var message = new JsonRpcRequest { Method = "test", Id = new RequestId("1") };
+ var item = new SseItem(message, "message") { EventId = "existing-id" };
+
+ var result = await writer.WriteEventAsync(item, CancellationToken);
+
+ Assert.Equal("existing-id", result.EventId);
+ }
+
+ [Fact]
+ public async Task WriteEventAsync_GeneratesSequentialEventIds()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") };
+ var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") };
+ var message3 = new JsonRpcRequest { Method = "test3", Id = new RequestId("3") };
+
+ var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+ var result2 = await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken);
+ var result3 = await writer.WriteEventAsync(new SseItem(message3, "message"), CancellationToken);
+
+ Assert.Equal("1", result1.EventId);
+ Assert.Equal("2", result2.EventId);
+ Assert.Equal("3", result3.EventId);
+ }
+
+ [Fact]
+ public async Task WriteEventAsync_TracksStoredEventIds()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") };
+ var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") };
+
+ await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+ await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken);
+
+ Assert.Equal(2, store.StoreEventCallCount);
+ Assert.Equal(["1", "2"], store.StoredEventIds);
+ }
+
+ [Fact]
+ public async Task WriteEventAsync_PreservesEventData()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var message = new JsonRpcRequest { Method = "test-method", Id = new RequestId("req-1") };
+ var item = new SseItem(message, "custom-event");
+
+ var result = await writer.WriteEventAsync(item, CancellationToken);
+
+ Assert.Same(message, result.Data);
+ Assert.Equal("custom-event", result.EventType);
+ }
+
+ [Fact]
+ public async Task WriteEventAsync_HandlesNullData()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var item = new SseItem(null, "priming");
+
+ var result = await writer.WriteEventAsync(item, CancellationToken);
+
+ Assert.NotNull(result.EventId);
+ Assert.Null(result.Data);
+ }
+
+ #endregion
+
+ #region SetModeAsync Tests
+
+ [Fact]
+ public async Task SetModeAsync_ChangesMode_FromDefaultToPolling()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ Assert.Equal(SseEventStreamMode.Default, writer.Mode);
+
+ await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken);
+
+ Assert.Equal(SseEventStreamMode.Polling, writer.Mode);
+ }
+
+ [Fact]
+ public async Task SetModeAsync_ChangesMode_FromPollingToDefault()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Polling
+ }, CancellationToken);
+
+ Assert.Equal(SseEventStreamMode.Polling, writer.Mode);
+
+ await writer.SetModeAsync(SseEventStreamMode.Default, CancellationToken);
+
+ Assert.Equal(SseEventStreamMode.Default, writer.Mode);
+ }
+
+ #endregion
+
+ #region GetStreamReaderAsync Tests
+
+ [Fact]
+ public async Task GetStreamReaderAsync_ReturnsNull_WhenEventIdNotFound()
+ {
+ var store = new TestSseEventStreamStore();
+
+ var reader = await store.GetStreamReaderAsync("nonexistent-id", CancellationToken);
+
+ Assert.Null(reader);
+ }
+
+ [Fact]
+ public async Task GetStreamReaderAsync_ReturnsReader_WhenEventIdExists()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Polling
+ }, CancellationToken);
+
+ var message = new JsonRpcRequest { Method = "test", Id = new RequestId("1") };
+ var result = await writer.WriteEventAsync(new SseItem(message, "message"), CancellationToken);
+
+ var reader = await store.GetStreamReaderAsync(result.EventId!, CancellationToken);
+
+ Assert.NotNull(reader);
+ Assert.Equal("stream-1", reader.StreamId);
+ }
+
+ [Fact]
+ public async Task GetStreamReaderAsync_Throws_WhenStreamReplaced()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ // Create a new stream with the same key, effectively replacing the old one
+ await Assert.ThrowsAsync(async () => await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken));
+ }
+
+ #endregion
+
+ #region ReadEventsAsync Tests
+
+ [Fact]
+ public async Task ReadEventsAsync_ReturnsEventsAfterLastEventId()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Polling
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") };
+ var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") };
+ var message3 = new JsonRpcRequest { Method = "test3", Id = new RequestId("3") };
+
+ var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+ await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken);
+ await writer.WriteEventAsync(new SseItem(message3, "message"), CancellationToken);
+
+ var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken);
+ Assert.NotNull(reader);
+
+ var events = await reader.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken);
+
+ Assert.Equal(2, events.Count);
+ Assert.Equal("test2", ((JsonRpcRequest)events[0].Data!).Method);
+ Assert.Equal("test3", ((JsonRpcRequest)events[1].Data!).Method);
+ }
+
+ [Fact]
+ public async Task ReadEventsAsync_IncludesNullDataEvents()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Polling
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") };
+ var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") };
+
+ var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+ await writer.WriteEventAsync(new SseItem(null, "priming"), CancellationToken); // null data event
+ await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken);
+
+ var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken);
+ Assert.NotNull(reader);
+
+ var events = await reader.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken);
+
+ Assert.Equal(2, events.Count);
+ Assert.Null(events[0].Data);
+ Assert.Equal("priming", events[0].EventType);
+ Assert.Equal("test2", ((JsonRpcRequest)events[1].Data!).Method);
+ }
+
+ [Fact]
+ public async Task ReadEventsAsync_ReturnsEmpty_WhenNoEventsAfterLastEventId()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Polling
+ }, CancellationToken);
+
+ var message = new JsonRpcRequest { Method = "test", Id = new RequestId("1") };
+ var result = await writer.WriteEventAsync(new SseItem(message, "message"), CancellationToken);
+
+ var reader = await store.GetStreamReaderAsync(result.EventId!, CancellationToken);
+ Assert.NotNull(reader);
+
+ var events = await reader.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken);
+
+ Assert.Empty(events);
+ }
+
+ [Fact]
+ public async Task ReadEventsAsync_InPollingMode_CompletesAfterStoredEvents()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Polling
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") };
+ var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") };
+
+ var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+ await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken);
+
+ var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken);
+ Assert.NotNull(reader);
+
+ // In polling mode, ReadEventsAsync should complete immediately after returning stored events
+ using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken);
+ cts.CancelAfter(TimeSpan.FromSeconds(1));
+ var events = await reader.ReadEventsAsync(cts.Token).ToListAsync(cts.Token);
+
+ Assert.Single(events);
+ }
+
+ [Fact]
+ public async Task ReadEventsAsync_InDefaultMode_WaitsForNewEvents()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") };
+ var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+
+ var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken);
+ Assert.NotNull(reader);
+
+ var readTask = Task.Run(async () =>
+ {
+ var events = new List>();
+ await foreach (var e in reader.ReadEventsAsync(CancellationToken))
+ {
+ events.Add(e);
+ if (events.Count >= 2)
+ {
+ break;
+ }
+ }
+ return events;
+ }, CancellationToken);
+
+ // Give the read task time to start waiting
+ await Task.Delay(50, CancellationToken);
+
+ // Write new events
+ var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") };
+ var message3 = new JsonRpcRequest { Method = "test3", Id = new RequestId("3") };
+ await writer.WriteEventAsync(new SseItem(message2, "message"), CancellationToken);
+ await writer.WriteEventAsync(new SseItem(message3, "message"), CancellationToken);
+
+ using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken);
+ cts.CancelAfter(TimeSpan.FromSeconds(5));
+ var events = await readTask.WaitAsync(cts.Token);
+
+ Assert.Equal(2, events.Count);
+ Assert.Equal("test2", ((JsonRpcRequest)events[0].Data!).Method);
+ Assert.Equal("test3", ((JsonRpcRequest)events[1].Data!).Method);
+ }
+
+ [Fact]
+ public async Task ReadEventsAsync_InDefaultMode_CompletesWhenWriterDisposed()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") };
+ var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+
+ var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken);
+ Assert.NotNull(reader);
+
+ var readTask = reader.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken).AsTask();
+
+ // Give the read task time to start waiting
+ await Task.Delay(50, CancellationToken);
+
+ // Dispose the writer
+ await writer.DisposeAsync();
+
+ using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken);
+ cts.CancelAfter(TimeSpan.FromSeconds(5));
+ var events = await readTask.WaitAsync(cts.Token);
+
+ Assert.Empty(events);
+ }
+
+ [Fact]
+ public async Task ReadEventsAsync_RespectsCancellation()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") };
+ var result1 = await writer.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+
+ var reader = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken);
+ Assert.NotNull(reader);
+
+ using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken);
+ cts.CancelAfter(TimeSpan.FromMilliseconds(100));
+
+ await Assert.ThrowsAsync(async () =>
+ {
+ await reader.ReadEventsAsync(cts.Token).ToListAsync(cts.Token);
+ });
+ }
+
+ #endregion
+
+ #region Cross-Session Tests
+
+ [Fact]
+ public async Task EventIds_AreUniqueAcrossSessions()
+ {
+ var store = new TestSseEventStreamStore();
+
+ var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-2",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "test1", Id = new RequestId("1") };
+ var message2 = new JsonRpcRequest { Method = "test2", Id = new RequestId("2") };
+
+ var result1 = await writer1.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+ var result2 = await writer2.WriteEventAsync(new SseItem(message2, "message"), CancellationToken);
+
+ Assert.NotEqual(result1.EventId, result2.EventId);
+ }
+
+ [Fact]
+ public async Task GetStreamReaderAsync_ReturnsCorrectStream_ForDifferentSessions()
+ {
+ var store = new TestSseEventStreamStore();
+
+ var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Polling
+ }, CancellationToken);
+
+ var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-2",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Polling
+ }, CancellationToken);
+
+ var message1 = new JsonRpcRequest { Method = "session1-test", Id = new RequestId("1") };
+ var message2 = new JsonRpcRequest { Method = "session2-test", Id = new RequestId("2") };
+ var message1b = new JsonRpcRequest { Method = "session1-test2", Id = new RequestId("3") };
+ var message2b = new JsonRpcRequest { Method = "session2-test2", Id = new RequestId("4") };
+
+ var result1 = await writer1.WriteEventAsync(new SseItem(message1, "message"), CancellationToken);
+ var result2 = await writer2.WriteEventAsync(new SseItem(message2, "message"), CancellationToken);
+ await writer1.WriteEventAsync(new SseItem(message1b, "message"), CancellationToken);
+ await writer2.WriteEventAsync(new SseItem(message2b, "message"), CancellationToken);
+
+ var reader1 = await store.GetStreamReaderAsync(result1.EventId!, CancellationToken);
+ var reader2 = await store.GetStreamReaderAsync(result2.EventId!, CancellationToken);
+
+ Assert.NotNull(reader1);
+ Assert.NotNull(reader2);
+
+ var events1 = await reader1.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken);
+ var events2 = await reader2.ReadEventsAsync(CancellationToken).ToListAsync(CancellationToken);
+
+ Assert.Single(events1);
+ Assert.Equal("session1-test2", ((JsonRpcRequest)events1[0].Data!).Method);
+
+ Assert.Single(events2);
+ Assert.Equal("session2-test2", ((JsonRpcRequest)events2[0].Data!).Method);
+ }
+
+ #endregion
+
+ #region DisposeAsync Tests
+
+ [Fact]
+ public async Task DisposeAsync_CompletesWithoutError()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ await writer.DisposeAsync();
+ // Should not throw
+ }
+
+ [Fact]
+ public async Task DisposeAsync_CanBeCalledMultipleTimes()
+ {
+ var store = new TestSseEventStreamStore();
+ var writer = await store.CreateStreamAsync(new SseEventStreamOptions
+ {
+ SessionId = "session-1",
+ StreamId = "stream-1",
+ Mode = SseEventStreamMode.Default
+ }, CancellationToken);
+
+ await writer.DisposeAsync();
+ await writer.DisposeAsync();
+ await writer.DisposeAsync();
+ // Should not throw
+ }
+
+ #endregion
+}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs
index a843e2975..91af874ef 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs
@@ -85,13 +85,10 @@ public async Task EnablingStatelessMode_Disables_SseEndpoints()
}
[Fact]
- public async Task EnablingStatelessMode_Disables_GetAndDeleteEndpoints()
+ public async Task EnablingStatelessMode_Disables_DeleteEndpoint()
{
await StartAsync();
- using var getResponse = await HttpClient.GetAsync("/", TestContext.Current.CancellationToken);
- Assert.Equal(HttpStatusCode.MethodNotAllowed, getResponse.StatusCode);
-
using var deleteResponse = await HttpClient.DeleteAsync("/", TestContext.Current.CancellationToken);
Assert.Equal(HttpStatusCode.MethodNotAllowed, deleteResponse.StatusCode);
}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs
new file mode 100644
index 000000000..c42fe082d
--- /dev/null
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs
@@ -0,0 +1,157 @@
+using System.Diagnostics;
+using System.Net;
+
+namespace ModelContextProtocol.AspNetCore.Tests.Utils;
+
+///
+/// A message handler that wraps SSE response streams and can trigger faults mid-stream
+/// to simulate network disconnections during SSE streaming.
+///
+internal sealed class FaultingStreamHandler : DelegatingHandler
+{
+ private FaultingStream? _lastStream;
+ private TaskCompletionSource? _reconnectTcs;
+
+ public async Task TriggerFaultAsync(CancellationToken cancellationToken)
+ {
+ if (_lastStream is null or { IsDisposed: true })
+ {
+ throw new InvalidOperationException("There is no active response stream to fault.");
+ }
+
+ if (_reconnectTcs is not null)
+ {
+ throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection.");
+ }
+
+ _reconnectTcs = new();
+ await _lastStream.TriggerFaultAsync(cancellationToken);
+
+ return new(_reconnectTcs);
+ }
+
+ public sealed class ReconnectAttempt(TaskCompletionSource reconnectTcs)
+ {
+ public void Continue()
+ => reconnectTcs.SetResult();
+ }
+
+ protected override async Task SendAsync(
+ HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ if (_reconnectTcs is not null && request.Headers.Accept.Contains(new("text/event-stream")))
+ {
+ // If we're blocking reconnection, wait until we're allowed to continue.
+ await _reconnectTcs.Task.WaitAsync(cancellationToken);
+ _reconnectTcs = null;
+ }
+
+ var response = await base.SendAsync(request, cancellationToken);
+
+ // Only wrap SSE streams (text/event-stream)
+ if (response.Content.Headers.ContentType?.MediaType == "text/event-stream")
+ {
+ var originalStream = await response.Content.ReadAsStreamAsync(cancellationToken);
+ _lastStream = new FaultingStream(originalStream);
+ var faultingContent = new FaultingStreamContent(_lastStream);
+
+ // Copy headers from original content
+ var newContent = faultingContent;
+ foreach (var header in response.Content.Headers)
+ {
+ newContent.Headers.TryAddWithoutValidation(header.Key, header.Value);
+ }
+
+ response.Content = newContent;
+ }
+
+ return response;
+ }
+
+ private sealed class FaultingStreamContent(FaultingStream stream) : HttpContent
+ {
+ private readonly FaultingStream _manualStream = new(stream);
+
+ protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context)
+ => throw new NotSupportedException();
+
+ protected override Task CreateContentReadStreamAsync()
+ => Task.FromResult(stream);
+
+ protected override bool TryComputeLength(out long length)
+ {
+ length = -1;
+ return false;
+ }
+ }
+
+ private sealed class FaultingStream(Stream innerStream) : Stream
+ {
+ private readonly CancellationTokenSource _cts = new();
+ private TaskCompletionSource? _faultTcs;
+ private bool _disposed;
+
+ public bool IsDisposed => _disposed;
+
+ public async Task TriggerFaultAsync(CancellationToken cancellationToken)
+ {
+ if (_faultTcs is not null)
+ {
+ throw new InvalidOperationException("Only one fault can be triggered per stream.");
+ }
+
+ _faultTcs = new TaskCompletionSource();
+
+ await _cts.CancelAsync();
+ await _faultTcs.Task.WaitAsync(cancellationToken);
+ }
+
+ public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default)
+ {
+ try
+ {
+ _cts.Token.ThrowIfCancellationRequested();
+
+ using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cts.Token);
+ return await innerStream.ReadAsync(buffer, linkedCts.Token);
+ }
+ catch (OperationCanceledException) when (_cts.IsCancellationRequested)
+ {
+ Debug.Assert(_faultTcs is not null);
+
+ if (!_faultTcs.TrySetResult())
+ {
+ throw new InvalidOperationException("Attempted to read an already-faulted stream.");
+ }
+
+ throw new IOException("Simulated network disconnection.");
+ }
+ }
+
+ public override int Read(byte[] buffer, int offset, int count)
+ => throw new NotSupportedException("Synchronous reads are not supported.");
+
+ public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ => ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
+
+ public override bool CanRead => innerStream.CanRead;
+ public override bool CanSeek => innerStream.CanSeek;
+ public override bool CanWrite => innerStream.CanWrite;
+ public override long Length => innerStream.Length;
+ public override long Position { get => innerStream.Position; set => innerStream.Position = value; }
+ public override void Flush() => innerStream.Flush();
+ public override long Seek(long offset, SeekOrigin origin) => innerStream.Seek(offset, origin);
+ public override void SetLength(long value) => innerStream.SetLength(value);
+ public override void Write(byte[] buffer, int offset, int count) => innerStream.Write(buffer, offset, count);
+ protected override void Dispose(bool disposing)
+ {
+ if (!disposing || _disposed)
+ {
+ return;
+ }
+
+ _disposed = true;
+ innerStream.Dispose();
+ }
+ }
+}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs
index fe70c2fa5..8bab71290 100644
--- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs
@@ -9,6 +9,8 @@ namespace ModelContextProtocol.AspNetCore.Tests.Utils;
public class KestrelInMemoryTest : LoggedTest
{
+ private readonly TestHttpMessageHandler _httpMessageHandler;
+
public KestrelInMemoryTest(ITestOutputHelper testOutputHelper)
: base(testOutputHelper)
{
@@ -27,7 +29,9 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper)
return new(connection.ClientStream);
};
- HttpClient = new HttpClient(SocketsHttpHandler)
+ _httpMessageHandler = new(SocketsHttpHandler);
+
+ HttpClient = new HttpClient(_httpMessageHandler)
{
BaseAddress = new Uri("http://localhost:5000/"),
Timeout = TimeSpan.FromSeconds(10),
@@ -42,9 +46,33 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper)
public KestrelInMemoryTransport KestrelInMemoryTransport { get; } = new();
+ protected void SetHttpMessageHandler(DelegatingHandler? handler)
+ {
+ if (handler is null)
+ {
+ _httpMessageHandler.InnerHandler = SocketsHttpHandler;
+ }
+ else
+ {
+ _httpMessageHandler.InnerHandler = handler;
+ handler.InnerHandler = SocketsHttpHandler;
+ }
+ }
+
public override void Dispose()
{
HttpClient.Dispose();
base.Dispose();
}
+
+ private sealed class TestHttpMessageHandler : DelegatingHandler
+ {
+ public TestHttpMessageHandler(HttpMessageHandler innerHandler)
+ {
+ InnerHandler = innerHandler;
+ }
+
+ protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
+ => base.SendAsync(request, cancellationToken);
+ }
}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs
new file mode 100644
index 000000000..e4058eacd
--- /dev/null
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs
@@ -0,0 +1,246 @@
+using ModelContextProtocol.Protocol;
+using ModelContextProtocol.Server;
+using System.Collections.Concurrent;
+using System.Net.ServerSentEvents;
+using System.Runtime.CompilerServices;
+using System.Threading.Channels;
+
+namespace ModelContextProtocol.AspNetCore.Tests.Utils;
+
+///
+/// In-memory event store for testing resumability.
+/// This is a simple implementation intended for testing, not for production use.
+///
+public sealed class TestSseEventStreamStore : ISseEventStreamStore
+{
+ private readonly ConcurrentDictionary _streams = new();
+ private readonly ConcurrentDictionary _eventLookup = new();
+ private readonly List _storedEventIds = [];
+ private readonly object _storedEventIdsLock = new();
+ private int _storeEventCallCount;
+ private long _globalSequence;
+
+ ///
+ /// Gets the number of times events have been stored.
+ ///
+ public int StoreEventCallCount => _storeEventCallCount;
+
+ ///
+ /// Gets the list of stored event IDs in order.
+ ///
+ public IReadOnlyList StoredEventIds
+ {
+ get
+ {
+ lock (_storedEventIdsLock)
+ {
+ return [.. _storedEventIds];
+ }
+ }
+ }
+
+ ///
+ public ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default)
+ {
+ var streamKey = GetStreamKey(options.SessionId, options.StreamId);
+ var state = new StreamState(options.SessionId, options.StreamId, options.Mode);
+ if (!_streams.TryAdd(streamKey, state))
+ {
+ throw new InvalidOperationException($"A stream with key '{streamKey}' has already been created.");
+ }
+ var writer = new InMemoryEventStreamWriter(this, state);
+ return new ValueTask(writer);
+ }
+
+ ///
+ public ValueTask GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default)
+ {
+ // Look up the event by its ID to find which stream it belongs to
+ if (!_eventLookup.TryGetValue(lastEventId, out var lookup))
+ {
+ return new ValueTask((ISseEventStreamReader?)null);
+ }
+
+ var reader = new InMemoryEventStreamReader(lookup.Stream, lookup.Sequence);
+ return new ValueTask(reader);
+ }
+
+ private string GenerateEventId() => Interlocked.Increment(ref _globalSequence).ToString();
+
+ private void TrackEvent(string eventId, StreamState stream, long sequence)
+ {
+ _eventLookup[eventId] = (stream, sequence);
+ lock (_storedEventIdsLock)
+ {
+ _storedEventIds.Add(eventId);
+ }
+ Interlocked.Increment(ref _storeEventCallCount);
+ }
+
+ private static string GetStreamKey(string sessionId, string streamId) => $"{sessionId}:{streamId}";
+
+ ///
+ /// Holds the state for a single stream.
+ ///
+ private sealed class StreamState
+ {
+ private readonly Channel<(SseItem Item, long Sequence)> _channel;
+ private readonly List<(SseItem Item, long Sequence)> _events = [];
+ private readonly object _lock = new();
+ private long _sequence;
+
+ public StreamState(string sessionId, string streamId, SseEventStreamMode mode)
+ {
+ SessionId = sessionId;
+ StreamId = streamId;
+ Mode = mode;
+ _channel = Channel.CreateUnbounded<(SseItem, long)>();
+ }
+
+ public string SessionId { get; }
+ public string StreamId { get; }
+ public SseEventStreamMode Mode { get; set; }
+ public bool IsCompleted { get; private set; }
+
+ public long NextSequence() => Interlocked.Increment(ref _sequence);
+
+ public void AddEvent(SseItem item, long sequence)
+ {
+ if (IsCompleted)
+ {
+ throw new InvalidOperationException("Cannot add events to a completed stream.");
+ }
+
+ lock (_lock)
+ {
+ _events.Add((item, sequence));
+ }
+ _channel.Writer.TryWrite((item, sequence));
+ }
+
+ public List> GetEventsAfter(long sequence)
+ {
+ lock (_lock)
+ {
+ var result = new List>();
+ foreach (var (item, seq) in _events)
+ {
+ if (seq > sequence)
+ {
+ result.Add(item);
+ }
+ }
+ return result;
+ }
+ }
+
+ public ChannelReader<(SseItem Item, long Sequence)> Reader => _channel.Reader;
+
+ public void Complete()
+ {
+ IsCompleted = true;
+ _channel.Writer.TryComplete();
+ }
+ }
+
+ private sealed class InMemoryEventStreamWriter : ISseEventStreamWriter
+ {
+ private readonly TestSseEventStreamStore _store;
+ private readonly StreamState _state;
+ private bool _disposed;
+
+ public InMemoryEventStreamWriter(TestSseEventStreamStore store, StreamState state)
+ {
+ _store = store;
+ _state = state;
+ }
+
+ public string StreamId => _state.StreamId;
+ public SseEventStreamMode Mode => _state.Mode;
+
+ public ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default)
+ {
+ _state.Mode = mode;
+ return default;
+ }
+
+ public ValueTask> WriteEventAsync(SseItem sseItem, CancellationToken cancellationToken = default)
+ {
+ // Skip if already has an event ID
+ if (sseItem.EventId is not null)
+ {
+ return new ValueTask>(sseItem);
+ }
+
+ var sequence = _state.NextSequence();
+ var eventId = _store.GenerateEventId();
+ var newItem = sseItem with { EventId = eventId };
+
+ _state.AddEvent(newItem, sequence);
+ _store.TrackEvent(eventId, _state, sequence);
+
+ return new ValueTask>(newItem);
+ }
+
+ public ValueTask DisposeAsync()
+ {
+ if (_disposed)
+ {
+ return default;
+ }
+
+ _disposed = true;
+ _state.Complete();
+ return default;
+ }
+ }
+
+ private sealed class InMemoryEventStreamReader : ISseEventStreamReader
+ {
+ private readonly StreamState _state;
+ private readonly long _startSequence;
+
+ public InMemoryEventStreamReader(StreamState state, long startSequence)
+ {
+ _state = state;
+ _startSequence = startSequence;
+ }
+
+ public string SessionId => _state.SessionId;
+ public string StreamId => _state.StreamId;
+
+ public async IAsyncEnumerable> ReadEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ // First, return any events that were already written after the start sequence
+ var existingEvents = _state.GetEventsAfter(_startSequence);
+ long lastSeenSequence = _startSequence;
+ foreach (var evt in existingEvents)
+ {
+ yield return evt;
+ }
+
+ // If in polling mode, stop after returning currently available events
+ if (_state.Mode == SseEventStreamMode.Polling)
+ {
+ yield break;
+ }
+
+ // If the stream is already completed, stop
+ if (_state.IsCompleted)
+ {
+ yield break;
+ }
+
+ // Wait for new events from the channel
+ await foreach (var (item, sequence) in _state.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
+ {
+ // Only yield events we haven't seen yet
+ if (sequence > lastSeenSequence)
+ {
+ lastSeenSequence = sequence;
+ yield return item;
+ }
+ }
+ }
+ }
+}
diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs
index 86cefcf10..9aeec6b1e 100644
--- a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs
+++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs
@@ -107,7 +107,7 @@ public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperat
{
Messages =
[
- new SamplingMessage
+ new SamplingMessage
{
Role = Role.User,
Content = [new TextContentBlock { Text = "Hello" }]
@@ -157,7 +157,7 @@ public async Task CreateSamplingHandler_ShouldHandleImageMessages()
{
Messages =
[
- new SamplingMessage
+ new SamplingMessage
{
Role = Role.User,
Content = [new ImageContentBlock
@@ -492,7 +492,7 @@ public async Task AsClientLoggerProvider_MessagesSentToClient()
public async Task ReturnsNegotiatedProtocolVersion(string? protocolVersion)
{
await using McpClient client = await CreateMcpClientForServer(new() { ProtocolVersion = protocolVersion });
- Assert.Equal(protocolVersion ?? "2025-06-18", client.NegotiatedProtocolVersion);
+ Assert.Equal(protocolVersion ?? "2025-11-25", client.NegotiatedProtocolVersion);
}
[Fact]
@@ -500,7 +500,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn
{
int getWeatherToolCallCount = 0;
int askClientToolCallCount = 0;
-
+
Server.ServerOptions.ToolCollection?.Add(McpServerTool.Create(
async (McpServer server, string query, CancellationToken cancellationToken) =>
{
@@ -513,14 +513,14 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn
return $"Weather in {location}: sunny, 22°C";
},
"get_weather", "Gets the weather for a location");
-
+
var response = await server
.AsSamplingChatClient()
.AsBuilder()
.UseFunctionInvocation()
.Build()
.GetResponseAsync(query, new ChatOptions { Tools = [weatherTool] }, cancellationToken);
-
+
return response.Text ?? "No response";
},
new() { Name = "ask_client", Description = "Asks the client a question using sampling" }));
@@ -530,7 +530,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn
{
int currentCall = samplingCallCount++;
var lastMessage = messages.LastOrDefault();
-
+
// First call: Return a tool call request for get_weather
if (currentCall == 0)
{
@@ -552,7 +552,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn
string resultText = toolResult.Result?.ToString() ?? string.Empty;
Assert.Contains("Weather in Paris: sunny", resultText);
-
+
return Task.FromResult(new([
new ChatMessage(ChatRole.User, messages.First().Contents),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call_weather_123", "get_weather", new Dictionary { ["location"] = "Paris" })]),
@@ -577,7 +577,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn
cancellationToken: TestContext.Current.CancellationToken);
Assert.NotNull(result);
Assert.Null(result.IsError);
-
+
var textContent = result.Content.OfType().FirstOrDefault();
Assert.NotNull(textContent);
Assert.Contains("Weather in Paris: sunny, 22", textContent.Text);
@@ -585,7 +585,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn
Assert.Equal(1, askClientToolCallCount);
Assert.Equal(2, samplingCallCount);
}
-
+
/// Simple test IChatClient implementation for testing.
private sealed class TestChatClient(Func, ChatOptions?, CancellationToken, Task> getResponse) : IChatClient
{
@@ -594,7 +594,7 @@ public Task GetResponseAsync(
ChatOptions? options = null,
CancellationToken cancellationToken = default) =>
getResponse(messages, options, cancellationToken);
-
+
async IAsyncEnumerable IChatClient.GetStreamingResponseAsync(
IEnumerable messages,
ChatOptions? options,
@@ -606,7 +606,7 @@ async IAsyncEnumerable IChatClient.GetStreamingResponseAsync
yield return update;
}
}
-
+
object? IChatClient.GetService(Type serviceType, object? serviceKey) => null;
void IDisposable.Dispose() { }
}