Coverage for tests / test_models.py: 100.00%

261 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-11 11:48 +0100

1"""Tests for the `dataclasses` module.""" 

2 

3from __future__ import annotations 

4 

5import sys 

6from copy import deepcopy 

7from textwrap import dedent 

8 

9import pytest 

10 

11from griffe import ( 

12 Attribute, 

13 Class, 

14 Docstring, 

15 Function, 

16 GriffeLoader, 

17 Module, 

18 NameResolutionError, 

19 Parameter, 

20 ParameterKind, 

21 Parameters, 

22 TypeParameter, 

23 TypeParameterKind, 

24 TypeParameters, 

25 module_vtree, 

26 temporary_inspected_module, 

27 temporary_pypackage, 

28 temporary_visited_module, 

29 temporary_visited_package, 

30) 

31 

32 

33def test_submodule_exports() -> None: 

34 """Check that a module is exported depending on whether it was also imported.""" 

35 root = Module("root") 

36 sub = Module("sub") 

37 private = Attribute("_private") 

38 root["sub"] = sub 

39 root["_private"] = private 

40 

41 assert not sub.is_wildcard_exposed 

42 root.imports["sub"] = "root.sub" 

43 assert sub.is_wildcard_exposed 

44 

45 assert not private.is_wildcard_exposed 

46 root.exports = ["_private"] 

47 assert private.is_wildcard_exposed 

48 

49 

50def test_has_docstrings() -> None: 

51 """Assert the `.has_docstrings` method is recursive.""" 

52 with temporary_visited_module("class A:\n '''Hello.'''") as module: 

53 assert module.has_docstrings 

54 

55 

56def test_has_docstrings_submodules() -> None: 

57 """Assert the `.has_docstrings` method descends into submodules.""" 

58 module = module_vtree("a.b.c.d") 

59 module["b.c.d"].docstring = Docstring("Hello.") 

60 assert module.has_docstrings 

61 

62 

63def test_handle_aliases_chain_in_has_docstrings() -> None: 

64 """Assert the `.has_docstrings` method can handle aliases chains in members.""" 

65 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package: 

66 mod_a = tmp_package.path / "mod_a.py" 

67 mod_b = tmp_package.path / "mod_b.py" 

68 mod_a.write_text("from .mod_b import someobj", encoding="utf8") 

69 mod_b.write_text("from somelib import someobj", encoding="utf8") 

70 

71 loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) 

72 package = loader.load(tmp_package.name) 

73 assert not package.has_docstrings 

74 loader.resolve_aliases(implicit=True) 

75 assert not package.has_docstrings 

76 

77 

78def test_has_docstrings_does_not_trigger_alias_resolution() -> None: 

79 """Assert the `.has_docstrings` method does not trigger alias resolution.""" 

80 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package: 

81 mod_a = tmp_package.path / "mod_a.py" 

82 mod_b = tmp_package.path / "mod_b.py" 

83 mod_a.write_text("from .mod_b import someobj", encoding="utf8") 

84 mod_b.write_text("from somelib import someobj", encoding="utf8") 

85 

86 loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) 

87 package = loader.load(tmp_package.name) 

88 assert not package.has_docstrings 

89 assert not package["mod_a.someobj"].resolved 

90 

91 

92def test_deepcopy() -> None: 

93 """Assert we can deep-copy object trees.""" 

94 loader = GriffeLoader() 

95 mod = loader.load("griffe") 

96 

97 deepcopy(mod) 

98 deepcopy(mod.as_dict()) 

99 

100 

101def test_dataclass_properties_and_class_variables() -> None: 

102 """Don't return properties or class variables as parameters of dataclasses.""" 

103 code = """ 

104 from dataclasses import dataclass 

105 from functools import cached_property 

106 from typing import ClassVar 

107 

108 @dataclass 

109 class Point: 

110 x: float 

111 y: float 

112 

113 # These definitions create class variables. 

114 r: ClassVar[float] 

115 s: float = 3 

116 t: ClassVar[float] = 3 

117 

118 @property 

119 def a(self): 

120 return 0 

121 

122 @cached_property 

123 def b(self): 

124 return 0 

125 """ 

126 with temporary_visited_package("package", {"__init__.py": code}) as module: 

127 params = module["Point"].parameters 

128 assert [p.name for p in params] == ["self", "x", "y", "s"] 

129 

130 

131@pytest.mark.parametrize( 

132 "code", 

133 [ 

134 """ 

135 @dataclass 

136 class Dataclass: 

137 x: float 

138 y: float = field(kw_only=True) 

139 

140 class Class: 

141 def __init__(self, x: float, *, y: float): ... 

142 """, 

143 """ 

144 @dataclass 

145 class Dataclass: 

146 x: float = field(kw_only=True) 

147 y: float 

148 

149 class Class: 

150 def __init__(self, y: float, *, x: float): ... 

151 """, 

152 """ 

153 @dataclass 

154 class Dataclass: 

155 x: float 

156 _: KW_ONLY 

157 y: float 

158 

159 class Class: 

160 def __init__(self, x: float, *, y: float): ... 

161 """, 

162 """ 

163 @dataclass 

164 class Dataclass: 

165 _: KW_ONLY 

166 x: float 

167 y: float 

168 

169 class Class: 

170 def __init__(self, *, x: float, y: float): ... 

171 """, 

172 """ 

173 @dataclass(kw_only=True) 

174 class Dataclass: 

175 x: float 

176 y: float 

177 

178 class Class: 

179 def __init__(self, *, x: float, y: float): ... 

180 """, 

181 ], 

182) 

183def test_dataclass_parameter_kinds(code: str) -> None: 

184 """Check dataclass and equivalent non-dataclass parameters. 

185 

186 The parameter kinds for each pair should be the same. 

187 

188 Parameters: 

189 code: Python code to visit. 

190 """ 

191 code = f"from dataclasses import dataclass, field, KW_ONLY\n\n{dedent(code)}" 

192 with temporary_visited_package("package", {"__init__.py": code}) as module: 

193 for dataclass_param, regular_param in zip( 

194 module["Dataclass"].parameters, 

195 module["Class"].parameters, 

196 strict=False, 

197 ): 

198 assert dataclass_param == regular_param 

199 

200 

201def test_regular_class_inheriting_dataclass_dont_get_its_own_params() -> None: 

202 """A regular class inheriting from a dataclass don't have its attributes added to `__init__`.""" 

203 code = """ 

204 from dataclasses import dataclass 

205 

206 @dataclass 

207 class Base: 

208 a: int 

209 b: str 

210 

211 @dataclass 

212 class Derived1(Base): 

213 c: float 

214 

215 class Derived2(Base): 

216 d: float 

217 """ 

218 with temporary_visited_package("package", {"__init__.py": code}) as module: 

219 params1 = module["Derived1"].parameters 

220 params2 = module["Derived2"].parameters 

221 assert [p.name for p in params1] == ["self", "a", "b", "c"] 

222 assert [p.name for p in params2] == ["self", "a", "b"] 

223 

224 

225def test_regular_class_inheriting_dataclass_is_labelled_dataclass() -> None: 

226 """A regular class inheriting from a dataclass is labelled as a dataclass too.""" 

227 code = """ 

228 from dataclasses import dataclass 

229 

230 @dataclass 

231 class Base: 

232 pass 

233 

234 class Derived(Base): 

235 pass 

236 """ 

237 with temporary_visited_package("package", {"__init__.py": code}) as module: 

238 obj = module["Derived"] 

239 assert "dataclass" in obj.labels 

240 

241 

242def test_fields_with_init_false() -> None: 

243 """Fields marked with `init=False` are not added to the `__init__` method.""" 

244 code = """ 

245 from dataclasses import dataclass, field 

246 

247 @dataclass 

248 class PointA: 

249 x: float 

250 y: float 

251 z: float = field(init=False) 

252 

253 @dataclass(init=False) 

254 class PointB: 

255 x: float 

256 y: float 

257 

258 @dataclass(init=False) 

259 class PointC: 

260 x: float 

261 y: float = field(init=True) # `init=True` has no effect. 

262 """ 

263 with temporary_visited_package("package", {"__init__.py": code}) as module: 

264 params_a = module["PointA"].parameters 

265 params_b = module["PointB"].parameters 

266 params_c = module["PointC"].parameters 

267 

268 assert "z" not in params_a 

269 assert "x" not in params_b 

270 assert "y" not in params_b 

271 assert "x" not in params_c 

272 assert "y" not in params_c 

273 

274 

275def test_parameters_are_reorderd_to_match_their_kind() -> None: 

276 """Keyword-only parameters in base class are pushed back to the end of the signature.""" 

277 code = """ 

278 from dataclasses import dataclass 

279 

280 @dataclass(kw_only=True) 

281 class Base: 

282 a: int 

283 b: str 

284 

285 @dataclass 

286 class Reordered(Base): 

287 b: float 

288 c: float 

289 """ 

290 with temporary_visited_package("package", {"__init__.py": code}) as module: 

291 params_base = module["Base"].parameters 

292 params_reordered = module["Reordered"].parameters 

293 assert [p.name for p in params_base] == ["self", "a", "b"] 

294 assert [p.name for p in params_reordered] == ["self", "b", "c", "a"] 

295 assert str(params_reordered["b"].annotation) == "float" 

296 

297 

298def test_parameters_annotated_as_initvar() -> None: 

299 """Don't return InitVar annotated fields as class members. 

300 

301 But if __init__ is defined, InitVar has no effect. 

302 """ 

303 code = """ 

304 from dataclasses import dataclass, InitVar 

305 

306 @dataclass 

307 class PointA: 

308 x: float 

309 y: float 

310 z: InitVar[float] 

311 

312 @dataclass 

313 class PointB: 

314 x: float 

315 y: float 

316 z: InitVar[float] 

317 

318 def __init__(self, r: float): ... 

319 """ 

320 

321 with temporary_visited_package("package", {"__init__.py": code}) as module: 

322 point_a = module["PointA"] 

323 assert [p.name for p in point_a.parameters] == ["self", "x", "y", "z"] 

324 assert list(point_a.members) == ["x", "y", "__init__"] 

325 

326 point_b = module["PointB"] 

327 assert [p.name for p in point_b.parameters] == ["self", "r"] 

328 assert list(point_b.members) == ["x", "y", "z", "__init__"] 

329 

330 

331def test_visited_module_source() -> None: 

332 """Check the source property of a module.""" 

333 code = "print('hello')\nprint('world')" 

334 with temporary_visited_package("package", {"__init__.py": code}) as module: 

335 assert module.source == code 

336 

337 

338def test_visited_class_source() -> None: 

339 """Check the source property of a class.""" 

340 code = """ 

341 class A: 

342 def __init__(self, x: int): 

343 self.x = x 

344 """ 

345 with temporary_visited_package("package", {"__init__.py": code}) as module: 

346 assert module["A"].source == dedent(code).strip() 

347 

348 

349def test_visited_object_source_with_missing_line_number() -> None: 

350 """Check the source property of an object with missing line number.""" 

351 code = """ 

352 class A: 

353 def __init__(self, x: int): 

354 self.x = x 

355 """ 

356 with temporary_visited_package("package", {"__init__.py": code}) as module: 

357 module["A"].endlineno = None 

358 assert not module["A"].source 

359 module["A"].endlineno = 3 

360 module["A"].lineno = None 

361 assert not module["A"].source 

362 

363 

364def test_inspected_module_source() -> None: 

365 """Check the source property of a module.""" 

366 code = "print('hello')\nprint('world')" 

367 with temporary_inspected_module(code) as module: 

368 assert module.source == code 

369 

370 

371def test_inspected_class_source() -> None: 

372 """Check the source property of a class.""" 

373 code = """ 

374 class A: 

375 def __init__(self, x: int): 

376 self.x = x 

377 """ 

378 with temporary_inspected_module(code) as module: 

379 assert module["A"].source == dedent(code).strip() 

380 

381 

382def test_inspected_object_source_with_missing_line_number() -> None: 

383 """Check the source property of an object with missing line number.""" 

384 code = """ 

385 class A: 

386 def __init__(self, x: int): 

387 self.x = x 

388 """ 

389 with temporary_inspected_module(code) as module: 

390 module["A"].endlineno = None 

391 assert not module["A"].source 

392 module["A"].endlineno = 3 

393 module["A"].lineno = None 

394 assert not module["A"].source 

395 

396 

397def test_dataclass_parameter_docstrings() -> None: 

398 """Class parameters should have a docstring attribute.""" 

399 code = """ 

400 from dataclasses import dataclass, InitVar 

401 

402 @dataclass 

403 class Base: 

404 a: int 

405 "Parameter a" 

406 b: InitVar[int] = 3 

407 "Parameter b" 

408 

409 @dataclass 

410 class Derived(Base): 

411 c: float 

412 d: InitVar[float] 

413 "Parameter d" 

414 """ 

415 

416 with temporary_visited_package("package", {"__init__.py": code}) as module: 

417 base = module["Base"] 

418 param_self = base.parameters[0] 

419 param_a = base.parameters[1] 

420 param_b = base.parameters[2] 

421 assert param_self.docstring is None 

422 assert param_a.docstring.value == "Parameter a" 

423 assert param_b.docstring.value == "Parameter b" 

424 

425 derived = module["Derived"] 

426 param_self = derived.parameters[0] 

427 param_a = derived.parameters[1] 

428 param_b = derived.parameters[2] 

429 param_c = derived.parameters[3] 

430 param_d = derived.parameters[4] 

431 assert param_self.docstring is None 

432 assert param_a.docstring.value == "Parameter a" 

433 assert param_b.docstring.value == "Parameter b" 

434 assert param_c.docstring is None 

435 assert param_d.docstring.value == "Parameter d" 

436 

437 

438def test_attributes_that_have_no_annotations() -> None: 

439 """Dataclass attributes that have no annotatations are not parameters.""" 

440 code = """ 

441 from dataclasses import dataclass, field 

442 

443 @dataclass 

444 class Base: 

445 a: int 

446 b: str = field(init=False) 

447 c = 3 # Class attribute. 

448 

449 @dataclass 

450 class Derived(Base): 

451 a = 1 # No effect on the parameter status of `a`. 

452 b = "b" # Inherited non-parameter. 

453 d: float = 4 

454 """ 

455 with temporary_visited_package("package", {"__init__.py": code}) as module: 

456 base_params = [p.name for p in module["Base"].parameters] 

457 derived_params = [p.name for p in module["Derived"].parameters] 

458 assert base_params == ["self", "a"] 

459 assert derived_params == ["self", "a", "d"] 

460 

461 

462def test_name_resolution() -> None: 

463 """Name are correctly resolved in the scope of an object.""" 

464 code = """ 

465 module_attribute = 0 

466 

467 class Class: 

468 import imported 

469 

470 class_attribute = 0 

471 

472 def __init__(self): 

473 self.instance_attribute = 0 

474 

475 def method(self): 

476 local_variable = 0 

477 """ 

478 with temporary_visited_module(code) as module: 

479 assert module.resolve("module_attribute") == "module.module_attribute" 

480 assert module.resolve("Class") == "module.Class" 

481 

482 assert module["module_attribute"].resolve("Class") == "module.Class" 

483 with pytest.raises(NameResolutionError): 

484 module["module_attribute"].resolve("class_attribute") 

485 

486 assert module["Class"].resolve("module_attribute") == "module.module_attribute" 

487 assert module["Class"].resolve("imported") == "imported" 

488 assert module["Class"].resolve("class_attribute") == "module.Class.class_attribute" 

489 assert module["Class"].resolve("instance_attribute") == "module.Class.instance_attribute" 

490 assert module["Class"].resolve("method") == "module.Class.method" 

491 

492 assert module["Class.class_attribute"].resolve("module_attribute") == "module.module_attribute" 

493 assert module["Class.class_attribute"].resolve("Class") == "module.Class" 

494 assert module["Class.class_attribute"].resolve("imported") == "imported" 

495 assert module["Class.class_attribute"].resolve("instance_attribute") == "module.Class.instance_attribute" 

496 assert module["Class.class_attribute"].resolve("method") == "module.Class.method" 

497 

498 assert module["Class.instance_attribute"].resolve("module_attribute") == "module.module_attribute" 

499 assert module["Class.instance_attribute"].resolve("Class") == "module.Class" 

500 assert module["Class.instance_attribute"].resolve("imported") == "imported" 

501 assert module["Class.instance_attribute"].resolve("class_attribute") == "module.Class.class_attribute" 

502 assert module["Class.instance_attribute"].resolve("method") == "module.Class.method" 

503 

504 assert module["Class.method"].resolve("module_attribute") == "module.module_attribute" 

505 assert module["Class.method"].resolve("Class") == "module.Class" 

506 assert module["Class.method"].resolve("imported") == "imported" 

507 assert module["Class.method"].resolve("class_attribute") == "module.Class.class_attribute" 

508 assert module["Class.method"].resolve("instance_attribute") == "module.Class.instance_attribute" 

509 

510 

511def test_set_parameters() -> None: 

512 """We can set parameters.""" 

513 parameters = Parameters() 

514 # Does not exist yet. 

515 parameters["x"] = Parameter(name="x") 

516 assert "x" in parameters 

517 # Already exists, by name. 

518 parameters["x"] = Parameter(name="x") 

519 assert "x" in parameters 

520 assert len(parameters) == 1 

521 # Already exists, by index. 

522 parameters[0] = Parameter(name="y") 

523 assert "y" in parameters 

524 assert len(parameters) == 1 

525 

526 

527def test_delete_parameters() -> None: 

528 """We can delete parameters.""" 

529 parameters = Parameters() 

530 # By name. 

531 parameters["x"] = Parameter(name="x") 

532 del parameters["x"] 

533 assert "x" not in parameters 

534 assert len(parameters) == 0 

535 # By index. 

536 parameters["x"] = Parameter(name="x") 

537 del parameters[0] 

538 assert "x" not in parameters 

539 assert len(parameters) == 0 

540 

541 

542def test_not_resolving_attribute_value_to_itself() -> None: 

543 """Attribute values with same name don't resolve to themselves.""" 

544 with temporary_visited_module( 

545 """ 

546 class A: 

547 def __init__(self): 

548 x = "something" 

549 self.x = x 

550 """, 

551 ) as module: 

552 assert module["A.x"].value.canonical_path == "x" # Not `module.A.x`. 

553 

554 

555def test_resolving_never_raises_alias_errors() -> None: 

556 """Resolving never raises alias errors.""" 

557 with temporary_visited_package( 

558 "package", 

559 { 

560 "__init__.py": """ 

561 from package.mod import pd 

562 

563 class A: 

564 def __init__(self): 

565 pass 

566 """, 

567 "mod.py": "import pandas as pd", 

568 }, 

569 ) as module: 

570 assert module["A.__init__"].resolve("pd") == "package.mod.pd" 

571 

572 

573def test_building_function_and_class_signatures() -> None: 

574 """Test the construction of a class/function signature.""" 

575 # Test simple function signatures. 

576 simple_params = Parameters( 

577 Parameter("x", annotation="int"), 

578 Parameter("y", annotation="int", default="0"), 

579 ) 

580 simple_func = Function("simple_function", parameters=simple_params, returns="int") 

581 assert simple_func.signature() == "simple_function(x: int, y: int = 0) -> int" 

582 

583 # Test class signatures. 

584 init = Function("__init__", parameters=simple_params, returns="None") 

585 cls = Class("TestClass") 

586 cls.set_member("__init__", init) 

587 assert cls.signature() == "TestClass(x: int, y: int = 0)" 

588 

589 # Create a more complex function with various parameter types. 

590 params = Parameters( 

591 Parameter("a", kind=ParameterKind.positional_only), 

592 Parameter("b", kind=ParameterKind.positional_only, annotation="int", default="0"), 

593 Parameter("c", kind=ParameterKind.positional_or_keyword), 

594 Parameter("d", kind=ParameterKind.positional_or_keyword, annotation="str", default="''"), 

595 Parameter("args", kind=ParameterKind.var_positional), 

596 Parameter("e", kind=ParameterKind.keyword_only), 

597 Parameter("f", kind=ParameterKind.keyword_only, annotation="bool", default="False"), 

598 Parameter("kwargs", kind=ParameterKind.var_keyword), 

599 ) 

600 

601 func = Function("test_function", parameters=params, returns="None") 

602 expected = "test_function(a, b: int = 0, /, c, d: str = '', *args, e, f: bool = False, **kwargs) -> None" 

603 assert func.signature() == expected 

604 

605 

606def test_set_type_parameters() -> None: 

607 """We can set type parameters.""" 

608 type_parameters = TypeParameters() 

609 # Does not exist yet. 

610 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var) 

611 assert "x" in type_parameters 

612 # Already exists, by name. 

613 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var) 

614 assert "x" in type_parameters 

615 assert len(type_parameters) == 1 

616 # Already exists, by name, with different kind. 

617 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.param_spec) 

618 assert "x" in type_parameters 

619 assert len(type_parameters) == 1 

620 # Already exists, by index. 

621 type_parameters[0] = TypeParameter(name="y", kind=TypeParameterKind.type_var) 

622 assert "y" in type_parameters 

623 assert len(type_parameters) == 1 

624 

625 

626def test_delete_type_parameters() -> None: 

627 """We can delete type parameters.""" 

628 type_parameters = TypeParameters() 

629 # By name. 

630 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var) 

631 del type_parameters["x"] 

632 assert "x" not in type_parameters 

633 assert len(type_parameters) == 0 

634 # By index. 

635 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var) 

636 del type_parameters[0] 

637 assert "x" not in type_parameters 

638 assert len(type_parameters) == 0 

639 

640 

641# YORE: EOL 3.11: Remove line. 

642@pytest.mark.skipif(sys.version_info < (3, 12), reason="Python less than 3.12 does not have PEP 695 generics") 

643def test_annotation_resolution() -> None: 

644 """Names are correctly resolved in the annotation scope of an object.""" 

645 with temporary_visited_module( 

646 """ 

647 class C[T]: 

648 class D[T]: 

649 def func[Y](self, arg1: T, arg2: Y): pass 

650 def func[Z](arg1: T, arg2: Y): pass 

651 """, 

652 ) as module: 

653 assert module["C.D"].resolve("T") == "module.C.D[T]" 

654 

655 assert module["C.D.func"].resolve("T") == "module.C.D[T]" 

656 assert module["C.D.func"].resolve("Y") == "module.C.D.func[Y]" 

657 

658 assert module["C"].resolve("T") == "module.C[T]" 

659 

660 assert module["C.func"].resolve("T") == "module.C[T]" 

661 with pytest.raises(NameResolutionError): 

662 module["C.func"].resolve("Y")