Skip to content

Add support for combination two CTE queries #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions django_cte/cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from __future__ import unicode_literals

from django.db.models import Manager
from django.db.models.query import Q, QuerySet, ValuesIterable
from django.db.models.query import (
Q, QuerySet, ValuesIterable, EmptyQuerySet)
from django.db.models.sql.datastructures import BaseTable

from .join import QJoin, INNER
Expand Down Expand Up @@ -86,7 +87,7 @@ def join(self, model_or_queryset, *filter_q, **filter_kw):
query.demote_joins(existing_inner)

parent = query.get_initial_alias()
query.join(QJoin(parent, self.name, self.name, on_clause, join_type))
query.join(QJoin(parent, self, self.name, on_clause, join_type))
return queryset

def queryset(self):
Expand Down Expand Up @@ -124,6 +125,11 @@ def _resolve_ref(self, name):
class CTEQuerySet(QuerySet):
"""QuerySet with support for Common Table Expressions"""

def _check_operator_queryset(self, other, operator_):
if self.query.combinator or other.query.combinator:
raise TypeError(
f"Cannot use {operator_} operator with combined queryset.")

def __init__(self, model=None, query=None, using=None, hints=None):
# Only create an instance of a Query if this is the first invocation in
# a query chain.
Expand All @@ -142,6 +148,46 @@ def with_cte(self, cte):
qs.query._with_ctes.append(cte)
return qs

def union(self, *other_qs, all=False):
# If the query is an EmptyQuerySet, combine all nonempty querysets.
if isinstance(self, EmptyQuerySet):
qs = [q for q in other_qs if not isinstance(q, EmptyQuerySet)]
if not qs:
return self
if len(qs) == 1:
return qs[0]
qs[0]._combinator_cte_query(*qs[1:])
return qs[0]._combinator_query("union", *qs[1:], all=all)
extra_qs = [q._chain() for q in other_qs]
qs = self._clone()
qs._combinator_cte_query(*extra_qs)
return qs._combinator_query("union", *extra_qs, all=all)

def intersection(self, *other_qs):
# If any query is an EmptyQuerySet, return it.
if isinstance(self, EmptyQuerySet):
return self
for other in other_qs:
if isinstance(other, EmptyQuerySet):
return other
other_qs = [q._chain() for q in other_qs]
qs = self._clone()
qs._combinator_cte_query(*other_qs)
return qs._combinator_query("intersection", *other_qs)

def difference(self, *other_qs):
# If the query is an EmptyQuerySet, return it.
if isinstance(self, EmptyQuerySet):
return self
other_qs = [q._chain() for q in other_qs]
qs = self._clone()
qs._combinator_cte_query(*other_qs)
return qs._combinator_query("difference", *other_qs)

def _combinator_cte_query(self, *other_qs):
for other in other_qs:
other.query = self.query.combine_cte(other.query)

def as_manager(cls):
# Address the circular dependency between
# `CTEQuerySet` and `CTEManager`.
Expand Down
23 changes: 19 additions & 4 deletions django_cte/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class QJoin(object):

filtered_relation = None

def __init__(self, parent_alias, table_name, table_alias,
def __init__(self, parent_alias, cte, table_alias,
on_clause, join_type=INNER, nullable=None):
self._cte = cte
self.parent_alias = parent_alias
self.table_name = table_name
self.table_alias = table_alias
self._table_alias = table_alias
self.on_clause = on_clause
self.join_type = join_type # LOUTER or INNER
self.nullable = join_type != INNER if nullable is None else nullable
Expand All @@ -45,6 +45,21 @@ def __eq__(self, other):
def equals(self, other):
return self.identity == other.identity

@property
def table_alias(self):
return self._table_alias or self._cte.name

@table_alias.setter
def table_alias(self, value):
if value == self._cte.name:
self._table_alias = None
else:
self._table_alias = value

@property
def table_name(self):
return self._cte.name

def as_sql(self, compiler, connection):
"""Generate join clause SQL"""
on_clause_sql, params = self.on_clause.as_sql(compiler, connection)
Expand All @@ -64,7 +79,7 @@ def as_sql(self, compiler, connection):
def relabeled_clone(self, change_map):
return self.__class__(
parent_alias=change_map.get(self.parent_alias, self.parent_alias),
table_name=self.table_name,
cte=self._cte,
table_alias=change_map.get(self.table_alias, self.table_alias),
on_clause=self.on_clause.relabeled_clone(change_map),
join_type=self.join_type,
Expand Down
15 changes: 7 additions & 8 deletions django_cte/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class CTEColumn(Expression):

def __init__(self, cte, name, output_field=None):
self._cte = cte
self.table_alias = cte.name
self._table_alias = cte.name
self.name = self.alias = name
self._output_field = output_field

Expand All @@ -43,6 +43,12 @@ def _ref(self):
raise ValueError("Circular reference: {} = {}".format(self, ref))
return ref

@property
def table_alias(self):
if self._cte.query is None:
raise AttributeError
return self._cte.name

@property
def target(self):
return self._ref.target
Expand All @@ -67,13 +73,6 @@ def as_sql(self, compiler, connection):
column = self.name
return "%s.%s" % (qn(self.table_alias), qn(column)), []

def relabeled_clone(self, relabels):
if self.table_alias is not None and self.table_alias in relabels:
clone = self.copy()
clone.table_alias = relabels[self.table_alias]
return clone
return self


class CTEColumnRef(Expression):

Expand Down
32 changes: 32 additions & 0 deletions django_cte/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,37 @@ def combine(self, other, connector):
self._with_ctes = other._with_ctes[:]
return super(CTEQuery, self).combine(other, connector)

def combine_cte(self, other):
relabel = {}
_names = []
for cte in self._with_ctes:
_names.append(cte.name)
for cte in other._with_ctes[:]:
if cte.name in _names:
if cte.name not in relabel.keys():
# FIXME: we need a better way to generate names
# also this prevent CTE reuse? should we identify where
# they are the same? this will give performance improvement
# in the materialization of the CTE
relabel[cte.name] = '%s%s' % (cte.name, '1')
_names.append(cte.name)
if relabel:
other = other.relabeled_clone(relabel)
for cte in other._with_ctes[:]:
if cte not in self._with_ctes:
self._with_ctes.append(cte)

other = other.clone()
other._with_ctes = []
return other

def relabeled_clone(self, change_map):
obj = super().relabeled_clone(change_map)
for cte in self._with_ctes[:]:
if cte.name in change_map:
cte.name = change_map[cte.name]
return obj

def get_compiler(self, using=None, connection=None, *args, **kwargs):
""" Overrides the Query method get_compiler in order to return
a CTECompiler.
Expand All @@ -51,6 +82,7 @@ def add_annotation(self, annotation, *args, **kw):
def __chain(self, _name, klass=None, *args, **kwargs):
klass = QUERY_TYPES.get(klass, self.__class__)
clone = getattr(super(CTEQuery, self), _name)(klass, *args, **kwargs)
# Should we clone the cte here?
clone._with_ctes = self._with_ctes[:]
return clone

Expand Down
Loading