Coverage for tests/test_models.py: 100.00%

261 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-11 13:44 +0200

1"""Tests for the `dataclasses` module.""" 

2 

3from __future__ import annotations 

4 

5import sys 

6from copy import deepcopy 

7from textwrap import dedent 

8 

9import pytest 

10 

11from griffe import ( 

12 Attribute, 

13 Class, 

14 Docstring, 

15 Function, 

16 GriffeLoader, 

17 Module, 

18 NameResolutionError, 

19 Parameter, 

20 ParameterKind, 

21 Parameters, 

22 TypeParameter, 

23 TypeParameterKind, 

24 TypeParameters, 

25 module_vtree, 

26 temporary_inspected_module, 

27 temporary_pypackage, 

28 temporary_visited_module, 

29 temporary_visited_package, 

30) 

31 

32 

33def test_submodule_exports() -> None: 

34 """Check that a module is exported depending on whether it was also imported.""" 

35 root = Module("root") 

36 sub = Module("sub") 

37 private = Attribute("_private") 

38 root["sub"] = sub 

39 root["_private"] = private 

40 

41 assert not sub.is_wildcard_exposed 

42 root.imports["sub"] = "root.sub" 

43 assert sub.is_wildcard_exposed 

44 

45 assert not private.is_wildcard_exposed 

46 root.exports = ["_private"] 

47 assert private.is_wildcard_exposed 

48 

49 

50def test_has_docstrings() -> None: 

51 """Assert the `.has_docstrings` method is recursive.""" 

52 with temporary_visited_module("class A:\n '''Hello.'''") as module: 

53 assert module.has_docstrings 

54 

55 

56def test_has_docstrings_submodules() -> None: 

57 """Assert the `.has_docstrings` method descends into submodules.""" 

58 module = module_vtree("a.b.c.d") 

59 module["b.c.d"].docstring = Docstring("Hello.") 

60 assert module.has_docstrings 

61 

62 

63def test_handle_aliases_chain_in_has_docstrings() -> None: 

64 """Assert the `.has_docstrings` method can handle aliases chains in members.""" 

65 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package: 

66 mod_a = tmp_package.path / "mod_a.py" 

67 mod_b = tmp_package.path / "mod_b.py" 

68 mod_a.write_text("from .mod_b import someobj") 

69 mod_b.write_text("from somelib import someobj") 

70 

71 loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) 

72 package = loader.load(tmp_package.name) 

73 assert not package.has_docstrings 

74 loader.resolve_aliases(implicit=True) 

75 assert not package.has_docstrings 

76 

77 

78def test_has_docstrings_does_not_trigger_alias_resolution() -> None: 

79 """Assert the `.has_docstrings` method does not trigger alias resolution.""" 

80 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package: 

81 mod_a = tmp_package.path / "mod_a.py" 

82 mod_b = tmp_package.path / "mod_b.py" 

83 mod_a.write_text("from .mod_b import someobj") 

84 mod_b.write_text("from somelib import someobj") 

85 

86 loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) 

87 package = loader.load(tmp_package.name) 

88 assert not package.has_docstrings 

89 assert not package["mod_a.someobj"].resolved 

90 

91 

92def test_deepcopy() -> None: 

93 """Assert we can deep-copy object trees.""" 

94 loader = GriffeLoader() 

95 mod = loader.load("griffe") 

96 

97 deepcopy(mod) 

98 deepcopy(mod.as_dict()) 

99 

100 

101def test_dataclass_properties_and_class_variables() -> None: 

102 """Don't return properties or class variables as parameters of dataclasses.""" 

103 code = """ 

104 from dataclasses import dataclass 

105 from functools import cached_property 

106 from typing import ClassVar 

107 

108 @dataclass 

109 class Point: 

110 x: float 

111 y: float 

112 

113 # These definitions create class variables. 

114 r: ClassVar[float] 

115 s: float = 3 

116 t: ClassVar[float] = 3 

117 

118 @property 

119 def a(self): 

120 return 0 

121 

122 @cached_property 

123 def b(self): 

124 return 0 

125 """ 

126 with temporary_visited_package("package", {"__init__.py": code}) as module: 

127 params = module["Point"].parameters 

128 assert [p.name for p in params] == ["self", "x", "y", "s"] 

129 

130 

131@pytest.mark.parametrize( 

132 "code", 

133 [ 

134 """ 

135 @dataclass 

136 class Dataclass: 

137 x: float 

138 y: float = field(kw_only=True) 

139 

140 class Class: 

141 def __init__(self, x: float, *, y: float): ... 

142 """, 

143 """ 

144 @dataclass 

145 class Dataclass: 

146 x: float = field(kw_only=True) 

147 y: float 

148 

149 class Class: 

150 def __init__(self, y: float, *, x: float): ... 

151 """, 

152 """ 

153 @dataclass 

154 class Dataclass: 

155 x: float 

156 _: KW_ONLY 

157 y: float 

158 

159 class Class: 

160 def __init__(self, x: float, *, y: float): ... 

161 """, 

162 """ 

163 @dataclass 

164 class Dataclass: 

165 _: KW_ONLY 

166 x: float 

167 y: float 

168 

169 class Class: 

170 def __init__(self, *, x: float, y: float): ... 

171 """, 

172 """ 

173 @dataclass(kw_only=True) 

174 class Dataclass: 

175 x: float 

176 y: float 

177 

178 class Class: 

179 def __init__(self, *, x: float, y: float): ... 

180 """, 

181 ], 

182) 

183def test_dataclass_parameter_kinds(code: str) -> None: 

184 """Check dataclass and equivalent non-dataclass parameters. 

185 

186 The parameter kinds for each pair should be the same. 

187 

188 Parameters: 

189 code: Python code to visit. 

190 """ 

191 code = f"from dataclasses import dataclass, field, KW_ONLY\n\n{dedent(code)}" 

192 with temporary_visited_package("package", {"__init__.py": code}) as module: 

193 for dataclass_param, regular_param in zip(module["Dataclass"].parameters, module["Class"].parameters): 

194 assert dataclass_param == regular_param 

195 

196 

197def test_regular_class_inheriting_dataclass_dont_get_its_own_params() -> None: 

198 """A regular class inheriting from a dataclass don't have its attributes added to `__init__`.""" 

199 code = """ 

200 from dataclasses import dataclass 

201 

202 @dataclass 

203 class Base: 

204 a: int 

205 b: str 

206 

207 @dataclass 

208 class Derived1(Base): 

209 c: float 

210 

211 class Derived2(Base): 

212 d: float 

213 """ 

214 with temporary_visited_package("package", {"__init__.py": code}) as module: 

215 params1 = module["Derived1"].parameters 

216 params2 = module["Derived2"].parameters 

217 assert [p.name for p in params1] == ["self", "a", "b", "c"] 

218 assert [p.name for p in params2] == ["self", "a", "b"] 

219 

220 

221def test_regular_class_inheriting_dataclass_is_labelled_dataclass() -> None: 

222 """A regular class inheriting from a dataclass is labelled as a dataclass too.""" 

223 code = """ 

224 from dataclasses import dataclass 

225 

226 @dataclass 

227 class Base: 

228 pass 

229 

230 class Derived(Base): 

231 pass 

232 """ 

233 with temporary_visited_package("package", {"__init__.py": code}) as module: 

234 obj = module["Derived"] 

235 assert "dataclass" in obj.labels 

236 

237 

238def test_fields_with_init_false() -> None: 

239 """Fields marked with `init=False` are not added to the `__init__` method.""" 

240 code = """ 

241 from dataclasses import dataclass, field 

242 

243 @dataclass 

244 class PointA: 

245 x: float 

246 y: float 

247 z: float = field(init=False) 

248 

249 @dataclass(init=False) 

250 class PointB: 

251 x: float 

252 y: float 

253 

254 @dataclass(init=False) 

255 class PointC: 

256 x: float 

257 y: float = field(init=True) # `init=True` has no effect. 

258 """ 

259 with temporary_visited_package("package", {"__init__.py": code}) as module: 

260 params_a = module["PointA"].parameters 

261 params_b = module["PointB"].parameters 

262 params_c = module["PointC"].parameters 

263 

264 assert "z" not in params_a 

265 assert "x" not in params_b 

266 assert "y" not in params_b 

267 assert "x" not in params_c 

268 assert "y" not in params_c 

269 

270 

271def test_parameters_are_reorderd_to_match_their_kind() -> None: 

272 """Keyword-only parameters in base class are pushed back to the end of the signature.""" 

273 code = """ 

274 from dataclasses import dataclass 

275 

276 @dataclass(kw_only=True) 

277 class Base: 

278 a: int 

279 b: str 

280 

281 @dataclass 

282 class Reordered(Base): 

283 b: float 

284 c: float 

285 """ 

286 with temporary_visited_package("package", {"__init__.py": code}) as module: 

287 params_base = module["Base"].parameters 

288 params_reordered = module["Reordered"].parameters 

289 assert [p.name for p in params_base] == ["self", "a", "b"] 

290 assert [p.name for p in params_reordered] == ["self", "b", "c", "a"] 

291 assert str(params_reordered["b"].annotation) == "float" 

292 

293 

294def test_parameters_annotated_as_initvar() -> None: 

295 """Don't return InitVar annotated fields as class members. 

296 

297 But if __init__ is defined, InitVar has no effect. 

298 """ 

299 code = """ 

300 from dataclasses import dataclass, InitVar 

301 

302 @dataclass 

303 class PointA: 

304 x: float 

305 y: float 

306 z: InitVar[float] 

307 

308 @dataclass 

309 class PointB: 

310 x: float 

311 y: float 

312 z: InitVar[float] 

313 

314 def __init__(self, r: float): ... 

315 """ 

316 

317 with temporary_visited_package("package", {"__init__.py": code}) as module: 

318 point_a = module["PointA"] 

319 assert [p.name for p in point_a.parameters] == ["self", "x", "y", "z"] 

320 assert list(point_a.members) == ["x", "y", "__init__"] 

321 

322 point_b = module["PointB"] 

323 assert [p.name for p in point_b.parameters] == ["self", "r"] 

324 assert list(point_b.members) == ["x", "y", "z", "__init__"] 

325 

326 

327def test_visited_module_source() -> None: 

328 """Check the source property of a module.""" 

329 code = "print('hello')\nprint('world')" 

330 with temporary_visited_package("package", {"__init__.py": code}) as module: 

331 assert module.source == code 

332 

333 

334def test_visited_class_source() -> None: 

335 """Check the source property of a class.""" 

336 code = """ 

337 class A: 

338 def __init__(self, x: int): 

339 self.x = x 

340 """ 

341 with temporary_visited_package("package", {"__init__.py": code}) as module: 

342 assert module["A"].source == dedent(code).strip() 

343 

344 

345def test_visited_object_source_with_missing_line_number() -> None: 

346 """Check the source property of an object with missing line number.""" 

347 code = """ 

348 class A: 

349 def __init__(self, x: int): 

350 self.x = x 

351 """ 

352 with temporary_visited_package("package", {"__init__.py": code}) as module: 

353 module["A"].endlineno = None 

354 assert not module["A"].source 

355 module["A"].endlineno = 3 

356 module["A"].lineno = None 

357 assert not module["A"].source 

358 

359 

360def test_inspected_module_source() -> None: 

361 """Check the source property of a module.""" 

362 code = "print('hello')\nprint('world')" 

363 with temporary_inspected_module(code) as module: 

364 assert module.source == code 

365 

366 

367def test_inspected_class_source() -> None: 

368 """Check the source property of a class.""" 

369 code = """ 

370 class A: 

371 def __init__(self, x: int): 

372 self.x = x 

373 """ 

374 with temporary_inspected_module(code) as module: 

375 assert module["A"].source == dedent(code).strip() 

376 

377 

378def test_inspected_object_source_with_missing_line_number() -> None: 

379 """Check the source property of an object with missing line number.""" 

380 code = """ 

381 class A: 

382 def __init__(self, x: int): 

383 self.x = x 

384 """ 

385 with temporary_inspected_module(code) as module: 

386 module["A"].endlineno = None 

387 assert not module["A"].source 

388 module["A"].endlineno = 3 

389 module["A"].lineno = None 

390 assert not module["A"].source 

391 

392 

393def test_dataclass_parameter_docstrings() -> None: 

394 """Class parameters should have a docstring attribute.""" 

395 code = """ 

396 from dataclasses import dataclass, InitVar 

397 

398 @dataclass 

399 class Base: 

400 a: int 

401 "Parameter a" 

402 b: InitVar[int] = 3 

403 "Parameter b" 

404 

405 @dataclass 

406 class Derived(Base): 

407 c: float 

408 d: InitVar[float] 

409 "Parameter d" 

410 """ 

411 

412 with temporary_visited_package("package", {"__init__.py": code}) as module: 

413 base = module["Base"] 

414 param_self = base.parameters[0] 

415 param_a = base.parameters[1] 

416 param_b = base.parameters[2] 

417 assert param_self.docstring is None 

418 assert param_a.docstring.value == "Parameter a" 

419 assert param_b.docstring.value == "Parameter b" 

420 

421 derived = module["Derived"] 

422 param_self = derived.parameters[0] 

423 param_a = derived.parameters[1] 

424 param_b = derived.parameters[2] 

425 param_c = derived.parameters[3] 

426 param_d = derived.parameters[4] 

427 assert param_self.docstring is None 

428 assert param_a.docstring.value == "Parameter a" 

429 assert param_b.docstring.value == "Parameter b" 

430 assert param_c.docstring is None 

431 assert param_d.docstring.value == "Parameter d" 

432 

433 

434def test_attributes_that_have_no_annotations() -> None: 

435 """Dataclass attributes that have no annotatations are not parameters.""" 

436 code = """ 

437 from dataclasses import dataclass, field 

438 

439 @dataclass 

440 class Base: 

441 a: int 

442 b: str = field(init=False) 

443 c = 3 # Class attribute. 

444 

445 @dataclass 

446 class Derived(Base): 

447 a = 1 # No effect on the parameter status of `a`. 

448 b = "b" # Inherited non-parameter. 

449 d: float = 4 

450 """ 

451 with temporary_visited_package("package", {"__init__.py": code}) as module: 

452 base_params = [p.name for p in module["Base"].parameters] 

453 derived_params = [p.name for p in module["Derived"].parameters] 

454 assert base_params == ["self", "a"] 

455 assert derived_params == ["self", "a", "d"] 

456 

457 

458def test_name_resolution() -> None: 

459 """Name are correctly resolved in the scope of an object.""" 

460 code = """ 

461 module_attribute = 0 

462 

463 class Class: 

464 import imported 

465 

466 class_attribute = 0 

467 

468 def __init__(self): 

469 self.instance_attribute = 0 

470 

471 def method(self): 

472 local_variable = 0 

473 """ 

474 with temporary_visited_module(code) as module: 

475 assert module.resolve("module_attribute") == "module.module_attribute" 

476 assert module.resolve("Class") == "module.Class" 

477 

478 assert module["module_attribute"].resolve("Class") == "module.Class" 

479 with pytest.raises(NameResolutionError): 

480 module["module_attribute"].resolve("class_attribute") 

481 

482 assert module["Class"].resolve("module_attribute") == "module.module_attribute" 

483 assert module["Class"].resolve("imported") == "imported" 

484 assert module["Class"].resolve("class_attribute") == "module.Class.class_attribute" 

485 assert module["Class"].resolve("instance_attribute") == "module.Class.instance_attribute" 

486 assert module["Class"].resolve("method") == "module.Class.method" 

487 

488 assert module["Class.class_attribute"].resolve("module_attribute") == "module.module_attribute" 

489 assert module["Class.class_attribute"].resolve("Class") == "module.Class" 

490 assert module["Class.class_attribute"].resolve("imported") == "imported" 

491 assert module["Class.class_attribute"].resolve("instance_attribute") == "module.Class.instance_attribute" 

492 assert module["Class.class_attribute"].resolve("method") == "module.Class.method" 

493 

494 assert module["Class.instance_attribute"].resolve("module_attribute") == "module.module_attribute" 

495 assert module["Class.instance_attribute"].resolve("Class") == "module.Class" 

496 assert module["Class.instance_attribute"].resolve("imported") == "imported" 

497 assert module["Class.instance_attribute"].resolve("class_attribute") == "module.Class.class_attribute" 

498 assert module["Class.instance_attribute"].resolve("method") == "module.Class.method" 

499 

500 assert module["Class.method"].resolve("module_attribute") == "module.module_attribute" 

501 assert module["Class.method"].resolve("Class") == "module.Class" 

502 assert module["Class.method"].resolve("imported") == "imported" 

503 assert module["Class.method"].resolve("class_attribute") == "module.Class.class_attribute" 

504 assert module["Class.method"].resolve("instance_attribute") == "module.Class.instance_attribute" 

505 

506 

507def test_set_parameters() -> None: 

508 """We can set parameters.""" 

509 parameters = Parameters() 

510 # Does not exist yet. 

511 parameters["x"] = Parameter(name="x") 

512 assert "x" in parameters 

513 # Already exists, by name. 

514 parameters["x"] = Parameter(name="x") 

515 assert "x" in parameters 

516 assert len(parameters) == 1 

517 # Already exists, by index. 

518 parameters[0] = Parameter(name="y") 

519 assert "y" in parameters 

520 assert len(parameters) == 1 

521 

522 

523def test_delete_parameters() -> None: 

524 """We can delete parameters.""" 

525 parameters = Parameters() 

526 # By name. 

527 parameters["x"] = Parameter(name="x") 

528 del parameters["x"] 

529 assert "x" not in parameters 

530 assert len(parameters) == 0 

531 # By index. 

532 parameters["x"] = Parameter(name="x") 

533 del parameters[0] 

534 assert "x" not in parameters 

535 assert len(parameters) == 0 

536 

537 

538def test_not_resolving_attribute_value_to_itself() -> None: 

539 """Attribute values with same name don't resolve to themselves.""" 

540 with temporary_visited_module( 

541 """ 

542 class A: 

543 def __init__(self): 

544 x = "something" 

545 self.x = x 

546 """, 

547 ) as module: 

548 assert module["A.x"].value.canonical_path == "x" # Not `module.A.x`. 

549 

550 

551def test_resolving_never_raises_alias_errors() -> None: 

552 """Resolving never raises alias errors.""" 

553 with temporary_visited_package( 

554 "package", 

555 { 

556 "__init__.py": """ 

557 from package.mod import pd 

558 

559 class A: 

560 def __init__(self): 

561 pass 

562 """, 

563 "mod.py": "import pandas as pd", 

564 }, 

565 ) as module: 

566 assert module["A.__init__"].resolve("pd") == "package.mod.pd" 

567 

568 

569def test_building_function_and_class_signatures() -> None: 

570 """Test the construction of a class/function signature.""" 

571 # Test simple function signatures. 

572 simple_params = Parameters( 

573 Parameter("x", annotation="int"), 

574 Parameter("y", annotation="int", default="0"), 

575 ) 

576 simple_func = Function("simple_function", parameters=simple_params, returns="int") 

577 assert simple_func.signature() == "simple_function(x: int, y: int = 0) -> int" 

578 

579 # Test class signatures. 

580 init = Function("__init__", parameters=simple_params, returns="None") 

581 cls = Class("TestClass") 

582 cls.set_member("__init__", init) 

583 assert cls.signature() == "TestClass(x: int, y: int = 0)" 

584 

585 # Create a more complex function with various parameter types. 

586 params = Parameters( 

587 Parameter("a", kind=ParameterKind.positional_only), 

588 Parameter("b", kind=ParameterKind.positional_only, annotation="int", default="0"), 

589 Parameter("c", kind=ParameterKind.positional_or_keyword), 

590 Parameter("d", kind=ParameterKind.positional_or_keyword, annotation="str", default="''"), 

591 Parameter("args", kind=ParameterKind.var_positional), 

592 Parameter("e", kind=ParameterKind.keyword_only), 

593 Parameter("f", kind=ParameterKind.keyword_only, annotation="bool", default="False"), 

594 Parameter("kwargs", kind=ParameterKind.var_keyword), 

595 ) 

596 

597 func = Function("test_function", parameters=params, returns="None") 

598 expected = "test_function(a, b: int = 0, /, c, d: str = '', *args, e, f: bool = False, **kwargs) -> None" 

599 assert func.signature() == expected 

600 

601 

602def test_set_type_parameters() -> None: 

603 """We can set type parameters.""" 

604 type_parameters = TypeParameters() 

605 # Does not exist yet. 

606 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var) 

607 assert "x" in type_parameters 

608 # Already exists, by name. 

609 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var) 

610 assert "x" in type_parameters 

611 assert len(type_parameters) == 1 

612 # Already exists, by name, with different kind. 

613 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.param_spec) 

614 assert "x" in type_parameters 

615 assert len(type_parameters) == 1 

616 # Already exists, by index. 

617 type_parameters[0] = TypeParameter(name="y", kind=TypeParameterKind.type_var) 

618 assert "y" in type_parameters 

619 assert len(type_parameters) == 1 

620 

621 

622def test_delete_type_parameters() -> None: 

623 """We can delete type parameters.""" 

624 type_parameters = TypeParameters() 

625 # By name. 

626 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var) 

627 del type_parameters["x"] 

628 assert "x" not in type_parameters 

629 assert len(type_parameters) == 0 

630 # By index. 

631 type_parameters["x"] = TypeParameter(name="x", kind=TypeParameterKind.type_var) 

632 del type_parameters[0] 

633 assert "x" not in type_parameters 

634 assert len(type_parameters) == 0 

635 

636 

637# YORE: EOL 3.11: Remove line. 

638@pytest.mark.skipif(sys.version_info < (3, 12), reason="Python less than 3.12 does not have PEP 695 generics") 

639def test_annotation_resolution() -> None: 

640 """Names are correctly resolved in the annotation scope of an object.""" 

641 with temporary_visited_module( 

642 """ 

643 class C[T]: 

644 class D[T]: 

645 def func[Y](self, arg1: T, arg2: Y): pass 

646 def func[Z](arg1: T, arg2: Y): pass 

647 """, 

648 ) as module: 

649 assert module["C.D"].resolve("T") == "module.C.D[T]" 

650 

651 assert module["C.D.func"].resolve("T") == "module.C.D[T]" 

652 assert module["C.D.func"].resolve("Y") == "module.C.D.func[Y]" 

653 

654 assert module["C"].resolve("T") == "module.C[T]" 

655 

656 assert module["C.func"].resolve("T") == "module.C[T]" 

657 with pytest.raises(NameResolutionError): 

658 module["C.func"].resolve("Y")