Coverage for src/_griffe/extensions/base.py: 83.33%
112 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-12 01:34 +0200
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-12 01:34 +0200
1# This module contains the base class for extensions
2# and the functions to load them.
4from __future__ import annotations
6import os
7import sys
8from importlib.util import module_from_spec, spec_from_file_location
9from inspect import isclass
10from pathlib import Path
11from typing import TYPE_CHECKING, Any, Union
13from _griffe.agents.nodes.ast import ast_children, ast_kind
14from _griffe.exceptions import ExtensionNotLoadedError
15from _griffe.importer import dynamic_import
17if TYPE_CHECKING:
18 import ast
19 from types import ModuleType
21 from _griffe.agents.inspector import Inspector
22 from _griffe.agents.nodes.runtime import ObjectNode
23 from _griffe.agents.visitor import Visitor
24 from _griffe.loader import GriffeLoader
25 from _griffe.models import Alias, Attribute, Class, Function, Module, Object
28class Extension:
29 """Base class for Griffe extensions."""
31 def visit(self, node: ast.AST) -> None:
32 """Visit a node.
34 Parameters:
35 node: The node to visit.
36 """
37 getattr(self, f"visit_{ast_kind(node)}", lambda _: None)(node)
39 def generic_visit(self, node: ast.AST) -> None:
40 """Visit children nodes.
42 Parameters:
43 node: The node to visit the children of.
44 """
45 for child in ast_children(node):
46 self.visit(child)
48 def inspect(self, node: ObjectNode) -> None:
49 """Inspect a node.
51 Parameters:
52 node: The node to inspect.
53 """
54 getattr(self, f"inspect_{node.kind}", lambda _: None)(node)
56 def generic_inspect(self, node: ObjectNode) -> None:
57 """Extend the base generic inspection with extensions.
59 Parameters:
60 node: The node to inspect.
61 """
62 for child in node.children:
63 if not child.alias_target_path:
64 self.inspect(child)
66 def on_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
67 """Run when visiting a new node during static/dynamic analysis.
69 Parameters:
70 node: The currently visited node.
71 """
73 def on_instance(
74 self,
75 *,
76 node: ast.AST | ObjectNode,
77 obj: Object,
78 agent: Visitor | Inspector,
79 **kwargs: Any,
80 ) -> None:
81 """Run when an Object has been created.
83 Parameters:
84 node: The currently visited node.
85 obj: The object instance.
86 agent: The analysis agent currently running.
87 **kwargs: For forward-compatibility.
88 """
90 def on_members(self, *, node: ast.AST | ObjectNode, obj: Object, agent: Visitor | Inspector, **kwargs: Any) -> None:
91 """Run when members of an Object have been loaded.
93 Parameters:
94 node: The currently visited node.
95 obj: The object instance.
96 agent: The analysis agent currently running.
97 **kwargs: For forward-compatibility.
98 """
100 def on_module_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
101 """Run when visiting a new module node during static/dynamic analysis.
103 Parameters:
104 node: The currently visited node.
105 agent: The analysis agent currently running.
106 **kwargs: For forward-compatibility.
107 """
109 def on_module_instance(
110 self,
111 *,
112 node: ast.AST | ObjectNode,
113 mod: Module,
114 agent: Visitor | Inspector,
115 **kwargs: Any,
116 ) -> None:
117 """Run when a Module has been created.
119 Parameters:
120 node: The currently visited node.
121 mod: The module instance.
122 agent: The analysis agent currently running.
123 **kwargs: For forward-compatibility.
124 """
126 def on_module_members(
127 self,
128 *,
129 node: ast.AST | ObjectNode,
130 mod: Module,
131 agent: Visitor | Inspector,
132 **kwargs: Any,
133 ) -> None:
134 """Run when members of a Module have been loaded.
136 Parameters:
137 node: The currently visited node.
138 mod: The module instance.
139 agent: The analysis agent currently running.
140 **kwargs: For forward-compatibility.
141 """
143 def on_class_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
144 """Run when visiting a new class node during static/dynamic analysis.
146 Parameters:
147 node: The currently visited node.
148 agent: The analysis agent currently running.
149 **kwargs: For forward-compatibility.
150 """
152 def on_class_instance(
153 self,
154 *,
155 node: ast.AST | ObjectNode,
156 cls: Class,
157 agent: Visitor | Inspector,
158 **kwargs: Any,
159 ) -> None:
160 """Run when a Class has been created.
162 Parameters:
163 node: The currently visited node.
164 cls: The class instance.
165 agent: The analysis agent currently running.
166 **kwargs: For forward-compatibility.
167 """
169 def on_class_members(
170 self,
171 *,
172 node: ast.AST | ObjectNode,
173 cls: Class,
174 agent: Visitor | Inspector,
175 **kwargs: Any,
176 ) -> None:
177 """Run when members of a Class have been loaded.
179 Parameters:
180 node: The currently visited node.
181 cls: The class instance.
182 agent: The analysis agent currently running.
183 **kwargs: For forward-compatibility.
184 """
186 def on_function_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
187 """Run when visiting a new function node during static/dynamic analysis.
189 Parameters:
190 node: The currently visited node.
191 agent: The analysis agent currently running.
192 **kwargs: For forward-compatibility.
193 """
195 def on_function_instance(
196 self,
197 *,
198 node: ast.AST | ObjectNode,
199 func: Function,
200 agent: Visitor | Inspector,
201 **kwargs: Any,
202 ) -> None:
203 """Run when a Function has been created.
205 Parameters:
206 node: The currently visited node.
207 func: The function instance.
208 agent: The analysis agent currently running.
209 **kwargs: For forward-compatibility.
210 """
212 def on_attribute_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None:
213 """Run when visiting a new attribute node during static/dynamic analysis.
215 Parameters:
216 node: The currently visited node.
217 agent: The analysis agent currently running.
218 **kwargs: For forward-compatibility.
219 """
221 def on_attribute_instance(
222 self,
223 *,
224 node: ast.AST | ObjectNode,
225 attr: Attribute,
226 agent: Visitor | Inspector,
227 **kwargs: Any,
228 ) -> None:
229 """Run when an Attribute has been created.
231 Parameters:
232 node: The currently visited node.
233 attr: The attribute instance.
234 agent: The analysis agent currently running.
235 **kwargs: For forward-compatibility.
236 """
238 def on_alias(
239 self,
240 *,
241 node: ast.AST | ObjectNode,
242 alias: Alias,
243 agent: Visitor | Inspector,
244 **kwargs: Any,
245 ) -> None:
246 """Run when an Alias has been created.
248 Parameters:
249 node: The currently visited node.
250 alias: The alias instance.
251 agent: The analysis agent currently running.
252 **kwargs: For forward-compatibility.
253 """
255 def on_package_loaded(self, *, pkg: Module, loader: GriffeLoader, **kwargs: Any) -> None:
256 """Run when a package has been completely loaded.
258 Parameters:
259 pkg: The package (Module) instance.
260 loader: The loader currently in use.
261 **kwargs: For forward-compatibility.
262 """
264 def on_wildcard_expansion(
265 self,
266 *,
267 alias: Alias,
268 loader: GriffeLoader,
269 **kwargs: Any,
270 ) -> None:
271 """Run when wildcard imports are expanded into aliases.
273 Parameters:
274 alias: The alias instance.
275 loader: The loader currently in use.
276 **kwargs: For forward-compatibility.
277 """
280LoadableExtensionType = Union[str, dict[str, Any], Extension, type[Extension]]
281"""All the types that can be passed to `load_extensions`."""
284class Extensions:
285 """This class helps iterating on extensions that should run at different times."""
287 def __init__(self, *extensions: Extension) -> None:
288 """Initialize the extensions container.
290 Parameters:
291 *extensions: The extensions to add.
292 """
293 self._extensions: list[Extension] = []
294 self.add(*extensions)
296 def add(self, *extensions: Extension) -> None:
297 """Add extensions to this container.
299 Parameters:
300 *extensions: The extensions to add.
301 """
302 for extension in extensions:
303 self._extensions.append(extension)
305 def call(self, event: str, **kwargs: Any) -> None:
306 """Call the extension hook for the given event.
308 Parameters:
309 event: The triggered event.
310 **kwargs: Arguments passed to the hook.
311 """
312 for extension in self._extensions:
313 getattr(extension, event)(**kwargs)
316builtin_extensions: set[str] = {
317 "dataclasses",
318}
319"""The names of built-in Griffe extensions."""
322def _load_extension_path(path: str) -> ModuleType:
323 module_name = os.path.basename(path).rsplit(".", 1)[0] # noqa: PTH119
324 spec = spec_from_file_location(module_name, path)
325 if not spec: 325 ↛ 326line 325 didn't jump to line 326 because the condition on line 325 was never true
326 raise ExtensionNotLoadedError(f"Could not import module from path '{path}'")
327 module = module_from_spec(spec)
328 sys.modules[module_name] = module
329 spec.loader.exec_module(module) # type: ignore[union-attr]
330 return module
333def _load_extension(
334 extension: str | dict[str, Any] | Extension | type[Extension],
335) -> Extension | list[Extension]:
336 """Load a configured extension.
338 Parameters:
339 extension: An extension, with potential configuration options.
341 Raises:
342 ExtensionNotLoadedError: When the extension cannot be loaded,
343 either because the module is not found, or because it does not expose
344 the Extension attribute. ImportError will bubble up so users can see
345 the traceback.
347 Returns:
348 An extension instance.
349 """
350 ext_object = None
352 # If it's already an extension instance, return it.
353 if isinstance(extension, Extension):
354 return extension
356 # If it's an extension class, instantiate it (without options) and return it.
357 if isclass(extension) and issubclass(extension, Extension):
358 return extension()
360 # If it's a dictionary, we expect the only key to be an import path
361 # and the value to be a dictionary of options.
362 if isinstance(extension, dict):
363 import_path, options = next(iter(extension.items()))
364 # Force path to be a string, as it could have been passed from `mkdocs.yml`,
365 # using the custom YAML tag `!relative`, which gives an instance of MkDocs
366 # path placeholder classes, which are not iterable.
367 import_path = str(import_path)
369 # Otherwise we consider it's an import path, without options.
370 else:
371 import_path = str(extension)
372 options = {}
374 # If the import path contains a colon, we split into path and class name.
375 colons = import_path.count(":")
376 # Special case for The Annoying Operating System.
377 if colons > 1 or (colons and ":" not in Path(import_path).drive):
378 import_path, extension_name = import_path.rsplit(":", 1)
379 else:
380 extension_name = None
382 # If the import path corresponds to a built-in extension, expand it.
383 if import_path in builtin_extensions:
384 import_path = f"_griffe.extensions.{import_path}"
385 # If the import path is a path to an existing file, load it.
386 elif os.path.exists(import_path): # noqa: PTH110
387 try:
388 ext_object = _load_extension_path(import_path)
389 except ImportError as error:
390 raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error
392 # If the extension wasn't loaded yet, we consider the import path
393 # to be a Python dotted path like `package.module` or `package.module.Extension`.
394 if not ext_object:
395 try:
396 ext_object = dynamic_import(import_path)
397 except ModuleNotFoundError as error:
398 raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error
399 except ImportError as error:
400 raise ExtensionNotLoadedError(f"Error while importing extension '{import_path}': {error}") from error
402 # If the loaded object is an extension class, instantiate it with options and return it.
403 if isclass(ext_object) and issubclass(ext_object, Extension):
404 return ext_object(**options)
406 # Otherwise the loaded object is a module, so we get the extension class by name,
407 # instantiate it with options and return it.
408 if extension_name:
409 try:
410 return getattr(ext_object, extension_name)(**options)
411 except AttributeError as error:
412 raise ExtensionNotLoadedError(
413 f"Extension module '{import_path}' has no '{extension_name}' attribute",
414 ) from error
416 # No class name was specified so we search all extension classes in the module,
417 # instantiate each with the same options, and return them.
418 extensions = [
419 obj for obj in vars(ext_object).values() if isclass(obj) and issubclass(obj, Extension) and obj is not Extension
420 ]
421 return [ext(**options) for ext in extensions]
424def load_extensions(*exts: LoadableExtensionType) -> Extensions:
425 """Load configured extensions.
427 Parameters:
428 exts: Extensions with potential configuration options.
430 Returns:
431 An extensions container.
432 """
433 extensions = Extensions()
435 for extension in exts:
436 ext = _load_extension(extension)
437 if isinstance(ext, list):
438 extensions.add(*ext)
439 else:
440 extensions.add(ext)
442 # TODO: Deprecate and remove at some point?
443 # Always add our built-in dataclasses extension.
444 from _griffe.extensions.dataclasses import DataclassesExtension
446 for ext in extensions._extensions:
447 if type(ext) is DataclassesExtension: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true
448 break
449 else:
450 extensions.add(*_load_extension("dataclasses")) # type: ignore[misc]
452 return extensions