From 97b945e1e088c2e51d9a6233be49c401040883e9 Mon Sep 17 00:00:00 2001 From: Marion Date: Wed, 04 Sep 2013 15:08:07 +0000 Subject: Merge branch 'type-system' into type-system-while-until Conflicts: TurtleArt/taprimitive.py -- preserve the semantics of both changes --- diff --git a/TurtleArt/tabasics.py b/TurtleArt/tabasics.py index db27eaf..84a374d 100644 --- a/TurtleArt/tabasics.py +++ b/TurtleArt/tabasics.py @@ -132,13 +132,6 @@ class Palettes(): self.tw = turtle_window self.prim_cache = { - "check_number": Primitive(self.check_number, - return_type=TYPE_NUMBER, # TODO make this function obsolete (the return type is actually nonsense) - arg_descs=[ArgSlot(TYPE_OBJECT)], export_me=False), - "convert_for_cmp": Primitive(Primitive.convert_for_cmp, - constant_args={'decimal_point': self.tw.decimal_point}), - "convert_to_number": Primitive(Primitive.convert_to_number, - constant_args={'decimal_point': self.tw.decimal_point}), "minus": Primitive(Primitive.minus, return_type=TYPE_NUMBER, arg_descs=[ArgSlot(TYPE_NUMBER)]) @@ -980,9 +973,12 @@ buttons')) default=_('action'), logo_command='to action', help_string=_('top of nameable action stack')) - self.tw.lc.def_prim('nop3', 1, Primitive(self.tw.lc.prim_define_stack)) + self.tw.lc.def_prim('nop3', 1, + Primitive(self.tw.lc.prim_define_stack, + arg_descs=[ArgSlot(TYPE_STRING)])) - primitive_dictionary['stack'] = Primitive(self.tw.lc.prim_invoke_stack) + primitive_dictionary['stack'] = Primitive(self.tw.lc.prim_invoke_stack, + arg_descs=[ArgSlot(TYPE_STRING)]) palette.add_block('stack', style='basic-style-1arg', label=_('action'), @@ -991,10 +987,8 @@ buttons')) logo_command='action', default=_('action'), help_string=_('invokes named action stack')) - self.tw.lc.def_prim('stack', 1, - Primitive(self.tw.lc.prim_invoke_stack), True) + self.tw.lc.def_prim('stack', 1, primitive_dictionary['stack'], True) - primitive_dictionary['setbox'] = Primitive(self.tw.lc.prim_set_box) palette.add_block('storeinbox1', hidden=True, style='basic-style-1arg', @@ -1005,7 +999,8 @@ buttons')) logo_command='make "box1', help_string=_('stores numeric value in Variable 1')) self.tw.lc.def_prim('storeinbox1', 1, - Primitive(self.tw.lc.prim_set_box, constant_args={0: 'box1'})) + Primitive(self.tw.lc.prim_set_box, + arg_descs=[ConstantArg('box1'), ArgSlot(TYPE_OBJECT)])) palette.add_block('storeinbox2', hidden=True, @@ -1017,7 +1012,8 @@ buttons')) logo_command='make "box2', help_string=_('stores numeric value in Variable 2')) self.tw.lc.def_prim('storeinbox2', 1, - Primitive(self.tw.lc.prim_set_box, constant_args={0: 'box2'})) + Primitive(self.tw.lc.prim_set_box, + arg_descs=[ConstantArg('box2'), ArgSlot(TYPE_OBJECT)])) palette.add_block('box1', hidden=True, @@ -1028,7 +1024,8 @@ buttons')) help_string=_('Variable 1 (numeric value)'), value_block=True) self.tw.lc.def_prim('box1', 0, - Primitive(self.tw.lc.prim_get_box, constant_args={0: 'box1'})) + Primitive(self.tw.lc.prim_get_box, + arg_descs=[ConstantArg('box1')])) palette.add_block('box2', hidden=True, @@ -1039,8 +1036,11 @@ buttons')) help_string=_('Variable 2 (numeric value)'), value_block=True) self.tw.lc.def_prim('box2', 0, - Primitive(self.tw.lc.prim_get_box, constant_args={0: 'box2'})) + Primitive(self.tw.lc.prim_get_box, + arg_descs=[ConstantArg('box2')])) + primitive_dictionary['setbox'] = Primitive(self.tw.lc.prim_set_box, + arg_descs=[ArgSlot(TYPE_STRING), ArgSlot(TYPE_OBJECT)]) palette.add_block('storein', style='basic-style-2arg', label=[_('store in'), _('box'), _('value')], @@ -1050,10 +1050,12 @@ buttons')) default=[_('my box'), 100], help_string=_('stores numeric value in named \ variable')) - self.tw.lc.def_prim('storeinbox', 2, - Primitive(self.tw.lc.prim_set_box)) + self.tw.lc.def_prim('storeinbox', 2, primitive_dictionary['setbox']) - primitive_dictionary['box'] = Primitive(self.tw.lc.prim_get_box) + primitive_dictionary['box'] = Primitive(self.tw.lc.prim_get_box, + return_type=or_(TYPE_OBJECT, TYPE_STRING, TYPE_NUMBER, TYPE_FLOAT, + TYPE_INT, TYPE_NUMERIC_STRING, TYPE_CHAR, TYPE_COLOR), + arg_descs=[ArgSlot(TYPE_STRING)]) palette.add_block('box', style='number-style-1strarg', hidden=True, @@ -1064,8 +1066,7 @@ variable')) logo_command='box', value_block=True, help_string=_('named variable (numeric value)')) - self.tw.lc.def_prim('box', 1, - Primitive(self.tw.lc.prim_get_box)) + self.tw.lc.def_prim('box', 1, primitive_dictionary['box']) palette.add_block('hat1', hidden=True, @@ -1076,7 +1077,7 @@ variable')) help_string=_('top of Action 1 stack')) self.tw.lc.def_prim('nop1', 0, Primitive(self.tw.lc.prim_define_stack, - constant_args={0: 'stack1'})) + arg_descs=[ConstantArg('stack1')])) palette.add_block('hat2', hidden=True, @@ -1087,7 +1088,7 @@ variable')) help_string=_('top of Action 2 stack')) self.tw.lc.def_prim('nop2', 0, Primitive(self.tw.lc.prim_define_stack, - constant_args={0: 'stack2'})) + arg_descs=[ConstantArg('stack2')])) palette.add_block('stack1', hidden=True, @@ -1098,7 +1099,7 @@ variable')) help_string=_('invokes Action 1 stack')) self.tw.lc.def_prim('stack1', 0, Primitive(self.tw.lc.prim_invoke_stack, - constant_args={0: 'stack1'}), + arg_descs=[ConstantArg('stack1')]), True) palette.add_block('stack2', @@ -1110,7 +1111,7 @@ variable')) help_string=_('invokes Action 2 stack')) self.tw.lc.def_prim('stack2', 0, Primitive(self.tw.lc.prim_invoke_stack, - constant_args={0: 'stack2'}), + arg_descs=[ConstantArg('stack2')]), True) def _trash_palette(self): @@ -1207,29 +1208,6 @@ variable')) break self.tw.lc.ireturn() yield True - - def check_number(self, value): - ''' Check if value is a number. If yes, return the value. If no, - raise a logoerror. ''' - if not _num_type(value): - raise logoerror("#notanumber") - return value - - def check_non_negative(self, x, msg="#negroot"): - ''' Raise a logoerror iff x is negative. Otherwise, return x - unchanged. - msg -- the name of the logoerror message ''' - if x < 0: - raise logoerror(msg) - return x - - def check_non_zero(self, x, msg="#zerodivide"): - ''' Raise a logoerror iff x is zero. Otherwise, return x - unchanged. - msg -- the name of the logoerror message ''' - if x == 0: - raise logoerror(msg) - return x def after_right(self, *ignored_args): if self.tw.lc.update_values: diff --git a/TurtleArt/taexportpython.py b/TurtleArt/taexportpython.py index 71c4f9b..0d1e27a 100644 --- a/TurtleArt/taexportpython.py +++ b/TurtleArt/taexportpython.py @@ -30,7 +30,8 @@ import util.codegen as codegen #from ast_pprint import * # only used for debugging, safe to comment out from talogo import LogoCode -from taprimitive import (Primitive, PyExportError, value_to_ast) +from taprimitive import (ast_yield_true, Primitive, PyExportError, + value_to_ast) from tautils import (debug_output, find_group, find_top_block, get_stack_name) @@ -106,7 +107,8 @@ def _action_stack_to_python(block, lc, name="start"): name -- the name of the action stack (defaults to "start") """ # traverse the block stack and get the AST for every block ast_list = _walk_action_stack(block, lc) - ast_list.append(_ast_yield_true()) + if not isinstance(ast_list[-1], ast.Yield): + ast_list.append(ast_yield_true()) action_stack_ast = ast.Module(body=ast_list) #debug_output(str(action_stack_ast)) @@ -207,8 +209,9 @@ def _walk_action_stack(top_block, lc, convert_me=True): # body of conditional or loop new_arg_asts = _walk_action_stack(conn, lc, convert_me=convert_me) - if prim == LogoCode.prim_loop: - new_arg_asts.append(_ast_yield_true()) + if (prim == LogoCode.prim_loop and + not isinstance(new_arg_asts[-1], ast.Yield)): + new_arg_asts.append(ast_yield_true()) arg_asts.append(new_arg_asts) else: # argument block @@ -239,7 +242,4 @@ def _indent(code, num_levels=1): new_line_list.append(indentation + line) return linesep.join(new_line_list) -def _ast_yield_true(): - return ast.Yield(value=ast.Name(id='True', ctx=ast.Load)) - diff --git a/TurtleArt/taprimitive.py b/TurtleArt/taprimitive.py index 6e23ad0..ca0e73a 100644 --- a/TurtleArt/taprimitive.py +++ b/TurtleArt/taprimitive.py @@ -28,10 +28,10 @@ from tacanvas import TurtleGraphics from taconstants import (Color, CONSTANTS) from talogo import (LogoCode, logoerror) from taturtle import (Turtle, Turtles) -from tatype import (convert, get_call_ast, get_converter, get_type, - is_bound_instancemethod, is_instancemethod, +from tatype import (ACTION_AST, BOX_AST, convert, get_call_ast, get_converter, + get_type, is_bound_instancemethod, is_instancemethod, is_staticmethod, TATypeError, Type, TypeDisjunction, - TYPE_FLOAT, TYPE_OBJECT) + TYPE_COLOR, TYPE_FLOAT, TYPE_OBJECT) from tautils import debug_output from tawindow import (global_objects, TurtleArtWindow) from util import ast_extensions @@ -187,12 +187,11 @@ class Primitive(object): if isinstance(slot, ArgSlot): filler = filler_list.pop(0) try: - value = slot.fill(filler,convert_to_ast=convert_to_ast) + const = slot.fill(filler,convert_to_ast=convert_to_ast) except TATypeError as error: break else: - new_slot_list.append(ConstantArg(value, - call_arg=slot.call_arg)) + new_slot_list.append(const) else: new_slot_list.append(slot) if error is None: @@ -205,10 +204,9 @@ class Primitive(object): for key in keywords: kwarg_desc = new_prim.kwarg_descs[key] if isinstance(kwarg_desc, ArgSlot): - value = kwarg_desc.fill(keywords[key], + const = kwarg_desc.fill(keywords[key], convert_to_ast=convert_to_ast) - # TODO don't we need the ConstantArg constructor here as well? - new_prim.kwarg_descs[key] = value + new_prim.kwarg_descs[key] = const return new_prim @@ -297,7 +295,7 @@ class Primitive(object): debug_output(" arg_asts: " + repr(arg_asts)) new_prim = self.fill_slots(arg_asts, kwarg_asts, convert_to_ast=True) if not new_prim.are_slots_filled(): - raise PyExportError("not enough arguments") + raise PyExportError("not enough arguments") # TODO better msg if Primitive._DEBUG: debug_output(" new_prim.arg_descs: " + repr(new_prim.arg_descs)) @@ -354,23 +352,21 @@ class Primitive(object): # boxes elif self == LogoCode.prim_set_box: - id_str = 'BOX[%s]' % (repr(ast_to_value(new_arg_asts[0]))) - target_ast = ast.Name(id=id_str, ctx=ast.Store) - value_ast = new_arg_asts[1] - assign_ast = ast.Assign(targets=[target_ast], value=value_ast) - return assign_ast + target_ast = ast.Subscript(value=BOX_AST, + slice=ast.Index(value=new_arg_asts[0]), ctx=ast.Store) + return ast.Assign(targets=[target_ast], value=new_arg_asts[1]) elif self == LogoCode.prim_get_box: - id_str = 'BOX[%s]' % (repr(ast_to_value(new_arg_asts[0]))) - return ast.Name(id=id_str, ctx=ast.Load) + return ast.Subscript(value=BOX_AST, + slice=ast.Index(value=new_arg_asts[0]), ctx=ast.Load) # action stacks elif self == LogoCode.prim_define_stack: return elif self == LogoCode.prim_invoke_stack: - stack_name = ast_to_value(new_arg_asts[0]) - stack_func_name = 'ACTION[%s]' % (repr(stack_name)) - stack_func = ast.Name(id=stack_func_name, ctx=ast.Load) - return get_call_ast('logo.icall', [stack_func]) + stack_func = ast.Subscript(value=ACTION_AST, + slice=ast.Index(value=new_arg_asts[0]), ctx=ast.Load) + call_ast = get_call_ast('logo.icall', [stack_func]) + return [call_ast, ast_yield_true()] # standard operators elif self.func.__name__ in Primitive.STANDARD_OPERATORS: @@ -381,7 +377,8 @@ class Primitive(object): return get_type(x)[0] == TYPE_FLOAT if ( not _is_float(new_arg_asts[0]) and not _is_float(new_arg_asts[1])): - new_arg_asts[0] = get_call_ast('float', [new_arg_asts[0]]) + new_arg_asts[0] = get_call_ast('float', [new_arg_asts[0]], + return_type=TYPE_FLOAT) if len(new_arg_asts) == 1: if isinstance(op, tuple): op = op[0] @@ -406,11 +403,8 @@ class Primitive(object): # square root elif self == Primitive.square_root: - return get_call_ast('sqrt', new_arg_asts, new_kwarg_asts) - - # type conversion # TODO remove when obsolete - elif self in (Primitive.convert_for_cmp, Primitive.convert_to_number): - return self.func(*new_arg_asts, **new_kwarg_asts) + return get_call_ast('sqrt', new_arg_asts, new_kwarg_asts, + return_type=self.return_type) # identity elif self == Primitive.identity: @@ -440,7 +434,8 @@ class Primitive(object): else: func_name = self.get_name_for_export() - return get_call_ast(func_name, new_arg_asts, new_kwarg_asts) + return get_call_ast(func_name, new_arg_asts, new_kwarg_asts, + return_type=self.return_type) def __eq__(self, other): """ Two Primitives are equal iff their all their properties are equal. @@ -594,77 +589,6 @@ class Primitive(object): return arg1 + arg2 @staticmethod - def convert_to_number(value, decimal_point='.'): - """ Convert value to a number. If value is an AST, another AST is - wrapped around it to represent the conversion, e.g., - Str(s='1.2') -> Call(func=Name('float'), args=[Str(s='1.2')]) - 1. Return all numbers (float, int, long) unchanged. - 2. Convert a string containing a number into a float. - 3. Convert a single character to its ASCII integer value. - 4. Extract the first element of a list and convert it to a number. - 5. Convert a Color to a float. - If the value cannot be converted to a number and the value is not - an AST, return None. If it is an AST, return an AST representing - `float(value)'. """ # TODO find a better solution - # 1. number - if isinstance(value, (float, int, long, ast.Num)): - return value - - converted = None - conversion_ast = None - convert_to_ast = False - if isinstance(value, ast.AST): - convert_to_ast = True - value_ast = value - value = ast_to_value(value_ast) - if isinstance(decimal_point, ast.AST): - decimal_point = ast_to_value(decimal_point) - - # 2./3. string - if isinstance(value, basestring): - if convert_to_ast: - conversion_ast = Primitive.convert_for_cmp(value_ast, - decimal_point) - if not isinstance(conversion_ast, ast.Num): - converted = None - else: - converted = Primitive.convert_for_cmp(value, decimal_point) - if not isinstance(converted, (float, int, long)): - converted = None - # 4. list - elif isinstance(value, list): - if value: - number = Primitive.convert_to_number(value[0]) - if convert_to_ast: - conversion_ast = number - else: - converted = number - else: - converted = None - if convert_to_ast: - conversion_ast = get_call_ast('float', [value_ast]) - # 5. Color - elif isinstance(value, Color): - converted = float(value) - if convert_to_ast: - conversion_ast = get_call_ast('float', [value_ast]) - else: - converted = None - if convert_to_ast: - conversion_ast = get_call_ast('float', [value_ast]) - - if convert_to_ast: - if conversion_ast is None: - return value_ast - else: - return conversion_ast - else: - if converted is None: - return value - else: - return converted - - @staticmethod def minus(arg1, arg2=None): """ If only one argument is given, change its sign. If two arguments are given, subtract the second from the first. """ @@ -725,59 +649,6 @@ class Primitive(object): return not arg @staticmethod - def convert_for_cmp(value, decimal_point='.'): - """ Convert value such that it can be compared to something else. If - value is an AST, another AST is wrapped around it to represent the - conversion, e.g., - Str(s='a') -> Call(func=Name('ord'), args=[Str(s='a')]) - 1. Convert a string containing a number into a float. - 2. Convert a single character to its ASCII integer value. - 3. Return all other values unchanged. """ - converted = None - conversion_ast = None - convert_to_ast = False - if isinstance(value, ast.AST): - convert_to_ast = True - value_ast = value - value = ast_to_value(value_ast) - if isinstance(decimal_point, ast.AST): - decimal_point = ast_to_value(decimal_point) - - if isinstance(value, basestring): - # 1. string containing a number - replaced = value.replace(decimal_point, '.') - try: - converted = float(replaced) - except ValueError: - pass - else: - if convert_to_ast: - conversion_ast = get_call_ast('float', [value_ast]) - - # 2. single character - if converted is None: - try: - converted = ord(value) - except TypeError: - pass - else: - if convert_to_ast: - conversion_ast = get_call_ast('ord', [value_ast]) - - # 3. normal string or other type of value (nothing to do) - - if convert_to_ast: - if conversion_ast is None: - return value_ast - else: - return conversion_ast - else: - if converted is None: - return value - else: - return converted - - @staticmethod def equals(arg1, arg2): """ Return arg1 == arg2 """ return arg1 == arg2 @@ -820,11 +691,6 @@ class Disjunction(tuple): return self -# make TypeDisjunction 'inherit' the methods of the abstract Disjunction class -TypeDisjunction.__repr__ = Disjunction.__repr__ -TypeDisjunction.get_alternatives = Disjunction.get_alternatives - - class PrimitiveDisjunction(Disjunction,Primitive): """ Disjunction of two or more Primitives. PrimitiveDisjunctions may not be nested. """ @@ -899,8 +765,9 @@ class ArgSlot(object): return (self, ) def fill(self, argument, convert_to_ast=False): - """ Try to fill this argument slot with the given argument. If there - is a type problem, raise a TATypeError. """ + """ Try to fill this argument slot with the given argument. Return + a ConstantArg containing the result. If there is a type problem, + raise a TATypeError. """ if isinstance(argument, ast.AST): convert_to_ast = True @@ -935,23 +802,26 @@ class ArgSlot(object): # check if the argument can fill this slot (type-wise) if wrapper is not None: - arg_type = get_type(wrapper)[0] + arg_types = get_type(wrapper)[0] bad_value = wrapper elif func is not None: - arg_type = get_type(func)[0] + arg_types = get_type(func)[0] bad_value = func else: - arg_type = get_type(argument)[0] + arg_types = get_type(argument)[0] bad_value = argument converter = None + if not isinstance(arg_types, TypeDisjunction): + arg_types = TypeDisjunction((arg_types, )) if isinstance(slot.type, TypeDisjunction): - for type_ in slot.type: - converter = get_converter(arg_type, type_) + slot_types = slot.type + else: + slot_types = TypeDisjunction((slot.type, )) + for old_type in arg_types: + for new_type in slot_types: + converter = get_converter(old_type, new_type) if converter is not None: break - else: - type_ = slot.type - converter = get_converter(arg_type, type_) # unable to convert, try next wrapper/ slot/ func if converter is None: continue @@ -1010,22 +880,23 @@ class ArgSlot(object): # 3. check the type and convert the argument if necessary try: - converted_argument = convert(wrapped_argument, type_, - converter=converter) + converted_argument = convert(wrapped_argument, + new_type, old_type=old_type, converter=converter) except TATypeError as error: # on failure, try next wrapper/ slot/ func bad_value = wrapped_argument continue else: # on success, return the result - return converted_argument + return ConstantArg(converted_argument, + value_type=new_type, call_arg=slot.call_arg) # if we haven't returned anything yet, then all alternatives failed if error is not None: raise error else: - raise TATypeError(bad_value=bad_value, bad_type=arg_type, - req_type=type_, + raise TATypeError(bad_value=bad_value, bad_type=old_type, + req_type=new_type, message="filling slot " + repr(self)) @@ -1038,9 +909,13 @@ class ConstantArg(object): """ A constant argument or keyword argument to a Primitive. It is independent of the block program structure. """ - def __init__(self, value, call_arg=True): + def __init__(self, value, call_arg=True, value_type=None): + """ call_arg -- call the value before returning it? + value_type -- the type of the value (from the TA type system). This + is useful to store e.g., the return type of call ASTs. """ self.value = value self.call_arg = call_arg + self.value_type = value_type def get(self, convert_to_ast=False): """ If call_arg is True and the value is callable, call the value @@ -1057,6 +932,14 @@ class ConstantArg(object): else: return self.value + def get_value_type(self): + """ If this ConstantArg has stored the type of its value, return + that. Else, use get_type(...) to guess the type of the value. """ + if self.value_type is None: + return get_type(self.value)[0] + else: + return self.value_type + def __repr__(self): return "ConstantArg(%s)" % (repr(self.value)) @@ -1113,7 +996,8 @@ def value_to_ast(value, *args_for_prim, **kwargs_for_prim): # call to the Color constructor with this object's values, # e.g., Color('red', 0, 50, 100) return get_call_ast('Color', [value.name, value.color, - value.shade, value.gray]) + value.shade, value.gray], + return_type=TYPE_COLOR) else: raise ValueError("unknown type of raw value: " + repr(type(value))) @@ -1121,7 +1005,9 @@ def ast_to_value(ast_object): """ Retrieve the value out of a value AST. Supported AST types: Num, Str, Name, List, Tuple, Set If no value can be extracted, return None. """ - if isinstance(ast_object, ast.Num): + if not isinstance(ast_object, ast.AST): + return ast_object + elif isinstance(ast_object, ast.Num): return ast_object.n elif isinstance(ast_object, ast.Str): return ast_object.s @@ -1135,6 +1021,9 @@ def ast_to_value(ast_object): else: return None +def ast_yield_true(): + return ast.Yield(value=ast.Name(id='True', ctx=ast.Load)) + def export_me(something): """ Return True iff this is not a Primitive or its export_me attribute diff --git a/TurtleArt/tatype.py b/TurtleArt/tatype.py index 164ba2a..9cb5433 100644 --- a/TurtleArt/tatype.py +++ b/TurtleArt/tatype.py @@ -49,7 +49,19 @@ class Type(object): class TypeDisjunction(tuple,Type): """ Disjunction of two or more Types (from the type hierarchy) """ - pass + + def __init__(self, iterable): + self = tuple(iterable) + + + def __str__(self): + s = ["("] + for disj in self: + s.append(str(disj)) + s.append(" or ") + s.pop() + s.append(")") + return "".join(s) TYPE_OBJECT = Type('object', 0) @@ -65,6 +77,9 @@ TYPE_STRING = Type('string', 9) # TODO add list types +BOX_AST = ast.Name(id='BOX', ctx=ast.Load) +ACTION_AST = ast.Name(id='ACTION', ctx=ast.Load) + def get_type(x): """ Return the most specific type in the type hierarchy that applies to x and a boolean indicating whether x is an AST. If the type cannot be @@ -101,6 +116,11 @@ def get_type(x): return (TYPE_OBJECT, True) else: return (get_type(value)[0], True) + elif isinstance(x, ast.Subscript): + if x.value == BOX_AST: + return (TypeDisjunction((TYPE_OBJECT, TYPE_STRING, TYPE_NUMBER, + TYPE_FLOAT, TYPE_INT, TYPE_NUMERIC_STRING, TYPE_CHAR, + TYPE_COLOR)), True) elif isinstance(x, ast.Call): if isinstance(x.func, ast.Name): if x.func.id == 'float': @@ -163,24 +183,6 @@ TYPE_CONVERTERS = { } -def get_call_ast(func_name, args=None, kwargs=None): - """ Return an AST representing the call to a function with the name - func_name, passing it the arguments args (given as a list) and the - keyword arguments kwargs (given as a dictionary). """ - if args is None: - args = [] - keywords = [] - if kwargs is not None: - for (key, value) in kwargs.iteritems(): - keywords.append(ast.keyword(arg=key, value=value)) - return ast.Call(func=ast.Name(id=func_name, - ctx=ast.Load), - args=args, - keywords=keywords, - starargs=None, - kwargs=None) - - class TATypeError(BaseException): """ TypeError with the types from the hierarchy, not with Python types """ @@ -299,11 +301,7 @@ def convert(x, new_type, old_type=None, converter=None): func = ast.Attribute(value=y, attr=converter.im_func.__name__, ctx=ast.Load) - return ast.Call(func=func, - args=[], - keywords={}, - starargs=None, - kwargs=None) + return get_call_ast(func) else: func_name = converter.__name__ return get_call_ast(func_name, [y]) @@ -320,3 +318,57 @@ def convert(x, new_type, old_type=None, converter=None): return _apply_converter(converter, x) +class TypedCall(ast.Call): + """ Like a Call AST, but with a return type """ + + def __init__(self, func, args=None, keywords=None, starargs=None, + kwargs=None, return_type=None): + + if args is None: + args = [] + if keywords is None: + keywords = [] + + ast.Call.__init__(self, func=func, args=args, keywords=keywords, + starargs=starargs, kwargs=kwargs) + + self._return_type = return_type + + @property + def return_type(self): + if self._return_type is None: + return get_type(self.func) + else: + return self._return_type + + +def get_call_ast(func_name, args=None, kwargs=None, return_type=None): + """ Return an AST representing the call to a function with the name + func_name, passing it the arguments args (given as a list) and the + keyword arguments kwargs (given as a dictionary). + func_name -- either the name of a callable as a string, or an AST + representing a callable expression + return_type -- if this is not None, return a TypedCall object with this + return type instead """ + if args is None: + args = [] + # convert keyword argument dict to a list of (key, value) pairs + keywords = [] + if kwargs is not None: + for (key, value) in kwargs.iteritems(): + keywords.append(ast.keyword(arg=key, value=value)) + # get or generate the AST representing the callable + if isinstance(func_name, ast.AST): + func_ast = func_name + else: + func_ast = ast.Name(id=func_name, ctx=ast.Load) + # if no return type is given, return a simple Call AST + if return_type is None: + return ast.Call(func=func_ast, args=args, keywords=keywords, + starargs=None, kwargs=None) + # if a return type is given, return a TypedCall AST + else: + return TypedCall(func=func_ast, args=args, keywords=keywords, + return_type=return_type) + + diff --git a/util/codegen.py b/util/codegen.py index 1bcc42f..f892308 100644 --- a/util/codegen.py +++ b/util/codegen.py @@ -393,6 +393,7 @@ class SourceGenerator(NodeVisitor): self.write('**') self.visit(node.kwargs) self.write(')') + visit_TypedCall = visit_Call def visit_Name(self, node): self.write(node.id) -- cgit v0.9.1