Coverage for tests/test_models.py: 100.00%

189 statements  

« prev     ^ index     » next       coverage.py v7.6.2, created at 2024-10-12 01:34 +0200

1"""Tests for the `dataclasses` module.""" 

2 

3from __future__ import annotations 

4 

5from copy import deepcopy 

6from textwrap import dedent 

7 

8import pytest 

9 

10from griffe import ( 

11 Attribute, 

12 Docstring, 

13 GriffeLoader, 

14 Module, 

15 NameResolutionError, 

16 module_vtree, 

17 temporary_inspected_module, 

18 temporary_pypackage, 

19 temporary_visited_module, 

20 temporary_visited_package, 

21) 

22 

23 

24def test_submodule_exports() -> None: 

25 """Check that a module is exported depending on whether it was also imported.""" 

26 root = Module("root") 

27 sub = Module("sub") 

28 private = Attribute("_private") 

29 root["sub"] = sub 

30 root["_private"] = private 

31 

32 assert not sub.is_wildcard_exposed 

33 root.imports["sub"] = "root.sub" 

34 assert sub.is_wildcard_exposed 

35 

36 assert not private.is_wildcard_exposed 

37 root.exports = {"_private"} 

38 assert private.is_wildcard_exposed 

39 

40 

41def test_has_docstrings() -> None: 

42 """Assert the `.has_docstrings` method is recursive.""" 

43 with temporary_visited_module("class A:\n '''Hello.'''") as module: 

44 assert module.has_docstrings 

45 

46 

47def test_has_docstrings_submodules() -> None: 

48 """Assert the `.has_docstrings` method descends into submodules.""" 

49 module = module_vtree("a.b.c.d") 

50 module["b.c.d"].docstring = Docstring("Hello.") 

51 assert module.has_docstrings 

52 

53 

54def test_handle_aliases_chain_in_has_docstrings() -> None: 

55 """Assert the `.has_docstrings` method can handle aliases chains in members.""" 

56 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package: 

57 mod_a = tmp_package.path / "mod_a.py" 

58 mod_b = tmp_package.path / "mod_b.py" 

59 mod_a.write_text("from .mod_b import someobj") 

60 mod_b.write_text("from somelib import someobj") 

61 

62 loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) 

63 package = loader.load(tmp_package.name) 

64 assert not package.has_docstrings 

65 loader.resolve_aliases(implicit=True) 

66 assert not package.has_docstrings 

67 

68 

69def test_has_docstrings_does_not_trigger_alias_resolution() -> None: 

70 """Assert the `.has_docstrings` method does not trigger alias resolution.""" 

71 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package: 

72 mod_a = tmp_package.path / "mod_a.py" 

73 mod_b = tmp_package.path / "mod_b.py" 

74 mod_a.write_text("from .mod_b import someobj") 

75 mod_b.write_text("from somelib import someobj") 

76 

77 loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) 

78 package = loader.load(tmp_package.name) 

79 assert not package.has_docstrings 

80 assert not package["mod_a.someobj"].resolved 

81 

82 

83def test_deepcopy() -> None: 

84 """Assert we can deep-copy object trees.""" 

85 loader = GriffeLoader() 

86 mod = loader.load("griffe") 

87 

88 deepcopy(mod) 

89 deepcopy(mod.as_dict()) 

90 

91 

92def test_dataclass_properties_and_class_variables() -> None: 

93 """Don't return properties or class variables as parameters of dataclasses.""" 

94 code = """ 

95 from dataclasses import dataclass 

96 from functools import cached_property 

97 from typing import ClassVar 

98 

99 @dataclass 

100 class Point: 

101 x: float 

102 y: float 

103 

104 # These definitions create class variables 

105 r: ClassVar[float] 

106 s: float = 3 

107 t: ClassVar[float] = 3 

108 

109 @property 

110 def a(self): 

111 return 0 

112 

113 @cached_property 

114 def b(self): 

115 return 0 

116 """ 

117 with temporary_visited_package("package", {"__init__.py": code}) as module: 

118 params = module["Point"].parameters 

119 assert [p.name for p in params] == ["self", "x", "y", "s"] 

120 

121 

122@pytest.mark.parametrize( 

123 "code", 

124 [ 

125 """ 

126 @dataclass 

127 class Dataclass: 

128 x: float 

129 y: float = field(kw_only=True) 

130 

131 class Class: 

132 def __init__(self, x: float, *, y: float): ... 

133 """, 

134 """ 

135 @dataclass 

136 class Dataclass: 

137 x: float = field(kw_only=True) 

138 y: float 

139 

140 class Class: 

141 def __init__(self, y: float, *, x: float): ... 

142 """, 

143 """ 

144 @dataclass 

145 class Dataclass: 

146 x: float 

147 _: KW_ONLY 

148 y: float 

149 

150 class Class: 

151 def __init__(self, x: float, *, y: float): ... 

152 """, 

153 """ 

154 @dataclass 

155 class Dataclass: 

156 _: KW_ONLY 

157 x: float 

158 y: float 

159 

160 class Class: 

161 def __init__(self, *, x: float, y: float): ... 

162 """, 

163 """ 

164 @dataclass(kw_only=True) 

165 class Dataclass: 

166 x: float 

167 y: float 

168 

169 class Class: 

170 def __init__(self, *, x: float, y: float): ... 

171 """, 

172 ], 

173) 

174def test_dataclass_parameter_kinds(code: str) -> None: 

175 """Check dataclass and equivalent non-dataclass parameters. 

176 

177 The parameter kinds for each pair should be the same. 

178 

179 Parameters: 

180 code: Python code to visit. 

181 """ 

182 code = f"from dataclasses import dataclass, field, KW_ONLY\n\n{dedent(code)}" 

183 with temporary_visited_package("package", {"__init__.py": code}) as module: 

184 for dataclass_param, regular_param in zip(module["Dataclass"].parameters, module["Class"].parameters): 

185 assert dataclass_param == regular_param 

186 

187 

188def test_regular_class_inheriting_dataclass_dont_get_its_own_params() -> None: 

189 """A regular class inheriting from a dataclass don't have its attributes added to `__init__`.""" 

190 code = """ 

191 from dataclasses import dataclass 

192 

193 @dataclass 

194 class Base: 

195 a: int 

196 b: str 

197 

198 @dataclass 

199 class Derived1(Base): 

200 c: float 

201 

202 class Derived2(Base): 

203 d: float 

204 """ 

205 with temporary_visited_package("package", {"__init__.py": code}) as module: 

206 params1 = module["Derived1"].parameters 

207 params2 = module["Derived2"].parameters 

208 assert [p.name for p in params1] == ["self", "a", "b", "c"] 

209 assert [p.name for p in params2] == ["self", "a", "b"] 

210 

211 

212def test_regular_class_inheriting_dataclass_is_labelled_dataclass() -> None: 

213 """A regular class inheriting from a dataclass is labelled as a dataclass too.""" 

214 code = """ 

215 from dataclasses import dataclass 

216 

217 @dataclass 

218 class Base: 

219 pass 

220 

221 class Derived(Base): 

222 pass 

223 """ 

224 with temporary_visited_package("package", {"__init__.py": code}) as module: 

225 obj = module["Derived"] 

226 assert "dataclass" in obj.labels 

227 

228 

229def test_fields_with_init_false() -> None: 

230 """Fields marked with `init=False` are not added to the `__init__` method.""" 

231 code = """ 

232 from dataclasses import dataclass, field 

233 

234 @dataclass 

235 class PointA: 

236 x: float 

237 y: float 

238 z: float = field(init=False) 

239 

240 @dataclass(init=False) 

241 class PointB: 

242 x: float 

243 y: float 

244 

245 @dataclass(init=False) 

246 class PointC: 

247 x: float 

248 y: float = field(init=True) # init=True has no effect 

249 """ 

250 with temporary_visited_package("package", {"__init__.py": code}) as module: 

251 params_a = module["PointA"].parameters 

252 params_b = module["PointB"].parameters 

253 params_c = module["PointC"].parameters 

254 

255 assert "z" not in params_a 

256 assert "x" not in params_b 

257 assert "y" not in params_b 

258 assert "x" not in params_c 

259 assert "y" not in params_c 

260 

261 

262def test_parameters_are_reorderd_to_match_their_kind() -> None: 

263 """Keyword-only parameters in base class are pushed back to the end of the signature.""" 

264 code = """ 

265 from dataclasses import dataclass 

266 

267 @dataclass(kw_only=True) 

268 class Base: 

269 a: int 

270 b: str 

271 

272 @dataclass 

273 class Reordered(Base): 

274 b: float 

275 c: float 

276 """ 

277 with temporary_visited_package("package", {"__init__.py": code}) as module: 

278 params_base = module["Base"].parameters 

279 params_reordered = module["Reordered"].parameters 

280 assert [p.name for p in params_base] == ["self", "a", "b"] 

281 assert [p.name for p in params_reordered] == ["self", "b", "c", "a"] 

282 assert str(params_reordered["b"].annotation) == "float" 

283 

284 

285def test_parameters_annotated_as_initvar() -> None: 

286 """Don't return InitVar annotated fields as class members. 

287 

288 But if __init__ is defined, InitVar has no effect. 

289 """ 

290 code = """ 

291 from dataclasses import dataclass, InitVar 

292 

293 @dataclass 

294 class PointA: 

295 x: float 

296 y: float 

297 z: InitVar[float] 

298 

299 @dataclass 

300 class PointB: 

301 x: float 

302 y: float 

303 z: InitVar[float] 

304 

305 def __init__(self, r: float): ... 

306 """ 

307 

308 with temporary_visited_package("package", {"__init__.py": code}) as module: 

309 point_a = module["PointA"] 

310 assert [p.name for p in point_a.parameters] == ["self", "x", "y", "z"] 

311 assert list(point_a.members) == ["x", "y", "__init__"] 

312 

313 point_b = module["PointB"] 

314 assert [p.name for p in point_b.parameters] == ["self", "r"] 

315 assert list(point_b.members) == ["x", "y", "z", "__init__"] 

316 

317 

318def test_visited_module_source() -> None: 

319 """Check the source property of a module.""" 

320 code = "print('hello')\nprint('world')" 

321 with temporary_visited_package("package", {"__init__.py": code}) as module: 

322 assert module.source == code 

323 

324 

325def test_visited_class_source() -> None: 

326 """Check the source property of a class.""" 

327 code = """ 

328 class A: 

329 def __init__(self, x: int): 

330 self.x = x 

331 """ 

332 with temporary_visited_package("package", {"__init__.py": code}) as module: 

333 assert module["A"].source == dedent(code).strip() 

334 

335 

336def test_visited_object_source_with_missing_line_number() -> None: 

337 """Check the source property of an object with missing line number.""" 

338 code = """ 

339 class A: 

340 def __init__(self, x: int): 

341 self.x = x 

342 """ 

343 with temporary_visited_package("package", {"__init__.py": code}) as module: 

344 module["A"].endlineno = None 

345 assert not module["A"].source 

346 module["A"].endlineno = 3 

347 module["A"].lineno = None 

348 assert not module["A"].source 

349 

350 

351def test_inspected_module_source() -> None: 

352 """Check the source property of a module.""" 

353 code = "print('hello')\nprint('world')" 

354 with temporary_inspected_module(code) as module: 

355 assert module.source == code 

356 

357 

358def test_inspected_class_source() -> None: 

359 """Check the source property of a class.""" 

360 code = """ 

361 class A: 

362 def __init__(self, x: int): 

363 self.x = x 

364 """ 

365 with temporary_inspected_module(code) as module: 

366 assert module["A"].source == dedent(code).strip() 

367 

368 

369def test_inspected_object_source_with_missing_line_number() -> None: 

370 """Check the source property of an object with missing line number.""" 

371 code = """ 

372 class A: 

373 def __init__(self, x: int): 

374 self.x = x 

375 """ 

376 with temporary_inspected_module(code) as module: 

377 module["A"].endlineno = None 

378 assert not module["A"].source 

379 module["A"].endlineno = 3 

380 module["A"].lineno = None 

381 assert not module["A"].source 

382 

383 

384def test_dataclass_parameter_docstrings() -> None: 

385 """Class parameters should have a docstring attribute.""" 

386 code = """ 

387 from dataclasses import dataclass, InitVar 

388 

389 @dataclass 

390 class Base: 

391 a: int 

392 "Parameter a" 

393 b: InitVar[int] = 3 

394 "Parameter b" 

395 

396 @dataclass 

397 class Derived(Base): 

398 c: float 

399 d: InitVar[float] 

400 "Parameter d" 

401 """ 

402 

403 with temporary_visited_package("package", {"__init__.py": code}) as module: 

404 base = module["Base"] 

405 param_self = base.parameters[0] 

406 param_a = base.parameters[1] 

407 param_b = base.parameters[2] 

408 assert param_self.docstring is None 

409 assert param_a.docstring.value == "Parameter a" 

410 assert param_b.docstring.value == "Parameter b" 

411 

412 derived = module["Derived"] 

413 param_self = derived.parameters[0] 

414 param_a = derived.parameters[1] 

415 param_b = derived.parameters[2] 

416 param_c = derived.parameters[3] 

417 param_d = derived.parameters[4] 

418 assert param_self.docstring is None 

419 assert param_a.docstring.value == "Parameter a" 

420 assert param_b.docstring.value == "Parameter b" 

421 assert param_c.docstring is None 

422 assert param_d.docstring.value == "Parameter d" 

423 

424 

425def test_attributes_that_have_no_annotations() -> None: 

426 """Dataclass attributes that have no annotatations are not parameters.""" 

427 code = """ 

428 from dataclasses import dataclass, field 

429 

430 @dataclass 

431 class Base: 

432 a: int 

433 b: str = field(init=False) 

434 c = 3 # class attribute 

435 

436 @dataclass 

437 class Derived(Base): 

438 a = 1 # no effect on the parameter status of a 

439 b = "b" # inherited non-parameter 

440 d: float = 4 

441 """ 

442 with temporary_visited_package("package", {"__init__.py": code}) as module: 

443 base_params = [p.name for p in module["Base"].parameters] 

444 derived_params = [p.name for p in module["Derived"].parameters] 

445 assert base_params == ["self", "a"] 

446 assert derived_params == ["self", "a", "d"] 

447 

448 

449def test_name_resolution() -> None: 

450 """Name are correctly resolved in the scope of an object.""" 

451 code = """ 

452 module_attribute = 0 

453 

454 class Class: 

455 import imported 

456 

457 class_attribute = 0 

458 

459 def __init__(self): 

460 self.instance_attribute = 0 

461 

462 def method(self): 

463 local_variable = 0 

464 """ 

465 with temporary_visited_module(code) as module: 

466 assert module.resolve("module_attribute") == "module.module_attribute" 

467 assert module.resolve("Class") == "module.Class" 

468 

469 assert module["module_attribute"].resolve("Class") == "module.Class" 

470 with pytest.raises(NameResolutionError): 

471 module["module_attribute"].resolve("class_attribute") 

472 

473 assert module["Class"].resolve("module_attribute") == "module.module_attribute" 

474 assert module["Class"].resolve("imported") == "imported" 

475 assert module["Class"].resolve("class_attribute") == "module.Class.class_attribute" 

476 assert module["Class"].resolve("instance_attribute") == "module.Class.instance_attribute" 

477 assert module["Class"].resolve("method") == "module.Class.method" 

478 

479 assert module["Class.class_attribute"].resolve("module_attribute") == "module.module_attribute" 

480 assert module["Class.class_attribute"].resolve("Class") == "module.Class" 

481 assert module["Class.class_attribute"].resolve("imported") == "imported" 

482 assert module["Class.class_attribute"].resolve("instance_attribute") == "module.Class.instance_attribute" 

483 assert module["Class.class_attribute"].resolve("method") == "module.Class.method" 

484 

485 assert module["Class.instance_attribute"].resolve("module_attribute") == "module.module_attribute" 

486 assert module["Class.instance_attribute"].resolve("Class") == "module.Class" 

487 assert module["Class.instance_attribute"].resolve("imported") == "imported" 

488 assert module["Class.instance_attribute"].resolve("class_attribute") == "module.Class.class_attribute" 

489 assert module["Class.instance_attribute"].resolve("method") == "module.Class.method" 

490 

491 assert module["Class.method"].resolve("module_attribute") == "module.module_attribute" 

492 assert module["Class.method"].resolve("Class") == "module.Class" 

493 assert module["Class.method"].resolve("imported") == "imported" 

494 assert module["Class.method"].resolve("class_attribute") == "module.Class.class_attribute" 

495 assert module["Class.method"].resolve("instance_attribute") == "module.Class.instance_attribute"