Skip to content

Commit

Permalink
Merge pull request linuxlewis#47 from ministryofjustice/master
Browse files Browse the repository at this point in the history
  • Loading branch information
Paweł Kowalski committed Nov 12, 2015
2 parents 833133a + a6dbbd8 commit 906851b
Showing 1 changed file with 62 additions and 11 deletions.
73 changes: 62 additions & 11 deletions djorm_pgfulltext/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from itertools import repeat
from django.db import models, connections
from django.db.models.query import QuerySet
from django.db.models.sql.aggregates import Aggregate as SQLAggregate
from django.db.models import Aggregate
import re

from djorm_pgfulltext.utils import adapt

Expand Down Expand Up @@ -50,6 +53,20 @@ def auto_update_search_field_handler(sender, instance, *args, **kwargs):
instance.update_search_field()


class SQLStringAgg(SQLAggregate):
sql_template = '%(function)s(%(field)s, \' \')'
sql_function = 'string_agg'


class StringAgg(Aggregate):
name = 'StringAgg'

def add_to_query(self, query, alias, col, source, is_summary):
aggregate = SQLStringAgg(
col, source=source, is_summary=is_summary, **self.extra)
query.aggregates[alias] = aggregate


class SearchManagerMixIn(object):
"""
A mixin to create a Manager with a 'search' method that may do a full text search
Expand Down Expand Up @@ -130,6 +147,9 @@ def get_queryset(self):
def search(self, *args, **kwargs):
return self.get_queryset().search(*args, **kwargs)

def word_tree_search(self, *args, **kwargs):
return self.get_queryset().word_tree_search(*args, **kwargs)

def update_search_field(self, pk=None, config=None, using=None):
'''
Update the search_field of one instance, or a list of instances, or
Expand Down Expand Up @@ -196,8 +216,9 @@ def _parse_fields(self, fields):
parsed_fields.update([(x, None) for x in fields])

# Does not support field.attname.
field_names = set(field.name for field in self.model._meta.fields if not field.primary_key)
non_model_fields = set(x[0] for x in parsed_fields).difference(field_names)
to_search = (self.model._meta.fields + self.model._meta.many_to_many)
field_names = set(field.name for field in to_search if not field.primary_key)
non_model_fields = set(x[0].split('__')[0] for x in parsed_fields).difference(field_names)
if non_model_fields:
raise ValueError("The following fields do not exist in this"
" model: {0}".format(", ".join(x for x in non_model_fields)))
Expand Down Expand Up @@ -228,16 +249,36 @@ def _get_vector_for_field(self, field_name, weight=None, config=None, using=None
if not config:
config = self.config

return "setweight(to_tsvector('%s', coalesce(%s, '')), '%s')" % \
(config, self._get_field_value_query(field_name, using), weight)

def _get_field_value_query(self, field_name, using=None):
if using is None:
using = self.db

field = self.model._meta.get_field(field_name)

connection = connections[using]
qn = connection.ops.quote_name
if '__' not in field_name:
field = self.model._meta.get_field(field_name)
return '%s.%s' % (qn(self.model._meta.db_table), qn(field.column))

return "setweight(to_tsvector('%s', coalesce(%s.%s, '')), '%s')" % \
(config, qn(self.model._meta.db_table), qn(field.column), weight)
from_field, to_field = field_name.split('__')
field = self.model._meta.get_field(from_field)

model_pk = '%s.%s' % (qn(self.model._meta.db_table),
qn(self.model._meta.pk.column))
agg_name = '%s_agg' % to_field

q = field.rel.to.objects\
.filter(**{'%s__pk' % field.rel.related_name: 1})\
.annotate(**{agg_name: StringAgg(to_field)})\
.values(agg_name)

sql = re.sub(
r' GROUP(.*?)\)',
')',
'(%s)' % q.query.sql_with_params()[0] % model_pk)

return sql


class SearchQuerySet(QuerySet):
Expand Down Expand Up @@ -289,10 +330,10 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,

if query:
function = "to_tsquery" if raw else "plainto_tsquery"
ts_query = "%s('%s', %s)" % (
ts_query = "%s('%s', '%s')" % (
function,
config,
adapt(query)
query
)

full_search_field = "%s.%s" % (
Expand All @@ -304,14 +345,16 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,
# these fields. In other case, intent use of search_field if
# exists.
if fields:
search_vector = self.manager._get_search_vector(config, using, fields=fields)
search_vector = self.manager._get_search_vector(
config, using, fields=fields
)
else:
if not self.manager.search_field:
raise ValueError("search_field is not specified")

search_vector = full_search_field

where = " (%s) @@ (%s)" % (search_vector, ts_query)

select_dict, order = {}, []

if rank_field:
Expand All @@ -334,6 +377,14 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,

return qs

def word_tree_search(self, query, **kwargs):
if query:
kwargs['raw'] = True
query = re.sub('[^a-zA-Z0-9 ]+', '', query)
query = re.sub('[ ]+', ' & ', query.strip())
query = '%s:*' % query
return self.search(query, **kwargs)


class SearchManager(SearchManagerMixIn, models.Manager):
pass

0 comments on commit 906851b

Please sign in to comment.