Skip to content

Commit f5558b6

Browse files
authored
Chore!: remove pydantic v1 validator arg helpers (#3615)
1 parent a702570 commit f5558b6

File tree

19 files changed

+331
-392
lines changed

19 files changed

+331
-392
lines changed

sqlmesh/core/_typing.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import sys
34
import typing as t
45

56
from sqlglot import exp
@@ -9,3 +10,8 @@
910
SchemaName = t.Union[str, exp.Table]
1011
SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
1112
CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]]
13+
14+
if sys.version_info >= (3, 11):
15+
from typing import Self as Self
16+
else:
17+
from typing_extensions import Self as Self

sqlmesh/core/audit/definition.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,10 @@
2828
extract_macro_references_and_variables,
2929
)
3030
from sqlmesh.utils.metaprogramming import Executable
31-
from sqlmesh.utils.pydantic import (
32-
PydanticModel,
33-
field_validator,
34-
model_validator,
35-
model_validator_v1_args,
36-
)
31+
from sqlmesh.utils.pydantic import PydanticModel, field_validator, model_validator
3732

3833
if t.TYPE_CHECKING:
34+
from sqlmesh.core._typing import Self
3935
from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot
4036

4137

@@ -175,12 +171,10 @@ class StandaloneAudit(_Node, AuditMixin):
175171
_depends_on_validator = depends_on_validator
176172

177173
@model_validator(mode="after")
178-
@model_validator_v1_args
179-
def _node_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
180-
if values.get("blocking"):
181-
name = values.get("name")
182-
raise AuditConfigError(f"Standalone audits cannot be blocking: '{name}'.")
183-
return values
174+
def _node_root_validator(self) -> Self:
175+
if self.blocking:
176+
raise AuditConfigError(f"Standalone audits cannot be blocking: '{self.name}'.")
177+
return self
184178

185179
def render_audit_query(
186180
self,

sqlmesh/core/config/connection.py

+91-96
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,12 @@
2323
from sqlmesh.core.engine_adapter.shared import CatalogSupport
2424
from sqlmesh.core.engine_adapter import EngineAdapter
2525
from sqlmesh.utils.errors import ConfigError
26-
from sqlmesh.utils.pydantic import (
27-
field_validator,
28-
model_validator,
29-
model_validator_v1_args,
30-
field_validator_v1_args,
31-
)
26+
from sqlmesh.utils.pydantic import ValidationInfo, field_validator, model_validator
3227
from sqlmesh.utils.aws import validate_s3_uri
3328

29+
if t.TYPE_CHECKING:
30+
from sqlmesh.core._typing import Self
31+
3432
logger = logging.getLogger(__name__)
3533

3634
RECOMMENDED_STATE_SYNC_ENGINES = {"postgres", "gcp_postgres", "mysql", "mssql"}
@@ -163,19 +161,20 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
163161
_data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}
164162

165163
@model_validator(mode="before")
166-
@model_validator_v1_args
167-
def _validate_database_catalogs(
168-
cls, values: t.Dict[str, t.Optional[str]]
169-
) -> t.Dict[str, t.Optional[str]]:
170-
if db_path := values.get("database") and values.get("catalogs"):
164+
def _validate_database_catalogs(cls, data: t.Any) -> t.Any:
165+
if not isinstance(data, dict):
166+
return data
167+
168+
if db_path := data.get("database") and data.get("catalogs"):
171169
raise ConfigError(
172170
"Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
173171
)
174172
if isinstance(db_path, str) and db_path.startswith("md:"):
175173
raise ConfigError(
176174
"Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`."
177175
)
178-
return values
176+
177+
return data
179178

180179
@property
181180
def _engine_adapter(self) -> t.Type[EngineAdapter]:
@@ -430,29 +429,29 @@ class SnowflakeConnectionConfig(ConnectionConfig):
430429
_concurrent_tasks_validator = concurrent_tasks_validator
431430

432431
@model_validator(mode="before")
433-
@model_validator_v1_args
434-
def _validate_authenticator(
435-
cls, values: t.Dict[str, t.Optional[str]]
436-
) -> t.Dict[str, t.Optional[str]]:
437-
from snowflake.connector.network import (
438-
DEFAULT_AUTHENTICATOR,
439-
OAUTH_AUTHENTICATOR,
440-
)
432+
def _validate_authenticator(cls, data: t.Any) -> t.Any:
433+
if not isinstance(data, dict):
434+
return data
441435

442-
auth = values.get("authenticator")
436+
from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR
437+
438+
auth = data.get("authenticator")
443439
auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR
444-
user = values.get("user")
445-
password = values.get("password")
446-
values["private_key"] = cls._get_private_key(values, auth) # type: ignore
440+
user = data.get("user")
441+
password = data.get("password")
442+
data["private_key"] = cls._get_private_key(data, auth) # type: ignore
443+
447444
if (
448445
auth == DEFAULT_AUTHENTICATOR
449-
and not values.get("private_key")
446+
and not data.get("private_key")
450447
and (not user or not password)
451448
):
452449
raise ConfigError("User and password must be provided if using default authentication")
453-
if auth == OAUTH_AUTHENTICATOR and not values.get("token"):
450+
451+
if auth == OAUTH_AUTHENTICATOR and not data.get("token"):
454452
raise ConfigError("Token must be provided if using oauth authentication")
455-
return values
453+
454+
return data
456455

457456
@classmethod
458457
def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]:
@@ -621,26 +620,28 @@ class DatabricksConnectionConfig(ConnectionConfig):
621620
_http_headers_validator = http_headers_validator
622621

623622
@model_validator(mode="before")
624-
@model_validator_v1_args
625-
def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
623+
def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
624+
if not isinstance(data, dict):
625+
return data
626+
626627
from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter
627628

628629
if DatabricksEngineAdapter.can_access_spark_session(
629-
bool(values.get("disable_spark_session"))
630+
bool(data.get("disable_spark_session"))
630631
):
631-
return values
632+
return data
632633

633-
databricks_connect_use_serverless = values.get("databricks_connect_use_serverless")
634+
databricks_connect_use_serverless = data.get("databricks_connect_use_serverless")
634635
server_hostname, http_path, access_token, auth_type = (
635-
values.get("server_hostname"),
636-
values.get("http_path"),
637-
values.get("access_token"),
638-
values.get("auth_type"),
636+
data.get("server_hostname"),
637+
data.get("http_path"),
638+
data.get("access_token"),
639+
data.get("auth_type"),
639640
)
640641

641642
if databricks_connect_use_serverless:
642-
values["force_databricks_connect"] = True
643-
values["disable_databricks_connect"] = False
643+
data["force_databricks_connect"] = True
644+
data["disable_databricks_connect"] = False
644645

645646
if (not server_hostname or not http_path or not access_token) and (
646647
not databricks_connect_use_serverless and not auth_type
@@ -651,35 +652,35 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
651652
if (
652653
databricks_connect_use_serverless
653654
and not server_hostname
654-
and not values.get("databricks_connect_server_hostname")
655+
and not data.get("databricks_connect_server_hostname")
655656
):
656657
raise ValueError(
657658
"`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set"
658659
)
659660
if DatabricksEngineAdapter.can_access_databricks_connect(
660-
bool(values.get("disable_databricks_connect"))
661+
bool(data.get("disable_databricks_connect"))
661662
):
662-
if not values.get("databricks_connect_access_token"):
663-
values["databricks_connect_access_token"] = access_token
664-
if not values.get("databricks_connect_server_hostname"):
665-
values["databricks_connect_server_hostname"] = f"https://{server_hostname}"
663+
if not data.get("databricks_connect_access_token"):
664+
data["databricks_connect_access_token"] = access_token
665+
if not data.get("databricks_connect_server_hostname"):
666+
data["databricks_connect_server_hostname"] = f"https://{server_hostname}"
666667
if not databricks_connect_use_serverless:
667-
if not values.get("databricks_connect_cluster_id"):
668+
if not data.get("databricks_connect_cluster_id"):
668669
if t.TYPE_CHECKING:
669670
assert http_path is not None
670-
values["databricks_connect_cluster_id"] = http_path.split("/")[-1]
671+
data["databricks_connect_cluster_id"] = http_path.split("/")[-1]
671672

672673
if auth_type:
673674
from databricks.sql.auth.auth import AuthType
674675

675-
all_values = [m.value for m in AuthType]
676-
if auth_type not in all_values:
676+
all_data = [m.value for m in AuthType]
677+
if auth_type not in all_data:
677678
raise ValueError(
678-
f"`auth_type` {auth_type} does not match a valid option: {all_values}"
679+
f"`auth_type` {auth_type} does not match a valid option: {all_data}"
679680
)
680681

681-
client_id = values.get("oauth_client_id")
682-
client_secret = values.get("oauth_client_secret")
682+
client_id = data.get("oauth_client_id")
683+
client_secret = data.get("oauth_client_secret")
683684

684685
if client_secret and not client_id:
685686
raise ValueError(
@@ -689,7 +690,7 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
689690
if not http_path:
690691
raise ValueError("`http_path` is still required when using `auth_type`")
691692

692-
return values
693+
return data
693694

694695
@property
695696
def _connection_kwargs_keys(self) -> t.Set[str]:
@@ -866,26 +867,24 @@ class BigQueryConnectionConfig(ConnectionConfig):
866867
type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery")
867868

868869
@field_validator("execution_project")
869-
@field_validator_v1_args
870870
def validate_execution_project(
871871
cls,
872872
v: t.Optional[str],
873-
values: t.Dict[str, t.Any],
873+
info: ValidationInfo,
874874
) -> t.Optional[str]:
875-
if v and not values.get("project"):
875+
if v and not info.data.get("project"):
876876
raise ConfigError(
877877
"If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
878878
)
879879
return v
880880

881881
@field_validator("quota_project")
882-
@field_validator_v1_args
883882
def validate_quota_project(
884883
cls,
885884
v: t.Optional[str],
886-
values: t.Dict[str, t.Any],
885+
info: ValidationInfo,
887886
) -> t.Optional[str]:
888-
if v and not values.get("project"):
887+
if v and not info.data.get("project"):
889888
raise ConfigError(
890889
"If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
891890
)
@@ -998,12 +997,13 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
998997
pre_ping: bool = True
999998

1000999
@model_validator(mode="before")
1001-
@model_validator_v1_args
1002-
def _validate_auth_method(
1003-
cls, values: t.Dict[str, t.Optional[str]]
1004-
) -> t.Dict[str, t.Optional[str]]:
1005-
password = values.get("password")
1006-
enable_iam_auth = values.get("enable_iam_auth")
1000+
def _validate_auth_method(cls, data: t.Any) -> t.Any:
1001+
if not isinstance(data, dict):
1002+
return data
1003+
1004+
password = data.get("password")
1005+
enable_iam_auth = data.get("enable_iam_auth")
1006+
10071007
if password and enable_iam_auth:
10081008
raise ConfigError(
10091009
"Invalid GCP Postgres connection configuration - both password and"
@@ -1016,7 +1016,8 @@ def _validate_auth_method(
10161016
" for a postgres user account or enable_iam_auth set to 'True'"
10171017
" for an IAM user account."
10181018
)
1019-
return values
1019+
1020+
return data
10201021

10211022
@property
10221023
def _connection_kwargs_keys(self) -> t.Set[str]:
@@ -1437,40 +1438,37 @@ class TrinoConnectionConfig(ConnectionConfig):
14371438
type_: t.Literal["trino"] = Field(alias="type", default="trino")
14381439

14391440
@model_validator(mode="after")
1440-
@model_validator_v1_args
1441-
def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
1442-
port = values.get("port")
1443-
if (
1444-
values["http_scheme"] == "http"
1445-
and not values["method"].is_no_auth
1446-
and not values["method"].is_basic
1447-
):
1441+
def _root_validator(self) -> Self:
1442+
port = self.port
1443+
if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic:
14481444
raise ConfigError("HTTP scheme can only be used with no-auth or basic method")
1445+
14491446
if port is None:
1450-
values["port"] = 80 if values["http_scheme"] == "http" else 443
1451-
if (values["method"].is_ldap or values["method"].is_basic) and (
1452-
not values["password"] or not values["user"]
1453-
):
1447+
self.port = 80 if self.http_scheme == "http" else 443
1448+
1449+
if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user):
14541450
raise ConfigError(
1455-
f"Username and Password must be provided if using {values['method'].value} authentication"
1451+
f"Username and Password must be provided if using {self.method.value} authentication"
14561452
)
1457-
if values["method"].is_kerberos and (
1458-
not values["principal"] or not values["keytab"] or not values["krb5_config"]
1453+
1454+
if self.method.is_kerberos and (
1455+
not self.principal or not self.keytab or not self.krb5_config
14591456
):
14601457
raise ConfigError(
14611458
"Kerberos requires the following fields: principal, keytab, and krb5_config"
14621459
)
1463-
if values["method"].is_jwt and not values["jwt_token"]:
1460+
1461+
if self.method.is_jwt and not self.jwt_token:
14641462
raise ConfigError("JWT requires `jwt_token` to be set")
1465-
if values["method"].is_certificate and (
1466-
not values["cert"]
1467-
or not values["client_certificate"]
1468-
or not values["client_private_key"]
1463+
1464+
if self.method.is_certificate and (
1465+
not self.cert or not self.client_certificate or not self.client_private_key
14691466
):
14701467
raise ConfigError(
14711468
"Certificate requires the following fields: cert, client_certificate, and client_private_key"
14721469
)
1473-
return values
1470+
1471+
return self
14741472

14751473
@property
14761474
def _connection_kwargs_keys(self) -> t.Set[str]:
@@ -1677,26 +1675,23 @@ class AthenaConnectionConfig(ConnectionConfig):
16771675
type_: t.Literal["athena"] = Field(alias="type", default="athena")
16781676

16791677
@model_validator(mode="after")
1680-
@model_validator_v1_args
1681-
def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
1682-
work_group = values.get("work_group")
1683-
s3_staging_dir = values.get("s3_staging_dir")
1684-
s3_warehouse_location = values.get("s3_warehouse_location")
1678+
def _root_validator(self) -> Self:
1679+
work_group = self.work_group
1680+
s3_staging_dir = self.s3_staging_dir
1681+
s3_warehouse_location = self.s3_warehouse_location
16851682

16861683
if not work_group and not s3_staging_dir:
16871684
raise ConfigError("At least one of work_group or s3_staging_dir must be set")
16881685

16891686
if s3_staging_dir:
1690-
values["s3_staging_dir"] = validate_s3_uri(
1691-
s3_staging_dir, base=True, error_type=ConfigError
1692-
)
1687+
self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError)
16931688

16941689
if s3_warehouse_location:
1695-
values["s3_warehouse_location"] = validate_s3_uri(
1690+
self.s3_warehouse_location = validate_s3_uri(
16961691
s3_warehouse_location, base=True, error_type=ConfigError
16971692
)
16981693

1699-
return values
1694+
return self
17001695

17011696
@property
17021697
def _connection_kwargs_keys(self) -> t.Set[str]:

0 commit comments

Comments
 (0)