Skip to content

Commit c955c1f

Browse files
author
Paweł Kowalski
committed
Merge pull request linuxlewis#47 from ministryofjustice/master
2 parents 833133a + a6dbbd8 commit c955c1f

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

djorm_pgfulltext/models.py

+62-11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from itertools import repeat
55
from django.db import models, connections
66
from django.db.models.query import QuerySet
7+
from django.db.models.sql.aggregates import Aggregate as SQLAggregate
8+
from django.db.models import Aggregate
9+
import re
710

811
from djorm_pgfulltext.utils import adapt
912

@@ -50,6 +53,20 @@ def auto_update_search_field_handler(sender, instance, *args, **kwargs):
5053
instance.update_search_field()
5154

5255

56+
class SQLStringAgg(SQLAggregate):
57+
sql_template = '%(function)s(%(field)s, \' \')'
58+
sql_function = 'string_agg'
59+
60+
61+
class StringAgg(Aggregate):
62+
name = 'StringAgg'
63+
64+
def add_to_query(self, query, alias, col, source, is_summary):
65+
aggregate = SQLStringAgg(
66+
col, source=source, is_summary=is_summary, **self.extra)
67+
query.aggregates[alias] = aggregate
68+
69+
5370
class SearchManagerMixIn(object):
5471
"""
5572
A mixin to create a Manager with a 'search' method that may do a full text search
@@ -130,6 +147,9 @@ def get_queryset(self):
130147
def search(self, *args, **kwargs):
131148
return self.get_queryset().search(*args, **kwargs)
132149

150+
def word_tree_search(self, *args, **kwargs):
151+
return self.get_queryset().word_tree_search(*args, **kwargs)
152+
133153
def update_search_field(self, pk=None, config=None, using=None):
134154
'''
135155
Update the search_field of one instance, or a list of instances, or
@@ -196,8 +216,9 @@ def _parse_fields(self, fields):
196216
parsed_fields.update([(x, None) for x in fields])
197217

198218
# Does not support field.attname.
199-
field_names = set(field.name for field in self.model._meta.fields if not field.primary_key)
200-
non_model_fields = set(x[0] for x in parsed_fields).difference(field_names)
219+
to_search = (self.model._meta.fields + self.model._meta.many_to_many)
220+
field_names = set(field.name for field in to_search if not field.primary_key)
221+
non_model_fields = set(x[0].split('__')[0] for x in parsed_fields).difference(field_names)
201222
if non_model_fields:
202223
raise ValueError("The following fields do not exist in this"
203224
" model: {0}".format(", ".join(x for x in non_model_fields)))
@@ -228,16 +249,36 @@ def _get_vector_for_field(self, field_name, weight=None, config=None, using=None
228249
if not config:
229250
config = self.config
230251

252+
return "setweight(to_tsvector('%s', coalesce(%s, '')), '%s')" % \
253+
(config, self._get_field_value_query(field_name, using), weight)
254+
255+
def _get_field_value_query(self, field_name, using=None):
231256
if using is None:
232257
using = self.db
233-
234-
field = self.model._meta.get_field(field_name)
235-
236258
connection = connections[using]
237259
qn = connection.ops.quote_name
260+
if '__' not in field_name:
261+
field = self.model._meta.get_field(field_name)
262+
return '%s.%s' % (qn(self.model._meta.db_table), qn(field.column))
238263

239-
return "setweight(to_tsvector('%s', coalesce(%s.%s, '')), '%s')" % \
240-
(config, qn(self.model._meta.db_table), qn(field.column), weight)
264+
from_field, to_field = field_name.split('__')
265+
field = self.model._meta.get_field(from_field)
266+
267+
model_pk = '%s.%s' % (qn(self.model._meta.db_table),
268+
qn(self.model._meta.pk.column))
269+
agg_name = '%s_agg' % to_field
270+
271+
q = field.rel.to.objects\
272+
.filter(**{'%s__pk' % field.rel.related_name: 1})\
273+
.annotate(**{agg_name: StringAgg(to_field)})\
274+
.values(agg_name)
275+
276+
sql = re.sub(
277+
r' GROUP(.*?)\)',
278+
')',
279+
'(%s)' % q.query.sql_with_params()[0] % model_pk)
280+
281+
return sql
241282

242283

243284
class SearchQuerySet(QuerySet):
@@ -289,10 +330,10 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,
289330

290331
if query:
291332
function = "to_tsquery" if raw else "plainto_tsquery"
292-
ts_query = "%s('%s', %s)" % (
333+
ts_query = "%s('%s', '%s')" % (
293334
function,
294335
config,
295-
adapt(query)
336+
query
296337
)
297338

298339
full_search_field = "%s.%s" % (
@@ -304,14 +345,16 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,
304345
# these fields. In other case, intent use of search_field if
305346
# exists.
306347
if fields:
307-
search_vector = self.manager._get_search_vector(config, using, fields=fields)
348+
search_vector = self.manager._get_search_vector(
349+
config, using, fields=fields
350+
)
308351
else:
309352
if not self.manager.search_field:
310353
raise ValueError("search_field is not specified")
311-
312354
search_vector = full_search_field
313355

314356
where = " (%s) @@ (%s)" % (search_vector, ts_query)
357+
315358
select_dict, order = {}, []
316359

317360
if rank_field:
@@ -334,6 +377,14 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,
334377

335378
return qs
336379

380+
def word_tree_search(self, query, **kwargs):
381+
if query:
382+
kwargs['raw'] = True
383+
query = re.sub('[^a-zA-Z0-9 ]+', '', query)
384+
query = re.sub('[ ]+', ' & ', query.strip())
385+
query = '%s:*' % query
386+
return self.search(query, **kwargs)
387+
337388

338389
class SearchManager(SearchManagerMixIn, models.Manager):
339390
pass

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = djorm-ext-pgfulltext
3-
version = 0.9.4
3+
version = 2015.11.12.dev1
44
author = Lovely Team
55
author-email = engineering@livelovely.com
66
summary = PostgreSQL Full Text Search integration with django orm.

0 commit comments

Comments
 (0)