diff --git a/package-lock.json b/package-lock.json index aec4857235e..128d7c8adc2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -50501,6 +50501,7 @@ "@mongodb-js/atlas-service": "^0.74.1", "@mongodb-js/compass-app-registry": "^9.5.0", "@mongodb-js/compass-app-stores": "^7.76.1", + "@mongodb-js/compass-assistant": "^1.21.1", "@mongodb-js/compass-collection": "^4.91.1", "@mongodb-js/compass-components": "^1.60.0", "@mongodb-js/compass-connections": "^1.90.1", @@ -63096,6 +63097,7 @@ "@mongodb-js/atlas-service": "^0.74.1", "@mongodb-js/compass-app-registry": "^9.5.0", "@mongodb-js/compass-app-stores": "^7.76.1", + "@mongodb-js/compass-assistant": "^1.21.1", "@mongodb-js/compass-collection": "^4.91.1", "@mongodb-js/compass-components": "^1.60.0", "@mongodb-js/compass-connections": "^1.90.1", diff --git a/packages/compass-aggregations/src/components/pipeline-builder-workspace/pipeline-as-text-workspace/pipeline-editor.tsx b/packages/compass-aggregations/src/components/pipeline-builder-workspace/pipeline-as-text-workspace/pipeline-editor.tsx index 00c7639812f..d2a5f1ce8ca 100644 --- a/packages/compass-aggregations/src/components/pipeline-builder-workspace/pipeline-as-text-workspace/pipeline-editor.tsx +++ b/packages/compass-aggregations/src/components/pipeline-builder-workspace/pipeline-as-text-workspace/pipeline-editor.tsx @@ -23,6 +23,7 @@ import type { PipelineParserError } from '../../../modules/pipeline-builder/pipe import { useAutocompleteFields } from '@mongodb-js/compass-field-store'; import { useTelemetry } from '@mongodb-js/compass-telemetry/provider'; import { useConnectionInfoRef } from '@mongodb-js/compass-connections/provider'; +import { useSyncAssistantGlobalState } from '@mongodb-js/compass-assistant'; const containerStyles = css({ position: 'relative', @@ -86,6 +87,8 @@ export const PipelineEditor: React.FunctionComponent = ({ const editorInitialValueRef = useRef(pipelineText); const editorCurrentValueRef = useCurrentValueRef(pipelineText); + useSyncAssistantGlobalState('currentAggregation', pipelineText); + const { utmSource, utmMedium } = useRequiredURLSearchParams(); const completer = useMemo(() => { diff --git a/packages/compass-aggregations/src/components/pipeline-builder-workspace/pipeline-builder-ui-workspace/index.tsx b/packages/compass-aggregations/src/components/pipeline-builder-workspace/pipeline-builder-ui-workspace/index.tsx index 3a0e9cc3961..2b9e597e2f3 100644 --- a/packages/compass-aggregations/src/components/pipeline-builder-workspace/pipeline-builder-ui-workspace/index.tsx +++ b/packages/compass-aggregations/src/components/pipeline-builder-workspace/pipeline-builder-ui-workspace/index.tsx @@ -17,6 +17,8 @@ import AddStage from '../../add-stage'; import UseCaseDroppableArea from '../../use-case-droppable-area'; import type { StageIdAndType } from '../../../modules/pipeline-builder/stage-editor'; import PipelineBuilderDndWrapper from './dnd-wrapper'; +import { prettify } from '../../../modules/pipeline-builder/pipeline-parser/utils'; +import { useSyncAssistantGlobalState } from '@mongodb-js/compass-assistant'; const pipelineWorkspaceContainerStyles = css({ position: 'relative', @@ -36,6 +38,7 @@ const pipelineWorkspaceStyles = css({ export type PipelineBuilderUIWorkspaceProps = { stagesIdAndType: StageIdAndType[]; + pipelineText: string; isSidePanelOpen: boolean; onStageMoveEnd: (from: number, to: number) => void; onStageAddAfterEnd: (after?: number) => void; @@ -50,11 +53,14 @@ export const PipelineBuilderUIWorkspace: React.FunctionComponent< PipelineBuilderUIWorkspaceProps > = ({ stagesIdAndType, + pipelineText, isSidePanelOpen, onStageMoveEnd, onStageAddAfterEnd, onUseCaseDropped, }) => { + useSyncAssistantGlobalState('currentAggregation', pipelineText); + return ( + stage.stageOperator !== null && stage.value !== null && !stage.disabled + ) + .map((stage) => stageToString(stage)) + .join(',\n')}\n]`; + return prettify(code); +} + const mapState = (state: RootState) => { return { stagesIdAndType: state.pipelineBuilder.stageEditor.stagesIdAndType, + pipelineText: getPipelineTextFromStages( + state.pipelineBuilder.stageEditor.stages + ), isSidePanelOpen: state.sidePanel.isPanelOpen, }; }; diff --git a/packages/compass-assistant/src/@ai-sdk/react/use-chat.ts b/packages/compass-assistant/src/@ai-sdk/react/use-chat.ts index 2e630cb46f3..aa9349f9b9b 100644 --- a/packages/compass-assistant/src/@ai-sdk/react/use-chat.ts +++ b/packages/compass-assistant/src/@ai-sdk/react/use-chat.ts @@ -41,6 +41,7 @@ export type UseChatHelpers = { | 'stop' | 'resumeStream' | 'addToolResult' + | 'addToolApprovalResponse' | 'status' | 'messages' | 'clearError' @@ -124,5 +125,6 @@ export function useChat({ resumeStream: chatRef.current.resumeStream, status, addToolResult: chatRef.current.addToolResult, + addToolApprovalResponse: chatRef.current.addToolApprovalResponse, }; } diff --git a/packages/compass-assistant/src/assistant-global-state.tsx b/packages/compass-assistant/src/assistant-global-state.tsx index e69f0e181a4..6c8623e2fcc 100644 --- a/packages/compass-assistant/src/assistant-global-state.tsx +++ b/packages/compass-assistant/src/assistant-global-state.tsx @@ -10,8 +10,8 @@ export type GlobalState = { activeConnections: ConnectionInfo[]; activeWorkspace: WorkspaceTab | null; activeCollectionMetadata: CollectionMetadata | null; - currentQuery: object | null; - currentAggregation: object | null; + currentQuery: string | null; + currentAggregation: string | null; activeCollectionSubTab: CollectionSubtab | null; }; diff --git a/packages/compass-assistant/src/compass-assistant-provider.tsx b/packages/compass-assistant/src/compass-assistant-provider.tsx index b8d75aec37b..3727c6def36 100644 --- a/packages/compass-assistant/src/compass-assistant-provider.tsx +++ b/packages/compass-assistant/src/compass-assistant-provider.tsx @@ -26,6 +26,8 @@ import { type ProactiveInsightsContext, } from './prompts'; import { + type PreferencesAccess, + preferencesLocator, useIsAIFeatureEnabled, usePreference, } from 'compass-preferences-model/provider'; @@ -39,14 +41,24 @@ import { type TrackFunction, useTelemetry, } from '@mongodb-js/compass-telemetry/provider'; -import type { AtlasAiService } from '@mongodb-js/compass-generative-ai/provider'; -import { atlasAiServiceLocator } from '@mongodb-js/compass-generative-ai/provider'; +import type { + AtlasAiService, + ToolsController, +} from '@mongodb-js/compass-generative-ai/provider'; +import { + atlasAiServiceLocator, + toolsControllerLocator, +} from '@mongodb-js/compass-generative-ai/provider'; import { buildConversationInstructionsPrompt } from './prompts'; import { createOpenAI } from '@ai-sdk/openai'; import { AssistantGlobalStateProvider, useAssistantGlobalState, } from './assistant-global-state'; +import { + lastAssistantMessageIsCompleteWithApprovalResponses, + type ToolSet, +} from 'ai'; export const ASSISTANT_DRAWER_ID = 'compass-assistant-drawer'; @@ -200,8 +212,10 @@ export const AssistantProvider: React.FunctionComponent< appNameForPrompt: string; chat: Chat; atlasAiService: AtlasAiService; + toolsController: ToolsController; + preferences: PreferencesAccess; }> -> = ({ chat, atlasAiService, children }) => { +> = ({ chat, atlasAiService, toolsController, preferences, children }) => { const { openDrawer } = useDrawerActions(); const track = useTelemetry(); @@ -272,6 +286,22 @@ export const AssistantProvider: React.FunctionComponent< chat.messages = [...chat.messages, contextPrompt]; } + const { enableToolCalling } = preferences.getPreferences(); + + if (enableToolCalling) { + toolsController.setActiveTools(new Set(['compass-ui'])); + toolsController.setContext({ + query: assistantGlobalStateRef.current.currentQuery || undefined, + aggregation: + assistantGlobalStateRef.current.currentAggregation || undefined, + }); + } else { + toolsController.setActiveTools(new Set([])); + toolsController.setContext({ + query: undefined, + aggregation: undefined, + }); + } await chat.sendMessage(message, options); }; }); @@ -340,12 +370,16 @@ export const CompassAssistantProvider = registerCompassPlugin( appNameForPrompt, chat, atlasAiService, + toolsController, + preferences, children, }: PropsWithChildren<{ appNameForPrompt: string; originForPrompt: string; chat?: Chat; atlasAiService?: AtlasAiService; + toolsController?: ToolsController; + preferences?: PreferencesAccess; }>) => { if (!chat) { throw new Error('Chat was not provided by the state'); @@ -353,12 +387,20 @@ export const CompassAssistantProvider = registerCompassPlugin( if (!atlasAiService) { throw new Error('atlasAiService was not provided by the state'); } + if (!toolsController) { + throw new Error('toolsController was not provided by the state'); + } + if (!preferences) { + throw new Error('preferences was not provided by the state'); + } return ( {children} @@ -367,7 +409,14 @@ export const CompassAssistantProvider = registerCompassPlugin( }, activate: ( { chat: initialChat, originForPrompt, appNameForPrompt }, - { atlasService, atlasAiService, logger, track } + { + atlasService, + atlasAiService, + toolsController, + preferences, + logger, + track, + } ) => { const chat = initialChat ?? @@ -377,10 +426,13 @@ export const CompassAssistantProvider = registerCompassPlugin( atlasService, logger, track, + getTools: () => toolsController.getActiveTools(), }); return { - store: { state: { chat, atlasAiService } }, + store: { + state: { chat, atlasAiService, toolsController, preferences }, + }, deactivate: () => {}, }; }, @@ -389,8 +441,10 @@ export const CompassAssistantProvider = registerCompassPlugin( atlasService: atlasServiceLocator, atlasAiService: atlasAiServiceLocator, atlasAuthService: atlasAuthServiceLocator, + toolsController: toolsControllerLocator, track: telemetryLocator, logger: createLoggerLocator('COMPASS-ASSISTANT'), + preferences: preferencesLocator, } ); @@ -401,6 +455,7 @@ export function createDefaultChat({ logger, track, options, + getTools, }: { originForPrompt: string; appNameForPrompt: string; @@ -410,13 +465,15 @@ export function createDefaultChat({ options?: { transport: Chat['transport']; }; + getTools?: () => ToolSet; }): Chat { const initialBaseUrl = 'http://PLACEHOLDER_BASE_URL_TO_BE_REPLACED.invalid'; - return new Chat({ + return new Chat({ transport: options?.transport ?? new DocsProviderTransport({ origin: originForPrompt, + getTools, instructions: buildConversationInstructionsPrompt({ target: appNameForPrompt, }), @@ -453,6 +510,7 @@ export function createDefaultChat({ }, }).responses('mongodb-chat-latest'), }), + sendAutomaticallyWhen: lastAssistantMessageIsCompleteWithApprovalResponses, onError: (err: Error) => { logger.log.error( logger.mongoLogId(1_001_000_370), diff --git a/packages/compass-assistant/src/components/assistant-chat.tsx b/packages/compass-assistant/src/components/assistant-chat.tsx index 59a9dfed02b..6f703fa1512 100644 --- a/packages/compass-assistant/src/components/assistant-chat.tsx +++ b/packages/compass-assistant/src/components/assistant-chat.tsx @@ -18,6 +18,8 @@ import { Icon, } from '@mongodb-js/compass-components'; import { ConfirmationMessage } from './confirmation-message'; +import { ToolCallMessage } from './tool-call-message'; +import type { ToolCallPart } from './tool-call-message'; import { useTelemetry } from '@mongodb-js/compass-telemetry/provider'; import { NON_GENUINE_WARNING_MESSAGE } from '../preset-messages'; import { SuggestedPrompts } from './suggested-prompts'; @@ -203,7 +205,14 @@ export const AssistantChat: React.FunctionComponent = ({ chat.messages[chat.messages.length - 1] ?? {}; const { ensureOptInAndSend } = useContext(AssistantActionsContext); - const { messages, status, error, clearError, setMessages } = useChat({ + const { + messages, + status, + error, + clearError, + setMessages, + addToolApprovalResponse, + } = useChat({ chat, }); @@ -350,6 +359,22 @@ export const AssistantChat: React.FunctionComponent = ({ [ensureOptInAndSend, setMessages, track] ); + const handleToolApproval = useCallback( + (approvalId: string, approved: boolean) => { + void addToolApprovalResponse({ + id: approvalId, + approved, + }); + + // TODO: Add telemetry event for tool approvals when it's added to the telemetry schema + // track('Assistant Tool Call Approval', { + // approved, + // approval_id: approvalId, + // }); + }, + [addToolApprovalResponse] + ); + const visibleMessages = messages.filter( (message) => !message.metadata?.isSystemContext ); @@ -375,6 +400,8 @@ export const AssistantChat: React.FunctionComponent = ({ const { id, role, metadata, parts } = message; const seenTitles = new Set(); const sources = []; + const toolCalls: ToolCallPart[] = []; + for (const part of parts) { // Related sources are type source-url. We want to only // include url_citation (has url and title), not file_citation @@ -389,7 +416,14 @@ export const AssistantChat: React.FunctionComponent = ({ }); } } + + // Detect tool call parts (they have a "tool-" prefix or a toolCallId) + if (part.type.startsWith('tool-') || 'toolCallId' in part) { + toolCalls.push(part as ToolCallPart); + } } + + // Handle confirmation messages if (metadata?.confirmation) { const { description, state } = metadata.confirmation; const isLastMessage = index === visibleMessages.length - 1; @@ -420,32 +454,56 @@ export const AssistantChat: React.FunctionComponent = ({ const isSender = role === 'user'; + // Render tool calls and text content together return ( - - {isSender === false && ( - - handleFeedback({ message, state }) - } - onSubmitFeedback={(event, state) => - handleFeedback({ message, state }) - } - className={noWrapFixesStyles} - /> - )} - {sources.length > 0 && ( - + + {/* Show tool calls if present */} + {toolCalls.map((toolCall) => { + const toolCallId = + toolCall.toolCallId || `${id}-${toolCall.type}`; + + return ( + + handleToolApproval(approvalId, true) + } + onDeny={(approvalId) => + handleToolApproval(approvalId, false) + } + /> + ); + })} + {/* Show text message if there's text content */} + {displayText && ( + + {isSender === false && ( + + handleFeedback({ message, state }) + } + onSubmitFeedback={(event, state) => + handleFeedback({ message, state }) + } + className={noWrapFixesStyles} + /> + )} + {sources.length > 0 && ( + + )} + )} - + ); })} @@ -479,7 +537,7 @@ export const AssistantChat: React.FunctionComponent = ({ void handleMessageSend({ text })} - state={status === 'submitted' ? 'loading' : undefined} + state={status !== 'ready' ? 'loading' : undefined} textareaProps={inputBarTextareaProps} /> diff --git a/packages/compass-assistant/src/components/tool-call-message.tsx b/packages/compass-assistant/src/components/tool-call-message.tsx new file mode 100644 index 00000000000..872da9c256b --- /dev/null +++ b/packages/compass-assistant/src/components/tool-call-message.tsx @@ -0,0 +1,329 @@ +import React, { useState } from 'react'; +import { + Icon, + Body, + Button, + ButtonVariant, + spacing, + css, + cx, + palette, + useDarkMode, +} from '@mongodb-js/compass-components'; + +const toolCallCardStyles = css({ + padding: spacing[200], + borderRadius: spacing[200], + backgroundColor: palette.gray.light3, + border: `1px solid ${palette.gray.light2}`, +}); + +const toolCallCardDarkModeStyles = css({ + backgroundColor: palette.gray.dark3, + borderColor: palette.gray.dark2, +}); + +const toolHeaderStyles = css({ + display: 'flex', + alignItems: 'center', + gap: spacing[200], + marginBottom: spacing[100], +}); + +const toolIconContainerStyles = css({ + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + width: spacing[400], + height: spacing[400], + borderRadius: '4px', + backgroundColor: palette.white, +}); + +const toolIconContainerDarkModeStyles = css({ + backgroundColor: palette.gray.dark2, +}); + +const toolNameStyles = css({ + fontFamily: 'Source Code Pro, monospace', + fontWeight: 600, + fontSize: '12px', + lineHeight: '16px', +}); + +const inputLabelStyles = css({ + fontWeight: 600, + fontSize: '12px', +}); + +const codeBlockStyles = css({ + maxHeight: '200px', + overflowY: 'auto', + fontSize: '11px', + backgroundColor: palette.gray.light2, + padding: spacing[200], + borderRadius: '4px', + fontFamily: 'Source Code Pro, monospace', + whiteSpace: 'pre-wrap', + wordBreak: 'break-all', +}); + +const codeBlockDarkModeStyles = css({ + backgroundColor: palette.gray.dark2, +}); + +const expandableHeaderStyles = css({ + display: 'flex', + alignItems: 'center', + justifyContent: 'space-between', + cursor: 'pointer', + padding: spacing[200], + marginTop: spacing[100], + borderRadius: '4px', + backgroundColor: palette.gray.light2, + transition: 'background-color color 0.2s ease-in-out', + '&:hover': { + backgroundColor: palette.gray.light1, + color: palette.black, + }, +}); + +const expandableHeaderDarkModeStyles = css({ + backgroundColor: palette.gray.dark2, + '&:hover': { + backgroundColor: palette.gray.dark1, + }, +}); + +const expandableContentStyles = css({ + marginTop: spacing[100], + padding: spacing[100], + borderRadius: '4px', + backgroundColor: palette.gray.light2, +}); + +const expandableContentDarkModeStyles = css({ + backgroundColor: palette.gray.dark2, +}); + +const buttonGroupStyles = css({ + display: 'flex', + gap: spacing[200], + marginTop: spacing[300], + '> button': { + flex: 1, + }, +}); + +const statusStyles = css({ + display: 'flex', + alignItems: 'center', + gap: spacing[100], + marginTop: spacing[200], +}); + +const statusTextStyles = css({ + color: palette.gray.dark1, + fontWeight: 500, + fontSize: '12px', +}); + +export interface ToolCallPart { + type: string; + toolCallId?: string; + state?: + | 'approval-requested' + | 'approval-responded' + | 'input-available' + | 'output-available' + | 'output-error' + | 'output-denied'; + input?: Record; + output?: { + content?: Array<{ + type: string; + text?: string; + }>; + }; + approval?: { + id: string; + approved?: boolean; + reason?: string; + }; +} + +interface ToolCallMessageProps { + toolCall: ToolCallPart; + onApprove?: (approvalId: string) => void; + onDeny?: (approvalId: string) => void; +} + +// Extract tool name from type (e.g., "tool-list-databases" -> "list-databases") +function getToolDisplayName(type: string): string { + return type.replace(/^tool-/, ''); +} + +export const ToolCallMessage: React.FunctionComponent = ({ + toolCall, + onApprove, + onDeny, +}) => { + const darkMode = useDarkMode(); + const [isInputExpanded, setIsInputExpanded] = useState(false); + const [isOutputExpanded, setIsOutputExpanded] = useState(false); + + const toolName = getToolDisplayName(toolCall.type); + const inputJSON = JSON.stringify(toolCall.input || {}, null, 2); + + // Extract output text if available + const outputText = toolCall.output + ? JSON.stringify(toolCall.output, null, 2) + : 'No output available'; + const hasOutput = toolCall.state === 'output-available'; + + const isAwaitingApproval = toolCall.state === 'approval-requested'; + const isApprovalResponded = toolCall.state === 'approval-responded'; + const isDenied = toolCall.state === 'output-denied'; + const wasApproved = toolCall.approval?.approved === true; + + return ( +
+
+
+ +
+
+ {toolName} +
+
+ + {/* Input Section */} +
+
setIsInputExpanded(!isInputExpanded)} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault(); + setIsInputExpanded(!isInputExpanded); + } + }} + > + Input + +
+ {isInputExpanded && ( +
+
+ {inputJSON} +
+
+ )} +
+ + {/* Output Section - only show if there's output */} + {hasOutput && ( +
+
setIsOutputExpanded(!isOutputExpanded)} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault(); + setIsOutputExpanded(!isOutputExpanded); + } + }} + > + Output + +
+ {isOutputExpanded && ( +
+
+ {outputText} +
+
+ )} +
+ )} + + {/* Approval buttons - show when approval is requested */} + {isAwaitingApproval && toolCall.approval && ( +
+ + +
+ )} + + {/* Status indicator for responded/denied states */} + {(isApprovalResponded || isDenied) && toolCall.approval && ( +
+ + + {wasApproved ? 'Tool call approved' : 'Tool call denied'} + {toolCall.approval.reason && ` - ${toolCall.approval.reason}`} + +
+ )} +
+ ); +}; diff --git a/packages/compass-assistant/src/docs-provider-transport.ts b/packages/compass-assistant/src/docs-provider-transport.ts index 7bb20232742..a5913135af2 100644 --- a/packages/compass-assistant/src/docs-provider-transport.ts +++ b/packages/compass-assistant/src/docs-provider-transport.ts @@ -1,6 +1,7 @@ import { type ChatTransport, type LanguageModel, + type ToolSet, type UIMessageChunk, convertToModelMessages, streamText, @@ -18,20 +19,24 @@ export function shouldExcludeMessage({ metadata }: AssistantMessage) { export class DocsProviderTransport implements ChatTransport { private model: LanguageModel; private origin: string; + private getTools: () => ToolSet; private instructions: string; constructor({ - instructions, model, origin, + getTools, + instructions, }: { - instructions: string; model: LanguageModel; origin: string; + getTools?: () => ToolSet; + instructions: string; }) { - this.instructions = instructions; this.model = model; this.origin = origin; + this.getTools = getTools ?? (() => ({})); + this.instructions = instructions; } static emptyStream = new ReadableStream({ @@ -71,6 +76,7 @@ export class DocsProviderTransport implements ChatTransport { headers: { 'X-Request-Origin': this.origin, }, + tools: this.getTools(), providerOptions: { openai: { store: false, diff --git a/packages/compass-assistant/src/prompts.ts b/packages/compass-assistant/src/prompts.ts index a9995f296c6..131fa6e2763 100644 --- a/packages/compass-assistant/src/prompts.ts +++ b/packages/compass-assistant/src/prompts.ts @@ -36,20 +36,21 @@ You should: - Encourage the user to understand what they are doing before they act, e.g. by reading the official documentation or other related resources. - Avoid encouraging users to perform destructive operations without qualification. Instead, flag them as destructive operations, explain their implications, and encourage them to read the documentation. 4. Always call the 'search_content' tool. +5. Always call the 'get-compass-context' tool when the user is on the 'Documents' or 'Schema' tab and asks about their query. +5. Always call the 'get-compass-context' tool when the user is on the 'Aggregations' tab and asks about their aggregation or pipeline. You are able to: 1. Answer technical questions +2. Use the 'get-compass-context' tool to get the current query from the query bar (if applicable) or the aggregation pipeline from the aggregation builder (if applicable). You CANNOT: -1. Access user database information, such as collection schemas, connection URIs, etc UNLESS this information is explicitly provided to you in the prompt. -2. Query MongoDB directly or execute code. -3. Access the current state of the UI +1. Query MongoDB directly or execute code. `; }; diff --git a/packages/compass-generative-ai/src/provider.tsx b/packages/compass-generative-ai/src/provider.tsx index 9a6d6dbfb56..2c9e07c7f0e 100644 --- a/packages/compass-generative-ai/src/provider.tsx +++ b/packages/compass-generative-ai/src/provider.tsx @@ -1,5 +1,6 @@ import React, { createContext, useContext, useMemo } from 'react'; import { AtlasAiService } from './atlas-ai-service'; +import { ToolsController } from './tools-controller'; import { preferencesLocator } from 'compass-preferences-model/provider'; import { useLogger } from '@mongodb-js/compass-logging/provider'; import { atlasServiceLocator } from '@mongodb-js/atlas-service/provider'; @@ -49,3 +50,40 @@ export const atlasAiServiceLocator = createServiceLocator( 'atlasAiServiceLocator' ); export { AtlasAiService } from './atlas-ai-service'; + +const ToolsControllerContext = createContext(null); + +export const ToolsControllerProvider: React.FC = createServiceProvider( + function ToolsControllerProvider({ children }) { + const logger = useLogger('TOOLS-CONTROLLER'); + + const toolsController = useMemo(() => { + return new ToolsController({ + logger, + }); + }, [logger]); + + return ( + + {children} + + ); + } +); + +function useToolsControllerContext(): ToolsController { + const service = useContext(ToolsControllerContext); + if (!service) { + throw new Error('No ToolsController available in this context'); + } + return service; +} + +export const toolsControllerLocator = createServiceLocator( + useToolsControllerContext, + 'toolsControllerLocator' +); +export { ToolsController } from './tools-controller'; + +// Export the hook for direct use in components +export const useToolsController = useToolsControllerContext; diff --git a/packages/compass-generative-ai/src/tools-controller.ts b/packages/compass-generative-ai/src/tools-controller.ts new file mode 100644 index 00000000000..3947a19cc76 --- /dev/null +++ b/packages/compass-generative-ai/src/tools-controller.ts @@ -0,0 +1,58 @@ +import type { ToolSet } from 'ai'; +import type { Logger } from '@mongodb-js/compass-logging'; +import z from 'zod'; + +// TODO: add readonly-db +type ToolGroup = 'compass-ui'; + +type CompassContext = { + query?: string; + aggregation?: string; +}; + +// TODO: add connection info +type ToolsContext = CompassContext; + +export class ToolsController { + private logger: Logger; + private toolGroups: Set = new Set(); + private context: ToolsContext = Object.create(null); + + constructor({ logger }: { logger: Logger }) { + this.logger = logger; + } + + setActiveTools(toolGroups: Set): void { + this.toolGroups = toolGroups; + } + + getActiveTools(): ToolSet { + const tools = Object.create(null); + + if (this.toolGroups.has('compass-ui')) { + tools['get-compass-context'] = { + description: 'Get the current Compass query or aggregation.', + inputSchema: z.object({}), + needsApproval: true, + strict: false, + execute: (): Promise => { + this.logger.log.info( + this.logger.mongoLogId(1_001_000_386), + 'ToolsController', + 'Executing get-compass-context tool' + ); + return Promise.resolve(this.context); + }, + // TODO: toModelOutput function to format this? + }; + } + + return tools; + } + + setContext(context: ToolsContext): void { + // TODO: we'll also disconnect if the active connection is not the intended + // one and start connecting if necessary + this.context = context; + } +} diff --git a/packages/compass-query-bar/package.json b/packages/compass-query-bar/package.json index 09388e1bc1c..66e53921e93 100644 --- a/packages/compass-query-bar/package.json +++ b/packages/compass-query-bar/package.json @@ -66,6 +66,7 @@ }, "dependencies": { "@mongodb-js/atlas-service": "^0.74.1", + "@mongodb-js/compass-assistant": "^1.21.1", "@mongodb-js/compass-app-registry": "^9.5.0", "@mongodb-js/compass-app-stores": "^7.76.1", "@mongodb-js/compass-collection": "^4.91.1", diff --git a/packages/compass-query-bar/src/components/query-bar.tsx b/packages/compass-query-bar/src/components/query-bar.tsx index b6c5e201447..69e6854d7b7 100644 --- a/packages/compass-query-bar/src/components/query-bar.tsx +++ b/packages/compass-query-bar/src/components/query-bar.tsx @@ -45,6 +45,9 @@ import { useFavoriteQueryStorageAccess, useRecentQueryStorageAccess, } from '@mongodb-js/my-queries-storage/provider'; +import { useQueryBarQuery } from './hooks'; +import { useSyncAssistantGlobalState } from '@mongodb-js/compass-assistant'; +import { toJSString } from 'mongodb-query-parser'; const queryBarFormStyles = css({ display: 'flex', @@ -203,6 +206,9 @@ export const QueryBar: React.FunctionComponent = ({ recentQueryStorageAvailable && isMyQueriesEnabled; + const query = useQueryBarQuery(); + useSyncAssistantGlobalState('currentQuery', toJSString(query) || null); + return (
- - { - return Promise.resolve([{}, null] as [ - Record, - null - ]); - }} - onAutoconnectInfoRequest={(connectionStore) => { - if (autoconnectId) { - return connectionStore.loadAll().then( - (connections) => { - return connections.find( - (connectionInfo) => - connectionInfo.id === autoconnectId - ); - }, - (err) => { - const { log, mongoLogId } = logger; - log.warn( - mongoLogId(1_001_000_329), - 'Compass Web', - 'Could not load connections when trying to autoconnect', - { err: err.message } - ); - return undefined; - } - ); - } - return Promise.resolve(undefined); - }} + + - - - - { + return Promise.resolve([{}, null] as [ + Record, + null + ]); + }} + onAutoconnectInfoRequest={(connectionStore) => { + if (autoconnectId) { + return connectionStore.loadAll().then( + (connections) => { + return connections.find( + (connectionInfo) => + connectionInfo.id === autoconnectId + ); + }, + (err) => { + const { log, mongoLogId } = logger; + log.warn( + mongoLogId(1_001_000_329), + 'Compass Web', + 'Could not load connections when trying to autoconnect', + { err: err.message } + ); + return undefined; } - > - - - - - - + ); + } + return Promise.resolve(undefined); + }} + > + + + + + + + + + + + diff --git a/packages/compass/src/app/components/home.tsx b/packages/compass/src/app/components/home.tsx index c7dc0dba28c..547f25997f5 100644 --- a/packages/compass/src/app/components/home.tsx +++ b/packages/compass/src/app/components/home.tsx @@ -28,6 +28,7 @@ import { CompassInstanceStorePlugin } from '@mongodb-js/compass-app-stores'; import FieldStorePlugin from '@mongodb-js/compass-field-store'; import { AtlasAuthPlugin } from '@mongodb-js/atlas-service/renderer'; import { CompassGenerativeAIPlugin } from '@mongodb-js/compass-generative-ai'; +import { ToolsControllerProvider } from '@mongodb-js/compass-generative-ai/provider'; import { ConnectionStorageProvider } from '@mongodb-js/connection-storage/provider'; import { ConnectionImportExportProvider } from '@mongodb-js/compass-connection-import-export'; import { useTelemetry } from '@mongodb-js/compass-telemetry/provider'; @@ -114,26 +115,28 @@ function HomeWithConnections({ return ( - - { - openToast('failed-to-load-connections', { - title: 'Failed to load connections', - description: error.message, - variant: 'warning', - }); - }} + + - - - + { + openToast('failed-to-load-connections', { + title: 'Failed to load connections', + description: error.message, + variant: 'warning', + }); + }} + > + + + + );