Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/browser/stores/WorkspaceStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ export class WorkspaceStore {
data: WorkspaceChatMessage
) => void
> = {
"stream-pending": (workspaceId, aggregator, data) => {
aggregator.handleStreamPending(data as never);
if (this.onModelUsed) {
this.onModelUsed((data as { model: string }).model);
}
this.states.bump(workspaceId);
// Bump usage store so liveUsage can show the current model even before streaming starts
this.usageStore.bump(workspaceId);
},
"stream-start": (workspaceId, aggregator, data) => {
aggregator.handleStreamStart(data as never);
if (this.onModelUsed) {
Expand Down Expand Up @@ -480,7 +489,7 @@ export class WorkspaceStore {
name: metadata?.name ?? workspaceId, // Fall back to ID if metadata missing
messages: aggregator.getDisplayedMessages(),
queuedMessage: this.queuedMessages.get(workspaceId) ?? null,
canInterrupt: activeStreams.length > 0,
canInterrupt: activeStreams.length > 0 || aggregator.hasConnectingStreams(),
isCompacting: aggregator.isCompacting(),
awaitingUserQuestion: aggregator.hasAwaitingUserQuestion(),
loading: !hasMessages && !isCaughtUp,
Expand Down Expand Up @@ -965,7 +974,8 @@ export class WorkspaceStore {
// Check if there's an active stream in buffered events (reconnection scenario)
const pendingEvents = this.pendingStreamEvents.get(workspaceId) ?? [];
const hasActiveStream = pendingEvents.some(
(event) => "type" in event && event.type === "stream-start"
(event) =>
"type" in event && (event.type === "stream-start" || event.type === "stream-pending")
);

// Load historical messages first
Expand Down
78 changes: 51 additions & 27 deletions src/browser/utils/messages/StreamingMessageAggregator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type {
} from "@/common/types/message";
import { createMuxMessage } from "@/common/types/message";
import type {
StreamPendingEvent,
StreamStartEvent,
StreamDeltaEvent,
UsageDeltaEvent,
Expand Down Expand Up @@ -75,6 +76,9 @@ function hasFailureResult(result: unknown): boolean {
}

export class StreamingMessageAggregator {
// Streams that have been registered/started in the backend but haven't emitted stream-start yet.
// This is the "connecting" phase: abort should work, but no deltas have started.
private connectingStreams = new Map<string, { startTime: number; model: string }>();
private messages = new Map<string, MuxMessage>();
private activeStreams = new Map<string, StreamingContext>();

Expand Down Expand Up @@ -283,6 +287,7 @@ export class StreamingMessageAggregator {
*/
private cleanupStreamState(messageId: string): void {
this.activeStreams.delete(messageId);
this.connectingStreams.delete(messageId);
// Clear todos when stream ends - they're stream-scoped state
// On reload, todos will be reconstructed from completed tool_write calls in history
this.currentTodos = [];
Expand Down Expand Up @@ -391,6 +396,9 @@ export class StreamingMessageAggregator {
this.pendingStreamStartTime = time;
}

hasConnectingStreams(): boolean {
return this.connectingStreams.size > 0;
}
getActiveStreams(): StreamingContext[] {
return Array.from(this.activeStreams.values());
}
Expand Down Expand Up @@ -418,6 +426,11 @@ export class StreamingMessageAggregator {
return context.model;
}

// If we're connecting (stream-pending), return that model
for (const context of this.connectingStreams.values()) {
return context.model;
}

// Otherwise, return the model from the most recent assistant message
const messages = this.getAllMessages();
for (let i = messages.length - 1; i >= 0; i--) {
Expand All @@ -437,6 +450,7 @@ export class StreamingMessageAggregator {
clear(): void {
this.messages.clear();
this.activeStreams.clear();
this.connectingStreams.clear();
this.invalidateCache();
}

Expand All @@ -459,9 +473,18 @@ export class StreamingMessageAggregator {
}

// Unified event handlers that encapsulate all complex logic
handleStreamPending(data: StreamPendingEvent): void {
// Clear pending stream start timestamp - backend has accepted the request.
this.setPendingStreamStartTime(null);

this.connectingStreams.set(data.messageId, { startTime: Date.now(), model: data.model });
this.invalidateCache();
}

handleStreamStart(data: StreamStartEvent): void {
// Clear pending stream start timestamp - stream has started
// Clear pending/connecting state - stream has started.
this.setPendingStreamStartTime(null);
this.connectingStreams.delete(data.messageId);

// NOTE: We do NOT clear agentStatus or currentTodos here.
// They are cleared when a new user message arrives (see handleMessage),
Expand Down Expand Up @@ -596,10 +619,10 @@ export class StreamingMessageAggregator {
}

handleStreamError(data: StreamErrorMessage): void {
// Direct lookup by messageId
const activeStream = this.activeStreams.get(data.messageId);
const isTrackedStream =
this.activeStreams.has(data.messageId) || this.connectingStreams.has(data.messageId);

if (activeStream) {
if (isTrackedStream) {
// Mark the message with error metadata
const message = this.messages.get(data.messageId);
if (message?.metadata) {
Expand All @@ -608,32 +631,33 @@ export class StreamingMessageAggregator {
message.metadata.errorType = data.errorType;
}

// Clean up stream-scoped state (active stream tracking, TODOs)
// Clean up stream-scoped state (active/connecting tracking, TODOs)
this.cleanupStreamState(data.messageId);
this.invalidateCache();
} else {
// Pre-stream error (e.g., API key not configured before streaming starts)
// Create a synthetic error message since there's no active stream to attach to
// Get the highest historySequence from existing messages so this appears at the end
const maxSequence = Math.max(
0,
...Array.from(this.messages.values()).map((m) => m.metadata?.historySequence ?? 0)
);
const errorMessage: MuxMessage = {
id: data.messageId,
role: "assistant",
parts: [],
metadata: {
partial: true,
error: data.error,
errorType: data.errorType,
timestamp: Date.now(),
historySequence: maxSequence + 1,
},
};
this.messages.set(data.messageId, errorMessage);
this.invalidateCache();
return;
}

// Pre-stream error (e.g., API key not configured before streaming starts)
// Create a synthetic error message since there's no tracked stream to attach to.
// Get the highest historySequence from existing messages so this appears at the end.
const maxSequence = Math.max(
0,
...Array.from(this.messages.values()).map((m) => m.metadata?.historySequence ?? 0)
);
const errorMessage: MuxMessage = {
id: data.messageId,
role: "assistant",
parts: [],
metadata: {
partial: true,
error: data.error,
errorType: data.errorType,
timestamp: Date.now(),
historySequence: maxSequence + 1,
},
};
this.messages.set(data.messageId, errorMessage);
this.invalidateCache();
}

handleToolCallStart(data: ToolCallStartEvent): void {
Expand Down
29 changes: 29 additions & 0 deletions src/browser/utils/messages/retryEligibility.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,35 @@ describe("hasInterruptedStream", () => {
expect(hasInterruptedStream(messages, null)).toBe(true);
});

it("returns false when pendingStreamStartTime is null but last user message timestamp is recent (replay/reload)", () => {
const justSentTimestamp = Date.now() - (PENDING_STREAM_START_GRACE_PERIOD_MS - 500);
const messages: DisplayedMessage[] = [
{
type: "user",
id: "user-1",
historyId: "user-1",
content: "Hello",
historySequence: 1,
timestamp: justSentTimestamp,
},
];
expect(hasInterruptedStream(messages, null)).toBe(false);
});

it("returns true when pendingStreamStartTime is null and last user message timestamp is old (replay/reload)", () => {
const longAgoTimestamp = Date.now() - (PENDING_STREAM_START_GRACE_PERIOD_MS + 1000);
const messages: DisplayedMessage[] = [
{
type: "user",
id: "user-1",
historyId: "user-1",
content: "Hello",
historySequence: 1,
timestamp: longAgoTimestamp,
},
];
expect(hasInterruptedStream(messages, null)).toBe(true);
});
it("returns false when user message just sent (within grace period)", () => {
const messages: DisplayedMessage[] = [
{
Expand Down
20 changes: 13 additions & 7 deletions src/browser/utils/messages/retryEligibility.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,22 @@ export function hasInterruptedStream(
): boolean {
if (messages.length === 0) return false;

// Don't show retry barrier if user message was sent very recently (within the grace period)
// This prevents flash during normal send flow while stream-start event arrives
// After the grace period, assume something is wrong and show the barrier
if (pendingStreamStartTime !== null) {
const elapsed = Date.now() - pendingStreamStartTime;
const lastMessage = messages[messages.length - 1];

// Don't show retry barrier if the last user message was sent very recently (within the grace period).
//
// We prefer the explicit pendingStreamStartTime (set during the live send flow).
// But during history replay / app reload, pendingStreamStartTime can be null even when the last
// message is a fresh user message. In that case, fall back to the user message timestamp.
const graceStartTime =
pendingStreamStartTime ??
(lastMessage.type === "user" ? (lastMessage.timestamp ?? null) : null);

if (graceStartTime !== null) {
const elapsed = Date.now() - graceStartTime;
if (elapsed < PENDING_STREAM_START_GRACE_PERIOD_MS) return false;
}

const lastMessage = messages[messages.length - 1];

return (
lastMessage.type === "stream-error" || // Stream errored out (show UI for ALL error types)
lastMessage.type === "user" || // No response received yet (app restart during slow model)
Expand Down
1 change: 1 addition & 0 deletions src/common/orpc/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ export {
StreamDeltaEventSchema,
StreamEndEventSchema,
StreamErrorMessageSchema,
StreamPendingEventSchema,
StreamStartEventSchema,
ToolCallDeltaEventSchema,
ToolCallEndEventSchema,
Expand Down
12 changes: 12 additions & 0 deletions src/common/orpc/schemas/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ export const DeleteMessageSchema = z.object({
historySequences: z.array(z.number()),
});

// Emitted when a stream has been registered and is abortable, but before streaming begins.
// This prevents RetryBarrier flash during slow provider connection/setup.
export const StreamPendingEventSchema = z.object({
type: z.literal("stream-pending"),
workspaceId: z.string(),
messageId: z.string(),
model: z.string(),
historySequence: z.number().meta({
description: "Backend assigns global message ordering",
}),
});
export const StreamStartEventSchema = z.object({
type: z.literal("stream-start"),
workspaceId: z.string(),
Expand Down Expand Up @@ -261,6 +272,7 @@ export const WorkspaceChatMessageSchema = z.discriminatedUnion("type", [
CaughtUpMessageSchema,
StreamErrorMessageSchema,
DeleteMessageSchema,
StreamPendingEventSchema,
StreamStartEventSchema,
StreamDeltaEventSchema,
StreamEndEventSchema,
Expand Down
5 changes: 5 additions & 0 deletions src/common/orpc/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { z } from "zod";
import type * as schemas from "./schemas";

import type {
StreamPendingEvent,
StreamStartEvent,
StreamDeltaEvent,
StreamEndEvent,
Expand Down Expand Up @@ -43,6 +44,10 @@ export function isStreamError(msg: WorkspaceChatMessage): msg is StreamErrorMess
return (msg as { type?: string }).type === "stream-error";
}

export function isStreamPending(msg: WorkspaceChatMessage): msg is StreamPendingEvent {
return (msg as { type?: string }).type === "stream-pending";
}

export function isDeleteMessage(msg: WorkspaceChatMessage): msg is DeleteMessage {
return (msg as { type?: string }).type === "delete";
}
Expand Down
3 changes: 3 additions & 0 deletions src/common/types/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import type {
StreamAbortEventSchema,
StreamDeltaEventSchema,
StreamEndEventSchema,
StreamPendingEventSchema,
StreamStartEventSchema,
ToolCallDeltaEventSchema,
ToolCallEndEventSchema,
Expand All @@ -22,6 +23,7 @@ import type {
* Completed message part (reasoning, text, or tool) suitable for serialization
* Used in StreamEndEvent and partial message storage
*/
export type StreamPendingEvent = z.infer<typeof StreamPendingEventSchema>;
export type CompletedMessagePart = MuxReasoningPart | MuxTextPart | MuxToolPart;

export type StreamStartEvent = z.infer<typeof StreamStartEventSchema>;
Expand All @@ -45,6 +47,7 @@ export type ReasoningEndEvent = z.infer<typeof ReasoningEndEventSchema>;
export type UsageDeltaEvent = z.infer<typeof UsageDeltaEventSchema>;

export type AIServiceEvent =
| StreamPendingEvent
| StreamStartEvent
| StreamDeltaEvent
| StreamEndEvent
Expand Down
42 changes: 42 additions & 0 deletions src/common/utils/streamLifecycle.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Stream lifecycle events are emitted during an in-flight assistant response.
//
// Keeping the event list centralized makes it harder to accidentally forget to forward/buffer a
// newly introduced lifecycle event.

export const STREAM_LIFECYCLE_EVENTS = [
"stream-pending",
"stream-start",
"stream-delta",
"stream-abort",
"stream-end",
] as const;

export type StreamLifecycleEventName = (typeof STREAM_LIFECYCLE_EVENTS)[number];

// Events that can be forwarded 1:1 from StreamManager -> AIService.
// (`stream-abort` needs additional bookkeeping in AIService.)
export const STREAM_LIFECYCLE_EVENTS_DIRECT_FORWARD = [
"stream-pending",
"stream-start",
"stream-delta",
"stream-end",
] as const satisfies readonly StreamLifecycleEventName[];

// Events that can be forwarded 1:1 from AIService -> AgentSession -> renderer.
// (`stream-end` has additional session-side behavior.)
export const STREAM_LIFECYCLE_EVENTS_SIMPLE_FORWARD = [
"stream-pending",
"stream-start",
"stream-delta",
"stream-abort",
] as const satisfies readonly StreamLifecycleEventName[];

export function forwardStreamLifecycleEvents(params: {
events: readonly StreamLifecycleEventName[];
listen: (event: StreamLifecycleEventName, handler: (payload: unknown) => void) => void;
emit: (event: StreamLifecycleEventName, payload: unknown) => void;
}): void {
for (const event of params.events) {
params.listen(event, (payload) => params.emit(event, payload));
}
}
Loading