Coverage for tests/test_models.py: 100.00%
261 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-11 13:44 +0200
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-11 13:44 +0200
1"""Tests for the `dataclasses` module."""
3from __future__ import annotations
5import sys
6from copy import deepcopy
7from textwrap import dedent
9import pytest
11from griffe import (
12 Attribute,
13 Class,
14 Docstring,
15 Function,
16 GriffeLoader,
17 Module,
18 NameResolutionError,
19 Parameter,
20 ParameterKind,
21 Parameters,
22 TypeParameter,
23 TypeParameterKind,
24 TypeParameters,
25 module_vtree,
26 temporary_inspected_module,
27 temporary_pypackage,
28 temporary_visited_module,
29 temporary_visited_package,
30)
33def test_submodule_exports() -> None:
34 """Check that a module is exported depending on whether it was also imported."""
35 root = Module("root")
36 sub = Module("sub")
37 private = Attribute("_private")
38 root["sub"] = sub
39 root["_private"] = private
41 assert not sub.is_wildcard_exposed
42 root.imports["sub"] = "root.sub"
43 assert sub.is_wildcard_exposed
45 assert not private.is_wildcard_exposed
46 root.exports = ["_private"]
47 assert private.is_wildcard_exposed
50def test_has_docstrings() -> None:
51 """Assert the `.has_docstrings` method is recursive."""
52 with temporary_visited_module("class A:\n '''Hello.'''") as module:
53 assert module.has_docstrings
56def test_has_docstrings_submodules() -> None:
57 """Assert the `.has_docstrings` method descends into submodules."""
58 module = module_vtree("a.b.c.d")
59 module["b.c.d"].docstring = Docstring("Hello.")
60 assert module.has_docstrings
63def test_handle_aliases_chain_in_has_docstrings() -> None:
64 """Assert the `.has_docstrings` method can handle aliases chains in members."""
65 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package:
66 mod_a = tmp_package.path / "mod_a.py"
67 mod_b = tmp_package.path / "mod_b.py"
68 mod_a.write_text("from .mod_b import someobj")
69 mod_b.write_text("from somelib import someobj")
71 loader = GriffeLoader(search_paths=[tmp_package.tmpdir])
72 package = loader.load(tmp_package.name)
73 assert not package.has_docstrings
74 loader.resolve_aliases(implicit=True)
75 assert not package.has_docstrings
78def test_has_docstrings_does_not_trigger_alias_resolution() -> None:
79 """Assert the `.has_docstrings` method does not trigger alias resolution."""
80 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package:
81 mod_a = tmp_package.path / "mod_a.py"
82 mod_b = tmp_package.path / "mod_b.py"
83 mod_a.write_text("from .mod_b import someobj")
84 mod_b.write_text("from somelib import someobj")
86 loader = GriffeLoader(search_paths=[tmp_package.tmpdir])
87 package = loader.load(tmp_package.name)
88 assert not package.has_docstrings
89 assert not package["mod_a.someobj"].resolved
92def test_deepcopy() -> None:
93 """Assert we can deep-copy object trees."""
94 loader = GriffeLoader()
95 mod = loader.load("griffe")
97 deepcopy(mod)
98 deepcopy(mod.as_dict())
101def test_dataclass_properties_and_class_variables() -> None:
102 """Don't return properties or class variables as parameters of dataclasses."""
103 code = """
104 from dataclasses import dataclass
105 from functools import cached_property
106 from typing import ClassVar
108 @dataclass
109 class Point:
110 x: float
111 y: float
113 # These definitions create class variables.
114 r: ClassVar[float]
115 s: float = 3
116 t: ClassVar[float] = 3
118 @property
119 def a(self):
120 return 0
122 @cached_property
123 def b(self):
124 return 0
125 """
126 with temporary_visited_package("package", {"__init__.py": code}) as module:
127 params = module["Point"].parameters
128 assert [p.name for p in params] == ["self", "x", "y", "s"]
131@pytest.mark.parametrize(
132 "code",
133 [
134 """
135 @dataclass
136 class Dataclass:
137 x: float
138 y: float = field(kw_only=True)
140 class Class:
141 def __init__(self, x: float, *, y: float): ...
142 """,
143 """
144 @dataclass
145 class Dataclass:
146 x: float = field(kw_only=True)
147 y: float
149 class Class:
150 def __init__(self, y: float, *, x: float): ...
151 """,
152 """
153 @dataclass
154 class Dataclass:
155 x: float
156 _: KW_ONLY
157 y: float
159 class Class:
160 def __init__(self, x: float, *, y: float): ...
161 """,
162 """
163 @dataclass
164 class Dataclass:
165 _: KW_ONLY
166 x: float
167 y: float
169 class Class:
170 def __init__(self, *, x: float, y: float): ...
171 """,
172 """
173 @dataclass(kw_only=True)
174 class Dataclass:
175 x: float
176 y: float
178 class Class:
179 def __init__(self, *, x: float, y: float): ...
180 """,
181 ],
182)
183def test_dataclass_parameter_kinds(code: str) -> None:
184 """Check dataclass and equivalent non-dataclass parameters.
186 The parameter kinds for each pair should be the same.
188 Parameters:
189 code: Python code to visit.
190 """
191 code = f"from dataclasses import dataclass, field, KW_ONLY\n\n{dedent(code)}"
192 with temporary_visited_package("package", {"__init__.py": code}) as module:
193 for dataclass_param, regular_param in zip(module["Dataclass"].parameters, module["Class"].parameters):
194 assert dataclass_param == regular_param
197def test_regular_class_inheriting_dataclass_dont_get_its_own_params() -> None:
198 """A regular class inheriting from a dataclass don't have its attributes added to `__init__`."""
199 code = """
200 from dataclasses import dataclass
202 @dataclass
203 class Base:
204 a: int
205 b: str
207 @dataclass
208 class Derived1(Base):
209 c: float
211 class Derived2(Base):
212 d: float
213 """
214 with temporary_visited_package("package", {"__init__.py": code}) as module:
215 params1 = module["Derived1"].parameters
216 params2 = module["Derived2"].parameters
217 assert [p.name for p in params1] == ["self", "a", "b", "c"]
218 assert [p.name for p in params2] == ["self", "a", "b"]
221def test_regular_class_inheriting_dataclass_is_labelled_dataclass() -> None:
222 """A regular class inheriting from a dataclass is labelled as a dataclass too."""
223 code = """
224 from dataclasses import dataclass
226 @dataclass
227 class Base:
228 pass
230 class Derived(Base):
231 pass
232 """
233 with temporary_visited_package("package", {"__init__.py": code}) as module:
234 obj = module["Derived"]
235 assert "dataclass" in obj.labels
238def test_fields_with_init_false() -> None:
239 """Fields marked with `init=False` are not added to the `__init__` method."""
240 code = """
241 from dataclasses import dataclass, field
243 @dataclass
244 class PointA:
245 x: float
246 y: float
247 z: float = field(init=False)
249 @dataclass(init=False)
250 class PointB:
251 x: float
252 y: float
254 @dataclass(init=False)
255 class PointC:
256 x: float
257 y: float = field(init=True) # `init=True` has no effect.
258 """
259 with temporary_visited_package("package", {"__init__.py": code}) as module:
260 params_a = module["PointA"].parameters
261 params_b = module["PointB"].parameters
262 params_c = module["PointC"].parameters
264 assert "z" not in params_a
265 assert "x" not in params_b
266 assert "y" not in params_b
267 assert "x" not in params_c
268 assert "y" not in params_c
271def test_parameters_are_reorderd_to_match_their_kind() -> None:
272 """Keyword-only parameters in base class are pushed back to the end of the signature."""
273 code = """
274 from dataclasses import dataclass
276 @dataclass(kw_only=True)
277 class Base:
278 a: int
279 b: str
281 @dataclass
282 class Reordered(Base):
283 b: float
284 c: float
285 """
286 with temporary_visited_package("package", {"__init__.py": code}) as module:
287 params_base = module["Base"].parameters
288 params_reordered = module["Reordered"].parameters
289 assert [p.name for p in params_base] == ["self", "a", "b"]
290 assert [p.name for p in params_reordered] == ["self", "b", "c", "a"]
291 assert str(params_reordered["b"].annotation) == "float"
294def test_parameters_annotated_as_initvar() -> None:
295 """Don't return InitVar annotated fields as class members.
297 But if __init__ is defined, InitVar has no effect.
298 """
299 code = """
300 from dataclasses import dataclass, InitVar
302 @dataclass
303 class PointA:
304 x: float
305 y: float
306 z: InitVar[float]
308 @dataclass
309 class PointB:
310 x: float
311 y: float
312 z: InitVar[float]
314 def __init__(self, r: float): ...
315 """
317 with temporary_visited_package("package", {"__init__.py": code}) as module:
318 point_a = module["PointA"]
319 assert [p.name for p in point_a.parameters] == ["self", "x", "y", "z"]
320 assert list(point_a.members) == ["x", "y", "__init__"]
322 point_b = module["PointB"]
323 assert [p.name for p in point_b.parameters] == ["self", "r"]
324 assert list(point_b.members) == ["x", "y", "z", "__init__"]
327def test_visited_module_source() -> None:
328 """Check the source property of a module."""
329 code = "print('hello')\nprint('world')"
330 with temporary_visited_package("package", {"__init__.py": code}) as module:
331 assert module.source == code
334def test_visited_class_source() -> None:
335 """Check the source property of a class."""
336 code = """
337 class A:
338 def __init__(self, x: int):
339 self.x = x
340 """
341 with temporary_visited_package("package", {"__init__.py": code}) as module:
342 assert module["A"].source == dedent(code).strip()
345def test_visited_object_source_with_missing_line_number() -> None:
346 """Check the source property of an object with missing line number."""
347 code = """
348 class A:
349 def __init__(self, x: int):
350 self.x = x
351 """
352 with temporary_visited_package("package", {"__init__.py": code}) as module:
353 module["A"].endlineno = None
354 assert not module["A"].source
355 module["A"].endlineno = 3
356 module["A"].lineno = None
357 assert not module["A"].source
360def test_inspected_module_source() -> None:
361 """Check the source property of a module."""
362 code = "print('hello')\nprint('world')"
363 with temporary_inspected_module(code) as module:
364 assert module.source == code
367def test_inspected_class_source() -> None:
368 """Check the source property of a class."""
369 code = """
370 class A:
371 def __init__(self, x: int):
372 self.x = x
373 """
374 with temporary_inspected_module(code) as module:
375 assert module["A"].source == dedent(code).strip()
378def test_inspected_object_source_with_missing_line_number() -> None:
379 """Check the source property of an object with missing line number."""
380 code = """
381 class A:
382 def __init__(self, x: int):
383 self.x = x
384 """
385 with temporary_inspected_module(code) as module:
386 module["A"].endlineno = None
387 assert not module["A"].source
388 module["A"].endlineno = 3
389 module["A"].lineno = None
390 assert not module["A"].source
393def test_dataclass_parameter_docstrings() -> None:
394 """Class parameters should have a docstring attribute."""
395 code = """
396 from dataclasses import dataclass, InitVar
398 @dataclass
399 class Base:
400 a: int
401 "Parameter a"
402 b: InitVar[int] = 3
403 "Parameter b"
405 @dataclass
406 class Derived(Base):
407 c: float
408 d: InitVar[float]
409 "Parameter d"
410 """
412 with temporary_visited_package("package", {"__init__.py": code}) as module:
413 base = module["Base"]
414 param_self = base.parameters[0]
415 param_a = base.parameters[1]
416 param_b = base.parameters[2]
417 assert param_self.docstring is None
418 assert param_a.docstring.value == "Parameter a"
419 assert param_b.docstring.value == "Parameter b"
421 derived = module["Derived"]
422 param_self = derived.parameters[0]
423 param_a = derived.parameters[1]
424 param_b = derived.parameters[2]
425 param_c = derived.parameters[3]
426 param_d = derived.parameters[4]
427 assert param_self.docstring is None
428 assert param_a.docstring.value == "Parameter a"
429 assert param_b.docstring.value == "Parameter b"
430 assert param_c.docstring is None
431 assert param_d.docstring.value == "Parameter d"
434def test_attributes_that_have_no_annotations() -> None:
435 """Dataclass attributes that have no annotatations are not parameters."""
436 code = """
437 from dataclasses import dataclass, field
439 @dataclass
440 class Base:
441 a: int
442 b: str = field(init=False)
443 c = 3 # Class attribute.
445 @dataclass
446 class Derived(Base):
447 a = 1 # No effect on the parameter status of `a`.
448 b = "b" # Inherited non-parameter.
449 d: float = 4
450 """
451 with temporary_visited_package("package", {"__init__.py": code}) as module:
452 base_params = [p.name for p in module["Base"].parameters]
453 derived_params = [p.name for p in module["Derived"].parameters]
454 assert base_params == ["self", "a"]
455 assert derived_params == ["self", "a", "d"]
458def test_name_resolution() -> None:
459 """Name are correctly resolved in the scope of an object."""
460 code = """
461 module_attribute = 0
463 class Class:
464 import imported
466 class_attribute = 0
468 def __init__(self):
469 self.instance_attribute = 0
471 def method(self):
472 local_variable = 0
473 """
474 with temporary_visited_module(code) as module:
475 assert module.resolve("module_attribute") == "module.module_attribute"
476 assert module.resolve("Class") == "module.Class"
478 assert module["module_attribute"].resolve("Class") == "module.Class"
479 with pytest.raises(NameResolutionError):
480 module["module_attribute"].resolve("class_attribute")
482 assert module["Class"].resolve("module_attribute") == "module.module_attribute"
483 assert module["Class"].resolve("imported") == "imported"
484 assert module["Class"].resolve("class_attribute") == "module.Class.class_attribute"
485 assert module["Class"].resolve("instance_attribute") == "module.Class.instance_attribute"
486 assert module["Class"].resolve("method") == "module.Class.method"
488 assert module["Class.class_attribute"].resolve("module_attribute") == "module.module_attribute"
489 assert module["Class.class_attribute"].resolve("Class") == "module.Class"
490 assert module["Class.class_attribute"].resolve("imported") == "imported"
491 assert module["Class.class_attribute"].resolve("instance_attribute") == "module.Class.instance_attribute"
492 assert module["Class.class_attribute"].resolve("method") == "module.Class.method"
494 assert module["Class.instance_attribute"].resolve("module_attribute") == "module.module_attribute"
495 assert module["Class.instance_attribute"].resolve("Class") == "module.Class"
496 assert module["Class.instance_attribute"].resolve("imported") == "imported"
497 assert module["Class.instance_attribute"].resolve("class_attribute") == "module.Class.class_attribute"
498 assert module["Class.instance_attribute"].resolve("method") == "module.Class.method"
500 assert module["Class.method"].resolve("module_attribute") == "module.module_attribute"
501 assert module["Class.method"].resolve("Class") == "module.Class"
502 assert module["Class.method"].resolve("imported") == "imported"
503 assert module["Class.method"].resolve("class_attribute") == "module.Class.class_attribute"
504 assert module["Class.method"].resolve("instance_attribute") == "module.Class.instance_attribute"
507def test_set_parameters() -> None:
508 """We can set parameters."""
509 parameters = Parameters()
510 # Does not exist yet.
511 parameters["x"] = Parameter(name="x")
512 assert "x" in parameters
513 # Already exists, by name.
514 parameters["x"] = Parameter(name="x")
515 assert "x" in parameters
516 assert len(parameters) == 1
517 # Already exists, by index.
518 parameters[0] = Parameter(name="y")
519 assert "y" in parameters
520 assert len(parameters) == 1
523def test_delete_parameters() -> None:
524 """We can delete parameters."""
525 parameters = Parameters()
526 # By name.
527 parameters["x"] = Parameter(name="x")
528 del parameters["x"]
529 assert "x" not in parameters
530 assert len(parameters) == 0
531 # By index.
532 parameters["x"] = Parameter(name="x")
533 del parameters[0]
534 assert "x" not in parameters
535 assert len(parameters) == 0
538def test_not_resolving_attribute_value_to_itself() -> None:
539 """Attribute values with same name don't resolve to themselves."""
540 with temporary_visited_module(
541 """
542 class A:
543 def __init__(self):
544 x = "something"
545 self.x = x
546 """,
547 ) as module:
548 assert module["A.x"].value.canonical_path == "x" # Not `module.A.x`.
551def test_resolving_never_raises_alias_errors() -> None:
552 """Resolving never raises alias errors."""
553 with temporary_visited_package(
554 "package",
555 {
556 "__init__.py": """
557 from package.mod import pd
559 class A:
560 def __init__(self):
561 pass
562 """,
563 "mod.py": "import pandas as pd",
564 },
565 ) as module:
566 assert module["A.__init__"].resolve("pd") == "package.mod.pd"
569def test_building_function_and_class_signatures() -> None:
570 """Test the construction of a class/function signature."""
571 # Test simple function signatures.
572 simple_params = Parameters(
573 Parameter("x", annotation="int"),
574 Parameter("y", annotation="int", default="0"),
575 )
576 simple_func = Function("simple_function", parameters=simple_params, returns="int")
577 assert simple_func.signature() == "simple_function(x: int, y: int = 0) -> int"
579 # Test class signatures.
580 init = Function("__init__", parameters=simple_params, returns="None")
581 cls = Class("TestClass")
582 cls.set_member("__init__", init)
583 assert cls.signature() == "TestClass(x: int, y: int = 0)"
585 # Create a more complex function with various parameter types.
586 params = Parameters(
587 Parameter("a", kind=ParameterKind.positional_only),
588 Parameter("b", kind=ParameterKind.positional_only, annotation="int", default="0"),
589 Parameter("c", kind=ParameterKind.positional_or_keyword),
590 Parameter("d", kind=ParameterKind.positional_or_keyword, annotation="str", default="''"),
591 Parameter("args", kind=ParameterKind.var_positional),
592 Parameter("e", kind=ParameterKind.keyword_only),
593 Parameter("f", kind=ParameterKind.keyword_only, annotation="bool", default="False"),
594 Parameter("kwargs", kind=ParameterKind.var_keyword),
595 )
597 func = Function("test_function", parameters=params, returns="None")
598 expected = "test_function(a, b: int = 0, /, c, d: str = '', *args, e, f: bool = False, **kwargs) -> None"
599 assert func.signature() == expected
602def test_set_type_parameters() -> None:
603 """We can set type parameters."""
604 type_parameters = TypeParameters()
605 # Does not exist yet.
606 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var)
607 assert "x" in type_parameters
608 # Already exists, by name.
609 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var)
610 assert "x" in type_parameters
611 assert len(type_parameters) == 1
612 # Already exists, by name, with different kind.
613 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.param_spec)
614 assert "x" in type_parameters
615 assert len(type_parameters) == 1
616 # Already exists, by index.
617 type_parameters[0] = TypeParameter(name="y", kind=TypeParameterKind.type_var)
618 assert "y" in type_parameters
619 assert len(type_parameters) == 1
622def test_delete_type_parameters() -> None:
623 """We can delete type parameters."""
624 type_parameters = TypeParameters()
625 # By name.
626 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var)
627 del type_parameters["x"]
628 assert "x" not in type_parameters
629 assert len(type_parameters) == 0
630 # By index.
631 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var)
632 del type_parameters[0]
633 assert "x" not in type_parameters
634 assert len(type_parameters) == 0
637# YORE: EOL 3.11: Remove line.
638@pytest.mark.skipif(sys.version_info < (3, 12), reason="Python less than 3.12 does not have PEP 695 generics")
639def test_annotation_resolution() -> None:
640 """Names are correctly resolved in the annotation scope of an object."""
641 with temporary_visited_module(
642 """
643 class C[T]:
644 class D[T]:
645 def func[Y](self, arg1: T, arg2: Y): pass
646 def func[Z](arg1: T, arg2: Y): pass
647 """,
648 ) as module:
649 assert module["C.D"].resolve("T") == "module.C.D[T]"
651 assert module["C.D.func"].resolve("T") == "module.C.D[T]"
652 assert module["C.D.func"].resolve("Y") == "module.C.D.func[Y]"
654 assert module["C"].resolve("T") == "module.C[T]"
656 assert module["C.func"].resolve("T") == "module.C[T]"
657 with pytest.raises(NameResolutionError):
658 module["C.func"].resolve("Y")