Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 107 additions & 4 deletions metagpt/actions/run_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
5. Merged the `Config` class of send18:dev branch to take over the set/get operations of the Environment
class.
"""
import json
import subprocess
import sys
import textwrap
from pathlib import Path
from typing import Tuple

Expand Down Expand Up @@ -74,20 +77,120 @@
```
"""

# Wrapper script executed in the sandboxed subprocess. Reads the code to execute
# from stdin and prints a JSON object with ``result`` and ``error`` keys.
_SANDBOX_WRAPPER = textwrap.dedent("""\
import ast, json, operator, sys

_BIN_OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
}
_UNARY_OPS = {ast.UAdd: operator.pos, ast.USub: operator.neg, ast.Not: operator.not_}
_COMPARE_OPS = {
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
}

def _safe_eval(node):
if isinstance(node, ast.Expression):
return _safe_eval(node.body)
if isinstance(node, ast.Constant):
return node.value
if isinstance(node, ast.List):
return [_safe_eval(elt) for elt in node.elts]
if isinstance(node, ast.Tuple):
return tuple(_safe_eval(elt) for elt in node.elts)
if isinstance(node, ast.Set):
return {_safe_eval(elt) for elt in node.elts}
if isinstance(node, ast.Dict):
return {_safe_eval(key): _safe_eval(value) for key, value in zip(node.keys, node.values)}
if isinstance(node, ast.BinOp) and type(node.op) in _BIN_OPS:
return _BIN_OPS[type(node.op)](_safe_eval(node.left), _safe_eval(node.right))
if isinstance(node, ast.UnaryOp) and type(node.op) in _UNARY_OPS:
return _UNARY_OPS[type(node.op)](_safe_eval(node.operand))
if isinstance(node, ast.BoolOp):
values = [_safe_eval(value) for value in node.values]
return all(values) if isinstance(node.op, ast.And) else any(values)
if isinstance(node, ast.Compare):
left = _safe_eval(node.left)
for op, comparator in zip(node.ops, node.comparators):
right = _safe_eval(comparator)
if type(op) not in _COMPARE_OPS or not _COMPARE_OPS[type(op)](left, right):
return False
left = right
return True
raise ValueError("unsupported expression")

def _evaluate_result(code):
tree = ast.parse(code, mode="exec")
result = ""
for statement in tree.body:
if not (
isinstance(statement, ast.Assign)
and len(statement.targets) == 1
and isinstance(statement.targets[0], ast.Name)
and statement.targets[0].id == "result"
):
raise ValueError("only literal or arithmetic assignments to 'result' are supported")
result = _safe_eval(statement.value)
return result

try:
result = str(_evaluate_result(sys.stdin.read()))
error = ""
except Exception as e:
result = ""
error = str(e)

sys.stdout.write(json.dumps({"result": result, "error": error}))
""")


class RunCode(Action):
name: str = "RunCode"
i_context: RunCodeContext = Field(default_factory=RunCodeContext)

@classmethod
async def run_text(cls, code) -> Tuple[str, str]:
"""Execute *code* in an isolated subprocess and return ``(result, error)``.

The code is **not** executed via ``exec()``. Instead it is parsed in a
short-lived Python subprocess and limited to simple assignments to the
``result`` variable using literal values and arithmetic expressions.
"""
try:
# We will document_store the result in this dictionary
namespace = {}
exec(code, namespace)
process = subprocess.run(
[sys.executable, "-c", _SANDBOX_WRAPPER],
input=code,
capture_output=True,
text=True,
timeout=30,
)
# Parse the JSON payload written by the wrapper on stdout.
stdout = process.stdout.strip()
if stdout:
payload = json.loads(stdout)
result = payload.get("result", "")
error = payload.get("error", "")
else:
# Wrapper produced no output – likely a crash.
result = ""
error = process.stderr.strip() or "No output from sandbox subprocess"
except subprocess.TimeoutExpired:
return "", "Code execution timed out"
except Exception as e:
return "", str(e)
return namespace.get("result", ""), ""
return result, error

async def run_script(self, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]:
working_directory = str(working_directory)
Expand Down
2 changes: 1 addition & 1 deletion tests/metagpt/actions/test_run_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@pytest.mark.asyncio
async def test_run_text():
out, err = await RunCode.run_text("result = 1 + 1")
assert out == 2
assert out == "2"
assert err == ""

out, err = await RunCode.run_text("result = 1 / 0")
Expand Down
63 changes: 63 additions & 0 deletions tests/metagpt/actions/test_run_code_sandbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Security tests for CWE-95: Verify that RunCode.run_text() does not execute
untrusted Python statements from the host process.
"""

import pytest

from metagpt.actions.run_code import RunCode


@pytest.mark.asyncio
async def test_run_text_rejects_stacked_statements():
"""Code execution primitives such as imports/calls must be rejected."""
malicious_code = """
import os
os.environ["_METAGPT_MUTATION_TEST"] = "MUTATED"
result = "done"
"""
out, err = await RunCode.run_text(malicious_code)
assert out == ""
assert "only literal or arithmetic assignments" in err


@pytest.mark.asyncio
async def test_run_text_rejects_call_expressions():
"""A result assignment cannot invoke functions or builtins."""
out, err = await RunCode.run_text("result = __import__('os').getcwd()")
assert out == ""
assert "unsupported expression" in err


@pytest.mark.asyncio
async def test_run_text_rejects_attribute_access():
"""Attribute access must not be evaluated."""
out, err = await RunCode.run_text("result = ().__class__.__mro__")
assert out == ""
assert "unsupported expression" in err


@pytest.mark.asyncio
async def test_run_text_basic_functionality():
"""Basic run_text functionality should still work after the fix."""
out, err = await RunCode.run_text("result = 1 + 1")
assert out == "2"
assert err == ""

out, err = await RunCode.run_text("result = 'helloworld'")
assert out == "helloworld"
assert err == ""

out, err = await RunCode.run_text("result = 1 / 0")
assert out == ""
assert "division by zero" in err


@pytest.mark.asyncio
async def test_run_text_returns_string():
"""After sandboxing, run_text returns string representations of results."""
out, err = await RunCode.run_text("result = [1, 2, 3]")
assert out == "[1, 2, 3]"
assert err == ""
Loading