Coverage for tests/test_models.py: 100.00%

161 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-15 16:47 +0200

1"""Tests for the `dataclasses` module.""" 

2 

3from __future__ import annotations 

4 

5from copy import deepcopy 

6from textwrap import dedent 

7 

8import pytest 

9 

10from griffe import ( 

11 Attribute, 

12 Docstring, 

13 GriffeLoader, 

14 Module, 

15 module_vtree, 

16 temporary_inspected_module, 

17 temporary_pypackage, 

18 temporary_visited_module, 

19 temporary_visited_package, 

20) 

21 

22 

23def test_submodule_exports() -> None: 

24 """Check that a module is exported depending on whether it was also imported.""" 

25 root = Module("root") 

26 sub = Module("sub") 

27 private = Attribute("_private") 

28 root["sub"] = sub 

29 root["_private"] = private 

30 

31 assert not sub.is_wildcard_exposed 

32 root.imports["sub"] = "root.sub" 

33 assert sub.is_wildcard_exposed 

34 

35 assert not private.is_wildcard_exposed 

36 root.exports = {"_private"} 

37 assert private.is_wildcard_exposed 

38 

39 

40def test_has_docstrings() -> None: 

41 """Assert the `.has_docstrings` method is recursive.""" 

42 with temporary_visited_module("class A:\n '''Hello.'''") as module: 

43 assert module.has_docstrings 

44 

45 

46def test_has_docstrings_submodules() -> None: 

47 """Assert the `.has_docstrings` method descends into submodules.""" 

48 module = module_vtree("a.b.c.d") 

49 module["b.c.d"].docstring = Docstring("Hello.") 

50 assert module.has_docstrings 

51 

52 

53def test_handle_aliases_chain_in_has_docstrings() -> None: 

54 """Assert the `.has_docstrings` method can handle aliases chains in members.""" 

55 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package: 

56 mod_a = tmp_package.path / "mod_a.py" 

57 mod_b = tmp_package.path / "mod_b.py" 

58 mod_a.write_text("from .mod_b import someobj") 

59 mod_b.write_text("from somelib import someobj") 

60 

61 loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) 

62 package = loader.load(tmp_package.name) 

63 assert not package.has_docstrings 

64 loader.resolve_aliases(implicit=True) 

65 assert not package.has_docstrings 

66 

67 

68def test_has_docstrings_does_not_trigger_alias_resolution() -> None: 

69 """Assert the `.has_docstrings` method does not trigger alias resolution.""" 

70 with temporary_pypackage("package", ["mod_a.py", "mod_b.py"]) as tmp_package: 

71 mod_a = tmp_package.path / "mod_a.py" 

72 mod_b = tmp_package.path / "mod_b.py" 

73 mod_a.write_text("from .mod_b import someobj") 

74 mod_b.write_text("from somelib import someobj") 

75 

76 loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) 

77 package = loader.load(tmp_package.name) 

78 assert not package.has_docstrings 

79 assert not package["mod_a.someobj"].resolved 

80 

81 

82def test_deepcopy() -> None: 

83 """Assert we can deep-copy object trees.""" 

84 loader = GriffeLoader() 

85 mod = loader.load("griffe") 

86 

87 deepcopy(mod) 

88 deepcopy(mod.as_dict()) 

89 

90 

91def test_dataclass_properties_and_class_variables() -> None: 

92 """Don't return properties or class variables as parameters of dataclasses.""" 

93 code = """ 

94 from dataclasses import dataclass 

95 from functools import cached_property 

96 from typing import ClassVar 

97 

98 @dataclass 

99 class Point: 

100 x: float 

101 y: float 

102 

103 # These definitions create class variables 

104 r: ClassVar[float] 

105 s: float = 3 

106 t: ClassVar[float] = 3 

107 

108 @property 

109 def a(self): 

110 return 0 

111 

112 @cached_property 

113 def b(self): 

114 return 0 

115 """ 

116 with temporary_visited_package("package", {"__init__.py": code}) as module: 

117 params = module["Point"].parameters 

118 assert [p.name for p in params] == ["self", "x", "y", "s"] 

119 

120 

121@pytest.mark.parametrize( 

122 "code", 

123 [ 

124 """ 

125 @dataclass 

126 class Dataclass: 

127 x: float 

128 y: float = field(kw_only=True) 

129 

130 class Class: 

131 def __init__(self, x: float, *, y: float): ... 

132 """, 

133 """ 

134 @dataclass 

135 class Dataclass: 

136 x: float = field(kw_only=True) 

137 y: float 

138 

139 class Class: 

140 def __init__(self, y: float, *, x: float): ... 

141 """, 

142 """ 

143 @dataclass 

144 class Dataclass: 

145 x: float 

146 _: KW_ONLY 

147 y: float 

148 

149 class Class: 

150 def __init__(self, x: float, *, y: float): ... 

151 """, 

152 """ 

153 @dataclass 

154 class Dataclass: 

155 _: KW_ONLY 

156 x: float 

157 y: float 

158 

159 class Class: 

160 def __init__(self, *, x: float, y: float): ... 

161 """, 

162 """ 

163 @dataclass(kw_only=True) 

164 class Dataclass: 

165 x: float 

166 y: float 

167 

168 class Class: 

169 def __init__(self, *, x: float, y: float): ... 

170 """, 

171 ], 

172) 

173def test_dataclass_parameter_kinds(code: str) -> None: 

174 """Check dataclass and equivalent non-dataclass parameters. 

175 

176 The parameter kinds for each pair should be the same. 

177 

178 Parameters: 

179 code: Python code to visit. 

180 """ 

181 code = f"from dataclasses import dataclass, field, KW_ONLY\n\n{dedent(code)}" 

182 with temporary_visited_package("package", {"__init__.py": code}) as module: 

183 for dataclass_param, regular_param in zip(module["Dataclass"].parameters, module["Class"].parameters): 

184 assert dataclass_param == regular_param 

185 

186 

187def test_regular_class_inheriting_dataclass_dont_get_its_own_params() -> None: 

188 """A regular class inheriting from a dataclass don't have its attributes added to `__init__`.""" 

189 code = """ 

190 from dataclasses import dataclass 

191 

192 @dataclass 

193 class Base: 

194 a: int 

195 b: str 

196 

197 @dataclass 

198 class Derived1(Base): 

199 c: float 

200 

201 class Derived2(Base): 

202 d: float 

203 """ 

204 with temporary_visited_package("package", {"__init__.py": code}) as module: 

205 params1 = module["Derived1"].parameters 

206 params2 = module["Derived2"].parameters 

207 assert [p.name for p in params1] == ["self", "a", "b", "c"] 

208 assert [p.name for p in params2] == ["self", "a", "b"] 

209 

210 

211def test_regular_class_inheriting_dataclass_is_labelled_dataclass() -> None: 

212 """A regular class inheriting from a dataclass is labelled as a dataclass too.""" 

213 code = """ 

214 from dataclasses import dataclass 

215 

216 @dataclass 

217 class Base: 

218 pass 

219 

220 class Derived(Base): 

221 pass 

222 """ 

223 with temporary_visited_package("package", {"__init__.py": code}) as module: 

224 obj = module["Derived"] 

225 assert "dataclass" in obj.labels 

226 

227 

228def test_fields_with_init_false() -> None: 

229 """Fields marked with `init=False` are not added to the `__init__` method.""" 

230 code = """ 

231 from dataclasses import dataclass, field 

232 

233 @dataclass 

234 class PointA: 

235 x: float 

236 y: float 

237 z: float = field(init=False) 

238 

239 @dataclass(init=False) 

240 class PointB: 

241 x: float 

242 y: float 

243 

244 @dataclass(init=False) 

245 class PointC: 

246 x: float 

247 y: float = field(init=True) # init=True has no effect 

248 """ 

249 with temporary_visited_package("package", {"__init__.py": code}) as module: 

250 params_a = module["PointA"].parameters 

251 params_b = module["PointB"].parameters 

252 params_c = module["PointC"].parameters 

253 

254 assert "z" not in params_a 

255 assert "x" not in params_b 

256 assert "y" not in params_b 

257 assert "x" not in params_c 

258 assert "y" not in params_c 

259 

260 

261def test_parameters_are_reorderd_to_match_their_kind() -> None: 

262 """Keyword-only parameters in base class are pushed back to the end of the signature.""" 

263 code = """ 

264 from dataclasses import dataclass 

265 

266 @dataclass(kw_only=True) 

267 class Base: 

268 a: int 

269 b: str 

270 

271 @dataclass 

272 class Reordered(Base): 

273 b: float 

274 c: float 

275 """ 

276 with temporary_visited_package("package", {"__init__.py": code}) as module: 

277 params_base = module["Base"].parameters 

278 params_reordered = module["Reordered"].parameters 

279 assert [p.name for p in params_base] == ["self", "a", "b"] 

280 assert [p.name for p in params_reordered] == ["self", "b", "c", "a"] 

281 assert str(params_reordered["b"].annotation) == "float" 

282 

283 

284def test_parameters_annotated_as_initvar() -> None: 

285 """Don't return InitVar annotated fields as class members. 

286 

287 But if __init__ is defined, InitVar has no effect. 

288 """ 

289 code = """ 

290 from dataclasses import dataclass, InitVar 

291 

292 @dataclass 

293 class PointA: 

294 x: float 

295 y: float 

296 z: InitVar[float] 

297 

298 @dataclass 

299 class PointB: 

300 x: float 

301 y: float 

302 z: InitVar[float] 

303 

304 def __init__(self, r: float): ... 

305 """ 

306 

307 with temporary_visited_package("package", {"__init__.py": code}) as module: 

308 point_a = module["PointA"] 

309 assert [p.name for p in point_a.parameters] == ["self", "x", "y", "z"] 

310 assert list(point_a.members) == ["x", "y", "__init__"] 

311 

312 point_b = module["PointB"] 

313 assert [p.name for p in point_b.parameters] == ["self", "r"] 

314 assert list(point_b.members) == ["x", "y", "z", "__init__"] 

315 

316 

317def test_visited_module_source() -> None: 

318 """Check the source property of a module.""" 

319 code = "print('hello')\nprint('world')" 

320 with temporary_visited_package("package", {"__init__.py": code}) as module: 

321 assert module.source == code 

322 

323 

324def test_visited_class_source() -> None: 

325 """Check the source property of a class.""" 

326 code = """ 

327 class A: 

328 def __init__(self, x: int): 

329 self.x = x 

330 """ 

331 with temporary_visited_package("package", {"__init__.py": code}) as module: 

332 assert module["A"].source == dedent(code).strip() 

333 

334 

335def test_visited_object_source_with_missing_line_number() -> None: 

336 """Check the source property of an object with missing line number.""" 

337 code = """ 

338 class A: 

339 def __init__(self, x: int): 

340 self.x = x 

341 """ 

342 with temporary_visited_package("package", {"__init__.py": code}) as module: 

343 module["A"].endlineno = None 

344 assert not module["A"].source 

345 module["A"].endlineno = 3 

346 module["A"].lineno = None 

347 assert not module["A"].source 

348 

349 

350def test_inspected_module_source() -> None: 

351 """Check the source property of a module.""" 

352 code = "print('hello')\nprint('world')" 

353 with temporary_inspected_module(code) as module: 

354 assert module.source == code 

355 

356 

357def test_inspected_class_source() -> None: 

358 """Check the source property of a class.""" 

359 code = """ 

360 class A: 

361 def __init__(self, x: int): 

362 self.x = x 

363 """ 

364 with temporary_inspected_module(code) as module: 

365 assert module["A"].source == dedent(code).strip() 

366 

367 

368def test_inspected_object_source_with_missing_line_number() -> None: 

369 """Check the source property of an object with missing line number.""" 

370 code = """ 

371 class A: 

372 def __init__(self, x: int): 

373 self.x = x 

374 """ 

375 with temporary_inspected_module(code) as module: 

376 module["A"].endlineno = None 

377 assert not module["A"].source 

378 module["A"].endlineno = 3 

379 module["A"].lineno = None 

380 assert not module["A"].source 

381 

382 

383def test_dataclass_parameter_docstrings() -> None: 

384 """Class parameters should have a docstring attribute.""" 

385 code = """ 

386 from dataclasses import dataclass, InitVar 

387 

388 @dataclass 

389 class Base: 

390 a: int 

391 "Parameter a" 

392 b: InitVar[int] = 3 

393 "Parameter b" 

394 

395 @dataclass 

396 class Derived(Base): 

397 c: float 

398 d: InitVar[float] 

399 "Parameter d" 

400 """ 

401 

402 with temporary_visited_package("package", {"__init__.py": code}) as module: 

403 base = module["Base"] 

404 param_self = base.parameters[0] 

405 param_a = base.parameters[1] 

406 param_b = base.parameters[2] 

407 assert param_self.docstring is None 

408 assert param_a.docstring.value == "Parameter a" 

409 assert param_b.docstring.value == "Parameter b" 

410 

411 derived = module["Derived"] 

412 param_self = derived.parameters[0] 

413 param_a = derived.parameters[1] 

414 param_b = derived.parameters[2] 

415 param_c = derived.parameters[3] 

416 param_d = derived.parameters[4] 

417 assert param_self.docstring is None 

418 assert param_a.docstring.value == "Parameter a" 

419 assert param_b.docstring.value == "Parameter b" 

420 assert param_c.docstring is None 

421 assert param_d.docstring.value == "Parameter d" 

422 

423 

424def test_attributes_that_have_no_annotations() -> None: 

425 """Dataclass attributes that have no annotatations are not parameters.""" 

426 code = """ 

427 from dataclasses import dataclass, field 

428 

429 @dataclass 

430 class Base: 

431 a: int 

432 b: str = field(init=False) 

433 c = 3 # class attribute 

434 

435 @dataclass 

436 class Derived(Base): 

437 a = 1 # no effect on the parameter status of a 

438 b = "b" # inherited non-parameter 

439 d: float = 4 

440 """ 

441 with temporary_visited_package("package", {"__init__.py": code}) as module: 

442 base_params = [p.name for p in module["Base"].parameters] 

443 derived_params = [p.name for p in module["Derived"].parameters] 

444 assert base_params == ["self", "a"] 

445 assert derived_params == ["self", "a", "d"]