Coverage for src/_griffe/agents/visitor.py: 98.34%
261 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-15 16:47 +0200
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-15 16:47 +0200
1# This module contains our static analysis agent,
2# capable of parsing and visiting sources, statically.
4from __future__ import annotations
6import ast
7from contextlib import suppress
8from typing import TYPE_CHECKING, Any
10from _griffe.agents.nodes.assignments import get_instance_names, get_names
11from _griffe.agents.nodes.ast import (
12 ast_children,
13 ast_kind,
14 ast_next,
15)
16from _griffe.agents.nodes.docstrings import get_docstring
17from _griffe.agents.nodes.exports import safe_get__all__
18from _griffe.agents.nodes.imports import relative_to_absolute
19from _griffe.agents.nodes.parameters import get_parameters
20from _griffe.collections import LinesCollection, ModulesCollection
21from _griffe.enumerations import Kind
22from _griffe.exceptions import AliasResolutionError, CyclicAliasError, LastNodeError
23from _griffe.expressions import (
24 Expr,
25 ExprName,
26 safe_get_annotation,
27 safe_get_base_class,
28 safe_get_condition,
29 safe_get_expression,
30)
31from _griffe.extensions.base import Extensions, load_extensions
32from _griffe.models import Alias, Attribute, Class, Decorator, Docstring, Function, Module, Parameter, Parameters
34if TYPE_CHECKING:
35 from pathlib import Path
37 from _griffe.enumerations import Parser
40builtin_decorators = {
41 "property": "property",
42 "staticmethod": "staticmethod",
43 "classmethod": "classmethod",
44}
45"""Mapping of builtin decorators to labels."""
47stdlib_decorators = {
48 "abc.abstractmethod": {"abstractmethod"},
49 "functools.cache": {"cached"},
50 "functools.cached_property": {"cached", "property"},
51 "cached_property.cached_property": {"cached", "property"},
52 "functools.lru_cache": {"cached"},
53 "dataclasses.dataclass": {"dataclass"},
54}
55"""Mapping of standard library decorators to labels."""
57typing_overload = {"typing.overload", "typing_extensions.overload"}
58"""Set of recognized typing overload decorators.
60When such a decorator is found, the decorated function becomes an overload.
61"""
64def visit(
65 module_name: str,
66 filepath: Path,
67 code: str,
68 *,
69 extensions: Extensions | None = None,
70 parent: Module | None = None,
71 docstring_parser: Parser | None = None,
72 docstring_options: dict[str, Any] | None = None,
73 lines_collection: LinesCollection | None = None,
74 modules_collection: ModulesCollection | None = None,
75) -> Module:
76 """Parse and visit a module file.
78 We provide this function for static analysis. It uses a [`NodeVisitor`][ast.NodeVisitor]-like class,
79 the [`Visitor`][griffe.Visitor], to compile and parse code (using [`compile`][])
80 then visit the resulting AST (Abstract Syntax Tree).
82 Important:
83 This function is generally not used directly.
84 In most cases, users can rely on the [`GriffeLoader`][griffe.GriffeLoader]
85 and its accompanying [`load`][griffe.load] shortcut and their respective options
86 to load modules using static analysis.
88 Parameters:
89 module_name: The module name (as when importing [from] it).
90 filepath: The module file path.
91 code: The module contents.
92 extensions: The extensions to use when visiting the AST.
93 parent: The optional parent of this module.
94 docstring_parser: The docstring parser to use. By default, no parsing is done.
95 docstring_options: Additional docstring parsing options.
96 lines_collection: A collection of source code lines.
97 modules_collection: A collection of modules.
99 Returns:
100 The module, with its members populated.
101 """
102 return Visitor(
103 module_name,
104 filepath,
105 code,
106 extensions or load_extensions(),
107 parent,
108 docstring_parser=docstring_parser,
109 docstring_options=docstring_options,
110 lines_collection=lines_collection,
111 modules_collection=modules_collection,
112 ).get_module()
115class Visitor:
116 """This class is used to instantiate a visitor.
118 Visitors iterate on AST nodes to extract data from them.
119 """
121 def __init__(
122 self,
123 module_name: str,
124 filepath: Path,
125 code: str,
126 extensions: Extensions,
127 parent: Module | None = None,
128 docstring_parser: Parser | None = None,
129 docstring_options: dict[str, Any] | None = None,
130 lines_collection: LinesCollection | None = None,
131 modules_collection: ModulesCollection | None = None,
132 ) -> None:
133 """Initialize the visitor.
135 Parameters:
136 module_name: The module name.
137 filepath: The module filepath.
138 code: The module source code.
139 extensions: The extensions to use when visiting.
140 parent: An optional parent for the final module object.
141 docstring_parser: The docstring parser to use.
142 docstring_options: The docstring parsing options.
143 lines_collection: A collection of source code lines.
144 modules_collection: A collection of modules.
145 """
146 super().__init__()
148 self.module_name: str = module_name
149 """The module name."""
151 self.filepath: Path = filepath
152 """The module filepath."""
154 self.code: str = code
155 """The module source code."""
157 self.extensions: Extensions = extensions
158 """The extensions to use when visiting the AST."""
160 self.parent: Module | None = parent
161 """An optional parent for the final module object."""
163 self.current: Module | Class = None # type: ignore[assignment]
164 """The current object being visited."""
166 self.docstring_parser: Parser | None = docstring_parser
167 """The docstring parser to use."""
169 self.docstring_options: dict[str, Any] = docstring_options or {}
170 """The docstring parsing options."""
172 self.lines_collection: LinesCollection = lines_collection or LinesCollection()
173 """A collection of source code lines."""
175 self.modules_collection: ModulesCollection = modules_collection or ModulesCollection()
176 """A collection of modules."""
178 self.type_guarded: bool = False
179 """Whether the current code branch is type-guarded."""
181 def _get_docstring(self, node: ast.AST, *, strict: bool = False) -> Docstring | None:
182 value, lineno, endlineno = get_docstring(node, strict=strict)
183 if value is None:
184 return None
185 return Docstring(
186 value,
187 lineno=lineno,
188 endlineno=endlineno,
189 parser=self.docstring_parser,
190 parser_options=self.docstring_options,
191 )
193 def get_module(self) -> Module:
194 """Build and return the object representing the module attached to this visitor.
196 This method triggers a complete visit of the module nodes.
198 Returns:
199 A module instance.
200 """
201 # optimization: equivalent to ast.parse, but with optimize=1 to remove assert statements
202 # TODO: with options, could use optimize=2 to remove docstrings
203 top_node = compile(self.code, mode="exec", filename=str(self.filepath), flags=ast.PyCF_ONLY_AST, optimize=1)
204 self.visit(top_node)
205 return self.current.module
207 def visit(self, node: ast.AST) -> None:
208 """Extend the base visit with extensions.
210 Parameters:
211 node: The node to visit.
212 """
213 getattr(self, f"visit_{ast_kind(node)}", self.generic_visit)(node)
215 def generic_visit(self, node: ast.AST) -> None:
216 """Extend the base generic visit with extensions.
218 Parameters:
219 node: The node to visit.
220 """
221 for child in ast_children(node):
222 self.visit(child)
224 def visit_module(self, node: ast.Module) -> None:
225 """Visit a module node.
227 Parameters:
228 node: The node to visit.
229 """
230 self.extensions.call("on_node", node=node, agent=self)
231 self.extensions.call("on_module_node", node=node, agent=self)
232 self.current = module = Module(
233 name=self.module_name,
234 filepath=self.filepath,
235 parent=self.parent,
236 docstring=self._get_docstring(node),
237 lines_collection=self.lines_collection,
238 modules_collection=self.modules_collection,
239 )
240 self.extensions.call("on_instance", node=node, obj=module, agent=self)
241 self.extensions.call("on_module_instance", node=node, mod=module, agent=self)
242 self.generic_visit(node)
243 self.extensions.call("on_members", node=node, obj=module, agent=self)
244 self.extensions.call("on_module_members", node=node, mod=module, agent=self)
246 def visit_classdef(self, node: ast.ClassDef) -> None:
247 """Visit a class definition node.
249 Parameters:
250 node: The node to visit.
251 """
252 self.extensions.call("on_node", node=node, agent=self)
253 self.extensions.call("on_class_node", node=node, agent=self)
255 # handle decorators
256 decorators = []
257 if node.decorator_list:
258 lineno = node.decorator_list[0].lineno
259 for decorator_node in node.decorator_list:
260 decorators.append(
261 Decorator(
262 safe_get_expression(decorator_node, parent=self.current, parse_strings=False), # type: ignore[arg-type]
263 lineno=decorator_node.lineno,
264 endlineno=decorator_node.end_lineno,
265 ),
266 )
267 else:
268 lineno = node.lineno
270 # handle base classes
271 bases = []
272 if node.bases:
273 for base in node.bases:
274 bases.append(safe_get_base_class(base, parent=self.current))
276 class_ = Class(
277 name=node.name,
278 lineno=lineno,
279 endlineno=node.end_lineno,
280 docstring=self._get_docstring(node),
281 decorators=decorators,
282 bases=bases, # type: ignore[arg-type]
283 runtime=not self.type_guarded,
284 )
285 class_.labels |= self.decorators_to_labels(decorators)
286 self.current.set_member(node.name, class_)
287 self.current = class_
288 self.extensions.call("on_instance", node=node, obj=class_, agent=self)
289 self.extensions.call("on_class_instance", node=node, cls=class_, agent=self)
290 self.generic_visit(node)
291 self.extensions.call("on_members", node=node, obj=class_, agent=self)
292 self.extensions.call("on_class_members", node=node, cls=class_, agent=self)
293 self.current = self.current.parent # type: ignore[assignment]
295 def decorators_to_labels(self, decorators: list[Decorator]) -> set[str]:
296 """Build and return a set of labels based on decorators.
298 Parameters:
299 decorators: The decorators to check.
301 Returns:
302 A set of labels.
303 """
304 labels = set()
305 for decorator in decorators:
306 callable_path = decorator.callable_path
307 if callable_path in builtin_decorators:
308 labels.add(builtin_decorators[callable_path])
309 elif callable_path in stdlib_decorators:
310 labels |= stdlib_decorators[callable_path]
311 return labels
313 def get_base_property(self, decorators: list[Decorator], function: Function) -> str | None:
314 """Check decorators to return the base property in case of setters and deleters.
316 Parameters:
317 decorators: The decorators to check.
319 Returns:
320 base_property: The property for which the setter/deleted is set.
321 property_function: Either `"setter"` or `"deleter"`.
322 """
323 for decorator in decorators:
324 try:
325 path, prop_function = decorator.callable_path.rsplit(".", 1)
326 except ValueError:
327 continue
328 property_setter_or_deleter = (
329 prop_function in {"setter", "deleter"}
330 and path == function.path
331 and self.current.get_member(function.name).has_labels("property")
332 )
333 if property_setter_or_deleter:
334 return prop_function
335 return None
337 def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels: set | None = None) -> None:
338 """Handle a function definition node.
340 Parameters:
341 node: The node to visit.
342 labels: Labels to add to the data object.
343 """
344 self.extensions.call("on_node", node=node, agent=self)
345 self.extensions.call("on_function_node", node=node, agent=self)
347 labels = labels or set()
349 # handle decorators
350 decorators = []
351 overload = False
352 if node.decorator_list:
353 lineno = node.decorator_list[0].lineno
354 for decorator_node in node.decorator_list:
355 decorator_value = safe_get_expression(decorator_node, parent=self.current, parse_strings=False)
356 if decorator_value is None: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true
357 continue
358 decorator = Decorator(
359 decorator_value,
360 lineno=decorator_node.lineno,
361 endlineno=decorator_node.end_lineno,
362 )
363 decorators.append(decorator)
364 overload |= decorator.callable_path in typing_overload
365 else:
366 lineno = node.lineno
368 labels |= self.decorators_to_labels(decorators)
370 if "property" in labels:
371 attribute = Attribute(
372 name=node.name,
373 value=None,
374 annotation=safe_get_annotation(node.returns, parent=self.current),
375 lineno=node.lineno,
376 endlineno=node.end_lineno,
377 docstring=self._get_docstring(node),
378 runtime=not self.type_guarded,
379 )
380 attribute.labels |= labels
381 self.current.set_member(node.name, attribute)
382 self.extensions.call("on_instance", node=node, obj=attribute, agent=self)
383 self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self)
384 return
386 # handle parameters
387 parameters = Parameters(
388 *[
389 Parameter(
390 name,
391 kind=kind,
392 annotation=safe_get_annotation(annotation, parent=self.current),
393 default=default
394 if isinstance(default, str)
395 else safe_get_expression(default, parent=self.current, parse_strings=False),
396 )
397 for name, annotation, kind, default in get_parameters(node.args)
398 ],
399 )
401 function = Function(
402 name=node.name,
403 lineno=lineno,
404 endlineno=node.end_lineno,
405 parameters=parameters,
406 returns=safe_get_annotation(node.returns, parent=self.current),
407 decorators=decorators,
408 docstring=self._get_docstring(node),
409 runtime=not self.type_guarded,
410 parent=self.current,
411 )
413 property_function = self.get_base_property(decorators, function)
415 if overload:
416 self.current.overloads[function.name].append(function)
417 elif property_function:
418 base_property: Attribute = self.current.members[node.name] # type: ignore[assignment]
419 if property_function == "setter":
420 base_property.setter = function
421 base_property.labels.add("writable")
422 elif property_function == "deleter": 422 ↛ 431line 422 didn't jump to line 431 because the condition on line 422 was always true
423 base_property.deleter = function
424 base_property.labels.add("deletable")
425 else:
426 self.current.set_member(node.name, function)
427 if self.current.kind in {Kind.MODULE, Kind.CLASS} and self.current.overloads[function.name]:
428 function.overloads = self.current.overloads[function.name]
429 del self.current.overloads[function.name]
431 function.labels |= labels
433 self.extensions.call("on_instance", node=node, obj=function, agent=self)
434 self.extensions.call("on_function_instance", node=node, func=function, agent=self)
435 if self.current.kind is Kind.CLASS and function.name == "__init__":
436 self.current = function # type: ignore[assignment] # temporary assign a function
437 self.generic_visit(node)
438 self.current = self.current.parent # type: ignore[assignment]
440 def visit_functiondef(self, node: ast.FunctionDef) -> None:
441 """Visit a function definition node.
443 Parameters:
444 node: The node to visit.
445 """
446 self.handle_function(node)
448 def visit_asyncfunctiondef(self, node: ast.AsyncFunctionDef) -> None:
449 """Visit an async function definition node.
451 Parameters:
452 node: The node to visit.
453 """
454 self.handle_function(node, labels={"async"})
456 def visit_import(self, node: ast.Import) -> None:
457 """Visit an import node.
459 Parameters:
460 node: The node to visit.
461 """
462 for name in node.names:
463 alias_path = name.name if name.asname else name.name.split(".", 1)[0]
464 alias_name = name.asname or alias_path.split(".", 1)[0]
465 self.current.imports[alias_name] = alias_path
466 self.current.set_member(
467 alias_name,
468 Alias(
469 alias_name,
470 alias_path,
471 lineno=node.lineno,
472 endlineno=node.end_lineno,
473 runtime=not self.type_guarded,
474 ),
475 )
477 def visit_importfrom(self, node: ast.ImportFrom) -> None:
478 """Visit an "import from" node.
480 Parameters:
481 node: The node to visit.
482 """
483 for name in node.names:
484 if not node.module and node.level == 1 and not name.asname and self.current.module.is_init_module:
485 # special case: when being in `a/__init__.py` and doing `from . import b`,
486 # we are effectively creating a member `b` in `a` that is pointing to `a.b`
487 # -> cyclic alias! in that case, we just skip it, as both the member and module
488 # have the same name and can be accessed the same way
489 continue
491 alias_path = relative_to_absolute(node, name, self.current.module)
492 if name.name == "*":
493 alias_name = alias_path.replace(".", "/")
494 alias_path = alias_path.replace(".*", "")
495 else:
496 alias_name = name.asname or name.name
497 self.current.imports[alias_name] = alias_path
498 # Do not create aliases pointing to themselves (it happens with
499 # `from package.current_module import Thing as Thing` or
500 # `from . import thing as thing`).
501 if alias_path != f"{self.current.path}.{alias_name}":
502 self.current.set_member(
503 alias_name,
504 Alias(
505 alias_name,
506 alias_path,
507 lineno=node.lineno,
508 endlineno=node.end_lineno,
509 runtime=not self.type_guarded,
510 ),
511 )
513 def handle_attribute(
514 self,
515 node: ast.Assign | ast.AnnAssign,
516 annotation: str | Expr | None = None,
517 ) -> None:
518 """Handle an attribute (assignment) node.
520 Parameters:
521 node: The node to visit.
522 annotation: A potential annotation.
523 """
524 self.extensions.call("on_node", node=node, agent=self)
525 self.extensions.call("on_attribute_node", node=node, agent=self)
526 parent = self.current
527 labels = set()
529 if parent.kind is Kind.MODULE:
530 try:
531 names = get_names(node)
532 except KeyError: # unsupported nodes, like subscript
533 return
534 labels.add("module-attribute")
535 elif parent.kind is Kind.CLASS:
536 try:
537 names = get_names(node)
538 except KeyError: # unsupported nodes, like subscript
539 return
541 if isinstance(annotation, Expr) and annotation.is_classvar:
542 # explicit classvar: class attribute only
543 annotation = annotation.slice # type: ignore[attr-defined]
544 labels.add("class-attribute")
545 elif node.value:
546 # attribute assigned at class-level: available in instances as well
547 labels.add("class-attribute")
548 labels.add("instance-attribute")
549 else:
550 # annotated attribute only: not available at class-level
551 labels.add("instance-attribute")
553 elif parent.kind is Kind.FUNCTION: 553 ↛ 563line 553 didn't jump to line 563 because the condition on line 553 was always true
554 if parent.name != "__init__": 554 ↛ 555line 554 didn't jump to line 555 because the condition on line 554 was never true
555 return
556 try:
557 names = get_instance_names(node)
558 except KeyError: # unsupported nodes, like subscript
559 return
560 parent = parent.parent # type: ignore[assignment]
561 labels.add("instance-attribute")
563 if not names:
564 return
566 value = safe_get_expression(node.value, parent=self.current, parse_strings=False)
568 try:
569 docstring = self._get_docstring(ast_next(node), strict=True)
570 except (LastNodeError, AttributeError):
571 docstring = None
573 for name in names:
574 # TODO: handle assigns like x.y = z
575 # we need to resolve x.y and add z in its member
576 if "." in name:
577 continue
579 if name in parent.members:
580 # assigning multiple times
581 # TODO: might be better to inspect
582 if isinstance(node.parent, (ast.If, ast.ExceptHandler)): # type: ignore[union-attr]
583 continue # prefer "no-exception" case
585 existing_member = parent.members[name]
586 with suppress(AliasResolutionError, CyclicAliasError):
587 labels |= existing_member.labels
588 # forward previous docstring and annotation instead of erasing them
589 if existing_member.docstring and not docstring:
590 docstring = existing_member.docstring
591 with suppress(AttributeError):
592 if existing_member.annotation and not annotation: # type: ignore[union-attr]
593 annotation = existing_member.annotation # type: ignore[union-attr]
595 attribute = Attribute(
596 name=name,
597 value=value,
598 annotation=annotation,
599 lineno=node.lineno,
600 endlineno=node.end_lineno,
601 docstring=docstring,
602 runtime=not self.type_guarded,
603 )
604 attribute.labels |= labels
605 parent.set_member(name, attribute)
607 if name == "__all__":
608 with suppress(AttributeError):
609 parent.exports = [
610 name if isinstance(name, str) else ExprName(name.name, parent=name.parent)
611 for name in safe_get__all__(node, self.current) # type: ignore[arg-type]
612 ]
613 self.extensions.call("on_instance", node=node, obj=attribute, agent=self)
614 self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self)
616 def visit_assign(self, node: ast.Assign) -> None:
617 """Visit an assignment node.
619 Parameters:
620 node: The node to visit.
621 """
622 self.handle_attribute(node)
624 def visit_annassign(self, node: ast.AnnAssign) -> None:
625 """Visit an annotated assignment node.
627 Parameters:
628 node: The node to visit.
629 """
630 self.handle_attribute(node, safe_get_annotation(node.annotation, parent=self.current))
632 def visit_augassign(self, node: ast.AugAssign) -> None:
633 """Visit an augmented assignment node.
635 Parameters:
636 node: The node to visit.
637 """
638 with suppress(AttributeError):
639 all_augment = (
640 node.target.id == "__all__" # type: ignore[union-attr]
641 and self.current.is_module
642 and isinstance(node.op, ast.Add)
643 )
644 if all_augment:
645 # we assume exports is not None at this point
646 self.current.exports.extend( # type: ignore[union-attr]
647 [
648 name if isinstance(name, str) else ExprName(name.name, parent=name.parent)
649 for name in safe_get__all__(node, self.current) # type: ignore[arg-type]
650 ],
651 )
653 def visit_if(self, node: ast.If) -> None:
654 """Visit an "if" node.
656 Parameters:
657 node: The node to visit.
658 """
659 if isinstance(node.parent, (ast.Module, ast.ClassDef)): # type: ignore[attr-defined]
660 condition = safe_get_condition(node.test, parent=self.current, log_level=None)
661 if str(condition) in {"typing.TYPE_CHECKING", "TYPE_CHECKING"}:
662 self.type_guarded = True
663 self.generic_visit(node)
664 self.type_guarded = False