4
4
from itertools import repeat
5
5
from django .db import models , connections
6
6
from django .db .models .query import QuerySet
7
+ from django .db .models .sql .aggregates import Aggregate as SQLAggregate
8
+ from django .db .models import Aggregate
9
+ import re
7
10
8
11
from djorm_pgfulltext .utils import adapt
9
12
@@ -50,6 +53,20 @@ def auto_update_search_field_handler(sender, instance, *args, **kwargs):
50
53
instance .update_search_field ()
51
54
52
55
56
+ class SQLStringAgg (SQLAggregate ):
57
+ sql_template = '%(function)s(%(field)s, \' \' )'
58
+ sql_function = 'string_agg'
59
+
60
+
61
+ class StringAgg (Aggregate ):
62
+ name = 'StringAgg'
63
+
64
+ def add_to_query (self , query , alias , col , source , is_summary ):
65
+ aggregate = SQLStringAgg (
66
+ col , source = source , is_summary = is_summary , ** self .extra )
67
+ query .aggregates [alias ] = aggregate
68
+
69
+
53
70
class SearchManagerMixIn (object ):
54
71
"""
55
72
A mixin to create a Manager with a 'search' method that may do a full text search
@@ -130,6 +147,9 @@ def get_queryset(self):
130
147
def search (self , * args , ** kwargs ):
131
148
return self .get_queryset ().search (* args , ** kwargs )
132
149
150
+ def word_tree_search (self , * args , ** kwargs ):
151
+ return self .get_queryset ().word_tree_search (* args , ** kwargs )
152
+
133
153
def update_search_field (self , pk = None , config = None , using = None ):
134
154
'''
135
155
Update the search_field of one instance, or a list of instances, or
@@ -196,8 +216,9 @@ def _parse_fields(self, fields):
196
216
parsed_fields .update ([(x , None ) for x in fields ])
197
217
198
218
# Does not support field.attname.
199
- field_names = set (field .name for field in self .model ._meta .fields if not field .primary_key )
200
- non_model_fields = set (x [0 ] for x in parsed_fields ).difference (field_names )
219
+ to_search = (self .model ._meta .fields + self .model ._meta .many_to_many )
220
+ field_names = set (field .name for field in to_search if not field .primary_key )
221
+ non_model_fields = set (x [0 ].split ('__' )[0 ] for x in parsed_fields ).difference (field_names )
201
222
if non_model_fields :
202
223
raise ValueError ("The following fields do not exist in this"
203
224
" model: {0}" .format (", " .join (x for x in non_model_fields )))
@@ -228,16 +249,36 @@ def _get_vector_for_field(self, field_name, weight=None, config=None, using=None
228
249
if not config :
229
250
config = self .config
230
251
252
+ return "setweight(to_tsvector('%s', coalesce(%s, '')), '%s')" % \
253
+ (config , self ._get_field_value_query (field_name , using ), weight )
254
+
255
+ def _get_field_value_query (self , field_name , using = None ):
231
256
if using is None :
232
257
using = self .db
233
-
234
- field = self .model ._meta .get_field (field_name )
235
-
236
258
connection = connections [using ]
237
259
qn = connection .ops .quote_name
260
+ if '__' not in field_name :
261
+ field = self .model ._meta .get_field (field_name )
262
+ return '%s.%s' % (qn (self .model ._meta .db_table ), qn (field .column ))
238
263
239
- return "setweight(to_tsvector('%s', coalesce(%s.%s, '')), '%s')" % \
240
- (config , qn (self .model ._meta .db_table ), qn (field .column ), weight )
264
+ from_field , to_field = field_name .split ('__' )
265
+ field = self .model ._meta .get_field (from_field )
266
+
267
+ model_pk = '%s.%s' % (qn (self .model ._meta .db_table ),
268
+ qn (self .model ._meta .pk .column ))
269
+ agg_name = '%s_agg' % to_field
270
+
271
+ q = field .rel .to .objects \
272
+ .filter (** {'%s__pk' % field .rel .related_name : 1 })\
273
+ .annotate (** {agg_name : StringAgg (to_field )})\
274
+ .values (agg_name )
275
+
276
+ sql = re .sub (
277
+ r' GROUP(.*?)\)' ,
278
+ ')' ,
279
+ '(%s)' % q .query .sql_with_params ()[0 ] % model_pk )
280
+
281
+ return sql
241
282
242
283
243
284
class SearchQuerySet (QuerySet ):
@@ -289,10 +330,10 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,
289
330
290
331
if query :
291
332
function = "to_tsquery" if raw else "plainto_tsquery"
292
- ts_query = "%s('%s', %s )" % (
333
+ ts_query = "%s('%s', '%s' )" % (
293
334
function ,
294
335
config ,
295
- adapt ( query )
336
+ query
296
337
)
297
338
298
339
full_search_field = "%s.%s" % (
@@ -304,14 +345,16 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,
304
345
# these fields. In other case, intent use of search_field if
305
346
# exists.
306
347
if fields :
307
- search_vector = self .manager ._get_search_vector (config , using , fields = fields )
348
+ search_vector = self .manager ._get_search_vector (
349
+ config , using , fields = fields
350
+ )
308
351
else :
309
352
if not self .manager .search_field :
310
353
raise ValueError ("search_field is not specified" )
311
-
312
354
search_vector = full_search_field
313
355
314
356
where = " (%s) @@ (%s)" % (search_vector , ts_query )
357
+
315
358
select_dict , order = {}, []
316
359
317
360
if rank_field :
@@ -334,6 +377,14 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,
334
377
335
378
return qs
336
379
380
+ def word_tree_search (self , query , ** kwargs ):
381
+ if query :
382
+ kwargs ['raw' ] = True
383
+ query = re .sub ('[^a-zA-Z0-9 ]+' , '' , query )
384
+ query = re .sub ('[ ]+' , ' & ' , query .strip ())
385
+ query = '%s:*' % query
386
+ return self .search (query , ** kwargs )
387
+
337
388
338
389
class SearchManager (SearchManagerMixIn , models .Manager ):
339
390
pass
0 commit comments