Skip to content

Commit 0bb482a

Browse files
Fix: Pass runtime_stage and engine_adapter in environment_statements (#4021)
1 parent 6d6b644 commit 0bb482a

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

sqlmesh/core/environment.py

+2
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def execute_environment_statements(
244244
end=end,
245245
execution_time=execution_time,
246246
environment_naming_info=environment_naming_info,
247+
runtime_stage=runtime_stage,
248+
engine_adapter=adapter,
247249
)
248250
]:
249251
with adapter.transaction():

tests/core/test_context.py

+40-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from sqlmesh.core.dialect import parse, schema_
3232
from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
3333
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements
34-
from sqlmesh.core.macros import MacroEvaluator
34+
from sqlmesh.core.macros import MacroEvaluator, RuntimeStage
3535
from sqlmesh.core.model import load_sql_based_model, model, SqlModel, Model
3636
from sqlmesh.core.model.cache import OptimizedQueryCache
3737
from sqlmesh.core.renderer import render_statements
@@ -1444,7 +1444,7 @@ def test_environment_statements(tmp_path: pathlib.Path):
14441444
14451445
@macro()
14461446
def grant_select_privileges(evaluator):
1447-
if evaluator._environment_naming_info:
1447+
if evaluator._environment_naming_info and evaluator.runtime_stage == 'before_all':
14481448
mapping = to_view_mapping(
14491449
evaluator._snapshots.values(), evaluator._environment_naming_info
14501450
)
@@ -1493,7 +1493,10 @@ def grant_schema_usage(evaluator):
14931493
assert isinstance(python_env["grant_select_privileges"], Executable)
14941494

14951495
before_all_rendered = render_statements(
1496-
statements=before_all, dialect=dialect, python_env=python_env
1496+
statements=before_all,
1497+
dialect=dialect,
1498+
python_env=python_env,
1499+
runtime_stage=RuntimeStage.BEFORE_ALL,
14971500
)
14981501

14991502
assert before_all_rendered == [
@@ -1506,6 +1509,7 @@ def grant_schema_usage(evaluator):
15061509
python_env=python_env,
15071510
snapshots=snapshots,
15081511
environment_naming_info=EnvironmentNamingInfo(name="prod"),
1512+
runtime_stage=RuntimeStage.BEFORE_ALL,
15091513
)
15101514

15111515
assert after_all_rendered == [
@@ -1519,6 +1523,7 @@ def grant_schema_usage(evaluator):
15191523
python_env=python_env,
15201524
snapshots=snapshots,
15211525
environment_naming_info=EnvironmentNamingInfo(name="dev"),
1526+
runtime_stage=RuntimeStage.BEFORE_ALL,
15221527
)
15231528

15241529
assert after_all_rendered_dev == [
@@ -1534,7 +1539,7 @@ def test_plan_environment_statements(tmp_path: pathlib.Path):
15341539

15351540
config = Config(
15361541
model_defaults=ModelDefaultsConfig(dialect=dialect),
1537-
before_all=["@create_stats_table()"],
1542+
before_all=["@create_stats_table()", "@access_adapter()"],
15381543
after_all=["CREATE TABLE IF NOT EXISTS after_table AS SELECT @some_var"],
15391544
variables={"some_var": 5},
15401545
)
@@ -1578,9 +1583,34 @@ def create_stats_table(evaluator):
15781583
""",
15791584
)
15801585

1586+
create_temp_file(
1587+
tmp_path,
1588+
pathlib.Path(macros_dir, "access_adapter.py"),
1589+
"""
1590+
from sqlmesh.core.macros import macro
1591+
1592+
@macro()
1593+
def access_adapter(evaluator):
1594+
if evaluator.runtime_stage == 'before_all':
1595+
engine_adapter = evaluator.engine_adapter
1596+
for i in range(10):
1597+
try:
1598+
sql_inside_macro = f"CREATE TABLE IF NOT EXISTS db_connect AS SELECT {i} as 'access_attempt'"
1599+
engine_adapter.execute(sql_inside_macro)
1600+
return None
1601+
except Exception as e:
1602+
sleep(10)
1603+
raise Exception(f"Failed to connect to the database")
1604+
""",
1605+
)
1606+
15811607
context = Context(paths=tmp_path, config=config)
15821608

1583-
assert context._environment_statements[0].before_all == ["@create_stats_table()"]
1609+
assert context._environment_statements[0].before_all == [
1610+
"@create_stats_table()",
1611+
"@access_adapter()",
1612+
]
1613+
15841614
assert context._environment_statements[0].after_all == [
15851615
"CREATE TABLE IF NOT EXISTS after_table AS SELECT @some_var"
15861616
]
@@ -1619,6 +1649,11 @@ def create_stats_table(evaluator):
16191649
assert state_table[0].after_all == context._environment_statements[0].after_all
16201650
assert state_table[0].python_env == context._environment_statements[0].python_env
16211651

1652+
# This table will be created inside the macro by accessing the engine_adapter directly
1653+
inside_macro_execute = context.fetchdf("select * from memory.db_connect").to_dict()
1654+
assert (attempt_column := inside_macro_execute.get("access_attempt"))
1655+
assert isinstance(attempt_column, dict) and attempt_column[0] < 10
1656+
16221657

16231658
def test_environment_statements_dialect(tmp_path: Path):
16241659
before_all = [

tests/core/test_integration.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -4924,9 +4924,13 @@ def test_plan_production_environment_statements(tmp_path: Path):
49244924
f.write(defn)
49254925

49264926
before_all = [
4927-
"CREATE TABLE IF NOT EXISTS schema_names_for_@this_env (physical_schema_name VARCHAR)"
4927+
"CREATE TABLE IF NOT EXISTS schema_names_for_@this_env (physical_schema_name VARCHAR)",
4928+
"@IF(@runtime_stage = 'before_all', CREATE TABLE IF NOT EXISTS should_create AS SELECT @runtime_stage)",
4929+
]
4930+
after_all = [
4931+
"@IF(@this_env = 'prod', CREATE TABLE IF NOT EXISTS after_t AS SELECT @var_5)",
4932+
"@IF(@runtime_stage = 'before_all', CREATE TABLE IF NOT EXISTS not_create AS SELECT @runtime_stage)",
49284933
]
4929-
after_all = ["@IF(@this_env = 'prod', CREATE TABLE IF NOT EXISTS after_t AS SELECT @var_5)"]
49304934
config = Config(
49314935
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
49324936
before_all=before_all,
@@ -4948,6 +4952,14 @@ def test_plan_production_environment_statements(tmp_path: Path):
49484952
assert environment_statements[0].python_env.keys() == {"__sqlmesh__vars__"}
49494953
assert environment_statements[0].python_env["__sqlmesh__vars__"].payload == "{'var_5': 5}"
49504954

4955+
should_create = ctx.fetchdf("select * from should_create").to_dict()
4956+
assert should_create["before_all"][0] == "before_all"
4957+
4958+
with pytest.raises(
4959+
Exception, match=r"Catalog Error: Table with name not_create does not exist!"
4960+
):
4961+
ctx.fetchdf("select * from not_create")
4962+
49514963

49524964
@time_machine.travel("2025-03-08 00:00:00 UTC")
49534965
def test_tz(init_and_plan_context):

0 commit comments

Comments
 (0)