Coverage for src/_griffe/agents/visitor.py: 98.25%
261 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-12 01:34 +0200
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-12 01:34 +0200
1# This module contains our static analysis agent,
2# capable of parsing and visiting sources, statically.
4from __future__ import annotations
6import ast
7from contextlib import suppress
8from typing import TYPE_CHECKING, Any
10from _griffe.agents.nodes.assignments import get_instance_names, get_names
11from _griffe.agents.nodes.ast import (
12 ast_children,
13 ast_kind,
14 ast_next,
15)
16from _griffe.agents.nodes.docstrings import get_docstring
17from _griffe.agents.nodes.exports import safe_get__all__
18from _griffe.agents.nodes.imports import relative_to_absolute
19from _griffe.agents.nodes.parameters import get_parameters
20from _griffe.collections import LinesCollection, ModulesCollection
21from _griffe.enumerations import Kind
22from _griffe.exceptions import AliasResolutionError, CyclicAliasError, LastNodeError
23from _griffe.expressions import (
24 Expr,
25 ExprName,
26 safe_get_annotation,
27 safe_get_base_class,
28 safe_get_condition,
29 safe_get_expression,
30)
31from _griffe.extensions.base import Extensions, load_extensions
32from _griffe.models import Alias, Attribute, Class, Decorator, Docstring, Function, Module, Parameter, Parameters
34if TYPE_CHECKING:
35 from pathlib import Path
37 from _griffe.enumerations import Parser
40builtin_decorators = {
41 "property": "property",
42 "staticmethod": "staticmethod",
43 "classmethod": "classmethod",
44}
45"""Mapping of builtin decorators to labels."""
47stdlib_decorators = {
48 "abc.abstractmethod": {"abstractmethod"},
49 "functools.cache": {"cached"},
50 "functools.cached_property": {"cached", "property"},
51 "cached_property.cached_property": {"cached", "property"},
52 "functools.lru_cache": {"cached"},
53 "dataclasses.dataclass": {"dataclass"},
54}
55"""Mapping of standard library decorators to labels."""
57typing_overload = {"typing.overload", "typing_extensions.overload"}
58"""Set of recognized typing overload decorators.
60When such a decorator is found, the decorated function becomes an overload.
61"""
64def visit(
65 module_name: str,
66 filepath: Path,
67 code: str,
68 *,
69 extensions: Extensions | None = None,
70 parent: Module | None = None,
71 docstring_parser: Parser | None = None,
72 docstring_options: dict[str, Any] | None = None,
73 lines_collection: LinesCollection | None = None,
74 modules_collection: ModulesCollection | None = None,
75) -> Module:
76 """Parse and visit a module file.
78 We provide this function for static analysis. It uses a [`NodeVisitor`][ast.NodeVisitor]-like class,
79 the [`Visitor`][griffe.Visitor], to compile and parse code (using [`compile`][])
80 then visit the resulting AST (Abstract Syntax Tree).
82 Important:
83 This function is generally not used directly.
84 In most cases, users can rely on the [`GriffeLoader`][griffe.GriffeLoader]
85 and its accompanying [`load`][griffe.load] shortcut and their respective options
86 to load modules using static analysis.
88 Parameters:
89 module_name: The module name (as when importing [from] it).
90 filepath: The module file path.
91 code: The module contents.
92 extensions: The extensions to use when visiting the AST.
93 parent: The optional parent of this module.
94 docstring_parser: The docstring parser to use. By default, no parsing is done.
95 docstring_options: Additional docstring parsing options.
96 lines_collection: A collection of source code lines.
97 modules_collection: A collection of modules.
99 Returns:
100 The module, with its members populated.
101 """
102 return Visitor(
103 module_name,
104 filepath,
105 code,
106 extensions or load_extensions(),
107 parent,
108 docstring_parser=docstring_parser,
109 docstring_options=docstring_options,
110 lines_collection=lines_collection,
111 modules_collection=modules_collection,
112 ).get_module()
115class Visitor:
116 """This class is used to instantiate a visitor.
118 Visitors iterate on AST nodes to extract data from them.
119 """
121 def __init__(
122 self,
123 module_name: str,
124 filepath: Path,
125 code: str,
126 extensions: Extensions,
127 parent: Module | None = None,
128 docstring_parser: Parser | None = None,
129 docstring_options: dict[str, Any] | None = None,
130 lines_collection: LinesCollection | None = None,
131 modules_collection: ModulesCollection | None = None,
132 ) -> None:
133 """Initialize the visitor.
135 Parameters:
136 module_name: The module name.
137 filepath: The module filepath.
138 code: The module source code.
139 extensions: The extensions to use when visiting.
140 parent: An optional parent for the final module object.
141 docstring_parser: The docstring parser to use.
142 docstring_options: The docstring parsing options.
143 lines_collection: A collection of source code lines.
144 modules_collection: A collection of modules.
145 """
146 super().__init__()
148 self.module_name: str = module_name
149 """The module name."""
151 self.filepath: Path = filepath
152 """The module filepath."""
154 self.code: str = code
155 """The module source code."""
157 self.extensions: Extensions = extensions
158 """The extensions to use when visiting the AST."""
160 self.parent: Module | None = parent
161 """An optional parent for the final module object."""
163 self.current: Module | Class = None # type: ignore[assignment]
164 """The current object being visited."""
166 self.docstring_parser: Parser | None = docstring_parser
167 """The docstring parser to use."""
169 self.docstring_options: dict[str, Any] = docstring_options or {}
170 """The docstring parsing options."""
172 self.lines_collection: LinesCollection = lines_collection or LinesCollection()
173 """A collection of source code lines."""
175 self.modules_collection: ModulesCollection = modules_collection or ModulesCollection()
176 """A collection of modules."""
178 self.type_guarded: bool = False
179 """Whether the current code branch is type-guarded."""
181 def _get_docstring(self, node: ast.AST, *, strict: bool = False) -> Docstring | None:
182 value, lineno, endlineno = get_docstring(node, strict=strict)
183 if value is None:
184 return None
185 return Docstring(
186 value,
187 lineno=lineno,
188 endlineno=endlineno,
189 parser=self.docstring_parser,
190 parser_options=self.docstring_options,
191 )
193 def get_module(self) -> Module:
194 """Build and return the object representing the module attached to this visitor.
196 This method triggers a complete visit of the module nodes.
198 Returns:
199 A module instance.
200 """
201 # optimization: equivalent to ast.parse, but with optimize=1 to remove assert statements
202 # TODO: with options, could use optimize=2 to remove docstrings
203 top_node = compile(self.code, mode="exec", filename=str(self.filepath), flags=ast.PyCF_ONLY_AST, optimize=1)
204 self.visit(top_node)
205 return self.current.module
207 def visit(self, node: ast.AST) -> None:
208 """Extend the base visit with extensions.
210 Parameters:
211 node: The node to visit.
212 """
213 getattr(self, f"visit_{ast_kind(node)}", self.generic_visit)(node)
215 def generic_visit(self, node: ast.AST) -> None:
216 """Extend the base generic visit with extensions.
218 Parameters:
219 node: The node to visit.
220 """
221 for child in ast_children(node):
222 self.visit(child)
224 def visit_module(self, node: ast.Module) -> None:
225 """Visit a module node.
227 Parameters:
228 node: The node to visit.
229 """
230 self.extensions.call("on_node", node=node, agent=self)
231 self.extensions.call("on_module_node", node=node, agent=self)
232 self.current = module = Module(
233 name=self.module_name,
234 filepath=self.filepath,
235 parent=self.parent,
236 docstring=self._get_docstring(node),
237 lines_collection=self.lines_collection,
238 modules_collection=self.modules_collection,
239 )
240 self.extensions.call("on_instance", node=node, obj=module, agent=self)
241 self.extensions.call("on_module_instance", node=node, mod=module, agent=self)
242 self.generic_visit(node)
243 self.extensions.call("on_members", node=node, obj=module, agent=self)
244 self.extensions.call("on_module_members", node=node, mod=module, agent=self)
246 def visit_classdef(self, node: ast.ClassDef) -> None:
247 """Visit a class definition node.
249 Parameters:
250 node: The node to visit.
251 """
252 self.extensions.call("on_node", node=node, agent=self)
253 self.extensions.call("on_class_node", node=node, agent=self)
255 # handle decorators
256 decorators: list[Decorator] = []
257 if node.decorator_list:
258 lineno = node.decorator_list[0].lineno
259 decorators.extend(
260 Decorator(
261 safe_get_expression(decorator_node, parent=self.current, parse_strings=False), # type: ignore[arg-type]
262 lineno=decorator_node.lineno,
263 endlineno=decorator_node.end_lineno,
264 )
265 for decorator_node in node.decorator_list
266 )
267 else:
268 lineno = node.lineno
270 # handle base classes
271 bases = [safe_get_base_class(base, parent=self.current) for base in node.bases]
273 class_ = Class(
274 name=node.name,
275 lineno=lineno,
276 endlineno=node.end_lineno,
277 docstring=self._get_docstring(node),
278 decorators=decorators,
279 bases=bases, # type: ignore[arg-type]
280 runtime=not self.type_guarded,
281 )
282 class_.labels |= self.decorators_to_labels(decorators)
283 self.current.set_member(node.name, class_)
284 self.current = class_
285 self.extensions.call("on_instance", node=node, obj=class_, agent=self)
286 self.extensions.call("on_class_instance", node=node, cls=class_, agent=self)
287 self.generic_visit(node)
288 self.extensions.call("on_members", node=node, obj=class_, agent=self)
289 self.extensions.call("on_class_members", node=node, cls=class_, agent=self)
290 self.current = self.current.parent # type: ignore[assignment]
292 def decorators_to_labels(self, decorators: list[Decorator]) -> set[str]:
293 """Build and return a set of labels based on decorators.
295 Parameters:
296 decorators: The decorators to check.
298 Returns:
299 A set of labels.
300 """
301 labels = set()
302 for decorator in decorators:
303 callable_path = decorator.callable_path
304 if callable_path in builtin_decorators:
305 labels.add(builtin_decorators[callable_path])
306 elif callable_path in stdlib_decorators:
307 labels |= stdlib_decorators[callable_path]
308 return labels
310 def get_base_property(self, decorators: list[Decorator], function: Function) -> str | None:
311 """Check decorators to return the base property in case of setters and deleters.
313 Parameters:
314 decorators: The decorators to check.
316 Returns:
317 base_property: The property for which the setter/deleted is set.
318 property_function: Either `"setter"` or `"deleter"`.
319 """
320 for decorator in decorators:
321 try:
322 path, prop_function = decorator.callable_path.rsplit(".", 1)
323 except ValueError:
324 continue
325 property_setter_or_deleter = (
326 prop_function in {"setter", "deleter"}
327 and path == function.path
328 and self.current.get_member(function.name).has_labels("property")
329 )
330 if property_setter_or_deleter:
331 return prop_function
332 return None
334 def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels: set | None = None) -> None:
335 """Handle a function definition node.
337 Parameters:
338 node: The node to visit.
339 labels: Labels to add to the data object.
340 """
341 self.extensions.call("on_node", node=node, agent=self)
342 self.extensions.call("on_function_node", node=node, agent=self)
344 labels = labels or set()
346 # handle decorators
347 decorators = []
348 overload = False
349 if node.decorator_list:
350 lineno = node.decorator_list[0].lineno
351 for decorator_node in node.decorator_list:
352 decorator_value = safe_get_expression(decorator_node, parent=self.current, parse_strings=False)
353 if decorator_value is None: 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true
354 continue
355 decorator = Decorator(
356 decorator_value,
357 lineno=decorator_node.lineno,
358 endlineno=decorator_node.end_lineno,
359 )
360 decorators.append(decorator)
361 overload |= decorator.callable_path in typing_overload
362 else:
363 lineno = node.lineno
365 labels |= self.decorators_to_labels(decorators)
367 if "property" in labels:
368 attribute = Attribute(
369 name=node.name,
370 value=None,
371 annotation=safe_get_annotation(node.returns, parent=self.current),
372 lineno=node.lineno,
373 endlineno=node.end_lineno,
374 docstring=self._get_docstring(node),
375 runtime=not self.type_guarded,
376 )
377 attribute.labels |= labels
378 self.current.set_member(node.name, attribute)
379 self.extensions.call("on_instance", node=node, obj=attribute, agent=self)
380 self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self)
381 return
383 # handle parameters
384 parameters = Parameters(
385 *[
386 Parameter(
387 name,
388 kind=kind,
389 annotation=safe_get_annotation(annotation, parent=self.current),
390 default=default
391 if isinstance(default, str)
392 else safe_get_expression(default, parent=self.current, parse_strings=False),
393 )
394 for name, annotation, kind, default in get_parameters(node.args)
395 ],
396 )
398 function = Function(
399 name=node.name,
400 lineno=lineno,
401 endlineno=node.end_lineno,
402 parameters=parameters,
403 returns=safe_get_annotation(node.returns, parent=self.current),
404 decorators=decorators,
405 docstring=self._get_docstring(node),
406 runtime=not self.type_guarded,
407 parent=self.current,
408 )
410 property_function = self.get_base_property(decorators, function)
412 if overload:
413 self.current.overloads[function.name].append(function)
414 elif property_function:
415 base_property: Attribute = self.current.members[node.name] # type: ignore[assignment]
416 if property_function == "setter":
417 base_property.setter = function
418 base_property.labels.add("writable")
419 elif property_function == "deleter": 419 ↛ 428line 419 didn't jump to line 428 because the condition on line 419 was always true
420 base_property.deleter = function
421 base_property.labels.add("deletable")
422 else:
423 self.current.set_member(node.name, function)
424 if self.current.kind in {Kind.MODULE, Kind.CLASS} and self.current.overloads[function.name]:
425 function.overloads = self.current.overloads[function.name]
426 del self.current.overloads[function.name]
428 function.labels |= labels
430 self.extensions.call("on_instance", node=node, obj=function, agent=self)
431 self.extensions.call("on_function_instance", node=node, func=function, agent=self)
432 if self.current.kind is Kind.CLASS and function.name == "__init__":
433 self.current = function # type: ignore[assignment] # temporary assign a function
434 self.generic_visit(node)
435 self.current = self.current.parent # type: ignore[assignment]
437 def visit_functiondef(self, node: ast.FunctionDef) -> None:
438 """Visit a function definition node.
440 Parameters:
441 node: The node to visit.
442 """
443 self.handle_function(node)
445 def visit_asyncfunctiondef(self, node: ast.AsyncFunctionDef) -> None:
446 """Visit an async function definition node.
448 Parameters:
449 node: The node to visit.
450 """
451 self.handle_function(node, labels={"async"})
453 def visit_import(self, node: ast.Import) -> None:
454 """Visit an import node.
456 Parameters:
457 node: The node to visit.
458 """
459 for name in node.names:
460 alias_path = name.name if name.asname else name.name.split(".", 1)[0]
461 alias_name = name.asname or alias_path.split(".", 1)[0]
462 self.current.imports[alias_name] = alias_path
463 alias = Alias(
464 alias_name,
465 alias_path,
466 lineno=node.lineno,
467 endlineno=node.end_lineno,
468 runtime=not self.type_guarded,
469 )
470 self.current.set_member(alias_name, alias)
471 self.extensions.call("on_alias", alias=alias, node=node, agent=self)
473 def visit_importfrom(self, node: ast.ImportFrom) -> None:
474 """Visit an "import from" node.
476 Parameters:
477 node: The node to visit.
478 """
479 for name in node.names:
480 if not node.module and node.level == 1 and not name.asname and self.current.module.is_init_module:
481 # special case: when being in `a/__init__.py` and doing `from . import b`,
482 # we are effectively creating a member `b` in `a` that is pointing to `a.b`
483 # -> cyclic alias! in that case, we just skip it, as both the member and module
484 # have the same name and can be accessed the same way
485 continue
487 alias_path = relative_to_absolute(node, name, self.current.module)
488 if name.name == "*":
489 alias_name = alias_path.replace(".", "/")
490 alias_path = alias_path.replace(".*", "")
491 else:
492 alias_name = name.asname or name.name
493 self.current.imports[alias_name] = alias_path
494 # Do not create aliases pointing to themselves (it happens with
495 # `from package.current_module import Thing as Thing` or
496 # `from . import thing as thing`).
497 if alias_path != f"{self.current.path}.{alias_name}":
498 alias = Alias(
499 alias_name,
500 alias_path,
501 lineno=node.lineno,
502 endlineno=node.end_lineno,
503 runtime=not self.type_guarded,
504 )
505 self.current.set_member(alias_name, alias)
506 self.extensions.call("on_alias", alias=alias, node=node, agent=self)
508 def handle_attribute(
509 self,
510 node: ast.Assign | ast.AnnAssign,
511 annotation: str | Expr | None = None,
512 ) -> None:
513 """Handle an attribute (assignment) node.
515 Parameters:
516 node: The node to visit.
517 annotation: A potential annotation.
518 """
519 self.extensions.call("on_node", node=node, agent=self)
520 self.extensions.call("on_attribute_node", node=node, agent=self)
521 parent = self.current
522 labels = set()
524 if parent.kind is Kind.MODULE:
525 try:
526 names = get_names(node)
527 except KeyError: # unsupported nodes, like subscript
528 return
529 labels.add("module-attribute")
530 elif parent.kind is Kind.CLASS:
531 try:
532 names = get_names(node)
533 except KeyError: # unsupported nodes, like subscript
534 return
536 if isinstance(annotation, Expr) and annotation.is_classvar:
537 # explicit classvar: class attribute only
538 annotation = annotation.slice # type: ignore[attr-defined]
539 labels.add("class-attribute")
540 elif node.value:
541 # attribute assigned at class-level: available in instances as well
542 labels.add("class-attribute")
543 labels.add("instance-attribute")
544 else:
545 # annotated attribute only: not available at class-level
546 labels.add("instance-attribute")
548 elif parent.kind is Kind.FUNCTION: 548 ↛ 558line 548 didn't jump to line 558 because the condition on line 548 was always true
549 if parent.name != "__init__": 549 ↛ 550line 549 didn't jump to line 550 because the condition on line 549 was never true
550 return
551 try:
552 names = get_instance_names(node)
553 except KeyError: # unsupported nodes, like subscript
554 return
555 parent = parent.parent # type: ignore[assignment]
556 labels.add("instance-attribute")
558 if not names:
559 return
561 value = safe_get_expression(node.value, parent=self.current, parse_strings=False)
563 try:
564 docstring = self._get_docstring(ast_next(node), strict=True)
565 except (LastNodeError, AttributeError):
566 docstring = None
568 for name in names:
569 # TODO: handle assigns like x.y = z
570 # we need to resolve x.y and add z in its member
571 if "." in name:
572 continue
574 if name in parent.members:
575 # assigning multiple times
576 # TODO: might be better to inspect
577 if isinstance(node.parent, (ast.If, ast.ExceptHandler)): # type: ignore[union-attr]
578 continue # prefer "no-exception" case
580 existing_member = parent.members[name]
581 with suppress(AliasResolutionError, CyclicAliasError):
582 labels |= existing_member.labels
583 # forward previous docstring and annotation instead of erasing them
584 if existing_member.docstring and not docstring:
585 docstring = existing_member.docstring
586 with suppress(AttributeError):
587 if existing_member.annotation and not annotation: # type: ignore[union-attr]
588 annotation = existing_member.annotation # type: ignore[union-attr]
590 attribute = Attribute(
591 name=name,
592 value=value,
593 annotation=annotation,
594 lineno=node.lineno,
595 endlineno=node.end_lineno,
596 docstring=docstring,
597 runtime=not self.type_guarded,
598 )
599 attribute.labels |= labels
600 parent.set_member(name, attribute)
602 if name == "__all__":
603 with suppress(AttributeError):
604 parent.exports = [
605 name if isinstance(name, str) else ExprName(name.name, parent=name.parent)
606 for name in safe_get__all__(node, self.current) # type: ignore[arg-type]
607 ]
608 self.extensions.call("on_instance", node=node, obj=attribute, agent=self)
609 self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self)
611 def visit_assign(self, node: ast.Assign) -> None:
612 """Visit an assignment node.
614 Parameters:
615 node: The node to visit.
616 """
617 self.handle_attribute(node)
619 def visit_annassign(self, node: ast.AnnAssign) -> None:
620 """Visit an annotated assignment node.
622 Parameters:
623 node: The node to visit.
624 """
625 self.handle_attribute(node, safe_get_annotation(node.annotation, parent=self.current))
627 def visit_augassign(self, node: ast.AugAssign) -> None:
628 """Visit an augmented assignment node.
630 Parameters:
631 node: The node to visit.
632 """
633 with suppress(AttributeError):
634 all_augment = (
635 node.target.id == "__all__" # type: ignore[union-attr]
636 and self.current.is_module
637 and isinstance(node.op, ast.Add)
638 )
639 if all_augment:
640 # we assume exports is not None at this point
641 self.current.exports.extend( # type: ignore[union-attr]
642 [
643 name if isinstance(name, str) else ExprName(name.name, parent=name.parent)
644 for name in safe_get__all__(node, self.current) # type: ignore[arg-type]
645 ],
646 )
648 def visit_if(self, node: ast.If) -> None:
649 """Visit an "if" node.
651 Parameters:
652 node: The node to visit.
653 """
654 if isinstance(node.parent, (ast.Module, ast.ClassDef)): # type: ignore[attr-defined]
655 condition = safe_get_condition(node.test, parent=self.current, log_level=None)
656 if str(condition) in {"typing.TYPE_CHECKING", "TYPE_CHECKING"}:
657 self.type_guarded = True
658 self.generic_visit(node)
659 self.type_guarded = False