7
7
8
8
from dbt .adapters .base import BaseRelation
9
9
from dbt .contracts .relation import RelationType
10
+ from jinja2 import nodes
11
+ from jinja2 .exceptions import UndefinedError
10
12
from pydantic import Field , validator
11
13
from sqlglot .helper import ensure_list
12
14
@@ -140,6 +142,14 @@ def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
140
142
141
143
@property
142
144
def all_sql (self ) -> SqlStr :
145
+ return SqlStr ("\n " .join (self .pre_hook + [self .sql_no_config ] + self .post_hook ))
146
+
147
+ @property
148
+ def sql_no_config (self ) -> SqlStr :
149
+ return SqlStr ("" )
150
+
151
+ @property
152
+ def sql_embedded_config (self ) -> SqlStr :
143
153
return SqlStr ("" )
144
154
145
155
@property
@@ -190,13 +200,17 @@ def relation_info(self) -> AttributeDict[str, t.Any]:
190
200
}
191
201
)
192
202
203
+ def attribute_dict (self ) -> AttributeDict [str , t .Any ]:
204
+ return AttributeDict (self .dict ())
205
+
193
206
def sqlmesh_model_kwargs (self , model_context : DbtContext ) -> t .Dict [str , t .Any ]:
194
207
"""Get common sqlmesh model parameters"""
195
208
jinja_macros = model_context .jinja_macros .trim (self ._dependencies .macros )
196
209
jinja_macros .global_objs .update (
197
210
{
198
211
"this" : self .relation_info ,
199
212
"schema" : self .table_schema ,
213
+ "config" : self .attribute_dict (),
200
214
** model_context .jinja_globals , # type: ignore
201
215
}
202
216
)
@@ -220,7 +234,6 @@ def sqlmesh_model_kwargs(self, model_context: DbtContext) -> t.Dict[str, t.Any]:
220
234
221
235
def render_config (self : BMC , context : DbtContext ) -> BMC :
222
236
rendered = super ().render_config (context )
223
- rendered ._dependencies = Dependencies (macros = extract_macro_references (rendered .all_sql ))
224
237
rendered = ModelSqlRenderer (context , rendered ).enriched_config
225
238
226
239
rendered_dependencies = rendered ._dependencies
@@ -275,7 +288,7 @@ def __init__(self, context: DbtContext, config: BMC):
275
288
jinja_globals = {
276
289
** context .jinja_globals ,
277
290
** date_dict (c .EPOCH , c .EPOCH , c .EPOCH ),
278
- "config" : self . _config ,
291
+ "config" : lambda * args , ** kwargs : "" ,
279
292
"ref" : self ._ref ,
280
293
"var" : self ._var ,
281
294
"source" : self ._source ,
@@ -293,9 +306,15 @@ def __init__(self, context: DbtContext, config: BMC):
293
306
dialect = context .engine_adapter .dialect if context .engine_adapter else "" ,
294
307
)
295
308
309
+ self .jinja_env = self .context .jinja_macros .build_environment (** self ._jinja_globals )
310
+
296
311
@property
297
312
def enriched_config (self ) -> BMC :
298
313
if self ._rendered_sql is None :
314
+ self ._enriched_config = self ._update_with_sql_config (self ._enriched_config )
315
+ self ._enriched_config ._dependencies = Dependencies (
316
+ macros = extract_macro_references (self ._enriched_config .all_sql )
317
+ )
299
318
self .render ()
300
319
self ._enriched_config ._dependencies = self ._enriched_config ._dependencies .union (
301
320
self ._captured_dependencies
@@ -304,14 +323,42 @@ def enriched_config(self) -> BMC:
304
323
305
324
def render (self ) -> str :
306
325
if self ._rendered_sql is None :
307
- registry = self . context . jinja_macros
308
- self ._rendered_sql = (
309
- registry . build_environment ( ** self ._jinja_globals )
310
- . from_string ( self . config . all_sql )
311
- . render ()
312
- )
326
+ try :
327
+ self ._rendered_sql = self . jinja_env . from_string (
328
+ self ._enriched_config . all_sql
329
+ ). render ( )
330
+ except UndefinedError as e :
331
+ raise ConfigError ( e . message )
313
332
return self ._rendered_sql
314
333
334
+ def _update_with_sql_config (self , config : BMC ) -> BMC :
335
+ def _extract_value (node : t .Any ) -> t .Any :
336
+ if not isinstance (node , nodes .Node ):
337
+ return node
338
+ if isinstance (node , nodes .Const ):
339
+ return _extract_value (node .value )
340
+ if isinstance (node , nodes .TemplateData ):
341
+ return _extract_value (node .data )
342
+ if isinstance (node , nodes .List ):
343
+ return [_extract_value (val ) for val in node .items ]
344
+ if isinstance (node , nodes .Dict ):
345
+ return {_extract_value (pair .key ): _extract_value (pair .value ) for pair in node .items }
346
+ if isinstance (node , nodes .Tuple ):
347
+ return tuple (_extract_value (val ) for val in node .items )
348
+
349
+ return self .jinja_env .from_string (nodes .Template ([nodes .Output ([node ])])).render ()
350
+
351
+ for call in self .jinja_env .parse (self ._enriched_config .sql_embedded_config ).find_all (
352
+ nodes .Call
353
+ ):
354
+ if not isinstance (call .node , nodes .Name ) or call .node .name != "config" :
355
+ continue
356
+ config = config .update_with (
357
+ {kwarg .key : _extract_value (kwarg .value ) for kwarg in call .kwargs }
358
+ )
359
+
360
+ return config
361
+
315
362
def _ref (self , package_name : str , model_name : t .Optional [str ] = None ) -> BaseRelation :
316
363
if package_name in self .context .models :
317
364
relation = BaseRelation .create (** self .context .models [package_name ].relation_info )
@@ -341,13 +388,6 @@ def _source(self, source_name: str, table_name: str) -> BaseRelation:
341
388
self ._captured_dependencies .sources .add (full_name )
342
389
return BaseRelation .create (** self .context .sources [full_name ].relation_info )
343
390
344
- def _config (self , * args : t .Any , ** kwargs : t .Any ) -> str :
345
- if args and isinstance (args [0 ], dict ):
346
- self ._enriched_config = self ._enriched_config .update_with (args [0 ])
347
- if kwargs :
348
- self ._enriched_config = self ._enriched_config .update_with (kwargs )
349
- return ""
350
-
351
391
class TrackingAdapter (ParsetimeAdapter ):
352
392
def __init__ (self , outer_self : ModelSqlRenderer , * args : t .Any , ** kwargs : t .Any ):
353
393
super ().__init__ (* args , ** kwargs )
0 commit comments