Skip to content

Commit c02b8e4

Browse files
authored
Merge pull request #92 from camuthig/fix-empty-result-set
Handle empty result sets in CTEs
2 parents 744d745 + bd6fc3f commit c02b8e4

File tree

3 files changed

+76
-3
lines changed

3 files changed

+76
-3
lines changed

django_cte/query.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22
from __future__ import unicode_literals
33

44
import django
5+
from django.core.exceptions import EmptyResultSet
56
from django.db import connections
67
from django.db.models.sql import DeleteQuery, Query, UpdateQuery
78
from django.db.models.sql.compiler import (
89
SQLCompiler,
910
SQLDeleteCompiler,
1011
SQLUpdateCompiler,
1112
)
13+
from django.db.models.sql.constants import LOUTER
14+
from django.db.models.sql.where import ExtraWhere, WhereNode
1215

1316
from .expressions import CTESubqueryResolver
17+
from .join import QJoin
1418

1519

1620
class CTEQuery(Query):
@@ -75,9 +79,40 @@ def generate_sql(cls, connection, query, as_sql):
7579
for cte in query._with_ctes:
7680
if django.VERSION > (4, 2):
7781
_ignore_with_col_aliases(cte.query)
78-
compiler = cte.query.get_compiler(connection=connection)
82+
83+
alias = query.alias_map.get(cte.name)
84+
should_elide_empty = (
85+
not isinstance(alias, QJoin) or alias.join_type != LOUTER
86+
)
87+
88+
if django.VERSION >= (4, 0):
89+
compiler = cte.query.get_compiler(
90+
connection=connection, elide_empty=should_elide_empty
91+
)
92+
else:
93+
compiler = cte.query.get_compiler(connection=connection)
94+
7995
qn = compiler.quote_name_unless_alias
80-
cte_sql, cte_params = compiler.as_sql()
96+
try:
97+
cte_sql, cte_params = compiler.as_sql()
98+
except EmptyResultSet:
99+
if django.VERSION < (4, 0) and not should_elide_empty:
100+
# elide_empty is not available prior to Django 4.0. The
101+
# below behavior emulates the logic of it, rebuilding
102+
# the CTE query with a WHERE clause that is always false
103+
# but that the SqlCompiler cannot optimize away. This is
104+
# only required for left outer joins, as standard inner
105+
# joins should be optimized and raise the EmptyResultSet
106+
query = cte.query.copy()
107+
query.where = WhereNode([ExtraWhere(["1 = 0"], [])])
108+
compiler = query.get_compiler(connection=connection)
109+
cte_sql, cte_params = compiler.as_sql()
110+
else:
111+
# If the CTE raises an EmptyResultSet the SqlCompiler still
112+
# needs to know the information about this base compiler
113+
# like, col_count and klass_info.
114+
as_sql()
115+
raise
81116
template = cls.get_cte_query_template(cte)
82117
ctes.append(template.format(name=qn(cte.name), query=cte_sql))
83118
params.extend(cte_params)

django_cte/raw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def quote_name_unless_alias(self, name):
3232
class raw_cte_queryset(object):
3333
class query(object):
3434
@staticmethod
35-
def get_compiler(connection):
35+
def get_compiler(connection, *, elide_empty=None):
3636
return raw_cte_compiler(connection)
3737

3838
@staticmethod

tests/test_cte.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,3 +562,41 @@ def test_explain(self):
562562
)
563563

564564
self.assertIsInstance(orders.explain(), str)
565+
566+
def test_empty_result_set_cte(self):
567+
"""
568+
Verifies that the CTEQueryCompiler can handle empty result sets in the
569+
related CTEs
570+
"""
571+
totals = With(
572+
Order.objects
573+
.filter(id__in=[])
574+
.values("region_id")
575+
.annotate(total=Sum("amount")),
576+
name="totals",
577+
)
578+
orders = (
579+
totals.join(Order, region=totals.col.region_id)
580+
.with_cte(totals)
581+
.annotate(region_total=totals.col.total)
582+
.order_by("amount")
583+
)
584+
585+
self.assertEqual(len(orders), 0)
586+
587+
def test_left_outer_join_on_empty_result_set_cte(self):
588+
totals = With(
589+
Order.objects
590+
.filter(id__in=[])
591+
.values("region_id")
592+
.annotate(total=Sum("amount")),
593+
name="totals",
594+
)
595+
orders = (
596+
totals.join(Order, region=totals.col.region_id, _join_type=LOUTER)
597+
.with_cte(totals)
598+
.annotate(region_total=totals.col.total)
599+
.order_by("amount")
600+
)
601+
602+
self.assertEqual(len(orders), 22)

0 commit comments

Comments
 (0)