Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 146 additions & 49 deletions python/llm/agents/agent-mastery-course/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
from uuid import uuid4
from contextlib import nullcontext
import os
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
Expand All @@ -12,7 +14,8 @@
from arize.otel import register
from openinference.instrumentation.langchain import LangChainInstrumentor
from openinference.instrumentation.litellm import LiteLLMInstrumentor
from openinference.instrumentation import using_prompt_template
from openinference.instrumentation import using_prompt_template, using_metadata, using_attributes
from opentelemetry import trace
_TRACING = True
except Exception:
def using_prompt_template(**kwargs): # type: ignore
Expand All @@ -21,12 +24,24 @@ def using_prompt_template(**kwargs): # type: ignore
def _noop():
yield
return _noop()
def using_metadata(*args, **kwargs): # type: ignore
from contextlib import contextmanager
@contextmanager
def _noop():
yield
return _noop()
def using_attributes(*args, **kwargs): # type: ignore
from contextlib import contextmanager
@contextmanager
def _noop():
yield
return _noop()
_TRACING = False

# LangGraph + LangChain
from langgraph.graph import StateGraph, END, START
from langgraph.prebuilt import ToolNode
from typing_extensions import TypedDict, Annotated
from typing_extensions import TypedDict, Annotated, NotRequired
import operator
import json
from pathlib import Path
Expand All @@ -52,6 +67,10 @@ class TripRequest(BaseModel):
budget: Optional[str] = None
interests: Optional[str] = None
travel_style: Optional[str] = None
user_input: Optional[str] = None
session_id: Optional[str] = None
user_id: Optional[str] = None
turn_index: Optional[int] = None


class TripResponse(BaseModel):
Expand Down Expand Up @@ -601,24 +620,31 @@ class TripState(TypedDict):
local_context: Optional[str]
final: Optional[str]
tool_calls: Annotated[List[Dict[str, Any]], operator.add]
session_id: NotRequired[str]
user_id: NotRequired[Optional[str]]
turn_index: NotRequired[int]


def research_agent(state: TripState) -> TripState:
req = state["trip_request"]
destination = req["destination"]
user_input = (req.get("user_input") or "").strip()
if ENABLE_MCP:
prompt_t = (
"You are a research assistant.\n"
"First, call the mcp_weather tool to get weather for {destination}.\n"
"Then use other tools as needed for additional information."
)
prompt_lines = [
"You are a research assistant.",
"First, call the mcp_weather tool to get weather for {destination}.",
"Then use other tools as needed for additional information.",
]
else:
prompt_t = (
"You are a research assistant.\n"
"Gather essential information about {destination}.\n"
"Use at most one tool if needed."
)
vars_ = {"destination": destination}
prompt_lines = [
"You are a research assistant.",
"Gather essential information about {destination}.",
"Use at most one tool if needed.",
]
if user_input:
prompt_lines.append("User input: {user_input}")
prompt_t = "\n".join(prompt_lines)
vars_ = {"destination": destination, "user_input": user_input}
tools = [essential_info, weather_brief, visa_brief]
if ENABLE_MCP:
tools.append(mcp_weather)
Expand All @@ -644,17 +670,26 @@ def budget_agent(state: TripState) -> TripState:
req = state["trip_request"]
destination, duration = req["destination"], req["duration"]
budget = req.get("budget", "moderate")
prompt_t = (
"You are a budget analyst.\n"
"Analyze costs for {destination} over {duration} with budget: {budget}.\n"
"Use tools to get pricing information, then provide a detailed breakdown."
)
vars_ = {"destination": destination, "duration": duration, "budget": budget}

user_input = (req.get("user_input") or "").strip()
prompt_lines = [
"You are a budget analyst.",
"Analyze costs for {destination} over {duration} with budget: {budget}.",
"Use tools to get pricing information, then provide a detailed breakdown.",
]
if user_input:
prompt_lines.append("User input: {user_input}")
prompt_t = "\n".join(prompt_lines)
vars_ = {
"destination": destination,
"duration": duration,
"budget": budget,
"user_input": user_input,
}

messages = [SystemMessage(content=prompt_t.format(**vars_))]
tools = [budget_basics, attraction_prices]
agent = llm.bind_tools(tools)

calls: List[Dict[str, Any]] = []

with using_prompt_template(template=prompt_t, variables=vars_, version="v1"):
Expand All @@ -670,7 +705,12 @@ def budget_agent(state: TripState) -> TripState:
# Add tool results and ask for synthesis
messages.append(res)
messages.extend(tr["messages"])
messages.append(SystemMessage(content=f"Create a detailed budget breakdown for {duration} in {destination} with a {budget} budget."))
follow_up = (
f"Create a detailed budget breakdown for {duration} in {destination} with a {budget} budget."
)
if user_input:
follow_up += f" Address this user input as well: {user_input}."
messages.append(SystemMessage(content=follow_up))

final_res = llm.invoke(messages)
out = final_res.content
Expand All @@ -683,9 +723,12 @@ def budget_agent(state: TripState) -> TripState:
def local_agent(state: TripState) -> TripState:
req = state["trip_request"]
destination = req["destination"]
interests = req.get("interests", "local culture")
user_input = (req.get("user_input") or "").strip()
interests_raw = (req.get("interests") or "").strip()
interests = interests_raw or "local culture"
# Pull semantic matches from curated dataset when the flag allows it.
retrieved = LOCAL_GUIDE_RETRIEVER.retrieve(destination, interests)
retrieval_focus = interests_raw or (user_input if user_input else None)
retrieved = LOCAL_GUIDE_RETRIEVER.retrieve(destination, retrieval_focus)
context_lines = []
citation_lines = []
for idx, item in enumerate(retrieved, start=1):
Expand All @@ -703,11 +746,21 @@ def local_agent(state: TripState) -> TripState:

prompt_t = (
"You are a local guide.\n"
"Use the retrieved travel notes to suggest authentic experiences in {destination} for interests: {interests}.\n"
"Use the retrieved travel notes to suggest authentic experiences in {destination}.\n"
)
if user_input:
prompt_t += "User input: {user_input}\n"
prompt_t += (
"Focus interests: {interests}.\n"
"Context:\n{context}\n"
"Cite the numbered items when you rely on them."
)
vars_ = {"destination": destination, "interests": interests, "context": context_text}
vars_ = {
"destination": destination,
"interests": interests,
"context": context_text,
"user_input": user_input,
}
with using_prompt_template(template=prompt_t, variables=vars_, version="v1"):
agent = llm.bind_tools([local_flavor, local_customs, hidden_gems])
res = agent.invoke([SystemMessage(content=prompt_t.format(**vars_))])
Expand Down Expand Up @@ -758,44 +811,64 @@ def itinerary_agent(state: TripState) -> TripState:
destination = req["destination"]
duration = req["duration"]
travel_style = req.get("travel_style", "standard")
prompt_t = (
"Create a {duration} itinerary for {destination} ({travel_style}).\n\n"
"Inputs:\nResearch: {research}\nBudget: {budget}\nLocal: {local}\n"
)
user_input = (req.get("user_input") or "").strip()
prompt_parts = [
"Create a {duration} itinerary for {destination} ({travel_style}).",
"",
"Inputs:",
"Research: {research}",
"Budget: {budget}",
"Local: {local}",
]
if user_input:
prompt_parts.append("User input: {user_input}")
prompt_t = "\n".join(prompt_parts)
vars_ = {
"duration": duration,
"destination": destination,
"travel_style": travel_style,
"research": (state.get("research") or "")[:400],
"budget": (state.get("budget") or "")[:400],
"local": (state.get("local") or "")[:400],
"user_input": user_input,
}
with using_prompt_template(template=prompt_t, variables=vars_, version="v1"):
res = llm.invoke([SystemMessage(content=prompt_t.format(**vars_))])
with using_attributes(tags=["itinerary", "final_agent"]):
if _TRACING:
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("metadata.itinerary", "true")
current_span.set_attribute("metadata.agent_type", "itinerary")
current_span.set_attribute("metadata.agent_node", "itinerary_agent")
if user_input:
current_span.set_attribute("metadata.user_input", user_input)
res = llm.invoke([SystemMessage(content=prompt_t.format(**vars_))])
return {"messages": [SystemMessage(content=res.content)], "final": res.content}


def build_graph():
g = StateGraph(TripState)
g.add_node("research_node", research_agent)
g.add_node("budget_node", budget_agent)
g.add_node("local_node", local_agent)
g.add_node("itinerary_node", itinerary_agent)
g.add_node("research_agent", research_agent)
g.add_node("budget_agent", budget_agent)
g.add_node("local_agent", local_agent)
g.add_node("itinerary_agent", itinerary_agent)

# Run research, budget, and local agents in parallel
g.add_edge(START, "research_node")
g.add_edge(START, "budget_node")
g.add_edge(START, "local_node")
g.add_edge(START, "research_agent")
g.add_edge(START, "budget_agent")
g.add_edge(START, "local_agent")

# All three agents feed into the itinerary agent
g.add_edge("research_node", "itinerary_node")
g.add_edge("budget_node", "itinerary_node")
g.add_edge("local_node", "itinerary_node")
g.add_edge("research_agent", "itinerary_agent")
g.add_edge("budget_agent", "itinerary_agent")
g.add_edge("local_agent", "itinerary_agent")

g.add_edge("itinerary_node", END)
g.add_edge("itinerary_agent", END)

# Compile without checkpointer to avoid state persistence issues
return g.compile()
compiled = g.compile()
compiled.name = "TripAgentGraph"
return compiled


app = FastAPI(title="AI Trip Planner")
Expand Down Expand Up @@ -838,17 +911,41 @@ def health():

@app.post("/plan-trip", response_model=TripResponse)
def plan_trip(req: TripRequest):

graph = build_graph()
# Only include necessary fields in initial state
# Agent outputs (research, budget, local, final) will be added during execution
state = {

session_id = req.session_id or str(uuid4())
user_id = req.user_id
turn_idx = req.turn_index

state: Dict[str, Any] = {
"messages": [],
"trip_request": req.model_dump(),
"tool_calls": [],
"session_id": session_id,
}
# No config needed without checkpointer
out = graph.invoke(state)
if user_id:
state["user_id"] = user_id
if turn_idx is not None:
state["turn_index"] = turn_idx

# Build attributes for session and user tracking
# Note: using_attributes only accepts session_id and user_id as kwargs
attrs_kwargs = {}
if session_id:
attrs_kwargs["session_id"] = session_id
if user_id:
attrs_kwargs["user_id"] = user_id

# Add turn_index as a custom span attribute if provided
if turn_idx is not None and _TRACING:
with using_attributes(**attrs_kwargs):
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("turn_index", turn_idx)
out = graph.invoke(state)
else:
with using_attributes(**attrs_kwargs):
out = graph.invoke(state)
return TripResponse(result=out.get("final", ""), tool_calls=out.get("tool_calls", []))


Expand Down
Loading