Coverage for src/_griffe/agents/visitor.py: 98.34%

261 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-15 16:47 +0200

1# This module contains our static analysis agent, 

2# capable of parsing and visiting sources, statically. 

3 

4from __future__ import annotations 

5 

6import ast 

7from contextlib import suppress 

8from typing import TYPE_CHECKING, Any 

9 

10from _griffe.agents.nodes.assignments import get_instance_names, get_names 

11from _griffe.agents.nodes.ast import ( 

12 ast_children, 

13 ast_kind, 

14 ast_next, 

15) 

16from _griffe.agents.nodes.docstrings import get_docstring 

17from _griffe.agents.nodes.exports import safe_get__all__ 

18from _griffe.agents.nodes.imports import relative_to_absolute 

19from _griffe.agents.nodes.parameters import get_parameters 

20from _griffe.collections import LinesCollection, ModulesCollection 

21from _griffe.enumerations import Kind 

22from _griffe.exceptions import AliasResolutionError, CyclicAliasError, LastNodeError 

23from _griffe.expressions import ( 

24 Expr, 

25 ExprName, 

26 safe_get_annotation, 

27 safe_get_base_class, 

28 safe_get_condition, 

29 safe_get_expression, 

30) 

31from _griffe.extensions.base import Extensions, load_extensions 

32from _griffe.models import Alias, Attribute, Class, Decorator, Docstring, Function, Module, Parameter, Parameters 

33 

34if TYPE_CHECKING: 

35 from pathlib import Path 

36 

37 from _griffe.enumerations import Parser 

38 

39 

40builtin_decorators = { 

41 "property": "property", 

42 "staticmethod": "staticmethod", 

43 "classmethod": "classmethod", 

44} 

45"""Mapping of builtin decorators to labels.""" 

46 

47stdlib_decorators = { 

48 "abc.abstractmethod": {"abstractmethod"}, 

49 "functools.cache": {"cached"}, 

50 "functools.cached_property": {"cached", "property"}, 

51 "cached_property.cached_property": {"cached", "property"}, 

52 "functools.lru_cache": {"cached"}, 

53 "dataclasses.dataclass": {"dataclass"}, 

54} 

55"""Mapping of standard library decorators to labels.""" 

56 

57typing_overload = {"typing.overload", "typing_extensions.overload"} 

58"""Set of recognized typing overload decorators. 

59 

60When such a decorator is found, the decorated function becomes an overload. 

61""" 

62 

63 

64def visit( 

65 module_name: str, 

66 filepath: Path, 

67 code: str, 

68 *, 

69 extensions: Extensions | None = None, 

70 parent: Module | None = None, 

71 docstring_parser: Parser | None = None, 

72 docstring_options: dict[str, Any] | None = None, 

73 lines_collection: LinesCollection | None = None, 

74 modules_collection: ModulesCollection | None = None, 

75) -> Module: 

76 """Parse and visit a module file. 

77 

78 We provide this function for static analysis. It uses a [`NodeVisitor`][ast.NodeVisitor]-like class, 

79 the [`Visitor`][griffe.Visitor], to compile and parse code (using [`compile`][]) 

80 then visit the resulting AST (Abstract Syntax Tree). 

81 

82 Important: 

83 This function is generally not used directly. 

84 In most cases, users can rely on the [`GriffeLoader`][griffe.GriffeLoader] 

85 and its accompanying [`load`][griffe.load] shortcut and their respective options 

86 to load modules using static analysis. 

87 

88 Parameters: 

89 module_name: The module name (as when importing [from] it). 

90 filepath: The module file path. 

91 code: The module contents. 

92 extensions: The extensions to use when visiting the AST. 

93 parent: The optional parent of this module. 

94 docstring_parser: The docstring parser to use. By default, no parsing is done. 

95 docstring_options: Additional docstring parsing options. 

96 lines_collection: A collection of source code lines. 

97 modules_collection: A collection of modules. 

98 

99 Returns: 

100 The module, with its members populated. 

101 """ 

102 return Visitor( 

103 module_name, 

104 filepath, 

105 code, 

106 extensions or load_extensions(), 

107 parent, 

108 docstring_parser=docstring_parser, 

109 docstring_options=docstring_options, 

110 lines_collection=lines_collection, 

111 modules_collection=modules_collection, 

112 ).get_module() 

113 

114 

115class Visitor: 

116 """This class is used to instantiate a visitor. 

117 

118 Visitors iterate on AST nodes to extract data from them. 

119 """ 

120 

121 def __init__( 

122 self, 

123 module_name: str, 

124 filepath: Path, 

125 code: str, 

126 extensions: Extensions, 

127 parent: Module | None = None, 

128 docstring_parser: Parser | None = None, 

129 docstring_options: dict[str, Any] | None = None, 

130 lines_collection: LinesCollection | None = None, 

131 modules_collection: ModulesCollection | None = None, 

132 ) -> None: 

133 """Initialize the visitor. 

134 

135 Parameters: 

136 module_name: The module name. 

137 filepath: The module filepath. 

138 code: The module source code. 

139 extensions: The extensions to use when visiting. 

140 parent: An optional parent for the final module object. 

141 docstring_parser: The docstring parser to use. 

142 docstring_options: The docstring parsing options. 

143 lines_collection: A collection of source code lines. 

144 modules_collection: A collection of modules. 

145 """ 

146 super().__init__() 

147 

148 self.module_name: str = module_name 

149 """The module name.""" 

150 

151 self.filepath: Path = filepath 

152 """The module filepath.""" 

153 

154 self.code: str = code 

155 """The module source code.""" 

156 

157 self.extensions: Extensions = extensions 

158 """The extensions to use when visiting the AST.""" 

159 

160 self.parent: Module | None = parent 

161 """An optional parent for the final module object.""" 

162 

163 self.current: Module | Class = None # type: ignore[assignment] 

164 """The current object being visited.""" 

165 

166 self.docstring_parser: Parser | None = docstring_parser 

167 """The docstring parser to use.""" 

168 

169 self.docstring_options: dict[str, Any] = docstring_options or {} 

170 """The docstring parsing options.""" 

171 

172 self.lines_collection: LinesCollection = lines_collection or LinesCollection() 

173 """A collection of source code lines.""" 

174 

175 self.modules_collection: ModulesCollection = modules_collection or ModulesCollection() 

176 """A collection of modules.""" 

177 

178 self.type_guarded: bool = False 

179 """Whether the current code branch is type-guarded.""" 

180 

181 def _get_docstring(self, node: ast.AST, *, strict: bool = False) -> Docstring | None: 

182 value, lineno, endlineno = get_docstring(node, strict=strict) 

183 if value is None: 

184 return None 

185 return Docstring( 

186 value, 

187 lineno=lineno, 

188 endlineno=endlineno, 

189 parser=self.docstring_parser, 

190 parser_options=self.docstring_options, 

191 ) 

192 

193 def get_module(self) -> Module: 

194 """Build and return the object representing the module attached to this visitor. 

195 

196 This method triggers a complete visit of the module nodes. 

197 

198 Returns: 

199 A module instance. 

200 """ 

201 # optimization: equivalent to ast.parse, but with optimize=1 to remove assert statements 

202 # TODO: with options, could use optimize=2 to remove docstrings 

203 top_node = compile(self.code, mode="exec", filename=str(self.filepath), flags=ast.PyCF_ONLY_AST, optimize=1) 

204 self.visit(top_node) 

205 return self.current.module 

206 

207 def visit(self, node: ast.AST) -> None: 

208 """Extend the base visit with extensions. 

209 

210 Parameters: 

211 node: The node to visit. 

212 """ 

213 getattr(self, f"visit_{ast_kind(node)}", self.generic_visit)(node) 

214 

215 def generic_visit(self, node: ast.AST) -> None: 

216 """Extend the base generic visit with extensions. 

217 

218 Parameters: 

219 node: The node to visit. 

220 """ 

221 for child in ast_children(node): 

222 self.visit(child) 

223 

224 def visit_module(self, node: ast.Module) -> None: 

225 """Visit a module node. 

226 

227 Parameters: 

228 node: The node to visit. 

229 """ 

230 self.extensions.call("on_node", node=node, agent=self) 

231 self.extensions.call("on_module_node", node=node, agent=self) 

232 self.current = module = Module( 

233 name=self.module_name, 

234 filepath=self.filepath, 

235 parent=self.parent, 

236 docstring=self._get_docstring(node), 

237 lines_collection=self.lines_collection, 

238 modules_collection=self.modules_collection, 

239 ) 

240 self.extensions.call("on_instance", node=node, obj=module, agent=self) 

241 self.extensions.call("on_module_instance", node=node, mod=module, agent=self) 

242 self.generic_visit(node) 

243 self.extensions.call("on_members", node=node, obj=module, agent=self) 

244 self.extensions.call("on_module_members", node=node, mod=module, agent=self) 

245 

246 def visit_classdef(self, node: ast.ClassDef) -> None: 

247 """Visit a class definition node. 

248 

249 Parameters: 

250 node: The node to visit. 

251 """ 

252 self.extensions.call("on_node", node=node, agent=self) 

253 self.extensions.call("on_class_node", node=node, agent=self) 

254 

255 # handle decorators 

256 decorators = [] 

257 if node.decorator_list: 

258 lineno = node.decorator_list[0].lineno 

259 for decorator_node in node.decorator_list: 

260 decorators.append( 

261 Decorator( 

262 safe_get_expression(decorator_node, parent=self.current, parse_strings=False), # type: ignore[arg-type] 

263 lineno=decorator_node.lineno, 

264 endlineno=decorator_node.end_lineno, 

265 ), 

266 ) 

267 else: 

268 lineno = node.lineno 

269 

270 # handle base classes 

271 bases = [] 

272 if node.bases: 

273 for base in node.bases: 

274 bases.append(safe_get_base_class(base, parent=self.current)) 

275 

276 class_ = Class( 

277 name=node.name, 

278 lineno=lineno, 

279 endlineno=node.end_lineno, 

280 docstring=self._get_docstring(node), 

281 decorators=decorators, 

282 bases=bases, # type: ignore[arg-type] 

283 runtime=not self.type_guarded, 

284 ) 

285 class_.labels |= self.decorators_to_labels(decorators) 

286 self.current.set_member(node.name, class_) 

287 self.current = class_ 

288 self.extensions.call("on_instance", node=node, obj=class_, agent=self) 

289 self.extensions.call("on_class_instance", node=node, cls=class_, agent=self) 

290 self.generic_visit(node) 

291 self.extensions.call("on_members", node=node, obj=class_, agent=self) 

292 self.extensions.call("on_class_members", node=node, cls=class_, agent=self) 

293 self.current = self.current.parent # type: ignore[assignment] 

294 

295 def decorators_to_labels(self, decorators: list[Decorator]) -> set[str]: 

296 """Build and return a set of labels based on decorators. 

297 

298 Parameters: 

299 decorators: The decorators to check. 

300 

301 Returns: 

302 A set of labels. 

303 """ 

304 labels = set() 

305 for decorator in decorators: 

306 callable_path = decorator.callable_path 

307 if callable_path in builtin_decorators: 

308 labels.add(builtin_decorators[callable_path]) 

309 elif callable_path in stdlib_decorators: 

310 labels |= stdlib_decorators[callable_path] 

311 return labels 

312 

313 def get_base_property(self, decorators: list[Decorator], function: Function) -> str | None: 

314 """Check decorators to return the base property in case of setters and deleters. 

315 

316 Parameters: 

317 decorators: The decorators to check. 

318 

319 Returns: 

320 base_property: The property for which the setter/deleted is set. 

321 property_function: Either `"setter"` or `"deleter"`. 

322 """ 

323 for decorator in decorators: 

324 try: 

325 path, prop_function = decorator.callable_path.rsplit(".", 1) 

326 except ValueError: 

327 continue 

328 property_setter_or_deleter = ( 

329 prop_function in {"setter", "deleter"} 

330 and path == function.path 

331 and self.current.get_member(function.name).has_labels("property") 

332 ) 

333 if property_setter_or_deleter: 

334 return prop_function 

335 return None 

336 

337 def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels: set | None = None) -> None: 

338 """Handle a function definition node. 

339 

340 Parameters: 

341 node: The node to visit. 

342 labels: Labels to add to the data object. 

343 """ 

344 self.extensions.call("on_node", node=node, agent=self) 

345 self.extensions.call("on_function_node", node=node, agent=self) 

346 

347 labels = labels or set() 

348 

349 # handle decorators 

350 decorators = [] 

351 overload = False 

352 if node.decorator_list: 

353 lineno = node.decorator_list[0].lineno 

354 for decorator_node in node.decorator_list: 

355 decorator_value = safe_get_expression(decorator_node, parent=self.current, parse_strings=False) 

356 if decorator_value is None: 356 ↛ 357line 356 didn't jump to line 357 because the condition on line 356 was never true

357 continue 

358 decorator = Decorator( 

359 decorator_value, 

360 lineno=decorator_node.lineno, 

361 endlineno=decorator_node.end_lineno, 

362 ) 

363 decorators.append(decorator) 

364 overload |= decorator.callable_path in typing_overload 

365 else: 

366 lineno = node.lineno 

367 

368 labels |= self.decorators_to_labels(decorators) 

369 

370 if "property" in labels: 

371 attribute = Attribute( 

372 name=node.name, 

373 value=None, 

374 annotation=safe_get_annotation(node.returns, parent=self.current), 

375 lineno=node.lineno, 

376 endlineno=node.end_lineno, 

377 docstring=self._get_docstring(node), 

378 runtime=not self.type_guarded, 

379 ) 

380 attribute.labels |= labels 

381 self.current.set_member(node.name, attribute) 

382 self.extensions.call("on_instance", node=node, obj=attribute, agent=self) 

383 self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self) 

384 return 

385 

386 # handle parameters 

387 parameters = Parameters( 

388 *[ 

389 Parameter( 

390 name, 

391 kind=kind, 

392 annotation=safe_get_annotation(annotation, parent=self.current), 

393 default=default 

394 if isinstance(default, str) 

395 else safe_get_expression(default, parent=self.current, parse_strings=False), 

396 ) 

397 for name, annotation, kind, default in get_parameters(node.args) 

398 ], 

399 ) 

400 

401 function = Function( 

402 name=node.name, 

403 lineno=lineno, 

404 endlineno=node.end_lineno, 

405 parameters=parameters, 

406 returns=safe_get_annotation(node.returns, parent=self.current), 

407 decorators=decorators, 

408 docstring=self._get_docstring(node), 

409 runtime=not self.type_guarded, 

410 parent=self.current, 

411 ) 

412 

413 property_function = self.get_base_property(decorators, function) 

414 

415 if overload: 

416 self.current.overloads[function.name].append(function) 

417 elif property_function: 

418 base_property: Attribute = self.current.members[node.name] # type: ignore[assignment] 

419 if property_function == "setter": 

420 base_property.setter = function 

421 base_property.labels.add("writable") 

422 elif property_function == "deleter": 422 ↛ 431line 422 didn't jump to line 431 because the condition on line 422 was always true

423 base_property.deleter = function 

424 base_property.labels.add("deletable") 

425 else: 

426 self.current.set_member(node.name, function) 

427 if self.current.kind in {Kind.MODULE, Kind.CLASS} and self.current.overloads[function.name]: 

428 function.overloads = self.current.overloads[function.name] 

429 del self.current.overloads[function.name] 

430 

431 function.labels |= labels 

432 

433 self.extensions.call("on_instance", node=node, obj=function, agent=self) 

434 self.extensions.call("on_function_instance", node=node, func=function, agent=self) 

435 if self.current.kind is Kind.CLASS and function.name == "__init__": 

436 self.current = function # type: ignore[assignment] # temporary assign a function 

437 self.generic_visit(node) 

438 self.current = self.current.parent # type: ignore[assignment] 

439 

440 def visit_functiondef(self, node: ast.FunctionDef) -> None: 

441 """Visit a function definition node. 

442 

443 Parameters: 

444 node: The node to visit. 

445 """ 

446 self.handle_function(node) 

447 

448 def visit_asyncfunctiondef(self, node: ast.AsyncFunctionDef) -> None: 

449 """Visit an async function definition node. 

450 

451 Parameters: 

452 node: The node to visit. 

453 """ 

454 self.handle_function(node, labels={"async"}) 

455 

456 def visit_import(self, node: ast.Import) -> None: 

457 """Visit an import node. 

458 

459 Parameters: 

460 node: The node to visit. 

461 """ 

462 for name in node.names: 

463 alias_path = name.name if name.asname else name.name.split(".", 1)[0] 

464 alias_name = name.asname or alias_path.split(".", 1)[0] 

465 self.current.imports[alias_name] = alias_path 

466 self.current.set_member( 

467 alias_name, 

468 Alias( 

469 alias_name, 

470 alias_path, 

471 lineno=node.lineno, 

472 endlineno=node.end_lineno, 

473 runtime=not self.type_guarded, 

474 ), 

475 ) 

476 

477 def visit_importfrom(self, node: ast.ImportFrom) -> None: 

478 """Visit an "import from" node. 

479 

480 Parameters: 

481 node: The node to visit. 

482 """ 

483 for name in node.names: 

484 if not node.module and node.level == 1 and not name.asname and self.current.module.is_init_module: 

485 # special case: when being in `a/__init__.py` and doing `from . import b`, 

486 # we are effectively creating a member `b` in `a` that is pointing to `a.b` 

487 # -> cyclic alias! in that case, we just skip it, as both the member and module 

488 # have the same name and can be accessed the same way 

489 continue 

490 

491 alias_path = relative_to_absolute(node, name, self.current.module) 

492 if name.name == "*": 

493 alias_name = alias_path.replace(".", "/") 

494 alias_path = alias_path.replace(".*", "") 

495 else: 

496 alias_name = name.asname or name.name 

497 self.current.imports[alias_name] = alias_path 

498 # Do not create aliases pointing to themselves (it happens with 

499 # `from package.current_module import Thing as Thing` or 

500 # `from . import thing as thing`). 

501 if alias_path != f"{self.current.path}.{alias_name}": 

502 self.current.set_member( 

503 alias_name, 

504 Alias( 

505 alias_name, 

506 alias_path, 

507 lineno=node.lineno, 

508 endlineno=node.end_lineno, 

509 runtime=not self.type_guarded, 

510 ), 

511 ) 

512 

513 def handle_attribute( 

514 self, 

515 node: ast.Assign | ast.AnnAssign, 

516 annotation: str | Expr | None = None, 

517 ) -> None: 

518 """Handle an attribute (assignment) node. 

519 

520 Parameters: 

521 node: The node to visit. 

522 annotation: A potential annotation. 

523 """ 

524 self.extensions.call("on_node", node=node, agent=self) 

525 self.extensions.call("on_attribute_node", node=node, agent=self) 

526 parent = self.current 

527 labels = set() 

528 

529 if parent.kind is Kind.MODULE: 

530 try: 

531 names = get_names(node) 

532 except KeyError: # unsupported nodes, like subscript 

533 return 

534 labels.add("module-attribute") 

535 elif parent.kind is Kind.CLASS: 

536 try: 

537 names = get_names(node) 

538 except KeyError: # unsupported nodes, like subscript 

539 return 

540 

541 if isinstance(annotation, Expr) and annotation.is_classvar: 

542 # explicit classvar: class attribute only 

543 annotation = annotation.slice # type: ignore[attr-defined] 

544 labels.add("class-attribute") 

545 elif node.value: 

546 # attribute assigned at class-level: available in instances as well 

547 labels.add("class-attribute") 

548 labels.add("instance-attribute") 

549 else: 

550 # annotated attribute only: not available at class-level 

551 labels.add("instance-attribute") 

552 

553 elif parent.kind is Kind.FUNCTION: 553 ↛ 563line 553 didn't jump to line 563 because the condition on line 553 was always true

554 if parent.name != "__init__": 554 ↛ 555line 554 didn't jump to line 555 because the condition on line 554 was never true

555 return 

556 try: 

557 names = get_instance_names(node) 

558 except KeyError: # unsupported nodes, like subscript 

559 return 

560 parent = parent.parent # type: ignore[assignment] 

561 labels.add("instance-attribute") 

562 

563 if not names: 

564 return 

565 

566 value = safe_get_expression(node.value, parent=self.current, parse_strings=False) 

567 

568 try: 

569 docstring = self._get_docstring(ast_next(node), strict=True) 

570 except (LastNodeError, AttributeError): 

571 docstring = None 

572 

573 for name in names: 

574 # TODO: handle assigns like x.y = z 

575 # we need to resolve x.y and add z in its member 

576 if "." in name: 

577 continue 

578 

579 if name in parent.members: 

580 # assigning multiple times 

581 # TODO: might be better to inspect 

582 if isinstance(node.parent, (ast.If, ast.ExceptHandler)): # type: ignore[union-attr] 

583 continue # prefer "no-exception" case 

584 

585 existing_member = parent.members[name] 

586 with suppress(AliasResolutionError, CyclicAliasError): 

587 labels |= existing_member.labels 

588 # forward previous docstring and annotation instead of erasing them 

589 if existing_member.docstring and not docstring: 

590 docstring = existing_member.docstring 

591 with suppress(AttributeError): 

592 if existing_member.annotation and not annotation: # type: ignore[union-attr] 

593 annotation = existing_member.annotation # type: ignore[union-attr] 

594 

595 attribute = Attribute( 

596 name=name, 

597 value=value, 

598 annotation=annotation, 

599 lineno=node.lineno, 

600 endlineno=node.end_lineno, 

601 docstring=docstring, 

602 runtime=not self.type_guarded, 

603 ) 

604 attribute.labels |= labels 

605 parent.set_member(name, attribute) 

606 

607 if name == "__all__": 

608 with suppress(AttributeError): 

609 parent.exports = [ 

610 name if isinstance(name, str) else ExprName(name.name, parent=name.parent) 

611 for name in safe_get__all__(node, self.current) # type: ignore[arg-type] 

612 ] 

613 self.extensions.call("on_instance", node=node, obj=attribute, agent=self) 

614 self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self) 

615 

616 def visit_assign(self, node: ast.Assign) -> None: 

617 """Visit an assignment node. 

618 

619 Parameters: 

620 node: The node to visit. 

621 """ 

622 self.handle_attribute(node) 

623 

624 def visit_annassign(self, node: ast.AnnAssign) -> None: 

625 """Visit an annotated assignment node. 

626 

627 Parameters: 

628 node: The node to visit. 

629 """ 

630 self.handle_attribute(node, safe_get_annotation(node.annotation, parent=self.current)) 

631 

632 def visit_augassign(self, node: ast.AugAssign) -> None: 

633 """Visit an augmented assignment node. 

634 

635 Parameters: 

636 node: The node to visit. 

637 """ 

638 with suppress(AttributeError): 

639 all_augment = ( 

640 node.target.id == "__all__" # type: ignore[union-attr] 

641 and self.current.is_module 

642 and isinstance(node.op, ast.Add) 

643 ) 

644 if all_augment: 

645 # we assume exports is not None at this point 

646 self.current.exports.extend( # type: ignore[union-attr] 

647 [ 

648 name if isinstance(name, str) else ExprName(name.name, parent=name.parent) 

649 for name in safe_get__all__(node, self.current) # type: ignore[arg-type] 

650 ], 

651 ) 

652 

653 def visit_if(self, node: ast.If) -> None: 

654 """Visit an "if" node. 

655 

656 Parameters: 

657 node: The node to visit. 

658 """ 

659 if isinstance(node.parent, (ast.Module, ast.ClassDef)): # type: ignore[attr-defined] 

660 condition = safe_get_condition(node.test, parent=self.current, log_level=None) 

661 if str(condition) in {"typing.TYPE_CHECKING", "TYPE_CHECKING"}: 

662 self.type_guarded = True 

663 self.generic_visit(node) 

664 self.type_guarded = False