Created
March 16, 2023 21:08
-
-
Save bborn/4c6e769e74f3d6397452bec3c9f294e6 to your computer and use it in GitHub Desktop.
Python Repl tool using RestrictedPython
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A tool for running python code in a REPL.""" | |
import ast | |
from io import StringIO | |
import sys | |
from typing import Dict, Optional | |
from pydantic import Field, root_validator | |
from langchain.tools.base import BaseTool | |
from AccessControl import ModuleSecurityInfo as MSI | |
from RestrictedPython import compile_restricted, safe_builtins | |
from AccessControl.ZopeGuards import guarded_import, guarded_iter | |
from AccessControl.SecurityInfo import allow_module | |
def default_guarded_getitem(ob, index): | |
return ob[index] | |
MSI('json').declarePublic('loads') | |
MSI('json').declarePublic('dumps') | |
allowed_modules = [ | |
"langchain.agents", | |
"langchain.llms", | |
"langchain.text_splitter", | |
"langchain.document_loaders", | |
"langchain", | |
"langchain.prompts.few_shot", | |
"langchain.prompts.prompt", | |
] | |
for module in allowed_modules: | |
allow_module(module) | |
class PythonAstREPLTool(BaseTool): | |
"""A tool for running python code in a REPL.""" | |
name = "python_repl_ast" | |
description = ( | |
"A Python shell. Use this to execute python commands. " | |
"Input should be a valid python command. " | |
"When using this tool, sometimes output is abbreviated - " | |
"make sure it does not look abbreviated before using it in your answer." | |
) | |
globals: Optional[Dict] = Field(default_factory=dict) | |
locals: Optional[Dict] = Field(default_factory=dict) | |
@root_validator(pre=True) | |
def validate_python_version(cls, values: Dict) -> Dict: | |
"""Validate valid python version.""" | |
if sys.version_info < (3, 9): | |
raise ValueError( | |
"This tool relies on Python 3.9 or higher " | |
"(as it uses new functionality in the `ast` module, " | |
f"you have Python version: {sys.version}" | |
) | |
return values | |
def _run(self, query: str) -> str: | |
"""Use the tool.""" | |
locals = self.locals | |
safe_builtins["__import__"] = guarded_import | |
globals = {'__builtins__': safe_builtins, | |
"str": str, | |
"dict": dict, | |
"len": len, | |
"_getitem_": default_guarded_getitem, | |
"_getiter_": guarded_iter | |
} | |
try: | |
tree = ast.parse(query) | |
module = ast.Module(tree.body[:-1], type_ignores=[]) | |
exec(ast.unparse(module), globals, locals) # type: ignore | |
module_end = ast.Module(tree.body[-1:], type_ignores=[]) | |
module_end_str = ast.unparse(module_end) # type: ignore | |
try: | |
return eval(module_end_str, globals, locals) | |
except Exception: | |
exec(module_end_str, globals, locals) | |
return "" | |
except Exception as e: | |
return str(e) | |
async def _arun(self, query: str) -> str: | |
"""Use the tool asynchronously.""" | |
raise NotImplementedError("PythonReplTool does not support async") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment