Skip to content

feat: Custom base class detection #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/griffe_pydantic/_internal/common.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
from __future__ import annotations

import importlib
import json
import sys
from functools import partial
from typing import TYPE_CHECKING

from griffe import get_logger

if TYPE_CHECKING:
from collections.abc import Sequence

from griffe import Attribute, Class, Function
from pydantic import BaseModel

_DEFAULT_BASES = (
"pydantic.BaseModel",
"pydantic.main.BaseModel",
"pydantic_settings.BaseSettings",
"pydantic_settings.main.BaseSettings",
"sqlmodel.SQLModel",
"sqlmodel.main.SQLModel",
)


_self_namespace = "griffe_pydantic"
_mkdocstrings_namespace = "mkdocstrings"
_logger = get_logger(__name__)

_field_constraints = {
"gt",
Expand Down Expand Up @@ -77,3 +92,30 @@ def _process_function(func: Function, cls: Class, fields: Sequence[str]) -> None
for target in targets:
target.extra[_self_namespace].setdefault("validators", [])
target.extra[_self_namespace]["validators"].append(func)


def _import_from_name(name: str) -> type[BaseModel]:
"""Given a fully-qualified `package.module.Class` name, return the imported class."""
module_name, _, class_name = name.rpartition(".")
module = sys.modules.get(module_name, importlib.import_module(module_name))
try:
return getattr(module, class_name)
except AttributeError as e:
raise AttributeError(f"No class {class_name} in module {module}") from e


def _import_bases(names: tuple[str, ...]) -> tuple[type[BaseModel], ...]:
"""Import a set of bases from fully-qualified `package.module.Class` names.

Does not raise for import errors,
since we don't expect all possible bases to be present.
"""
bases = []
for name in names:
try:
bases.append(_import_from_name(name))
except ImportError:
# fine, we expect some of the defaults to fail, we only care if we have none
_logger.debug("Could not import %s", name)

return tuple(bases)
30 changes: 22 additions & 8 deletions src/griffe_pydantic/_internal/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
get_logger,
)

from griffe_pydantic._internal import dynamic, static
from griffe_pydantic._internal import common, dynamic, static

if TYPE_CHECKING:
from griffe import ObjectNode
Expand All @@ -22,14 +22,26 @@
class PydanticExtension(Extension):
"""Griffe extension for Pydantic."""

def __init__(self, *, schema: bool = False) -> None:
def __init__(
self,
*,
schema: bool = False,
bases: tuple[str, ...] | list[str] = common._DEFAULT_BASES,
include_bases: tuple[str, ...] | list[str] | None = None,
) -> None:
"""Initialize the extension.

Parameters:
schema: Whether to compute and store the JSON schema of models.
bases: Tuple of complete `package.module.Class` references to base classes that should be considered
pydantic models. Declaring this *replaces* the default bases.
include_bases: *Additional* base classes to consider as pydantic models, including the defaults.
"""
super().__init__()
self._schema = schema
self._bases = tuple(bases)
if include_bases:
self._bases += tuple(include_bases)
self._processed: set[str] = set()
self._recorded: list[tuple[ObjectNode, Class]] = []

Expand All @@ -38,19 +50,21 @@ def on_package_loaded(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG
for node, cls in self._recorded:
self._processed.add(cls.canonical_path)
dynamic._process_class(node.obj, cls, processed=self._processed, schema=self._schema)
static._process_module(pkg, processed=self._processed, schema=self._schema)
static._process_module(pkg, processed=self._processed, schema=self._schema, bases=self._bases)

def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: Any) -> None: # noqa: ARG002
"""Detect and prepare Pydantic models."""
# Prevent running during static analysis.
if isinstance(node, ast.AST):
return

try:
import pydantic
except ImportError:
_logger.warning("could not import pydantic - models will not be detected")
bases = common._import_bases(self._bases)
if not bases:
_logger.warning(
"could not import any expected model base - models will not be detected. \nexpected: %s",
self._bases,
)
return

if issubclass(node.obj, pydantic.BaseModel):
if issubclass(node.obj, bases):
self._recorded.append((node, cls))
21 changes: 14 additions & 7 deletions src/griffe_pydantic/_internal/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
_logger = get_logger(__name__)


def _inherits_pydantic(cls: Class) -> bool:
def _inherits_pydantic(cls: Class, bases: tuple[str, ...] = common._DEFAULT_BASES) -> bool:
"""Tell whether a class inherits from a Pydantic model.

Parameters:
Expand All @@ -41,10 +41,10 @@ def _inherits_pydantic(cls: Class) -> bool:
for base in cls.bases:
if isinstance(base, (ExprName, Expr)):
base = base.canonical_path # noqa: PLW2901
if base in {"pydantic.BaseModel", "pydantic.main.BaseModel"}:
if base in bases:
return True

return any(_inherits_pydantic(parent_class) for parent_class in cls.mro())
return any(_inherits_pydantic(parent_class, bases) for parent_class in cls.mro())


def _pydantic_validator(func: Function) -> ExprCall | None:
Expand Down Expand Up @@ -141,12 +141,18 @@ def _process_function(func: Function, cls: Class, *, processed: set[str]) -> Non
common._process_function(func, cls, fields)


def _process_class(cls: Class, *, processed: set[str], schema: bool = False) -> None:
def _process_class(
cls: Class,
*,
processed: set[str],
schema: bool = False,
bases: tuple[str, ...] = common._DEFAULT_BASES,
) -> None:
"""Finalize the Pydantic model data."""
if cls.canonical_path in processed:
return

if not _inherits_pydantic(cls):
if not _inherits_pydantic(cls, bases):
return

processed.add(cls.canonical_path)
Expand Down Expand Up @@ -182,6 +188,7 @@ def _process_module(
*,
processed: set[str],
schema: bool = False,
bases: tuple[str, ...] = common._DEFAULT_BASES,
) -> None:
"""Handle Pydantic models in a module."""
if mod.canonical_path in processed:
Expand All @@ -191,7 +198,7 @@ def _process_module(
for cls in mod.classes.values():
# Don't process aliases, real classes will be processed at some point anyway.
if not cls.is_alias:
_process_class(cls, processed=processed, schema=schema)
_process_class(cls, processed=processed, schema=schema, bases=bases)

for submodule in mod.modules.values():
_process_module(submodule, processed=processed, schema=schema)
_process_module(submodule, processed=processed, schema=schema, bases=bases)
59 changes: 59 additions & 0 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
from textwrap import dedent
from typing import TYPE_CHECKING

import pytest
Expand Down Expand Up @@ -232,3 +233,61 @@ class B(BaseModel, A):
extensions=Extensions(PydanticExtension(schema=False)),
) as package:
assert "pydantic-field" in package["B.a"].labels


@pytest.mark.parametrize("analysis", ["static", "dynamic"])
@pytest.mark.parametrize("base_mode", ["bases", "include_bases"])
def test_detect_custom_bases(analysis: str, base_mode: str) -> None:
"""We can detect pydantic models with non-standard bases as specified by config."""
package_name = "package"
module_name = "builtins"
class_name = "object"
code = dedent(f"""
from {module_name} import {class_name}

class ExampleParentModel({class_name}):
pass

class ExampleModel(ExampleParentModel):
pass
""")

fake_module = dedent(f"""
class {class_name}:
pass
""")

extension_kwargs = {base_mode: [".".join([module_name, class_name])]}

loader = {
"static": temporary_visited_package,
"dynamic": temporary_inspected_package,
}[analysis]
with loader(
package_name,
modules={"__init__.py": code, module_name + ".py": fake_module},
extensions=Extensions(PydanticExtension(**extension_kwargs)), # type: ignore[arg-type]
) as package:
assert "ExampleParentModel" in package.classes
assert "ExampleModel" in package.classes
assert package.classes["ExampleParentModel"].labels == {"pydantic-model"}
assert package.classes["ExampleModel"].labels == {"pydantic-model"}


@pytest.mark.parametrize("analysis", ["static", "dynamic"])
def test_replace_default_bases(analysis: str) -> None:
"""When we replace the bases, pydantic models should no longer be annotated."""
loader = {
"static": temporary_visited_package,
"dynamic": temporary_inspected_package,
}[analysis]

with loader(
"package",
modules={"__init__.py": code},
extensions=Extensions(PydanticExtension(bases=("fake.fakepackage.NoExisty",))),
) as package:
assert "ExampleParentModel" in package.classes
assert "ExampleModel" in package.classes
assert not package.classes["ExampleParentModel"].labels
assert not package.classes["ExampleModel"].labels
Loading