Coverage for src/_griffe/extensions/dataclasses.py: 95.83%
116 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-15 16:47 +0200
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-15 16:47 +0200
1# Built-in extension adding support for dataclasses.
2#
3# This extension re-creates `__init__` methods of dataclasses
4# during static analysis.
6from __future__ import annotations
8import ast
9from contextlib import suppress
10from functools import lru_cache
11from typing import Any, cast
13from _griffe.enumerations import ParameterKind
14from _griffe.expressions import (
15 Expr,
16 ExprAttribute,
17 ExprCall,
18 ExprDict,
19)
20from _griffe.extensions.base import Extension
21from _griffe.models import Attribute, Class, Decorator, Function, Module, Parameter, Parameters
24def _dataclass_decorator(decorators: list[Decorator]) -> Expr | None:
25 for decorator in decorators:
26 if isinstance(decorator.value, Expr) and decorator.value.canonical_path == "dataclasses.dataclass":
27 return decorator.value
28 return None
31def _expr_args(expr: Expr) -> dict[str, str | Expr]:
32 args = {}
33 if isinstance(expr, ExprCall):
34 for argument in expr.arguments:
35 try:
36 args[argument.name] = argument.value # type: ignore[union-attr]
37 except AttributeError:
38 # Argument is a unpacked variable.
39 with suppress(Exception):
40 collection = expr.function.parent.modules_collection # type: ignore[attr-defined]
41 var = collection[argument.value.canonical_path] # type: ignore[union-attr]
42 args.update(_expr_args(var.value))
43 elif isinstance(expr, ExprDict): 43 ↛ 45line 43 didn't jump to line 45 because the condition on line 43 was always true
44 args.update({ast.literal_eval(str(key)): value for key, value in zip(expr.keys, expr.values)})
45 return args
48def _dataclass_arguments(decorators: list[Decorator]) -> dict[str, Any]:
49 if (expr := _dataclass_decorator(decorators)) and isinstance(expr, ExprCall):
50 return _expr_args(expr)
51 return {}
54def _field_arguments(attribute: Attribute) -> dict[str, Any]:
55 if attribute.value:
56 value = attribute.value
57 if isinstance(value, ExprAttribute):
58 value = value.last
59 if isinstance(value, ExprCall) and value.canonical_path == "dataclasses.field":
60 return _expr_args(value)
61 return {}
64@lru_cache(maxsize=None)
65def _dataclass_parameters(class_: Class) -> list[Parameter]:
66 # Fetch `@dataclass` arguments if any.
67 dec_args = _dataclass_arguments(class_.decorators)
69 # Parameters not added to `__init__`, return empty list.
70 if dec_args.get("init") == "False":
71 return []
73 # All parameters marked as keyword-only.
74 kw_only = dec_args.get("kw_only") == "True"
76 # Iterate on current attributes to find parameters.
77 parameters = []
78 for member in class_.members.values():
79 if member.is_attribute:
80 member = cast(Attribute, member)
82 # All dataclass parameters have annotations
83 if member.annotation is None:
84 continue
86 # Attributes that have labels for these characteristics are
87 # not class parameters:
88 # - @property
89 # - @cached_property
90 # - ClassVar annotation
91 if "property" in member.labels or (
92 # TODO: It is better to explicitly check for ClassVar, but
93 # Visitor.handle_attribute unwraps it from the annotation.
94 # Maybe create internal_labels and store classvar in there.
95 "class-attribute" in member.labels and "instance-attribute" not in member.labels
96 ):
97 continue
99 # Start of keyword-only parameters.
100 if isinstance(member.annotation, Expr) and member.annotation.canonical_path == "dataclasses.KW_ONLY":
101 kw_only = True
102 continue
104 # Fetch `field` arguments if any.
105 field_args = _field_arguments(member)
107 # Parameter not added to `__init__`, skip it.
108 if field_args.get("init") == "False":
109 continue
111 # Determine parameter kind.
112 kind = (
113 ParameterKind.keyword_only
114 if kw_only or field_args.get("kw_only") == "True"
115 else ParameterKind.positional_or_keyword
116 )
118 # Determine parameter default.
119 if "default_factory" in field_args:
120 default = ExprCall(function=field_args["default_factory"], arguments=[])
121 else:
122 default = field_args.get("default", None if field_args else member.value)
124 # Add parameter to the list.
125 parameters.append(
126 Parameter(
127 member.name,
128 annotation=member.annotation,
129 kind=kind,
130 default=default,
131 docstring=member.docstring,
132 ),
133 )
135 return parameters
138def _reorder_parameters(parameters: list[Parameter]) -> list[Parameter]:
139 # De-duplicate, overwriting previous parameters.
140 params_dict = {param.name: param for param in parameters}
142 # Re-order, putting positional-only in front and keyword-only at the end.
143 pos_only = []
144 pos_kw = []
145 kw_only = []
146 for param in params_dict.values():
147 if param.kind is ParameterKind.positional_only: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true
148 pos_only.append(param)
149 elif param.kind is ParameterKind.keyword_only:
150 kw_only.append(param)
151 else:
152 pos_kw.append(param)
153 return pos_only + pos_kw + kw_only
156def _set_dataclass_init(class_: Class) -> None:
157 # Retrieve parameters from all parent dataclasses.
158 parameters = []
159 try:
160 mro = class_.mro()
161 except ValueError:
162 mro = () # type: ignore[assignment]
163 for parent in reversed(mro):
164 if _dataclass_decorator(parent.decorators):
165 parameters.extend(_dataclass_parameters(parent))
166 # At least one parent dataclass makes the current class a dataclass:
167 # that's how `dataclasses.is_dataclass` works.
168 class_.labels.add("dataclass")
170 # If the class is not decorated with `@dataclass`, skip it.
171 if not _dataclass_decorator(class_.decorators):
172 return
174 # Add current class parameters.
175 parameters.extend(_dataclass_parameters(class_))
177 # Create `__init__` method with re-ordered parameters.
178 init = Function(
179 "__init__",
180 lineno=0,
181 endlineno=0,
182 parent=class_,
183 parameters=Parameters(
184 Parameter(name="self", annotation=None, kind=ParameterKind.positional_or_keyword, default=None),
185 *_reorder_parameters(parameters),
186 ),
187 returns="None",
188 )
189 class_.set_member("__init__", init)
192def _del_members_annotated_as_initvar(class_: Class) -> None:
193 # Definitions annotated as InitVar are not class members
194 attributes = [member for member in class_.members.values() if isinstance(member, Attribute)]
195 for attribute in attributes:
196 if isinstance(attribute.annotation, Expr) and attribute.annotation.canonical_path == "dataclasses.InitVar":
197 class_.del_member(attribute.name)
200def _apply_recursively(mod_cls: Module | Class, processed: set[str]) -> None:
201 if mod_cls.canonical_path in processed: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true
202 return
203 processed.add(mod_cls.canonical_path)
204 if isinstance(mod_cls, Class):
205 if "__init__" not in mod_cls.members:
206 _set_dataclass_init(mod_cls)
207 _del_members_annotated_as_initvar(mod_cls)
208 for member in mod_cls.members.values():
209 if not member.is_alias and member.is_class:
210 _apply_recursively(member, processed) # type: ignore[arg-type]
211 elif isinstance(mod_cls, Module): 211 ↛ exitline 211 didn't return from function '_apply_recursively' because the condition on line 211 was always true
212 for member in mod_cls.members.values():
213 if not member.is_alias and (member.is_module or member.is_class):
214 _apply_recursively(member, processed) # type: ignore[arg-type]
217class DataclassesExtension(Extension):
218 """Built-in extension adding support for dataclasses.
220 This extension creates `__init__` methods of dataclasses
221 if they don't already exist.
222 """
224 def on_package_loaded(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG002
225 """Hook for loaded packages.
227 Parameters:
228 pkg: The loaded package.
229 """
230 _apply_recursively(pkg, set())