From 00532dfff9cf2d86fe276411a37aae68072e31c6 Mon Sep 17 00:00:00 2001 From: Warren He Date: Tue, 9 Dec 2025 11:26:31 -0800 Subject: [PATCH 1/3] scenarios/tau2: wrap in openenv --- scenarios/tau2/tau2_env.py | 68 +++++++++++++++++++++++++++++++ scenarios/tau2/tau2_evaluator.py | 70 +++++++++++++++----------------- 2 files changed, 101 insertions(+), 37 deletions(-) create mode 100644 scenarios/tau2/tau2_env.py diff --git a/scenarios/tau2/tau2_env.py b/scenarios/tau2/tau2_env.py new file mode 100644 index 0000000..5972ffc --- /dev/null +++ b/scenarios/tau2/tau2_env.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field +from typing import Any +import uuid + +import gymnasium as gym + +from openenv_core.env_server import Action, Environment, Observation, State +from tau2.gym import TAU_BENCH_ENV_ID, register_gym_agent + + +# https://github.com/sierra-research/tau2-bench/blob/main/src/tau2/gym/README.md +# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md + + +register_gym_agent() + + +@dataclass +class Tau2Action(Action): + action: str + + +@dataclass +class Tau2Observation(Observation): + observation: str + + +@dataclass +class Tau2State(State): + info: dict[str, Any] = field(default_factory=dict[str, Any]) + + +class Tau2Environment(Environment): + def __init__( + self, + domain: str, + task_id: str, + env_args: Any, + ): + super().__init__() + self._state = Tau2State() + self._gym_env: gym.Env[str, str] = gym.make( + TAU_BENCH_ENV_ID, + domain=domain, + task_id=task_id, + **env_args, + ) + + def reset(self) -> Tau2Observation: + self._state = Tau2State(episode_id=str(uuid.uuid4())) + observation, info = self._gym_env.reset() + self._state.info = info + return Tau2Observation(observation=observation) + + def step(self, action: Action) -> Tau2Observation: + assert isinstance(action, Tau2Action) + self._state.step_count += 1 + observation, reward, terminated, truncated, info = self._gym_env.step(action.action) + self._state.info = info + return Tau2Observation( + observation=observation, + done=terminated or truncated, + reward=float(reward), + ) + + @property + def state(self) -> Tau2State: + return self._state diff --git a/scenarios/tau2/tau2_evaluator.py b/scenarios/tau2/tau2_evaluator.py index 42d6a00..e18d5b3 100644 --- a/scenarios/tau2/tau2_evaluator.py +++ b/scenarios/tau2/tau2_evaluator.py @@ -14,7 +14,6 @@ import time from typing import Any, Optional -import gymnasium as gym import uvicorn from dotenv import load_dotenv @@ -40,17 +39,15 @@ from tau2.data_model.simulation import RewardInfo from tau2.environment.tool import Tool -from tau2.gym import TAU_BENCH_ENV_ID, register_gym_agent from tau2.run import get_tasks +from tau2_env import Tau2Environment, Tau2Action + logging.basicConfig(level=logging.INFO) logger = logging.getLogger("tau2_evaluator") RESPOND_ACTION_NAME = "respond" -# Register tau-bench gym environments -register_gym_agent() - def tools_to_str(tools: list[Tool]) -> str: """Convert tau-bench tools to JSON schema format.""" @@ -93,19 +90,19 @@ def validate_request(self, request: EvalRequest) -> tuple[bool, str]: return False, f"Missing config keys: {missing_config_keys}" return True, "ok" - async def run_eval(self, req: EvalRequest, updater: TaskUpdater) -> None: - logger.info(f"Starting tau2 evaluation: {req}") + async def run_eval(self, request: EvalRequest, updater: TaskUpdater) -> None: + logger.info(f"Starting tau2 evaluation: {request}") start_time = time.time() - domain = req.config["domain"] - task_ids = req.config.get("task_ids", None) - num_tasks = req.config.get("num_tasks", None) - max_steps = req.config.get("max_steps", 200) - user_llm = req.config.get("user_llm", "openai/gpt-4o") - user_llm_args = req.config.get("user_llm_args", {"temperature": 0.0}) + domain = request.config["domain"] + task_ids = request.config.get("task_ids", None) + num_tasks = request.config.get("num_tasks", None) + max_steps = request.config.get("max_steps", 200) + user_llm = request.config.get("user_llm", "openai/gpt-4o") + user_llm_args = request.config.get("user_llm_args", {"temperature": 0.0}) # Get the purple agent URL - agent_url = str(req.participants["agent"]) + agent_url = str(request.participants["agent"]) # Get task IDs resolved_task_ids = get_task_ids(domain, task_ids, num_tasks) @@ -146,7 +143,7 @@ async def run_eval(self, req: EvalRequest, updater: TaskUpdater) -> None: num_completed = len(metrics["tasks"]) pass_rate = (total_reward / num_completed * 100) if num_completed > 0 else 0 - result_data = { + result_data: dict[str, Any] = { "domain": domain, "score": total_reward, "max_score": num_completed, @@ -188,35 +185,34 @@ async def _run_single_task( task_id: str, max_steps: int, user_llm: str, - user_llm_args: dict, + user_llm_args: dict[Any, Any], ) -> float: """Run a single tau-bench task and return the reward.""" - env = gym.make( - TAU_BENCH_ENV_ID, + env = Tau2Environment( domain=domain, task_id=task_id, - max_steps=max_steps, - user_llm=user_llm, - user_llm_args=user_llm_args, - all_messages_as_observation=False, + env_args={ + "max_steps": max_steps, + "user_llm": user_llm, + "user_llm_args": user_llm_args, + }, ) - terminated = False - observation, info = env.reset() + observation = env.reset() # Build the initial task description for the purple agent - task_description = self._build_task_prompt(info, observation) + task_description = self._build_task_prompt(env.state.info, observation.observation) # Start a new conversation with the purple agent next_message = task_description is_first_message = True - while not terminated: + while not observation.done: logger.debug(f"Sending to purple agent: {next_message[:200]}...") # Send message to purple agent - response = await self._tool_provider.talk_to_agent( + response: str = await self._tool_provider.talk_to_agent( message=next_message, url=agent_url, new_conversation=is_first_message, @@ -227,28 +223,28 @@ async def _run_single_task( # Parse the purple agent's action try: - action = self._parse_agent_response(response) + action = Tau2Action(action=self._parse_agent_response(response)) except Exception as e: logger.error(f"Failed to parse agent response: {e}") # When parsing fails, respond with error as plain text (not a tool call) - action = "I encountered an error processing the request." + action = Tau2Action(action="I encountered an error processing the request.") # Step the environment with either a JSON string (tool call) or plain text (user response) - observation, reward, terminated, truncated, info = env.step(action) - logger.debug(f"Environment step: reward={reward}, terminated={terminated}") + observation = env.step(action) + logger.debug(f"Environment step: reward={observation.reward}, done={observation.done}") - if terminated: + if observation.done: break - next_message = observation + next_message = observation.observation # Extract final reward - if info.get("reward_info"): - reward_info = RewardInfo.model_validate_json(info["reward_info"]) + if env.state.info.get("reward_info"): + reward_info = RewardInfo.model_validate_json(env.state.info["reward_info"]) return reward_info.reward - return float(reward) + return 0. if observation.reward is None else float(observation.reward) - def _build_task_prompt(self, info: dict, observation: str) -> str: + def _build_task_prompt(self, info: dict[Any, Any], observation: str) -> str: """Build the initial task prompt for the purple agent.""" return f""" {info["policy"]} From 49b2b6cf3a18e746d96818b0d56dd772edffd108 Mon Sep 17 00:00:00 2001 From: Warren He Date: Wed, 10 Dec 2025 15:58:00 -0800 Subject: [PATCH 2/3] use openenv client+server --- scenarios/debate/Dockerfile.adk-debate-judge | 2 +- scenarios/debate/Dockerfile.debate-judge | 2 +- scenarios/debate/Dockerfile.debater | 2 +- scenarios/tau2/Dockerfile.tau2-agent | 2 +- scenarios/tau2/Dockerfile.tau2-env | 30 ++++++++ scenarios/tau2/Dockerfile.tau2-evaluator | 13 +--- scenarios/tau2/requirements.txt | 2 + scenarios/tau2/scenario.toml | 5 ++ scenarios/tau2/tau2_client.py | 25 +++++++ scenarios/tau2/tau2_env.py | 17 +---- scenarios/tau2/tau2_evaluator.py | 29 +++----- scenarios/tau2/tau2_models.py | 19 ++++++ scenarios/tau2/tau2_server.py | 17 +++++ src/agentbeats/run_scenario.py | 72 +++++++++++++++++++- 14 files changed, 187 insertions(+), 50 deletions(-) create mode 100644 scenarios/tau2/Dockerfile.tau2-env create mode 100644 scenarios/tau2/requirements.txt create mode 100644 scenarios/tau2/tau2_client.py create mode 100644 scenarios/tau2/tau2_models.py create mode 100644 scenarios/tau2/tau2_server.py diff --git a/scenarios/debate/Dockerfile.adk-debate-judge b/scenarios/debate/Dockerfile.adk-debate-judge index 34cfa15..5d493ee 100644 --- a/scenarios/debate/Dockerfile.adk-debate-judge +++ b/scenarios/debate/Dockerfile.adk-debate-judge @@ -11,7 +11,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv sync --locked -COPY scenarios scenarios +COPY scenarios/debate scenarios/debate ENTRYPOINT ["uv", "run", "scenarios/debate/adk_debate_judge.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/debate/Dockerfile.debate-judge b/scenarios/debate/Dockerfile.debate-judge index 72e9226..e3e0f73 100644 --- a/scenarios/debate/Dockerfile.debate-judge +++ b/scenarios/debate/Dockerfile.debate-judge @@ -11,7 +11,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv sync --locked -COPY scenarios scenarios +COPY scenarios/debate scenarios/debate ENTRYPOINT ["uv", "run", "scenarios/debate/debate_judge.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/debate/Dockerfile.debater b/scenarios/debate/Dockerfile.debater index f2877b9..89ebc30 100644 --- a/scenarios/debate/Dockerfile.debater +++ b/scenarios/debate/Dockerfile.debater @@ -11,7 +11,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv sync --locked -COPY scenarios scenarios +COPY scenarios/debate scenarios/debate ENTRYPOINT ["uv", "run", "scenarios/debate/debater.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/tau2/Dockerfile.tau2-agent b/scenarios/tau2/Dockerfile.tau2-agent index a542ac5..4699eda 100644 --- a/scenarios/tau2/Dockerfile.tau2-agent +++ b/scenarios/tau2/Dockerfile.tau2-agent @@ -12,7 +12,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv sync --locked -COPY scenarios scenarios +COPY scenarios/tau2 scenarios/tau2 ENTRYPOINT ["uv", "run", "scenarios/tau2/tau2_agent.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/tau2/Dockerfile.tau2-env b/scenarios/tau2/Dockerfile.tau2-env new file mode 100644 index 0000000..68f8816 --- /dev/null +++ b/scenarios/tau2/Dockerfile.tau2-env @@ -0,0 +1,30 @@ +# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md + +# Accept base image as build argument for CI/CD flexibility +ARG BASE_IMAGE=openenv-base:latest +FROM ${BASE_IMAGE} + +# Install dependencies +COPY scenarios/tau2/requirements.txt /tmp/requirements.txt +RUN \ + apt-get update && \ + apt-get install -y git && \ + pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt + +# Download tau2 data +RUN git clone --depth 1 --filter=blob:none --sparse https://github.com/sierra-research/tau2-bench.git /app/scenarios/tau2/tau2-bench && \ + cd /app/scenarios/tau2/tau2-bench && \ + git sparse-checkout set data + +ENV TAU2_DATA_DIR=/app/scenarios/tau2/tau2-bench/data + +# Copy environment code +COPY scenarios/tau2/ /app/scenarios/tau2/ + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run server +# https://github.com/meta-pytorch/OpenEnv/issues/244 +CMD ["python", "-m", "uvicorn", "tau2_server:app", "--host", "0.0.0.0", "--port", "8000", "--app-dir", "scenarios/tau2"] diff --git a/scenarios/tau2/Dockerfile.tau2-evaluator b/scenarios/tau2/Dockerfile.tau2-evaluator index dec15ee..a061288 100644 --- a/scenarios/tau2/Dockerfile.tau2-evaluator +++ b/scenarios/tau2/Dockerfile.tau2-evaluator @@ -17,18 +17,7 @@ RUN \ --mount=type=cache,target=/home/agentbeats/.cache/uv,uid=1000 \ uv pip install "tau2 @ git+https://github.com/sierra-research/tau2-bench.git" -# Download tau2 data -USER root -RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* -USER agentbeats - -RUN git clone --depth 1 --filter=blob:none --sparse https://github.com/sierra-research/tau2-bench.git /home/agentbeats/tau2-bench && \ - cd /home/agentbeats/tau2-bench && \ - git sparse-checkout set data - -ENV TAU2_DATA_DIR=/home/agentbeats/tau2-bench/data - -COPY scenarios scenarios +COPY scenarios/tau2 scenarios/tau2 ENTRYPOINT ["uv", "run", "scenarios/tau2/tau2_evaluator.py"] CMD ["--host", "0.0.0.0"] diff --git a/scenarios/tau2/requirements.txt b/scenarios/tau2/requirements.txt new file mode 100644 index 0000000..c91947a --- /dev/null +++ b/scenarios/tau2/requirements.txt @@ -0,0 +1,2 @@ +openenv-core>=0.1.1 +tau2 @ git+https://github.com/sierra-research/tau2-bench.git diff --git a/scenarios/tau2/scenario.toml b/scenarios/tau2/scenario.toml index d71cf66..1467cb8 100644 --- a/scenarios/tau2/scenario.toml +++ b/scenarios/tau2/scenario.toml @@ -7,6 +7,11 @@ role = "agent" endpoint = "http://127.0.0.1:9019" cmd = "python scenarios/tau2/tau2_agent.py --host 127.0.0.1 --port 9019" +[[environments]] +name = "tau2" +endpoint = "http://127.0.0.1:8000" +cmd = "uvicorn tau2_server:app --host 127.0.0.1 --port 8000 --app-dir scenarios/tau2" + [config] domain = "airline" num_tasks = 3 diff --git a/scenarios/tau2/tau2_client.py b/scenarios/tau2/tau2_client.py new file mode 100644 index 0000000..4192dac --- /dev/null +++ b/scenarios/tau2/tau2_client.py @@ -0,0 +1,25 @@ +from typing import Any + +from openenv_core.client_types import StepResult +from openenv_core.http_env_client import HTTPEnvClient + +from tau2_models import Tau2Action, Tau2Observation, Tau2State + + +# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md + + +class Tau2Env(HTTPEnvClient[Tau2Action, Tau2Observation]): + def _step_payload(self, action: Tau2Action) -> dict[str, Any]: + return {"action": action.action} + + def _parse_result(self, payload: dict[str, Any]) -> StepResult[Tau2Observation]: + obs = Tau2Observation(**payload["observation"]) + return StepResult( + observation=obs, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: dict[str, Any]) -> Tau2State: + return Tau2State(**payload) diff --git a/scenarios/tau2/tau2_env.py b/scenarios/tau2/tau2_env.py index 5972ffc..1c3d33d 100644 --- a/scenarios/tau2/tau2_env.py +++ b/scenarios/tau2/tau2_env.py @@ -7,6 +7,8 @@ from openenv_core.env_server import Action, Environment, Observation, State from tau2.gym import TAU_BENCH_ENV_ID, register_gym_agent +from tau2_models import Tau2Action, Tau2Observation, Tau2State + # https://github.com/sierra-research/tau2-bench/blob/main/src/tau2/gym/README.md # https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md @@ -15,21 +17,6 @@ register_gym_agent() -@dataclass -class Tau2Action(Action): - action: str - - -@dataclass -class Tau2Observation(Observation): - observation: str - - -@dataclass -class Tau2State(State): - info: dict[str, Any] = field(default_factory=dict[str, Any]) - - class Tau2Environment(Environment): def __init__( self, diff --git a/scenarios/tau2/tau2_evaluator.py b/scenarios/tau2/tau2_evaluator.py index e18d5b3..d4160f4 100644 --- a/scenarios/tau2/tau2_evaluator.py +++ b/scenarios/tau2/tau2_evaluator.py @@ -41,7 +41,8 @@ from tau2.environment.tool import Tool from tau2.run import get_tasks -from tau2_env import Tau2Environment, Tau2Action +from tau2_client import Tau2Env +from tau2_models import Tau2Action logging.basicConfig(level=logging.INFO) logger = logging.getLogger("tau2_evaluator") @@ -189,26 +190,18 @@ async def _run_single_task( ) -> float: """Run a single tau-bench task and return the reward.""" - env = Tau2Environment( - domain=domain, - task_id=task_id, - env_args={ - "max_steps": max_steps, - "user_llm": user_llm, - "user_llm_args": user_llm_args, - }, - ) + env = Tau2Env("http://localhost:8000") - observation = env.reset() + observation_sr = env.reset() # Build the initial task description for the purple agent - task_description = self._build_task_prompt(env.state.info, observation.observation) + task_description = self._build_task_prompt(env.state.info, observation_sr.observation.observation) # Start a new conversation with the purple agent next_message = task_description is_first_message = True - while not observation.done: + while not observation_sr.done: logger.debug(f"Sending to purple agent: {next_message[:200]}...") # Send message to purple agent @@ -230,19 +223,19 @@ async def _run_single_task( action = Tau2Action(action="I encountered an error processing the request.") # Step the environment with either a JSON string (tool call) or plain text (user response) - observation = env.step(action) - logger.debug(f"Environment step: reward={observation.reward}, done={observation.done}") + observation_sr = env.step(action) + logger.debug(f"Environment step: reward={observation_sr.reward}, done={observation_sr.done}") - if observation.done: + if observation_sr.done: break - next_message = observation.observation + next_message = observation_sr.observation.observation # Extract final reward if env.state.info.get("reward_info"): reward_info = RewardInfo.model_validate_json(env.state.info["reward_info"]) return reward_info.reward - return 0. if observation.reward is None else float(observation.reward) + return 0. if observation_sr.reward is None else float(observation_sr.reward) def _build_task_prompt(self, info: dict[Any, Any], observation: str) -> str: """Build the initial task prompt for the purple agent.""" diff --git a/scenarios/tau2/tau2_models.py b/scenarios/tau2/tau2_models.py new file mode 100644 index 0000000..ad8450a --- /dev/null +++ b/scenarios/tau2/tau2_models.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field +from typing import Any + +from openenv_core.env_server import Action, Observation, State + + +@dataclass +class Tau2Action(Action): + action: str + + +@dataclass +class Tau2Observation(Observation): + observation: str + + +@dataclass +class Tau2State(State): + info: dict[str, Any] = field(default_factory=dict[str, Any]) diff --git a/scenarios/tau2/tau2_server.py b/scenarios/tau2/tau2_server.py new file mode 100644 index 0000000..b373c7e --- /dev/null +++ b/scenarios/tau2/tau2_server.py @@ -0,0 +1,17 @@ +import os +import json +from openenv_core.env_server import create_fastapi_app + +from tau2_models import Tau2Action, Tau2Observation +from tau2_env import Tau2Environment + + +# https://github.com/meta-pytorch/OpenEnv/blob/fb169f8c660df722f538160b3ce636de3312a756/src/envs/README.md + + +env = Tau2Environment( + domain=os.environ.get("TAU2_DOMAIN", "airline"), + task_id=os.environ.get("TAU2_TASK_ID", "0"), + env_args=json.loads(os.environ.get("TAU2_ENV_ARGS_JSON", "{}")), +) +app = create_fastapi_app(env, Tau2Action, Tau2Observation) diff --git a/src/agentbeats/run_scenario.py b/src/agentbeats/run_scenario.py index bc937a8..834d63f 100644 --- a/src/agentbeats/run_scenario.py +++ b/src/agentbeats/run_scenario.py @@ -12,6 +12,46 @@ load_dotenv(override=True) +async def wait_for_environments(cfg: dict, timeout: int = 30) -> bool: + """Wait for all environments to be healthy and responding.""" + endpoints = [] + + # Collect environment endpoints to check + for e in cfg.get("environments", []): + if e.get("cmd"): # Only check if there's a command (server to start) + endpoints.append(f"http://{e['host']}:{e['port']}") + + if not endpoints: + return True # No environments to wait for + + print(f"Waiting for {len(endpoints)} environment(s) to be ready...") + start_time = time.time() + + async def check_endpoint(endpoint: str) -> bool: + """Check if an environment server is responding via /health endpoint.""" + try: + async with httpx.AsyncClient(timeout=2) as client: + response = await client.get(f"{endpoint}/health") + return response.status_code == 200 + except Exception: + return False + + while time.time() - start_time < timeout: + ready_count = 0 + for endpoint in endpoints: + if await check_endpoint(endpoint): + ready_count += 1 + + if ready_count == len(endpoints): + return True + + print(f" {ready_count}/{len(endpoints)} environments ready, waiting...") + await asyncio.sleep(1) + + print(f"Timeout: Only {ready_count}/{len(endpoints)} environments became ready after {timeout}s") + return False + + async def wait_for_agents(cfg: dict, timeout: int = 30) -> bool: """Wait for all agents to be healthy and responding.""" endpoints = [] @@ -87,8 +127,20 @@ def host_port(ep: str): "cmd": p.get("cmd", "") }) + envs = [] + for e in data.get("environments", []): + if isinstance(e, dict) and "endpoint" in e: + h, pt = host_port(e["endpoint"]) + envs.append({ + "name": str(e.get("name", "")), + "host": h, + "port": pt, + "cmd": e.get("cmd", "") + }) + cfg = data.get("config", {}) return { + "environments": envs, "green_agent": {"host": g_host, "port": g_port, "cmd": green_cmd}, "participants": parts, "config": cfg, @@ -113,6 +165,24 @@ def main(): procs = [] try: + # start environments + for e in cfg.get("environments", []): + cmd_args = shlex.split(e.get("cmd", "")) + if cmd_args: + print(f"Starting environment {e.get('name', '')} at {e['host']}:{e['port']}") + procs.append(subprocess.Popen( + cmd_args, + env=base_env, + stdout=sink, stderr=sink, + text=True, + start_new_session=True, + )) + + # Wait for all environments to be ready + if not asyncio.run(wait_for_environments(cfg)): + print("Error: Not all environments became ready. Exiting.") + return + # start participant agents for p in cfg["participants"]: cmd_args = shlex.split(p.get("cmd", "")) @@ -148,7 +218,7 @@ def main(): while True: for proc in procs: if proc.poll() is not None: - print(f"Agent exited with code {proc.returncode}") + print(f"Process exited with code {proc.returncode}") break time.sleep(0.5) else: From a54c47b2def48244e5142924c9dd38b117fe16e0 Mon Sep 17 00:00:00 2001 From: Warren He Date: Thu, 11 Dec 2025 14:14:11 -0800 Subject: [PATCH 3/3] use openenv container --- scenarios/tau2/scenario.toml | 7 +++++- src/agentbeats/run_scenario.py | 41 +++++++++++++++++++++++++++------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/scenarios/tau2/scenario.toml b/scenarios/tau2/scenario.toml index 1467cb8..43c1f10 100644 --- a/scenarios/tau2/scenario.toml +++ b/scenarios/tau2/scenario.toml @@ -10,7 +10,12 @@ cmd = "python scenarios/tau2/tau2_agent.py --host 127.0.0.1 --port 9019" [[environments]] name = "tau2" endpoint = "http://127.0.0.1:8000" -cmd = "uvicorn tau2_server:app --host 127.0.0.1 --port 8000 --app-dir scenarios/tau2" +image = "tau2-env:latest" +publishes = ["127.0.0.1:8000:8000"] +[environments.env] +TAU2_DOMAIN = "airline" +TAU2_TASK_ID = "0" +TAU2_ENV_ARGS_JSON = "{}" [config] domain = "airline" diff --git a/src/agentbeats/run_scenario.py b/src/agentbeats/run_scenario.py index 834d63f..fd397c1 100644 --- a/src/agentbeats/run_scenario.py +++ b/src/agentbeats/run_scenario.py @@ -4,12 +4,14 @@ from pathlib import Path import tomllib import httpx -from dotenv import load_dotenv +from dotenv import find_dotenv, load_dotenv from a2a.client import A2ACardResolver -load_dotenv(override=True) +dotenv_path = find_dotenv() +if dotenv_path: + load_dotenv(dotenv_path, override=True) async def wait_for_environments(cfg: dict, timeout: int = 30) -> bool: @@ -18,7 +20,7 @@ async def wait_for_environments(cfg: dict, timeout: int = 30) -> bool: # Collect environment endpoints to check for e in cfg.get("environments", []): - if e.get("cmd"): # Only check if there's a command (server to start) + if e.get("image"): # Only check if there's an image (container to start) endpoints.append(f"http://{e['host']}:{e['port']}") if not endpoints: @@ -135,7 +137,9 @@ def host_port(ep: str): "name": str(e.get("name", "")), "host": h, "port": pt, - "cmd": e.get("cmd", "") + "image": e.get("image", ""), + "publishes": e.get("publishes", []), + "env": e.get("env", {}), }) cfg = data.get("config", {}) @@ -164,19 +168,32 @@ def main(): base_env["PATH"] = parent_bin + os.pathsep + base_env.get("PATH", "") procs = [] + containers = [] try: # start environments for e in cfg.get("environments", []): - cmd_args = shlex.split(e.get("cmd", "")) - if cmd_args: - print(f"Starting environment {e.get('name', '')} at {e['host']}:{e['port']}") + if e.get("image", ""): + container_name = f"agentbeats-{e['name']}-{int(time.time())}" + cmd_args = [ + "docker", "run", + "--rm", + "--name", container_name, + ] + for publish in e["publishes"]: + cmd_args.extend(["-p", publish]) + if dotenv_path: + cmd_args.extend(["--env-file", dotenv_path]) + for key, value in e["env"].items(): + cmd_args.extend(["-e", f"{key}={value}"]) + cmd_args.append(e["image"]) + print(f"Starting environment {e['name']} at {e['host']}:{e['port']} (container: {container_name})") procs.append(subprocess.Popen( cmd_args, - env=base_env, stdout=sink, stderr=sink, text=True, start_new_session=True, )) + containers.append(container_name) # Wait for all environments to be ready if not asyncio.run(wait_for_environments(cfg)): @@ -235,6 +252,14 @@ def main(): finally: print("\nShutting down...") + for c in containers: + try: + subprocess.run( + ["docker", "stop", c], + stdout=sink, stderr=sink, + ) + except Exception as e: + print(f"Error in docker stop {c}: {e}") for p in procs: if p.poll() is None: try: