Coverage for src/_griffe/extensions/dataclasses.py: 95.65%
118 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-12 01:34 +0200
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-12 01:34 +0200
1# Built-in extension adding support for dataclasses.
2#
3# This extension re-creates `__init__` methods of dataclasses
4# during static analysis.
6from __future__ import annotations
8import ast
9from contextlib import suppress
10from functools import cache
11from typing import Any, cast
13from _griffe.enumerations import ParameterKind
14from _griffe.expressions import (
15 Expr,
16 ExprAttribute,
17 ExprCall,
18 ExprDict,
19)
20from _griffe.extensions.base import Extension
21from _griffe.logger import logger
22from _griffe.models import Attribute, Class, Decorator, Function, Module, Parameter, Parameters
25def _dataclass_decorator(decorators: list[Decorator]) -> Expr | None:
26 for decorator in decorators:
27 if isinstance(decorator.value, Expr) and decorator.value.canonical_path == "dataclasses.dataclass":
28 return decorator.value
29 return None
32def _expr_args(expr: Expr) -> dict[str, str | Expr]:
33 args = {}
34 if isinstance(expr, ExprCall):
35 for argument in expr.arguments:
36 try:
37 args[argument.name] = argument.value # type: ignore[union-attr]
38 except AttributeError:
39 # Argument is a unpacked variable.
40 with suppress(Exception):
41 collection = expr.function.parent.modules_collection # type: ignore[attr-defined]
42 var = collection[argument.value.canonical_path] # type: ignore[union-attr]
43 args.update(_expr_args(var.value))
44 elif isinstance(expr, ExprDict): 44 ↛ 46line 44 didn't jump to line 46 because the condition on line 44 was always true
45 args.update({ast.literal_eval(str(key)): value for key, value in zip(expr.keys, expr.values)})
46 return args
49def _dataclass_arguments(decorators: list[Decorator]) -> dict[str, Any]:
50 if (expr := _dataclass_decorator(decorators)) and isinstance(expr, ExprCall):
51 return _expr_args(expr)
52 return {}
55def _field_arguments(attribute: Attribute) -> dict[str, Any]:
56 if attribute.value:
57 value = attribute.value
58 if isinstance(value, ExprAttribute):
59 value = value.last
60 if isinstance(value, ExprCall) and value.canonical_path == "dataclasses.field":
61 return _expr_args(value)
62 return {}
65@cache
66def _dataclass_parameters(class_: Class) -> list[Parameter]:
67 # Fetch `@dataclass` arguments if any.
68 dec_args = _dataclass_arguments(class_.decorators)
70 # Parameters not added to `__init__`, return empty list.
71 if dec_args.get("init") == "False":
72 return []
74 # All parameters marked as keyword-only.
75 kw_only = dec_args.get("kw_only") == "True"
77 # Iterate on current attributes to find parameters.
78 parameters = []
79 for member in class_.members.values():
80 if member.is_attribute:
81 member = cast(Attribute, member)
83 # All dataclass parameters have annotations
84 if member.annotation is None:
85 continue
87 # Attributes that have labels for these characteristics are
88 # not class parameters:
89 # - @property
90 # - @cached_property
91 # - ClassVar annotation
92 if "property" in member.labels or (
93 # TODO: It is better to explicitly check for ClassVar, but
94 # Visitor.handle_attribute unwraps it from the annotation.
95 # Maybe create internal_labels and store classvar in there.
96 "class-attribute" in member.labels and "instance-attribute" not in member.labels
97 ):
98 continue
100 # Start of keyword-only parameters.
101 if isinstance(member.annotation, Expr) and member.annotation.canonical_path == "dataclasses.KW_ONLY":
102 kw_only = True
103 continue
105 # Fetch `field` arguments if any.
106 field_args = _field_arguments(member)
108 # Parameter not added to `__init__`, skip it.
109 if field_args.get("init") == "False":
110 continue
112 # Determine parameter kind.
113 kind = (
114 ParameterKind.keyword_only
115 if kw_only or field_args.get("kw_only") == "True"
116 else ParameterKind.positional_or_keyword
117 )
119 # Determine parameter default.
120 if "default_factory" in field_args:
121 default = ExprCall(function=field_args["default_factory"], arguments=[])
122 else:
123 default = field_args.get("default", None if field_args else member.value)
125 # Add parameter to the list.
126 parameters.append(
127 Parameter(
128 member.name,
129 annotation=member.annotation,
130 kind=kind,
131 default=default,
132 docstring=member.docstring,
133 ),
134 )
136 return parameters
139def _reorder_parameters(parameters: list[Parameter]) -> list[Parameter]:
140 # De-duplicate, overwriting previous parameters.
141 params_dict = {param.name: param for param in parameters}
143 # Re-order, putting positional-only in front and keyword-only at the end.
144 pos_only = []
145 pos_kw = []
146 kw_only = []
147 for param in params_dict.values():
148 if param.kind is ParameterKind.positional_only: 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true
149 pos_only.append(param)
150 elif param.kind is ParameterKind.keyword_only:
151 kw_only.append(param)
152 else:
153 pos_kw.append(param)
154 return pos_only + pos_kw + kw_only
157def _set_dataclass_init(class_: Class) -> None:
158 # Retrieve parameters from all parent dataclasses.
159 parameters = []
160 try:
161 mro = class_.mro()
162 except ValueError:
163 mro = () # type: ignore[assignment]
164 for parent in reversed(mro):
165 if _dataclass_decorator(parent.decorators):
166 parameters.extend(_dataclass_parameters(parent))
167 # At least one parent dataclass makes the current class a dataclass:
168 # that's how `dataclasses.is_dataclass` works.
169 class_.labels.add("dataclass")
171 # If the class is not decorated with `@dataclass`, skip it.
172 if not _dataclass_decorator(class_.decorators):
173 return
175 logger.debug("Handling dataclass: %s", class_.path)
177 # Add current class parameters.
178 parameters.extend(_dataclass_parameters(class_))
180 # Create `__init__` method with re-ordered parameters.
181 init = Function(
182 "__init__",
183 lineno=0,
184 endlineno=0,
185 parent=class_,
186 parameters=Parameters(
187 Parameter(name="self", annotation=None, kind=ParameterKind.positional_or_keyword, default=None),
188 *_reorder_parameters(parameters),
189 ),
190 returns="None",
191 )
192 class_.set_member("__init__", init)
195def _del_members_annotated_as_initvar(class_: Class) -> None:
196 # Definitions annotated as InitVar are not class members
197 attributes = [member for member in class_.members.values() if isinstance(member, Attribute)]
198 for attribute in attributes:
199 if isinstance(attribute.annotation, Expr) and attribute.annotation.canonical_path == "dataclasses.InitVar":
200 class_.del_member(attribute.name)
203def _apply_recursively(mod_cls: Module | Class, processed: set[str]) -> None:
204 if mod_cls.canonical_path in processed: 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true
205 return
206 processed.add(mod_cls.canonical_path)
207 if isinstance(mod_cls, Class):
208 if "__init__" not in mod_cls.members:
209 _set_dataclass_init(mod_cls)
210 _del_members_annotated_as_initvar(mod_cls)
211 for member in mod_cls.members.values():
212 if not member.is_alias and member.is_class:
213 _apply_recursively(member, processed) # type: ignore[arg-type]
214 elif isinstance(mod_cls, Module): 214 ↛ exitline 214 didn't return from function '_apply_recursively' because the condition on line 214 was always true
215 for member in mod_cls.members.values():
216 if not member.is_alias and (member.is_module or member.is_class):
217 _apply_recursively(member, processed) # type: ignore[arg-type]
220class DataclassesExtension(Extension):
221 """Built-in extension adding support for dataclasses.
223 This extension creates `__init__` methods of dataclasses
224 if they don't already exist.
225 """
227 def on_package_loaded(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG002
228 """Hook for loaded packages.
230 Parameters:
231 pkg: The loaded package.
232 """
233 _apply_recursively(pkg, set())