Skip to content

Commit 5fecad2

Browse files
authored
[DBT] Mimic dbt behavior for config() jinja within sql models (#634)
1 parent 33f8cc0 commit 5fecad2

File tree

9 files changed

+159
-54
lines changed

9 files changed

+159
-54
lines changed

sqlmesh/core/config/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,9 @@ def update_with(self: T, other: t.Union[t.Dict[str, t.Any], T]) -> T:
123123
else:
124124
updated_fields[field] = getattr(other, field)
125125

126-
return self.copy(update=updated_fields)
126+
# Assign each field to trigger assignment validators
127+
updated = self.copy()
128+
for field, value in updated_fields.items():
129+
setattr(updated, field, value)
130+
131+
return updated

sqlmesh/dbt/basemodel.py

+55-15
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from dbt.adapters.base import BaseRelation
99
from dbt.contracts.relation import RelationType
10+
from jinja2 import nodes
11+
from jinja2.exceptions import UndefinedError
1012
from pydantic import Field, validator
1113
from sqlglot.helper import ensure_list
1214

@@ -140,6 +142,14 @@ def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
140142

141143
@property
142144
def all_sql(self) -> SqlStr:
145+
return SqlStr("\n".join(self.pre_hook + [self.sql_no_config] + self.post_hook))
146+
147+
@property
148+
def sql_no_config(self) -> SqlStr:
149+
return SqlStr("")
150+
151+
@property
152+
def sql_embedded_config(self) -> SqlStr:
143153
return SqlStr("")
144154

145155
@property
@@ -190,13 +200,17 @@ def relation_info(self) -> AttributeDict[str, t.Any]:
190200
}
191201
)
192202

203+
def attribute_dict(self) -> AttributeDict[str, t.Any]:
204+
return AttributeDict(self.dict())
205+
193206
def sqlmesh_model_kwargs(self, model_context: DbtContext) -> t.Dict[str, t.Any]:
194207
"""Get common sqlmesh model parameters"""
195208
jinja_macros = model_context.jinja_macros.trim(self._dependencies.macros)
196209
jinja_macros.global_objs.update(
197210
{
198211
"this": self.relation_info,
199212
"schema": self.table_schema,
213+
"config": self.attribute_dict(),
200214
**model_context.jinja_globals, # type: ignore
201215
}
202216
)
@@ -220,7 +234,6 @@ def sqlmesh_model_kwargs(self, model_context: DbtContext) -> t.Dict[str, t.Any]:
220234

221235
def render_config(self: BMC, context: DbtContext) -> BMC:
222236
rendered = super().render_config(context)
223-
rendered._dependencies = Dependencies(macros=extract_macro_references(rendered.all_sql))
224237
rendered = ModelSqlRenderer(context, rendered).enriched_config
225238

226239
rendered_dependencies = rendered._dependencies
@@ -275,7 +288,7 @@ def __init__(self, context: DbtContext, config: BMC):
275288
jinja_globals={
276289
**context.jinja_globals,
277290
**date_dict(c.EPOCH, c.EPOCH, c.EPOCH),
278-
"config": self._config,
291+
"config": lambda *args, **kwargs: "",
279292
"ref": self._ref,
280293
"var": self._var,
281294
"source": self._source,
@@ -293,9 +306,15 @@ def __init__(self, context: DbtContext, config: BMC):
293306
dialect=context.engine_adapter.dialect if context.engine_adapter else "",
294307
)
295308

309+
self.jinja_env = self.context.jinja_macros.build_environment(**self._jinja_globals)
310+
296311
@property
297312
def enriched_config(self) -> BMC:
298313
if self._rendered_sql is None:
314+
self._enriched_config = self._update_with_sql_config(self._enriched_config)
315+
self._enriched_config._dependencies = Dependencies(
316+
macros=extract_macro_references(self._enriched_config.all_sql)
317+
)
299318
self.render()
300319
self._enriched_config._dependencies = self._enriched_config._dependencies.union(
301320
self._captured_dependencies
@@ -304,14 +323,42 @@ def enriched_config(self) -> BMC:
304323

305324
def render(self) -> str:
306325
if self._rendered_sql is None:
307-
registry = self.context.jinja_macros
308-
self._rendered_sql = (
309-
registry.build_environment(**self._jinja_globals)
310-
.from_string(self.config.all_sql)
311-
.render()
312-
)
326+
try:
327+
self._rendered_sql = self.jinja_env.from_string(
328+
self._enriched_config.all_sql
329+
).render()
330+
except UndefinedError as e:
331+
raise ConfigError(e.message)
313332
return self._rendered_sql
314333

334+
def _update_with_sql_config(self, config: BMC) -> BMC:
335+
def _extract_value(node: t.Any) -> t.Any:
336+
if not isinstance(node, nodes.Node):
337+
return node
338+
if isinstance(node, nodes.Const):
339+
return _extract_value(node.value)
340+
if isinstance(node, nodes.TemplateData):
341+
return _extract_value(node.data)
342+
if isinstance(node, nodes.List):
343+
return [_extract_value(val) for val in node.items]
344+
if isinstance(node, nodes.Dict):
345+
return {_extract_value(pair.key): _extract_value(pair.value) for pair in node.items}
346+
if isinstance(node, nodes.Tuple):
347+
return tuple(_extract_value(val) for val in node.items)
348+
349+
return self.jinja_env.from_string(nodes.Template([nodes.Output([node])])).render()
350+
351+
for call in self.jinja_env.parse(self._enriched_config.sql_embedded_config).find_all(
352+
nodes.Call
353+
):
354+
if not isinstance(call.node, nodes.Name) or call.node.name != "config":
355+
continue
356+
config = config.update_with(
357+
{kwarg.key: _extract_value(kwarg.value) for kwarg in call.kwargs}
358+
)
359+
360+
return config
361+
315362
def _ref(self, package_name: str, model_name: t.Optional[str] = None) -> BaseRelation:
316363
if package_name in self.context.models:
317364
relation = BaseRelation.create(**self.context.models[package_name].relation_info)
@@ -341,13 +388,6 @@ def _source(self, source_name: str, table_name: str) -> BaseRelation:
341388
self._captured_dependencies.sources.add(full_name)
342389
return BaseRelation.create(**self.context.sources[full_name].relation_info)
343390

344-
def _config(self, *args: t.Any, **kwargs: t.Any) -> str:
345-
if args and isinstance(args[0], dict):
346-
self._enriched_config = self._enriched_config.update_with(args[0])
347-
if kwargs:
348-
self._enriched_config = self._enriched_config.update_with(kwargs)
349-
return ""
350-
351391
class TrackingAdapter(ParsetimeAdapter):
352392
def __init__(self, outer_self: ModelSqlRenderer, *args: t.Any, **kwargs: t.Any):
353393
super().__init__(*args, **kwargs)

sqlmesh/dbt/builtin.py

-5
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,6 @@ def no_log(msg: str, info: bool = False) -> str:
143143
return ""
144144

145145

146-
def config(*args: t.Any, **kwargs: t.Any) -> str:
147-
return ""
148-
149-
150146
def generate_var(variables: t.Dict[str, t.Any]) -> t.Callable:
151147
def var(name: str, default: t.Optional[str] = None) -> str:
152148
return variables.get(name, default)
@@ -252,7 +248,6 @@ def _try_literal_eval(value: str) -> t.Any:
252248

253249
BUILTIN_GLOBALS = {
254250
"api": Api(),
255-
"config": config,
256251
"env_var": env_var,
257252
"exceptions": Exceptions(),
258253
"flags": Flags(),

sqlmesh/dbt/common.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ class DbtConfig(PydanticModel):
190190
class Config:
191191
extra = "allow"
192192
allow_mutation = True
193+
validate_assignment = True
193194

194195

195196
class GeneralConfig(DbtConfig, BaseConfig):
@@ -285,7 +286,9 @@ def render_value(val: t.Any) -> t.Any:
285286

286287
rendered = self.copy(deep=True)
287288
for name in rendered.__fields__:
288-
setattr(rendered, name, render_value(getattr(rendered, name)))
289+
value = getattr(rendered, name)
290+
if value is not None:
291+
setattr(rendered, name, render_value(value))
289292

290293
return rendered
291294

sqlmesh/dbt/model.py

+41-16
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class ModelConfig(BaseModelConfig):
7676
# redshift
7777
bind: t.Optional[bool] = None
7878

79+
# Private fields
80+
_sql_embedded_config: t.Optional[SqlStr] = None
81+
_sql_no_config: t.Optional[SqlStr] = None
82+
7983
@validator(
8084
"unique_key",
8185
"cluster_by",
@@ -157,26 +161,47 @@ def model_kind(self, target: TargetConfig) -> ModelKind:
157161
raise ConfigError(f"{materialization.value} materialization not supported.")
158162

159163
@property
160-
def sql_no_config(self) -> str:
161-
matches = re.findall(r"{{\s*config\(", self.sql)
162-
if matches:
163-
config_macro_start = self.sql.index(matches[0])
164-
cursor = config_macro_start
164+
def sql_no_config(self) -> SqlStr:
165+
if self._sql_no_config is None:
166+
self._sql_no_config = SqlStr("")
167+
self._extract_sql_config()
168+
return self._sql_no_config
169+
170+
@property
171+
def sql_embedded_config(self) -> SqlStr:
172+
if self._sql_embedded_config is None:
173+
self._sql_embedded_config = SqlStr("")
174+
self._extract_sql_config()
175+
return self._sql_embedded_config
176+
177+
def _extract_sql_config(self) -> None:
178+
def jinja_end(sql: str, start: int) -> int:
179+
cursor = start
165180
quote = None
166-
while cursor < len(self.sql):
167-
if self.sql[cursor] in ('"', "'"):
181+
while cursor < len(sql):
182+
if sql[cursor] in ('"', "'"):
168183
if quote is None:
169-
quote = self.sql[cursor]
170-
elif quote == self.sql[cursor]:
184+
quote = sql[cursor]
185+
elif quote == sql[cursor]:
171186
quote = None
172-
if self.sql[cursor : cursor + 2] == "}}" and quote is None:
173-
return "".join([self.sql[:config_macro_start], self.sql[cursor + 2 :]])
187+
if sql[cursor : cursor + 2] == "}}" and quote is None:
188+
return cursor + 2
174189
cursor += 1
175-
return self.sql
176-
177-
@property
178-
def all_sql(self) -> SqlStr:
179-
return SqlStr(";\n".join(self.pre_hook + [self.sql] + self.post_hook))
190+
return cursor
191+
192+
self._sql_no_config = self.sql
193+
matches = re.findall(r"{{\s*config\s*\(", self._sql_no_config)
194+
for match in matches:
195+
start = self._sql_no_config.find(match)
196+
if start == -1:
197+
continue
198+
extracted = self._sql_no_config[start : jinja_end(self._sql_no_config, start)]
199+
self._sql_embedded_config = SqlStr(
200+
"\n".join([self._sql_embedded_config, extracted])
201+
if self._sql_embedded_config
202+
else extracted
203+
)
204+
self._sql_no_config = SqlStr(self._sql_no_config.replace(extracted, "").strip())
180205

181206
def to_sqlmesh(self, context: DbtContext) -> Model:
182207
"""Converts the dbt model into a SQLMesh model."""

sqlmesh/dbt/source.py

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def _validate_quoting(cls, v: t.Dict[str, t.Any]) -> t.Dict[str, bool]:
5050

5151
@validator("columns", pre=True)
5252
def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
53+
if not isinstance(v, dict) or all(isinstance(col, ColumnConfig) for col in v.values()):
54+
return v
55+
5356
return yaml_to_columns(v)
5457

5558
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {

sqlmesh/utils/jinja.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def call_name(node: nodes.Expr) -> t.Tuple[str, ...]:
121121

122122

123123
def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> str:
124-
return ENVIRONMENT.from_string(query).render(methods)
124+
return ENVIRONMENT.from_string(query).render(methods or {})
125125

126126

127127
def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[t.Tuple[str, ...]]:

tests/dbt/test_config.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55

66
from sqlmesh.core.model import SqlModel
7-
from sqlmesh.dbt.basemodel import Dependencies
87
from sqlmesh.dbt.common import DbtContext
98
from sqlmesh.dbt.model import Materialization, ModelConfig
109
from sqlmesh.dbt.project import Project
@@ -130,6 +129,7 @@ def test_to_sqlmesh_fields(sushi_test_project: Project):
130129

131130

132131
def test_model_config_sql_no_config():
132+
context = DbtContext()
133133
assert (
134134
ModelConfig(
135135
sql="""{{
@@ -139,64 +139,72 @@ def test_model_config_sql_no_config():
139139
)
140140
}}
141141
query"""
142-
).sql_no_config.strip()
142+
)
143+
.render_config(context)
144+
.sql_no_config.strip()
143145
== "query"
144146
)
145147

148+
context.variables = {"new": "old"}
146149
assert (
147150
ModelConfig(
148151
sql="""{{
149152
config(
150-
materialized='"table"',
153+
materialized='table',
151154
incremental_strategy='delete+insert',
152-
post_hook=" '{{ macro_call(this) }}' "
155+
post_hook=" '{{ var('new') }}' "
153156
)
154157
}}
155158
query"""
156-
).sql_no_config.strip()
159+
)
160+
.render_config(context)
161+
.sql_no_config.strip()
157162
== "query"
158163
)
159164

160165
assert (
161166
ModelConfig(
162-
sql="""before {{config(materialized='table', post_hook=" {{ macro_call(this) }} ")}} after"""
163-
).sql_no_config
167+
sql="""before {{config(materialized='table', post_hook=" {{ var('new') }} ")}} after"""
168+
)
169+
.render_config(context)
170+
.sql_no_config.strip()
164171
== "before after"
165172
)
166173

167174

168175
def test_variables(assert_exp_eq, sushi_test_project):
169176
# Case 1: using an undefined variable without a default value
170177
defined_variables = {}
171-
model_variables = {"foo"}
172-
173-
model_config = ModelConfig(alias="test", sql="SELECT {{ var('foo') }}")
174-
model_config._dependencies = Dependencies(variables=model_variables)
175178

176179
context = sushi_test_project.context
177180
context.variables = defined_variables
178181

182+
model_config = ModelConfig(alias="test", sql="SELECT {{ var('foo') }}")
183+
179184
kwargs = {"context": context}
180185

181186
with pytest.raises(ConfigError, match=r".*Variable 'foo' was not found.*"):
182-
model_config = model_config.render_config(context)
187+
rendered = model_config.render_config(context)
183188
model_config.to_sqlmesh(**kwargs)
184189

185190
# Case 2: using a defined variable without a default value
186191
defined_variables["foo"] = 6
187192
context.variables = defined_variables
188-
assert_exp_eq(model_config.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"')
193+
rendered = model_config.render_config(context)
194+
assert_exp_eq(rendered.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"')
189195

190196
# Case 3: using a defined variable with a default value
191197
model_config.sql = "SELECT {{ var('foo', 5) }}"
192198

193-
assert_exp_eq(model_config.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"')
199+
rendered = model_config.render_config(context)
200+
assert_exp_eq(rendered.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"')
194201

195202
# Case 4: using an undefined variable with a default value
196203
del defined_variables["foo"]
197204
context.variables = defined_variables
198205

199-
assert_exp_eq(model_config.to_sqlmesh(**kwargs).render_query(), 'SELECT 5 AS "5"')
206+
rendered = model_config.render_config(context)
207+
assert_exp_eq(rendered.to_sqlmesh(**kwargs).render_query(), 'SELECT 5 AS "5"')
200208

201209
# Finally, check that variable scoping & overwriting (some_var) works as expected
202210
expected_sushi_variables = {

0 commit comments

Comments
 (0)