diff --git a/djorm_pgfulltext/fields.py b/djorm_pgfulltext/fields.py index 429cc54..060bc85 100644 --- a/djorm_pgfulltext/fields.py +++ b/djorm_pgfulltext/fields.py @@ -4,13 +4,13 @@ from django.db import models from psycopg2.extensions import adapt + class VectorField(models.Field): def __init__(self, *args, **kwargs): kwargs['null'] = True kwargs['default'] = '' kwargs['editable'] = False kwargs['serialize'] = False - kwargs['db_index'] = True super(VectorField, self).__init__(*args, **kwargs) def db_type(self, *args, **kwargs): diff --git a/djorm_pgfulltext/models.py b/djorm_pgfulltext/models.py index 2ea2b0f..c107ca7 100644 --- a/djorm_pgfulltext/models.py +++ b/djorm_pgfulltext/models.py @@ -3,6 +3,9 @@ from itertools import repeat from django.db import models, connections from django.db.models.query import QuerySet +from django.db.models.sql.aggregates import Aggregate as SQLAggregate +from django.db.models import Aggregate +import re # Compatibility import and fixes section. @@ -47,6 +50,20 @@ def auto_update_search_field_handler(sender, instance, *args, **kwargs): instance.update_search_field() +class SQLStringAgg(SQLAggregate): + sql_template = '%(function)s(%(field)s, \' \')' + sql_function = 'string_agg' + + +class StringAgg(Aggregate): + name = 'StringAgg' + + def add_to_query(self, query, alias, col, source, is_summary): + aggregate = SQLStringAgg( + col, source=source, is_summary=is_summary, **self.extra) + query.aggregates[alias] = aggregate + + class SearchManagerMixIn(object): """ A mixin to create a Manager with a 'search' method that may do a full text search @@ -119,6 +136,9 @@ def get_queryset(self): def search(self, *args, **kwargs): return self.get_queryset().search(*args, **kwargs) + def word_tree_search(self, *args, **kwargs): + return self.get_queryset().word_tree_search(*args, **kwargs) + def update_search_field(self, pk=None, config=None, using=None): ''' Update the search_field of one instance, or a list of instances, or @@ -185,8 +205,9 @@ def _parse_fields(self, fields): parsed_fields.update([(x, None) for x in fields]) # Does not support field.attname. - field_names = set(field.name for field in self.model._meta.fields if not field.primary_key) - non_model_fields = set(x[0] for x in parsed_fields).difference(field_names) + to_search = (self.model._meta.fields + self.model._meta.many_to_many) + field_names = set(field.name for field in to_search if not field.primary_key) + non_model_fields = set(x[0].split('__')[0] for x in parsed_fields).difference(field_names) if non_model_fields: raise ValueError("The following fields do not exist in this" " model: {0}".format(", ".join(x for x in non_model_fields))) @@ -213,16 +234,36 @@ def _get_vector_for_field(self, field_name, weight=None, config=None, using=None if not config: config = self.config + return "setweight(to_tsvector('%s', coalesce(%s, '')), '%s')" % \ + (config, self._get_field_value_query(field_name, using), weight) + + def _get_field_value_query(self, field_name, using=None): if using is None: using = self.db - - field = self.model._meta.get_field(field_name) - connection = connections[using] qn = connection.ops.quote_name + if '__' not in field_name: + field = self.model._meta.get_field(field_name) + return '%s.%s' % (qn(self.model._meta.db_table), qn(field.column)) + + from_field, to_field = field_name.split('__') + field = self.model._meta.get_field(from_field) + + model_pk = '%s.%s' % (qn(self.model._meta.db_table), + qn(self.model._meta.pk.column)) + agg_name = '%s_agg' % to_field + + q = field.rel.to.objects\ + .filter(**{'%s__pk' % field.rel.related_name: 1})\ + .annotate(**{agg_name: StringAgg(to_field)})\ + .values(agg_name) - return "setweight(to_tsvector('%s', coalesce(%s.%s, '')), '%s')" % \ - (config, qn(self.model._meta.db_table), qn(field.column), weight) + sql = re.sub( + r' GROUP(.*?)\)', + ')', + '(%s)' % q.query.sql_with_params()[0] % model_pk) + + return sql class SearchQuerySet(QuerySet): @@ -319,6 +360,14 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None, return qs + def word_tree_search(self, query, **kwargs): + if query: + kwargs['raw'] = True + query = re.sub('[^a-zA-Z0-9 ]+', '', query) + query = re.sub('[ ]+', ' & ', query.strip()) + query = '%s:*' % query + return self.search(query, **kwargs) + class SearchManager(SearchManagerMixIn, models.Manager): pass