|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 |
| -from typing import TYPE_CHECKING |
| 5 | +from typing import Any, Callable |
6 | 6 |
|
7 | 7 | from griffe import (
|
8 | 8 | Attribute,
|
9 | 9 | Class,
|
10 | 10 | Docstring,
|
11 | 11 | Function,
|
| 12 | + Kind, |
12 | 13 | get_logger,
|
13 | 14 | )
|
14 | 15 | from pydantic.fields import FieldInfo
|
15 | 16 |
|
16 | 17 | from griffe_pydantic import common
|
17 | 18 |
|
18 |
| -if TYPE_CHECKING: |
19 |
| - from griffe import ObjectNode |
20 |
| - |
21 | 19 | logger = get_logger(__name__)
|
22 | 20 |
|
23 | 21 |
|
24 |
| -def process_attribute(node: ObjectNode, attr: Attribute, cls: Class) -> None: |
| 22 | +def process_attribute(obj: Any, attr: Attribute, cls: Class, *, processed: set[str]) -> None: |
25 | 23 | """Handle Pydantic fields."""
|
| 24 | + if attr.canonical_path in processed: |
| 25 | + return |
| 26 | + processed.add(attr.canonical_path) |
26 | 27 | if attr.name == "model_config":
|
27 |
| - cls.extra[common.self_namespace]["config"] = node.obj |
| 28 | + cls.extra[common.self_namespace]["config"] = obj |
28 | 29 | return
|
29 | 30 |
|
30 |
| - if not isinstance(node.obj, FieldInfo): |
| 31 | + if not isinstance(obj, FieldInfo): |
31 | 32 | return
|
32 | 33 |
|
33 | 34 | attr.labels = {"pydantic-field"}
|
34 |
| - attr.value = node.obj.default |
| 35 | + attr.value = obj.default |
35 | 36 | constraints = {}
|
36 | 37 | for constraint in common.field_constraints:
|
37 |
| - if (value := getattr(node.obj, constraint, None)) is not None: |
| 38 | + if (value := getattr(obj, constraint, None)) is not None: |
38 | 39 | constraints[constraint] = value
|
39 | 40 | attr.extra[common.self_namespace]["constraints"] = constraints
|
40 | 41 |
|
41 | 42 | # Populate docstring from the field's `description` argument.
|
42 |
| - if not attr.docstring and (docstring := node.obj.description): |
| 43 | + if not attr.docstring and (docstring := obj.description): |
43 | 44 | attr.docstring = Docstring(docstring, parent=attr)
|
44 | 45 |
|
45 | 46 |
|
46 |
| -def process_function(node: ObjectNode, func: Function, cls: Class) -> None: |
| 47 | +def process_function(obj: Callable, func: Function, cls: Class, *, processed: set[str]) -> None: |
47 | 48 | """Handle Pydantic field validators."""
|
48 |
| - if dec_info := getattr(node.obj, "decorator_info", None): |
| 49 | + if func.canonical_path in processed: |
| 50 | + return |
| 51 | + processed.add(func.canonical_path) |
| 52 | + if dec_info := getattr(obj, "decorator_info", None): |
49 | 53 | common.process_function(func, cls, dec_info.fields)
|
50 | 54 |
|
51 | 55 |
|
52 |
| -def process_class(node: ObjectNode, cls: Class) -> None: |
| 56 | +def process_class(obj: type, cls: Class, *, processed: set[str], schema: bool = False) -> None: |
53 | 57 | """Detect and prepare Pydantic models."""
|
54 | 58 | common.process_class(cls)
|
55 |
| - cls.extra[common.self_namespace]["schema"] = common.json_schema(node.obj) |
| 59 | + if schema: |
| 60 | + cls.extra[common.self_namespace]["schema"] = common.json_schema(obj) |
| 61 | + for member in cls.all_members.values(): |
| 62 | + kind = member.kind |
| 63 | + if kind is Kind.ATTRIBUTE: |
| 64 | + process_attribute(getattr(obj, member.name), member, cls, processed=processed) # type: ignore[arg-type] |
| 65 | + elif kind is Kind.FUNCTION: |
| 66 | + process_function(getattr(obj, member.name), member, cls, processed=processed) # type: ignore[arg-type] |
0 commit comments