diff --git a/scenarios/debate/Dockerfile.adk-debate-judge b/scenarios/debate/Dockerfile.adk-debate-judge index 34cfa15..5d493ee 100644 --- a/scenarios/debate/Dockerfile.adk-debate-judge +++ b/scenarios/debate/Dockerfile.adk-debate-judge @@ -11,7 +11,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv sync --locked -COPY scenarios scenarios +COPY scenarios/debate scenarios/debate ENTRYPOINT ["uv", "run", "scenarios/debate/adk_debate_judge.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/debate/Dockerfile.debate-judge b/scenarios/debate/Dockerfile.debate-judge index 72e9226..e3e0f73 100644 --- a/scenarios/debate/Dockerfile.debate-judge +++ b/scenarios/debate/Dockerfile.debate-judge @@ -11,7 +11,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv sync --locked -COPY scenarios scenarios +COPY scenarios/debate scenarios/debate ENTRYPOINT ["uv", "run", "scenarios/debate/debate_judge.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/debate/Dockerfile.debater b/scenarios/debate/Dockerfile.debater index f2877b9..89ebc30 100644 --- a/scenarios/debate/Dockerfile.debater +++ b/scenarios/debate/Dockerfile.debater @@ -11,7 +11,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv sync --locked -COPY scenarios scenarios +COPY scenarios/debate scenarios/debate ENTRYPOINT ["uv", "run", "scenarios/debate/debater.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/tau2/Dockerfile.tau2-agent b/scenarios/tau2/Dockerfile.tau2-agent index a542ac5..4699eda 100644 --- a/scenarios/tau2/Dockerfile.tau2-agent +++ b/scenarios/tau2/Dockerfile.tau2-agent @@ -12,7 +12,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv sync --locked -COPY scenarios scenarios +COPY scenarios/tau2 scenarios/tau2 ENTRYPOINT ["uv", "run", "scenarios/tau2/tau2_agent.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/tau2/Dockerfile.tau2-env b/scenarios/tau2/Dockerfile.tau2-env new file mode 100644 index 0000000..68f8816 --- /dev/null +++ b/scenarios/tau2/Dockerfile.tau2-env @@ -0,0 +1,30 @@ +# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md + +# Accept base image as build argument for CI/CD flexibility +ARG BASE_IMAGE=openenv-base:latest +FROM ${BASE_IMAGE} + +# Install dependencies +COPY scenarios/tau2/requirements.txt /tmp/requirements.txt +RUN \ + apt-get update && \ + apt-get install -y git && \ + pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt + +# Download tau2 data +RUN git clone --depth 1 --filter=blob:none --sparse https://github.com/sierra-research/tau2-bench.git /app/scenarios/tau2/tau2-bench && \ + cd /app/scenarios/tau2/tau2-bench && \ + git sparse-checkout set data + +ENV TAU2_DATA_DIR=/app/scenarios/tau2/tau2-bench/data + +# Copy environment code +COPY scenarios/tau2/ /app/scenarios/tau2/ + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run server +# https://github.com/meta-pytorch/OpenEnv/issues/244 +CMD ["python", "-m", "uvicorn", "tau2_server:app", "--host", "0.0.0.0", "--port", "8000", "--app-dir", "scenarios/tau2"] diff --git a/scenarios/tau2/Dockerfile.tau2-evaluator b/scenarios/tau2/Dockerfile.tau2-evaluator index dec15ee..a061288 100644 --- a/scenarios/tau2/Dockerfile.tau2-evaluator +++ b/scenarios/tau2/Dockerfile.tau2-evaluator @@ -17,18 +17,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv pip install "tau2 @ git+https://github.com/sierra-research/tau2-bench.git" -# Download tau2 data -USER root -RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* -USER agentbeats - -RUN git clone --depth 1 --filter=blob:none --sparse https://github.com/sierra-research/tau2-bench.git /home/agentbeats/tau2-bench && \ - cd /home/agentbeats/tau2-bench && \ - git sparse-checkout set data - -ENV TAU2_DATA_DIR=/home/agentbeats/tau2-bench/data - -COPY scenarios scenarios +COPY scenarios/tau2 scenarios/tau2 ENTRYPOINT ["uv", "run", "scenarios/tau2/tau2_evaluator.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/tau2/requirements.txt b/scenarios/tau2/requirements.txt new file mode 100644 index 0000000..c91947a --- /dev/null +++ b/scenarios/tau2/requirements.txt @@ -0,0 +1,2 @@ +openenv-core>=0.1.1 +tau2 @ git+https://github.com/sierra-research/tau2-bench.git diff --git a/scenarios/tau2/scenario.toml b/scenarios/tau2/scenario.toml index d71cf66..43c1f10 100644 --- a/scenarios/tau2/scenario.toml +++ b/scenarios/tau2/scenario.toml @@ -7,6 +7,16 @@ role = "agent" endpoint = "http://127.0.0.1:9019" cmd = "python scenarios/tau2/tau2_agent.py --host 127.0.0.1 --port 9019" +[[environments]] +name = "tau2" +endpoint = "http://127.0.0.1:8000" +image = "tau2-env:latest" +publishes = ["127.0.0.1:8000:8000"] +[environments.env] +TAU2_DOMAIN = "airline" +TAU2_TASK_ID = "0" +TAU2_ENV_ARGS_JSON = "{}" + [config] domain = "airline" num_tasks = 3 diff --git a/scenarios/tau2/tau2_client.py b/scenarios/tau2/tau2_client.py new file mode 100644 index 0000000..4192dac --- /dev/null +++ b/scenarios/tau2/tau2_client.py @@ -0,0 +1,25 @@ +from typing import Any + +from openenv_core.client_types import StepResult +from openenv_core.http_env_client import HTTPEnvClient + +from tau2_models import Tau2Action, Tau2Observation, Tau2State + + +# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md + + +class Tau2Env(HTTPEnvClient[Tau2Action, Tau2Observation]): + def _step_payload(self, action: Tau2Action) -> dict[str, Any]: + return {"action": action.action} + + def _parse_result(self, payload: dict[str, Any]) -> StepResult[Tau2Observation]: + obs = Tau2Observation(**payload["observation"]) + return StepResult( + observation=obs, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: dict[str, Any]) -> Tau2State: + return Tau2State(**payload) diff --git a/scenarios/tau2/tau2_env.py b/scenarios/tau2/tau2_env.py new file mode 100644 index 0000000..1c3d33d --- /dev/null +++ b/scenarios/tau2/tau2_env.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass, field +from typing import Any +import uuid + +import gymnasium as gym + +from openenv_core.env_server import Action, Environment, Observation, State +from tau2.gym import TAU_BENCH_ENV_ID, register_gym_agent + +from tau2_models import Tau2Action, Tau2Observation, Tau2State + + +# https://github.com/sierra-research/tau2-bench/blob/main/src/tau2/gym/README.md +# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md + + +register_gym_agent() + + +class Tau2Environment(Environment): + def __init__( + self, + domain: str, + task_id: str, + env_args: Any, + ): + super().__init__() + self._state = Tau2State() + self._gym_env: gym.Env[str, str] = gym.make( + TAU_BENCH_ENV_ID, + domain=domain, + task_id=task_id, + **env_args, + ) + + def reset(self) -> Tau2Observation: + self._state = Tau2State(episode_id=str(uuid.uuid4())) + observation, info = self._gym_env.reset() + self._state.info = info + return Tau2Observation(observation=observation) + + def step(self, action: Action) -> Tau2Observation: + assert isinstance(action, Tau2Action) + self._state.step_count += 1 + observation, reward, terminated, truncated, info = self._gym_env.step(action.action) + self._state.info = info + return Tau2Observation( + observation=observation, + done=terminated or truncated, + reward=float(reward), + ) + + @property + def state(self) -> Tau2State: + return self._state diff --git a/scenarios/tau2/tau2_evaluator.py b/scenarios/tau2/tau2_evaluator.py index 42d6a00..d4160f4 100644 --- a/scenarios/tau2/tau2_evaluator.py +++ b/scenarios/tau2/tau2_evaluator.py @@ -14,7 +14,6 @@ import time from typing import Any, Optional -import gymnasium as gym import uvicorn from dotenv import load_dotenv @@ -40,17 +39,16 @@ from tau2.data_model.simulation import RewardInfo from tau2.environment.tool import Tool -from tau2.gym import TAU_BENCH_ENV_ID, register_gym_agent from tau2.run import get_tasks +from tau2_client import Tau2Env +from tau2_models import Tau2Action + logging.basicConfig(level=logging.INFO) logger = logging.getLogger("tau2_evaluator") RESPOND_ACTION_NAME = "respond" -# Register tau-bench gym environments -register_gym_agent() - def tools_to_str(tools: list[Tool]) -> str: """Convert tau-bench tools to JSON schema format.""" @@ -93,19 +91,19 @@ def validate_request(self, request: EvalRequest) -> tuple[bool, str]: return False, f"Missing config keys: {missing_config_keys}" return True, "ok" - async def run_eval(self, req: EvalRequest, updater: TaskUpdater) -> None: - logger.info(f"Starting tau2 evaluation: {req}") + async def run_eval(self, request: EvalRequest, updater: TaskUpdater) -> None: + logger.info(f"Starting tau2 evaluation: {request}") start_time = time.time() - domain = req.config["domain"] - task_ids = req.config.get("task_ids", None) - num_tasks = req.config.get("num_tasks", None) - max_steps = req.config.get("max_steps", 200) - user_llm = req.config.get("user_llm", "openai/gpt-4o") - user_llm_args = req.config.get("user_llm_args", {"temperature": 0.0}) + domain = request.config["domain"] + task_ids = request.config.get("task_ids", None) + num_tasks = request.config.get("num_tasks", None) + max_steps = request.config.get("max_steps", 200) + user_llm = request.config.get("user_llm", "openai/gpt-4o") + user_llm_args = request.config.get("user_llm_args", {"temperature": 0.0}) # Get the purple agent URL - agent_url = str(req.participants["agent"]) + agent_url = str(request.participants["agent"]) # Get task IDs resolved_task_ids = get_task_ids(domain, task_ids, num_tasks) @@ -146,7 +144,7 @@ async def run_eval(self, req: EvalRequest, updater: TaskUpdater) -> None: num_completed = len(metrics["tasks"]) pass_rate = (total_reward / num_completed * 100) if num_completed > 0 else 0 - result_data = { + result_data: dict[str, Any] = { "domain": domain, "score": total_reward, "max_score": num_completed, @@ -188,35 +186,26 @@ async def _run_single_task( task_id: str, max_steps: int, user_llm: str, - user_llm_args: dict, + user_llm_args: dict[Any, Any], ) -> float: """Run a single tau-bench task and return the reward.""" - env = gym.make( - TAU_BENCH_ENV_ID, - domain=domain, - task_id=task_id, - max_steps=max_steps, - user_llm=user_llm, - user_llm_args=user_llm_args, - all_messages_as_observation=False, - ) + env = Tau2Env("http://localhost:8000") - terminated = False - observation, info = env.reset() + observation_sr = env.reset() # Build the initial task description for the purple agent - task_description = self._build_task_prompt(info, observation) + task_description = self._build_task_prompt(env.state.info, observation_sr.observation.observation) # Start a new conversation with the purple agent next_message = task_description is_first_message = True - while not terminated: + while not observation_sr.done: logger.debug(f"Sending to purple agent: {next_message[:200]}...") # Send message to purple agent - response = await self._tool_provider.talk_to_agent( + response: str = await self._tool_provider.talk_to_agent( message=next_message, url=agent_url, new_conversation=is_first_message, @@ -227,28 +216,28 @@ async def _run_single_task( # Parse the purple agent's action try: - action = self._parse_agent_response(response) + action = Tau2Action(action=self._parse_agent_response(response)) except Exception as e: logger.error(f"Failed to parse agent response: {e}") # When parsing fails, respond with error as plain text (not a tool call) - action = "I encountered an error processing the request." + action = Tau2Action(action="I encountered an error processing the request.") # Step the environment with either a JSON string (tool call) or plain text (user response) - observation, reward, terminated, truncated, info = env.step(action) - logger.debug(f"Environment step: reward={reward}, terminated={terminated}") + observation_sr = env.step(action) + logger.debug(f"Environment step: reward={observation_sr.reward}, done={observation_sr.done}") - if terminated: + if observation_sr.done: break - next_message = observation + next_message = observation_sr.observation.observation # Extract final reward - if info.get("reward_info"): - reward_info = RewardInfo.model_validate_json(info["reward_info"]) + if env.state.info.get("reward_info"): + reward_info = RewardInfo.model_validate_json(env.state.info["reward_info"]) return reward_info.reward - return float(reward) + return 0. if observation_sr.reward is None else float(observation_sr.reward) - def _build_task_prompt(self, info: dict, observation: str) -> str: + def _build_task_prompt(self, info: dict[Any, Any], observation: str) -> str: """Build the initial task prompt for the purple agent.""" return f""" {info["policy"]} diff --git a/scenarios/tau2/tau2_models.py b/scenarios/tau2/tau2_models.py new file mode 100644 index 0000000..ad8450a --- /dev/null +++ b/scenarios/tau2/tau2_models.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field +from typing import Any + +from openenv_core.env_server import Action, Observation, State + + +@dataclass +class Tau2Action(Action): + action: str + + +@dataclass +class Tau2Observation(Observation): + observation: str + + +@dataclass +class Tau2State(State): + info: dict[str, Any] = field(default_factory=dict[str, Any]) diff --git a/scenarios/tau2/tau2_server.py b/scenarios/tau2/tau2_server.py new file mode 100644 index 0000000..b373c7e --- /dev/null +++ b/scenarios/tau2/tau2_server.py @@ -0,0 +1,17 @@ +import os +import json +from openenv_core.env_server import create_fastapi_app + +from tau2_models import Tau2Action, Tau2Observation +from tau2_env import Tau2Environment + + +# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md + + +env = Tau2Environment( + domain=os.environ.get("TAU2_DOMAIN", "airline"), + task_id=os.environ.get("TAU2_TASK_ID", "0"), + env_args=json.loads(os.environ.get("TAU2_ENV_ARGS_JSON", "{}")), +) +app = create_fastapi_app(env, Tau2Action, Tau2Observation) diff --git a/src/agentbeats/run_scenario.py b/src/agentbeats/run_scenario.py index bc937a8..fd397c1 100644 --- a/src/agentbeats/run_scenario.py +++ b/src/agentbeats/run_scenario.py @@ -4,12 +4,54 @@ from pathlib import Path import tomllib import httpx -from dotenv import load_dotenv +from dotenv import find_dotenv, load_dotenv from a2a.client import A2ACardResolver -load_dotenv(override=True) +dotenv_path = find_dotenv() +if dotenv_path: + load_dotenv(dotenv_path, override=True) + + +async def wait_for_environments(cfg: dict, timeout: int = 30) -> bool: + """Wait for all environments to be healthy and responding.""" + endpoints = [] + + # Collect environment endpoints to check + for e in cfg.get("environments", []): + if e.get("image"): # Only check if there's an image (container to start) + endpoints.append(f"http://{e['host']}:{e['port']}") + + if not endpoints: + return True # No environments to wait for + + print(f"Waiting for {len(endpoints)} environment(s) to be ready...") + start_time = time.time() + + async def check_endpoint(endpoint: str) -> bool: + """Check if an environment server is responding via /health endpoint.""" + try: + async with httpx.AsyncClient(timeout=2) as client: + response = await client.get(f"{endpoint}/health") + return response.status_code == 200 + except Exception: + return False + + while time.time() - start_time < timeout: + ready_count = 0 + for endpoint in endpoints: + if await check_endpoint(endpoint): + ready_count += 1 + + if ready_count == len(endpoints): + return True + + print(f" {ready_count}/{len(endpoints)} environments ready, waiting...") + await asyncio.sleep(1) + + print(f"Timeout: Only {ready_count}/{len(endpoints)} environments became ready after {timeout}s") + return False async def wait_for_agents(cfg: dict, timeout: int = 30) -> bool: @@ -87,8 +129,22 @@ def host_port(ep: str): "cmd": p.get("cmd", "") }) + envs = [] + for e in data.get("environments", []): + if isinstance(e, dict) and "endpoint" in e: + h, pt = host_port(e["endpoint"]) + envs.append({ + "name": str(e.get("name", "")), + "host": h, + "port": pt, + "image": e.get("image", ""), + "publishes": e.get("publishes", []), + "env": e.get("env", {}), + }) + cfg = data.get("config", {}) return { + "environments": envs, "green_agent": {"host": g_host, "port": g_port, "cmd": green_cmd}, "participants": parts, "config": cfg, @@ -112,7 +168,38 @@ def main(): base_env["PATH"] = parent_bin + os.pathsep + base_env.get("PATH", "") procs = [] + containers = [] try: + # start environments + for e in cfg.get("environments", []): + if e.get("image", ""): + container_name = f"agentbeats-{e['name']}-{int(time.time())}" + cmd_args = [ + "docker", "run", + "--rm", + "--name", container_name, + ] + for publish in e["publishes"]: + cmd_args.extend(["-p", publish]) + if dotenv_path: + cmd_args.extend(["--env-file", dotenv_path]) + for key, value in e["env"].items(): + cmd_args.extend(["-e", f"{key}={value}"]) + cmd_args.append(e["image"]) + print(f"Starting environment {e['name']} at {e['host']}:{e['port']} (container: {container_name})") + procs.append(subprocess.Popen( + cmd_args, + stdout=sink, stderr=sink, + text=True, + start_new_session=True, + )) + containers.append(container_name) + + # Wait for all environments to be ready + if not asyncio.run(wait_for_environments(cfg)): + print("Error: Not all environments became ready. Exiting.") + return + # start participant agents for p in cfg["participants"]: cmd_args = shlex.split(p.get("cmd", "")) @@ -148,7 +235,7 @@ def main(): while True: for proc in procs: if proc.poll() is not None: - print(f"Agent exited with code {proc.returncode}") + print(f"Process exited with code {proc.returncode}") break time.sleep(0.5) else: @@ -165,6 +252,14 @@ def main(): finally: print("\nShutting down...") + for c in containers: + try: + subprocess.run( + ["docker", "stop", c], + stdout=sink, stderr=sink, + ) + except Exception as e: + print(f"Error in docker stop {c}: {e}") for p in procs: if p.poll() is None: try: