import re from rope.base import ast, codeanalyze from rope.base.change import ChangeSet, ChangeContents from rope.base.exceptions import RefactoringError from rope.refactor import (sourceutils, similarfinder, patchedast, suites, usefunction) # Extract refactoring has lots of special cases. I tried to split it # to smaller parts to make it more manageable: # # _ExtractInfo: holds information about the refactoring; it is passed # to the parts that need to have information about the refactoring # # _ExtractCollector: merely saves all of the information necessary for # performing the refactoring. # # _DefinitionLocationFinder: finds where to insert the definition. # # _ExceptionalConditionChecker: checks for exceptional conditions in # which the refactoring cannot be applied. # # _ExtractMethodParts: generates the pieces of code (like definition) # needed for performing extract method. # # _ExtractVariableParts: like _ExtractMethodParts for variables. # # _ExtractPerformer: Uses above classes to collect refactoring # changes. # # There are a few more helper functions and classes used by above # classes. class _ExtractRefactoring(object): def __init__(self, project, resource, start_offset, end_offset, variable=False): self.project = project self.pycore = project.pycore self.resource = resource self.start_offset = self._fix_start(resource.read(), start_offset) self.end_offset = self._fix_end(resource.read(), end_offset) def _fix_start(self, source, offset): while offset < len(source) and source[offset].isspace(): offset += 1 return offset def _fix_end(self, source, offset): while offset > 0 and source[offset - 1].isspace(): offset -= 1 return offset def get_changes(self, extracted_name, similar=False, global_=False): """Get the changes this refactoring makes :parameters: - `similar`: if `True`, similar expressions/statements are also replaced. - `global_`: if `True`, the extracted method/variable will be global. """ info = _ExtractInfo( self.project, self.resource, self.start_offset, self.end_offset, extracted_name, variable=self.kind == 'variable', similar=similar, make_global=global_) new_contents = _ExtractPerformer(info).extract() changes = ChangeSet('Extract %s <%s>' % (self.kind, extracted_name)) changes.add_change(ChangeContents(self.resource, new_contents)) return changes class ExtractMethod(_ExtractRefactoring): def __init__(self, *args, **kwds): super(ExtractMethod, self).__init__(*args, **kwds) kind = 'method' class ExtractVariable(_ExtractRefactoring): def __init__(self, *args, **kwds): kwds = dict(kwds) kwds['variable'] = True super(ExtractVariable, self).__init__(*args, **kwds) kind = 'variable' class _ExtractInfo(object): """Holds information about the extract to be performed""" def __init__(self, project, resource, start, end, new_name, variable, similar, make_global): self.pycore = project.pycore self.resource = resource self.pymodule = self.pycore.resource_to_pyobject(resource) self.global_scope = self.pymodule.get_scope() self.source = self.pymodule.source_code self.lines = self.pymodule.lines self.new_name = new_name self.variable = variable self.similar = similar self._init_parts(start, end) self._init_scope() self.make_global = make_global def _init_parts(self, start, end): self.region = (self._choose_closest_line_end(start), self._choose_closest_line_end(end, end=True)) start = self.logical_lines.logical_line_in( self.lines.get_line_number(self.region[0]))[0] end = self.logical_lines.logical_line_in( self.lines.get_line_number(self.region[1]))[1] self.region_lines = (start, end) self.lines_region = (self.lines.get_line_start(self.region_lines[0]), self.lines.get_line_end(self.region_lines[1])) @property def logical_lines(self): return self.pymodule.logical_lines def _init_scope(self): start_line = self.region_lines[0] scope = self.global_scope.get_inner_scope_for_line(start_line) if scope.get_kind() != 'Module' and scope.get_start() == start_line: scope = scope.parent self.scope = scope self.scope_region = self._get_scope_region(self.scope) def _get_scope_region(self, scope): return (self.lines.get_line_start(scope.get_start()), self.lines.get_line_end(scope.get_end()) + 1) def _choose_closest_line_end(self, offset, end=False): lineno = self.lines.get_line_number(offset) line_start = self.lines.get_line_start(lineno) line_end = self.lines.get_line_end(lineno) if self.source[line_start:offset].strip() == '': if end: return line_start - 1 else: return line_start elif self.source[offset:line_end].strip() == '': return min(line_end, len(self.source)) return offset @property def one_line(self): return self.region != self.lines_region and \ (self.logical_lines.logical_line_in(self.region_lines[0]) == self.logical_lines.logical_line_in(self.region_lines[1])) @property def global_(self): return self.scope.parent is None @property def method(self): return self.scope.parent is not None and \ self.scope.parent.get_kind() == 'Class' @property def indents(self): return sourceutils.get_indents(self.pymodule.lines, self.region_lines[0]) @property def scope_indents(self): if self.global_: return 0 return sourceutils.get_indents(self.pymodule.lines, self.scope.get_start()) @property def extracted(self): return self.source[self.region[0]:self.region[1]] _returned = None @property def returned(self): """Does the extracted piece contain return statement""" if self._returned is None: node = _parse_text(self.extracted) self._returned = usefunction._returns_last(node) return self._returned class _ExtractCollector(object): """Collects information needed for performing the extract""" def __init__(self, info): self.definition = None self.body_pattern = None self.checks = {} self.replacement_pattern = None self.matches = None self.replacements = None self.definition_location = None class _ExtractPerformer(object): def __init__(self, info): self.info = info _ExceptionalConditionChecker()(self.info) def extract(self): extract_info = self._collect_info() content = codeanalyze.ChangeCollector(self.info.source) definition = extract_info.definition lineno, indents = extract_info.definition_location offset = self.info.lines.get_line_start(lineno) indented = sourceutils.fix_indentation(definition, indents) content.add_change(offset, offset, indented) self._replace_occurrences(content, extract_info) return content.get_changed() def _replace_occurrences(self, content, extract_info): for match in extract_info.matches: replacement = similarfinder.CodeTemplate( extract_info.replacement_pattern) mapping = {} for name in replacement.get_names(): node = match.get_ast(name) if node: start, end = patchedast.node_region(match.get_ast(name)) mapping[name] = self.info.source[start:end] else: mapping[name] = name region = match.get_region() content.add_change(region[0], region[1], replacement.substitute(mapping)) def _collect_info(self): extract_collector = _ExtractCollector(self.info) self._find_definition(extract_collector) self._find_matches(extract_collector) self._find_definition_location(extract_collector) return extract_collector def _find_matches(self, collector): regions = self._where_to_search() finder = similarfinder.SimilarFinder(self.info.pymodule) matches = [] for start, end in regions: matches.extend((finder.get_matches(collector.body_pattern, collector.checks, start, end))) collector.matches = matches def _where_to_search(self): if self.info.similar: if self.info.make_global or self.info.global_: return [(0, len(self.info.pymodule.source_code))] if self.info.method and not self.info.variable: class_scope = self.info.scope.parent regions = [] method_kind = _get_function_kind(self.info.scope) for scope in class_scope.get_scopes(): if method_kind == 'method' and \ _get_function_kind(scope) != 'method': continue start = self.info.lines.get_line_start(scope.get_start()) end = self.info.lines.get_line_end(scope.get_end()) regions.append((start, end)) return regions else: if self.info.variable: return [self.info.scope_region] else: return [self.info._get_scope_region(self.info.scope.parent)] else: return [self.info.region] def _find_definition_location(self, collector): matched_lines = [] for match in collector.matches: start = self.info.lines.get_line_number(match.get_region()[0]) start_line = self.info.logical_lines.logical_line_in(start)[0] matched_lines.append(start_line) location_finder = _DefinitionLocationFinder(self.info, matched_lines) collector.definition_location = (location_finder.find_lineno(), location_finder.find_indents()) def _find_definition(self, collector): if self.info.variable: parts = _ExtractVariableParts(self.info) else: parts = _ExtractMethodParts(self.info) collector.definition = parts.get_definition() collector.body_pattern = parts.get_body_pattern() collector.replacement_pattern = parts.get_replacement_pattern() collector.checks = parts.get_checks() class _DefinitionLocationFinder(object): def __init__(self, info, matched_lines): self.info = info self.matched_lines = matched_lines # This only happens when subexpressions cannot be matched if not matched_lines: self.matched_lines.append(self.info.region_lines[0]) def find_lineno(self): if self.info.variable and not self.info.make_global: return self._get_before_line() if self.info.make_global or self.info.global_: toplevel = self._find_toplevel(self.info.scope) ast = self.info.pymodule.get_ast() newlines = sorted(self.matched_lines + [toplevel.get_end() + 1]) return suites.find_visible(ast, newlines) return self._get_after_scope() def _find_toplevel(self, scope): toplevel = scope if toplevel.parent is not None: while toplevel.parent.parent is not None: toplevel = toplevel.parent return toplevel def find_indents(self): if self.info.variable and not self.info.make_global: return sourceutils.get_indents(self.info.lines, self._get_before_line()) else: if self.info.global_ or self.info.make_global: return 0 return self.info.scope_indents def _get_before_line(self): ast = self.info.scope.pyobject.get_ast() return suites.find_visible(ast, self.matched_lines) def _get_after_scope(self): return self.info.scope.get_end() + 1 class _ExceptionalConditionChecker(object): def __call__(self, info): self.base_conditions(info) if info.one_line: self.one_line_conditions(info) else: self.multi_line_conditions(info) def base_conditions(self, info): if info.region[1] > info.scope_region[1]: raise RefactoringError('Bad region selected for extract method') end_line = info.region_lines[1] end_scope = info.global_scope.get_inner_scope_for_line(end_line) if end_scope != info.scope and end_scope.get_end() != end_line: raise RefactoringError('Bad region selected for extract method') try: extracted = info.source[info.region[0]:info.region[1]] if info.one_line: extracted = '(%s)' % extracted if _UnmatchedBreakOrContinueFinder.has_errors(extracted): raise RefactoringError('A break/continue without having a ' 'matching for/while loop.') except SyntaxError: raise RefactoringError('Extracted piece should ' 'contain complete statements.') def one_line_conditions(self, info): if self._is_region_on_a_word(info): raise RefactoringError('Should extract complete statements.') if info.variable and not info.one_line: raise RefactoringError('Extract variable should not ' 'span multiple lines.') def multi_line_conditions(self, info): node = _parse_text(info.source[info.region[0]:info.region[1]]) count = usefunction._return_count(node) if count > 1: raise RefactoringError('Extracted piece can have only one ' 'return statement.') if usefunction._yield_count(node): raise RefactoringError('Extracted piece cannot ' 'have yield statements.') if count == 1 and not usefunction._returns_last(node): raise RefactoringError('Return should be the last statement.') if info.region != info.lines_region: raise RefactoringError('Extracted piece should ' 'contain complete statements.') def _is_region_on_a_word(self, info): if info.region[0] > 0 and self._is_on_a_word(info, info.region[0] - 1) or \ self._is_on_a_word(info, info.region[1] - 1): return True def _is_on_a_word(self, info, offset): prev = info.source[offset] if not (prev.isalnum() or prev == '_') or \ offset + 1 == len(info.source): return False next = info.source[offset + 1] return next.isalnum() or next == '_' class _ExtractMethodParts(object): def __init__(self, info): self.info = info self.info_collector = self._create_info_collector() def get_definition(self): if self.info.global_: return '\n%s\n' % self._get_function_definition() else: return '\n%s' % self._get_function_definition() def get_replacement_pattern(self): variables = [] variables.extend(self._find_function_arguments()) variables.extend(self._find_function_returns()) return similarfinder.make_pattern(self._get_call(), variables) def get_body_pattern(self): variables = [] variables.extend(self._find_function_arguments()) variables.extend(self._find_function_returns()) variables.extend(self._find_temps()) return similarfinder.make_pattern(self._get_body(), variables) def _get_body(self): result = sourceutils.fix_indentation(self.info.extracted, 0) if self.info.one_line: result = '(%s)' % result return result def _find_temps(self): return usefunction.find_temps(self.info.pycore.project, self._get_body()) def get_checks(self): if self.info.method and not self.info.make_global: if _get_function_kind(self.info.scope) == 'method': class_name = similarfinder._pydefined_to_str( self.info.scope.parent.pyobject) return {self._get_self_name(): 'type=' + class_name} return {} def _create_info_collector(self): zero = self.info.scope.get_start() - 1 start_line = self.info.region_lines[0] - zero end_line = self.info.region_lines[1] - zero info_collector = _FunctionInformationCollector(start_line, end_line, self.info.global_) body = self.info.source[self.info.scope_region[0]: self.info.scope_region[1]] node = _parse_text(body) ast.walk(node, info_collector) return info_collector def _get_function_definition(self): args = self._find_function_arguments() returns = self._find_function_returns() result = [] if self.info.method and not self.info.make_global and \ _get_function_kind(self.info.scope) != 'method': result.append('@staticmethod\n') result.append('def %s:\n' % self._get_function_signature(args)) unindented_body = self._get_unindented_function_body(returns) indents = sourceutils.get_indent(self.info.pycore) function_body = sourceutils.indent_lines(unindented_body, indents) result.append(function_body) definition = ''.join(result) return definition + '\n' def _get_function_signature(self, args): args = list(args) prefix = '' if self._extracting_method(): self_name = self._get_self_name() if self_name is None: raise RefactoringError('Extracting a method from a function ' 'with no self argument.') if self_name in args: args.remove(self_name) args.insert(0, self_name) return prefix + self.info.new_name + \ '(%s)' % self._get_comma_form(args) def _extracting_method(self): return self.info.method and not self.info.make_global and \ _get_function_kind(self.info.scope) == 'method' def _get_self_name(self): param_names = self.info.scope.pyobject.get_param_names() if param_names: return param_names[0] def _get_function_call(self, args): prefix = '' if self.info.method and not self.info.make_global: if _get_function_kind(self.info.scope) == 'method': self_name = self._get_self_name() if self_name in args: args.remove(self_name) prefix = self_name + '.' else: prefix = self.info.scope.parent.pyobject.get_name() + '.' return prefix + '%s(%s)' % (self.info.new_name, self._get_comma_form(args)) def _get_comma_form(self, names): result = '' if names: result += names[0] for name in names[1:]: result += ', ' + name return result def _get_call(self): if self.info.one_line: args = self._find_function_arguments() return self._get_function_call(args) args = self._find_function_arguments() returns = self._find_function_returns() call_prefix = '' if returns: call_prefix = self._get_comma_form(returns) + ' = ' if self.info.returned: call_prefix = 'return ' return call_prefix + self._get_function_call(args) def _find_function_arguments(self): # if not make_global, do not pass any global names; they are # all visible. if self.info.global_ and not self.info.make_global: return () if not self.info.one_line: result = (self.info_collector.prewritten & self.info_collector.read) result |= (self.info_collector.prewritten & self.info_collector.postread & (self.info_collector.maybe_written - self.info_collector.written)) return list(result) start = self.info.region[0] if start == self.info.lines_region[0]: start = start + re.search('\S', self.info.extracted).start() function_definition = self.info.source[start:self.info.region[1]] read = _VariableReadsAndWritesFinder.find_reads_for_one_liners( function_definition) return list(self.info_collector.prewritten.intersection(read)) def _find_function_returns(self): if self.info.one_line or self.info.returned: return [] written = self.info_collector.written | \ self.info_collector.maybe_written return list(written & self.info_collector.postread) def _get_unindented_function_body(self, returns): if self.info.one_line: return 'return ' + _join_lines(self.info.extracted) extracted_body = self.info.extracted unindented_body = sourceutils.fix_indentation(extracted_body, 0) if returns: unindented_body += '\nreturn %s' % self._get_comma_form(returns) return unindented_body class _ExtractVariableParts(object): def __init__(self, info): self.info = info def get_definition(self): result = self.info.new_name + ' = ' + \ _join_lines(self.info.extracted) + '\n' return result def get_body_pattern(self): return '(%s)' % self.info.extracted.strip() def get_replacement_pattern(self): return self.info.new_name def get_checks(self): return {} class _FunctionInformationCollector(object): def __init__(self, start, end, is_global): self.start = start self.end = end self.is_global = is_global self.prewritten = set() self.maybe_written = set() self.written = set() self.read = set() self.postread = set() self.postwritten = set() self.host_function = True self.conditional = False def _read_variable(self, name, lineno): if self.start <= lineno <= self.end: if name not in self.written: self.read.add(name) if self.end < lineno: if name not in self.postwritten: self.postread.add(name) def _written_variable(self, name, lineno): if self.start <= lineno <= self.end: if self.conditional: self.maybe_written.add(name) else: self.written.add(name) if self.start > lineno: self.prewritten.add(name) if self.end < lineno: self.postwritten.add(name) def _FunctionDef(self, node): if not self.is_global and self.host_function: self.host_function = False for name in _get_argnames(node.args): self._written_variable(name, node.lineno) for child in node.body: ast.walk(child, self) else: self._written_variable(node.name, node.lineno) visitor = _VariableReadsAndWritesFinder() for child in node.body: ast.walk(child, visitor) for name in visitor.read - visitor.written: self._read_variable(name, node.lineno) def _Name(self, node): if isinstance(node.ctx, (ast.Store, ast.AugStore)): self._written_variable(node.id, node.lineno) if not isinstance(node.ctx, ast.Store): self._read_variable(node.id, node.lineno) def _Assign(self, node): ast.walk(node.value, self) for child in node.targets: ast.walk(child, self) def _ClassDef(self, node): self._written_variable(node.name, node.lineno) def _handle_conditional_node(self, node): self.conditional = True try: for child in ast.get_child_nodes(node): ast.walk(child, self) finally: self.conditional = False def _If(self, node): self._handle_conditional_node(node) def _While(self, node): self._handle_conditional_node(node) def _For(self, node): self._handle_conditional_node(node) def _get_argnames(arguments): result = [node.id for node in arguments.args if isinstance(node, ast.Name)] if arguments.vararg: result.append(arguments.vararg) if arguments.kwarg: result.append(arguments.kwarg) return result class _VariableReadsAndWritesFinder(object): def __init__(self): self.written = set() self.read = set() def _Name(self, node): if isinstance(node.ctx, (ast.Store, ast.AugStore)): self.written.add(node.id) if not isinstance(node, ast.Store): self.read.add(node.id) def _FunctionDef(self, node): self.written.add(node.name) visitor = _VariableReadsAndWritesFinder() for child in ast.get_child_nodes(node): ast.walk(child, visitor) self.read.update(visitor.read - visitor.written) def _Class(self, node): self.written.add(node.name) @staticmethod def find_reads_and_writes(code): if code.strip() == '': return set(), set() if isinstance(code, unicode): code = code.encode('utf-8') node = _parse_text(code) visitor = _VariableReadsAndWritesFinder() ast.walk(node, visitor) return visitor.read, visitor.written @staticmethod def find_reads_for_one_liners(code): if code.strip() == '': return set(), set() node = _parse_text(code) visitor = _VariableReadsAndWritesFinder() ast.walk(node, visitor) return visitor.read class _UnmatchedBreakOrContinueFinder(object): def __init__(self): self.error = False self.loop_count = 0 def _For(self, node): self.loop_encountered(node) def _While(self, node): self.loop_encountered(node) def loop_encountered(self, node): self.loop_count += 1 for child in node.body: ast.walk(child, self) self.loop_count -= 1 if node.orelse: ast.walk(node.orelse, self) def _Break(self, node): self.check_loop() def _Continue(self, node): self.check_loop() def check_loop(self): if self.loop_count < 1: self.error = True def _FunctionDef(self, node): pass def _ClassDef(self, node): pass @staticmethod def has_errors(code): if code.strip() == '': return False node = _parse_text(code) visitor = _UnmatchedBreakOrContinueFinder() ast.walk(node, visitor) return visitor.error def _get_function_kind(scope): return scope.pyobject.get_kind() def _parse_text(body): body = sourceutils.fix_indentation(body, 0) node = ast.parse(body) return node def _join_lines(code): lines = [] for line in code.splitlines(): if line.endswith('\\'): lines.append(line[:-1].strip()) else: lines.append(line.strip()) return ' '.join(lines)