Coverage for tests/test_models.py: 100.00%
161 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-15 16:47 +0200
« prev ^ index » next coverage.py v7.6.1, created at 2024-08-15 16:47 +0200
1"""Tests for the `dataclasses` module."""
3from __future__ import annotations
5from copy import deepcopy
6from textwrap import dedent
8import pytest
10from griffe import (
11 Attribute,
12 Docstring,
13 GriffeLoader,
14 Module,
15 module_vtree,
16 temporary_inspected_module,
17 temporary_pypackage,
18 temporary_visited_module,
19 temporary_visited_package,
20)
23def test_submodule_exports() -> None:
24 """Check that a module is exported depending on whether it was also imported."""
25 root = Module("root")
26 sub = Module("sub")
27 private = Attribute("_private")
28 root["sub"] = sub
29 root["_private"] = private
31 assert not sub.is_wildcard_exposed
32 root.imports["sub"] = "root.sub"
33 assert sub.is_wildcard_exposed
35 assert not private.is_wildcard_exposed
36 root.exports = {"_private"}
37 assert private.is_wildcard_exposed
40def test_has_docstrings() -> None:
41 """Assert the `.has_docstrings` method is recursive."""
42 with temporary_visited_module("class A:\n '''Hello.'''") as module:
43 assert module.has_docstrings
46def test_has_docstrings_submodules() -> None:
47 """Assert the `.has_docstrings` method descends into submodules."""
48 module = module_vtree("a.b.c.d")
49 module["b.c.d"].docstring = Docstring("Hello.")
50 assert module.has_docstrings
53def test_handle_aliases_chain_in_has_docstrings() -> None:
54 """Assert the `.has_docstrings` method can handle aliases chains in members."""
55 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package:
56 mod_a = tmp_package.path / "mod_a.py"
57 mod_b = tmp_package.path / "mod_b.py"
58 mod_a.write_text("from .mod_b import someobj")
59 mod_b.write_text("from somelib import someobj")
61 loader = GriffeLoader(search_paths=[tmp_package.tmpdir])
62 package = loader.load(tmp_package.name)
63 assert not package.has_docstrings
64 loader.resolve_aliases(implicit=True)
65 assert not package.has_docstrings
68def test_has_docstrings_does_not_trigger_alias_resolution() -> None:
69 """Assert the `.has_docstrings` method does not trigger alias resolution."""
70 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package:
71 mod_a = tmp_package.path / "mod_a.py"
72 mod_b = tmp_package.path / "mod_b.py"
73 mod_a.write_text("from .mod_b import someobj")
74 mod_b.write_text("from somelib import someobj")
76 loader = GriffeLoader(search_paths=[tmp_package.tmpdir])
77 package = loader.load(tmp_package.name)
78 assert not package.has_docstrings
79 assert not package["mod_a.someobj"].resolved
82def test_deepcopy() -> None:
83 """Assert we can deep-copy object trees."""
84 loader = GriffeLoader()
85 mod = loader.load("griffe")
87 deepcopy(mod)
88 deepcopy(mod.as_dict())
91def test_dataclass_properties_and_class_variables() -> None:
92 """Don't return properties or class variables as parameters of dataclasses."""
93 code = """
94 from dataclasses import dataclass
95 from functools import cached_property
96 from typing import ClassVar
98 @dataclass
99 class Point:
100 x: float
101 y: float
103 # These definitions create class variables
104 r: ClassVar[float]
105 s: float = 3
106 t: ClassVar[float] = 3
108 @property
109 def a(self):
110 return 0
112 @cached_property
113 def b(self):
114 return 0
115 """
116 with temporary_visited_package("package", {"__init__.py": code}) as module:
117 params = module["Point"].parameters
118 assert [p.name for p in params] == ["self", "x", "y", "s"]
121@pytest.mark.parametrize(
122 "code",
123 [
124 """
125 @dataclass
126 class Dataclass:
127 x: float
128 y: float = field(kw_only=True)
130 class Class:
131 def __init__(self, x: float, *, y: float): ...
132 """,
133 """
134 @dataclass
135 class Dataclass:
136 x: float = field(kw_only=True)
137 y: float
139 class Class:
140 def __init__(self, y: float, *, x: float): ...
141 """,
142 """
143 @dataclass
144 class Dataclass:
145 x: float
146 _: KW_ONLY
147 y: float
149 class Class:
150 def __init__(self, x: float, *, y: float): ...
151 """,
152 """
153 @dataclass
154 class Dataclass:
155 _: KW_ONLY
156 x: float
157 y: float
159 class Class:
160 def __init__(self, *, x: float, y: float): ...
161 """,
162 """
163 @dataclass(kw_only=True)
164 class Dataclass:
165 x: float
166 y: float
168 class Class:
169 def __init__(self, *, x: float, y: float): ...
170 """,
171 ],
172)
173def test_dataclass_parameter_kinds(code: str) -> None:
174 """Check dataclass and equivalent non-dataclass parameters.
176 The parameter kinds for each pair should be the same.
178 Parameters:
179 code: Python code to visit.
180 """
181 code = f"from dataclasses import dataclass, field, KW_ONLY\n\n{dedent(code)}"
182 with temporary_visited_package("package", {"__init__.py": code}) as module:
183 for dataclass_param, regular_param in zip(module["Dataclass"].parameters, module["Class"].parameters):
184 assert dataclass_param == regular_param
187def test_regular_class_inheriting_dataclass_dont_get_its_own_params() -> None:
188 """A regular class inheriting from a dataclass don't have its attributes added to `__init__`."""
189 code = """
190 from dataclasses import dataclass
192 @dataclass
193 class Base:
194 a: int
195 b: str
197 @dataclass
198 class Derived1(Base):
199 c: float
201 class Derived2(Base):
202 d: float
203 """
204 with temporary_visited_package("package", {"__init__.py": code}) as module:
205 params1 = module["Derived1"].parameters
206 params2 = module["Derived2"].parameters
207 assert [p.name for p in params1] == ["self", "a", "b", "c"]
208 assert [p.name for p in params2] == ["self", "a", "b"]
211def test_regular_class_inheriting_dataclass_is_labelled_dataclass() -> None:
212 """A regular class inheriting from a dataclass is labelled as a dataclass too."""
213 code = """
214 from dataclasses import dataclass
216 @dataclass
217 class Base:
218 pass
220 class Derived(Base):
221 pass
222 """
223 with temporary_visited_package("package", {"__init__.py": code}) as module:
224 obj = module["Derived"]
225 assert "dataclass" in obj.labels
228def test_fields_with_init_false() -> None:
229 """Fields marked with `init=False` are not added to the `__init__` method."""
230 code = """
231 from dataclasses import dataclass, field
233 @dataclass
234 class PointA:
235 x: float
236 y: float
237 z: float = field(init=False)
239 @dataclass(init=False)
240 class PointB:
241 x: float
242 y: float
244 @dataclass(init=False)
245 class PointC:
246 x: float
247 y: float = field(init=True) # init=True has no effect
248 """
249 with temporary_visited_package("package", {"__init__.py": code}) as module:
250 params_a = module["PointA"].parameters
251 params_b = module["PointB"].parameters
252 params_c = module["PointC"].parameters
254 assert "z" not in params_a
255 assert "x" not in params_b
256 assert "y" not in params_b
257 assert "x" not in params_c
258 assert "y" not in params_c
261def test_parameters_are_reorderd_to_match_their_kind() -> None:
262 """Keyword-only parameters in base class are pushed back to the end of the signature."""
263 code = """
264 from dataclasses import dataclass
266 @dataclass(kw_only=True)
267 class Base:
268 a: int
269 b: str
271 @dataclass
272 class Reordered(Base):
273 b: float
274 c: float
275 """
276 with temporary_visited_package("package", {"__init__.py": code}) as module:
277 params_base = module["Base"].parameters
278 params_reordered = module["Reordered"].parameters
279 assert [p.name for p in params_base] == ["self", "a", "b"]
280 assert [p.name for p in params_reordered] == ["self", "b", "c", "a"]
281 assert str(params_reordered["b"].annotation) == "float"
284def test_parameters_annotated_as_initvar() -> None:
285 """Don't return InitVar annotated fields as class members.
287 But if __init__ is defined, InitVar has no effect.
288 """
289 code = """
290 from dataclasses import dataclass, InitVar
292 @dataclass
293 class PointA:
294 x: float
295 y: float
296 z: InitVar[float]
298 @dataclass
299 class PointB:
300 x: float
301 y: float
302 z: InitVar[float]
304 def __init__(self, r: float): ...
305 """
307 with temporary_visited_package("package", {"__init__.py": code}) as module:
308 point_a = module["PointA"]
309 assert [p.name for p in point_a.parameters] == ["self", "x", "y", "z"]
310 assert list(point_a.members) == ["x", "y", "__init__"]
312 point_b = module["PointB"]
313 assert [p.name for p in point_b.parameters] == ["self", "r"]
314 assert list(point_b.members) == ["x", "y", "z", "__init__"]
317def test_visited_module_source() -> None:
318 """Check the source property of a module."""
319 code = "print('hello')\nprint('world')"
320 with temporary_visited_package("package", {"__init__.py": code}) as module:
321 assert module.source == code
324def test_visited_class_source() -> None:
325 """Check the source property of a class."""
326 code = """
327 class A:
328 def __init__(self, x: int):
329 self.x = x
330 """
331 with temporary_visited_package("package", {"__init__.py": code}) as module:
332 assert module["A"].source == dedent(code).strip()
335def test_visited_object_source_with_missing_line_number() -> None:
336 """Check the source property of an object with missing line number."""
337 code = """
338 class A:
339 def __init__(self, x: int):
340 self.x = x
341 """
342 with temporary_visited_package("package", {"__init__.py": code}) as module:
343 module["A"].endlineno = None
344 assert not module["A"].source
345 module["A"].endlineno = 3
346 module["A"].lineno = None
347 assert not module["A"].source
350def test_inspected_module_source() -> None:
351 """Check the source property of a module."""
352 code = "print('hello')\nprint('world')"
353 with temporary_inspected_module(code) as module:
354 assert module.source == code
357def test_inspected_class_source() -> None:
358 """Check the source property of a class."""
359 code = """
360 class A:
361 def __init__(self, x: int):
362 self.x = x
363 """
364 with temporary_inspected_module(code) as module:
365 assert module["A"].source == dedent(code).strip()
368def test_inspected_object_source_with_missing_line_number() -> None:
369 """Check the source property of an object with missing line number."""
370 code = """
371 class A:
372 def __init__(self, x: int):
373 self.x = x
374 """
375 with temporary_inspected_module(code) as module:
376 module["A"].endlineno = None
377 assert not module["A"].source
378 module["A"].endlineno = 3
379 module["A"].lineno = None
380 assert not module["A"].source
383def test_dataclass_parameter_docstrings() -> None:
384 """Class parameters should have a docstring attribute."""
385 code = """
386 from dataclasses import dataclass, InitVar
388 @dataclass
389 class Base:
390 a: int
391 "Parameter a"
392 b: InitVar[int] = 3
393 "Parameter b"
395 @dataclass
396 class Derived(Base):
397 c: float
398 d: InitVar[float]
399 "Parameter d"
400 """
402 with temporary_visited_package("package", {"__init__.py": code}) as module:
403 base = module["Base"]
404 param_self = base.parameters[0]
405 param_a = base.parameters[1]
406 param_b = base.parameters[2]
407 assert param_self.docstring is None
408 assert param_a.docstring.value == "Parameter a"
409 assert param_b.docstring.value == "Parameter b"
411 derived = module["Derived"]
412 param_self = derived.parameters[0]
413 param_a = derived.parameters[1]
414 param_b = derived.parameters[2]
415 param_c = derived.parameters[3]
416 param_d = derived.parameters[4]
417 assert param_self.docstring is None
418 assert param_a.docstring.value == "Parameter a"
419 assert param_b.docstring.value == "Parameter b"
420 assert param_c.docstring is None
421 assert param_d.docstring.value == "Parameter d"
424def test_attributes_that_have_no_annotations() -> None:
425 """Dataclass attributes that have no annotatations are not parameters."""
426 code = """
427 from dataclasses import dataclass, field
429 @dataclass
430 class Base:
431 a: int
432 b: str = field(init=False)
433 c = 3 # class attribute
435 @dataclass
436 class Derived(Base):
437 a = 1 # no effect on the parameter status of a
438 b = "b" # inherited non-parameter
439 d: float = 4
440 """
441 with temporary_visited_package("package", {"__init__.py": code}) as module:
442 base_params = [p.name for p in module["Base"].parameters]
443 derived_params = [p.name for p in module["Derived"].parameters]
444 assert base_params == ["self", "a"]
445 assert derived_params == ["self", "a", "d"]