# Copyright 2016 Grist Labs, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import collections import io import sys import token import tokenize from abc import ABCMeta from ast import Module, expr, AST from functools import lru_cache from typing import ( Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union, cast, Any, TYPE_CHECKING, Type, ) if TYPE_CHECKING: # pragma: no cover from .astroid_compat import NodeNG # Type class used to expand out the definition of AST to include fields added by this library # It's not actually used for anything other than type checking though! class EnhancedAST(AST): # Additional attributes set by mark_tokens first_token = None # type: Token last_token = None # type: Token lineno = 0 # type: int AstNode = Union[EnhancedAST, NodeNG] TokenInfo = tokenize.TokenInfo def token_repr(tok_type, string): # type: (int, Optional[str]) -> str """Returns a human-friendly representation of a token with the given type and string.""" # repr() prefixes unicode with 'u' on Python2 but not Python3; strip it out for consistency. return '%s:%s' % (token.tok_name[tok_type], repr(string).lstrip('u')) class Token(collections.namedtuple('Token', 'type string start end line index startpos endpos')): """ TokenInfo is an 8-tuple containing the same 5 fields as the tokens produced by the tokenize module, and 3 additional ones useful for this module: - [0] .type Token type (see token.py) - [1] .string Token (a string) - [2] .start Starting (row, column) indices of the token (a 2-tuple of ints) - [3] .end Ending (row, column) indices of the token (a 2-tuple of ints) - [4] .line Original line (string) - [5] .index Index of the token in the list of tokens that it belongs to. - [6] .startpos Starting character offset into the input text. - [7] .endpos Ending character offset into the input text. """ def __str__(self): # type: () -> str return token_repr(self.type, self.string) def match_token(token, tok_type, tok_str=None): # type: (Token, int, Optional[str]) -> bool """Returns true if token is of the given type and, if a string is given, has that string.""" return token.type == tok_type and (tok_str is None or token.string == tok_str) def expect_token(token, tok_type, tok_str=None): # type: (Token, int, Optional[str]) -> None """ Verifies that the given token is of the expected type. If tok_str is given, the token string is verified too. If the token doesn't match, raises an informative ValueError. """ if not match_token(token, tok_type, tok_str): raise ValueError("Expected token %s, got %s on line %s col %s" % ( token_repr(tok_type, tok_str), str(token), token.start[0], token.start[1] + 1)) def is_non_coding_token(token_type): # type: (int) -> bool """ These are considered non-coding tokens, as they don't affect the syntax tree. """ return token_type in (token.NL, token.COMMENT, token.ENCODING) def generate_tokens(text): # type: (str) -> Iterator[TokenInfo] """ Generates standard library tokens for the given code. """ # tokenize.generate_tokens is technically an undocumented API for Python3, but allows us to use the same API as for # Python2. See http://stackoverflow.com/a/4952291/328565. # FIXME: Remove cast once https://github.com/python/typeshed/issues/7003 gets fixed return tokenize.generate_tokens(cast(Callable[[], str], io.StringIO(text).readline)) def iter_children_func(node): # type: (AST) -> Callable """ Returns a function which yields all direct children of a AST node, skipping children that are singleton nodes. The function depends on whether ``node`` is from ``ast`` or from the ``astroid`` module. """ return iter_children_astroid if hasattr(node, 'get_children') else iter_children_ast def iter_children_astroid(node, include_joined_str=False): # type: (NodeNG, bool) -> Union[Iterator, List] if not include_joined_str and is_joined_str(node): return [] return node.get_children() SINGLETONS = {c for n, c in ast.__dict__.items() if isinstance(c, type) and issubclass(c, (ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop))} def iter_children_ast(node, include_joined_str=False): # type: (AST, bool) -> Iterator[Union[AST, expr]] if not include_joined_str and is_joined_str(node): return if isinstance(node, ast.Dict): # override the iteration order: instead of <all keys>, <all values>, # yield keys and values in source order (key1, value1, key2, value2, ...) for (key, value) in zip(node.keys, node.values): if key is not None: yield key yield value return for child in ast.iter_child_nodes(node): # Skip singleton children; they don't reflect particular positions in the code and break the # assumptions about the tree consisting of distinct nodes. Note that collecting classes # beforehand and checking them in a set is faster than using isinstance each time. if child.__class__ not in SINGLETONS: yield child stmt_class_names = {n for n, c in ast.__dict__.items() if isinstance(c, type) and issubclass(c, ast.stmt)} expr_class_names = ({n for n, c in ast.__dict__.items() if isinstance(c, type) and issubclass(c, ast.expr)} | {'AssignName', 'DelName', 'Const', 'AssignAttr', 'DelAttr'}) # These feel hacky compared to isinstance() but allow us to work with both ast and astroid nodes # in the same way, and without even importing astroid. def is_expr(node): # type: (AstNode) -> bool """Returns whether node is an expression node.""" return node.__class__.__name__ in expr_class_names def is_stmt(node): # type: (AstNode) -> bool """Returns whether node is a statement node.""" return node.__class__.__name__ in stmt_class_names def is_module(node): # type: (AstNode) -> bool """Returns whether node is a module node.""" return node.__class__.__name__ == 'Module' def is_joined_str(node): # type: (AstNode) -> bool """Returns whether node is a JoinedStr node, used to represent f-strings.""" # At the moment, nodes below JoinedStr have wrong line/col info, and trying to process them only # leads to errors. return node.__class__.__name__ == 'JoinedStr' def is_expr_stmt(node): # type: (AstNode) -> bool """Returns whether node is an `Expr` node, which is a statement that is an expression.""" return node.__class__.__name__ == 'Expr' CONSTANT_CLASSES: Tuple[Type, ...] = (ast.Constant,) try: from astroid import Const CONSTANT_CLASSES += (Const,) except ImportError: # pragma: no cover # astroid is not available pass def is_constant(node): # type: (AstNode) -> bool """Returns whether node is a Constant node.""" return isinstance(node, CONSTANT_CLASSES) def is_ellipsis(node): # type: (AstNode) -> bool """Returns whether node is an Ellipsis node.""" return is_constant(node) and node.value is Ellipsis # type: ignore def is_starred(node): # type: (AstNode) -> bool """Returns whether node is a starred expression node.""" return node.__class__.__name__ == 'Starred' def is_slice(node): # type: (AstNode) -> bool """Returns whether node represents a slice, e.g. `1:2` in `x[1:2]`""" # Before 3.9, a tuple containing a slice is an ExtSlice, # but this was removed in https://bugs.python.org/issue34822 return ( node.__class__.__name__ in ('Slice', 'ExtSlice') or ( node.__class__.__name__ == 'Tuple' and any(map(is_slice, cast(ast.Tuple, node).elts)) ) ) def is_empty_astroid_slice(node): # type: (AstNode) -> bool return ( node.__class__.__name__ == "Slice" and not isinstance(node, ast.AST) and node.lower is node.upper is node.step is None ) # Sentinel value used by visit_tree(). _PREVISIT = object() def visit_tree(node, previsit, postvisit): # type: (Module, Callable[[AstNode, Optional[Token]], Tuple[Optional[Token], Optional[Token]]], Optional[Callable[[AstNode, Optional[Token], Optional[Token]], None]]) -> None """ Scans the tree under the node depth-first using an explicit stack. It avoids implicit recursion via the function call stack to avoid hitting 'maximum recursion depth exceeded' error. It calls ``previsit()`` and ``postvisit()`` as follows: * ``previsit(node, par_value)`` - should return ``(par_value, value)`` ``par_value`` is as returned from ``previsit()`` of the parent. * ``postvisit(node, par_value, value)`` - should return ``value`` ``par_value`` is as returned from ``previsit()`` of the parent, and ``value`` is as returned from ``previsit()`` of this node itself. The return ``value`` is ignored except the one for the root node, which is returned from the overall ``visit_tree()`` call. For the initial node, ``par_value`` is None. ``postvisit`` may be None. """ if not postvisit: postvisit = lambda node, pvalue, value: None iter_children = iter_children_func(node) done = set() ret = None stack = [(node, None, _PREVISIT)] # type: List[Tuple[AstNode, Optional[Token], Union[Optional[Token], object]]] while stack: current, par_value, value = stack.pop() if value is _PREVISIT: assert current not in done # protect againt infinite loop in case of a bad tree. done.add(current) pvalue, post_value = previsit(current, par_value) stack.append((current, par_value, post_value)) # Insert all children in reverse order (so that first child ends up on top of the stack). ins = len(stack) for n in iter_children(current): stack.insert(ins, (n, pvalue, _PREVISIT)) else: ret = postvisit(current, par_value, cast(Optional[Token], value)) return ret def walk(node, include_joined_str=False): # type: (AST, bool) -> Iterator[Union[Module, AstNode]] """ Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node`` itself), using depth-first pre-order traversal (yieling parents before their children). This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and ``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``. By default, ``JoinedStr`` (f-string) nodes and their contents are skipped because they previously couldn't be handled. Set ``include_joined_str`` to True to include them. """ iter_children = iter_children_func(node) done = set() stack = [node] while stack: current = stack.pop() assert current not in done # protect againt infinite loop in case of a bad tree. done.add(current) yield current # Insert all children in reverse order (so that first child ends up on top of the stack). # This is faster than building a list and reversing it. ins = len(stack) for c in iter_children(current, include_joined_str): stack.insert(ins, c) def replace(text, replacements): # type: (str, List[Tuple[int, int, str]]) -> str """ Replaces multiple slices of text with new values. This is a convenience method for making code modifications of ranges e.g. as identified by ``ASTTokens.get_text_range(node)``. Replacements is an iterable of ``(start, end, new_text)`` tuples. For example, ``replace("this is a test", [(0, 4, "X"), (8, 9, "THE")])`` produces ``"X is THE test"``. """ p = 0 parts = [] for (start, end, new_text) in sorted(replacements): parts.append(text[p:start]) parts.append(new_text) p = end parts.append(text[p:]) return ''.join(parts) class NodeMethods: """ Helper to get `visit_{node_type}` methods given a node's class and cache the results. """ def __init__(self): # type: () -> None self._cache = {} # type: Dict[Union[ABCMeta, type], Callable[[AstNode, Token, Token], Tuple[Token, Token]]] def get(self, obj, cls): # type: (Any, Union[ABCMeta, type]) -> Callable """ Using the lowercase name of the class as node_type, returns `obj.visit_{node_type}`, or `obj.visit_default` if the type-specific method is not found. """ method = self._cache.get(cls) if not method: name = "visit_" + cls.__name__.lower() method = getattr(obj, name, obj.visit_default) self._cache[cls] = method return method def patched_generate_tokens(original_tokens): # type: (Iterable[TokenInfo]) -> Iterator[TokenInfo] """ Fixes tokens yielded by `tokenize.generate_tokens` to handle more non-ASCII characters in identifiers. Workaround for https://github.com/python/cpython/issues/68382. Should only be used when tokenizing a string that is known to be valid syntax, because it assumes that error tokens are not actually errors. Combines groups of consecutive NAME, NUMBER, and/or ERRORTOKEN tokens into a single NAME token. """ group = [] # type: List[tokenize.TokenInfo] for tok in original_tokens: if ( tok.type in (tokenize.NAME, tokenize.ERRORTOKEN, tokenize.NUMBER) # Only combine tokens if they have no whitespace in between and (not group or group[-1].end == tok.start) ): group.append(tok) else: for combined_token in combine_tokens(group): yield combined_token group = [] yield tok for combined_token in combine_tokens(group): yield combined_token def combine_tokens(group): # type: (List[tokenize.TokenInfo]) -> List[tokenize.TokenInfo] if not any(tok.type == tokenize.ERRORTOKEN for tok in group) or len({tok.line for tok in group}) != 1: return group return [ tokenize.TokenInfo( type=tokenize.NAME, string="".join(t.string for t in group), start=group[0].start, end=group[-1].end, line=group[0].line, ) ] def last_stmt(node): # type: (ast.AST) -> ast.AST """ If the given AST node contains multiple statements, return the last one. Otherwise, just return the node. """ child_stmts = [ child for child in iter_children_func(node)(node) if is_stmt(child) or type(child).__name__ in ( "excepthandler", "ExceptHandler", "match_case", "MatchCase", "TryExcept", "TryFinally", ) ] if child_stmts: return last_stmt(child_stmts[-1]) return node @lru_cache(maxsize=None) def fstring_positions_work(): # type: () -> bool """ The positions attached to nodes inside f-string FormattedValues have some bugs that were fixed in Python 3.9.7 in https://github.com/python/cpython/pull/27729. This checks for those bugs more concretely without relying on the Python version. Specifically this checks: - Values with a format spec or conversion - Repeated (i.e. identical-looking) expressions - f-strings implicitly concatenated over multiple lines. - Multiline, triple-quoted f-strings. """ source = """( f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}" f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}" f"{x + y + z} {x} {y} {z} {z} {z!a} {z:z}" f''' {s} {t} {u} {v} ''' )""" tree = ast.parse(source) name_nodes = [node for node in ast.walk(tree) if isinstance(node, ast.Name)] name_positions = [(node.lineno, node.col_offset) for node in name_nodes] positions_are_unique = len(set(name_positions)) == len(name_positions) correct_source_segments = all( ast.get_source_segment(source, node) == node.id for node in name_nodes ) return positions_are_unique and correct_source_segments def annotate_fstring_nodes(tree): # type: (ast.AST) -> None """ Add a special attribute `_broken_positions` to nodes inside f-strings if the lineno/col_offset cannot be trusted. """ if sys.version_info >= (3, 12): # f-strings were weirdly implemented until https://peps.python.org/pep-0701/ # In Python 3.12, inner nodes have sensible positions. return for joinedstr in walk(tree, include_joined_str=True): if not isinstance(joinedstr, ast.JoinedStr): continue for part in joinedstr.values: # The ast positions of the FormattedValues/Constant nodes span the full f-string, which is weird. setattr(part, '_broken_positions', True) # use setattr for mypy if isinstance(part, ast.FormattedValue): if not fstring_positions_work(): for child in walk(part.value): setattr(child, '_broken_positions', True) if part.format_spec: # this is another JoinedStr # Again, the standard positions span the full f-string. setattr(part.format_spec, '_broken_positions', True)