Coverage for src/_griffe/extensions/dataclasses.py: 95.65%

118 statements  

« prev     ^ index     » next       coverage.py v7.6.2, created at 2024-10-12 01:34 +0200

1# Built-in extension adding support for dataclasses. 

2# 

3# This extension re-creates `__init__` methods of dataclasses 

4# during static analysis. 

5 

6from __future__ import annotations 

7 

8import ast 

9from contextlib import suppress 

10from functools import cache 

11from typing import Any, cast 

12 

13from _griffe.enumerations import ParameterKind 

14from _griffe.expressions import ( 

15 Expr, 

16 ExprAttribute, 

17 ExprCall, 

18 ExprDict, 

19) 

20from _griffe.extensions.base import Extension 

21from _griffe.logger import logger 

22from _griffe.models import Attribute, Class, Decorator, Function, Module, Parameter, Parameters 

23 

24 

25def _dataclass_decorator(decorators: list[Decorator]) -> Expr | None: 

26 for decorator in decorators: 

27 if isinstance(decorator.value, Expr) and decorator.value.canonical_path == "dataclasses.dataclass": 

28 return decorator.value 

29 return None 

30 

31 

32def _expr_args(expr: Expr) -> dict[str, str | Expr]: 

33 args = {} 

34 if isinstance(expr, ExprCall): 

35 for argument in expr.arguments: 

36 try: 

37 args[argument.name] = argument.value # type: ignore[union-attr] 

38 except AttributeError: 

39 # Argument is a unpacked variable. 

40 with suppress(Exception): 

41 collection = expr.function.parent.modules_collection # type: ignore[attr-defined] 

42 var = collection[argument.value.canonical_path] # type: ignore[union-attr] 

43 args.update(_expr_args(var.value)) 

44 elif isinstance(expr, ExprDict): 44 ↛ 46line 44 didn't jump to line 46 because the condition on line 44 was always true

45 args.update({ast.literal_eval(str(key)): value for key, value in zip(expr.keys, expr.values)}) 

46 return args 

47 

48 

49def _dataclass_arguments(decorators: list[Decorator]) -> dict[str, Any]: 

50 if (expr := _dataclass_decorator(decorators)) and isinstance(expr, ExprCall): 

51 return _expr_args(expr) 

52 return {} 

53 

54 

55def _field_arguments(attribute: Attribute) -> dict[str, Any]: 

56 if attribute.value: 

57 value = attribute.value 

58 if isinstance(value, ExprAttribute): 

59 value = value.last 

60 if isinstance(value, ExprCall) and value.canonical_path == "dataclasses.field": 

61 return _expr_args(value) 

62 return {} 

63 

64 

65@cache 

66def _dataclass_parameters(class_: Class) -> list[Parameter]: 

67 # Fetch `@dataclass` arguments if any. 

68 dec_args = _dataclass_arguments(class_.decorators) 

69 

70 # Parameters not added to `__init__`, return empty list. 

71 if dec_args.get("init") == "False": 

72 return [] 

73 

74 # All parameters marked as keyword-only. 

75 kw_only = dec_args.get("kw_only") == "True" 

76 

77 # Iterate on current attributes to find parameters. 

78 parameters = [] 

79 for member in class_.members.values(): 

80 if member.is_attribute: 

81 member = cast(Attribute, member) 

82 

83 # All dataclass parameters have annotations 

84 if member.annotation is None: 

85 continue 

86 

87 # Attributes that have labels for these characteristics are 

88 # not class parameters: 

89 # - @property 

90 # - @cached_property 

91 # - ClassVar annotation 

92 if "property" in member.labels or ( 

93 # TODO: It is better to explicitly check for ClassVar, but 

94 # Visitor.handle_attribute unwraps it from the annotation. 

95 # Maybe create internal_labels and store classvar in there. 

96 "class-attribute" in member.labels and "instance-attribute" not in member.labels 

97 ): 

98 continue 

99 

100 # Start of keyword-only parameters. 

101 if isinstance(member.annotation, Expr) and member.annotation.canonical_path == "dataclasses.KW_ONLY": 

102 kw_only = True 

103 continue 

104 

105 # Fetch `field` arguments if any. 

106 field_args = _field_arguments(member) 

107 

108 # Parameter not added to `__init__`, skip it. 

109 if field_args.get("init") == "False": 

110 continue 

111 

112 # Determine parameter kind. 

113 kind = ( 

114 ParameterKind.keyword_only 

115 if kw_only or field_args.get("kw_only") == "True" 

116 else ParameterKind.positional_or_keyword 

117 ) 

118 

119 # Determine parameter default. 

120 if "default_factory" in field_args: 

121 default = ExprCall(function=field_args["default_factory"], arguments=[]) 

122 else: 

123 default = field_args.get("default", None if field_args else member.value) 

124 

125 # Add parameter to the list. 

126 parameters.append( 

127 Parameter( 

128 member.name, 

129 annotation=member.annotation, 

130 kind=kind, 

131 default=default, 

132 docstring=member.docstring, 

133 ), 

134 ) 

135 

136 return parameters 

137 

138 

139def _reorder_parameters(parameters: list[Parameter]) -> list[Parameter]: 

140 # De-duplicate, overwriting previous parameters. 

141 params_dict = {param.name: param for param in parameters} 

142 

143 # Re-order, putting positional-only in front and keyword-only at the end. 

144 pos_only = [] 

145 pos_kw = [] 

146 kw_only = [] 

147 for param in params_dict.values(): 

148 if param.kind is ParameterKind.positional_only: 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true

149 pos_only.append(param) 

150 elif param.kind is ParameterKind.keyword_only: 

151 kw_only.append(param) 

152 else: 

153 pos_kw.append(param) 

154 return pos_only + pos_kw + kw_only 

155 

156 

157def _set_dataclass_init(class_: Class) -> None: 

158 # Retrieve parameters from all parent dataclasses. 

159 parameters = [] 

160 try: 

161 mro = class_.mro() 

162 except ValueError: 

163 mro = () # type: ignore[assignment] 

164 for parent in reversed(mro): 

165 if _dataclass_decorator(parent.decorators): 

166 parameters.extend(_dataclass_parameters(parent)) 

167 # At least one parent dataclass makes the current class a dataclass: 

168 # that's how `dataclasses.is_dataclass` works. 

169 class_.labels.add("dataclass") 

170 

171 # If the class is not decorated with `@dataclass`, skip it. 

172 if not _dataclass_decorator(class_.decorators): 

173 return 

174 

175 logger.debug("Handling dataclass: %s", class_.path) 

176 

177 # Add current class parameters. 

178 parameters.extend(_dataclass_parameters(class_)) 

179 

180 # Create `__init__` method with re-ordered parameters. 

181 init = Function( 

182 "__init__", 

183 lineno=0, 

184 endlineno=0, 

185 parent=class_, 

186 parameters=Parameters( 

187 Parameter(name="self", annotation=None, kind=ParameterKind.positional_or_keyword, default=None), 

188 *_reorder_parameters(parameters), 

189 ), 

190 returns="None", 

191 ) 

192 class_.set_member("__init__", init) 

193 

194 

195def _del_members_annotated_as_initvar(class_: Class) -> None: 

196 # Definitions annotated as InitVar are not class members 

197 attributes = [member for member in class_.members.values() if isinstance(member, Attribute)] 

198 for attribute in attributes: 

199 if isinstance(attribute.annotation, Expr) and attribute.annotation.canonical_path == "dataclasses.InitVar": 

200 class_.del_member(attribute.name) 

201 

202 

203def _apply_recursively(mod_cls: Module | Class, processed: set[str]) -> None: 

204 if mod_cls.canonical_path in processed: 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true

205 return 

206 processed.add(mod_cls.canonical_path) 

207 if isinstance(mod_cls, Class): 

208 if "__init__" not in mod_cls.members: 

209 _set_dataclass_init(mod_cls) 

210 _del_members_annotated_as_initvar(mod_cls) 

211 for member in mod_cls.members.values(): 

212 if not member.is_alias and member.is_class: 

213 _apply_recursively(member, processed) # type: ignore[arg-type] 

214 elif isinstance(mod_cls, Module): 214 ↛ exitline 214 didn't return from function '_apply_recursively' because the condition on line 214 was always true

215 for member in mod_cls.members.values(): 

216 if not member.is_alias and (member.is_module or member.is_class): 

217 _apply_recursively(member, processed) # type: ignore[arg-type] 

218 

219 

220class DataclassesExtension(Extension): 

221 """Built-in extension adding support for dataclasses. 

222 

223 This extension creates `__init__` methods of dataclasses 

224 if they don't already exist. 

225 """ 

226 

227 def on_package_loaded(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG002 

228 """Hook for loaded packages. 

229 

230 Parameters: 

231 pkg: The loaded package. 

232 """ 

233 _apply_recursively(pkg, set())