Skip to content

Commit cd79b03

Browse files
authored
fix: make sure to migrate previous versions (#683)
* fix: make sure to migrate previous versions * format
1 parent 3bbf353 commit cd79b03

File tree

2 files changed

+68
-18
lines changed

2 files changed

+68
-18
lines changed

sqlmesh/core/snapshot/definition.py

+3
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ class SnapshotDataVersion(PydanticModel, frozen=True):
8989
version: str
9090
change_category: t.Optional[SnapshotChangeCategory]
9191

92+
def snapshot_id(self, name: str) -> SnapshotId:
93+
return SnapshotId(name=name, identifier=self.fingerprint.to_identifier())
94+
9295
@property
9396
def data_version(self) -> SnapshotDataVersion:
9497
return self

sqlmesh/core/state_sync/engine_adapter.py

+65-18
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,26 @@
1919
import json
2020
import logging
2121
import typing as t
22+
from copy import deepcopy
2223

2324
from sqlglot import __version__ as SQLGLOT_VERSION
2425
from sqlglot import exp
2526

27+
from sqlmesh.core.audit import Audit
2628
from sqlmesh.core.dialect import select_from_values
2729
from sqlmesh.core.engine_adapter import EngineAdapter, TransactionType
2830
from sqlmesh.core.environment import Environment
31+
from sqlmesh.core.model import Model
2932
from sqlmesh.core.snapshot import (
3033
Snapshot,
34+
SnapshotDataVersion,
35+
SnapshotFingerprint,
3136
SnapshotId,
3237
SnapshotIdLike,
3338
SnapshotNameVersionLike,
39+
fingerprint_from_model,
3440
)
41+
from sqlmesh.core.snapshot.definition import _parents_from_model
3542
from sqlmesh.core.state_sync.base import SCHEMA_VERSION, StateSync, Versions
3643
from sqlmesh.core.state_sync.common import CommonStateSyncMixin, transactional
3744
from sqlmesh.utils.date import now_timestamp
@@ -404,10 +411,10 @@ def _migrate_rows(self) -> None:
404411
for snapshot in all_snapshots.values():
405412
seen = set()
406413
queue = {snapshot.snapshot_id}
407-
env: t.Dict[str, t.Dict] = {
408-
"models": {},
409-
"audits": {},
410-
}
414+
model = snapshot.model
415+
models: t.Dict[str, Model] = {}
416+
audits: t.Dict[str, Audit] = {}
417+
env: t.Dict[str, t.Dict] = {"models": models, "audits": audits}
411418

412419
while queue:
413420
snapshot_id = queue.pop()
@@ -426,37 +433,77 @@ def _migrate_rows(self) -> None:
426433
cached_env = cache.get(snapshot_id)
427434

428435
if cached_env:
429-
env["models"].update(cached_env["models"])
430-
env["audits"].update(cached_env["audits"])
436+
models.update(cached_env["models"])
437+
audits.update(cached_env["audits"])
431438
else:
432-
env["models"][s.name] = s.model
439+
models[s.name] = s.model
433440

434441
for audit in s.audits:
435-
env["audits"][audit.name] = audit
442+
audits[audit.name] = audit
436443

437444
cache[snapshot_id] = env
438445

439-
new_snapshot = Snapshot.from_model(
440-
snapshot.model,
446+
new_snapshot = deepcopy(snapshot)
447+
448+
fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {}
449+
450+
new_snapshot.fingerprint = fingerprint_from_model(
451+
model,
441452
physical_schema=snapshot.physical_schema,
442-
models=env["models"],
443-
ttl=snapshot.ttl,
444-
version=snapshot.version,
445-
audits=env["audits"],
453+
models=models,
454+
audits=audits,
446455
)
447456

448-
if new_snapshot == snapshot or new_snapshot in all_snapshots:
449-
logger.debug(f"{snapshot.snapshot_id} is unchaged")
457+
new_snapshot.parents = tuple(
458+
SnapshotId(
459+
name=name,
460+
identifier=fingerprint_from_model(
461+
models[name],
462+
physical_schema=snapshot.physical_schema,
463+
models=models,
464+
audits=audits,
465+
cache=fingerprint_cache,
466+
).to_identifier(),
467+
)
468+
for name in _parents_from_model(model, models)
469+
)
470+
471+
if new_snapshot == snapshot:
472+
logger.debug(f"{new_snapshot.snapshot_id} is unchanged.")
473+
continue
474+
if new_snapshot.snapshot_id in all_snapshots:
475+
logger.debug(f"{new_snapshot.snapshot_id} exists.")
450476
continue
451477

452-
new_snapshot.merge_intervals(snapshot)
453478
snapshot_mapping[snapshot.snapshot_id] = new_snapshot
454-
logger.debug(f"{snapshot.snapshot_id} mapped to {new_snapshot.snapshot_id}")
479+
logger.debug(f"{snapshot.snapshot_id} mapped to {new_snapshot.snapshot_id}.")
455480

456481
if not snapshot_mapping:
457482
logger.debug("No changes to snapshots detected.")
458483
return
459484

485+
def map_data_versions(
486+
name: str, versions: t.Sequence[SnapshotDataVersion]
487+
) -> t.Tuple[SnapshotDataVersion, ...]:
488+
version_ids = ((version.snapshot_id(name), version) for version in versions)
489+
490+
return tuple(
491+
snapshot_mapping[version_id].data_version
492+
if version_id in snapshot_mapping
493+
else version
494+
for version_id, version in version_ids
495+
)
496+
497+
for from_snapshot_id, to_snapshot in snapshot_mapping.items():
498+
from_snapshot = all_snapshots[from_snapshot_id]
499+
to_snapshot.previous_versions = map_data_versions(
500+
from_snapshot.name, from_snapshot.previous_versions
501+
)
502+
to_snapshot.indirect_versions = {
503+
name: map_data_versions(name, versions)
504+
for name, versions in from_snapshot.indirect_versions.items()
505+
}
506+
460507
self.delete_snapshots(snapshot_mapping)
461508
self._push_snapshots(snapshot_mapping.values())
462509

0 commit comments

Comments
 (0)