Coverage for src/_griffe/extensions/base.py: 80.98%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-15 16:47 +0200

1# This module contains the base class for extensions 

2# and the functions to load them. 

3 

4from __future__ import annotations 

5 

6import os 

7import sys 

8from importlib.util import module_from_spec, spec_from_file_location 

9from inspect import isclass 

10from pathlib import Path 

11from typing import TYPE_CHECKING, Any, Dict, Type, Union 

12 

13from _griffe.agents.nodes.ast import ast_children, ast_kind 

14from _griffe.exceptions import ExtensionNotLoadedError 

15from _griffe.importer import dynamic_import 

16 

17if TYPE_CHECKING: 

18 import ast 

19 from types import ModuleType 

20 

21 from _griffe.agents.inspector import Inspector 

22 from _griffe.agents.nodes.runtime import ObjectNode 

23 from _griffe.agents.visitor import Visitor 

24 from _griffe.models import Attribute, Class, Function, Module, Object 

25 

26 

27class Extension: 

28 """Base class for Griffe extensions.""" 

29 

30 def visit(self, node: ast.AST) -> None: 

31 """Visit a node. 

32 

33 Parameters: 

34 node: The node to visit. 

35 """ 

36 getattr(self, f"visit_{ast_kind(node)}", lambda _: None)(node) 

37 

38 def generic_visit(self, node: ast.AST) -> None: 

39 """Visit children nodes. 

40 

41 Parameters: 

42 node: The node to visit the children of. 

43 """ 

44 for child in ast_children(node): 

45 self.visit(child) 

46 

47 def inspect(self, node: ObjectNode) -> None: 

48 """Inspect a node. 

49 

50 Parameters: 

51 node: The node to inspect. 

52 """ 

53 getattr(self, f"inspect_{node.kind}", lambda _: None)(node) 

54 

55 def generic_inspect(self, node: ObjectNode) -> None: 

56 """Extend the base generic inspection with extensions. 

57 

58 Parameters: 

59 node: The node to inspect. 

60 """ 

61 for child in node.children: 

62 if not child.alias_target_path: 

63 self.inspect(child) 

64 

65 def on_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None: 

66 """Run when visiting a new node during static/dynamic analysis. 

67 

68 Parameters: 

69 node: The currently visited node. 

70 """ 

71 

72 def on_instance( 

73 self, 

74 *, 

75 node: ast.AST | ObjectNode, 

76 obj: Object, 

77 agent: Visitor | Inspector, 

78 **kwargs: Any, 

79 ) -> None: 

80 """Run when an Object has been created. 

81 

82 Parameters: 

83 node: The currently visited node. 

84 obj: The object instance. 

85 """ 

86 

87 def on_members(self, *, node: ast.AST | ObjectNode, obj: Object, agent: Visitor | Inspector, **kwargs: Any) -> None: 

88 """Run when members of an Object have been loaded. 

89 

90 Parameters: 

91 node: The currently visited node. 

92 obj: The object instance. 

93 """ 

94 

95 def on_module_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None: 

96 """Run when visiting a new module node during static/dynamic analysis. 

97 

98 Parameters: 

99 node: The currently visited node. 

100 """ 

101 

102 def on_module_instance( 

103 self, 

104 *, 

105 node: ast.AST | ObjectNode, 

106 mod: Module, 

107 agent: Visitor | Inspector, 

108 **kwargs: Any, 

109 ) -> None: 

110 """Run when a Module has been created. 

111 

112 Parameters: 

113 node: The currently visited node. 

114 mod: The module instance. 

115 """ 

116 

117 def on_module_members( 

118 self, 

119 *, 

120 node: ast.AST | ObjectNode, 

121 mod: Module, 

122 agent: Visitor | Inspector, 

123 **kwargs: Any, 

124 ) -> None: 

125 """Run when members of a Module have been loaded. 

126 

127 Parameters: 

128 node: The currently visited node. 

129 mod: The module instance. 

130 """ 

131 

132 def on_class_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None: 

133 """Run when visiting a new class node during static/dynamic analysis. 

134 

135 Parameters: 

136 node: The currently visited node. 

137 """ 

138 

139 def on_class_instance( 

140 self, 

141 *, 

142 node: ast.AST | ObjectNode, 

143 cls: Class, 

144 agent: Visitor | Inspector, 

145 **kwargs: Any, 

146 ) -> None: 

147 """Run when a Class has been created. 

148 

149 Parameters: 

150 node: The currently visited node. 

151 cls: The class instance. 

152 """ 

153 

154 def on_class_members( 

155 self, 

156 *, 

157 node: ast.AST | ObjectNode, 

158 cls: Class, 

159 agent: Visitor | Inspector, 

160 **kwargs: Any, 

161 ) -> None: 

162 """Run when members of a Class have been loaded. 

163 

164 Parameters: 

165 node: The currently visited node. 

166 cls: The class instance. 

167 """ 

168 

169 def on_function_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None: 

170 """Run when visiting a new function node during static/dynamic analysis. 

171 

172 Parameters: 

173 node: The currently visited node. 

174 """ 

175 

176 def on_function_instance( 

177 self, 

178 *, 

179 node: ast.AST | ObjectNode, 

180 func: Function, 

181 agent: Visitor | Inspector, 

182 **kwargs: Any, 

183 ) -> None: 

184 """Run when a Function has been created. 

185 

186 Parameters: 

187 node: The currently visited node. 

188 func: The function instance. 

189 """ 

190 

191 def on_attribute_node(self, *, node: ast.AST | ObjectNode, agent: Visitor | Inspector, **kwargs: Any) -> None: 

192 """Run when visiting a new attribute node during static/dynamic analysis. 

193 

194 Parameters: 

195 node: The currently visited node. 

196 """ 

197 

198 def on_attribute_instance( 

199 self, 

200 *, 

201 node: ast.AST | ObjectNode, 

202 attr: Attribute, 

203 agent: Visitor | Inspector, 

204 **kwargs: Any, 

205 ) -> None: 

206 """Run when an Attribute has been created. 

207 

208 Parameters: 

209 node: The currently visited node. 

210 attr: The attribute instance. 

211 """ 

212 

213 def on_package_loaded(self, *, pkg: Module, **kwargs: Any) -> None: 

214 """Run when a package has been completely loaded. 

215 

216 Parameters: 

217 pkg: The package (Module) instance. 

218 """ 

219 

220 

221LoadableExtensionType = Union[str, Dict[str, Any], Extension, Type[Extension]] 

222"""All the types that can be passed to `load_extensions`.""" 

223 

224 

225class Extensions: 

226 """This class helps iterating on extensions that should run at different times.""" 

227 

228 def __init__(self, *extensions: Extension) -> None: 

229 """Initialize the extensions container. 

230 

231 Parameters: 

232 *extensions: The extensions to add. 

233 """ 

234 self._extensions: list[Extension] = [] 

235 self.add(*extensions) 

236 

237 def add(self, *extensions: Extension) -> None: 

238 """Add extensions to this container. 

239 

240 Parameters: 

241 *extensions: The extensions to add. 

242 """ 

243 for extension in extensions: 

244 self._extensions.append(extension) 

245 

246 def call(self, event: str, **kwargs: Any) -> None: 

247 """Call the extension hook for the given event. 

248 

249 Parameters: 

250 event: The triggered event. 

251 **kwargs: Arguments passed to the hook. 

252 """ 

253 for extension in self._extensions: 

254 getattr(extension, event)(**kwargs) 

255 

256 

257builtin_extensions: set[str] = { 

258 "dataclasses", 

259} 

260"""The names of built-in Griffe extensions.""" 

261 

262 

263def _load_extension_path(path: str) -> ModuleType: 

264 module_name = os.path.basename(path).rsplit(".", 1)[0] 

265 spec = spec_from_file_location(module_name, path) 

266 if not spec: 266 ↛ 267line 266 didn't jump to line 267 because the condition on line 266 was never true

267 raise ExtensionNotLoadedError(f"Could not import module from path '{path}'") 

268 module = module_from_spec(spec) 

269 sys.modules[module_name] = module 

270 spec.loader.exec_module(module) # type: ignore[union-attr] 

271 return module 

272 

273 

274def _load_extension( 

275 extension: str | dict[str, Any] | Extension | type[Extension], 

276) -> Extension | list[Extension]: 

277 """Load a configured extension. 

278 

279 Parameters: 

280 extension: An extension, with potential configuration options. 

281 

282 Raises: 

283 ExtensionNotLoadedError: When the extension cannot be loaded, 

284 either because the module is not found, or because it does not expose 

285 the Extension attribute. ImportError will bubble up so users can see 

286 the traceback. 

287 

288 Returns: 

289 An extension instance. 

290 """ 

291 ext_object = None 

292 

293 # If it's already an extension instance, return it. 

294 if isinstance(extension, Extension): 

295 return extension 

296 

297 # If it's an extension class, instantiate it (without options) and return it. 

298 if isclass(extension) and issubclass(extension, Extension): 

299 return extension() 

300 

301 # If it's a dictionary, we expect the only key to be an import path 

302 # and the value to be a dictionary of options. 

303 if isinstance(extension, dict): 

304 import_path, options = next(iter(extension.items())) 

305 # Force path to be a string, as it could have been passed from `mkdocs.yml`, 

306 # using the custom YAML tag `!relative`, which gives an instance of MkDocs 

307 # path placeholder classes, which are not iterable. 

308 import_path = str(import_path) 

309 

310 # Otherwise we consider it's an import path, without options. 

311 else: 

312 import_path = str(extension) 

313 options = {} 

314 

315 # If the import path contains a colon, we split into path and class name. 

316 colons = import_path.count(":") 

317 # Special case for The Annoying Operating System. 

318 if colons > 1 or (colons and ":" not in Path(import_path).drive): 

319 import_path, extension_name = import_path.rsplit(":", 1) 

320 else: 

321 extension_name = None 

322 

323 # If the import path corresponds to a built-in extension, expand it. 

324 if import_path in builtin_extensions: 

325 import_path = f"_griffe.extensions.{import_path}" 

326 # If the import path is a path to an existing file, load it. 

327 elif os.path.exists(import_path): 

328 try: 

329 ext_object = _load_extension_path(import_path) 

330 except ImportError as error: 

331 raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error 

332 

333 # If the extension wasn't loaded yet, we consider the import path 

334 # to be a Python dotted path like `package.module` or `package.module.Extension`. 

335 if not ext_object: 

336 try: 

337 ext_object = dynamic_import(import_path) 

338 except ModuleNotFoundError as error: 

339 raise ExtensionNotLoadedError(f"Extension module '{import_path}' could not be found") from error 

340 except ImportError as error: 

341 raise ExtensionNotLoadedError(f"Error while importing extension '{import_path}': {error}") from error 

342 

343 # If the loaded object is an extension class, instantiate it with options and return it. 

344 if isclass(ext_object) and issubclass(ext_object, Extension): 

345 return ext_object(**options) 

346 

347 # Otherwise the loaded object is a module, so we get the extension class by name, 

348 # instantiate it with options and return it. 

349 if extension_name: 

350 try: 

351 return getattr(ext_object, extension_name)(**options) 

352 except AttributeError as error: 

353 raise ExtensionNotLoadedError( 

354 f"Extension module '{import_path}' has no '{extension_name}' attribute", 

355 ) from error 

356 

357 # No class name was specified so we search all extension classes in the module, 

358 # instantiate each with the same options, and return them. 

359 extensions = [] 

360 for obj in vars(ext_object).values(): 

361 if isclass(obj) and issubclass(obj, Extension) and obj is not Extension: 

362 extensions.append(obj) 

363 return [ext(**options) for ext in extensions] 

364 

365 

366def load_extensions(*exts: LoadableExtensionType) -> Extensions: 

367 """Load configured extensions. 

368 

369 Parameters: 

370 exts: Extensions with potential configuration options. 

371 

372 Returns: 

373 An extensions container. 

374 """ 

375 extensions = Extensions() 

376 

377 for extension in exts: 

378 ext = _load_extension(extension) 

379 if isinstance(ext, list): 

380 extensions.add(*ext) 

381 else: 

382 extensions.add(ext) 

383 

384 # TODO: Deprecate and remove at some point? 

385 # Always add our built-in dataclasses extension. 

386 from _griffe.extensions.dataclasses import DataclassesExtension 

387 

388 for ext in extensions._extensions: 

389 if type(ext) is DataclassesExtension: 389 ↛ 390line 389 didn't jump to line 390 because the condition on line 389 was never true

390 break 

391 else: 

392 extensions.add(*_load_extension("dataclasses")) # type: ignore[misc] 

393 

394 return extensions