|
10 | 10 | # *************************************************************
|
11 | 11 |
|
12 | 12 | ### Standard packages ###
|
13 |
| -from typing import Literal, Optional, Set, Tuple |
| 13 | +from typing import Optional, List, Set, Tuple, Union, get_args, get_origin |
14 | 14 |
|
15 | 15 | ### Third-party packages ###
|
16 |
| -from pydantic import ( |
17 |
| - BaseModel, |
18 |
| - StrictBool, |
19 |
| - StrictInt, |
20 |
| - StrictStr, |
21 |
| - model_validator, |
22 |
| -) |
23 |
| - |
24 |
| - |
25 |
| -class LoadConfig(BaseModel): |
26 |
| - cookie_key: Optional[StrictStr] = "fastapi-csrf-token" |
27 |
| - cookie_path: Optional[StrictStr] = "/" |
28 |
| - cookie_domain: Optional[StrictStr] = None |
29 |
| - cookie_samesite: Optional[Literal["lax", "none", "strict"]] = "lax" |
30 |
| - cookie_secure: Optional[StrictBool] = False |
31 |
| - header_name: Optional[StrictStr] = "X-CSRF-Token" |
32 |
| - header_type: Optional[StrictStr] = None |
33 |
| - httponly: Optional[StrictBool] = True |
34 |
| - max_age: Optional[StrictInt] = 3600 |
35 |
| - methods: Optional[Set[Literal["DELETE", "GET", "OPTIONS", "PATCH", "POST", "PUT"]]] = None |
36 |
| - secret_key: Optional[StrictStr] = None |
37 |
| - token_location: Optional[Literal["body", "header"]] = "header" |
38 |
| - token_key: Optional[StrictStr] = None |
39 |
| - |
40 |
| - @model_validator(mode="after") |
41 |
| - def validate_cookie_samesite_none_secure(self) -> "LoadConfig": |
| 16 | +from dataclasses import dataclass |
| 17 | + |
| 18 | + |
| 19 | +@dataclass |
| 20 | +class LoadConfig: |
| 21 | + cookie_key: Optional[str] = "fastapi-csrf-token" |
| 22 | + cookie_path: Optional[str] = "/" |
| 23 | + cookie_domain: Optional[str] = None |
| 24 | + cookie_samesite: Optional[str] = "lax" |
| 25 | + cookie_secure: Optional[bool] = False |
| 26 | + header_name: Optional[str] = "X-CSRF-Token" |
| 27 | + header_type: Optional[str] = None |
| 28 | + httponly: Optional[bool] = True |
| 29 | + max_age: Optional[int] = 3600 |
| 30 | + methods: Optional[Set[str]] = None |
| 31 | + secret_key: Optional[str] = None |
| 32 | + token_location: Optional[str] = "header" |
| 33 | + token_key: Optional[str] = None |
| 34 | + |
| 35 | + def __post_init__(self) -> None: |
| 36 | + self.validate_attribute_types() |
| 37 | + self.validate_cookie_samesite() |
| 38 | + self.validate_cookie_samesite_none_secure() |
| 39 | + self.validate_methods() |
| 40 | + self.validate_token_key() |
| 41 | + self.validate_token_location() |
| 42 | + |
| 43 | + def validate_attribute_types(self) -> None: |
| 44 | + for name, field_type in self.__annotations__.items(): |
| 45 | + origin = get_origin(field_type) |
| 46 | + if origin == Union: |
| 47 | + types = get_args(field_type) |
| 48 | + typed: List[bool] = [] |
| 49 | + for current_type in types: |
| 50 | + if get_origin(current_type) is None: |
| 51 | + typed.append(isinstance(getattr(self, name), current_type)) |
| 52 | + else: |
| 53 | + subscripted = get_origin(current_type) |
| 54 | + typed.append(isinstance(getattr(self, name), subscripted)) |
| 55 | + # TODO: subtypes |
| 56 | + if not any(typed): |
| 57 | + raise TypeError(f"The field `{name}` was not correctly assigned as `{field_type}`.") |
| 58 | + elif not isinstance(getattr(self, name), field_type): |
| 59 | + current_type = type(getattr(self, name)) |
| 60 | + raise TypeError( |
| 61 | + f"The field `{name}` was assigned by `{current_type}` instead of `{field_type}`" |
| 62 | + ) |
| 63 | + |
| 64 | + def validate_methods(self) -> None: |
| 65 | + if self.methods is not None and isinstance(self.methods, set): |
| 66 | + for method in self.methods: |
| 67 | + if method not in {"DELETE", "GET", "PATCH", "POST", "PUT"}: |
| 68 | + raise TypeError("lol") |
| 69 | + |
| 70 | + def validate_cookie_samesite(self) -> None: |
| 71 | + if self.cookie_samesite is not None and self.cookie_samesite not in {"lax", "none", "strict"}: |
| 72 | + raise TypeError("lol") |
| 73 | + |
| 74 | + def validate_cookie_samesite_none_secure(self) -> None: |
42 | 75 | if self.cookie_samesite in {None, "none"} and self.cookie_secure is not True:
|
43 |
| - raise ValueError('The "cookie_secure" must be True if "cookie_samesite" set to "none".') |
44 |
| - return self |
| 76 | + raise TypeError('The "cookie_secure" must be True if "cookie_samesite" set to "none".') |
45 | 77 |
|
46 |
| - @model_validator(mode="after") |
47 |
| - def validate_token_key(self) -> "LoadConfig": |
| 78 | + def validate_token_key(self) -> None: |
48 | 79 | token_location: str = self.token_location if self.token_location is not None else "header"
|
49 |
| - if token_location == "body": |
50 |
| - if self.token_key is None: |
51 |
| - raise ValueError('The "token_key" must be present when "token_location" is "body"') |
52 |
| - return self |
| 80 | + if token_location == "body" and self.token_key is None: |
| 81 | + raise TypeError('The "token_key" must be present when "token_location" is "body"') |
| 82 | + |
| 83 | + def validate_token_location(self) -> None: |
| 84 | + if self.token_location not in {"body", "header"}: |
| 85 | + raise TypeError("lol") |
53 | 86 |
|
54 | 87 |
|
55 | 88 | __all__: Tuple[str, ...] = ("LoadConfig",)
|
0 commit comments