Coverage for src/_griffe/agents/visitor.py: 98.25%

261 statements  

« prev     ^ index     » next       coverage.py v7.6.2, created at 2024-10-12 01:34 +0200

1# This module contains our static analysis agent, 

2# capable of parsing and visiting sources, statically. 

3 

4from __future__ import annotations 

5 

6import ast 

7from contextlib import suppress 

8from typing import TYPE_CHECKING, Any 

9 

10from _griffe.agents.nodes.assignments import get_instance_names, get_names 

11from _griffe.agents.nodes.ast import ( 

12 ast_children, 

13 ast_kind, 

14 ast_next, 

15) 

16from _griffe.agents.nodes.docstrings import get_docstring 

17from _griffe.agents.nodes.exports import safe_get__all__ 

18from _griffe.agents.nodes.imports import relative_to_absolute 

19from _griffe.agents.nodes.parameters import get_parameters 

20from _griffe.collections import LinesCollection, ModulesCollection 

21from _griffe.enumerations import Kind 

22from _griffe.exceptions import AliasResolutionError, CyclicAliasError, LastNodeError 

23from _griffe.expressions import ( 

24 Expr, 

25 ExprName, 

26 safe_get_annotation, 

27 safe_get_base_class, 

28 safe_get_condition, 

29 safe_get_expression, 

30) 

31from _griffe.extensions.base import Extensions, load_extensions 

32from _griffe.models import Alias, Attribute, Class, Decorator, Docstring, Function, Module, Parameter, Parameters 

33 

34if TYPE_CHECKING: 

35 from pathlib import Path 

36 

37 from _griffe.enumerations import Parser 

38 

39 

40builtin_decorators = { 

41 "property": "property", 

42 "staticmethod": "staticmethod", 

43 "classmethod": "classmethod", 

44} 

45"""Mapping of builtin decorators to labels.""" 

46 

47stdlib_decorators = { 

48 "abc.abstractmethod": {"abstractmethod"}, 

49 "functools.cache": {"cached"}, 

50 "functools.cached_property": {"cached", "property"}, 

51 "cached_property.cached_property": {"cached", "property"}, 

52 "functools.lru_cache": {"cached"}, 

53 "dataclasses.dataclass": {"dataclass"}, 

54} 

55"""Mapping of standard library decorators to labels.""" 

56 

57typing_overload = {"typing.overload", "typing_extensions.overload"} 

58"""Set of recognized typing overload decorators. 

59 

60When such a decorator is found, the decorated function becomes an overload. 

61""" 

62 

63 

64def visit( 

65 module_name: str, 

66 filepath: Path, 

67 code: str, 

68 *, 

69 extensions: Extensions | None = None, 

70 parent: Module | None = None, 

71 docstring_parser: Parser | None = None, 

72 docstring_options: dict[str, Any] | None = None, 

73 lines_collection: LinesCollection | None = None, 

74 modules_collection: ModulesCollection | None = None, 

75) -> Module: 

76 """Parse and visit a module file. 

77 

78 We provide this function for static analysis. It uses a [`NodeVisitor`][ast.NodeVisitor]-like class, 

79 the [`Visitor`][griffe.Visitor], to compile and parse code (using [`compile`][]) 

80 then visit the resulting AST (Abstract Syntax Tree). 

81 

82 Important: 

83 This function is generally not used directly. 

84 In most cases, users can rely on the [`GriffeLoader`][griffe.GriffeLoader] 

85 and its accompanying [`load`][griffe.load] shortcut and their respective options 

86 to load modules using static analysis. 

87 

88 Parameters: 

89 module_name: The module name (as when importing [from] it). 

90 filepath: The module file path. 

91 code: The module contents. 

92 extensions: The extensions to use when visiting the AST. 

93 parent: The optional parent of this module. 

94 docstring_parser: The docstring parser to use. By default, no parsing is done. 

95 docstring_options: Additional docstring parsing options. 

96 lines_collection: A collection of source code lines. 

97 modules_collection: A collection of modules. 

98 

99 Returns: 

100 The module, with its members populated. 

101 """ 

102 return Visitor( 

103 module_name, 

104 filepath, 

105 code, 

106 extensions or load_extensions(), 

107 parent, 

108 docstring_parser=docstring_parser, 

109 docstring_options=docstring_options, 

110 lines_collection=lines_collection, 

111 modules_collection=modules_collection, 

112 ).get_module() 

113 

114 

115class Visitor: 

116 """This class is used to instantiate a visitor. 

117 

118 Visitors iterate on AST nodes to extract data from them. 

119 """ 

120 

121 def __init__( 

122 self, 

123 module_name: str, 

124 filepath: Path, 

125 code: str, 

126 extensions: Extensions, 

127 parent: Module | None = None, 

128 docstring_parser: Parser | None = None, 

129 docstring_options: dict[str, Any] | None = None, 

130 lines_collection: LinesCollection | None = None, 

131 modules_collection: ModulesCollection | None = None, 

132 ) -> None: 

133 """Initialize the visitor. 

134 

135 Parameters: 

136 module_name: The module name. 

137 filepath: The module filepath. 

138 code: The module source code. 

139 extensions: The extensions to use when visiting. 

140 parent: An optional parent for the final module object. 

141 docstring_parser: The docstring parser to use. 

142 docstring_options: The docstring parsing options. 

143 lines_collection: A collection of source code lines. 

144 modules_collection: A collection of modules. 

145 """ 

146 super().__init__() 

147 

148 self.module_name: str = module_name 

149 """The module name.""" 

150 

151 self.filepath: Path = filepath 

152 """The module filepath.""" 

153 

154 self.code: str = code 

155 """The module source code.""" 

156 

157 self.extensions: Extensions = extensions 

158 """The extensions to use when visiting the AST.""" 

159 

160 self.parent: Module | None = parent 

161 """An optional parent for the final module object.""" 

162 

163 self.current: Module | Class = None # type: ignore[assignment] 

164 """The current object being visited.""" 

165 

166 self.docstring_parser: Parser | None = docstring_parser 

167 """The docstring parser to use.""" 

168 

169 self.docstring_options: dict[str, Any] = docstring_options or {} 

170 """The docstring parsing options.""" 

171 

172 self.lines_collection: LinesCollection = lines_collection or LinesCollection() 

173 """A collection of source code lines.""" 

174 

175 self.modules_collection: ModulesCollection = modules_collection or ModulesCollection() 

176 """A collection of modules.""" 

177 

178 self.type_guarded: bool = False 

179 """Whether the current code branch is type-guarded.""" 

180 

181 def _get_docstring(self, node: ast.AST, *, strict: bool = False) -> Docstring | None: 

182 value, lineno, endlineno = get_docstring(node, strict=strict) 

183 if value is None: 

184 return None 

185 return Docstring( 

186 value, 

187 lineno=lineno, 

188 endlineno=endlineno, 

189 parser=self.docstring_parser, 

190 parser_options=self.docstring_options, 

191 ) 

192 

193 def get_module(self) -> Module: 

194 """Build and return the object representing the module attached to this visitor. 

195 

196 This method triggers a complete visit of the module nodes. 

197 

198 Returns: 

199 A module instance. 

200 """ 

201 # optimization: equivalent to ast.parse, but with optimize=1 to remove assert statements 

202 # TODO: with options, could use optimize=2 to remove docstrings 

203 top_node = compile(self.code, mode="exec", filename=str(self.filepath), flags=ast.PyCF_ONLY_AST, optimize=1) 

204 self.visit(top_node) 

205 return self.current.module 

206 

207 def visit(self, node: ast.AST) -> None: 

208 """Extend the base visit with extensions. 

209 

210 Parameters: 

211 node: The node to visit. 

212 """ 

213 getattr(self, f"visit_{ast_kind(node)}", self.generic_visit)(node) 

214 

215 def generic_visit(self, node: ast.AST) -> None: 

216 """Extend the base generic visit with extensions. 

217 

218 Parameters: 

219 node: The node to visit. 

220 """ 

221 for child in ast_children(node): 

222 self.visit(child) 

223 

224 def visit_module(self, node: ast.Module) -> None: 

225 """Visit a module node. 

226 

227 Parameters: 

228 node: The node to visit. 

229 """ 

230 self.extensions.call("on_node", node=node, agent=self) 

231 self.extensions.call("on_module_node", node=node, agent=self) 

232 self.current = module = Module( 

233 name=self.module_name, 

234 filepath=self.filepath, 

235 parent=self.parent, 

236 docstring=self._get_docstring(node), 

237 lines_collection=self.lines_collection, 

238 modules_collection=self.modules_collection, 

239 ) 

240 self.extensions.call("on_instance", node=node, obj=module, agent=self) 

241 self.extensions.call("on_module_instance", node=node, mod=module, agent=self) 

242 self.generic_visit(node) 

243 self.extensions.call("on_members", node=node, obj=module, agent=self) 

244 self.extensions.call("on_module_members", node=node, mod=module, agent=self) 

245 

246 def visit_classdef(self, node: ast.ClassDef) -> None: 

247 """Visit a class definition node. 

248 

249 Parameters: 

250 node: The node to visit. 

251 """ 

252 self.extensions.call("on_node", node=node, agent=self) 

253 self.extensions.call("on_class_node", node=node, agent=self) 

254 

255 # handle decorators 

256 decorators: list[Decorator] = [] 

257 if node.decorator_list: 

258 lineno = node.decorator_list[0].lineno 

259 decorators.extend( 

260 Decorator( 

261 safe_get_expression(decorator_node, parent=self.current, parse_strings=False), # type: ignore[arg-type] 

262 lineno=decorator_node.lineno, 

263 endlineno=decorator_node.end_lineno, 

264 ) 

265 for decorator_node in node.decorator_list 

266 ) 

267 else: 

268 lineno = node.lineno 

269 

270 # handle base classes 

271 bases = [safe_get_base_class(base, parent=self.current) for base in node.bases] 

272 

273 class_ = Class( 

274 name=node.name, 

275 lineno=lineno, 

276 endlineno=node.end_lineno, 

277 docstring=self._get_docstring(node), 

278 decorators=decorators, 

279 bases=bases, # type: ignore[arg-type] 

280 runtime=not self.type_guarded, 

281 ) 

282 class_.labels |= self.decorators_to_labels(decorators) 

283 self.current.set_member(node.name, class_) 

284 self.current = class_ 

285 self.extensions.call("on_instance", node=node, obj=class_, agent=self) 

286 self.extensions.call("on_class_instance", node=node, cls=class_, agent=self) 

287 self.generic_visit(node) 

288 self.extensions.call("on_members", node=node, obj=class_, agent=self) 

289 self.extensions.call("on_class_members", node=node, cls=class_, agent=self) 

290 self.current = self.current.parent # type: ignore[assignment] 

291 

292 def decorators_to_labels(self, decorators: list[Decorator]) -> set[str]: 

293 """Build and return a set of labels based on decorators. 

294 

295 Parameters: 

296 decorators: The decorators to check. 

297 

298 Returns: 

299 A set of labels. 

300 """ 

301 labels = set() 

302 for decorator in decorators: 

303 callable_path = decorator.callable_path 

304 if callable_path in builtin_decorators: 

305 labels.add(builtin_decorators[callable_path]) 

306 elif callable_path in stdlib_decorators: 

307 labels |= stdlib_decorators[callable_path] 

308 return labels 

309 

310 def get_base_property(self, decorators: list[Decorator], function: Function) -> str | None: 

311 """Check decorators to return the base property in case of setters and deleters. 

312 

313 Parameters: 

314 decorators: The decorators to check. 

315 

316 Returns: 

317 base_property: The property for which the setter/deleted is set. 

318 property_function: Either `"setter"` or `"deleter"`. 

319 """ 

320 for decorator in decorators: 

321 try: 

322 path, prop_function = decorator.callable_path.rsplit(".", 1) 

323 except ValueError: 

324 continue 

325 property_setter_or_deleter = ( 

326 prop_function in {"setter", "deleter"} 

327 and path == function.path 

328 and self.current.get_member(function.name).has_labels("property") 

329 ) 

330 if property_setter_or_deleter: 

331 return prop_function 

332 return None 

333 

334 def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels: set | None = None) -> None: 

335 """Handle a function definition node. 

336 

337 Parameters: 

338 node: The node to visit. 

339 labels: Labels to add to the data object. 

340 """ 

341 self.extensions.call("on_node", node=node, agent=self) 

342 self.extensions.call("on_function_node", node=node, agent=self) 

343 

344 labels = labels or set() 

345 

346 # handle decorators 

347 decorators = [] 

348 overload = False 

349 if node.decorator_list: 

350 lineno = node.decorator_list[0].lineno 

351 for decorator_node in node.decorator_list: 

352 decorator_value = safe_get_expression(decorator_node, parent=self.current, parse_strings=False) 

353 if decorator_value is None: 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true

354 continue 

355 decorator = Decorator( 

356 decorator_value, 

357 lineno=decorator_node.lineno, 

358 endlineno=decorator_node.end_lineno, 

359 ) 

360 decorators.append(decorator) 

361 overload |= decorator.callable_path in typing_overload 

362 else: 

363 lineno = node.lineno 

364 

365 labels |= self.decorators_to_labels(decorators) 

366 

367 if "property" in labels: 

368 attribute = Attribute( 

369 name=node.name, 

370 value=None, 

371 annotation=safe_get_annotation(node.returns, parent=self.current), 

372 lineno=node.lineno, 

373 endlineno=node.end_lineno, 

374 docstring=self._get_docstring(node), 

375 runtime=not self.type_guarded, 

376 ) 

377 attribute.labels |= labels 

378 self.current.set_member(node.name, attribute) 

379 self.extensions.call("on_instance", node=node, obj=attribute, agent=self) 

380 self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self) 

381 return 

382 

383 # handle parameters 

384 parameters = Parameters( 

385 *[ 

386 Parameter( 

387 name, 

388 kind=kind, 

389 annotation=safe_get_annotation(annotation, parent=self.current), 

390 default=default 

391 if isinstance(default, str) 

392 else safe_get_expression(default, parent=self.current, parse_strings=False), 

393 ) 

394 for name, annotation, kind, default in get_parameters(node.args) 

395 ], 

396 ) 

397 

398 function = Function( 

399 name=node.name, 

400 lineno=lineno, 

401 endlineno=node.end_lineno, 

402 parameters=parameters, 

403 returns=safe_get_annotation(node.returns, parent=self.current), 

404 decorators=decorators, 

405 docstring=self._get_docstring(node), 

406 runtime=not self.type_guarded, 

407 parent=self.current, 

408 ) 

409 

410 property_function = self.get_base_property(decorators, function) 

411 

412 if overload: 

413 self.current.overloads[function.name].append(function) 

414 elif property_function: 

415 base_property: Attribute = self.current.members[node.name] # type: ignore[assignment] 

416 if property_function == "setter": 

417 base_property.setter = function 

418 base_property.labels.add("writable") 

419 elif property_function == "deleter": 419 ↛ 428line 419 didn't jump to line 428 because the condition on line 419 was always true

420 base_property.deleter = function 

421 base_property.labels.add("deletable") 

422 else: 

423 self.current.set_member(node.name, function) 

424 if self.current.kind in {Kind.MODULE, Kind.CLASS} and self.current.overloads[function.name]: 

425 function.overloads = self.current.overloads[function.name] 

426 del self.current.overloads[function.name] 

427 

428 function.labels |= labels 

429 

430 self.extensions.call("on_instance", node=node, obj=function, agent=self) 

431 self.extensions.call("on_function_instance", node=node, func=function, agent=self) 

432 if self.current.kind is Kind.CLASS and function.name == "__init__": 

433 self.current = function # type: ignore[assignment] # temporary assign a function 

434 self.generic_visit(node) 

435 self.current = self.current.parent # type: ignore[assignment] 

436 

437 def visit_functiondef(self, node: ast.FunctionDef) -> None: 

438 """Visit a function definition node. 

439 

440 Parameters: 

441 node: The node to visit. 

442 """ 

443 self.handle_function(node) 

444 

445 def visit_asyncfunctiondef(self, node: ast.AsyncFunctionDef) -> None: 

446 """Visit an async function definition node. 

447 

448 Parameters: 

449 node: The node to visit. 

450 """ 

451 self.handle_function(node, labels={"async"}) 

452 

453 def visit_import(self, node: ast.Import) -> None: 

454 """Visit an import node. 

455 

456 Parameters: 

457 node: The node to visit. 

458 """ 

459 for name in node.names: 

460 alias_path = name.name if name.asname else name.name.split(".", 1)[0] 

461 alias_name = name.asname or alias_path.split(".", 1)[0] 

462 self.current.imports[alias_name] = alias_path 

463 alias = Alias( 

464 alias_name, 

465 alias_path, 

466 lineno=node.lineno, 

467 endlineno=node.end_lineno, 

468 runtime=not self.type_guarded, 

469 ) 

470 self.current.set_member(alias_name, alias) 

471 self.extensions.call("on_alias", alias=alias, node=node, agent=self) 

472 

473 def visit_importfrom(self, node: ast.ImportFrom) -> None: 

474 """Visit an "import from" node. 

475 

476 Parameters: 

477 node: The node to visit. 

478 """ 

479 for name in node.names: 

480 if not node.module and node.level == 1 and not name.asname and self.current.module.is_init_module: 

481 # special case: when being in `a/__init__.py` and doing `from . import b`, 

482 # we are effectively creating a member `b` in `a` that is pointing to `a.b` 

483 # -> cyclic alias! in that case, we just skip it, as both the member and module 

484 # have the same name and can be accessed the same way 

485 continue 

486 

487 alias_path = relative_to_absolute(node, name, self.current.module) 

488 if name.name == "*": 

489 alias_name = alias_path.replace(".", "/") 

490 alias_path = alias_path.replace(".*", "") 

491 else: 

492 alias_name = name.asname or name.name 

493 self.current.imports[alias_name] = alias_path 

494 # Do not create aliases pointing to themselves (it happens with 

495 # `from package.current_module import Thing as Thing` or 

496 # `from . import thing as thing`). 

497 if alias_path != f"{self.current.path}.{alias_name}": 

498 alias = Alias( 

499 alias_name, 

500 alias_path, 

501 lineno=node.lineno, 

502 endlineno=node.end_lineno, 

503 runtime=not self.type_guarded, 

504 ) 

505 self.current.set_member(alias_name, alias) 

506 self.extensions.call("on_alias", alias=alias, node=node, agent=self) 

507 

508 def handle_attribute( 

509 self, 

510 node: ast.Assign | ast.AnnAssign, 

511 annotation: str | Expr | None = None, 

512 ) -> None: 

513 """Handle an attribute (assignment) node. 

514 

515 Parameters: 

516 node: The node to visit. 

517 annotation: A potential annotation. 

518 """ 

519 self.extensions.call("on_node", node=node, agent=self) 

520 self.extensions.call("on_attribute_node", node=node, agent=self) 

521 parent = self.current 

522 labels = set() 

523 

524 if parent.kind is Kind.MODULE: 

525 try: 

526 names = get_names(node) 

527 except KeyError: # unsupported nodes, like subscript 

528 return 

529 labels.add("module-attribute") 

530 elif parent.kind is Kind.CLASS: 

531 try: 

532 names = get_names(node) 

533 except KeyError: # unsupported nodes, like subscript 

534 return 

535 

536 if isinstance(annotation, Expr) and annotation.is_classvar: 

537 # explicit classvar: class attribute only 

538 annotation = annotation.slice # type: ignore[attr-defined] 

539 labels.add("class-attribute") 

540 elif node.value: 

541 # attribute assigned at class-level: available in instances as well 

542 labels.add("class-attribute") 

543 labels.add("instance-attribute") 

544 else: 

545 # annotated attribute only: not available at class-level 

546 labels.add("instance-attribute") 

547 

548 elif parent.kind is Kind.FUNCTION: 548 ↛ 558line 548 didn't jump to line 558 because the condition on line 548 was always true

549 if parent.name != "__init__": 549 ↛ 550line 549 didn't jump to line 550 because the condition on line 549 was never true

550 return 

551 try: 

552 names = get_instance_names(node) 

553 except KeyError: # unsupported nodes, like subscript 

554 return 

555 parent = parent.parent # type: ignore[assignment] 

556 labels.add("instance-attribute") 

557 

558 if not names: 

559 return 

560 

561 value = safe_get_expression(node.value, parent=self.current, parse_strings=False) 

562 

563 try: 

564 docstring = self._get_docstring(ast_next(node), strict=True) 

565 except (LastNodeError, AttributeError): 

566 docstring = None 

567 

568 for name in names: 

569 # TODO: handle assigns like x.y = z 

570 # we need to resolve x.y and add z in its member 

571 if "." in name: 

572 continue 

573 

574 if name in parent.members: 

575 # assigning multiple times 

576 # TODO: might be better to inspect 

577 if isinstance(node.parent, (ast.If, ast.ExceptHandler)): # type: ignore[union-attr] 

578 continue # prefer "no-exception" case 

579 

580 existing_member = parent.members[name] 

581 with suppress(AliasResolutionError, CyclicAliasError): 

582 labels |= existing_member.labels 

583 # forward previous docstring and annotation instead of erasing them 

584 if existing_member.docstring and not docstring: 

585 docstring = existing_member.docstring 

586 with suppress(AttributeError): 

587 if existing_member.annotation and not annotation: # type: ignore[union-attr] 

588 annotation = existing_member.annotation # type: ignore[union-attr] 

589 

590 attribute = Attribute( 

591 name=name, 

592 value=value, 

593 annotation=annotation, 

594 lineno=node.lineno, 

595 endlineno=node.end_lineno, 

596 docstring=docstring, 

597 runtime=not self.type_guarded, 

598 ) 

599 attribute.labels |= labels 

600 parent.set_member(name, attribute) 

601 

602 if name == "__all__": 

603 with suppress(AttributeError): 

604 parent.exports = [ 

605 name if isinstance(name, str) else ExprName(name.name, parent=name.parent) 

606 for name in safe_get__all__(node, self.current) # type: ignore[arg-type] 

607 ] 

608 self.extensions.call("on_instance", node=node, obj=attribute, agent=self) 

609 self.extensions.call("on_attribute_instance", node=node, attr=attribute, agent=self) 

610 

611 def visit_assign(self, node: ast.Assign) -> None: 

612 """Visit an assignment node. 

613 

614 Parameters: 

615 node: The node to visit. 

616 """ 

617 self.handle_attribute(node) 

618 

619 def visit_annassign(self, node: ast.AnnAssign) -> None: 

620 """Visit an annotated assignment node. 

621 

622 Parameters: 

623 node: The node to visit. 

624 """ 

625 self.handle_attribute(node, safe_get_annotation(node.annotation, parent=self.current)) 

626 

627 def visit_augassign(self, node: ast.AugAssign) -> None: 

628 """Visit an augmented assignment node. 

629 

630 Parameters: 

631 node: The node to visit. 

632 """ 

633 with suppress(AttributeError): 

634 all_augment = ( 

635 node.target.id == "__all__" # type: ignore[union-attr] 

636 and self.current.is_module 

637 and isinstance(node.op, ast.Add) 

638 ) 

639 if all_augment: 

640 # we assume exports is not None at this point 

641 self.current.exports.extend( # type: ignore[union-attr] 

642 [ 

643 name if isinstance(name, str) else ExprName(name.name, parent=name.parent) 

644 for name in safe_get__all__(node, self.current) # type: ignore[arg-type] 

645 ], 

646 ) 

647 

648 def visit_if(self, node: ast.If) -> None: 

649 """Visit an "if" node. 

650 

651 Parameters: 

652 node: The node to visit. 

653 """ 

654 if isinstance(node.parent, (ast.Module, ast.ClassDef)): # type: ignore[attr-defined] 

655 condition = safe_get_condition(node.test, parent=self.current, log_level=None) 

656 if str(condition) in {"typing.TYPE_CHECKING", "TYPE_CHECKING"}: 

657 self.type_guarded = True 

658 self.generic_visit(node) 

659 self.type_guarded = False