Coverage for src/_griffe/extensions/dataclasses.py: 95.83%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-15 16:47 +0200

1# Built-in extension adding support for dataclasses. 

2# 

3# This extension re-creates `__init__` methods of dataclasses 

4# during static analysis. 

5 

6from __future__ import annotations 

7 

8import ast 

9from contextlib import suppress 

10from functools import lru_cache 

11from typing import Any, cast 

12 

13from _griffe.enumerations import ParameterKind 

14from _griffe.expressions import ( 

15 Expr, 

16 ExprAttribute, 

17 ExprCall, 

18 ExprDict, 

19) 

20from _griffe.extensions.base import Extension 

21from _griffe.models import Attribute, Class, Decorator, Function, Module, Parameter, Parameters 

22 

23 

24def _dataclass_decorator(decorators: list[Decorator]) -> Expr | None: 

25 for decorator in decorators: 

26 if isinstance(decorator.value, Expr) and decorator.value.canonical_path == "dataclasses.dataclass": 

27 return decorator.value 

28 return None 

29 

30 

31def _expr_args(expr: Expr) -> dict[str, str | Expr]: 

32 args = {} 

33 if isinstance(expr, ExprCall): 

34 for argument in expr.arguments: 

35 try: 

36 args[argument.name] = argument.value # type: ignore[union-attr] 

37 except AttributeError: 

38 # Argument is a unpacked variable. 

39 with suppress(Exception): 

40 collection = expr.function.parent.modules_collection # type: ignore[attr-defined] 

41 var = collection[argument.value.canonical_path] # type: ignore[union-attr] 

42 args.update(_expr_args(var.value)) 

43 elif isinstance(expr, ExprDict): 43 ↛ 45line 43 didn't jump to line 45 because the condition on line 43 was always true

44 args.update({ast.literal_eval(str(key)): value for key, value in zip(expr.keys, expr.values)}) 

45 return args 

46 

47 

48def _dataclass_arguments(decorators: list[Decorator]) -> dict[str, Any]: 

49 if (expr := _dataclass_decorator(decorators)) and isinstance(expr, ExprCall): 

50 return _expr_args(expr) 

51 return {} 

52 

53 

54def _field_arguments(attribute: Attribute) -> dict[str, Any]: 

55 if attribute.value: 

56 value = attribute.value 

57 if isinstance(value, ExprAttribute): 

58 value = value.last 

59 if isinstance(value, ExprCall) and value.canonical_path == "dataclasses.field": 

60 return _expr_args(value) 

61 return {} 

62 

63 

64@lru_cache(maxsize=None) 

65def _dataclass_parameters(class_: Class) -> list[Parameter]: 

66 # Fetch `@dataclass` arguments if any. 

67 dec_args = _dataclass_arguments(class_.decorators) 

68 

69 # Parameters not added to `__init__`, return empty list. 

70 if dec_args.get("init") == "False": 

71 return [] 

72 

73 # All parameters marked as keyword-only. 

74 kw_only = dec_args.get("kw_only") == "True" 

75 

76 # Iterate on current attributes to find parameters. 

77 parameters = [] 

78 for member in class_.members.values(): 

79 if member.is_attribute: 

80 member = cast(Attribute, member) 

81 

82 # All dataclass parameters have annotations 

83 if member.annotation is None: 

84 continue 

85 

86 # Attributes that have labels for these characteristics are 

87 # not class parameters: 

88 # - @property 

89 # - @cached_property 

90 # - ClassVar annotation 

91 if "property" in member.labels or ( 

92 # TODO: It is better to explicitly check for ClassVar, but 

93 # Visitor.handle_attribute unwraps it from the annotation. 

94 # Maybe create internal_labels and store classvar in there. 

95 "class-attribute" in member.labels and "instance-attribute" not in member.labels 

96 ): 

97 continue 

98 

99 # Start of keyword-only parameters. 

100 if isinstance(member.annotation, Expr) and member.annotation.canonical_path == "dataclasses.KW_ONLY": 

101 kw_only = True 

102 continue 

103 

104 # Fetch `field` arguments if any. 

105 field_args = _field_arguments(member) 

106 

107 # Parameter not added to `__init__`, skip it. 

108 if field_args.get("init") == "False": 

109 continue 

110 

111 # Determine parameter kind. 

112 kind = ( 

113 ParameterKind.keyword_only 

114 if kw_only or field_args.get("kw_only") == "True" 

115 else ParameterKind.positional_or_keyword 

116 ) 

117 

118 # Determine parameter default. 

119 if "default_factory" in field_args: 

120 default = ExprCall(function=field_args["default_factory"], arguments=[]) 

121 else: 

122 default = field_args.get("default", None if field_args else member.value) 

123 

124 # Add parameter to the list. 

125 parameters.append( 

126 Parameter( 

127 member.name, 

128 annotation=member.annotation, 

129 kind=kind, 

130 default=default, 

131 docstring=member.docstring, 

132 ), 

133 ) 

134 

135 return parameters 

136 

137 

138def _reorder_parameters(parameters: list[Parameter]) -> list[Parameter]: 

139 # De-duplicate, overwriting previous parameters. 

140 params_dict = {param.name: param for param in parameters} 

141 

142 # Re-order, putting positional-only in front and keyword-only at the end. 

143 pos_only = [] 

144 pos_kw = [] 

145 kw_only = [] 

146 for param in params_dict.values(): 

147 if param.kind is ParameterKind.positional_only: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true

148 pos_only.append(param) 

149 elif param.kind is ParameterKind.keyword_only: 

150 kw_only.append(param) 

151 else: 

152 pos_kw.append(param) 

153 return pos_only + pos_kw + kw_only 

154 

155 

156def _set_dataclass_init(class_: Class) -> None: 

157 # Retrieve parameters from all parent dataclasses. 

158 parameters = [] 

159 try: 

160 mro = class_.mro() 

161 except ValueError: 

162 mro = () # type: ignore[assignment] 

163 for parent in reversed(mro): 

164 if _dataclass_decorator(parent.decorators): 

165 parameters.extend(_dataclass_parameters(parent)) 

166 # At least one parent dataclass makes the current class a dataclass: 

167 # that's how `dataclasses.is_dataclass` works. 

168 class_.labels.add("dataclass") 

169 

170 # If the class is not decorated with `@dataclass`, skip it. 

171 if not _dataclass_decorator(class_.decorators): 

172 return 

173 

174 # Add current class parameters. 

175 parameters.extend(_dataclass_parameters(class_)) 

176 

177 # Create `__init__` method with re-ordered parameters. 

178 init = Function( 

179 "__init__", 

180 lineno=0, 

181 endlineno=0, 

182 parent=class_, 

183 parameters=Parameters( 

184 Parameter(name="self", annotation=None, kind=ParameterKind.positional_or_keyword, default=None), 

185 *_reorder_parameters(parameters), 

186 ), 

187 returns="None", 

188 ) 

189 class_.set_member("__init__", init) 

190 

191 

192def _del_members_annotated_as_initvar(class_: Class) -> None: 

193 # Definitions annotated as InitVar are not class members 

194 attributes = [member for member in class_.members.values() if isinstance(member, Attribute)] 

195 for attribute in attributes: 

196 if isinstance(attribute.annotation, Expr) and attribute.annotation.canonical_path == "dataclasses.InitVar": 

197 class_.del_member(attribute.name) 

198 

199 

200def _apply_recursively(mod_cls: Module | Class, processed: set[str]) -> None: 

201 if mod_cls.canonical_path in processed: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true

202 return 

203 processed.add(mod_cls.canonical_path) 

204 if isinstance(mod_cls, Class): 

205 if "__init__" not in mod_cls.members: 

206 _set_dataclass_init(mod_cls) 

207 _del_members_annotated_as_initvar(mod_cls) 

208 for member in mod_cls.members.values(): 

209 if not member.is_alias and member.is_class: 

210 _apply_recursively(member, processed) # type: ignore[arg-type] 

211 elif isinstance(mod_cls, Module): 211 ↛ exitline 211 didn't return from function '_apply_recursively' because the condition on line 211 was always true

212 for member in mod_cls.members.values(): 

213 if not member.is_alias and (member.is_module or member.is_class): 

214 _apply_recursively(member, processed) # type: ignore[arg-type] 

215 

216 

217class DataclassesExtension(Extension): 

218 """Built-in extension adding support for dataclasses. 

219 

220 This extension creates `__init__` methods of dataclasses 

221 if they don't already exist. 

222 """ 

223 

224 def on_package_loaded(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG002 

225 """Hook for loaded packages. 

226 

227 Parameters: 

228 pkg: The loaded package. 

229 """ 

230 _apply_recursively(pkg, set())