Skip to content

Commit 16c97de

Browse files
committed
🏗️(project) migrate to pydantic v2 and switch tests to polyfactory
Migrating to `pydantic` v2 should speed up processing and allow interoperability with projects such as `warren`. This migration makes the hypothesis package used in tests obsolete, which is why we introduce `polyfactory`.
1 parent 02b0d47 commit 16c97de

File tree

157 files changed

+2629
-2404
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

157 files changed

+2629
-2404
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ and this project adheres to
88

99
## [Unreleased]
1010

11+
### Changed
12+
13+
- Upgrade `pydantic` to `2.7.0`
14+
- Migrate model tests from hypothesis strategies to polyfactory
15+
1116
## [4.2.0] - 2024-04-08
1217

1318
### Added

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ dependencies = [
3232
# library (mostly models).
3333
"importlib-metadata>=7.0.1, <8.0",
3434
"langcodes>=3.2.0",
35-
"pydantic[dotenv,email]>=1.10.0, <2.0",
35+
"pydantic[email]>=2.5.3,<3.0",
36+
"pydantic_settings>=2.1.0,<3.0",
3637
"rfc3987>=1.3.0",
3738
]
3839
dynamic = ["version"]
@@ -92,7 +93,6 @@ dev = [
9293
"black==24.3.0",
9394
"cryptography==42.0.5",
9495
"factory-boy==3.3.0",
95-
"hypothesis<6.92.0", # pin as hypothesis 6.92.0 observability feature seems broken
9696
"logging-gelf==0.0.32",
9797
"mike==2.0.0",
9898
"mkdocs==1.5.3",
@@ -102,6 +102,7 @@ dev = [
102102
"moto==5.0.5",
103103
"mypy==1.9.0",
104104
"neoteroi-mkdocs==1.0.5",
105+
"polyfactory==2.15.0",
105106
"pyfakefs==5.4.0",
106107
"pymdown-extensions==10.7.1",
107108
"pytest<8.0.0", # pin as pytest-httpx<0.23.0 is not compatible with pytest 8.0.0

src/ralph/api/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,7 @@ async def whoami(
5151
user: AuthenticatedUser = Depends(get_authenticated_user),
5252
) -> Dict[str, Any]:
5353
"""Return the current user's username along with their scopes."""
54-
return {"agent": user.agent, "scopes": user.scopes}
54+
return {
55+
"agent": user.agent.model_dump(mode="json", exclude_none=True),
56+
"scopes": user.scopes,
57+
}

src/ralph/api/auth/basic.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Basic authentication & authorization related tools for the Ralph API."""
22

33
import logging
4+
import os
45
from functools import lru_cache
56
from pathlib import Path
67
from threading import Lock
@@ -10,7 +11,7 @@
1011
from cachetools import TTLCache, cached
1112
from fastapi import Depends, HTTPException, status
1213
from fastapi.security import HTTPBasic, HTTPBasicCredentials
13-
from pydantic import BaseModel, root_validator
14+
from pydantic import RootModel, model_validator
1415
from starlette.authentication import AuthenticationError
1516

1617
from ralph.api.auth.user import AuthenticatedUser
@@ -40,45 +41,42 @@ class UserCredentials(AuthenticatedUser):
4041
username: str
4142

4243

43-
class ServerUsersCredentials(BaseModel):
44+
class ServerUsersCredentials(RootModel[List[UserCredentials]]):
4445
"""Custom root pydantic model.
4546
4647
Describe expected list of all server users credentials as stored in
4748
the credentials file.
4849
4950
Attributes:
50-
__root__ (List): Custom root consisting of the
51+
root (List): Custom root consisting of the
5152
list of all server users credentials.
5253
"""
5354

54-
__root__: List[UserCredentials]
55-
5655
def __add__(self, other) -> Any: # noqa: D105
57-
return ServerUsersCredentials.parse_obj(self.__root__ + other.__root__)
56+
return ServerUsersCredentials.model_validate(self.root + other.root)
5857

5958
def __getitem__(self, item: int) -> UserCredentials: # noqa: D105
60-
return self.__root__[item]
59+
return self.root[item]
6160

6261
def __len__(self) -> int: # noqa: D105
63-
return len(self.__root__)
62+
return len(self.root)
6463

6564
def __iter__(self) -> Iterator[UserCredentials]: # noqa: D105
66-
return iter(self.__root__)
65+
return iter(self.root)
6766

68-
@root_validator
69-
@classmethod
70-
def ensure_unique_username(cls, values: Any) -> Any:
67+
@model_validator(mode="after")
68+
def ensure_unique_username(self) -> Any:
7169
"""Every username should be unique among registered users."""
72-
usernames = [entry.username for entry in values.get("__root__")]
70+
usernames = [entry.username for entry in self.root]
7371
if len(usernames) != len(set(usernames)):
7472
raise ValueError(
7573
"You cannot create multiple credentials with the same username"
7674
)
77-
return values
75+
return self
7876

7977

8078
@lru_cache()
81-
def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials:
79+
def get_stored_credentials(auth_file: os.PathLike) -> ServerUsersCredentials:
8280
"""Helper to read the credentials/scopes file.
8381
8482
Read credentials from JSON file and stored them to avoid reloading them with every
@@ -96,7 +94,9 @@ def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials:
9694
msg = "Credentials file <%s> not found."
9795
logger.warning(msg, auth_file)
9896
raise AuthenticationError(msg.format(auth_file))
99-
return ServerUsersCredentials.parse_file(auth_file)
97+
98+
with open(auth_file, encoding=settings.LOCALE_ENCODING) as f:
99+
return ServerUsersCredentials.model_validate_json(f.read())
100100

101101

102102
@cached(

src/ralph/api/auth/oidc.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from fastapi.security import HTTPBearer, OpenIdConnect
1010
from jose import ExpiredSignatureError, JWTError, jwt
1111
from jose.exceptions import JWTClaimsError
12-
from pydantic import AnyUrl, BaseModel, Extra
12+
from pydantic import AnyUrl, BaseModel, ConfigDict
1313
from typing_extensions import Annotated
1414

1515
from ralph.api.auth.user import AuthenticatedUser, UserScopes
@@ -45,14 +45,13 @@ class IDToken(BaseModel):
4545

4646
iss: str
4747
sub: str
48-
aud: Optional[str]
48+
aud: Optional[str] = None
4949
exp: int
5050
iat: int
51-
scope: Optional[str]
52-
target: Optional[str]
51+
scope: Optional[str] = None
52+
target: Optional[str] = None
5353

54-
class Config: # noqa: D106
55-
extra = Extra.ignore
54+
model_config = ConfigDict(extra="ignore")
5655

5756

5857
@lru_cache()
@@ -144,7 +143,7 @@ def get_oidc_user(
144143
headers={"WWW-Authenticate": "Bearer"},
145144
) from exc
146145

147-
id_token = IDToken.parse_obj(decoded_token)
146+
id_token = IDToken.model_validate(decoded_token)
148147

149148
user = AuthenticatedUser(
150149
agent={"openid": f"{id_token.iss}/{id_token.sub}"},

src/ralph/api/auth/user.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Authenticated user for the Ralph API."""
22

3-
from typing import Dict, FrozenSet, Literal, Optional
3+
from typing import FrozenSet, Literal, Optional
44

5-
from pydantic import BaseModel
5+
from pydantic import BaseModel, RootModel
6+
7+
from ralph.models.xapi.base.agents import BaseXapiAgent
68

79
Scope = Literal[
810
"statements/write",
@@ -18,7 +20,7 @@
1820
]
1921

2022

21-
class UserScopes(FrozenSet[Scope]):
23+
class UserScopes(RootModel[FrozenSet[Scope]]):
2224
"""Scopes available to users."""
2325

2426
def is_authorized(self, requested_scope: Scope):
@@ -47,19 +49,11 @@ def is_authorized(self, requested_scope: Scope):
4749
}
4850

4951
expanded_user_scopes = set()
50-
for scope in self:
52+
for scope in self.root:
5153
expanded_user_scopes.update(expanded_scopes.get(scope, {scope}))
5254

5355
return requested_scope in expanded_user_scopes
5456

55-
@classmethod
56-
def __get_validators__(cls): # noqa: D105
57-
def validate(value: FrozenSet[Scope]):
58-
"""Transform value to an instance of UserScopes."""
59-
return cls(value)
60-
61-
yield validate
62-
6357

6458
class AuthenticatedUser(BaseModel):
6559
"""Pydantic model for user authentication.
@@ -70,6 +64,6 @@ class AuthenticatedUser(BaseModel):
7064
target (str or None): The target index or database to store statements into.
7165
"""
7266

73-
agent: Dict
67+
agent: BaseXapiAgent
7468
scopes: UserScopes
75-
target: Optional[str]
69+
target: Optional[str] = None

src/ralph/api/forwarding.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ async def forward_xapi_statements(
4242
try:
4343
# NB: post or put
4444
req = await getattr(client, method)(
45-
forwarding.url,
45+
str(forwarding.url),
4646
json=statements,
4747
auth=(forwarding.basic_username, forwarding.basic_password),
4848
timeout=forwarding.timeout,
4949
)
5050
req.raise_for_status()
5151
msg = "Forwarded %s statements to %s with success."
5252
if isinstance(statements, list):
53-
logger.debug(msg, len(statements), forwarding.url)
53+
logger.debug(msg, len(statements), str(forwarding.url))
5454
else:
55-
logger.debug(msg, 1, forwarding.url)
55+
logger.debug(msg, 1, str(forwarding.url))
5656
except (RequestError, HTTPStatusError) as error:
5757
logger.error("Failed to forward xAPI statements. %s", error)

src/ralph/api/models.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Optional, Union
88
from uuid import UUID
99

10-
from pydantic import AnyUrl, BaseModel, Extra
10+
from pydantic import AnyUrl, BaseModel, ConfigDict
1111

1212
from ..models.xapi.base.agents import BaseXapiAgent
1313
from ..models.xapi.base.groups import BaseXapiGroup
@@ -30,13 +30,7 @@ class BaseModelWithLaxConfig(BaseModel):
3030
we receive statements through the API.
3131
"""
3232

33-
class Config:
34-
"""Enable extra properties.
35-
36-
Useful for not having to perform comprehensive validation.
37-
"""
38-
39-
extra = Extra.allow
33+
model_config = ConfigDict(extra="allow", coerce_numbers_to_str=True)
4034

4135

4236
class LaxObjectField(BaseModelWithLaxConfig):
@@ -65,6 +59,6 @@ class LaxStatement(BaseModelWithLaxConfig):
6559
"""
6660

6761
actor: Union[BaseXapiAgent, BaseXapiGroup]
68-
id: Optional[UUID]
62+
id: Optional[UUID] = None
6963
object: LaxObjectField
7064
verb: LaxVerbField

src/ralph/api/routers/health.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async def heartbeat(response: Response) -> Heartbeat:
4747
4848
Return a 200 if all checks are successful.
4949
"""
50-
statuses = Heartbeat.construct(
50+
statuses = Heartbeat.model_construct(
5151
database=await await_if_coroutine(BACKEND_CLIENT.status())
5252
)
5353
if not statuses.is_alive:

0 commit comments

Comments
 (0)