Skip to content

Commit 3c3e0ad

Browse files
committed
Fix: Batch intervals in DAG order to make sure that parent snapshot signals are respected by their children (#3956)
1 parent 3cd933a commit 3c3e0ad

File tree

2 files changed

+112
-27
lines changed

2 files changed

+112
-27
lines changed

sqlmesh/core/scheduler.py

+25-26
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
earliest_start_date,
2121
missing_intervals,
2222
merge_intervals,
23+
snapshots_to_dag,
2324
Intervals,
2425
)
2526
from sqlmesh.core.snapshot.definition import (
@@ -341,35 +342,26 @@ def run(
341342

342343
return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS
343344

344-
def batch_intervals(
345-
self,
346-
merged_intervals: SnapshotToIntervals,
347-
start: t.Optional[TimeLike] = None,
348-
end: t.Optional[TimeLike] = None,
349-
execution_time: t.Optional[TimeLike] = None,
350-
) -> t.Dict[Snapshot, Intervals]:
351-
def expand_range_as_interval(
352-
start_ts: int, end_ts: int, interval_unit: IntervalUnit
353-
) -> t.List[Interval]:
354-
values = expand_range(start_ts, end_ts, interval_unit)
355-
return [(values[i], values[i + 1]) for i in range(len(values) - 1)]
356-
357-
dag = DAG[str]()
358-
359-
for snapshot in merged_intervals:
360-
dag.add(snapshot.name, [p.name for p in snapshot.parents])
361-
362-
snapshot_intervals = {
363-
snapshot: [
364-
i
365-
for interval in intervals
366-
for i in expand_range_as_interval(*interval, snapshot.node.interval_unit)
367-
]
345+
def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snapshot, Intervals]:
346+
dag = snapshots_to_dag(merged_intervals)
347+
348+
snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = {
349+
snapshot.snapshot_id: (
350+
snapshot,
351+
[
352+
i
353+
for interval in intervals
354+
for i in _expand_range_as_interval(*interval, snapshot.node.interval_unit)
355+
],
356+
)
368357
for snapshot, intervals in merged_intervals.items()
369358
}
370359
snapshot_batches = {}
371360
all_unready_intervals: t.Dict[str, set[Interval]] = {}
372-
for snapshot, intervals in snapshot_intervals.items():
361+
for snapshot_id in dag:
362+
if snapshot_id not in snapshot_intervals:
363+
continue
364+
snapshot, intervals = snapshot_intervals[snapshot_id]
373365
unready = set(intervals)
374366
intervals = snapshot.check_ready_intervals(intervals)
375367
unready -= set(intervals)
@@ -425,7 +417,7 @@ def run_merged_intervals(
425417
"""
426418
execution_time = execution_time or now_timestamp()
427419

428-
batched_intervals = self.batch_intervals(merged_intervals, start, end, execution_time)
420+
batched_intervals = self.batch_intervals(merged_intervals)
429421

430422
self.console.start_evaluation_progress(
431423
{snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()},
@@ -646,3 +638,10 @@ def _resolve_one_snapshot_per_version(
646638
snapshot_per_version[key] = snapshot
647639

648640
return snapshot_per_version
641+
642+
643+
def _expand_range_as_interval(
644+
start_ts: int, end_ts: int, interval_unit: IntervalUnit
645+
) -> t.List[Interval]:
646+
values = expand_range(start_ts, end_ts, interval_unit)
647+
return [(values[i], values[i + 1]) for i in range(len(values) - 1)]

tests/core/test_scheduler.py

+87-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _get_batched_missing_intervals(
7676
execution_time: t.Optional[TimeLike] = None,
7777
) -> SnapshotToIntervals:
7878
merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time)
79-
return scheduler.batch_intervals(merged_intervals, start, end, execution_time)
79+
return scheduler.batch_intervals(merged_intervals)
8080

8181
return _get_batched_missing_intervals
8282

@@ -722,3 +722,89 @@ def signal_b(batch: DatetimeRanges):
722722
c: [],
723723
d: [],
724724
}
725+
726+
727+
def test_signals_snapshots_out_of_order(
728+
mocker: MockerFixture, make_snapshot, get_batched_missing_intervals
729+
):
730+
@signal()
731+
def signal_base(batch: DatetimeRanges):
732+
return [batch[0]]
733+
734+
signals = signal.get_registry()
735+
736+
snapshot_a = make_snapshot(
737+
load_sql_based_model(
738+
parse( # type: ignore
739+
"""
740+
MODEL (
741+
name a,
742+
kind INCREMENTAL_BY_TIME_RANGE(
743+
lookback 1,
744+
time_column dt,
745+
),
746+
start '2023-01-01',
747+
signals SIGNAL_BASE(),
748+
);
749+
SELECT @start_date AS dt;
750+
"""
751+
),
752+
signal_definitions=signals,
753+
),
754+
)
755+
756+
snapshot_b = make_snapshot(
757+
load_sql_based_model(
758+
parse( # type: ignore
759+
"""
760+
MODEL (
761+
name b,
762+
kind INCREMENTAL_BY_TIME_RANGE(
763+
lookback 1,
764+
time_column dt,
765+
),
766+
start '2023-01-01'
767+
);
768+
SELECT @start_date AS dt;
769+
"""
770+
),
771+
signal_definitions=signals,
772+
)
773+
)
774+
775+
snapshot_c = make_snapshot(
776+
load_sql_based_model(
777+
parse( # type: ignore
778+
"""
779+
MODEL (
780+
name c,
781+
kind INCREMENTAL_BY_TIME_RANGE(
782+
lookback 1,
783+
time_column dt,
784+
),
785+
start '2023-01-01',
786+
);
787+
SELECT * FROM a UNION SELECT * FROM b
788+
"""
789+
),
790+
signal_definitions=signals,
791+
),
792+
nodes={snapshot_a.name: snapshot_a.model, snapshot_b.name: snapshot_b.model},
793+
)
794+
795+
snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1)
796+
scheduler = Scheduler(
797+
snapshots=[snapshot_c, snapshot_b, snapshot_a], # reverse order
798+
snapshot_evaluator=snapshot_evaluator,
799+
state_sync=mocker.MagicMock(),
800+
max_workers=2,
801+
default_catalog=None,
802+
)
803+
804+
batches = get_batched_missing_intervals(scheduler, "2023-01-01", "2023-01-03", None)
805+
806+
assert batches == {
807+
snapshot_a: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
808+
snapshot_b: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))],
809+
snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
810+
}

0 commit comments

Comments
 (0)