Skip to content

Commit e9fc9d2

Browse files
erindruizeigerman
authored andcommitted
Fix: Fix logging, grain resolution and source/target object names in table_diff (#3921)
1 parent 44e0180 commit e9fc9d2

File tree

4 files changed

+137
-15
lines changed

4 files changed

+137
-15
lines changed

sqlmesh/core/console.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from sqlglot.dialects.dialect import DialectType
5050
from sqlmesh.core.context_diff import ContextDiff
5151
from sqlmesh.core.plan import Plan, EvaluatablePlan, PlanBuilder, SnapshotIntervals
52-
from sqlmesh.core.table_diff import RowDiff, SchemaDiff
52+
from sqlmesh.core.table_diff import TableDiff, RowDiff, SchemaDiff
5353

5454
LayoutWidget = t.TypeVar("LayoutWidget", bound=t.Union[widgets.VBox, widgets.HBox])
5555

@@ -290,6 +290,10 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
290290
def loading_stop(self, id: uuid.UUID) -> None:
291291
"""Stop loading for the given id."""
292292

293+
@abc.abstractmethod
294+
def show_table_diff_summary(self, table_diff: TableDiff) -> None:
295+
"""Display information about the tables being diffed and how they are being joined"""
296+
293297
@abc.abstractmethod
294298
def show_schema_diff(self, schema_diff: SchemaDiff) -> None:
295299
"""Show table schema diff."""
@@ -459,6 +463,9 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
459463
def loading_stop(self, id: uuid.UUID) -> None:
460464
pass
461465

466+
def show_table_diff_summary(self, table_diff: TableDiff) -> None:
467+
pass
468+
462469
def show_schema_diff(self, schema_diff: SchemaDiff) -> None:
463470
pass
464471

@@ -1274,6 +1281,44 @@ def loading_stop(self, id: uuid.UUID) -> None:
12741281
self.loading_status[id].stop()
12751282
del self.loading_status[id]
12761283

1284+
def show_table_diff_summary(self, table_diff: TableDiff) -> None:
1285+
tree = Tree("\n[b]Table Diff")
1286+
1287+
if table_diff.model_name:
1288+
model = Tree("Model:")
1289+
model.add(f"[blue]{table_diff.model_name}[/blue]")
1290+
1291+
tree.add(model)
1292+
1293+
envs = Tree("Environment:")
1294+
source = Tree(
1295+
f"Source: [{self.TABLE_DIFF_SOURCE_BLUE}]{table_diff.source_alias}[/{self.TABLE_DIFF_SOURCE_BLUE}]"
1296+
)
1297+
envs.add(source)
1298+
1299+
target = Tree(f"Target: [green]{table_diff.target_alias}[/green]")
1300+
envs.add(target)
1301+
1302+
tree.add(envs)
1303+
1304+
tables = Tree("Tables:")
1305+
1306+
tables.add(
1307+
f"Source: [{self.TABLE_DIFF_SOURCE_BLUE}]{table_diff.source}[/{self.TABLE_DIFF_SOURCE_BLUE}]"
1308+
)
1309+
tables.add(f"Target: [green]{table_diff.target}[/green]")
1310+
1311+
tree.add(tables)
1312+
1313+
join = Tree("Join On:")
1314+
_, _, key_column_names = table_diff.key_columns
1315+
for col_name in key_column_names:
1316+
join.add(f"[yellow]{col_name}[/yellow]")
1317+
1318+
tree.add(join)
1319+
1320+
self._print(tree)
1321+
12771322
def show_schema_diff(self, schema_diff: SchemaDiff) -> None:
12781323
source_name = schema_diff.source
12791324
if schema_diff.source_alias:

sqlmesh/core/context.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -1518,33 +1518,37 @@ def table_diff(
15181518
if not target_env:
15191519
raise SQLMeshError(f"Could not find environment '{target}')")
15201520

1521+
# Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point
1522+
# to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews
15211523
source = next(
15221524
snapshot for snapshot in source_env.snapshots if snapshot.name == model.fqn
1523-
).table_name()
1525+
).qualified_view_name.for_environment(source_env.naming_info, adapter.dialect)
1526+
15241527
target = next(
15251528
snapshot for snapshot in target_env.snapshots if snapshot.name == model.fqn
1526-
).table_name()
1529+
).qualified_view_name.for_environment(target_env.naming_info, adapter.dialect)
1530+
15271531
source_alias = source_env.name
15281532
target_alias = target_env.name
15291533

15301534
if not on:
1531-
for ref in model.all_references:
1532-
if ref.unique:
1533-
expr = ref.expression
1534-
1535-
if isinstance(expr, exp.Tuple):
1536-
on = [key.this.sql() for key in expr.expressions]
1537-
else:
1538-
# Handle a single Column or Paren expression
1539-
on = [expr.this.sql()]
1535+
on = []
1536+
for expr in [ref.expression for ref in model.all_references if ref.unique]:
1537+
if isinstance(expr, exp.Tuple):
1538+
on.extend(
1539+
[key.this.sql(dialect=adapter.dialect) for key in expr.expressions]
1540+
)
1541+
else:
1542+
# Handle a single Column or Paren expression
1543+
on.append(expr.this.sql(dialect=adapter.dialect))
15401544

15411545
if not on:
15421546
raise SQLMeshError(
15431547
"SQLMesh doesn't know how to join the two tables. Specify the `grains` in each model definition or pass join column names in separate `-o` flags."
15441548
)
15451549

15461550
table_diff = TableDiff(
1547-
adapter=adapter,
1551+
adapter=adapter.with_log_level(logger.getEffectiveLevel()),
15481552
source=source,
15491553
target=target,
15501554
on=on,
@@ -1558,6 +1562,7 @@ def table_diff(
15581562
decimals=decimals,
15591563
)
15601564
if show:
1565+
self.console.show_table_diff_summary(table_diff)
15611566
self.console.show_schema_diff(table_diff.schema_diff())
15621567
self.console.show_row_diff(
15631568
table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check),

tests/core/test_table_diff.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
@pytest.mark.slow
13-
def test_data_diff(sushi_context_fixed_date):
13+
def test_data_diff(sushi_context_fixed_date, capsys, caplog):
1414
model = sushi_context_fixed_date.models['"memory"."sushi"."customer_revenue_by_day"']
1515

1616
model.query.select(exp.cast("'1'", "VARCHAR").as_("modified_col"), "1 AS y", copy=False)
@@ -74,6 +74,11 @@ def test_data_diff(sushi_context_fixed_date):
7474
model_or_snapshot="sushi.customer_revenue_by_day",
7575
)
7676

77+
# verify queries were actually logged to the log file, this helps immensely with debugging
78+
console_output = capsys.readouterr()
79+
assert "__sqlmesh_join_key" not in console_output # they should not go to the console
80+
assert "__sqlmesh_join_key" in caplog.text
81+
7782
schema_diff = diff.schema_diff()
7883
assert schema_diff.added == [("z", exp.DataType.build("int"))]
7984
assert schema_diff.modified == {
@@ -272,6 +277,12 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
272277
sample_query_sql = 'SELECT "s_exists", "t_exists", "row_joined", "row_full_match", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "key_matches" = 0 OR "value_matches" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20'
273278
drop_sql = 'DROP TABLE IF EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"'
274279

280+
# make with_log_level() return the current instance of engine_adapter so we can still spy on _execute
281+
mocker.patch.object(
282+
engine_adapter, "with_log_level", new_callable=lambda: lambda _: engine_adapter
283+
)
284+
assert engine_adapter.with_log_level(1) == engine_adapter
285+
275286
spy_execute = mocker.spy(engine_adapter, "_execute")
276287
mocker.patch("sqlmesh.core.engine_adapter.base.random_id", return_value="abcdefgh")
277288

@@ -302,3 +313,53 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
302313

303314
query_sql_where = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "key", "value", "key" AS "__sqlmesh_join_key" FROM "table_diff_source" WHERE "key" = 2), "__target" AS (SELECT "key", "value", "key" AS "__sqlmesh_join_key" FROM "table_diff_target" WHERE "key" = 2), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" AND (NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL) THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"'
304315
spy_execute.assert_any_call(query_sql_where)
316+
317+
318+
@pytest.mark.slow
319+
def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context):
320+
(sushi_context_fixed_date.path / "models" / "waiter_revenue_by_day.sql").write_text("""
321+
MODEL (
322+
name sushi.waiter_revenue_by_day,
323+
kind incremental_by_time_range (
324+
time_column event_date,
325+
batch_size 10,
326+
),
327+
owner jen,
328+
cron '@daily',
329+
audits (
330+
NUMBER_OF_ROWS(threshold := 0)
331+
),
332+
grain (waiter_id, event_date)
333+
);
334+
335+
SELECT
336+
o.waiter_id::INT + 1 AS waiter_id, /* Waiter id */
337+
SUM(oi.quantity * i.price)::DOUBLE AS revenue, /* Revenue from orders taken by this waiter */
338+
o.event_date::DATE AS event_date /* Date */
339+
FROM sushi.orders AS o
340+
LEFT JOIN sushi.order_items AS oi
341+
ON o.id = oi.order_id AND o.event_date = oi.event_date
342+
LEFT JOIN sushi.items AS i
343+
ON oi.item_id = i.id AND oi.event_date = i.event_date
344+
WHERE
345+
o.event_date BETWEEN @start_date AND @end_date
346+
GROUP BY
347+
o.waiter_id,
348+
o.event_date
349+
""")
350+
# this creates a dev preview of "sushi.waiter_revenue_by_day"
351+
sushi_context_fixed_date.refresh()
352+
sushi_context_fixed_date.auto_categorize_changes = CategorizerConfig(
353+
sql=AutoCategorizationMode.FULL
354+
)
355+
sushi_context_fixed_date.plan(environment="unit_test", auto_apply=True, include_unmodified=True)
356+
357+
table_diff = sushi_context_fixed_date.table_diff(
358+
source="unit_test", target="prod", model_or_snapshot="sushi.waiter_revenue_by_day"
359+
)
360+
361+
assert table_diff.source == "memory.sushi__unit_test.waiter_revenue_by_day"
362+
assert table_diff.target == "memory.sushi.waiter_revenue_by_day"
363+
364+
_, _, col_names = table_diff.key_columns
365+
assert col_names == ["waiter_id", "event_date"]

tests/integrations/jupyter/test_magics.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -650,8 +650,19 @@ def test_table_diff(notebook, loaded_sushi_context, convert_all_html_output_to_t
650650

651651
assert not output.stdout
652652
assert not output.stderr
653-
assert len(output.outputs) == 4
653+
assert len(output.outputs) == 5
654654
assert convert_all_html_output_to_text(output) == [
655+
"""Table Diff
656+
├── Model:
657+
│ └── sushi.top_waiters
658+
├── Environment:
659+
│ ├── Source: dev
660+
│ └── Target: prod
661+
├── Tables:
662+
│ ├── Source: memory.sushi__dev.top_waiters
663+
│ └── Target: memory.sushi.top_waiters
664+
└── Join On:
665+
└── waiter_id""",
655666
"""Schema Diff Between 'DEV' and 'PROD' environments for model 'sushi.top_waiters':
656667
└── Schemas match""",
657668
"""Row Counts:

0 commit comments

Comments
 (0)