From 00aa735d854f56d9aaad1e41807bb75f51f28304 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Mon, 17 Nov 2025 11:49:10 -0500 Subject: [PATCH 01/22] Adds some tool-related requirements. --- mellea/stdlib/reqlib/tools.py | 82 +++++++++++++++++++++++++ test/stdlib_basics/test_reqlib_tools.py | 10 +++ 2 files changed, 92 insertions(+) create mode 100644 mellea/stdlib/reqlib/tools.py create mode 100644 test/stdlib_basics/test_reqlib_tools.py diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py new file mode 100644 index 00000000..c83d35bf --- /dev/null +++ b/mellea/stdlib/reqlib/tools.py @@ -0,0 +1,82 @@ +from typing import Callable, Optional +from mellea.stdlib.requirement import Context, ValidationResult, Requirement + + +def _name2str(tool_name: str | Callable) -> str: + match tool_name: + case Callable(): + return tool_name.__name__ + case str(): + return tool_name + case _: + raise TypeError(f"Expected Callable or str but found: {type(tool_name)}") + + +def uses_tool(tool_name: str | Callable, check_only=False): + """Forces the model to call a givne tool. + + Args: + tool_name: The tool that must be called; this can be either the name of the tool or the Callable for the tool. + check_only: Propagates to the Requirement. + + Use `tool_choice` if the OpenAI `tool_choice` model option is supported by your model and inference engine. + """ + tool_name = _name2str(tool_name) + + def _validate(ctx: Context): + output = ctx.last_output() + return ValidationResult(result=tool_name in output.tool_calls) + + return Requirement( + description=f"Use the {tool_name} tool.", + validation_fn=_validate, + check_only=check_only + ) + + +def tool_arg_validator( + description: str, + tool_name: Optional[str | Callable], + arg_name: str, + validation_fn: Callable, + check_only: bool=False +) -> Requirement: + """A requirement that passes only if `validation_fn` returns a True value for the *value* of the `arg_name` argument to `tool_name`. + If `tool_name` is not specified, then this requirement is enforced for *every* tool that + + Args: + description: The Requirement description. + tool_name: The (optional) tool name for . + arg_name: The argument to check. + validation_fn: A validation function for validating the value of the `arg_name` argument. + + TODO: + 1. should this be a requirement? + 2. should this be done automatically when the user provides asserts in their function body? + """ + if tool_name: + tool_name = _name2str(tool_name) + + def _validate(ctx: Context): + output = ctx.last_output() + if tool_name: + if tool_name not in output.tool_calls: + return ValidationResult(result=False, reason=f"Tool {tool_name} was not called.") + if arg_name not in output.tool_calls[tool_name].args: + return ValidationResult(result=False, reason=f"Tool {tool_name} did not call argument {arg_name}") + arg_value = output.tool_calls[tool_name].args[arg_name] + validate_result = validation_fn(arg_value) + if validate_result: + return ValidationResult(result=True) + else: + return ValidationResult(result=False, reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}") + else: + for tool in ctx.last_output().tool_calls.keys(): + if arg_name in output.tool_calls[tool_name].args: + arg_value = output.tool_calls[tool_name].args[arg_name] + validate_result = validation_fn(arg_value) + if not validate_result: + return ValidationResult(result=False, reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}") + return ValidationResult(result=True) + + return Requirement(description=description, validation_fn=_validate, check_only=check_only) \ No newline at end of file diff --git a/test/stdlib_basics/test_reqlib_tools.py b/test/stdlib_basics/test_reqlib_tools.py new file mode 100644 index 00000000..b7f771f2 --- /dev/null +++ b/test/stdlib_basics/test_reqlib_tools.py @@ -0,0 +1,10 @@ +import pytest +from mellea.stdlib.reqlib.tools import _name2str + +def test_name2str(): + """Test handling when no Python code is present.""" + def test123(): + pass + assert _name2str(test123) == "test123" + assert _name2str("test1234") == "test1234" + From 7d2125a9b9eecc6c94af5fe312d63603d9e979ec Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Mon, 17 Nov 2025 12:33:51 -0500 Subject: [PATCH 02/22] Fixes some bugs in requirement checkers. --- mellea/stdlib/reqlib/tools.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py index c83d35bf..7b04e941 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/reqlib/tools.py @@ -1,10 +1,11 @@ from typing import Callable, Optional -from mellea.stdlib.requirement import Context, ValidationResult, Requirement +from mellea.stdlib.base import Context +from mellea.stdlib.requirement import Requirement, ValidationResult def _name2str(tool_name: str | Callable) -> str: match tool_name: - case Callable(): + case tool_name if callable(tool_name): return tool_name.__name__ case str(): return tool_name @@ -72,11 +73,11 @@ def _validate(ctx: Context): return ValidationResult(result=False, reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}") else: for tool in ctx.last_output().tool_calls.keys(): - if arg_name in output.tool_calls[tool_name].args: - arg_value = output.tool_calls[tool_name].args[arg_name] + if arg_name in output.tool_calls[tool].args: + arg_value = output.tool_calls[tool].args[arg_name] validate_result = validation_fn(arg_value) if not validate_result: return ValidationResult(result=False, reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}") return ValidationResult(result=True) - return Requirement(description=description, validation_fn=_validate, check_only=check_only) \ No newline at end of file + return Requirement(description=description, validation_fn=_validate, check_only=check_only) From 6fa86ef3f7361007f4b5f7f803f34f5db967331e Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Tue, 18 Nov 2025 12:53:29 -0500 Subject: [PATCH 03/22] Adds a timeout argument to call_func. --- mellea/stdlib/base.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 111d44f6..58bbcace 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -17,6 +17,7 @@ from PIL import Image as PILImage from mellea.helpers.fancy_logger import FancyLogger +import multiprocessing.pool class CBlock: @@ -653,6 +654,14 @@ class ModelToolCall: func: Callable args: Mapping[str, Any] - def call_func(self) -> Any: - """A helper function for calling the function/tool represented by this object.""" - return self.func(**self.args) + def call_func(self, timeout: float | None=None) -> Any: + """A helper function for calling the function/tool represented by this object. + + Args: + timeout: if set, the tool call will time-out and an exception will be thrown after `timeout` seconds.""" + if timeout: + with multiprocessing.pool.ThreadPool() as pool: + result = pool.apply_async(self.func, kwds=self.args).get(timeout=timeout) + return result + else: + return self.func(**self.args) \ No newline at end of file From 523f3f20076f24874782cc317a2a584a9341ecd0 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 09:17:48 -0500 Subject: [PATCH 04/22] Moves Alex's code interpreters into a `tools` module. --- mellea/stdlib/reqlib/python.py | 195 --------------------- mellea/stdlib/tools/__init__.py | 1 + mellea/stdlib/tools/code_interpreter.py | 223 ++++++++++++++++++++++++ 3 files changed, 224 insertions(+), 195 deletions(-) create mode 100644 mellea/stdlib/tools/__init__.py create mode 100644 mellea/stdlib/tools/code_interpreter.py diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py index 2bcab2c5..6a163033 100644 --- a/mellea/stdlib/reqlib/python.py +++ b/mellea/stdlib/reqlib/python.py @@ -15,201 +15,6 @@ logger = FancyLogger.get_logger() -# region execution backends - - -@dataclass -class ExecutionResult: - """Result of code execution.""" - - success: bool - message: str | None = None - error: str | None = None - skipped: bool = False - - -class ExecutionEnvironment(ABC): - """Abstract environment for executing Python code.""" - - def __init__(self, allowed_imports: list[str] | None = None): - """Initialize with optional import restrictions. - - Args: - allowed_imports: List of allowed import modules. None means any import is allowed. - """ - self.allowed_imports = allowed_imports - - @abstractmethod - def execute(self, code: str, timeout: int) -> ExecutionResult: - """Execute code and return result.""" - - -class SafeEnvironment(ExecutionEnvironment): - """Safe environment that validates but does not execute code.""" - - def execute(self, code: str, timeout: int) -> ExecutionResult: - """Validate code syntax and imports without executing.""" - try: - ast.parse(code) - except SyntaxError as e: - return ExecutionResult(success=False, error=str(e)) - - if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) - if unauthorized: - return ExecutionResult( - success=False, - error=f"Unauthorized imports detected: {', '.join(unauthorized)}", - ) - - return ExecutionResult( - success=True, - skipped=True, - message="Code validated but not executed (safe mode)", - ) - - -class UnsafeEnvironment(ExecutionEnvironment): - """Unsafe environment that executes code directly with subprocess.""" - - def execute(self, code: str, timeout: int) -> ExecutionResult: - """Execute code with subprocess after checking imports.""" - if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) - if unauthorized: - return ExecutionResult( - success=False, - error=f"Unauthorized imports detected: {', '.join(unauthorized)}", - ) - - return self._execute_subprocess(code, timeout) - - def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: - f.write(code) - temp_file = f.name - - try: - # Execute code using the same Python interpreter and environment as the current process - # This ensures the code has access to all installed packages and dependencies - result = subprocess.run( - [sys.executable, temp_file], - capture_output=True, - text=True, - timeout=timeout, - ) - - if result.returncode == 0: - message = "Code executed successfully" - if result.stdout.strip(): - message += f"\nOutput: {result.stdout.strip()}" - return ExecutionResult(success=True, message=message) - else: - return ExecutionResult( - success=False, - error=f"Execution failed with error: {result.stderr[:200]}", - ) - except subprocess.TimeoutExpired: - return ExecutionResult( - success=False, error=f"Execution timed out after {timeout} seconds" - ) - except Exception as e: - return ExecutionResult(success=False, error=f"Execution error: {e!s}") - finally: - try: - Path(temp_file).unlink() - except Exception: - pass - - -class LLMSandboxEnvironment(ExecutionEnvironment): - """Environment using llm-sandbox for secure Docker-based execution.""" - - def execute(self, code: str, timeout: int) -> ExecutionResult: - """Execute code using llm-sandbox.""" - if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) - if unauthorized: - return ExecutionResult( - success=False, - error=f"Unauthorized imports detected: {', '.join(unauthorized)}", - ) - - try: - from llm_sandbox import SandboxSession - except ImportError: - return ExecutionResult( - success=False, - error="llm-sandbox not installed. Install with: uv add 'llm-sandbox[docker]'", - ) - - try: - with SandboxSession( - lang="python", verbose=False, keep_template=False - ) as session: - result = session.run(code, timeout=timeout) - - if result.exit_code == 0: - message = "Code executed successfully in sandbox" - if ( - hasattr(result, "stdout") - and result.stdout - and result.stdout.strip() - ): - message += f"\nOutput: {result.stdout.strip()}" - return ExecutionResult(success=True, message=message) - else: - if result.stderr: - error_msg = f"Sandbox execution failed: {result.stderr[:200]}" - else: - # Log unknown error details for debugging - logger.warning( - f"Sandbox execution failed without stderr. Exit code: {result.exit_code}, " - f"Available attributes: {[attr for attr in dir(result) if not attr.startswith('_')]}" - ) - error_msg = f"Sandbox execution failed with exit code {result.exit_code} (no error details available)" - return ExecutionResult(success=False, error=error_msg) - - except Exception as e: - return ExecutionResult( - success=False, error=f"Sandbox execution error: {e!s}" - ) - - -def _get_unauthorized_imports(code: str, allowed_imports: list[str]) -> list[str]: - """Get list of unauthorized imports used in code.""" - unauthorized: list[str] = [] - try: - tree = ast.parse(code) - except SyntaxError: - return unauthorized - - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - base_module = alias.name.split(".")[0] - if ( - base_module not in allowed_imports - and base_module not in unauthorized - ): - unauthorized.append(base_module) - elif isinstance(node, ast.ImportFrom): - if node.module: - base_module = node.module.split(".")[0] - if ( - base_module not in allowed_imports - and base_module not in unauthorized - ): - unauthorized.append(base_module) - return unauthorized - - -def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: - """Check if code only uses allowed imports.""" - return len(_get_unauthorized_imports(code, allowed_imports)) == 0 - - -# endregion # region code extraction diff --git a/mellea/stdlib/tools/__init__.py b/mellea/stdlib/tools/__init__.py new file mode 100644 index 00000000..b6afdefd --- /dev/null +++ b/mellea/stdlib/tools/__init__.py @@ -0,0 +1 @@ +from .code_interpreter import code_interpreter, local_code_interpreter \ No newline at end of file diff --git a/mellea/stdlib/tools/code_interpreter.py b/mellea/stdlib/tools/code_interpreter.py new file mode 100644 index 00000000..ac3ffa5e --- /dev/null +++ b/mellea/stdlib/tools/code_interpreter.py @@ -0,0 +1,223 @@ +from dataclasses import dataclass +from abc import ABC, abstractmethod +import ast +import tempfile +import subprocess +import sys +from pathlib import Path +from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import Context +from mellea.stdlib.requirement import Requirement, ValidationResult + +logger = FancyLogger.get_logger() + + +@dataclass +class ExecutionResult: + """Result of code execution.""" + + success: bool + message: str | None = None + error: str | None = None + skipped: bool = False + + +class ExecutionEnvironment(ABC): + """Abstract environment for executing Python code.""" + + def __init__(self, allowed_imports: list[str] | None = None): + """Initialize with optional import restrictions. + + Args: + allowed_imports: List of allowed import modules. None means any import is allowed. + """ + self.allowed_imports = allowed_imports + + @abstractmethod + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code and return result.""" + + +class SafeEnvironment(ExecutionEnvironment): + """Safe environment that validates but does not execute code.""" + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Validate code syntax and imports without executing.""" + try: + ast.parse(code) + except SyntaxError as e: + return ExecutionResult(success=False, error=str(e)) + + if self.allowed_imports: + unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + if unauthorized: + return ExecutionResult( + success=False, + error=f"Unauthorized imports detected: {', '.join(unauthorized)}", + ) + + return ExecutionResult( + success=True, + skipped=True, + message="Code validated but not executed (safe mode)", + ) + + +class UnsafeEnvironment(ExecutionEnvironment): + """Unsafe environment that executes code directly with subprocess.""" + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code with subprocess after checking imports.""" + if self.allowed_imports: + unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + if unauthorized: + return ExecutionResult( + success=False, + error=f"Unauthorized imports detected: {', '.join(unauthorized)}", + ) + + return self._execute_subprocess(code, timeout) + + def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(code) + temp_file = f.name + + try: + # Execute code using the same Python interpreter and environment as the current process + # This ensures the code has access to all installed packages and dependencies + result = subprocess.run( + [sys.executable, temp_file], + capture_output=True, + text=True, + timeout=timeout, + ) + + if result.returncode == 0: + message = "Code executed successfully" + if result.stdout.strip(): + message += f"\nOutput: {result.stdout.strip()}" + return ExecutionResult(success=True, message=message) + else: + return ExecutionResult( + success=False, + error=f"Execution failed with error: {result.stderr[:200]}", + ) + except subprocess.TimeoutExpired: + return ExecutionResult( + success=False, error=f"Execution timed out after {timeout} seconds" + ) + except Exception as e: + return ExecutionResult(success=False, error=f"Execution error: {e!s}") + finally: + try: + Path(temp_file).unlink() + except Exception: + pass + + +class LLMSandboxEnvironment(ExecutionEnvironment): + """Environment using llm-sandbox for secure Docker-based execution.""" + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code using llm-sandbox.""" + if self.allowed_imports: + unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + if unauthorized: + return ExecutionResult( + success=False, + error=f"Unauthorized imports detected: {', '.join(unauthorized)}", + ) + + try: + from llm_sandbox import SandboxSession + except ImportError: + return ExecutionResult( + success=False, + error="llm-sandbox not installed. Install with: uv add 'llm-sandbox[docker]'", + ) + + try: + with SandboxSession( + lang="python", verbose=False, keep_template=False + ) as session: + result = session.run(code, timeout=timeout) + + if result.exit_code == 0: + message = "Code executed successfully in sandbox" + if ( + hasattr(result, "stdout") + and result.stdout + and result.stdout.strip() + ): + message += f"\nOutput: {result.stdout.strip()}" + return ExecutionResult(success=True, message=message) + else: + if result.stderr: + error_msg = f"Sandbox execution failed: {result.stderr[:200]}" + else: + # Log unknown error details for debugging + logger.warning( + f"Sandbox execution failed without stderr. Exit code: {result.exit_code}, " + f"Available attributes: {[attr for attr in dir(result) if not attr.startswith('_')]}" + ) + error_msg = f"Sandbox execution failed with exit code {result.exit_code} (no error details available)" + return ExecutionResult(success=False, error=error_msg) + + except Exception as e: + return ExecutionResult( + success=False, error=f"Sandbox execution error: {e!s}" + ) + + +def _get_unauthorized_imports(code: str, allowed_imports: list[str]) -> list[str]: + """Get list of unauthorized imports used in code.""" + unauthorized: list[str] = [] + try: + tree = ast.parse(code) + except SyntaxError: + return unauthorized + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + base_module = alias.name.split(".")[0] + if ( + base_module not in allowed_imports + and base_module not in unauthorized + ): + unauthorized.append(base_module) + elif isinstance(node, ast.ImportFrom): + if node.module: + base_module = node.module.split(".")[0] + if ( + base_module not in allowed_imports + and base_module not in unauthorized + ): + unauthorized.append(base_module) + return unauthorized + + +def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: + """Check if code only uses allowed imports.""" + return len(_get_unauthorized_imports(code, allowed_imports)) == 0 + + +def code_interpreter(code: str): + """Executes python code. + + Args: + code: The Python code to execute. + """ + exec_env = LLMSandboxEnvironment(allowed_imports=None) + exec_env.execute(code, 60) + + +def local_code_interpreter(code: str): + """Executes python code in the cwd + + Args: + code: The Python code to execute. + """ + exec_env = UnsafeEnvironment(allowed_imports=None) + exec_env.execute(code, 60) \ No newline at end of file From 180cf12a4f002dde70ab5767ea845c302f5a7e96 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 09:19:39 -0500 Subject: [PATCH 05/22] Fix imports broken by previous commit. --- mellea/stdlib/reqlib/python.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py index 6a163033..fd11df2d 100644 --- a/mellea/stdlib/reqlib/python.py +++ b/mellea/stdlib/reqlib/python.py @@ -12,6 +12,7 @@ from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import Context from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.tools.code_interpreter import LLMSandboxEnvironment, UnsafeEnvironment, ExecutionEnvironment, SafeEnvironment logger = FancyLogger.get_logger() From b72ac1b6bbee71041673906eeb351f7ea1336288 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 09:23:33 -0500 Subject: [PATCH 06/22] Changes name of the static analysis environment. --- mellea/stdlib/reqlib/python.py | 4 ++-- mellea/stdlib/tools/code_interpreter.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py index fd11df2d..5605ec68 100644 --- a/mellea/stdlib/reqlib/python.py +++ b/mellea/stdlib/reqlib/python.py @@ -12,7 +12,7 @@ from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import Context from mellea.stdlib.requirement import Requirement, ValidationResult -from mellea.stdlib.tools.code_interpreter import LLMSandboxEnvironment, UnsafeEnvironment, ExecutionEnvironment, SafeEnvironment +from mellea.stdlib.tools.code_interpreter import LLMSandboxEnvironment, UnsafeEnvironment, ExecutionEnvironment, StaticAnalysisEnvironment logger = FancyLogger.get_logger() @@ -134,7 +134,7 @@ def _python_executes_without_error( elif allow_unsafe: environment = UnsafeEnvironment(allowed_imports=allowed_imports) else: - environment = SafeEnvironment(allowed_imports=allowed_imports) + environment = StaticAnalysisEnvironment(allowed_imports=allowed_imports) result = environment.execute(code, timeout) return ValidationResult( diff --git a/mellea/stdlib/tools/code_interpreter.py b/mellea/stdlib/tools/code_interpreter.py index ac3ffa5e..a038f77f 100644 --- a/mellea/stdlib/tools/code_interpreter.py +++ b/mellea/stdlib/tools/code_interpreter.py @@ -38,7 +38,7 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: """Execute code and return result.""" -class SafeEnvironment(ExecutionEnvironment): +class StaticAnalysisEnvironment(ExecutionEnvironment): """Safe environment that validates but does not execute code.""" def execute(self, code: str, timeout: int) -> ExecutionResult: From 6b5f7160f3ed6c34e462193122a4b1f4796e1a93 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 10:00:06 -0500 Subject: [PATCH 07/22] Refactors ExecutionResult. --- mellea/stdlib/reqlib/python.py | 2 +- mellea/stdlib/tools/code_interpreter.py | 94 ++++++++++++++----------- 2 files changed, 53 insertions(+), 43 deletions(-) diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py index 5605ec68..873d2046 100644 --- a/mellea/stdlib/reqlib/python.py +++ b/mellea/stdlib/reqlib/python.py @@ -138,7 +138,7 @@ def _python_executes_without_error( result = environment.execute(code, timeout) return ValidationResult( - result=result.success, reason=result.message or result.error + result=result.success, reason=result.stdout if result.success else result.stderr # TODO should we pass back both? ) diff --git a/mellea/stdlib/tools/code_interpreter.py b/mellea/stdlib/tools/code_interpreter.py index a038f77f..f90ef4fc 100644 --- a/mellea/stdlib/tools/code_interpreter.py +++ b/mellea/stdlib/tools/code_interpreter.py @@ -8,19 +8,42 @@ from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import Context from mellea.stdlib.requirement import Requirement, ValidationResult +from typing import Any + logger = FancyLogger.get_logger() @dataclass class ExecutionResult: - """Result of code execution.""" + """Result of code execution. + + Code execution can be aborted prior to spinning up an interpreter (e.g., if prohibited imports are used). + In these cases, the `success` flag is set to False and the `skipped` flag is set to True. + + If code is executed, then `success` is set to true iff the exit code is 0, and the `stdout` and `stderr` outputs + are set to non-None values. + + We also use the `ExecutionResult` object to communicate the result of static and dynamic analyses. Those are passed back + using the `analysis_result` field. + + TODO: should we also be trying to pass back the value of the final expression evaluated, or the value of locals() and globals()?""" success: bool - message: str | None = None - error: str | None = None + + stdout: str | None + + stderr: str | None + + """ Indicates whether execution was skipped. """ skipped: bool = False + """ If execution is skipped, this message indicates why. """ + skip_message: str | None = None + + """ Used for returning results from static analyses. """ + analysis_result : Any | None = None + class ExecutionEnvironment(ABC): """Abstract environment for executing Python code.""" @@ -44,7 +67,7 @@ class StaticAnalysisEnvironment(ExecutionEnvironment): def execute(self, code: str, timeout: int) -> ExecutionResult: """Validate code syntax and imports without executing.""" try: - ast.parse(code) + parse_tree = ast.parse(code) except SyntaxError as e: return ExecutionResult(success=False, error=str(e)) @@ -58,8 +81,11 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: return ExecutionResult( success=True, + stdout=None, + stderr=None, skipped=True, - message="Code validated but not executed (safe mode)", + skip_message="The static analysis execution environment does not execute code. To execute code, use one of the other execution environments.", + analysis_result=parse_tree ) @@ -92,17 +118,7 @@ def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: text=True, timeout=timeout, ) - - if result.returncode == 0: - message = "Code executed successfully" - if result.stdout.strip(): - message += f"\nOutput: {result.stdout.strip()}" - return ExecutionResult(success=True, message=message) - else: - return ExecutionResult( - success=False, - error=f"Execution failed with error: {result.stderr[:200]}", - ) + return ExecutionResult(success=result.returncode == 0, stdout=result.stdout.strip(), stderr=result.stderr.strip()) except subprocess.TimeoutExpired: return ExecutionResult( success=False, error=f"Execution timed out after {timeout} seconds" @@ -126,7 +142,10 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: if unauthorized: return ExecutionResult( success=False, - error=f"Unauthorized imports detected: {', '.join(unauthorized)}", + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Unauthorized imports detected: {', '.join(unauthorized)}", ) try: @@ -134,7 +153,10 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: except ImportError: return ExecutionResult( success=False, - error="llm-sandbox not installed. Install with: uv add 'llm-sandbox[docker]'", + stdout=None, + stderr=None, + skipped=True, + skip_message="llm-sandbox not installed. Install with: uv add 'llm-sandbox[docker]'", ) try: @@ -143,30 +165,18 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: ) as session: result = session.run(code, timeout=timeout) - if result.exit_code == 0: - message = "Code executed successfully in sandbox" - if ( - hasattr(result, "stdout") - and result.stdout - and result.stdout.strip() - ): - message += f"\nOutput: {result.stdout.strip()}" - return ExecutionResult(success=True, message=message) - else: - if result.stderr: - error_msg = f"Sandbox execution failed: {result.stderr[:200]}" - else: - # Log unknown error details for debugging - logger.warning( - f"Sandbox execution failed without stderr. Exit code: {result.exit_code}, " - f"Available attributes: {[attr for attr in dir(result) if not attr.startswith('_')]}" - ) - error_msg = f"Sandbox execution failed with exit code {result.exit_code} (no error details available)" - return ExecutionResult(success=False, error=error_msg) - + return ExecutionResult( + success=result.exit_code == 0, + stdout=result.stdout.strip(), + stderr=result.stderr.strip() + ) except Exception as e: return ExecutionResult( - success=False, error=f"Sandbox execution error: {e!s}" + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Sandbox execution error: {e!s}" ) @@ -203,7 +213,7 @@ def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: return len(_get_unauthorized_imports(code, allowed_imports)) == 0 -def code_interpreter(code: str): +def code_interpreter(code: str) -> ExecutionResult: """Executes python code. Args: @@ -213,7 +223,7 @@ def code_interpreter(code: str): exec_env.execute(code, 60) -def local_code_interpreter(code: str): +def local_code_interpreter(code: str) -> ExecutionResult: """Executes python code in the cwd Args: From 7512fd4a3460112e8e3fa616b761b819d7d9cceb Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 11:17:46 -0500 Subject: [PATCH 08/22] Fixes bug in the reqlib.tools uses_tool validator. --- mellea/stdlib/reqlib/tools.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py index 7b04e941..5784534d 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/reqlib/tools.py @@ -26,6 +26,8 @@ def uses_tool(tool_name: str | Callable, check_only=False): def _validate(ctx: Context): output = ctx.last_output() + if output.tool_calls is None: + return ValidationResult(result=False, reason="There were no tool calls.") return ValidationResult(result=tool_name in output.tool_calls) return Requirement( From 65dd35bb50e60e26a9d9798769896dc3a697ec7d Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 11:18:03 -0500 Subject: [PATCH 09/22] More work on the code_interpreter tool. --- docs/examples/tools/interpreter.py | 34 +++++++++++++++++++ mellea/stdlib/reqlib/python.py | 2 +- mellea/stdlib/tools/__init__.py | 2 +- .../{code_interpreter.py => interpreter.py} | 4 +-- 4 files changed, 38 insertions(+), 4 deletions(-) create mode 100644 docs/examples/tools/interpreter.py rename mellea/stdlib/tools/{code_interpreter.py => interpreter.py} (99%) diff --git a/docs/examples/tools/interpreter.py b/docs/examples/tools/interpreter.py new file mode 100644 index 00000000..5321937b --- /dev/null +++ b/docs/examples/tools/interpreter.py @@ -0,0 +1,34 @@ +from mellea.stdlib.tools import code_interpreter, local_code_interpreter +from mellea import start_session +from mellea.backends.types import ModelOption + +m = start_session() + +# # First, let's see how the code interpreter function works without an LLM in the loop: +# result = code_interpreter("print(1+1)") +# print(result) + +# # Now let's ask the LLM to make a plot. +# +# plot_output = m.instruct( +# description="Make a plot of y=x^2", +# model_options={ +# ModelOption.TOOLS: [local_code_interpreter] +# } +# ) +# print(plot_output) + +# Notice that the model did not actually generate a plot. Let's force tool use: + +from mellea.stdlib.reqlib.tools import uses_tool + +plot_output = m.instruct( + description="Make a plot of y=x^2", + requirements=[ + uses_tool(local_code_interpreter) + ], + model_options={ + ModelOption.TOOLS: [local_code_interpreter] + } +) +print(plot_output) diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py index 873d2046..8b390e44 100644 --- a/mellea/stdlib/reqlib/python.py +++ b/mellea/stdlib/reqlib/python.py @@ -12,7 +12,7 @@ from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import Context from mellea.stdlib.requirement import Requirement, ValidationResult -from mellea.stdlib.tools.code_interpreter import LLMSandboxEnvironment, UnsafeEnvironment, ExecutionEnvironment, StaticAnalysisEnvironment +from mellea.stdlib.tools.interpreter import LLMSandboxEnvironment, UnsafeEnvironment, ExecutionEnvironment, StaticAnalysisEnvironment logger = FancyLogger.get_logger() diff --git a/mellea/stdlib/tools/__init__.py b/mellea/stdlib/tools/__init__.py index b6afdefd..0b6b76bd 100644 --- a/mellea/stdlib/tools/__init__.py +++ b/mellea/stdlib/tools/__init__.py @@ -1 +1 @@ -from .code_interpreter import code_interpreter, local_code_interpreter \ No newline at end of file +from mellea.stdlib.tools.interpreter import code_interpreter, local_code_interpreter \ No newline at end of file diff --git a/mellea/stdlib/tools/code_interpreter.py b/mellea/stdlib/tools/interpreter.py similarity index 99% rename from mellea/stdlib/tools/code_interpreter.py rename to mellea/stdlib/tools/interpreter.py index f90ef4fc..51ddb602 100644 --- a/mellea/stdlib/tools/code_interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -220,7 +220,7 @@ def code_interpreter(code: str) -> ExecutionResult: code: The Python code to execute. """ exec_env = LLMSandboxEnvironment(allowed_imports=None) - exec_env.execute(code, 60) + return exec_env.execute(code, 60) def local_code_interpreter(code: str) -> ExecutionResult: @@ -230,4 +230,4 @@ def local_code_interpreter(code: str) -> ExecutionResult: code: The Python code to execute. """ exec_env = UnsafeEnvironment(allowed_imports=None) - exec_env.execute(code, 60) \ No newline at end of file + return exec_env.execute(code, 60) \ No newline at end of file From b05e7b217c40392d76b98922008ad000a72adedf Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 11:44:18 -0500 Subject: [PATCH 10/22] Fixes the code interpreter example (need tool_calls=True). Tracking this ergonomic issue in #234. --- docs/examples/tools/interpreter.py | 18 ++++++++++++++---- mellea/stdlib/tools/interpreter.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/docs/examples/tools/interpreter.py b/docs/examples/tools/interpreter.py index 5321937b..052356f3 100644 --- a/docs/examples/tools/interpreter.py +++ b/docs/examples/tools/interpreter.py @@ -1,7 +1,10 @@ from mellea.stdlib.tools import code_interpreter, local_code_interpreter from mellea import start_session from mellea.backends.types import ModelOption +from mellea.backends.model_ids import OPENAI_GPT_OSS_20B +from mellea.stdlib.reqlib.tools import uses_tool +# m = start_session(backend_name="ollama", model_id=OPENAI_GPT_OSS_20B) m = start_session() # # First, let's see how the code interpreter function works without an LLM in the loop: @@ -20,15 +23,22 @@ # Notice that the model did not actually generate a plot. Let's force tool use: -from mellea.stdlib.reqlib.tools import uses_tool plot_output = m.instruct( - description="Make a plot of y=x^2", + description="Use the code interpreter tool to make a plot of y=x^2.", requirements=[ uses_tool(local_code_interpreter) ], model_options={ ModelOption.TOOLS: [local_code_interpreter] - } + }, + tool_calls=True ) -print(plot_output) + +code = plot_output.tool_calls['local_code_interpreter'].args['code'] +print(f"Going to execute the following code:\n```python\n{code}\n```") + +# Call the tool. +exec_result = plot_output.tool_calls['local_code_interpreter'].call_func() + +print(exec_result) diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 51ddb602..9bef837d 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -230,4 +230,4 @@ def local_code_interpreter(code: str) -> ExecutionResult: code: The Python code to execute. """ exec_env = UnsafeEnvironment(allowed_imports=None) - return exec_env.execute(code, 60) \ No newline at end of file + return exec_env.execute(code, 60) \ No newline at end of file From 986e940f6bc026030044ec22b650cd0c339e4517 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 11:54:40 -0500 Subject: [PATCH 11/22] More cleanup for the python interpreter example. --- docs/examples/tools/interpreter.py | 44 ------------ docs/examples/tools/interpreter_example.py | 79 ++++++++++++++++++++++ 2 files changed, 79 insertions(+), 44 deletions(-) delete mode 100644 docs/examples/tools/interpreter.py create mode 100644 docs/examples/tools/interpreter_example.py diff --git a/docs/examples/tools/interpreter.py b/docs/examples/tools/interpreter.py deleted file mode 100644 index 052356f3..00000000 --- a/docs/examples/tools/interpreter.py +++ /dev/null @@ -1,44 +0,0 @@ -from mellea.stdlib.tools import code_interpreter, local_code_interpreter -from mellea import start_session -from mellea.backends.types import ModelOption -from mellea.backends.model_ids import OPENAI_GPT_OSS_20B -from mellea.stdlib.reqlib.tools import uses_tool - -# m = start_session(backend_name="ollama", model_id=OPENAI_GPT_OSS_20B) -m = start_session() - -# # First, let's see how the code interpreter function works without an LLM in the loop: -# result = code_interpreter("print(1+1)") -# print(result) - -# # Now let's ask the LLM to make a plot. -# -# plot_output = m.instruct( -# description="Make a plot of y=x^2", -# model_options={ -# ModelOption.TOOLS: [local_code_interpreter] -# } -# ) -# print(plot_output) - -# Notice that the model did not actually generate a plot. Let's force tool use: - - -plot_output = m.instruct( - description="Use the code interpreter tool to make a plot of y=x^2.", - requirements=[ - uses_tool(local_code_interpreter) - ], - model_options={ - ModelOption.TOOLS: [local_code_interpreter] - }, - tool_calls=True -) - -code = plot_output.tool_calls['local_code_interpreter'].args['code'] -print(f"Going to execute the following code:\n```python\n{code}\n```") - -# Call the tool. -exec_result = plot_output.tool_calls['local_code_interpreter'].call_func() - -print(exec_result) diff --git a/docs/examples/tools/interpreter_example.py b/docs/examples/tools/interpreter_example.py new file mode 100644 index 00000000..c7a1cd83 --- /dev/null +++ b/docs/examples/tools/interpreter_example.py @@ -0,0 +1,79 @@ +from mellea.stdlib.tools import code_interpreter, local_code_interpreter +from mellea import start_session, MelleaSession +from mellea.backends.types import ModelOption +from mellea.backends.model_ids import OPENAI_GPT_OSS_20B +from mellea.stdlib.reqlib.tools import uses_tool, tool_arg_validator + +def example_1(m: MelleaSession): + # First, let's see how the code interpreter function works without an LLM in the loop: + result = code_interpreter("print(1+1)") + print(result) + +# Now let's ask the LLM to make a plot. + +def example_2(m: MelleaSession): + plot_output = m.instruct( + description="Make a plot of y=x^2", + model_options={ + ModelOption.TOOLS: [local_code_interpreter] + } + ) + print(plot_output) + +# Notice that the model did not actually generate a plot. Let's force tool use: + +def example_3(m: MelleaSession): + plot_output = m.instruct( + description="Use the code interpreter tool to make a plot of y=x^2.", + requirements=[ + uses_tool(local_code_interpreter) + ], + model_options={ + ModelOption.TOOLS: [local_code_interpreter] + }, + tool_calls=True + ) + + code = plot_output.tool_calls['local_code_interpreter'].args['code'] + print(f"Going to execute the following code:\n```python\n{code}\n```") + + # Call the tool. + exec_result = plot_output.tool_calls['local_code_interpreter'].call_func() + + print(exec_result) + + +# Notice that the model did make a plot, but it just "showed" the plot. +# We would actually like this to be written out to a file. + +def example_4(m: MelleaSession): + plot_output = m.instruct( + description="Use the code interpreter tool to make a plot of y=x^2.", + requirements=[ + uses_tool(local_code_interpreter), + tool_arg_validator( + "The plot should be written to /tmp/output.png", + tool_name=local_code_interpreter, + arg_name="code", + validation_fn=lambda code_snippet: "/tmp/output.png" in code_snippet and "plt.show()" not in code_snippet + ) + ], + model_options={ + ModelOption.TOOLS: [local_code_interpreter] + }, + tool_calls=True + ) + + code = plot_output.tool_calls['local_code_interpreter'].args['code'] + print(f"Going to execute the following code:\n```python\n{code}\n```") + + # Call the tool. + exec_result = plot_output.tool_calls['local_code_interpreter'].call_func() + + print(exec_result) + + + +# m = start_session(backend_name="ollama", model_id=OPENAI_GPT_OSS_20B) +m = start_session() +example_4(m) \ No newline at end of file From 1c2a5eacefa1be05b611bd5e2768c611ed6df6d8 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:07:48 -0500 Subject: [PATCH 12/22] Reverts thread pool for tool calling. We should still think about this. --- mellea/stdlib/base.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 58bbcace..4c751ce5 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -654,14 +654,6 @@ class ModelToolCall: func: Callable args: Mapping[str, Any] - def call_func(self, timeout: float | None=None) -> Any: - """A helper function for calling the function/tool represented by this object. - - Args: - timeout: if set, the tool call will time-out and an exception will be thrown after `timeout` seconds.""" - if timeout: - with multiprocessing.pool.ThreadPool() as pool: - result = pool.apply_async(self.func, kwds=self.args).get(timeout=timeout) - return result - else: - return self.func(**self.args) \ No newline at end of file + def call_func(self) -> Any: + """A helper function for calling the function/tool represented by this object.""" + return self.func(**self.args) \ No newline at end of file From 3e2a00bee22393fdefa1f778f65fa72cb9e8ff7e Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:10:07 -0500 Subject: [PATCH 13/22] Revmoes unused imports is base and adds \n --- mellea/stdlib/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 4c751ce5..86389d1f 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -16,9 +16,6 @@ from PIL import Image as PILImage -from mellea.helpers.fancy_logger import FancyLogger -import multiprocessing.pool - class CBlock: """A `CBlock` is a block of content that can serve as input to or output from an LLM.""" @@ -656,4 +653,4 @@ class ModelToolCall: def call_func(self) -> Any: """A helper function for calling the function/tool represented by this object.""" - return self.func(**self.args) \ No newline at end of file + return self.func(**self.args) From 9ec1020fd9ee3514af614d6beb8500178cb2c2f1 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:10:40 -0500 Subject: [PATCH 14/22] Adds \n and removes unused import in base.py --- mellea/stdlib/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 86389d1f..111d44f6 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -16,6 +16,8 @@ from PIL import Image as PILImage +from mellea.helpers.fancy_logger import FancyLogger + class CBlock: """A `CBlock` is a block of content that can serve as input to or output from an LLM.""" From fdf88d5d8cdd21302c7ddb968bf14adf483ab9ea Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:22:30 -0500 Subject: [PATCH 15/22] Fixes some pre-commit issues. --- docs/examples/tools/interpreter_example.py | 42 +++++++-------- mellea/stdlib/reqlib/python.py | 12 ++++- mellea/stdlib/reqlib/tools.py | 61 ++++++++++++++-------- mellea/stdlib/tools/__init__.py | 2 +- mellea/stdlib/tools/interpreter.py | 48 ++++++++++------- 5 files changed, 99 insertions(+), 66 deletions(-) diff --git a/docs/examples/tools/interpreter_example.py b/docs/examples/tools/interpreter_example.py index c7a1cd83..ccde0ada 100644 --- a/docs/examples/tools/interpreter_example.py +++ b/docs/examples/tools/interpreter_example.py @@ -4,41 +4,40 @@ from mellea.backends.model_ids import OPENAI_GPT_OSS_20B from mellea.stdlib.reqlib.tools import uses_tool, tool_arg_validator + def example_1(m: MelleaSession): # First, let's see how the code interpreter function works without an LLM in the loop: result = code_interpreter("print(1+1)") print(result) + # Now let's ask the LLM to make a plot. + def example_2(m: MelleaSession): plot_output = m.instruct( description="Make a plot of y=x^2", - model_options={ - ModelOption.TOOLS: [local_code_interpreter] - } + model_options={ModelOption.TOOLS: [local_code_interpreter]}, ) print(plot_output) + # Notice that the model did not actually generate a plot. Let's force tool use: + def example_3(m: MelleaSession): plot_output = m.instruct( description="Use the code interpreter tool to make a plot of y=x^2.", - requirements=[ - uses_tool(local_code_interpreter) - ], - model_options={ - ModelOption.TOOLS: [local_code_interpreter] - }, - tool_calls=True + requirements=[uses_tool(local_code_interpreter)], + model_options={ModelOption.TOOLS: [local_code_interpreter]}, + tool_calls=True, ) - code = plot_output.tool_calls['local_code_interpreter'].args['code'] + code = plot_output.tool_calls["local_code_interpreter"].args["code"] print(f"Going to execute the following code:\n```python\n{code}\n```") # Call the tool. - exec_result = plot_output.tool_calls['local_code_interpreter'].call_func() + exec_result = plot_output.tool_calls["local_code_interpreter"].call_func() print(exec_result) @@ -46,6 +45,7 @@ def example_3(m: MelleaSession): # Notice that the model did make a plot, but it just "showed" the plot. # We would actually like this to be written out to a file. + def example_4(m: MelleaSession): plot_output = m.instruct( description="Use the code interpreter tool to make a plot of y=x^2.", @@ -55,25 +55,23 @@ def example_4(m: MelleaSession): "The plot should be written to /tmp/output.png", tool_name=local_code_interpreter, arg_name="code", - validation_fn=lambda code_snippet: "/tmp/output.png" in code_snippet and "plt.show()" not in code_snippet - ) + validation_fn=lambda code_snippet: "/tmp/output.png" in code_snippet + and "plt.show()" not in code_snippet, + ), ], - model_options={ - ModelOption.TOOLS: [local_code_interpreter] - }, - tool_calls=True + model_options={ModelOption.TOOLS: [local_code_interpreter]}, + tool_calls=True, ) - code = plot_output.tool_calls['local_code_interpreter'].args['code'] + code = plot_output.tool_calls["local_code_interpreter"].args["code"] print(f"Going to execute the following code:\n```python\n{code}\n```") # Call the tool. - exec_result = plot_output.tool_calls['local_code_interpreter'].call_func() + exec_result = plot_output.tool_calls["local_code_interpreter"].call_func() print(exec_result) - # m = start_session(backend_name="ollama", model_id=OPENAI_GPT_OSS_20B) m = start_session() -example_4(m) \ No newline at end of file +example_4(m) diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py index 8b390e44..db24e1b5 100644 --- a/mellea/stdlib/reqlib/python.py +++ b/mellea/stdlib/reqlib/python.py @@ -12,7 +12,12 @@ from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import Context from mellea.stdlib.requirement import Requirement, ValidationResult -from mellea.stdlib.tools.interpreter import LLMSandboxEnvironment, UnsafeEnvironment, ExecutionEnvironment, StaticAnalysisEnvironment +from mellea.stdlib.tools.interpreter import ( + ExecutionEnvironment, + LLMSandboxEnvironment, + StaticAnalysisEnvironment, + UnsafeEnvironment, +) logger = FancyLogger.get_logger() @@ -138,7 +143,10 @@ def _python_executes_without_error( result = environment.execute(code, timeout) return ValidationResult( - result=result.success, reason=result.stdout if result.success else result.stderr # TODO should we pass back both? + result=result.success, + reason=result.stdout + if result.success + else result.stderr, # TODO should we pass back both? ) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py index 5784534d..7d735524 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/reqlib/tools.py @@ -1,4 +1,6 @@ -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional + from mellea.stdlib.base import Context from mellea.stdlib.requirement import Requirement, ValidationResult @@ -14,72 +16,85 @@ def _name2str(tool_name: str | Callable) -> str: def uses_tool(tool_name: str | Callable, check_only=False): - """Forces the model to call a givne tool. - + """Forces the model to call a given tool. + Args: tool_name: The tool that must be called; this can be either the name of the tool or the Callable for the tool. check_only: Propagates to the Requirement. - + Use `tool_choice` if the OpenAI `tool_choice` model option is supported by your model and inference engine. """ tool_name = _name2str(tool_name) - + def _validate(ctx: Context): output = ctx.last_output() if output.tool_calls is None: return ValidationResult(result=False, reason="There were no tool calls.") return ValidationResult(result=tool_name in output.tool_calls) - + return Requirement( description=f"Use the {tool_name} tool.", validation_fn=_validate, - check_only=check_only + check_only=check_only, ) def tool_arg_validator( - description: str, - tool_name: Optional[str | Callable], - arg_name: str, - validation_fn: Callable, - check_only: bool=False + description: str, + tool_name: str | Callable | None, + arg_name: str, + validation_fn: Callable, + check_only: bool = False, ) -> Requirement: """A requirement that passes only if `validation_fn` returns a True value for the *value* of the `arg_name` argument to `tool_name`. - If `tool_name` is not specified, then this requirement is enforced for *every* tool that - + If `tool_name` is not specified, then this requirement is enforced for *every* tool that + Args: description: The Requirement description. tool_name: The (optional) tool name for . arg_name: The argument to check. validation_fn: A validation function for validating the value of the `arg_name` argument. - - TODO: + + Todo: 1. should this be a requirement? 2. should this be done automatically when the user provides asserts in their function body? """ if tool_name: tool_name = _name2str(tool_name) - + def _validate(ctx: Context): output = ctx.last_output() if tool_name: if tool_name not in output.tool_calls: - return ValidationResult(result=False, reason=f"Tool {tool_name} was not called.") + return ValidationResult( + result=False, reason=f"Tool {tool_name} was not called." + ) if arg_name not in output.tool_calls[tool_name].args: - return ValidationResult(result=False, reason=f"Tool {tool_name} did not call argument {arg_name}") + return ValidationResult( + result=False, + reason=f"Tool {tool_name} did not call argument {arg_name}", + ) arg_value = output.tool_calls[tool_name].args[arg_name] validate_result = validation_fn(arg_value) if validate_result: return ValidationResult(result=True) else: - return ValidationResult(result=False, reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}") + return ValidationResult( + result=False, + reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}", + ) else: for tool in ctx.last_output().tool_calls.keys(): if arg_name in output.tool_calls[tool].args: arg_value = output.tool_calls[tool].args[arg_name] validate_result = validation_fn(arg_value) if not validate_result: - return ValidationResult(result=False, reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}") + return ValidationResult( + result=False, + reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}", + ) return ValidationResult(result=True) - - return Requirement(description=description, validation_fn=_validate, check_only=check_only) + + return Requirement( + description=description, validation_fn=_validate, check_only=check_only + ) diff --git a/mellea/stdlib/tools/__init__.py b/mellea/stdlib/tools/__init__.py index 0b6b76bd..c61b6f42 100644 --- a/mellea/stdlib/tools/__init__.py +++ b/mellea/stdlib/tools/__init__.py @@ -1 +1 @@ -from mellea.stdlib.tools.interpreter import code_interpreter, local_code_interpreter \ No newline at end of file +from mellea.stdlib.tools.interpreter import code_interpreter, local_code_interpreter diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 9bef837d..70f81615 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -1,15 +1,15 @@ -from dataclasses import dataclass -from abc import ABC, abstractmethod import ast -import tempfile import subprocess import sys +import tempfile +from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path +from typing import Any + from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import Context from mellea.stdlib.requirement import Requirement, ValidationResult -from typing import Any - logger = FancyLogger.get_logger() @@ -26,8 +26,9 @@ class ExecutionResult: We also use the `ExecutionResult` object to communicate the result of static and dynamic analyses. Those are passed back using the `analysis_result` field. - - TODO: should we also be trying to pass back the value of the final expression evaluated, or the value of locals() and globals()?""" + + TODO: should we also be trying to pass back the value of the final expression evaluated, or the value of locals() and globals()? + """ success: bool @@ -42,7 +43,7 @@ class ExecutionResult: skip_message: str | None = None """ Used for returning results from static analyses. """ - analysis_result : Any | None = None + analysis_result: Any | None = None class ExecutionEnvironment(ABC): @@ -69,7 +70,14 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: try: parse_tree = ast.parse(code) except SyntaxError as e: - return ExecutionResult(success=False, error=str(e)) + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message="Parse failed.", + analysis_result=e, + ) if self.allowed_imports: unauthorized = _get_unauthorized_imports(code, self.allowed_imports) @@ -85,7 +93,7 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: stderr=None, skipped=True, skip_message="The static analysis execution environment does not execute code. To execute code, use one of the other execution environments.", - analysis_result=parse_tree + analysis_result=parse_tree, ) @@ -118,7 +126,11 @@ def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: text=True, timeout=timeout, ) - return ExecutionResult(success=result.returncode == 0, stdout=result.stdout.strip(), stderr=result.stderr.strip()) + return ExecutionResult( + success=result.returncode == 0, + stdout=result.stdout.strip(), + stderr=result.stderr.strip(), + ) except subprocess.TimeoutExpired: return ExecutionResult( success=False, error=f"Execution timed out after {timeout} seconds" @@ -168,15 +180,15 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: return ExecutionResult( success=result.exit_code == 0, stdout=result.stdout.strip(), - stderr=result.stderr.strip() + stderr=result.stderr.strip(), ) except Exception as e: return ExecutionResult( - success=False, + success=False, stdout=None, stderr=None, - skipped=True, - skip_message=f"Sandbox execution error: {e!s}" + skipped=True, + skip_message=f"Sandbox execution error: {e!s}", ) @@ -215,7 +227,7 @@ def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: def code_interpreter(code: str) -> ExecutionResult: """Executes python code. - + Args: code: The Python code to execute. """ @@ -225,9 +237,9 @@ def code_interpreter(code: str) -> ExecutionResult: def local_code_interpreter(code: str) -> ExecutionResult: """Executes python code in the cwd - + Args: code: The Python code to execute. """ exec_env = UnsafeEnvironment(allowed_imports=None) - return exec_env.execute(code, 60) \ No newline at end of file + return exec_env.execute(code, 60) From c63ff78fc4eebc4a112bcfb568346abd2a84a08f Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:32:32 -0500 Subject: [PATCH 16/22] Fixes some pre-commit issues. --- mellea/stdlib/reqlib/tools.py | 3 +++ mellea/stdlib/tools/interpreter.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py index 7d735524..f007ad6b 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/reqlib/tools.py @@ -1,3 +1,5 @@ +"""Requirements for tool-use workflows.""" + from collections.abc import Callable from typing import Optional @@ -54,6 +56,7 @@ def tool_arg_validator( tool_name: The (optional) tool name for . arg_name: The argument to check. validation_fn: A validation function for validating the value of the `arg_name` argument. + check_only: propagates the `check_only` flag to the requirement. Todo: 1. should this be a requirement? diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 70f81615..2c7ddd68 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -84,7 +84,10 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: if unauthorized: return ExecutionResult( success=False, - error=f"Unauthorized imports detected: {', '.join(unauthorized)}", + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Unauthorized imports detected: {', '.join(unauthorized)}", ) return ExecutionResult( @@ -107,7 +110,10 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: if unauthorized: return ExecutionResult( success=False, - error=f"Unauthorized imports detected: {', '.join(unauthorized)}", + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Unauthorized imports detected: {', '.join(unauthorized)}", ) return self._execute_subprocess(code, timeout) @@ -236,7 +242,7 @@ def code_interpreter(code: str) -> ExecutionResult: def local_code_interpreter(code: str) -> ExecutionResult: - """Executes python code in the cwd + """Executes python code in the cwd. Args: code: The Python code to execute. From 0fd1a576fbbdf02dbe90a888a89ff6247e9bcbca Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:37:34 -0500 Subject: [PATCH 17/22] Fixes some pre-commit errors. --- mellea/stdlib/reqlib/tools.py | 3 +++ mellea/stdlib/tools/__init__.py | 1 + mellea/stdlib/tools/interpreter.py | 15 +++++++++++++-- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py index f007ad6b..6da35069 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/reqlib/tools.py @@ -30,6 +30,7 @@ def uses_tool(tool_name: str | Callable, check_only=False): def _validate(ctx: Context): output = ctx.last_output() + assert output is not None if output.tool_calls is None: return ValidationResult(result=False, reason="There were no tool calls.") return ValidationResult(result=tool_name in output.tool_calls) @@ -49,6 +50,7 @@ def tool_arg_validator( check_only: bool = False, ) -> Requirement: """A requirement that passes only if `validation_fn` returns a True value for the *value* of the `arg_name` argument to `tool_name`. + If `tool_name` is not specified, then this requirement is enforced for *every* tool that Args: @@ -67,6 +69,7 @@ def tool_arg_validator( def _validate(ctx: Context): output = ctx.last_output() + assert output is not None if tool_name: if tool_name not in output.tool_calls: return ValidationResult( diff --git a/mellea/stdlib/tools/__init__.py b/mellea/stdlib/tools/__init__.py index c61b6f42..0a015c25 100644 --- a/mellea/stdlib/tools/__init__.py +++ b/mellea/stdlib/tools/__init__.py @@ -1 +1,2 @@ +""" Implementations of tools. """ from mellea.stdlib.tools.interpreter import code_interpreter, local_code_interpreter diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 2c7ddd68..66d417df 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -1,3 +1,4 @@ +""" Code interpreter tool. """ import ast import subprocess import sys @@ -139,10 +140,20 @@ def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: ) except subprocess.TimeoutExpired: return ExecutionResult( - success=False, error=f"Execution timed out after {timeout} seconds" + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message="Execution timed out." ) except Exception as e: - return ExecutionResult(success=False, error=f"Execution error: {e!s}") + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Exception encountered in Mellea process (*not* the code interpreter process) when trying to run code_interpreter: {e!s}" + ) finally: try: Path(temp_file).unlink() From f40556064121395b0c7416f5da4a8c976af4c6a3 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:37:51 -0500 Subject: [PATCH 18/22] Ruff formatter pass. --- mellea/stdlib/reqlib/tools.py | 2 +- mellea/stdlib/tools/__init__.py | 3 ++- mellea/stdlib/tools/interpreter.py | 7 ++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py index 6da35069..1b062ba4 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/reqlib/tools.py @@ -50,7 +50,7 @@ def tool_arg_validator( check_only: bool = False, ) -> Requirement: """A requirement that passes only if `validation_fn` returns a True value for the *value* of the `arg_name` argument to `tool_name`. - + If `tool_name` is not specified, then this requirement is enforced for *every* tool that Args: diff --git a/mellea/stdlib/tools/__init__.py b/mellea/stdlib/tools/__init__.py index 0a015c25..24ca99aa 100644 --- a/mellea/stdlib/tools/__init__.py +++ b/mellea/stdlib/tools/__init__.py @@ -1,2 +1,3 @@ -""" Implementations of tools. """ +"""Implementations of tools.""" + from mellea.stdlib.tools.interpreter import code_interpreter, local_code_interpreter diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 66d417df..1fbae33f 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -1,4 +1,5 @@ -""" Code interpreter tool. """ +"""Code interpreter tool.""" + import ast import subprocess import sys @@ -144,7 +145,7 @@ def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: stdout=None, stderr=None, skipped=True, - skip_message="Execution timed out." + skip_message="Execution timed out.", ) except Exception as e: return ExecutionResult( @@ -152,7 +153,7 @@ def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: stdout=None, stderr=None, skipped=True, - skip_message=f"Exception encountered in Mellea process (*not* the code interpreter process) when trying to run code_interpreter: {e!s}" + skip_message=f"Exception encountered in Mellea process (*not* the code interpreter process) when trying to run code_interpreter: {e!s}", ) finally: try: From 646d1271577341599e20ad7317a2e67b68d2dc50 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:40:10 -0500 Subject: [PATCH 19/22] Handle edge-case in tool requirement checker. --- mellea/stdlib/reqlib/tools.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py index 1b062ba4..090d7c6d 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/reqlib/tools.py @@ -69,7 +69,14 @@ def tool_arg_validator( def _validate(ctx: Context): output = ctx.last_output() + assert output is not None + + if output.tool_calls is None: + return ValidationResult( + result=False, reason=f"Expected {tool_name} to be called but no tools were called." + ) + if tool_name: if tool_name not in output.tool_calls: return ValidationResult( From 6745e8d7550749bd9fb064937188b2a524a92372 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:40:29 -0500 Subject: [PATCH 20/22] ruff --- mellea/stdlib/reqlib/tools.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py index 090d7c6d..2f6ff90f 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/reqlib/tools.py @@ -69,14 +69,15 @@ def tool_arg_validator( def _validate(ctx: Context): output = ctx.last_output() - + assert output is not None if output.tool_calls is None: return ValidationResult( - result=False, reason=f"Expected {tool_name} to be called but no tools were called." + result=False, + reason=f"Expected {tool_name} to be called but no tools were called.", ) - + if tool_name: if tool_name not in output.tool_calls: return ValidationResult( From aaf460e14f8e65d01d52a8fd0223524a03640990 Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 12:41:44 -0500 Subject: [PATCH 21/22] Fixes final pre-commit error. --- mellea/stdlib/reqlib/tools.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py index 2f6ff90f..6b64f18a 100644 --- a/mellea/stdlib/reqlib/tools.py +++ b/mellea/stdlib/reqlib/tools.py @@ -69,7 +69,6 @@ def tool_arg_validator( def _validate(ctx: Context): output = ctx.last_output() - assert output is not None if output.tool_calls is None: @@ -98,7 +97,7 @@ def _validate(ctx: Context): reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}", ) else: - for tool in ctx.last_output().tool_calls.keys(): + for tool in output.tool_calls.keys(): if arg_name in output.tool_calls[tool].args: arg_value = output.tool_calls[tool].args[arg_name] validate_result = validation_fn(arg_value) From 976839748d6e9e4dfe7b33da359305eaccafee2b Mon Sep 17 00:00:00 2001 From: Nathan Fulton Date: Thu, 20 Nov 2025 16:11:43 -0500 Subject: [PATCH 22/22] Fixes failing tests caused by changes to ExecutionResult. --- mellea/stdlib/reqlib/python.py | 5 +---- mellea/stdlib/tools/interpreter.py | 21 ++++++++++++++++++++- test/stdlib_basics/test_reqlib_python.py | 4 ---- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py index db24e1b5..1c83a330 100644 --- a/mellea/stdlib/reqlib/python.py +++ b/mellea/stdlib/reqlib/python.py @@ -143,10 +143,7 @@ def _python_executes_without_error( result = environment.execute(code, timeout) return ValidationResult( - result=result.success, - reason=result.stdout - if result.success - else result.stderr, # TODO should we pass back both? + result=result.success, reason=result.to_validationresult_reason() ) diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 1fbae33f..9f144057 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -47,6 +47,25 @@ class ExecutionResult: """ Used for returning results from static analyses. """ analysis_result: Any | None = None + def to_validationresult_reason(self): + """Maps an ExecutionResult to a ValidationResult reason. + + TODO: Downstream use of this method is really hacky. A far better solution is for `ExecutionResult` to implement the `ValidationResult` interface. + """ + assert self.skip_message is not None or ( + self.stderr is not None and self.stdout is not None + ), ( + "Every ExecutionResult should have either a skip_message or a stdout/stderr stream." + ) + if self.skip_message: + reason = self.skip_message + else: + if self.success: + reason = self.stdout + else: + reason = self.stderr + return reason + class ExecutionEnvironment(ABC): """Abstract environment for executing Python code.""" @@ -97,7 +116,7 @@ def execute(self, code: str, timeout: int) -> ExecutionResult: stdout=None, stderr=None, skipped=True, - skip_message="The static analysis execution environment does not execute code. To execute code, use one of the other execution environments.", + skip_message="Code parses successful; the parse result is in the analysis_result field of the ExecutionResult object. The static analysis execution environment does not execute code. To execute code, use one of the other execution environments.", analysis_result=parse_tree, ) diff --git a/test/stdlib_basics/test_reqlib_python.py b/test/stdlib_basics/test_reqlib_python.py index 153a981f..b3f9211a 100644 --- a/test/stdlib_basics/test_reqlib_python.py +++ b/test/stdlib_basics/test_reqlib_python.py @@ -128,7 +128,6 @@ def test_safe_mode_default(): req = PythonExecutionReq() result = req.validation_fn(VALID_PYTHON_CTX) assert result.as_bool() is True - assert "safe mode" in result.reason def test_safe_mode_syntax_error(): @@ -143,7 +142,6 @@ def test_safe_mode_no_execution(): req = PythonExecutionReq(timeout=1) result = req.validation_fn(PYTHON_INFINITE_LOOP_CTX) assert result.as_bool() is True # Should pass because it's not actually executed - assert "safe mode" in result.reason # endregion @@ -225,7 +223,6 @@ def test_sandbox_execution_valid(): req = PythonExecutionReq(use_sandbox=True, timeout=10) result = req.validation_fn(VALID_PYTHON_CTX) assert result.as_bool() is True - assert "sandbox" in result.reason.lower() @pytest.mark.skipif( @@ -313,7 +310,6 @@ def test_direct_validation_function(): VALID_PYTHON_CTX, timeout=5, allow_unsafe=False, use_sandbox=False ) assert result.as_bool() is True - assert "safe mode" in result.reason result = _python_executes_without_error( SYNTAX_ERROR_CTX, timeout=5, allow_unsafe=False, use_sandbox=False