Skip to content

Commit e100fbf

Browse files
authored
Fix: Batch intervals in DAG order to make sure that parent snapshot signals are respected by their children (#3956)
1 parent d745fc3 commit e100fbf

File tree

2 files changed

+112
-27
lines changed

2 files changed

+112
-27
lines changed

sqlmesh/core/scheduler.py

+25-26
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
earliest_start_date,
2222
missing_intervals,
2323
merge_intervals,
24+
snapshots_to_dag,
2425
Intervals,
2526
)
2627
from sqlmesh.core.snapshot.definition import (
@@ -344,35 +345,26 @@ def run(
344345

345346
return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS
346347

347-
def batch_intervals(
348-
self,
349-
merged_intervals: SnapshotToIntervals,
350-
start: t.Optional[TimeLike] = None,
351-
end: t.Optional[TimeLike] = None,
352-
execution_time: t.Optional[TimeLike] = None,
353-
) -> t.Dict[Snapshot, Intervals]:
354-
def expand_range_as_interval(
355-
start_ts: int, end_ts: int, interval_unit: IntervalUnit
356-
) -> t.List[Interval]:
357-
values = expand_range(start_ts, end_ts, interval_unit)
358-
return [(values[i], values[i + 1]) for i in range(len(values) - 1)]
359-
360-
dag = DAG[str]()
361-
362-
for snapshot in merged_intervals:
363-
dag.add(snapshot.name, [p.name for p in snapshot.parents])
364-
365-
snapshot_intervals = {
366-
snapshot: [
367-
i
368-
for interval in intervals
369-
for i in expand_range_as_interval(*interval, snapshot.node.interval_unit)
370-
]
348+
def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snapshot, Intervals]:
349+
dag = snapshots_to_dag(merged_intervals)
350+
351+
snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = {
352+
snapshot.snapshot_id: (
353+
snapshot,
354+
[
355+
i
356+
for interval in intervals
357+
for i in _expand_range_as_interval(*interval, snapshot.node.interval_unit)
358+
],
359+
)
371360
for snapshot, intervals in merged_intervals.items()
372361
}
373362
snapshot_batches = {}
374363
all_unready_intervals: t.Dict[str, set[Interval]] = {}
375-
for snapshot, intervals in snapshot_intervals.items():
364+
for snapshot_id in dag:
365+
if snapshot_id not in snapshot_intervals:
366+
continue
367+
snapshot, intervals = snapshot_intervals[snapshot_id]
376368
unready = set(intervals)
377369
intervals = snapshot.check_ready_intervals(intervals)
378370
unready -= set(intervals)
@@ -429,7 +421,7 @@ def run_merged_intervals(
429421
"""
430422
execution_time = execution_time or now_timestamp()
431423

432-
batched_intervals = self.batch_intervals(merged_intervals, start, end, execution_time)
424+
batched_intervals = self.batch_intervals(merged_intervals)
433425

434426
self.console.start_evaluation_progress(
435427
{snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()},
@@ -686,3 +678,10 @@ def _resolve_one_snapshot_per_version(
686678
snapshot_per_version[key] = snapshot
687679

688680
return snapshot_per_version
681+
682+
683+
def _expand_range_as_interval(
684+
start_ts: int, end_ts: int, interval_unit: IntervalUnit
685+
) -> t.List[Interval]:
686+
values = expand_range(start_ts, end_ts, interval_unit)
687+
return [(values[i], values[i + 1]) for i in range(len(values) - 1)]

tests/core/test_scheduler.py

+87-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _get_batched_missing_intervals(
7676
execution_time: t.Optional[TimeLike] = None,
7777
) -> SnapshotToIntervals:
7878
merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time)
79-
return scheduler.batch_intervals(merged_intervals, start, end, execution_time)
79+
return scheduler.batch_intervals(merged_intervals)
8080

8181
return _get_batched_missing_intervals
8282

@@ -722,3 +722,89 @@ def signal_b(batch: DatetimeRanges):
722722
c: [],
723723
d: [],
724724
}
725+
726+
727+
def test_signals_snapshots_out_of_order(
728+
mocker: MockerFixture, make_snapshot, get_batched_missing_intervals
729+
):
730+
@signal()
731+
def signal_base(batch: DatetimeRanges):
732+
return [batch[0]]
733+
734+
signals = signal.get_registry()
735+
736+
snapshot_a = make_snapshot(
737+
load_sql_based_model(
738+
parse( # type: ignore
739+
"""
740+
MODEL (
741+
name a,
742+
kind INCREMENTAL_BY_TIME_RANGE(
743+
lookback 1,
744+
time_column dt,
745+
),
746+
start '2023-01-01',
747+
signals SIGNAL_BASE(),
748+
);
749+
SELECT @start_date AS dt;
750+
"""
751+
),
752+
signal_definitions=signals,
753+
),
754+
)
755+
756+
snapshot_b = make_snapshot(
757+
load_sql_based_model(
758+
parse( # type: ignore
759+
"""
760+
MODEL (
761+
name b,
762+
kind INCREMENTAL_BY_TIME_RANGE(
763+
lookback 1,
764+
time_column dt,
765+
),
766+
start '2023-01-01'
767+
);
768+
SELECT @start_date AS dt;
769+
"""
770+
),
771+
signal_definitions=signals,
772+
)
773+
)
774+
775+
snapshot_c = make_snapshot(
776+
load_sql_based_model(
777+
parse( # type: ignore
778+
"""
779+
MODEL (
780+
name c,
781+
kind INCREMENTAL_BY_TIME_RANGE(
782+
lookback 1,
783+
time_column dt,
784+
),
785+
start '2023-01-01',
786+
);
787+
SELECT * FROM a UNION SELECT * FROM b
788+
"""
789+
),
790+
signal_definitions=signals,
791+
),
792+
nodes={snapshot_a.name: snapshot_a.model, snapshot_b.name: snapshot_b.model},
793+
)
794+
795+
snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1)
796+
scheduler = Scheduler(
797+
snapshots=[snapshot_c, snapshot_b, snapshot_a], # reverse order
798+
snapshot_evaluator=snapshot_evaluator,
799+
state_sync=mocker.MagicMock(),
800+
max_workers=2,
801+
default_catalog=None,
802+
)
803+
804+
batches = get_batched_missing_intervals(scheduler, "2023-01-01", "2023-01-03", None)
805+
806+
assert batches == {
807+
snapshot_a: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
808+
snapshot_b: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))],
809+
snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
810+
}

0 commit comments

Comments
 (0)