From fe7f3a066050a827e660462cb51e711bd8c48b11 Mon Sep 17 00:00:00 2001 From: Arnav Singhvi Date: Thu, 26 Sep 2024 16:40:07 -0700 Subject: [PATCH] patches python interpreter to support additional syntax --- dspy/predict/program_of_thought.py | 1 - dspy/primitives/python_interpreter.py | 237 ++++++++++++++++++++++---- 2 files changed, 207 insertions(+), 31 deletions(-) diff --git a/dspy/predict/program_of_thought.py b/dspy/predict/program_of_thought.py index 52d6ff1d5..4497af12e 100644 --- a/dspy/predict/program_of_thought.py +++ b/dspy/predict/program_of_thought.py @@ -156,7 +156,6 @@ def execute_code(self, code): interpreter = PythonInterpreter(action_space={"print": print}, import_white_list=self.import_white_list) try: output = str(code_prompt.execute(interpreter=interpreter)[0]) - print return code, output, None except Exception as e: return code, None, str(e) diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index 1b7456c7b..134b02c28 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -102,7 +102,7 @@ def __init__(self, action_space: Dict[str, Any], self.action_space = action_space self.state = self.action_space.copy() self.fuzz_state: Dict[str, Any] = {} - self.import_white_list = import_white_list or [] + self.import_white_list = import_white_list or ["math", "random", "datetime", "time", "string", "collections", "itertools", "functools", "typing", "enum", "json", "ast"] #default imports def execute(self, code: str, state: Optional[Dict[str, Any]] = None, fuzz_state: Optional[Dict[str, Any]] = None, @@ -183,6 +183,8 @@ def _execute_ast(self, expression: ast.AST) -> Any: elif isinstance(expression, ast.BinOp): # Binary Operator -> return the result value return self._execute_binop(expression) + elif isinstance(expression, ast.BoolOp): + return self._execute_condition(expression) elif isinstance(expression, ast.Call): # Function call -> return the value of the function call return self._execute_call(expression) @@ -212,9 +214,13 @@ def _execute_ast(self, expression: ast.AST) -> Any: elif isinstance(expression, ast.FunctionDef): self.state[expression.name] = expression return None + elif isinstance(expression, ast.GeneratorExp): + return self._execute_generatorexp(expression) elif isinstance(expression, ast.If): # If -> execute the right branch return self._execute_if(expression) + elif isinstance(expression, ast.IfExp): + return self._execute_ifexp(expression) elif isinstance(expression, ast.Import): # Import -> add imported names in self.state and return None. self._execute_import(expression) @@ -228,6 +234,8 @@ def _execute_ast(self, expression: ast.AST) -> Any: elif isinstance(expression, ast.JoinedStr): return "".join( [str(self._execute_ast(v)) for v in expression.values]) + elif isinstance(expression, ast.Lambda): + return self._execute_lambda(expression) elif isinstance(expression, ast.List): # List -> evaluate all elements return [self._execute_ast(elt) for elt in expression.elts] @@ -242,8 +250,27 @@ def _execute_ast(self, expression: ast.AST) -> Any: elif isinstance(expression, ast.Tuple): return tuple([self._execute_ast(elt) for elt in expression.elts]) elif isinstance(expression, ast.UnaryOp): - # Binary Operator -> return the result value return self._execute_unaryop(expression) + elif isinstance(expression, ast.While): + return self._execute_while(expression) + elif isinstance(expression, ast.ListComp): + return self._execute_listcomp(expression) + elif isinstance(expression, ast.DictComp): + return self._execute_dictcomp(expression) + elif isinstance(expression, ast.SetComp): + return self._execute_setcomp(expression) + elif isinstance(expression, ast.Break): + raise BreakException() + elif isinstance(expression, ast.Continue): + raise ContinueException() + elif isinstance(expression, ast.Try): + return self._execute_try(expression) + elif isinstance(expression, ast.Raise): + return self._execute_raise(expression) + elif isinstance(expression, ast.Pass): + return None + elif isinstance(expression, ast.Assert): + return self._execute_assert(expression) else: # For now we refuse anything else. Let's add things as we need # them. @@ -353,39 +380,47 @@ def _execute_condition(self, condition): elif isinstance(condition.op, ast.Or): results = [self._execute_ast(value) for value in condition.values] return any(results) - else: #TODO - add any other BoolOps missing + else: raise InterpreterError(f"Boolean operator {condition.op} is not supported") elif isinstance(condition, ast.Compare): if len(condition.ops) > 1: raise InterpreterError("Cannot evaluate conditions with multiple operators") - if len(condition.ops) > 1: - raise InterpreterError( - "Cannot evaluate conditions with multiple operators") - left = self._execute_ast(condition.left) - comparator = condition.ops[0] - right = self._execute_ast(condition.comparators[0]) - if isinstance(comparator, ast.Eq): - return left == right - elif isinstance(comparator, ast.NotEq): - return left != right - elif isinstance(comparator, ast.Lt): - return left < right - elif isinstance(comparator, ast.LtE): - return left <= right - elif isinstance(comparator, ast.Gt): - return left > right - elif isinstance(comparator, ast.GtE): - return left >= right - elif isinstance(comparator, ast.Is): - return left is right - elif isinstance(comparator, ast.IsNot): - return left is not right - elif isinstance(comparator, ast.In): - return left in right - elif isinstance(comparator, ast.NotIn): - return left not in right + left = self._execute_ast(condition.left) + comparator = condition.ops[0] + right = self._execute_ast(condition.comparators[0]) + if isinstance(comparator, ast.Eq): + return left == right + elif isinstance(comparator, ast.NotEq): + return left != right + elif isinstance(comparator, ast.Lt): + return left < right + elif isinstance(comparator, ast.LtE): + return left <= right + elif isinstance(comparator, ast.Gt): + return left > right + elif isinstance(comparator, ast.GtE): + return left >= right + elif isinstance(comparator, ast.Is): + return left is right + elif isinstance(comparator, ast.IsNot): + return left is not right + elif isinstance(comparator, ast.In): + return left in right + elif isinstance(comparator, ast.NotIn): + return left not in right + else: + raise InterpreterError("Unsupported comparison operator") + elif isinstance(condition, ast.UnaryOp): + return self._execute_unaryop(condition) + elif isinstance(condition, ast.Name): + return bool(self._execute_ast(condition)) + elif isinstance(condition, ast.Call): + return bool(self._execute_ast(condition)) + elif isinstance(condition, ast.Constant): + return bool(condition.value) else: - raise InterpreterError("Unsupported condition type") + raise InterpreterError(f"Unsupported condition type: {type(condition).__name__}") + def _execute_if(self, if_statement: ast.If): result = None @@ -400,6 +435,13 @@ def _execute_if(self, if_statement: ast.If): if line_result is not None: result = line_result return result + + def _execute_ifexp(self, ifexp: ast.IfExp) -> Any: + test_result = self._execute_condition(ifexp.test) + if test_result: + return self._execute_ast(ifexp.body) + else: + return self._execute_ast(ifexp.orelse) def _execute_for(self, for_statement: ast.For): result = None @@ -427,6 +469,16 @@ def _execute_import_from(self, import_from: ast.ImportFrom): imported_module = importlib.import_module(import_from.module) alias = import_name.asname or import_name.name self.state[alias] = getattr(imported_module, import_name.name) + + def _execute_lambda(self, lambda_node: ast.Lambda) -> Any: + def lambda_function(*args): + old_state = self.state.copy() + for param, arg in zip(lambda_node.args.args, args): + self.state[param.arg] = arg + result = self._execute_ast(lambda_node.body) + self.state = old_state # Restore the state + return result + return lambda_function def _validate_import(self, full_name: str): tmp_name = "" @@ -465,6 +517,12 @@ def _execute_binop(self, binop: ast.BinOp): return left << right elif isinstance(operator, ast.RShift): return left >> right + elif isinstance(operator, ast.BitAnd): + return left & right + elif isinstance(operator, ast.BitOr): + return left | right + elif isinstance(operator, ast.BitXor): + return left ^ right elif isinstance(operator, ast.MatMult): return left @ right else: @@ -480,8 +538,127 @@ def _execute_unaryop(self, unaryop: ast.UnaryOp): return -operand elif isinstance(operator, ast.Not): return not operand + elif isinstance(operator, ast.Invert): + return ~operand else: raise InterpreterError(f"Operator not supported: {operator}") + + def _execute_listcomp(self, comp: ast.ListComp): + return [self._execute_comp(comp.elt, comp.generators)] + + def _execute_dictcomp(self, comp: ast.DictComp): + return {self._execute_comp(comp.key, comp.generators): self._execute_comp(comp.value, comp.generators)} + + def _execute_setcomp(self, comp: ast.SetComp): + return {self._execute_comp(comp.elt, comp.generators)} + + def _execute_comp(self, elt, generators): + if not generators: + return self._execute_ast(elt) + gen = generators[0] + result = [] + for value in self._execute_ast(gen.iter): + self._assign(gen.target, value) + if all(self._execute_condition(if_cond) for if_cond in gen.ifs): + result.extend(self._execute_comp(elt, generators[1:])) + return result + + def _execute_generatorexp(self, genexp: ast.GeneratorExp): + def generator(): + for value in self._execute_comp(genexp.elt, genexp.generators): + yield value + return generator() + + def _execute_while(self, while_statement: ast.While): + result = None + while self._execute_condition(while_statement.test): + for line in while_statement.body: + line_result = self._execute_ast(line) + if line_result is not None: + result = line_result + if isinstance(line, (ast.Break, ast.Continue)): + break + else: + continue + break + return result + + def _execute_for(self, for_statement: ast.For): + class BreakException(Exception): + pass + + class ContinueException(Exception): + pass + result = None + try: + for value in self._execute_ast(for_statement.iter): + self._assign(for_statement.target, value) + try: + for line in for_statement.body: + line_result = self._execute_ast(line) + if line_result is not None: + result = line_result + except ContinueException: + continue + except BreakException: + pass + return result + + def _execute_while(self, while_statement: ast.While): + class BreakException(Exception): + pass + + class ContinueException(Exception): + pass + result = None + try: + while self._execute_condition(while_statement.test): + try: + for line in while_statement.body: + line_result = self._execute_ast(line) + if line_result is not None: + result = line_result + except ContinueException: + continue + except BreakException: + pass + return result + + def _execute_try(self, try_statement: ast.Try): + try: + for line in try_statement.body: + self._execute_ast(line) + except Exception as e: + handled = False + for handler in try_statement.handlers: + if handler.type is None or isinstance(e, self._execute_ast(handler.type)): + if handler.name: + self.state[handler.name.id] = e + for line in handler.body: + self._execute_ast(line) + handled = True + break + if not handled: + raise + finally: + for line in try_statement.finalbody: + self._execute_ast(line) + + def _execute_raise(self, raise_statement: ast.Raise): + if raise_statement.exc: + exception = self._execute_ast(raise_statement.exc) + raise exception + else: + raise + + def _execute_assert(self, assert_statement: ast.Assert): + test_result = self._execute_condition(assert_statement.test) + if not test_result: + if assert_statement.msg: + msg = self._execute_ast(assert_statement.msg) + raise AssertionError(msg) + else: + raise AssertionError def _get_value_from_state(self, key: str) -> Any: if key in self.state: