Skip to content

Commit bb9826a

Browse files
Feat(dbt): Add support for on-run-start and on-run-end hooks (#4044)
1 parent d015490 commit bb9826a

File tree

15 files changed

+217
-10
lines changed

15 files changed

+217
-10
lines changed

docs/concepts/macros/macro_variables.md

+10-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ This example used one of SQLMesh's predefined variables, but you can also define
4444

4545
We describe SQLMesh's predefined variables below; user-defined macro variables are discussed in the [SQLMesh macros](./sqlmesh_macros.md#user-defined-variables) and [Jinja macros](./jinja_macros.md#user-defined-variables) pages.
4646

47-
## Predefined Variables
47+
## Predefined variables
4848
SQLMesh comes with predefined variables that can be used in your queries. They are automatically set by the SQLMesh runtime.
4949

5050
Most predefined variables are related to time and use a combination of prefixes (start, end, etc.) and postfixes (date, ds, ts, etc.). They are described in the next section; [other predefined variables](#runtime-variables) are discussed in the following section.
@@ -120,7 +120,7 @@ All predefined temporal macro variables:
120120

121121
### Runtime variables
122122

123-
SQLMesh provides two other predefined variables used to modify model behavior based on information available at runtime.
123+
SQLMesh provides additional predefined variables used to modify model behavior based on information available at runtime.
124124

125125
* @runtime_stage - A string value denoting the current stage of the SQLMesh runtime. Typically used in models to conditionally execute pre/post-statements (learn more [here](../models/sql_models.md#optional-prepost-statements)). It returns one of these values:
126126
* 'loading' - The project is being loaded into SQLMesh's runtime context.
@@ -133,5 +133,11 @@ SQLMesh provides two other predefined variables used to modify model behavior ba
133133
* @this_model - A string value containing the name of the physical table the model view selects from. Typically used to create [generic audits](../audits.md#generic-audits). In the case of [on_virtual_update statements](../models/sql_models.md#optional-on-virtual-update-statements) it contains the qualified view name instead.
134134
* Can be used in model definitions when SQLGlot cannot fully parse a statement and you need to reference the model's underlying physical table directly.
135135
* Can be passed as an argument to macros that access or interact with the underlying physical table.
136-
* @this_env - A string value containing the name of the current [environment](../environments.md). Only available in [`before_all` and `after_all` statements](../../guides/configuration.md#before_all-and-after_all-statements), as well as in macros invoked within them.
137-
* @model_kind_name - A string value containing the name of the current model kind. Intended to be used in scenarios where you need to control the [physical properties in model defaults](../../reference/model_configuration.md#model-defaults).
136+
* @model_kind_name - A string value containing the name of the current model kind. Intended to be used in scenarios where you need to control the [physical properties in model defaults](../../reference/model_configuration.md#model-defaults).
137+
138+
#### Before all and after all variables
139+
140+
The following variables are also available in [`before_all` and `after_all` statements](../../guides/configuration.md#before_all-and-after_all-statements), as well as in macros invoked within them.
141+
142+
* @this_env - A string value containing the name of the current [environment](../environments.md).
143+
* @schemas - A list of the schema names of the [virtual layer](../../concepts/glossary.md#virtual-layer) of the current environment.

docs/integrations/dbt.md

-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,6 @@ The dbt jinja methods that are not currently supported are:
324324
* selected_sources
325325
* adapter.expand_target_column_types
326326
* adapter.rename_relation
327-
* schemas
328327
* graph.nodes.values
329328
* graph.metrics.values
330329

examples/multi_dbt/bronze/dbt_project.yml

+3
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"]
1919
models:
2020
start: "2024-01-01"
2121
+materialized: table
22+
23+
on-run-start:
24+
- 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);'

examples/multi_dbt/silver/dbt_project.yml

+3
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"]
1919
models:
2020
start: "2024-01-01"
2121
+materialized: table
22+
23+
on-run-end:
24+
- '{{ store_schemas(schemas) }}'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{% macro store_schemas(schemas) %}
2+
create or replace table schema_table as select {{schemas}} as all_schemas;
3+
{% endmacro %}

sqlmesh/core/environment.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo, Snapshot
1515
from sqlmesh.utils import word_characters_only
1616
from sqlmesh.utils.date import TimeLike, now_timestamp
17+
from sqlmesh.utils.jinja import JinjaMacroRegistry
1718
from sqlmesh.utils.metaprogramming import Executable
1819
from sqlmesh.utils.pydantic import PydanticModel, field_validator
1920

@@ -218,6 +219,7 @@ class EnvironmentStatements(PydanticModel):
218219
before_all: t.List[str]
219220
after_all: t.List[str]
220221
python_env: t.Dict[str, Executable]
222+
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
221223

222224

223225
def execute_environment_statements(
@@ -239,6 +241,7 @@ def execute_environment_statements(
239241
dialect=adapter.dialect,
240242
default_catalog=default_catalog,
241243
python_env=statements.python_env,
244+
jinja_macros=statements.jinja_macros,
242245
snapshots=snapshots,
243246
start=start,
244247
end=end,

sqlmesh/core/renderer.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ def _render(
107107

108108
if environment_naming_info := kwargs.get("environment_naming_info", None):
109109
kwargs["this_env"] = getattr(environment_naming_info, "name")
110+
if snapshots and (
111+
schemas := set(
112+
[
113+
s.qualified_view_name.schema_for_environment(
114+
environment_naming_info, dialect=self._dialect
115+
)
116+
for s in snapshots.values()
117+
if s.is_model and not s.is_symbolic
118+
]
119+
)
120+
):
121+
kwargs["schemas"] = list(schemas)
110122

111123
this_model = kwargs.pop("this_model", None)
112124

@@ -411,19 +423,21 @@ def render(
411423

412424
def render_statements(
413425
statements: t.List[str],
414-
dialect: DialectType = None,
426+
dialect: str,
415427
default_catalog: t.Optional[str] = None,
416428
python_env: t.Optional[t.Dict[str, Executable]] = None,
429+
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
417430
**render_kwargs: t.Any,
418431
) -> t.List[str]:
419432
rendered_statements: t.List[str] = []
420433
for statement in statements:
421-
for expression in parse(statement, dialect=dialect):
434+
for expression in d.parse(statement, default_dialect=dialect):
422435
if expression:
423436
rendered = ExpressionRenderer(
424437
expression,
425438
dialect,
426439
[],
440+
jinja_macro_registry=jinja_macros or JinjaMacroRegistry(),
427441
python_env=python_env,
428442
default_catalog=default_catalog,
429443
quote_identifiers=False,

sqlmesh/dbt/loader.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import logging
44
import sys
55
import typing as t
6+
import sqlmesh.core.dialect as d
7+
from sqlglot.optimizer.simplify import gen
68
from pathlib import Path
79
from sqlmesh.core import constants as c
810
from sqlmesh.core.config import (
@@ -11,9 +13,11 @@
1113
GatewayConfig,
1214
ModelDefaultsConfig,
1315
)
16+
from sqlmesh.core.environment import EnvironmentStatements
1417
from sqlmesh.core.loader import CacheBase, LoadedProject, Loader
1518
from sqlmesh.core.macros import MacroRegistry, macro
1619
from sqlmesh.core.model import Model, ModelCache
20+
from sqlmesh.core.model.common import make_python_env
1721
from sqlmesh.core.signal import signal
1822
from sqlmesh.dbt.basemodel import BMC, BaseModelConfig
1923
from sqlmesh.dbt.context import DbtContext
@@ -23,7 +27,11 @@
2327
from sqlmesh.dbt.target import TargetConfig
2428
from sqlmesh.utils import UniqueKeyDict
2529
from sqlmesh.utils.errors import ConfigError
26-
from sqlmesh.utils.jinja import JinjaMacroRegistry
30+
from sqlmesh.utils.jinja import (
31+
JinjaMacroRegistry,
32+
MacroInfo,
33+
extract_macro_references_and_variables,
34+
)
2735

2836
if sys.version_info >= (3, 12):
2937
from importlib import metadata
@@ -230,6 +238,60 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
230238

231239
return requirements, excluded_requirements
232240

241+
def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
242+
"""Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""
243+
244+
on_run_start: t.List[str] = []
245+
on_run_end: t.List[str] = []
246+
jinja_root_macros: t.Dict[str, MacroInfo] = {}
247+
variables: t.Dict[str, t.Any] = self._get_variables()
248+
dialect = self.config.dialect
249+
for project in self._load_projects():
250+
context = project.context.copy()
251+
if manifest := context._manifest:
252+
on_run_start.extend(manifest._on_run_start or [])
253+
on_run_end.extend(manifest._on_run_end or [])
254+
255+
if root_package := context.jinja_macros.root_package_name:
256+
if root_macros := context.jinja_macros.packages.get(root_package):
257+
jinja_root_macros |= root_macros
258+
context.set_and_render_variables(context.variables, root_package)
259+
variables |= context.variables
260+
261+
if statements := on_run_start + on_run_end:
262+
jinja_macro_references, used_variables = extract_macro_references_and_variables(
263+
*(gen(stmt) for stmt in statements)
264+
)
265+
jinja_macros = context.jinja_macros
266+
jinja_macros.root_macros = jinja_root_macros
267+
jinja_macros = (
268+
jinja_macros.trim(jinja_macro_references)
269+
if not jinja_macros.trimmed
270+
else jinja_macros
271+
)
272+
273+
python_env = make_python_env(
274+
[s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)],
275+
jinja_macro_references=jinja_macro_references,
276+
module_path=self.config_path,
277+
macros=macros,
278+
variables=variables,
279+
used_variables=used_variables,
280+
path=self.config_path,
281+
)
282+
283+
return EnvironmentStatements(
284+
before_all=[
285+
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_start or []
286+
],
287+
after_all=[
288+
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_end or []
289+
],
290+
python_env=python_env,
291+
jinja_macros=jinja_macros,
292+
)
293+
return None
294+
233295
def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]:
234296
if not root.is_dir():
235297
return {}

sqlmesh/dbt/manifest.py

+8
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def __init__(
9494
self.project_path / c.CACHE, "jinja_calls"
9595
)
9696

97+
self._on_run_start: t.Optional[t.List[str]] = None
98+
self._on_run_end: t.Optional[t.List[str]] = None
99+
97100
def tests(self, package_name: t.Optional[str] = None) -> TestConfigs:
98101
self._load_all()
99102
return self._tests_per_package[package_name or self._project_name]
@@ -312,6 +315,11 @@ def _load_manifest(self) -> Manifest:
312315

313316
runtime_config = RuntimeConfig.from_parts(project, profile, args)
314317

318+
if runtime_config.on_run_start:
319+
self._on_run_start = runtime_config.on_run_start
320+
if runtime_config.on_run_end:
321+
self._on_run_end = runtime_config.on_run_end
322+
315323
self._project_name = project.project_name
316324

317325
if DBT_VERSION >= (1, 8):

tests/core/test_context.py

+19
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,7 @@ def test_environment_statements(tmp_path: pathlib.Path):
14171417
after_all=[
14181418
"@grant_schema_usage()",
14191419
"@grant_select_privileges()",
1420+
"@grant_usage_role(@schemas, 'admin')",
14201421
],
14211422
)
14221423

@@ -1481,6 +1482,22 @@ def grant_schema_usage(evaluator):
14811482
""",
14821483
)
14831484

1485+
create_temp_file(
1486+
tmp_path,
1487+
pathlib.Path(macros_dir, "grant_usage_file.py"),
1488+
"""
1489+
from sqlmesh import macro
1490+
1491+
@macro()
1492+
def grant_usage_role(evaluator, schemas, role):
1493+
if evaluator._environment_naming_info:
1494+
return [
1495+
f"GRANT USAGE ON SCHEMA {schema} TO {role};"
1496+
for schema in schemas
1497+
]
1498+
""",
1499+
)
1500+
14841501
context = Context(paths=tmp_path, config=config)
14851502
snapshots = {s.name: s for s in context.snapshots.values()}
14861503

@@ -1515,6 +1532,7 @@ def grant_schema_usage(evaluator):
15151532
assert after_all_rendered == [
15161533
"GRANT USAGE ON SCHEMA db TO user_role",
15171534
"GRANT SELECT ON VIEW memory.db.test_after_model TO ROLE admin_role",
1535+
'GRANT USAGE ON SCHEMA "db" TO "admin"',
15181536
]
15191537

15201538
after_all_rendered_dev = render_statements(
@@ -1529,6 +1547,7 @@ def grant_schema_usage(evaluator):
15291547
assert after_all_rendered_dev == [
15301548
"GRANT USAGE ON SCHEMA db__dev TO user_role",
15311549
"GRANT SELECT ON VIEW memory.db__dev.test_after_model TO ROLE admin_role",
1550+
'GRANT USAGE ON SCHEMA "db__dev" TO "admin"',
15321551
]
15331552

15341553

tests/core/test_integration.py

+18
Original file line numberDiff line numberDiff line change
@@ -4502,6 +4502,24 @@ def test_multi_dbt(mocker):
45024502
context.apply(plan)
45034503
validate_apply_basics(context, c.PROD, plan.snapshots.values())
45044504

4505+
environment_statements = context.state_sync.get_environment_statements(c.PROD)
4506+
assert len(environment_statements) == 2
4507+
bronze_statements = environment_statements[0]
4508+
assert bronze_statements.before_all == [
4509+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
4510+
]
4511+
assert not bronze_statements.after_all
4512+
silver_statements = environment_statements[1]
4513+
assert not silver_statements.before_all
4514+
assert silver_statements.after_all == [
4515+
"JINJA_STATEMENT_BEGIN;\n{{ store_schemas(schemas) }}\nJINJA_END;"
4516+
]
4517+
assert "store_schemas" in silver_statements.jinja_macros.root_macros
4518+
analytics_table = context.fetchdf("select * from analytic_stats;")
4519+
assert sorted(analytics_table.columns) == sorted(["physical_table", "evaluation_time"])
4520+
schema_table = context.fetchdf("select * from schema_table;")
4521+
assert sorted(schema_table.all_schemas[0]) == sorted(["bronze", "silver"])
4522+
45054523

45064524
def test_multi_hybrid(mocker):
45074525
context = Context(

tests/dbt/test_adapter.py

+48
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
from sqlmesh import Context
1515
from sqlmesh.core.dialect import schema_
16+
from sqlmesh.core.environment import EnvironmentNamingInfo
17+
from sqlmesh.core.macros import RuntimeStage
18+
from sqlmesh.core.renderer import render_statements
1619
from sqlmesh.core.snapshot import SnapshotId
1720
from sqlmesh.dbt.adapter import ParsetimeAdapter
1821
from sqlmesh.dbt.project import Project
@@ -270,3 +273,48 @@ def test_quote_as_configured():
270273
adapter.quote_as_configured("foo", "identifier") == '"foo"'
271274
adapter.quote_as_configured("foo", "schema") == "foo"
272275
adapter.quote_as_configured("foo", "database") == "foo"
276+
277+
278+
def test_on_run_start_end(copy_to_temp_path):
279+
project_root = "tests/fixtures/dbt/sushi_test"
280+
sushi_context = Context(paths=copy_to_temp_path(project_root))
281+
assert len(sushi_context._environment_statements) == 1
282+
environment_statements = sushi_context._environment_statements[0]
283+
284+
assert environment_statements.before_all == [
285+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
286+
]
287+
assert environment_statements.after_all == [
288+
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;"
289+
]
290+
assert "create_tables" in environment_statements.jinja_macros.root_macros
291+
292+
rendered_before_all = render_statements(
293+
environment_statements.before_all,
294+
dialect=sushi_context.default_dialect,
295+
python_env=environment_statements.python_env,
296+
jinja_macros=environment_statements.jinja_macros,
297+
runtime_stage=RuntimeStage.BEFORE_ALL,
298+
)
299+
300+
rendered_after_all = render_statements(
301+
environment_statements.after_all,
302+
dialect=sushi_context.default_dialect,
303+
python_env=environment_statements.python_env,
304+
jinja_macros=environment_statements.jinja_macros,
305+
snapshots=sushi_context.snapshots,
306+
runtime_stage=RuntimeStage.AFTER_ALL,
307+
environment_naming_info=EnvironmentNamingInfo(name="dev"),
308+
)
309+
310+
assert rendered_before_all == [
311+
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)"
312+
]
313+
314+
# The jinja macro should have resolved the schemas for this environment and generated corresponding statements
315+
assert sorted(rendered_after_all) == sorted(
316+
[
317+
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
318+
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
319+
]
320+
)

tests/dbt/test_transformation.py

+10
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,16 @@ def test_dbt_version(sushi_test_project: Project):
997997
assert context.render("{{ dbt_version }}").startswith("1.")
998998

999999

1000+
@pytest.mark.xdist_group("dbt_manifest")
1001+
def test_dbt_on_run_start_end(sushi_test_project: Project):
1002+
context = sushi_test_project.context
1003+
assert context._manifest
1004+
assert context._manifest._on_run_start == [
1005+
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);"
1006+
]
1007+
assert context._manifest._on_run_end == ["{{ create_tables(schemas) }}"]
1008+
1009+
10001010
@pytest.mark.xdist_group("dbt_manifest")
10011011
def test_parsetime_adapter_call(
10021012
assert_exp_eq, sushi_test_project: Project, sushi_test_dbt_context: Context

0 commit comments

Comments
 (0)