diff --git a/packages/typescript/ai-client/src/chat-client.ts b/packages/typescript/ai-client/src/chat-client.ts index 1d1ba091..1fdbddd4 100644 --- a/packages/typescript/ai-client/src/chat-client.ts +++ b/packages/typescript/ai-client/src/chat-client.ts @@ -19,6 +19,7 @@ export class ChatClient { private connection: ConnectionAdapter private uniqueId: string private body: Record = {} + private context: unknown = undefined private isLoading = false private error: Error | undefined = undefined private abortController: AbortController | null = null @@ -43,6 +44,7 @@ export class ChatClient { constructor(options: ChatClientOptions) { this.uniqueId = options.id || this.generateUniqueId('chat') this.body = options.body || {} + this.context = options.context this.connection = options.connection this.events = new DefaultChatClientEventEmitter(this.uniqueId) @@ -136,7 +138,7 @@ export class ChatClient { const clientTool = this.clientToolsRef.current.get(args.toolName) if (clientTool?.execute) { try { - const output = await clientTool.execute(args.input) + const output = await clientTool.execute(args.input, { context: this.context }) await this.addToolResult({ toolCallId: args.toolCallId, tool: args.toolName, diff --git a/packages/typescript/ai-client/src/types.ts b/packages/typescript/ai-client/src/types.ts index 4f83debb..8e0fc15e 100644 --- a/packages/typescript/ai-client/src/types.ts +++ b/packages/typescript/ai-client/src/types.ts @@ -133,6 +133,7 @@ export interface UIMessage = any> { export interface ChatClientOptions< TTools extends ReadonlyArray = any, + TContext = unknown, > { /** * Connection adapter for streaming @@ -140,6 +141,11 @@ export interface ChatClientOptions< */ connection: ConnectionAdapter + /** + * Context object to pass to client tools during execution + */ + context?: TContext + /** * Initial messages to populate the chat */ diff --git a/packages/typescript/ai/src/activities/chat/index.ts b/packages/typescript/ai/src/activities/chat/index.ts index 95c01374..4fcdcb71 100644 --- a/packages/typescript/ai/src/activities/chat/index.ts +++ b/packages/typescript/ai/src/activities/chat/index.ts @@ -31,6 +31,7 @@ import type { TextOptions, Tool, ToolCall, + ToolOptions, } from '../../types' // =========================== @@ -218,6 +219,7 @@ class TextEngine< private earlyTermination = false private toolPhase: ToolPhaseResult = 'continue' private cyclePhase: CyclePhase = 'processText' + private readonly options: Partial> constructor(config: TextEngineConfig) { this.adapter = config.adapter @@ -235,6 +237,7 @@ class TextEngine< ? { signal: config.params.abortController.signal } : undefined this.effectiveSignal = config.params.abortController?.signal + this.options = { context: config.params.context } } /** Get the accumulated content after the chat loop completes */ @@ -579,6 +582,7 @@ class TextEngine< this.tools, approvals, clientToolResults, + this.options, ) if ( diff --git a/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts b/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts index 5096db42..b4f47b6f 100644 --- a/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts +++ b/packages/typescript/ai/src/activities/chat/tools/tool-calls.ts @@ -4,6 +4,7 @@ import type { ModelMessage, Tool, ToolCall, + ToolOptions, ToolResultStreamChunk, } from '../../../types' @@ -111,6 +112,7 @@ export class ToolCallManager { */ async *executeTools( doneChunk: DoneStreamChunk, + options: Partial> = {}, ): AsyncGenerator, void> { const toolCallsArray = this.getToolCalls() const toolResults: Array = [] @@ -147,7 +149,9 @@ export class ToolCallManager { } // Execute the tool - let result = await tool.execute(args) + let result = await tool.execute(args, { + context: options.context, + }) // Validate output against outputSchema if provided (for Standard Schema compliant schemas) if ( @@ -260,6 +264,7 @@ export async function executeToolCalls( tools: ReadonlyArray, approvals: Map = new Map(), clientResults: Map = new Map(), + options: Partial> = {}, ): Promise { const results: Array = [] const needsApproval: Array = [] @@ -319,6 +324,10 @@ export async function executeToolCalls( } } + const toolOptions: ToolOptions = { + context: options.context, + } + // CASE 1: Client-side tool (no execute function) if (!tool.execute) { // Check if tool needs approval @@ -395,7 +404,7 @@ export async function executeToolCalls( // Execute after approval const startTime = Date.now() try { - let result = await tool.execute(input) + let result = await tool.execute(input, toolOptions) const duration = Date.now() - startTime // Validate output against outputSchema if provided (for Standard Schema compliant schemas) @@ -453,7 +462,7 @@ export async function executeToolCalls( // CASE 3: Normal server tool - execute immediately const startTime = Date.now() try { - let result = await tool.execute(input) + let result = await tool.execute(input, toolOptions) const duration = Date.now() - startTime // Validate output against outputSchema if provided (for Standard Schema compliant schemas) diff --git a/packages/typescript/ai/src/activities/chat/tools/tool-definition.ts b/packages/typescript/ai/src/activities/chat/tools/tool-definition.ts index 69633c49..6a9df2ce 100644 --- a/packages/typescript/ai/src/activities/chat/tools/tool-definition.ts +++ b/packages/typescript/ai/src/activities/chat/tools/tool-definition.ts @@ -4,6 +4,7 @@ import type { JSONSchema, SchemaInput, Tool, + ToolOptions, } from '../../../types' /** @@ -24,16 +25,12 @@ export interface ClientTool< TInput extends SchemaInput = SchemaInput, TOutput extends SchemaInput = SchemaInput, TName extends string = string, -> { + TContext extends unknown = unknown, +> extends Tool { __toolSide: 'client' - name: TName - description: string - inputSchema?: TInput - outputSchema?: TOutput - needsApproval?: boolean - metadata?: Record execute?: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType } @@ -109,18 +106,20 @@ export interface ToolDefinition< /** * Create a server-side tool with execute function */ - server: ( + server: ( execute: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType, ) => ServerTool /** * Create a client-side tool with optional execute function */ - client: ( + client: ( execute?: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType, ) => ClientTool } @@ -190,27 +189,31 @@ export function toolDefinition< const definition: ToolDefinition = { __toolSide: 'definition', ...config, - server( + server ( execute: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType, ): ServerTool { return { __toolSide: 'server', ...config, - execute, + execute: (args, options) => execute(args, options as ToolOptions), } }, - client( + client( execute?: ( args: InferSchemaType, + options: ToolOptions, ) => Promise> | InferSchemaType, ): ClientTool { return { __toolSide: 'client', ...config, - execute, + execute: execute + ? (args: any, options: any) => execute(args, options as ToolOptions) + : undefined } }, } diff --git a/packages/typescript/ai/src/types.ts b/packages/typescript/ai/src/types.ts index 9df621c6..3fde34c8 100644 --- a/packages/typescript/ai/src/types.ts +++ b/packages/typescript/ai/src/types.ts @@ -84,6 +84,15 @@ export type SchemaInput = StandardJSONSchemaV1 | JSONSchema export type InferSchemaType = T extends StandardJSONSchemaV1 ? TInput : unknown +/** + * Options object passed to tool execute functions + * @template TContext - The type of context object + */ +export interface ToolOptions { + /** Context object that can be accessed by tools during execution */ + context?: TContext +} + export interface ToolCall { id: string type: 'function' @@ -420,15 +429,17 @@ export interface Tool< * Can return any value - will be automatically stringified if needed. * * @param args - The arguments parsed from the model's tool call (validated against inputSchema) + * @param options - Options object containing context to pass to tool execute functions * @returns Result to send back to the model (validated against outputSchema if provided) * * @example - * execute: async (args) => { + * execute: async (args, options) => { + * const user = await options.context?.db.users.find({ id: options.context.userId }); // Can access context * const weather = await fetchWeather(args.location); * return weather; // Can return object or string * } */ - execute?: (args: any) => Promise | any + execute?: (args: any, options: any) => Promise | any /** If true, tool execution requires user approval before running. Works with both server and client tools. */ needsApproval?: boolean @@ -565,6 +576,7 @@ export type AgentLoopStrategy = (state: AgentLoopState) => boolean export interface TextOptions< TProviderOptionsSuperset extends Record = Record, TProviderOptionsForModel = TProviderOptionsSuperset, + TContext = unknown, > { model: string messages: Array @@ -619,6 +631,29 @@ export interface TextOptions< metadata?: Record modelOptions?: TProviderOptionsForModel request?: Request | RequestInit + /** + * Context object that is automatically passed to all tool execute functions. + * + * This allows tools to access shared context (like user ID, database connections, + * request metadata, etc.) without needing to capture them via closures. + * Works for both server and client tools. + * + * @example + * const stream = chat({ + * adapter: openai(), + * model: 'gpt-4o', + * messages, + * context: { userId: '123', db }, + * tools: [getUserData], + * }); + * + * // In tool definition: + * const getUserData = getUserDataDef.server(async (args, options) => { + * // options.context.userId and options.context.db are available + * return await options.context.db.users.find({ userId: options.context.userId }); + * }); + */ + context?: TContext /** * Schema for structured output. diff --git a/packages/typescript/ai/tests/ai-text.test.ts b/packages/typescript/ai/tests/ai-text.test.ts index 3eee78e8..4b710c40 100644 --- a/packages/typescript/ai/tests/ai-text.test.ts +++ b/packages/typescript/ai/tests/ai-text.test.ts @@ -380,7 +380,7 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { inputSchema: z.object({ location: z.string().optional(), }), - execute: vi.fn(async (args: any) => + execute: vi.fn(async (args: any, _options?: any) => JSON.stringify({ temp: 72, location: args.location }), ), } @@ -445,7 +445,7 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { }), ) - expect(tool.execute).toHaveBeenCalledWith({ location: 'Paris' }) + expect(tool.execute).toHaveBeenCalledWith({ location: 'Paris' }, { context: undefined }) expect(adapter.chatStreamCallCount).toBeGreaterThanOrEqual(2) const toolResultChunks = chunks.filter((c) => c.type === 'tool_result') @@ -469,7 +469,7 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { a: z.number(), b: z.number(), }), - execute: vi.fn(async (args: any) => + execute: vi.fn(async (args: any, _options?: any) => JSON.stringify({ result: args.a + args.b }), ), } @@ -551,7 +551,7 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { ) // Tool should be executed with complete arguments - expect(tool.execute).toHaveBeenCalledWith({ a: 10, b: 20 }) + expect(tool.execute).toHaveBeenCalledWith({ a: 10, b: 20 }, { context: undefined }) const toolResultChunks = chunks.filter((c) => c.type === 'tool_result') expect(toolResultChunks.length).toBeGreaterThan(0) }) @@ -561,14 +561,14 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { name: 'tool1', description: 'Tool 1', inputSchema: z.object({}), - execute: vi.fn(async () => JSON.stringify({ result: 1 })), + execute: vi.fn(async (_args: any, _options?: any) => JSON.stringify({ result: 1 })), } const tool2: Tool = { name: 'tool2', description: 'Tool 2', inputSchema: z.object({}), - execute: vi.fn(async () => JSON.stringify({ result: 2 })), + execute: vi.fn(async (_args: any, _options?: any) => JSON.stringify({ result: 2 })), } class MultipleToolsAdapter extends MockAdapter { @@ -659,7 +659,7 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { name: 'test_tool', description: 'Test', inputSchema: z.object({}), - execute: vi.fn(async () => JSON.stringify({ result: 'ok' })), + execute: vi.fn(async (_args: any, _options?: any) => JSON.stringify({ result: 'ok' })), } class ContentWithToolsAdapter extends MockAdapter { @@ -1469,7 +1469,7 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { const chunks = await collectChunks(stream) expect(chunks[0]?.type).toBe('tool_result') - expect(toolExecute).toHaveBeenCalledWith({ path: '/tmp/test.txt' }) + expect(toolExecute).toHaveBeenCalledWith({ path: '/tmp/test.txt' }, { context: undefined }) expect(adapter.chatStreamCallCount).toBe(1) }) }) @@ -2558,7 +2558,7 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { await collectChunks(stream2) // Tool should have been executed because approval was provided - expect(tool.execute).toHaveBeenCalledWith({ path: '/tmp/test.txt' }) + expect(tool.execute).toHaveBeenCalledWith({ path: '/tmp/test.txt' }, { context: undefined }) }) it('should extract client tool outputs from messages with parts', async () => { @@ -2822,7 +2822,7 @@ describe('chat() - Comprehensive Logic Path Coverage', () => { name: 'get_temperature', description: 'Get the current temperature in degrees', inputSchema: z.object({}), - execute: vi.fn(async (_args: any) => { + execute: vi.fn(async (_args: any, _options?: any) => { return '70' }), } diff --git a/packages/typescript/ai/tests/tool-call-manager.test.ts b/packages/typescript/ai/tests/tool-call-manager.test.ts index af117f08..b82706f8 100644 --- a/packages/typescript/ai/tests/tool-call-manager.test.ts +++ b/packages/typescript/ai/tests/tool-call-manager.test.ts @@ -18,7 +18,7 @@ describe('ToolCallManager', () => { inputSchema: z.object({ location: z.string().optional(), }), - execute: vi.fn((args: any) => { + execute: vi.fn((args: any, _options?: any) => { return JSON.stringify({ temp: 72, location: args.location }) }), } @@ -121,7 +121,7 @@ describe('ToolCallManager', () => { expect(finalResult[0]?.toolCallId).toBe('call_123') // Tool execute should have been called - expect(mockWeatherTool.execute).toHaveBeenCalledWith({ location: 'Paris' }) + expect(mockWeatherTool.execute).toHaveBeenCalledWith({ location: 'Paris' }, { context: undefined }) }) it('should handle tool execution errors gracefully', async () => { @@ -209,15 +209,15 @@ describe('ToolCallManager', () => { }) it('should handle multiple tool calls in same iteration', async () => { - const calculateTool: Tool = { - name: 'calculate', - description: 'Calculate', - inputSchema: z.object({ - expression: z.string(), - }), - execute: vi.fn((args: any) => { - return JSON.stringify({ result: eval(args.expression) }) - }), +const calculateTool: Tool = { + name: 'calculate', + description: 'Calculate', + inputSchema: z.object({ + expression: z.string(), + }), + execute: vi.fn((args: any, _options?: any) => { + return JSON.stringify({ result: eval(args.expression) }) + }), } const manager = new ToolCallManager([mockWeatherTool, calculateTool]) diff --git a/packages/typescript/ai/tests/tool-definition.test.ts b/packages/typescript/ai/tests/tool-definition.test.ts index 2d37178c..565dd725 100644 --- a/packages/typescript/ai/tests/tool-definition.test.ts +++ b/packages/typescript/ai/tests/tool-definition.test.ts @@ -1,6 +1,7 @@ -import { describe, it, expect, vi } from 'vitest' +import { describe, expect, it, vi } from 'vitest' import { z } from 'zod' import { toolDefinition } from '../src/activities/chat/tools/tool-definition' +import type { ToolOptions } from '../src' describe('toolDefinition', () => { it('should create a tool definition with basic properties', () => { @@ -46,7 +47,7 @@ describe('toolDefinition', () => { }), }) - const executeFn = vi.fn(async (_args: { location: string }) => { + const executeFn = vi.fn((_args: { location: string }, _options?: unknown) => { return { temperature: 72, conditions: 'sunny', @@ -60,9 +61,9 @@ describe('toolDefinition', () => { expect(serverTool.execute).toBeDefined() if (serverTool.execute) { - const result = await serverTool.execute({ location: 'Paris' }) + const result = await serverTool.execute({ location: 'Paris' }, { context: undefined }) expect(result).toEqual({ temperature: 72, conditions: 'sunny' }) - expect(executeFn).toHaveBeenCalledWith({ location: 'Paris' }) + expect(executeFn).toHaveBeenCalledWith({ location: 'Paris' }, { context: undefined }) } }) @@ -79,7 +80,7 @@ describe('toolDefinition', () => { }), }) - const executeFn = vi.fn(async (_args: { key: string; value: string }) => { + const executeFn = vi.fn(async (_args: { key: string; value: string }, _options?: unknown) => { return { success: true } }) @@ -90,9 +91,9 @@ describe('toolDefinition', () => { expect(clientTool.execute).toBeDefined() if (clientTool.execute) { - const result = await clientTool.execute({ key: 'test', value: 'data' }) + const result = await clientTool.execute({ key: 'test', value: 'data' }, { context: undefined }) expect(result).toEqual({ success: true }) - expect(executeFn).toHaveBeenCalledWith({ key: 'test', value: 'data' }) + expect(executeFn).toHaveBeenCalledWith({ key: 'test', value: 'data' }, { context: undefined }) } }) @@ -176,7 +177,7 @@ describe('toolDefinition', () => { }) if (serverTool.execute) { - const result = serverTool.execute({ value: 5 }) + const result = serverTool.execute({ value: 5 }, { context: undefined }) expect(result).toEqual({ doubled: 10 }) } }) @@ -222,7 +223,7 @@ describe('toolDefinition', () => { orderId: '123', items: [], shipping: { address: '123 Main St', method: 'standard' }, - }) + }, { context: undefined }) expect(serverTool.__toolSide).toBe('server') }) @@ -255,4 +256,106 @@ describe('toolDefinition', () => { expect(tool.__toolSide).toBe('definition') expect(tool.inputSchema).toBeDefined() }) + + it('should pass context to server tool execute function', async () => { + const tool = toolDefinition({ + name: 'getContextValue', + description: 'Get a value from context', + inputSchema: z.object({ + key: z.string(), + }), + outputSchema: z.object({ + exists: z.boolean(), + value: z.string().optional(), + }), + }) + + const contextValue = 'test-value' + const context = { testData: contextValue } + + const serverTool = tool.server( + (_: unknown, options: ToolOptions | undefined) => { + const exists = options?.context?.testData !== undefined + const value = exists ? options.context?.testData : undefined + return { exists, value } + } + ) + + if (serverTool.execute) { + const result = await serverTool.execute( + { key: 'testData' }, + { context } + ) + + expect(result.exists).toBe(true) + expect(result.value).toBe(contextValue) + } + }) + + it('should pass context to client tool execute function', async () => { + const tool = toolDefinition({ + name: 'getContextValue', + description: 'Get a value from context', + inputSchema: z.object({ + key: z.string(), + }), + outputSchema: z.object({ + exists: z.boolean(), + value: z.string().optional(), + }), + }) + + const contextValue = 'test-value' + const context = { testData: contextValue } + + const clientTool = tool.client( + (_: unknown, options: ToolOptions | undefined) => { + const exists = options?.context?.testData !== undefined + const value = exists ? options.context?.testData : undefined + return { exists, value } + } + ) + + if (clientTool.execute) { + const result = await clientTool.execute( + { key: 'testData' }, + { context } + ) + + expect(result.exists).toBe(true) + expect(result.value).toBe(contextValue) + } + }) + + it('should handle missing context gracefully', async () => { + const tool = toolDefinition({ + name: 'getContextValue', + description: 'Get a value from context', + inputSchema: z.object({ + key: z.string(), + }), + outputSchema: z.object({ + exists: z.boolean(), + value: z.string().optional(), + }), + }) + + const serverTool = tool.server( + (_: unknown, options: ToolOptions<{ testData?: string }> | undefined) => { + const exists = options?.context?.testData !== undefined + const value = exists ? options.context?.testData : undefined + return { exists, value } + } + ) + + if (serverTool.execute) { + const result = await serverTool.execute( + { key: 'testData' }, + { context: {} } // Empty context + ) + + expect(result.exists).toBe(false) + expect(result.value).toBeUndefined() + } + }) })