Coverage for src/_griffe/agents/nodes/exports.py: 93.10%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-15 16:47 +0200

1# This module contains utilities for extracting exports from `__all__` assignments. 

2 

3from __future__ import annotations 

4 

5import ast 

6from contextlib import suppress 

7from dataclasses import dataclass 

8from typing import TYPE_CHECKING, Any, Callable 

9 

10from _griffe.agents.nodes.values import get_value 

11from _griffe.enumerations import LogLevel 

12from _griffe.logger import logger 

13 

14if TYPE_CHECKING: 

15 from _griffe.models import Module 

16 

17 

18@dataclass 

19class ExportedName: 

20 """An intermediate class to store names.""" 

21 

22 name: str 

23 """The exported name.""" 

24 parent: Module 

25 """The parent module.""" 

26 

27 

28def _extract_constant(node: ast.Constant, parent: Module) -> list[str | ExportedName]: 

29 return [node.value] 

30 

31 

32def _extract_name(node: ast.Name, parent: Module) -> list[str | ExportedName]: 

33 return [ExportedName(node.id, parent)] 

34 

35 

36def _extract_starred(node: ast.Starred, parent: Module) -> list[str | ExportedName]: 

37 return _extract(node.value, parent) 

38 

39 

40def _extract_sequence(node: ast.List | ast.Set | ast.Tuple, parent: Module) -> list[str | ExportedName]: 

41 sequence = [] 

42 for elt in node.elts: 

43 sequence.extend(_extract(elt, parent)) 

44 return sequence 

45 

46 

47def _extract_binop(node: ast.BinOp, parent: Module) -> list[str | ExportedName]: 

48 left = _extract(node.left, parent) 

49 right = _extract(node.right, parent) 

50 return left + right 

51 

52 

53_node_map: dict[type, Callable[[Any, Module], list[str | ExportedName]]] = { 

54 ast.Constant: _extract_constant, 

55 ast.Name: _extract_name, 

56 ast.Starred: _extract_starred, 

57 ast.List: _extract_sequence, 

58 ast.Set: _extract_sequence, 

59 ast.Tuple: _extract_sequence, 

60 ast.BinOp: _extract_binop, 

61} 

62 

63 

64def _extract(node: ast.AST, parent: Module) -> list[str | ExportedName]: 

65 return _node_map[type(node)](node, parent) 

66 

67 

68def get__all__(node: ast.Assign | ast.AnnAssign | ast.AugAssign, parent: Module) -> list[str | ExportedName]: 

69 """Get the values declared in `__all__`. 

70 

71 Parameters: 

72 node: The assignment node. 

73 parent: The parent module. 

74 

75 Returns: 

76 A set of names. 

77 """ 

78 if node.value is None: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 return [] 

80 return _extract(node.value, parent) 

81 

82 

83def safe_get__all__( 

84 node: ast.Assign | ast.AnnAssign | ast.AugAssign, 

85 parent: Module, 

86 log_level: LogLevel = LogLevel.debug, # TODO: set to error when we handle more things 

87) -> list[str | ExportedName]: 

88 """Safely (no exception) extract values in `__all__`. 

89 

90 Parameters: 

91 node: The `__all__` assignment node. 

92 parent: The parent used to resolve the names. 

93 log_level: Log level to use to log a message. 

94 

95 Returns: 

96 A list of strings or resovable names. 

97 """ 

98 try: 

99 return get__all__(node, parent) 

100 except Exception as error: # noqa: BLE001 

101 message = f"Failed to extract `__all__` value: {get_value(node.value)}" 

102 with suppress(Exception): 

103 message += f" at {parent.relative_filepath}:{node.lineno}" 

104 if isinstance(error, KeyError): 104 ↛ 107line 104 didn't jump to line 107 because the condition on line 104 was always true

105 message += f": unsupported node {error}" 

106 else: 

107 message += f": {error}" 

108 getattr(logger, log_level.value)(message) 

109 return []