import collections import re import warnings from rope.base import ast, codeanalyze, exceptions def get_patched_ast(source, sorted_children=False): """Adds ``region`` and ``sorted_children`` fields to nodes Adds ``sorted_children`` field only if `sorted_children` is True. """ return patch_ast(ast.parse(source), source, sorted_children) def patch_ast(node, source, sorted_children=False): """Patches the given node After calling, each node in `node` will have a new field named `region` that is a tuple containing the start and end offsets of the code that generated it. If `sorted_children` is true, a `sorted_children` field will be created for each node, too. It is a list containing child nodes as well as whitespaces and comments that occur between them. """ if hasattr(node, 'region'): return node walker = _PatchingASTWalker(source, children=sorted_children) ast.call_for_nodes(node, walker) return node def node_region(patched_ast_node): """Get the region of a patched ast node""" return patched_ast_node.region def write_ast(patched_ast_node): """Extract source form a patched AST node with `sorted_children` field If the node is patched with sorted_children turned off you can use `node_region` function for obtaining code using module source code. """ result = [] for child in patched_ast_node.sorted_children: if isinstance(child, ast.AST): result.append(write_ast(child)) else: result.append(child) return ''.join(result) class MismatchedTokenError(exceptions.RopeError): pass class _PatchingASTWalker(object): def __init__(self, source, children=False): self.source = _Source(source) self.children = children self.lines = codeanalyze.SourceLinesAdapter(source) self.children_stack = [] Number = object() String = object() def __call__(self, node): method = getattr(self, '_' + node.__class__.__name__, None) if method is not None: return method(node) # ???: Unknown node; what should we do here? warnings.warn('Unknown node type <%s>; please report!' % node.__class__.__name__, RuntimeWarning) node.region = (self.source.offset, self.source.offset) if self.children: node.sorted_children = ast.get_children(node) def _handle(self, node, base_children, eat_parens=False, eat_spaces=False): if hasattr(node, 'region'): # ???: The same node was seen twice; what should we do? warnings.warn( 'Node <%s> has been already patched; please report!' % node.__class__.__name__, RuntimeWarning) return base_children = collections.deque(base_children) self.children_stack.append(base_children) children = collections.deque() formats = [] suspected_start = self.source.offset start = suspected_start first_token = True while base_children: child = base_children.popleft() if child is None: continue offset = self.source.offset if isinstance(child, ast.AST): ast.call_for_nodes(child, self) token_start = child.region[0] else: if child is self.String: region = self.source.consume_string( end=self._find_next_statement_start()) elif child is self.Number: region = self.source.consume_number() elif child == '!=': # INFO: This has been added to handle deprecated ``<>`` region = self.source.consume_not_equal() else: region = self.source.consume(child) child = self.source[region[0]:region[1]] token_start = region[0] if not first_token: formats.append(self.source[offset:token_start]) if self.children: children.append(self.source[offset:token_start]) else: first_token = False start = token_start if self.children: children.append(child) start = self._handle_parens(children, start, formats) if eat_parens: start = self._eat_surrounding_parens( children, suspected_start, start) if eat_spaces: if self.children: children.appendleft(self.source[0:start]) end_spaces = self.source[self.source.offset:] self.source.consume(end_spaces) if self.children: children.append(end_spaces) start = 0 if self.children: node.sorted_children = children node.region = (start, self.source.offset) self.children_stack.pop() def _handle_parens(self, children, start, formats): """Changes `children` and returns new start""" opens, closes = self._count_needed_parens(formats) old_end = self.source.offset new_end = None for i in range(closes): new_end = self.source.consume(')')[1] if new_end is not None: if self.children: children.append(self.source[old_end:new_end]) new_start = start for i in range(opens): new_start = self.source.rfind_token('(', 0, new_start) if new_start != start: if self.children: children.appendleft(self.source[new_start:start]) start = new_start return start def _eat_surrounding_parens(self, children, suspected_start, start): index = self.source.rfind_token('(', suspected_start, start) if index is not None: old_start = start old_offset = self.source.offset start = index if self.children: children.appendleft(self.source[start + 1:old_start]) children.appendleft('(') token_start, token_end = self.source.consume(')') if self.children: children.append(self.source[old_offset:token_start]) children.append(')') return start def _count_needed_parens(self, children): start = 0 opens = 0 for child in children: if not isinstance(child, basestring): continue if child == '' or child[0] in '\'"': continue index = 0 while index < len(child): if child[index] == ')': if opens > 0: opens -= 1 else: start += 1 if child[index] == '(': opens += 1 if child[index] == '#': try: index = child.index('\n', index) except ValueError: break index += 1 return start, opens def _find_next_statement_start(self): for children in reversed(self.children_stack): for child in children: if isinstance(child, ast.stmt): return child.col_offset \ + self.lines.get_line_start(child.lineno) return len(self.source.source) _operators = {'And': 'and', 'Or': 'or', 'Add': '+', 'Sub': '-', 'Mult': '*', 'Div': '/', 'Mod': '%', 'Pow': '**', 'LShift': '<<', 'RShift': '>>', 'BitOr': '|', 'BitAnd': '&', 'BitXor': '^', 'FloorDiv': '//', 'Invert': '~', 'Not': 'not', 'UAdd': '+', 'USub': '-', 'Eq': '==', 'NotEq': '!=', 'Lt': '<', 'LtE': '<=', 'Gt': '>', 'GtE': '>=', 'Is': 'is', 'IsNot': 'is not', 'In': 'in', 'NotIn': 'not in'} def _get_op(self, node): return self._operators[node.__class__.__name__].split(' ') def _Attribute(self, node): self._handle(node, [node.value, '.', node.attr]) def _Assert(self, node): children = ['assert', node.test] if node.msg: children.append(',') children.append(node.msg) self._handle(node, children) def _Assign(self, node): children = self._child_nodes(node.targets, '=') children.append('=') children.append(node.value) self._handle(node, children) def _AugAssign(self, node): children = [node.target] children.extend(self._get_op(node.op)) children.extend(['=', node.value]) self._handle(node, children) def _Repr(self, node): self._handle(node, ['`', node.value, '`']) def _BinOp(self, node): children = [node.left] + self._get_op(node.op) + [node.right] self._handle(node, children) def _BoolOp(self, node): self._handle(node, self._child_nodes(node.values, self._get_op(node.op)[0])) def _Break(self, node): self._handle(node, ['break']) def _Call(self, node): children = [node.func, '('] args = list(node.args) + node.keywords children.extend(self._child_nodes(args, ',')) if node.starargs is not None: if args: children.append(',') children.extend(['*', node.starargs]) if node.kwargs is not None: if args or node.starargs is not None: children.append(',') children.extend(['**', node.kwargs]) children.append(')') self._handle(node, children) def _ClassDef(self, node): children = [] if getattr(node, 'decorator_list', None): for decorator in node.decorator_list: children.append('@') children.append(decorator) children.extend(['class', node.name]) if node.bases: children.append('(') children.extend(self._child_nodes(node.bases, ',')) children.append(')') children.append(':') children.extend(node.body) self._handle(node, children) def _Compare(self, node): children = [] children.append(node.left) for op, expr in zip(node.ops, node.comparators): children.extend(self._get_op(op)) children.append(expr) self._handle(node, children) def _Delete(self, node): self._handle(node, ['del'] + self._child_nodes(node.targets, ',')) def _Num(self, node): self._handle(node, [self.Number]) def _Str(self, node): self._handle(node, [self.String]) def _Continue(self, node): self._handle(node, ['continue']) def _Dict(self, node): children = [] children.append('{') if node.keys: for index, (key, value) in enumerate(zip(node.keys, node.values)): children.extend([key, ':', value]) if index < len(node.keys) - 1: children.append(',') children.append('}') self._handle(node, children) def _Ellipsis(self, node): self._handle(node, ['...']) def _Expr(self, node): self._handle(node, [node.value]) def _Exec(self, node): children = [] children.extend(['exec', node.body]) if node.globals: children.extend(['in', node.globals]) if node.locals: children.extend([',', node.locals]) self._handle(node, children) def _ExtSlice(self, node): children = [] for index, dim in enumerate(node.dims): if index > 0: children.append(',') children.append(dim) self._handle(node, children) def _For(self, node): children = ['for', node.target, 'in', node.iter, ':'] children.extend(node.body) if node.orelse: children.extend(['else', ':']) children.extend(node.orelse) self._handle(node, children) def _ImportFrom(self, node): children = ['from'] if node.level: children.append('.' * node.level) children.extend([node.module or '', # see comment at rope.base.ast.walk 'import']) children.extend(self._child_nodes(node.names, ',')) self._handle(node, children) def _alias(self, node): children = [node.name] if node.asname: children.extend(['as', node.asname]) self._handle(node, children) def _FunctionDef(self, node): children = [] try: decorators = getattr(node, 'decorator_list') except AttributeError: decorators = getattr(node, 'decorators', None) if decorators: for decorator in decorators: children.append('@') children.append(decorator) children.extend(['def', node.name, '(', node.args]) children.extend([')', ':']) children.extend(node.body) self._handle(node, children) def _arguments(self, node): children = [] args = list(node.args) defaults = [None] * (len(args) - len(node.defaults)) + list(node.defaults) for index, (arg, default) in enumerate(zip(args, defaults)): if index > 0: children.append(',') self._add_args_to_children(children, arg, default) if node.vararg is not None: if args: children.append(',') children.extend(['*', node.vararg]) if node.kwarg is not None: if args or node.vararg is not None: children.append(',') children.extend(['**', node.kwarg]) self._handle(node, children) def _add_args_to_children(self, children, arg, default): if isinstance(arg, (list, tuple)): self._add_tuple_parameter(children, arg) else: children.append(arg) if default is not None: children.append('=') children.append(default) def _add_tuple_parameter(self, children, arg): children.append('(') for index, token in enumerate(arg): if index > 0: children.append(',') if isinstance(token, (list, tuple)): self._add_tuple_parameter(children, token) else: children.append(token) children.append(')') def _GeneratorExp(self, node): children = [node.elt] children.extend(node.generators) self._handle(node, children, eat_parens=True) def _comprehension(self, node): children = ['for', node.target, 'in', node.iter] if node.ifs: for if_ in node.ifs: children.append('if') children.append(if_) self._handle(node, children) def _Global(self, node): children = self._child_nodes(node.names, ',') children.insert(0, 'global') self._handle(node, children) def _If(self, node): if self._is_elif(node): children = ['elif'] else: children = ['if'] children.extend([node.test, ':']) children.extend(node.body) if node.orelse: if len(node.orelse) == 1 and self._is_elif(node.orelse[0]): pass else: children.extend(['else', ':']) children.extend(node.orelse) self._handle(node, children) def _is_elif(self, node): if not isinstance(node, ast.If): return False offset = self.lines.get_line_start(node.lineno) + node.col_offset word = self.source[offset:offset + 4] # XXX: This is a bug; the offset does not point to the first alt_word = self.source[offset - 5:offset - 1] return 'elif' in (word, alt_word) def _IfExp(self, node): return self._handle(node, [node.body, 'if', node.test, 'else', node.orelse]) def _Import(self, node): children = ['import'] children.extend(self._child_nodes(node.names, ',')) self._handle(node, children) def _keyword(self, node): self._handle(node, [node.arg, '=', node.value]) def _Lambda(self, node): self._handle(node, ['lambda', node.args, ':', node.body]) def _List(self, node): self._handle(node, ['['] + self._child_nodes(node.elts, ',') + [']']) def _ListComp(self, node): children = ['[', node.elt] children.extend(node.generators) children.append(']') self._handle(node, children) def _Module(self, node): self._handle(node, list(node.body), eat_spaces=True) def _Name(self, node): self._handle(node, [node.id]) def _Pass(self, node): self._handle(node, ['pass']) def _Print(self, node): children = ['print'] if node.dest: children.extend(['>>', node.dest]) if node.values: children.append(',') children.extend(self._child_nodes(node.values, ',')) if not node.nl: children.append(',') self._handle(node, children) def _Raise(self, node): children = ['raise'] if node.type: children.append(node.type) if node.inst: children.append(',') children.append(node.inst) if node.tback: children.append(',') children.append(node.tback) self._handle(node, children) def _Return(self, node): children = ['return'] if node.value: children.append(node.value) self._handle(node, children) def _Sliceobj(self, node): children = [] for index, slice in enumerate(node.nodes): if index > 0: children.append(':') if slice: children.append(slice) self._handle(node, children) def _Index(self, node): self._handle(node, [node.value]) def _Subscript(self, node): self._handle(node, [node.value, '[', node.slice, ']']) def _Slice(self, node): children = [] if node.lower: children.append(node.lower) children.append(':') if node.upper: children.append(node.upper) if node.step: children.append(':') children.append(node.step) self._handle(node, children) def _TryFinally(self, node): children = [] if len(node.body) != 1 or not isinstance(node.body[0], ast.TryExcept): children.extend(['try', ':']) children.extend(node.body) children.extend(['finally', ':']) children.extend(node.finalbody) self._handle(node, children) def _TryExcept(self, node): children = ['try', ':'] children.extend(node.body) children.extend(node.handlers) if node.orelse: children.extend(['else', ':']) children.extend(node.orelse) self._handle(node, children) def _ExceptHandler(self, node): self._excepthandler(node) def _excepthandler(self, node): children = ['except'] if node.type: children.append(node.type) if node.name: children.extend([',', node.name]) children.append(':') children.extend(node.body) self._handle(node, children) def _Tuple(self, node): if node.elts: self._handle(node, self._child_nodes(node.elts, ','), eat_parens=True) else: self._handle(node, ['(', ')']) def _UnaryOp(self, node): children = self._get_op(node.op) children.append(node.operand) self._handle(node, children) def _Yield(self, node): children = ['yield'] if node.value: children.append(node.value) self._handle(node, children) def _While(self, node): children = ['while', node.test, ':'] children.extend(node.body) if node.orelse: children.extend(['else', ':']) children.extend(node.orelse) self._handle(node, children) def _With(self, node): children = ['with', node.context_expr] if node.optional_vars: children.extend(['as', node.optional_vars]) children.append(':') children.extend(node.body) self._handle(node, children) def _child_nodes(self, nodes, separator): children = [] for index, child in enumerate(nodes): children.append(child) if index < len(nodes) - 1: children.append(separator) return children class _Source(object): def __init__(self, source): self.source = source self.offset = 0 def consume(self, token): try: while True: new_offset = self.source.index(token, self.offset) if self._good_token(token, new_offset): break else: self._skip_comment() except (ValueError, TypeError): raise MismatchedTokenError( 'Token <%s> at %s cannot be matched' % (token, self._get_location())) self.offset = new_offset + len(token) return (new_offset, self.offset) def consume_string(self, end=None): if _Source._string_pattern is None: original = codeanalyze.get_string_pattern() pattern = r'(%s)((\s|\\\n|#[^\n]*\n)*(%s))*' % \ (original, original) _Source._string_pattern = re.compile(pattern) repattern = _Source._string_pattern return self._consume_pattern(repattern, end) def consume_number(self): if _Source._number_pattern is None: _Source._number_pattern = re.compile( self._get_number_pattern()) repattern = _Source._number_pattern return self._consume_pattern(repattern) def consume_not_equal(self): if _Source._not_equals_pattern is None: _Source._not_equals_pattern = re.compile(r'<>|!=') repattern = _Source._not_equals_pattern return self._consume_pattern(repattern) def _good_token(self, token, offset, start=None): """Checks whether consumed token is in comments""" if start is None: start = self.offset try: comment_index = self.source.rindex('#', start, offset) except ValueError: return True try: new_line_index = self.source.rindex('\n', start, offset) except ValueError: return False return comment_index < new_line_index def _skip_comment(self): self.offset = self.source.index('\n', self.offset + 1) def _get_location(self): lines = self.source[:self.offset].split('\n') return (len(lines), len(lines[-1])) def _consume_pattern(self, repattern, end=None): while True: if end is None: end = len(self.source) match = repattern.search(self.source, self.offset, end) if self._good_token(match.group(), match.start()): break else: self._skip_comment() self.offset = match.end() return match.start(), match.end() def till_token(self, token): new_offset = self.source.index(token, self.offset) return self[self.offset:new_offset] def rfind_token(self, token, start, end): index = start while True: try: index = self.source.rindex(token, start, end) if self._good_token(token, index, start=start): return index else: end = index except ValueError: return None def from_offset(self, offset): return self[offset:self.offset] def find_backwards(self, pattern, offset): return self.source.rindex(pattern, 0, offset) def __getitem__(self, index): return self.source[index] def __getslice__(self, i, j): return self.source[i:j] def _get_number_pattern(self): # HACK: It is merely an approaximation and does the job integer = r'(0|0x)?[\da-fA-F]+[lL]?' return r'(%s(\.\d*)?|(\.\d+))([eE][-+]?\d*)?[jJ]?' % integer _string_pattern = None _number_pattern = None _not_equals_pattern = None