23
23
from sqlmesh .core .engine_adapter .shared import CatalogSupport
24
24
from sqlmesh .core .engine_adapter import EngineAdapter
25
25
from sqlmesh .utils .errors import ConfigError
26
- from sqlmesh .utils .pydantic import (
27
- field_validator ,
28
- model_validator ,
29
- model_validator_v1_args ,
30
- field_validator_v1_args ,
31
- )
26
+ from sqlmesh .utils .pydantic import ValidationInfo , field_validator , model_validator
32
27
from sqlmesh .utils .aws import validate_s3_uri
33
28
29
+ if t .TYPE_CHECKING :
30
+ from sqlmesh .core ._typing import Self
31
+
34
32
logger = logging .getLogger (__name__ )
35
33
36
34
RECOMMENDED_STATE_SYNC_ENGINES = {"postgres" , "gcp_postgres" , "mysql" , "mssql" }
@@ -163,19 +161,20 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
163
161
_data_file_to_adapter : t .ClassVar [t .Dict [str , EngineAdapter ]] = {}
164
162
165
163
@model_validator (mode = "before" )
166
- @ model_validator_v1_args
167
- def _validate_database_catalogs (
168
- cls , values : t . Dict [ str , t . Optional [ str ]]
169
- ) -> t . Dict [ str , t . Optional [ str ]]:
170
- if db_path := values .get ("database" ) and values .get ("catalogs" ):
164
+ def _validate_database_catalogs ( cls , data : t . Any ) -> t . Any :
165
+ if not isinstance ( data , dict ):
166
+ return data
167
+
168
+ if db_path := data .get ("database" ) and data .get ("catalogs" ):
171
169
raise ConfigError (
172
170
"Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog"
173
171
)
174
172
if isinstance (db_path , str ) and db_path .startswith ("md:" ):
175
173
raise ConfigError (
176
174
"Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`."
177
175
)
178
- return values
176
+
177
+ return data
179
178
180
179
@property
181
180
def _engine_adapter (self ) -> t .Type [EngineAdapter ]:
@@ -430,29 +429,29 @@ class SnowflakeConnectionConfig(ConnectionConfig):
430
429
_concurrent_tasks_validator = concurrent_tasks_validator
431
430
432
431
@model_validator (mode = "before" )
433
- @model_validator_v1_args
434
- def _validate_authenticator (
435
- cls , values : t .Dict [str , t .Optional [str ]]
436
- ) -> t .Dict [str , t .Optional [str ]]:
437
- from snowflake .connector .network import (
438
- DEFAULT_AUTHENTICATOR ,
439
- OAUTH_AUTHENTICATOR ,
440
- )
432
+ def _validate_authenticator (cls , data : t .Any ) -> t .Any :
433
+ if not isinstance (data , dict ):
434
+ return data
441
435
442
- auth = values .get ("authenticator" )
436
+ from snowflake .connector .network import DEFAULT_AUTHENTICATOR , OAUTH_AUTHENTICATOR
437
+
438
+ auth = data .get ("authenticator" )
443
439
auth = auth .upper () if auth else DEFAULT_AUTHENTICATOR
444
- user = values .get ("user" )
445
- password = values .get ("password" )
446
- values ["private_key" ] = cls ._get_private_key (values , auth ) # type: ignore
440
+ user = data .get ("user" )
441
+ password = data .get ("password" )
442
+ data ["private_key" ] = cls ._get_private_key (data , auth ) # type: ignore
443
+
447
444
if (
448
445
auth == DEFAULT_AUTHENTICATOR
449
- and not values .get ("private_key" )
446
+ and not data .get ("private_key" )
450
447
and (not user or not password )
451
448
):
452
449
raise ConfigError ("User and password must be provided if using default authentication" )
453
- if auth == OAUTH_AUTHENTICATOR and not values .get ("token" ):
450
+
451
+ if auth == OAUTH_AUTHENTICATOR and not data .get ("token" ):
454
452
raise ConfigError ("Token must be provided if using oauth authentication" )
455
- return values
453
+
454
+ return data
456
455
457
456
@classmethod
458
457
def _get_private_key (cls , values : t .Dict [str , t .Optional [str ]], auth : str ) -> t .Optional [bytes ]:
@@ -621,26 +620,28 @@ class DatabricksConnectionConfig(ConnectionConfig):
621
620
_http_headers_validator = http_headers_validator
622
621
623
622
@model_validator (mode = "before" )
624
- @model_validator_v1_args
625
- def _databricks_connect_validator (cls , values : t .Dict [str , t .Any ]) -> t .Dict [str , t .Any ]:
623
+ def _databricks_connect_validator (cls , data : t .Any ) -> t .Any :
624
+ if not isinstance (data , dict ):
625
+ return data
626
+
626
627
from sqlmesh .core .engine_adapter .databricks import DatabricksEngineAdapter
627
628
628
629
if DatabricksEngineAdapter .can_access_spark_session (
629
- bool (values .get ("disable_spark_session" ))
630
+ bool (data .get ("disable_spark_session" ))
630
631
):
631
- return values
632
+ return data
632
633
633
- databricks_connect_use_serverless = values .get ("databricks_connect_use_serverless" )
634
+ databricks_connect_use_serverless = data .get ("databricks_connect_use_serverless" )
634
635
server_hostname , http_path , access_token , auth_type = (
635
- values .get ("server_hostname" ),
636
- values .get ("http_path" ),
637
- values .get ("access_token" ),
638
- values .get ("auth_type" ),
636
+ data .get ("server_hostname" ),
637
+ data .get ("http_path" ),
638
+ data .get ("access_token" ),
639
+ data .get ("auth_type" ),
639
640
)
640
641
641
642
if databricks_connect_use_serverless :
642
- values ["force_databricks_connect" ] = True
643
- values ["disable_databricks_connect" ] = False
643
+ data ["force_databricks_connect" ] = True
644
+ data ["disable_databricks_connect" ] = False
644
645
645
646
if (not server_hostname or not http_path or not access_token ) and (
646
647
not databricks_connect_use_serverless and not auth_type
@@ -651,35 +652,35 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
651
652
if (
652
653
databricks_connect_use_serverless
653
654
and not server_hostname
654
- and not values .get ("databricks_connect_server_hostname" )
655
+ and not data .get ("databricks_connect_server_hostname" )
655
656
):
656
657
raise ValueError (
657
658
"`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set"
658
659
)
659
660
if DatabricksEngineAdapter .can_access_databricks_connect (
660
- bool (values .get ("disable_databricks_connect" ))
661
+ bool (data .get ("disable_databricks_connect" ))
661
662
):
662
- if not values .get ("databricks_connect_access_token" ):
663
- values ["databricks_connect_access_token" ] = access_token
664
- if not values .get ("databricks_connect_server_hostname" ):
665
- values ["databricks_connect_server_hostname" ] = f"https://{ server_hostname } "
663
+ if not data .get ("databricks_connect_access_token" ):
664
+ data ["databricks_connect_access_token" ] = access_token
665
+ if not data .get ("databricks_connect_server_hostname" ):
666
+ data ["databricks_connect_server_hostname" ] = f"https://{ server_hostname } "
666
667
if not databricks_connect_use_serverless :
667
- if not values .get ("databricks_connect_cluster_id" ):
668
+ if not data .get ("databricks_connect_cluster_id" ):
668
669
if t .TYPE_CHECKING :
669
670
assert http_path is not None
670
- values ["databricks_connect_cluster_id" ] = http_path .split ("/" )[- 1 ]
671
+ data ["databricks_connect_cluster_id" ] = http_path .split ("/" )[- 1 ]
671
672
672
673
if auth_type :
673
674
from databricks .sql .auth .auth import AuthType
674
675
675
- all_values = [m .value for m in AuthType ]
676
- if auth_type not in all_values :
676
+ all_data = [m .value for m in AuthType ]
677
+ if auth_type not in all_data :
677
678
raise ValueError (
678
- f"`auth_type` { auth_type } does not match a valid option: { all_values } "
679
+ f"`auth_type` { auth_type } does not match a valid option: { all_data } "
679
680
)
680
681
681
- client_id = values .get ("oauth_client_id" )
682
- client_secret = values .get ("oauth_client_secret" )
682
+ client_id = data .get ("oauth_client_id" )
683
+ client_secret = data .get ("oauth_client_secret" )
683
684
684
685
if client_secret and not client_id :
685
686
raise ValueError (
@@ -689,7 +690,7 @@ def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str
689
690
if not http_path :
690
691
raise ValueError ("`http_path` is still required when using `auth_type`" )
691
692
692
- return values
693
+ return data
693
694
694
695
@property
695
696
def _connection_kwargs_keys (self ) -> t .Set [str ]:
@@ -866,26 +867,24 @@ class BigQueryConnectionConfig(ConnectionConfig):
866
867
type_ : t .Literal ["bigquery" ] = Field (alias = "type" , default = "bigquery" )
867
868
868
869
@field_validator ("execution_project" )
869
- @field_validator_v1_args
870
870
def validate_execution_project (
871
871
cls ,
872
872
v : t .Optional [str ],
873
- values : t . Dict [ str , t . Any ] ,
873
+ info : ValidationInfo ,
874
874
) -> t .Optional [str ]:
875
- if v and not values .get ("project" ):
875
+ if v and not info . data .get ("project" ):
876
876
raise ConfigError (
877
877
"If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
878
878
)
879
879
return v
880
880
881
881
@field_validator ("quota_project" )
882
- @field_validator_v1_args
883
882
def validate_quota_project (
884
883
cls ,
885
884
v : t .Optional [str ],
886
- values : t . Dict [ str , t . Any ] ,
885
+ info : ValidationInfo ,
887
886
) -> t .Optional [str ]:
888
- if v and not values .get ("project" ):
887
+ if v and not info . data .get ("project" ):
889
888
raise ConfigError (
890
889
"If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
891
890
)
@@ -998,12 +997,13 @@ class GCPPostgresConnectionConfig(ConnectionConfig):
998
997
pre_ping : bool = True
999
998
1000
999
@model_validator (mode = "before" )
1001
- @model_validator_v1_args
1002
- def _validate_auth_method (
1003
- cls , values : t .Dict [str , t .Optional [str ]]
1004
- ) -> t .Dict [str , t .Optional [str ]]:
1005
- password = values .get ("password" )
1006
- enable_iam_auth = values .get ("enable_iam_auth" )
1000
+ def _validate_auth_method (cls , data : t .Any ) -> t .Any :
1001
+ if not isinstance (data , dict ):
1002
+ return data
1003
+
1004
+ password = data .get ("password" )
1005
+ enable_iam_auth = data .get ("enable_iam_auth" )
1006
+
1007
1007
if password and enable_iam_auth :
1008
1008
raise ConfigError (
1009
1009
"Invalid GCP Postgres connection configuration - both password and"
@@ -1016,7 +1016,8 @@ def _validate_auth_method(
1016
1016
" for a postgres user account or enable_iam_auth set to 'True'"
1017
1017
" for an IAM user account."
1018
1018
)
1019
- return values
1019
+
1020
+ return data
1020
1021
1021
1022
@property
1022
1023
def _connection_kwargs_keys (self ) -> t .Set [str ]:
@@ -1437,40 +1438,37 @@ class TrinoConnectionConfig(ConnectionConfig):
1437
1438
type_ : t .Literal ["trino" ] = Field (alias = "type" , default = "trino" )
1438
1439
1439
1440
@model_validator (mode = "after" )
1440
- @model_validator_v1_args
1441
- def _root_validator (cls , values : t .Dict [str , t .Any ]) -> t .Dict [str , t .Any ]:
1442
- port = values .get ("port" )
1443
- if (
1444
- values ["http_scheme" ] == "http"
1445
- and not values ["method" ].is_no_auth
1446
- and not values ["method" ].is_basic
1447
- ):
1441
+ def _root_validator (self ) -> Self :
1442
+ port = self .port
1443
+ if self .http_scheme == "http" and not self .method .is_no_auth and not self .method .is_basic :
1448
1444
raise ConfigError ("HTTP scheme can only be used with no-auth or basic method" )
1445
+
1449
1446
if port is None :
1450
- values ["port" ] = 80 if values ["http_scheme" ] == "http" else 443
1451
- if (values ["method" ].is_ldap or values ["method" ].is_basic ) and (
1452
- not values ["password" ] or not values ["user" ]
1453
- ):
1447
+ self .port = 80 if self .http_scheme == "http" else 443
1448
+
1449
+ if (self .method .is_ldap or self .method .is_basic ) and (not self .password or not self .user ):
1454
1450
raise ConfigError (
1455
- f"Username and Password must be provided if using { values [ ' method' ] .value } authentication"
1451
+ f"Username and Password must be provided if using { self . method .value } authentication"
1456
1452
)
1457
- if values ["method" ].is_kerberos and (
1458
- not values ["principal" ] or not values ["keytab" ] or not values ["krb5_config" ]
1453
+
1454
+ if self .method .is_kerberos and (
1455
+ not self .principal or not self .keytab or not self .krb5_config
1459
1456
):
1460
1457
raise ConfigError (
1461
1458
"Kerberos requires the following fields: principal, keytab, and krb5_config"
1462
1459
)
1463
- if values ["method" ].is_jwt and not values ["jwt_token" ]:
1460
+
1461
+ if self .method .is_jwt and not self .jwt_token :
1464
1462
raise ConfigError ("JWT requires `jwt_token` to be set" )
1465
- if values ["method" ].is_certificate and (
1466
- not values ["cert" ]
1467
- or not values ["client_certificate" ]
1468
- or not values ["client_private_key" ]
1463
+
1464
+ if self .method .is_certificate and (
1465
+ not self .cert or not self .client_certificate or not self .client_private_key
1469
1466
):
1470
1467
raise ConfigError (
1471
1468
"Certificate requires the following fields: cert, client_certificate, and client_private_key"
1472
1469
)
1473
- return values
1470
+
1471
+ return self
1474
1472
1475
1473
@property
1476
1474
def _connection_kwargs_keys (self ) -> t .Set [str ]:
@@ -1677,26 +1675,23 @@ class AthenaConnectionConfig(ConnectionConfig):
1677
1675
type_ : t .Literal ["athena" ] = Field (alias = "type" , default = "athena" )
1678
1676
1679
1677
@model_validator (mode = "after" )
1680
- @model_validator_v1_args
1681
- def _root_validator (cls , values : t .Dict [str , t .Any ]) -> t .Dict [str , t .Any ]:
1682
- work_group = values .get ("work_group" )
1683
- s3_staging_dir = values .get ("s3_staging_dir" )
1684
- s3_warehouse_location = values .get ("s3_warehouse_location" )
1678
+ def _root_validator (self ) -> Self :
1679
+ work_group = self .work_group
1680
+ s3_staging_dir = self .s3_staging_dir
1681
+ s3_warehouse_location = self .s3_warehouse_location
1685
1682
1686
1683
if not work_group and not s3_staging_dir :
1687
1684
raise ConfigError ("At least one of work_group or s3_staging_dir must be set" )
1688
1685
1689
1686
if s3_staging_dir :
1690
- values ["s3_staging_dir" ] = validate_s3_uri (
1691
- s3_staging_dir , base = True , error_type = ConfigError
1692
- )
1687
+ self .s3_staging_dir = validate_s3_uri (s3_staging_dir , base = True , error_type = ConfigError )
1693
1688
1694
1689
if s3_warehouse_location :
1695
- values [ " s3_warehouse_location" ] = validate_s3_uri (
1690
+ self . s3_warehouse_location = validate_s3_uri (
1696
1691
s3_warehouse_location , base = True , error_type = ConfigError
1697
1692
)
1698
1693
1699
- return values
1694
+ return self
1700
1695
1701
1696
@property
1702
1697
def _connection_kwargs_keys (self ) -> t .Set [str ]:
0 commit comments