Skip to content

Commit 6e3ab4f

Browse files
committed
refactor: Run dynamic analysis once package is loaded
This allows us to use `all_members` when looking up fields and validators, which fixes an issue with models inherting from base ones. Issue-19: #19
1 parent c41a776 commit 6e3ab4f

File tree

5 files changed

+50
-59
lines changed

5 files changed

+50
-59
lines changed

src/griffe_pydantic/common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131

3232

3333
def _model_fields(cls: Class) -> dict[str, Attribute]:
34-
return {name: attr for name, attr in cls.members.items() if "pydantic-field" in attr.labels} # type: ignore[misc]
34+
return {name: attr for name, attr in cls.all_members.items() if "pydantic-field" in attr.labels} # type: ignore[misc]
3535

3636

3737
def _model_validators(cls: Class) -> dict[str, Function]:
38-
return {name: func for name, func in cls.members.items() if "pydantic-validator" in func.labels} # type: ignore[misc]
38+
return {name: func for name, func in cls.all_members.items() if "pydantic-validator" in func.labels} # type: ignore[misc]
3939

4040

4141
def json_schema(model: type[BaseModel]) -> str:
@@ -69,7 +69,7 @@ def process_function(func: Function, cls: Class, fields: Sequence[str]) -> None:
6969
cls: A Griffe function representing the Pydantic validator.
7070
"""
7171
func.labels = {"pydantic-validator"}
72-
targets = [cls.members[field] for field in fields]
72+
targets = [cls.all_members[field] for field in fields]
7373

7474
func.extra[self_namespace].setdefault("targets", [])
7575
func.extra[self_namespace]["targets"].extend(targets)

src/griffe_pydantic/dynamic.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,65 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING
5+
from typing import Any, Callable
66

77
from griffe import (
88
Attribute,
99
Class,
1010
Docstring,
1111
Function,
12+
Kind,
1213
get_logger,
1314
)
1415
from pydantic.fields import FieldInfo
1516

1617
from griffe_pydantic import common
1718

18-
if TYPE_CHECKING:
19-
from griffe import ObjectNode
20-
2119
logger = get_logger(__name__)
2220

2321

24-
def process_attribute(node: ObjectNode, attr: Attribute, cls: Class) -> None:
22+
def process_attribute(obj: Any, attr: Attribute, cls: Class, *, processed: set[str]) -> None:
2523
"""Handle Pydantic fields."""
24+
if attr.canonical_path in processed:
25+
return
26+
processed.add(attr.canonical_path)
2627
if attr.name == "model_config":
27-
cls.extra[common.self_namespace]["config"] = node.obj
28+
cls.extra[common.self_namespace]["config"] = obj
2829
return
2930

30-
if not isinstance(node.obj, FieldInfo):
31+
if not isinstance(obj, FieldInfo):
3132
return
3233

3334
attr.labels = {"pydantic-field"}
34-
attr.value = node.obj.default
35+
attr.value = obj.default
3536
constraints = {}
3637
for constraint in common.field_constraints:
37-
if (value := getattr(node.obj, constraint, None)) is not None:
38+
if (value := getattr(obj, constraint, None)) is not None:
3839
constraints[constraint] = value
3940
attr.extra[common.self_namespace]["constraints"] = constraints
4041

4142
# Populate docstring from the field's `description` argument.
42-
if not attr.docstring and (docstring := node.obj.description):
43+
if not attr.docstring and (docstring := obj.description):
4344
attr.docstring = Docstring(docstring, parent=attr)
4445

4546

46-
def process_function(node: ObjectNode, func: Function, cls: Class) -> None:
47+
def process_function(obj: Callable, func: Function, cls: Class, *, processed: set[str]) -> None:
4748
"""Handle Pydantic field validators."""
48-
if dec_info := getattr(node.obj, "decorator_info", None):
49+
if func.canonical_path in processed:
50+
return
51+
processed.add(func.canonical_path)
52+
if dec_info := getattr(obj, "decorator_info", None):
4953
common.process_function(func, cls, dec_info.fields)
5054

5155

52-
def process_class(node: ObjectNode, cls: Class) -> None:
56+
def process_class(obj: type, cls: Class, *, processed: set[str], schema: bool = False) -> None:
5357
"""Detect and prepare Pydantic models."""
5458
common.process_class(cls)
55-
cls.extra[common.self_namespace]["schema"] = common.json_schema(node.obj)
59+
if schema:
60+
cls.extra[common.self_namespace]["schema"] = common.json_schema(obj)
61+
for member in cls.all_members.values():
62+
kind = member.kind
63+
if kind is Kind.ATTRIBUTE:
64+
process_attribute(getattr(obj, member.name), member, cls, processed=processed) # type: ignore[arg-type]
65+
elif kind is Kind.FUNCTION:
66+
process_function(getattr(obj, member.name), member, cls, processed=processed) # type: ignore[arg-type]

src/griffe_pydantic/extension.py

+5-36
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
from typing import TYPE_CHECKING, Any
77

88
from griffe import (
9-
Attribute,
109
Class,
1110
Extension,
12-
Function,
1311
Module,
1412
get_logger,
1513
)
@@ -34,11 +32,14 @@ def __init__(self, *, schema: bool = False) -> None:
3432
"""
3533
super().__init__()
3634
self.schema = schema
37-
self.in_model: list[Class] = []
3835
self.processed: set[str] = set()
36+
self.recorded: list[tuple[ObjectNode, Class]] = []
3937

4038
def on_package_loaded(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG002
4139
"""Detect models once the whole package is loaded."""
40+
for node, cls in self.recorded:
41+
self.processed.add(cls.canonical_path)
42+
dynamic.process_class(node.obj, cls, processed=self.processed, schema=self.schema)
4243
static.process_module(pkg, processed=self.processed, schema=self.schema)
4344

4445
def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: Any) -> None: # noqa: ARG002
@@ -54,36 +55,4 @@ def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs:
5455
return
5556

5657
if issubclass(node.obj, pydantic.BaseModel):
57-
self.in_model.append(cls)
58-
dynamic.process_class(node, cls)
59-
self.processed.add(cls.canonical_path)
60-
61-
def on_attribute_instance(self, *, node: ast.AST | ObjectNode, attr: Attribute, **kwargs: Any) -> None: # noqa: ARG002
62-
"""Handle Pydantic fields."""
63-
# Prevent running during static analysis.
64-
if isinstance(node, ast.AST):
65-
return
66-
if self.in_model:
67-
cls = self.in_model[-1]
68-
dynamic.process_attribute(node, attr, cls)
69-
self.processed.add(attr.canonical_path)
70-
71-
def on_function_instance(self, *, node: ast.AST | ObjectNode, func: Function, **kwargs: Any) -> None: # noqa: ARG002
72-
"""Handle Pydantic field validators."""
73-
# Prevent running during static analysis.
74-
if isinstance(node, ast.AST):
75-
return
76-
if self.in_model:
77-
cls = self.in_model[-1]
78-
dynamic.process_function(node, func, cls)
79-
self.processed.add(func.canonical_path)
80-
81-
def on_class_members(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: Any) -> None: # noqa: ARG002
82-
"""Finalize the Pydantic model data."""
83-
# Prevent running during static analysis.
84-
if isinstance(node, ast.AST):
85-
return
86-
87-
if self.in_model and cls is self.in_model[-1]:
88-
# Pop last class from the heap.
89-
self.in_model.pop()
58+
self.recorded.append((node, cls))

src/griffe_pydantic/static.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,16 @@ def process_attribute(attr: Attribute, cls: Class, *, processed: set[str]) -> No
9696
kwargs["default"] = attr.value
9797

9898
if attr.name == "model_config":
99-
cls.extra[common.self_namespace]["config"] = kwargs
99+
config = {}
100+
for key, value in kwargs.items():
101+
if isinstance(value, str):
102+
try:
103+
config[key] = ast.literal_eval(value)
104+
except ValueError:
105+
config[key] = value
106+
else:
107+
config[key] = value
108+
cls.extra[common.self_namespace]["config"] = config
100109
return
101110

102111
attr.labels.add("pydantic-field")

tests/test_extension.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import logging
66
from typing import TYPE_CHECKING
77

8-
from griffe import Extensions, temporary_visited_package
8+
import pytest
9+
from griffe import Extensions, temporary_inspected_package, temporary_visited_package
910

1011
from griffe_pydantic.extension import PydanticExtension
1112

1213
if TYPE_CHECKING:
13-
import pytest
1414
from mkdocstrings_handlers.python.handler import PythonHandler
1515

1616

@@ -58,9 +58,11 @@ class RegularClass(object):
5858
"""
5959

6060

61-
def test_extension() -> None:
61+
@pytest.mark.parametrize("analysis", ["static", "dynamic"])
62+
def test_extension(analysis: str) -> None:
6263
"""Test the extension."""
63-
with temporary_visited_package(
64+
loader = {"static": temporary_visited_package, "dynamic": temporary_inspected_package}[analysis]
65+
with loader(
6466
"package",
6567
modules={"__init__.py": code},
6668
extensions=Extensions(PydanticExtension(schema=True)),
@@ -74,7 +76,7 @@ def test_extension() -> None:
7476
assert package.classes["ExampleModel"].labels == {"pydantic-model"}
7577

7678
config = package.classes["ExampleModel"].extra["griffe_pydantic"]["config"]
77-
assert config == {"frozen": "False"}
79+
assert config == {"frozen": False}
7880

7981
schema = package.classes["ExampleModel"].extra["griffe_pydantic"]["schema"]
8082
assert schema.startswith('{\n "description"')

0 commit comments

Comments
 (0)