Coverage for src/griffe_pydantic/common.py: 100.00%
26 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-18 01:11 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-18 01:11 +0100
1"""Griffe extension for Pydantic."""
3from __future__ import annotations
5import json
6from functools import partial
7from typing import TYPE_CHECKING
9if TYPE_CHECKING:
10 from collections.abc import Sequence
12 from griffe import Attribute, Class, Function
13 from pydantic import BaseModel
15self_namespace = "griffe_pydantic"
16mkdocstrings_namespace = "mkdocstrings"
18field_constraints = {
19 "gt",
20 "ge",
21 "lt",
22 "le",
23 "multiple_of",
24 "min_length",
25 "max_length",
26 "pattern",
27 "allow_inf_nan",
28 "max_digits",
29 "decimal_place",
30}
33def _model_fields(cls: Class) -> dict[str, Attribute]:
34 return {name: attr for name, attr in cls.members.items() if "pydantic-field" in attr.labels} # type: ignore[misc]
37def _model_validators(cls: Class) -> dict[str, Function]:
38 return {name: func for name, func in cls.members.items() if "pydantic-validator" in func.labels} # type: ignore[misc]
41def json_schema(model: type[BaseModel]) -> str:
42 """Produce a model schema as JSON.
44 Parameters:
45 model: A Pydantic model.
47 Returns:
48 A schema as JSON.
49 """
50 return json.dumps(model.model_json_schema(), indent=2)
53def process_class(cls: Class) -> None:
54 """Set metadata on a Pydantic model.
56 Parameters:
57 cls: The Griffe class representing the Pydantic model.
58 """
59 cls.labels.add("pydantic-model")
60 cls.extra[self_namespace]["fields"] = partial(_model_fields, cls)
61 cls.extra[self_namespace]["validators"] = partial(_model_validators, cls)
62 cls.extra[mkdocstrings_namespace]["template"] = "pydantic_model.html.jinja"
65def process_function(func: Function, cls: Class, fields: Sequence[str]) -> None:
66 """Set metadata on a Pydantic validator.
68 Parameters:
69 cls: A Griffe function representing the Pydantic validator.
70 """
71 func.labels = {"pydantic-validator"}
72 targets = [cls.members[field] for field in fields]
74 func.extra[self_namespace].setdefault("targets", [])
75 func.extra[self_namespace]["targets"].extend(targets)
76 for target in targets:
77 target.extra[self_namespace].setdefault("validators", [])
78 target.extra[self_namespace]["validators"].append(func)