Coverage for tests/test_models.py: 100.00%
189 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-12 01:34 +0200
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-12 01:34 +0200
1"""Tests for the `dataclasses` module."""
3from __future__ import annotations
5from copy import deepcopy
6from textwrap import dedent
8import pytest
10from griffe import (
11 Attribute,
12 Docstring,
13 GriffeLoader,
14 Module,
15 NameResolutionError,
16 module_vtree,
17 temporary_inspected_module,
18 temporary_pypackage,
19 temporary_visited_module,
20 temporary_visited_package,
21)
24def test_submodule_exports() -> None:
25 """Check that a module is exported depending on whether it was also imported."""
26 root = Module("root")
27 sub = Module("sub")
28 private = Attribute("_private")
29 root["sub"] = sub
30 root["_private"] = private
32 assert not sub.is_wildcard_exposed
33 root.imports["sub"] = "root.sub"
34 assert sub.is_wildcard_exposed
36 assert not private.is_wildcard_exposed
37 root.exports = {"_private"}
38 assert private.is_wildcard_exposed
41def test_has_docstrings() -> None:
42 """Assert the `.has_docstrings` method is recursive."""
43 with temporary_visited_module("class A:\n '''Hello.'''") as module:
44 assert module.has_docstrings
47def test_has_docstrings_submodules() -> None:
48 """Assert the `.has_docstrings` method descends into submodules."""
49 module = module_vtree("a.b.c.d")
50 module["b.c.d"].docstring = Docstring("Hello.")
51 assert module.has_docstrings
54def test_handle_aliases_chain_in_has_docstrings() -> None:
55 """Assert the `.has_docstrings` method can handle aliases chains in members."""
56 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package:
57 mod_a = tmp_package.path / "mod_a.py"
58 mod_b = tmp_package.path / "mod_b.py"
59 mod_a.write_text("from .mod_b import someobj")
60 mod_b.write_text("from somelib import someobj")
62 loader = GriffeLoader(search_paths=[tmp_package.tmpdir])
63 package = loader.load(tmp_package.name)
64 assert not package.has_docstrings
65 loader.resolve_aliases(implicit=True)
66 assert not package.has_docstrings
69def test_has_docstrings_does_not_trigger_alias_resolution() -> None:
70 """Assert the `.has_docstrings` method does not trigger alias resolution."""
71 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package:
72 mod_a = tmp_package.path / "mod_a.py"
73 mod_b = tmp_package.path / "mod_b.py"
74 mod_a.write_text("from .mod_b import someobj")
75 mod_b.write_text("from somelib import someobj")
77 loader = GriffeLoader(search_paths=[tmp_package.tmpdir])
78 package = loader.load(tmp_package.name)
79 assert not package.has_docstrings
80 assert not package["mod_a.someobj"].resolved
83def test_deepcopy() -> None:
84 """Assert we can deep-copy object trees."""
85 loader = GriffeLoader()
86 mod = loader.load("griffe")
88 deepcopy(mod)
89 deepcopy(mod.as_dict())
92def test_dataclass_properties_and_class_variables() -> None:
93 """Don't return properties or class variables as parameters of dataclasses."""
94 code = """
95 from dataclasses import dataclass
96 from functools import cached_property
97 from typing import ClassVar
99 @dataclass
100 class Point:
101 x: float
102 y: float
104 # These definitions create class variables
105 r: ClassVar[float]
106 s: float = 3
107 t: ClassVar[float] = 3
109 @property
110 def a(self):
111 return 0
113 @cached_property
114 def b(self):
115 return 0
116 """
117 with temporary_visited_package("package", {"__init__.py": code}) as module:
118 params = module["Point"].parameters
119 assert [p.name for p in params] == ["self", "x", "y", "s"]
122@pytest.mark.parametrize(
123 "code",
124 [
125 """
126 @dataclass
127 class Dataclass:
128 x: float
129 y: float = field(kw_only=True)
131 class Class:
132 def __init__(self, x: float, *, y: float): ...
133 """,
134 """
135 @dataclass
136 class Dataclass:
137 x: float = field(kw_only=True)
138 y: float
140 class Class:
141 def __init__(self, y: float, *, x: float): ...
142 """,
143 """
144 @dataclass
145 class Dataclass:
146 x: float
147 _: KW_ONLY
148 y: float
150 class Class:
151 def __init__(self, x: float, *, y: float): ...
152 """,
153 """
154 @dataclass
155 class Dataclass:
156 _: KW_ONLY
157 x: float
158 y: float
160 class Class:
161 def __init__(self, *, x: float, y: float): ...
162 """,
163 """
164 @dataclass(kw_only=True)
165 class Dataclass:
166 x: float
167 y: float
169 class Class:
170 def __init__(self, *, x: float, y: float): ...
171 """,
172 ],
173)
174def test_dataclass_parameter_kinds(code: str) -> None:
175 """Check dataclass and equivalent non-dataclass parameters.
177 The parameter kinds for each pair should be the same.
179 Parameters:
180 code: Python code to visit.
181 """
182 code = f"from dataclasses import dataclass, field, KW_ONLY\n\n{dedent(code)}"
183 with temporary_visited_package("package", {"__init__.py": code}) as module:
184 for dataclass_param, regular_param in zip(module["Dataclass"].parameters, module["Class"].parameters):
185 assert dataclass_param == regular_param
188def test_regular_class_inheriting_dataclass_dont_get_its_own_params() -> None:
189 """A regular class inheriting from a dataclass don't have its attributes added to `__init__`."""
190 code = """
191 from dataclasses import dataclass
193 @dataclass
194 class Base:
195 a: int
196 b: str
198 @dataclass
199 class Derived1(Base):
200 c: float
202 class Derived2(Base):
203 d: float
204 """
205 with temporary_visited_package("package", {"__init__.py": code}) as module:
206 params1 = module["Derived1"].parameters
207 params2 = module["Derived2"].parameters
208 assert [p.name for p in params1] == ["self", "a", "b", "c"]
209 assert [p.name for p in params2] == ["self", "a", "b"]
212def test_regular_class_inheriting_dataclass_is_labelled_dataclass() -> None:
213 """A regular class inheriting from a dataclass is labelled as a dataclass too."""
214 code = """
215 from dataclasses import dataclass
217 @dataclass
218 class Base:
219 pass
221 class Derived(Base):
222 pass
223 """
224 with temporary_visited_package("package", {"__init__.py": code}) as module:
225 obj = module["Derived"]
226 assert "dataclass" in obj.labels
229def test_fields_with_init_false() -> None:
230 """Fields marked with `init=False` are not added to the `__init__` method."""
231 code = """
232 from dataclasses import dataclass, field
234 @dataclass
235 class PointA:
236 x: float
237 y: float
238 z: float = field(init=False)
240 @dataclass(init=False)
241 class PointB:
242 x: float
243 y: float
245 @dataclass(init=False)
246 class PointC:
247 x: float
248 y: float = field(init=True) # init=True has no effect
249 """
250 with temporary_visited_package("package", {"__init__.py": code}) as module:
251 params_a = module["PointA"].parameters
252 params_b = module["PointB"].parameters
253 params_c = module["PointC"].parameters
255 assert "z" not in params_a
256 assert "x" not in params_b
257 assert "y" not in params_b
258 assert "x" not in params_c
259 assert "y" not in params_c
262def test_parameters_are_reorderd_to_match_their_kind() -> None:
263 """Keyword-only parameters in base class are pushed back to the end of the signature."""
264 code = """
265 from dataclasses import dataclass
267 @dataclass(kw_only=True)
268 class Base:
269 a: int
270 b: str
272 @dataclass
273 class Reordered(Base):
274 b: float
275 c: float
276 """
277 with temporary_visited_package("package", {"__init__.py": code}) as module:
278 params_base = module["Base"].parameters
279 params_reordered = module["Reordered"].parameters
280 assert [p.name for p in params_base] == ["self", "a", "b"]
281 assert [p.name for p in params_reordered] == ["self", "b", "c", "a"]
282 assert str(params_reordered["b"].annotation) == "float"
285def test_parameters_annotated_as_initvar() -> None:
286 """Don't return InitVar annotated fields as class members.
288 But if __init__ is defined, InitVar has no effect.
289 """
290 code = """
291 from dataclasses import dataclass, InitVar
293 @dataclass
294 class PointA:
295 x: float
296 y: float
297 z: InitVar[float]
299 @dataclass
300 class PointB:
301 x: float
302 y: float
303 z: InitVar[float]
305 def __init__(self, r: float): ...
306 """
308 with temporary_visited_package("package", {"__init__.py": code}) as module:
309 point_a = module["PointA"]
310 assert [p.name for p in point_a.parameters] == ["self", "x", "y", "z"]
311 assert list(point_a.members) == ["x", "y", "__init__"]
313 point_b = module["PointB"]
314 assert [p.name for p in point_b.parameters] == ["self", "r"]
315 assert list(point_b.members) == ["x", "y", "z", "__init__"]
318def test_visited_module_source() -> None:
319 """Check the source property of a module."""
320 code = "print('hello')\nprint('world')"
321 with temporary_visited_package("package", {"__init__.py": code}) as module:
322 assert module.source == code
325def test_visited_class_source() -> None:
326 """Check the source property of a class."""
327 code = """
328 class A:
329 def __init__(self, x: int):
330 self.x = x
331 """
332 with temporary_visited_package("package", {"__init__.py": code}) as module:
333 assert module["A"].source == dedent(code).strip()
336def test_visited_object_source_with_missing_line_number() -> None:
337 """Check the source property of an object with missing line number."""
338 code = """
339 class A:
340 def __init__(self, x: int):
341 self.x = x
342 """
343 with temporary_visited_package("package", {"__init__.py": code}) as module:
344 module["A"].endlineno = None
345 assert not module["A"].source
346 module["A"].endlineno = 3
347 module["A"].lineno = None
348 assert not module["A"].source
351def test_inspected_module_source() -> None:
352 """Check the source property of a module."""
353 code = "print('hello')\nprint('world')"
354 with temporary_inspected_module(code) as module:
355 assert module.source == code
358def test_inspected_class_source() -> None:
359 """Check the source property of a class."""
360 code = """
361 class A:
362 def __init__(self, x: int):
363 self.x = x
364 """
365 with temporary_inspected_module(code) as module:
366 assert module["A"].source == dedent(code).strip()
369def test_inspected_object_source_with_missing_line_number() -> None:
370 """Check the source property of an object with missing line number."""
371 code = """
372 class A:
373 def __init__(self, x: int):
374 self.x = x
375 """
376 with temporary_inspected_module(code) as module:
377 module["A"].endlineno = None
378 assert not module["A"].source
379 module["A"].endlineno = 3
380 module["A"].lineno = None
381 assert not module["A"].source
384def test_dataclass_parameter_docstrings() -> None:
385 """Class parameters should have a docstring attribute."""
386 code = """
387 from dataclasses import dataclass, InitVar
389 @dataclass
390 class Base:
391 a: int
392 "Parameter a"
393 b: InitVar[int] = 3
394 "Parameter b"
396 @dataclass
397 class Derived(Base):
398 c: float
399 d: InitVar[float]
400 "Parameter d"
401 """
403 with temporary_visited_package("package", {"__init__.py": code}) as module:
404 base = module["Base"]
405 param_self = base.parameters[0]
406 param_a = base.parameters[1]
407 param_b = base.parameters[2]
408 assert param_self.docstring is None
409 assert param_a.docstring.value == "Parameter a"
410 assert param_b.docstring.value == "Parameter b"
412 derived = module["Derived"]
413 param_self = derived.parameters[0]
414 param_a = derived.parameters[1]
415 param_b = derived.parameters[2]
416 param_c = derived.parameters[3]
417 param_d = derived.parameters[4]
418 assert param_self.docstring is None
419 assert param_a.docstring.value == "Parameter a"
420 assert param_b.docstring.value == "Parameter b"
421 assert param_c.docstring is None
422 assert param_d.docstring.value == "Parameter d"
425def test_attributes_that_have_no_annotations() -> None:
426 """Dataclass attributes that have no annotatations are not parameters."""
427 code = """
428 from dataclasses import dataclass, field
430 @dataclass
431 class Base:
432 a: int
433 b: str = field(init=False)
434 c = 3 # class attribute
436 @dataclass
437 class Derived(Base):
438 a = 1 # no effect on the parameter status of a
439 b = "b" # inherited non-parameter
440 d: float = 4
441 """
442 with temporary_visited_package("package", {"__init__.py": code}) as module:
443 base_params = [p.name for p in module["Base"].parameters]
444 derived_params = [p.name for p in module["Derived"].parameters]
445 assert base_params == ["self", "a"]
446 assert derived_params == ["self", "a", "d"]
449def test_name_resolution() -> None:
450 """Name are correctly resolved in the scope of an object."""
451 code = """
452 module_attribute = 0
454 class Class:
455 import imported
457 class_attribute = 0
459 def __init__(self):
460 self.instance_attribute = 0
462 def method(self):
463 local_variable = 0
464 """
465 with temporary_visited_module(code) as module:
466 assert module.resolve("module_attribute") == "module.module_attribute"
467 assert module.resolve("Class") == "module.Class"
469 assert module["module_attribute"].resolve("Class") == "module.Class"
470 with pytest.raises(NameResolutionError):
471 module["module_attribute"].resolve("class_attribute")
473 assert module["Class"].resolve("module_attribute") == "module.module_attribute"
474 assert module["Class"].resolve("imported") == "imported"
475 assert module["Class"].resolve("class_attribute") == "module.Class.class_attribute"
476 assert module["Class"].resolve("instance_attribute") == "module.Class.instance_attribute"
477 assert module["Class"].resolve("method") == "module.Class.method"
479 assert module["Class.class_attribute"].resolve("module_attribute") == "module.module_attribute"
480 assert module["Class.class_attribute"].resolve("Class") == "module.Class"
481 assert module["Class.class_attribute"].resolve("imported") == "imported"
482 assert module["Class.class_attribute"].resolve("instance_attribute") == "module.Class.instance_attribute"
483 assert module["Class.class_attribute"].resolve("method") == "module.Class.method"
485 assert module["Class.instance_attribute"].resolve("module_attribute") == "module.module_attribute"
486 assert module["Class.instance_attribute"].resolve("Class") == "module.Class"
487 assert module["Class.instance_attribute"].resolve("imported") == "imported"
488 assert module["Class.instance_attribute"].resolve("class_attribute") == "module.Class.class_attribute"
489 assert module["Class.instance_attribute"].resolve("method") == "module.Class.method"
491 assert module["Class.method"].resolve("module_attribute") == "module.module_attribute"
492 assert module["Class.method"].resolve("Class") == "module.Class"
493 assert module["Class.method"].resolve("imported") == "imported"
494 assert module["Class.method"].resolve("class_attribute") == "module.Class.class_attribute"
495 assert module["Class.method"].resolve("instance_attribute") == "module.Class.instance_attribute"