"""Module for parsing input file and walking ast."""
import ast
import re
import sys
from collections.abc import Iterator
from typing import Optional, Union, get_args, overload
from typing_extensions import TypeGuard
from .const import DEFAULT_EXCEPTION
from .types import (
BodyTypes,
ClassDocstring,
DocstringInfo,
ElementDocstring,
FixerSettings,
FunctionBody,
FunctionDocstring,
FunctionSignature,
ModuleDocstring,
NodeOfInterest,
Parameter,
ReturnValue,
)
__author__ = "J-E. Nitschke"
__copyright__ = "Copyright 2023-2024"
__licence__ = "MIT"
__version__ = "1.1.0"
__maintainer__ = "J-E. Nitschke"
@overload
def ast_unparse(node: None) -> None: ...
@overload
def ast_unparse(node: ast.AST) -> str: ...
[docs]
def ast_unparse(node: Optional[ast.AST]) -> Optional[str]:
"""Convert the AST node to source code as a string.
Parameters
----------
node : Optional[ast.AST]
Node to unparse.
Returns
-------
Optional[str]
`None` if `node` was `None`.
Otherwise the unparsed node.
"""
if node is None:
return None
return ast.unparse(node)
[docs]
class FunctionNodeVisitor: # pylint: disable=too-few-public-methods
"""Visit all subnodes of the function."""
def __init__(
self, start_node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
"""Visit all subnodes of the function.
Collect returns, yields and raises.
Discard returns and yields from nested functions.
Parameters
----------
start_node : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node to start traversal from.
"""
self.name = start_node.name
self.returns: set[tuple[str, ...]] = set()
self.returns_value = False
self.yields: set[tuple[str, ...]] = set()
self.yields_value = False
self.raises: list[str] = []
self._inside_nested_function = 0
self._visit(start_node)
def _visit(self, node: ast.AST) -> None:
"""Visit a node."""
method = "_visit_" + node.__class__.__name__
visitor = getattr(self, method, self._generic_visit)
visitor(node)
def _generic_visit(self, node: ast.AST) -> None:
"""Call if no explicit visitor function exists for a node."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value: # pyright: ignore[reportUnknownVariableType]
if isinstance(item, ast.AST):
self._visit(item)
elif isinstance(value, ast.AST):
self._visit(value)
def _visit_FunctionDef(self, node: ast.FunctionDef) -> None: # noqa: N802 # pylint: disable=invalid-name
"""Keep track of nested function depth.
Parameters
----------
node : ast.FunctionDef
Current node in the traversal.
"""
nested_function = self._inside_nested_function
self._inside_nested_function += int(nested_function)
self._generic_visit(node)
self._inside_nested_function -= int(nested_function)
def _visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa: N802 # pylint: disable=invalid-name
"""Keep track of nested function depth.
Parameters
----------
node : ast.AsyncFunctionDef
Current node in the traversal.
"""
nested_function = self._inside_nested_function
self._inside_nested_function += int(nested_function)
self._generic_visit(node)
self._inside_nested_function -= int(nested_function)
def _visit_Return(self, node: ast.Return) -> None: # noqa: N802 # pylint: disable=invalid-name
"""Do not process returns from nested functions.
Parameters
----------
node : ast.Return
Current node in the traversal.
"""
if not self._inside_nested_function and node.value is not None:
self.returns_value = True
if isinstance(node.value, ast.Tuple) and all(
isinstance(value, ast.Name) for value in node.value.elts
):
self.returns.add(AstAnalyzer.get_ids_from_returns(node.value.elts))
self._generic_visit(node)
def _visit_Yield(self, node: ast.Yield) -> None: # noqa: N802 # pylint: disable=invalid-name
"""Do not process yields from nested functions.
Parameters
----------
node : ast.Yield
Current node in the traversal.
"""
if not self._inside_nested_function:
self.yields_value = True
if isinstance(node.value, ast.Tuple) and all(
isinstance(value, ast.Name) for value in node.value.elts
):
self.yields.add(AstAnalyzer.get_ids_from_returns(node.value.elts))
self._generic_visit(node)
def _visit_YieldFrom(self, node: ast.YieldFrom) -> None: # noqa: N802 # pylint: disable=invalid-name
"""Do not process yields from nested functions.
Parameters
----------
node : ast.YieldFrom
Current node in the traversal.
"""
if not self._inside_nested_function:
self.yields_value = True
self._generic_visit(node)
def _visit_Raise(self, node: ast.Raise) -> None: # noqa: N802 # pylint: disable=invalid-name
"""Do process raises from nested functions.
Parameters
----------
node : ast.Raise
Current node in the traversal.
"""
pascal_case_regex = r"^(?:[A-Z][a-z]+)+$"
if not node.exc:
self.raises.append(DEFAULT_EXCEPTION)
elif isinstance(node.exc, ast.Name) and re.match(
pascal_case_regex, node.exc.id
):
self.raises.append(node.exc.id)
elif (
isinstance(node.exc, ast.Call)
and isinstance(node.exc.func, ast.Name)
and re.match(pascal_case_regex, node.exc.func.id)
):
self.raises.append(node.exc.func.id)
else:
self.raises.append(DEFAULT_EXCEPTION)
self._generic_visit(node)
[docs]
class AstAnalyzer:
"""Walk ast and extract module, class and function information."""
def __init__(self, file_content: str, *, settings: FixerSettings) -> None:
"""Initialize the Analyzer with the file contents.
The only reason this is a class is to have the raw
file_contents available at any point of the analysis to double check
something. Currently used for the module docstring and docstring
modifiers.
Parameters
----------
file_content : str
File contents to store.
settings : FixerSettings
Settings for what to fix and when.
"""
self.file_content = file_content
self.settings = settings
[docs]
@staticmethod
def func_decorators(
node: Union[ast.FunctionDef, ast.AsyncFunctionDef],
) -> Iterator[str]:
"""Get the names of the decorators of a function node."""
for name in node.decorator_list:
if isinstance(name, ast.Name):
yield name.id
[docs]
def parse_from_ast(
self,
) -> list[ElementDocstring]:
"""Walk AST of the input file extract info about module, classes and functions.
For the module and classes, the raw docstring
and its line numbers are extracted.
For functions the raw docstring and its line numbers are extracted.
Additionally the signature is parsed for parameters and return value.
Returns
-------
list[ElementDocstring]
List of information about module, classes and functions.
Raises
------
AssertionError
If the source file content could not be parsed into an ast.
"""
nodes_of_interest: list[ElementDocstring] = []
try:
file_ast = ast.parse(self.file_content)
except Exception as exc:
msg = f"Failed to parse source file AST: {exc}\n"
raise AssertionError(msg) from exc
for node in ast.walk(file_ast):
if isinstance(node, ast.Module):
nodes_of_interest.append(self.handle_module(node))
elif isinstance(node, ast.ClassDef):
if node.name in self.settings.ignored_classes:
continue
nodes_of_interest.append(self.handle_class(node))
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if (
any(
name in self.settings.ignored_decorators
for name in self.func_decorators(node)
)
or node.name in self.settings.ignored_functions
):
continue
nodes_of_interest.append(self.handle_function(node))
return nodes_of_interest
[docs]
def handle_module(self, module: ast.Module) -> ModuleDocstring:
"""Extract information about module.
Parameters
----------
module : ast.Module
Node representing the full module.
Returns
-------
ModuleDocstring
Docstring representation for the module.
"""
docstring_info = self.get_docstring_info(module)
if docstring_info is None:
docstring_line = self._get_docstring_line()
return ModuleDocstring(
"Module",
docstring="",
lines=(docstring_line, docstring_line),
modifier="",
issues=[],
had_docstring=False,
)
return ModuleDocstring(
name=docstring_info.name,
docstring=docstring_info.docstring,
lines=docstring_info.lines,
modifier=docstring_info.modifier,
issues=docstring_info.issues,
had_docstring=docstring_info.had_docstring,
)
[docs]
def handle_class(self, cls: ast.ClassDef) -> ClassDocstring:
"""Extract information about class docstring.
Parameters
----------
cls : ast.ClassDef
Node representing a class definition.
Returns
-------
ClassDocstring
Docstring representation for a class.
"""
docstring = self.handle_elem_docstring(cls)
attributes, methods = self.handle_class_body(cls)
return ClassDocstring(
name=docstring.name,
docstring=docstring.docstring,
lines=docstring.lines,
modifier=docstring.modifier,
issues=docstring.issues,
attributes=attributes,
methods=methods,
had_docstring=docstring.had_docstring,
)
[docs]
def handle_function(
self,
func: Union[ast.FunctionDef, ast.AsyncFunctionDef],
) -> FunctionDocstring:
"""Extract information from signature and docstring.
Parameters
----------
func : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node representing a function definition.
Returns
-------
FunctionDocstring
Docstring representation of a function.
"""
docstring = self.handle_elem_docstring(func)
signature = self.handle_function_signature(func)
body = self.handle_function_body(func)
# Minus one because the function counts the passed node itself
# Which is correct for each nested node but not the main one.
length = self._get_block_length(func) - 1
return FunctionDocstring(
name=docstring.name,
docstring=docstring.docstring,
lines=docstring.lines,
modifier=docstring.modifier,
issues=docstring.issues,
signature=signature,
body=body,
length=length,
had_docstring=docstring.had_docstring,
)
[docs]
def handle_elem_docstring(
self,
elem: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef],
) -> DocstringInfo:
"""Extract information about the docstring of the function.
Parameters
----------
elem : Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]
Element representing a function or class definition.
Returns
-------
DocstringInfo
Return general information about the docstring of the element.
Raises
------
ValueError
If the element did not have a body at all. This should not happen
for valid functions or classes.
ValueError
If the indent of the function body is not
one level deeper than the definition.
"""
docstring_info = self.get_docstring_info(elem)
if docstring_info is None:
if not elem.body:
msg = "Function body was unexpectedly completely empty."
raise ValueError(msg)
body_elem = elem.body[0]
# Ideally we would use one line after the end of the actual function
# definition. But this does not exist. So we need to use the body.
# However that can start at the same line as the function definition.
# In that case we cant place the docstring between definition and body.
# The col offsets are unlikely to match so try to detect this with a
# good error message.
if body_elem.col_offset != (elem.col_offset + self.settings.indent):
msg = (
"Function body did not start one indentation level"
" deeper than the function body. Can not properly place docstring."
)
raise ValueError(msg)
lineno = body_elem.lineno
if isinstance(
body_elem, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
):
lineno -= len(body_elem.decorator_list)
lines = (lineno, lineno)
return DocstringInfo(
name=elem.name,
docstring="",
lines=lines,
modifier="",
issues=[],
had_docstring=False,
)
return docstring_info
[docs]
def get_docstring_info(self, node: NodeOfInterest) -> Optional[DocstringInfo]:
"""Get docstring and line number if available.
Parameters
----------
node : NodeOfInterest
Get general information about the docstring of any node
if interest.
Returns
-------
Optional[DocstringInfo]
Information about the docstring if the element contains one.
Or `None` if there was no docstring at all.
Raises
------
ValueError
If the first element of the body is not a docstring after
`ast.get_docstring()` returned one.
"""
if ast.get_docstring(node):
if not (
node.body
and isinstance(first_element := node.body[0], ast.Expr)
and isinstance(docnode := first_element.value, ast.Constant)
and isinstance(docnode.value, str)
):
msg = (
"Expected first entry in body to be the "
"docstring, but found nothing or something else."
)
raise ValueError(msg)
modifier = self._get_modifier(
self.file_content.splitlines()[docnode.lineno - 1]
)
return DocstringInfo(
# Can not use DefinitionNodes in isinstance checks before 3.10
name=(
node.name
if isinstance(
node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
)
else "Module"
),
docstring=str(docnode.value),
lines=(docnode.lineno, docnode.end_lineno),
modifier=modifier,
issues=[],
had_docstring=True,
)
return None
def _get_modifier(self, line: str) -> str:
"""Get the string modifier from the start of a docstring.
Parameters
----------
line : str
Line to check
Returns
-------
str
Modifier of the string.
"""
line = line.strip()
delimiters = ['"""', "'''"]
modifiers = ["r", "u"]
if not line:
return ""
if line[:3] in delimiters:
return ""
if line[0].lower() in modifiers and line[1:4] in delimiters:
return line[0]
return ""
def _get_docstring_line(self) -> int:
"""Get the line where the module docstring should start.
Returns
-------
int
Starting line (starts at 1) of the docstring.
"""
shebang_encoding_lines = 2
lines_of_interest = self.file_content.splitlines()[:shebang_encoding_lines]
if not lines_of_interest:
return 1
for index, line in enumerate(lines_of_interest):
if not self.is_shebang_or_pragma(line):
# List indices start at 0 but file lines are counted from 1
return index + 1
return shebang_encoding_lines + 1
def _has_body(self, node: ast.AST) -> TypeGuard[BodyTypes]:
"""Check that the node is one of those that have a body."""
return isinstance(
node,
(get_args(BodyTypes)),
) and hasattr(node, "body")
def _get_block_length(self, node: ast.AST) -> int:
"""Get the number of statements in a block.
Recursively count the number of statements in a blocks body.
Parameters
----------
node : ast.AST
Node representing to count the number of statements for.
Returns
-------
int
Total number of (nested) statements in the block.
"""
# pylint: disable=no-member
if sys.version_info >= (3, 11):
try_nodes = (ast.Try, ast.TryStar)
else:
try_nodes = (ast.Try,)
length = 1
if self._has_body(node) and node.body:
length += sum(self._get_block_length(child) for child in node.body)
# Decorators add complexity, so lets count them for now
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
length += len(node.decorator_list)
elif isinstance(node, (ast.For, ast.AsyncFor, ast.While, ast.If, *try_nodes)):
length += sum(self._get_block_length(child) for child in node.orelse)
if isinstance(node, try_nodes):
length += sum(self._get_block_length(child) for child in node.finalbody)
length += sum(self._get_block_length(child) for child in node.handlers)
elif sys.version_info >= (3, 10) and isinstance(node, ast.Match):
# Each case counts itself + its body.
# This is intended for now as compared to if/else there is a lot
# of logic actually still happening in the case matching.
length += sum(self._get_block_length(child) for child in node.cases)
# We do not want to count the docstring
if (
length
and isinstance(
node,
(ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef, ast.Module),
)
and ast.get_docstring(node)
):
length -= 1
return length
[docs]
def handle_class_body(self, cls: ast.ClassDef) -> tuple[list[Parameter], list[str]]:
"""Extract attributes and methods from class body.
Will walk the AST of the ClassDef node and add each function encountered
as a method.
If the `__init__` method is encountered walk its body for attribute
definitions.
Parameters
----------
cls : ast.ClassDef
Node representing a class definition.
Returns
-------
attributes : list[Parameter]
List of the parameters that make up the classes attributes.
methods : list[str]
List of the method names in the class.
"""
attributes: list[Parameter] = []
methods: list[str] = []
for node in cls.body:
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
# Extract attributes from init method.
if node.name == "__init__":
attributes.extend(self._get_attributes_from_init(node))
# Skip dunder methods for method extraction
if node.name.startswith("__") and node.name.endswith("__"):
continue
# Optionally skip private methods.
if self.settings.ignore_privates and node.name.startswith("_"):
continue
# Handle properties as attributes
if "property" in self.func_decorators(node):
return_value = self.get_return_value_sig(node)
attributes.append(Parameter(node.name, return_value.type_name, None))
# Handle normal methods except for those with some specific decorators
# Like statismethod, classmethod, property or getters/setters.
elif not self._has_excluding_decorator(node):
methods.append(self._get_method_signature(node))
# Exclude some like staticmethods and properties
# Remove duplicates from attributes while maintaining order
return list(Parameter.uniquefy(attributes)), methods
[docs]
def handle_function_signature(
self,
func: Union[ast.FunctionDef, ast.AsyncFunctionDef],
) -> FunctionSignature:
"""Extract information about the signature of the function.
Parameters
----------
func : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node representing a function definition
Returns
-------
FunctionSignature
Information extracted from the function signature
"""
parameters = self.get_parameters_sig(func)
if parameters and (
parameters[0].arg_name == "self"
or (
parameters[0].arg_name == "cls"
and "classmethod" in self.func_decorators(func)
)
):
parameters.pop(0)
return_value = self.get_return_value_sig(func)
return FunctionSignature(parameters, return_value)
[docs]
def handle_function_body(
self,
func: Union[ast.FunctionDef, ast.AsyncFunctionDef],
) -> FunctionBody:
"""Check the function body for yields, raises and value returns.
Parameters
----------
func : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node representing a function definition
Returns
-------
FunctionBody
Information extracted from the function body.
"""
visitor = FunctionNodeVisitor(func)
return FunctionBody(
returns_value=visitor.returns_value,
returns=visitor.returns,
yields_value=visitor.yields_value,
yields=visitor.yields,
raises=visitor.raises,
)
[docs]
def get_return_value_sig(
self, func: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> ReturnValue:
"""Get information about return value from signature.
Parameters
----------
func : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node representing a function definition
Returns
-------
ReturnValue
Return information extracted from the function signature.
"""
return ReturnValue(type_name=ast_unparse(func.returns))
[docs]
def get_parameters_sig(
self, func: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> list[Parameter]:
"""Get information about function parameters.
Parameters
----------
func : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node representing a function definition
Returns
-------
list[Parameter]
Parameter information from the function signature.
"""
arguments: list[Parameter] = []
pos_defaults = self.get_padded_args_defaults(func)
pos_only_args = [
Parameter(arg.arg, ast_unparse(arg.annotation), None)
for arg in func.args.posonlyargs
]
arguments += pos_only_args
general_args = [
Parameter(arg.arg, ast_unparse(arg.annotation), default)
for arg, default in zip(func.args.args, pos_defaults)
]
arguments += general_args
if vararg := func.args.vararg:
arguments.append(
Parameter(f"*{vararg.arg}", ast_unparse(vararg.annotation), None)
)
kw_only_args = [
Parameter(
arg.arg,
ast_unparse(arg.annotation),
ast_unparse(default),
)
for arg, default in zip(func.args.kwonlyargs, func.args.kw_defaults)
]
arguments += kw_only_args
if kwarg := func.args.kwarg:
arguments.append(
Parameter(f"**{kwarg.arg}", ast_unparse(kwarg.annotation), None)
)
# Filter out unused arguments.
return (
[
argument
for argument in arguments
if not argument.arg_name.startswith("_")
]
if self.settings.ignore_unused_arguments
else arguments
)
[docs]
@staticmethod
def is_shebang_or_pragma(line: str) -> bool:
"""Check if a given line contains encoding or shebang.
Parameters
----------
line : str
Line to check
Returns
-------
bool
Whether the given line contains encoding or shebang
"""
shebang_regex = r"^#!(.*)"
if re.search(shebang_regex, line) is not None:
return True
pragma_regex = r"^#.*coding[=:]\s*([-\w.]+)"
return re.search(pragma_regex, line) is not None
[docs]
def get_padded_args_defaults(
self,
func: Union[ast.FunctionDef, ast.AsyncFunctionDef],
) -> list[Optional[str]]:
"""Left-Pad the general args defaults to the length of the args.
Parameters
----------
func : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node representing a function definition
Returns
-------
list[Optional[str]]
Left padded (with `None`) list of function arguments.
"""
pos_defaults = [ast_unparse(default) for default in func.args.defaults]
return [None] * (len(func.args.args) - len(pos_defaults)) + pos_defaults
def _has_excluding_decorator(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> bool:
"""Exclude function with some decorators.
Currently excluded:
staticmethod
classmethod
property (and related)
Parameters
----------
node : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node representing a function definition
Returns
-------
bool
Whether the function as any decorators that exclude it from
being recognized as a standard method.
"""
decorators = node.decorator_list
excluded_decorators = {"staticmethod", "classmethod", "property"}
for decorator in decorators:
if isinstance(decorator, ast.Name) and decorator.id in excluded_decorators:
return True
# Handle property related decorators like in
# @x.setter
# def x(self, value):
# self._x = value # noqa: ERA001
# @x.deleter
# def x(self):
# del self._x
if (
isinstance(decorator, ast.Attribute)
and isinstance(decorator.value, ast.Name)
and decorator.value.id == node.name
):
return True
return False
def _check_if_node_is_self_attributes(
self, node: ast.expr
) -> TypeGuard[ast.Attribute]:
"""Check whether the node represents a public attribute of self (self.abc).
Parameters
----------
node : ast.expr
Node representing the expression to be checked.
Returns
-------
TypeGuard[ast.Attribute]
True if the node represents a public attribute of self.
"""
return (
isinstance(node, ast.Attribute)
and isinstance(node.value, ast.Name)
and node.value.id == "self"
and not (self.settings.ignore_privates and node.attr.startswith("_"))
)
def _check_and_handle_assign_node(
self, target: ast.expr, attributes: list[Parameter]
) -> None:
"""Check if the assignment node contains assignments to self.X.
Add it to the list of attributes if that is the case.
Parameters
----------
target : ast.expr
Node representing an assignment
attributes : list[Parameter]
List of attributes the node attribute should be added to.
"""
if isinstance(target, (ast.Tuple, ast.List)):
for node in target.elts:
if self._check_if_node_is_self_attributes(node):
attributes.append(Parameter(node.attr, "_type_", None))
elif self._check_if_node_is_self_attributes(target):
attributes.append(Parameter(target.attr, "_type_", None))
def _get_attributes_from_init(
self, init: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> list[Parameter]:
"""Iterate over body and grab every assignment `self.abc = XYZ`.
Parameters
----------
init : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Init function node to extract attributes from.
Returns
-------
list[Parameter]
List of attributes extracted from the init function.
"""
attributes: list[Parameter] = []
for node in init.body:
if isinstance(node, ast.Assign):
# Targets is a list in case of multiple assignment
# a = b = 3 # noqa: ERA001
for target in node.targets:
self._check_and_handle_assign_node(target, attributes)
# Also handle annotated assignments
# c: int = "Test" # noqa: ERA001
elif isinstance(node, ast.AnnAssign):
self._check_and_handle_assign_node(node.target, attributes)
return attributes
def _get_method_signature(
self, func: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> str:
"""Remove self from signature and return the unparsed string.
Parameters
----------
func : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node representing a function definition.
Returns
-------
str
String of the method signature with `self` removed.
"""
arguments = func.args
if arguments.posonlyargs:
arguments.posonlyargs = [
arg for arg in arguments.posonlyargs if arg.arg != "self"
]
elif arguments.args:
arguments.args = [arg for arg in arguments.args if arg.arg != "self"]
return f"{func.name}({ast.unparse(arguments)})"
[docs]
@staticmethod
def get_ids_from_returns(values: list[ast.expr]) -> tuple[str, ...]:
"""Get the ids/names for all the expressions in the list.
Parameters
----------
values : list[ast.expr]
List of expressions to extract the ids from.
Returns
-------
tuple[str, ...]
Tuple of ids of the original expressions.
"""
return tuple(
value.id
for value in values
# Needed again for type checker
if isinstance(value, ast.Name)
)