diff options
-rw-r--r-- | TurtleArt/taprimitive.py | 7 | ||||
-rw-r--r-- | TurtleArt/tatype.py | 20 | ||||
-rw-r--r-- | util/codegen.py | 1 |
3 files changed, 24 insertions, 4 deletions
diff --git a/TurtleArt/taprimitive.py b/TurtleArt/taprimitive.py index ca0e73a..6b0c515 100644 --- a/TurtleArt/taprimitive.py +++ b/TurtleArt/taprimitive.py @@ -31,7 +31,7 @@ from taturtle import (Turtle, Turtles) from tatype import (ACTION_AST, BOX_AST, convert, get_call_ast, get_converter, get_type, is_bound_instancemethod, is_instancemethod, is_staticmethod, TATypeError, Type, TypeDisjunction, - TYPE_COLOR, TYPE_FLOAT, TYPE_OBJECT) + TypedLambda, TYPE_COLOR, TYPE_FLOAT, TYPE_OBJECT) from tautils import debug_output from tawindow import (global_objects, TurtleArtWindow) from util import ast_extensions @@ -850,9 +850,8 @@ class ArgSlot(object): # don't call and pass on the callable if convert_to_ast: lambda_body = func_prim.get_ast() - called_argument = ast.Lambda( - body=lambda_body, args=[], - vararg=None, kwarg=None, defaults=[]) + called_argument = TypedLambda(body=lambda_body, + return_type=lambda_body.return_type) else: called_argument = func_prim diff --git a/TurtleArt/tatype.py b/TurtleArt/tatype.py index 9cb5433..add26e3 100644 --- a/TurtleArt/tatype.py +++ b/TurtleArt/tatype.py @@ -342,6 +342,26 @@ class TypedCall(ast.Call): return self._return_type +class TypedLambda(ast.Lambda): + """ Like a Lambda AST, but with a return type """ + + def __init__(self, body=None, args=None, return_type=None): + + if args is None: + args = ast.arguments(args=[], vararg=None, kwarg=None, defaults=[]) + + ast.Lambda.__init__(self, body=body, args=args) + + self._return_type = return_type + + @property + def return_type(self): + if self._return_type is None: + return get_type(self.func) + else: + return self._return_type + + def get_call_ast(func_name, args=None, kwargs=None, return_type=None): """ Return an AST representing the call to a function with the name func_name, passing it the arguments args (given as a list) and the diff --git a/util/codegen.py b/util/codegen.py index f892308..43e9707 100644 --- a/util/codegen.py +++ b/util/codegen.py @@ -503,6 +503,7 @@ class SourceGenerator(NodeVisitor): self.signature(node.args) self.write(': ') self.visit(node.body) + visit_TypedLambda = visit_Lambda def visit_Ellipsis(self, node): self.write('Ellipsis') |