diff --git a/docs/examples/notebooks/tool_selection.ipynb b/docs/examples/notebooks/tool_selection.ipynb new file mode 100644 index 00000000..8eacb25e --- /dev/null +++ b/docs/examples/notebooks/tool_selection.ipynb @@ -0,0 +1,380 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "IByYhgKy9WCo" + }, + "source": [ + "# Simple Example\n", + "This Jupyter notebook runs on Colab and shows a simple example of Tooling selection using Top-K and Double Round Robin" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZIu6B1Ht927Z" + }, + "source": [ + "## Install Ollama\n", + "\n", + "Before we get started with Mellea, we download, install and serve ollama. We define set_css to wrap Colab output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VDaTfltQY3Fl" + }, + "outputs": [], + "source": [ + "!curl -fsSL https://ollama.com/install.sh | sh > /dev/null\n", + "!nohup ollama serve >/dev/null 2>&1 &\n", + "\n", + "from IPython.display import HTML, display\n", + "\n", + "\n", + "def set_css():\n", + " display(HTML(\"\\n\\n\"))\n", + "\n", + "\n", + "get_ipython().events.register(\"pre_run_cell\", set_css)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jEl3nAk696mI" + }, + "source": [ + "## Install Mellea\n", + "We run `uv pip install mellea` to install Mellea." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9EurAUSz_1yl" + }, + "outputs": [], + "source": [ + "!uv pip install mellea -q" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NU7bZqKA0djW" + }, + "source": [ + "### **Import top_k and double_round_robin libraries**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Jt3UC-Sa0djY" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/generative-computing/mellea-contribs.git" + ] + }, + { + "cell_type": "code", + "source": [ + "%cd mellea-contribs\n", + "!pip install -e . -qq" + ], + "metadata": { + "id": "Pf8560taKn7x" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## **Defining sample set of tools**" + ], + "metadata": { + "id": "W0rC23c6Lb6f" + } + }, + { + "cell_type": "code", + "source": [ + "from mellea import start_session\n", + "import mellea\n", + "\n", + "# Start a Mellea session\n", + "m = start_session()\n", + "\n", + "# Define a sample set of tools\n", + "TOOLS = [\n", + " {\n", + " \"name\": \"EnterpriseScheduler\",\n", + " \"short_description\": \"Schedules internal meetings and events.\",\n", + " \"long_description\": (\n", + " \"EnterpriseScheduler is an internal tool used to schedule meetings and events across teams. \"\n", + " \"It integrates with company calendars, room booking systems, and employee availability data \"\n", + " \"to resolve conflicts and enforce scheduling policies. The tool focuses strictly on time \"\n", + " \"and resource coordination and does not perform financial analysis or budgeting.\"\n", + " ),\n", + " \"requirements\": {\n", + " \"accepts\": [\"date\", \"participants\", \"location\"],\n", + " \"does_not_support\": [\"budget_calculation\", \"cost_estimation\"]\n", + " }\n", + " },\n", + " {\n", + " \"name\": \"BudgetOptimizer\",\n", + " \"short_description\": \"Performs internal budget forecasting and cost analysis.\",\n", + " \"long_description\": (\n", + " \"BudgetOptimizer is a proprietary financial planning tool used for estimating costs, \"\n", + " \"forecasting budgets, and optimizing spend for internal initiatives. It integrates with \"\n", + " \"enterprise ERP systems and historical financial data to produce projections and summaries. \"\n", + " \"The tool does not schedule meetings or manage calendars.\"\n", + " ),\n", + " \"requirements\": {\n", + " \"accepts\": [\"cost_inputs\", \"financial_constraints\"],\n", + " \"does_not_support\": [\"meeting_scheduling\", \"calendar_management\"]\n", + " }\n", + " },\n", + " {\n", + " \"name\": \"DocSearchPro\",\n", + " \"short_description\": \"Searches internal documents and knowledge bases.\",\n", + " \"long_description\": (\n", + " \"DocSearchPro provides semantic search over internal documents, policies, and project files. \"\n", + " \"It is designed to help employees retrieve institutional knowledge and past decisions quickly. \"\n", + " \"The tool does not execute actions, perform calculations, or modify schedules.\"\n", + " ),\n", + " \"requirements\": {\n", + " \"accepts\": [\"search_query\", \"filters\"],\n", + " \"does_not_support\": [\"scheduling\", \"budgeting\", \"resource_allocation\"]\n", + " }\n", + " },\n", + " {\n", + " \"name\": \"ResourceAllocator\",\n", + " \"short_description\": \"Allocates people and resources to projects.\",\n", + " \"long_description\": (\n", + " \"ResourceAllocator manages internal project resources by assigning people, equipment, and \"\n", + " \"timelines based on availability and priority constraints. It integrates with HR and project \"\n", + " \"management systems to track utilization and resolve conflicts. The tool does not handle \"\n", + " \"meeting scheduling or financial budgeting.\"\n", + " ),\n", + " \"requirements\": {\n", + " \"accepts\": [\"project_id\", \"resource_constraints\"],\n", + " \"does_not_support\": [\"meeting_scheduling\", \"budget_calculation\"]\n", + " }\n", + " }\n", + "]\n", + "\n", + "USER_QUERY = (\"Plan a team kickoff for next week by scheduling a meeting and estimating the expected cost based on headcount and duration.\")" + ], + "metadata": { + "id": "oSyRPp5bKpz7" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## **Step 1: Top-K Tool Shortlisting**" + ], + "metadata": { + "id": "nY9JSd6HMfZr" + } + }, + { + "cell_type": "code", + "source": [ + "from mellea_contribs.tools.top_k import top_k\n", + "from pydantic import RootModel\n", + "from mellea.stdlib.requirement import check, req, simple_validate\n", + "from mellea.stdlib.sampling import RejectionSamplingStrategy\n", + "\n", + "def shortlist_tools(m, user_query, tools, k):\n", + " items = [\n", + " {\"name\": t[\"name\"], \"description\": t[\"short_description\"]}\n", + " for t in tools\n", + " ]\n", + "\n", + " ranked = top_k(\n", + " items=items,\n", + " comparison_prompt=f\"\"\"\n", + " Select the tools most relevant to:\n", + " '{user_query}'\n", + " \"\"\",\n", + " m=m,\n", + " k=k,\n", + " )\n", + "\n", + " ranked_names = {t[\"name\"] for t in ranked}\n", + " return [t for t in tools if t[\"name\"] in ranked_names]" + ], + "metadata": { + "id": "gtITUkKlLsM0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## **Step 2: High-Confidence Ranking with Double Round Robin**" + ], + "metadata": { + "id": "NM3w5OXEMt3p" + } + }, + { + "cell_type": "code", + "source": [ + "from mellea_contribs.tools.double_round_robin import double_round_robin\n", + "\n", + "def rank_tools(m, user_query, tools):\n", + " items = [\n", + " {\n", + " \"name\": t[\"name\"],\n", + " \"description\": t[\"long_description\"],\n", + " \"constraints\": t[\"requirements\"],\n", + " }\n", + " for t in tools\n", + " ]\n", + "\n", + " return double_round_robin(\n", + " items=items,\n", + " comparison_prompt=f\"\"\"\n", + " Given the user query:\n", + " '{user_query}'\n", + "\n", + " Compare tools strictly based on their described capabilities\n", + " and stated constraints. Do not assume unsupported functionality.\n", + " \"\"\",\n", + " m=m,\n", + " )\n" + ], + "metadata": { + "id": "Hf66hB_fMqpi" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## **Final Tool selection**" + ], + "metadata": { + "id": "s9J7WdURNXeB" + } + }, + { + "cell_type": "code", + "source": [ + "class ToolChoice(RootModel[str]):\n", + " pass\n", + "\n", + "def choose_final_tool(\n", + " m: mellea.MelleaSession,\n", + " user_query: str,\n", + " ranked_tools: list[dict],\n", + "):\n", + " tool_names = [t[\"name\"] for t in ranked_tools]\n", + "\n", + " tool_context = \"\\n\\n\".join(\n", + " f\"\"\"\n", + " Tool: {t['name']}\n", + " Description: {t['long_description']}\n", + " Requirements: {t['requirements']}\n", + " \"\"\"\n", + " for t in ranked_tools\n", + " )\n", + "\n", + " response = m.instruct(\n", + " f\"\"\"\n", + " Given the user query:\n", + " '{user_query}'\n", + "\n", + " Choose the single best tool.\n", + " The tool must satisfy the user's intent and must not violate its stated requirements.\n", + " \"\"\",\n", + " grounding_context={\n", + " \"tools\": tool_context\n", + " },\n", + " requirements=[\n", + " req(\n", + " \"Tool must be one of the provided candidates\",\n", + " validation_fn=simple_validate(lambda s: s in tool_names),\n", + " ),\n", + " req(\"Return only the tool name\"),\n", + " ],\n", + " strategy=RejectionSamplingStrategy(loop_budget=3),\n", + " format=ToolChoice,\n", + " )\n", + "\n", + " return response.value\n" + ], + "metadata": { + "id": "WYfUea4YMzN9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "m = start_session()\n", + "\n", + "top_k_tools = shortlist_tools(m, USER_QUERY, TOOLS, k=2)\n", + "ranked = rank_tools(m, USER_QUERY, top_k_tools)\n", + "\n", + "ranked_tools = []\n", + "tool_by_name = {t[\"name\"]: t for t in TOOLS}\n", + "\n", + "for item, score in ranked:\n", + " ranked_tools.append(tool_by_name[item[\"name\"]])\n", + "\n", + "final_tool = choose_final_tool(m, USER_QUERY, ranked_tools)\n", + "\n", + "print(\"Top-K shortlisted tools:\", [t[\"name\"] for t in top_k_tools])\n", + "print(\"DRR ranking:\", [(t[\"name\"], score) for t, score in ranked])\n", + "print(\"Final selected tool:\", final_tool)" + ], + "metadata": { + "id": "8iCOuJ66NTu8" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file