-
Notifications
You must be signed in to change notification settings - Fork 246
feat(compass-assistant): add basic tool calling & tool cards COMPASS-10144 #7668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6778c80
f3adac8
4faf3a0
cd0ee5a
68070fa
da44bec
d7f812f
38ff9e1
c900693
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ import AddStage from '../../add-stage'; | |
| import UseCaseDroppableArea from '../../use-case-droppable-area'; | ||
| import type { StageIdAndType } from '../../../modules/pipeline-builder/stage-editor'; | ||
| import PipelineBuilderDndWrapper from './dnd-wrapper'; | ||
| import { prettify } from '../../../modules/pipeline-builder/pipeline-parser/utils'; | ||
| import { useSyncAssistantGlobalState } from '@mongodb-js/compass-assistant'; | ||
|
|
||
| const pipelineWorkspaceContainerStyles = css({ | ||
| position: 'relative', | ||
|
|
@@ -36,6 +38,7 @@ const pipelineWorkspaceStyles = css({ | |
|
|
||
| export type PipelineBuilderUIWorkspaceProps = { | ||
| stagesIdAndType: StageIdAndType[]; | ||
| pipelineText: string; | ||
| isSidePanelOpen: boolean; | ||
| onStageMoveEnd: (from: number, to: number) => void; | ||
| onStageAddAfterEnd: (after?: number) => void; | ||
|
|
@@ -50,11 +53,14 @@ export const PipelineBuilderUIWorkspace: React.FunctionComponent< | |
| PipelineBuilderUIWorkspaceProps | ||
| > = ({ | ||
| stagesIdAndType, | ||
| pipelineText, | ||
| isSidePanelOpen, | ||
| onStageMoveEnd, | ||
| onStageAddAfterEnd, | ||
| onUseCaseDropped, | ||
| }) => { | ||
| useSyncAssistantGlobalState('currentAggregation', pipelineText); | ||
|
|
||
| return ( | ||
| <PipelineBuilderDndWrapper | ||
| stagesIdAndType={stagesIdAndType} | ||
|
|
@@ -99,9 +105,34 @@ export const PipelineBuilderUIWorkspace: React.FunctionComponent< | |
| ); | ||
| }; | ||
|
|
||
| type Stage = { | ||
| disabled?: boolean; | ||
| syntaxError: Error | null; | ||
| stageOperator: string | null; | ||
| value: string | null; | ||
| }; | ||
|
|
||
| function stageToString(stage: Stage): string { | ||
| return `{ ${stage.stageOperator}: ${stage.value} }`; | ||
| } | ||
|
|
||
| function getPipelineTextFromStages(stages: Stage[]): string { | ||
| const code = `[${stages | ||
| .filter( | ||
| (stage) => | ||
| stage.stageOperator !== null && stage.value !== null && !stage.disabled | ||
| ) | ||
| .map((stage) => stageToString(stage)) | ||
| .join(',\n')}\n]`; | ||
| return prettify(code); | ||
| } | ||
|
|
||
| const mapState = (state: RootState) => { | ||
| return { | ||
| stagesIdAndType: state.pipelineBuilder.stageEditor.stagesIdAndType, | ||
| pipelineText: getPipelineTextFromStages( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case I can't think of a better way to calculate the pipeline text. The usual machinery is only accessible from redux land because it uses And then the builder has all these invalid / partial states. I'm attempting to just skip those stages. Maybe we want just the most recent valid pipeline, maybe we want the partial states and every syntax error so the model can do something with it? Maybe I'm overthinking it and this is fine for now. |
||
| state.pipelineBuilder.stageEditor.stages | ||
| ), | ||
| isSidePanelOpen: state.sidePanel.isPanelOpen, | ||
| }; | ||
| }; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,8 @@ import { | |
| type ProactiveInsightsContext, | ||
| } from './prompts'; | ||
| import { | ||
| type PreferencesAccess, | ||
| preferencesLocator, | ||
| useIsAIFeatureEnabled, | ||
| usePreference, | ||
| } from 'compass-preferences-model/provider'; | ||
|
|
@@ -39,14 +41,24 @@ import { | |
| type TrackFunction, | ||
| useTelemetry, | ||
| } from '@mongodb-js/compass-telemetry/provider'; | ||
| import type { AtlasAiService } from '@mongodb-js/compass-generative-ai/provider'; | ||
| import { atlasAiServiceLocator } from '@mongodb-js/compass-generative-ai/provider'; | ||
| import type { | ||
| AtlasAiService, | ||
| ToolsController, | ||
| } from '@mongodb-js/compass-generative-ai/provider'; | ||
| import { | ||
| atlasAiServiceLocator, | ||
| toolsControllerLocator, | ||
| } from '@mongodb-js/compass-generative-ai/provider'; | ||
| import { buildConversationInstructionsPrompt } from './prompts'; | ||
| import { createOpenAI } from '@ai-sdk/openai'; | ||
| import { | ||
| AssistantGlobalStateProvider, | ||
| useAssistantGlobalState, | ||
| } from './assistant-global-state'; | ||
| import { | ||
| lastAssistantMessageIsCompleteWithApprovalResponses, | ||
| type ToolSet, | ||
| } from 'ai'; | ||
|
|
||
| export const ASSISTANT_DRAWER_ID = 'compass-assistant-drawer'; | ||
|
|
||
|
|
@@ -200,8 +212,10 @@ export const AssistantProvider: React.FunctionComponent< | |
| appNameForPrompt: string; | ||
| chat: Chat<AssistantMessage>; | ||
| atlasAiService: AtlasAiService; | ||
| toolsController: ToolsController; | ||
| preferences: PreferencesAccess; | ||
| }> | ||
| > = ({ chat, atlasAiService, children }) => { | ||
| > = ({ chat, atlasAiService, toolsController, preferences, children }) => { | ||
| const { openDrawer } = useDrawerActions(); | ||
| const track = useTelemetry(); | ||
|
|
||
|
|
@@ -272,6 +286,22 @@ export const AssistantProvider: React.FunctionComponent< | |
| chat.messages = [...chat.messages, contextPrompt]; | ||
| } | ||
|
|
||
| const { enableToolCalling } = preferences.getPreferences(); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is something I'm probably gonna have to check with Sergey when he's back, but the preferences hook isn't working well here - it caches the initial value so it doesn't respond if you toggle the flag from the preferences modal. This way at least works for now. |
||
|
|
||
| if (enableToolCalling) { | ||
| toolsController.setActiveTools(new Set(['compass-ui'])); | ||
| toolsController.setContext({ | ||
| query: assistantGlobalStateRef.current.currentQuery || undefined, | ||
| aggregation: | ||
| assistantGlobalStateRef.current.currentAggregation || undefined, | ||
| }); | ||
| } else { | ||
| toolsController.setActiveTools(new Set([])); | ||
| toolsController.setContext({ | ||
| query: undefined, | ||
| aggregation: undefined, | ||
| }); | ||
| } | ||
| await chat.sendMessage(message, options); | ||
| }; | ||
| }); | ||
|
|
@@ -340,25 +370,37 @@ export const CompassAssistantProvider = registerCompassPlugin( | |
| appNameForPrompt, | ||
| chat, | ||
| atlasAiService, | ||
| toolsController, | ||
| preferences, | ||
| children, | ||
| }: PropsWithChildren<{ | ||
| appNameForPrompt: string; | ||
| originForPrompt: string; | ||
| chat?: Chat<AssistantMessage>; | ||
| atlasAiService?: AtlasAiService; | ||
| toolsController?: ToolsController; | ||
| preferences?: PreferencesAccess; | ||
| }>) => { | ||
| if (!chat) { | ||
| throw new Error('Chat was not provided by the state'); | ||
| } | ||
| if (!atlasAiService) { | ||
| throw new Error('atlasAiService was not provided by the state'); | ||
| } | ||
| if (!toolsController) { | ||
| throw new Error('toolsController was not provided by the state'); | ||
| } | ||
| if (!preferences) { | ||
| throw new Error('preferences was not provided by the state'); | ||
| } | ||
| return ( | ||
| <AssistantGlobalStateProvider> | ||
| <AssistantProvider | ||
| appNameForPrompt={appNameForPrompt} | ||
| chat={chat} | ||
| atlasAiService={atlasAiService} | ||
| toolsController={toolsController} | ||
| preferences={preferences} | ||
| > | ||
| {children} | ||
| </AssistantProvider> | ||
|
|
@@ -367,7 +409,14 @@ export const CompassAssistantProvider = registerCompassPlugin( | |
| }, | ||
| activate: ( | ||
| { chat: initialChat, originForPrompt, appNameForPrompt }, | ||
| { atlasService, atlasAiService, logger, track } | ||
| { | ||
| atlasService, | ||
| atlasAiService, | ||
| toolsController, | ||
| preferences, | ||
| logger, | ||
| track, | ||
| } | ||
| ) => { | ||
| const chat = | ||
| initialChat ?? | ||
|
|
@@ -377,10 +426,13 @@ export const CompassAssistantProvider = registerCompassPlugin( | |
| atlasService, | ||
| logger, | ||
| track, | ||
| getTools: () => toolsController.getActiveTools(), | ||
| }); | ||
|
|
||
| return { | ||
| store: { state: { chat, atlasAiService } }, | ||
| store: { | ||
| state: { chat, atlasAiService, toolsController, preferences }, | ||
| }, | ||
| deactivate: () => {}, | ||
| }; | ||
| }, | ||
|
|
@@ -389,8 +441,10 @@ export const CompassAssistantProvider = registerCompassPlugin( | |
| atlasService: atlasServiceLocator, | ||
| atlasAiService: atlasAiServiceLocator, | ||
| atlasAuthService: atlasAuthServiceLocator, | ||
| toolsController: toolsControllerLocator, | ||
| track: telemetryLocator, | ||
| logger: createLoggerLocator('COMPASS-ASSISTANT'), | ||
| preferences: preferencesLocator, | ||
| } | ||
| ); | ||
|
|
||
|
|
@@ -401,6 +455,7 @@ export function createDefaultChat({ | |
| logger, | ||
| track, | ||
| options, | ||
| getTools, | ||
| }: { | ||
| originForPrompt: string; | ||
| appNameForPrompt: string; | ||
|
|
@@ -410,13 +465,15 @@ export function createDefaultChat({ | |
| options?: { | ||
| transport: Chat<AssistantMessage>['transport']; | ||
| }; | ||
| getTools?: () => ToolSet; | ||
| }): Chat<AssistantMessage> { | ||
| const initialBaseUrl = 'http://PLACEHOLDER_BASE_URL_TO_BE_REPLACED.invalid'; | ||
| return new Chat({ | ||
| return new Chat<AssistantMessage>({ | ||
| transport: | ||
| options?.transport ?? | ||
| new DocsProviderTransport({ | ||
| origin: originForPrompt, | ||
| getTools, | ||
| instructions: buildConversationInstructionsPrompt({ | ||
| target: appNameForPrompt, | ||
| }), | ||
|
|
@@ -453,6 +510,7 @@ export function createDefaultChat({ | |
| }, | ||
| }).responses('mongodb-chat-latest'), | ||
| }), | ||
| sendAutomaticallyWhen: lastAssistantMessageIsCompleteWithApprovalResponses, | ||
| onError: (err: Error) => { | ||
| logger.log.error( | ||
| logger.mongoLogId(1_001_000_370), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the pipeline as text editor case we happen to already have the pipeline text. So I think this should at least be performant enough?