Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow many to many and foreign key fields in search #47

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion djorm_pgfulltext/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from django.db import models
from psycopg2.extensions import adapt


class VectorField(models.Field):
def __init__(self, *args, **kwargs):
kwargs['null'] = True
kwargs['default'] = ''
kwargs['editable'] = False
kwargs['serialize'] = False
kwargs['db_index'] = True
super(VectorField, self).__init__(*args, **kwargs)

def db_type(self, *args, **kwargs):
Expand Down
63 changes: 56 additions & 7 deletions djorm_pgfulltext/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from itertools import repeat
from django.db import models, connections
from django.db.models.query import QuerySet
from django.db.models.sql.aggregates import Aggregate as SQLAggregate
from django.db.models import Aggregate
import re

# Compatibility import and fixes section.

Expand Down Expand Up @@ -47,6 +50,20 @@ def auto_update_search_field_handler(sender, instance, *args, **kwargs):
instance.update_search_field()


class SQLStringAgg(SQLAggregate):
sql_template = '%(function)s(%(field)s, \' \')'
sql_function = 'string_agg'


class StringAgg(Aggregate):
name = 'StringAgg'

def add_to_query(self, query, alias, col, source, is_summary):
aggregate = SQLStringAgg(
col, source=source, is_summary=is_summary, **self.extra)
query.aggregates[alias] = aggregate


class SearchManagerMixIn(object):
"""
A mixin to create a Manager with a 'search' method that may do a full text search
Expand Down Expand Up @@ -119,6 +136,9 @@ def get_queryset(self):
def search(self, *args, **kwargs):
return self.get_queryset().search(*args, **kwargs)

def word_tree_search(self, *args, **kwargs):
return self.get_queryset().word_tree_search(*args, **kwargs)

def update_search_field(self, pk=None, config=None, using=None):
'''
Update the search_field of one instance, or a list of instances, or
Expand Down Expand Up @@ -185,8 +205,9 @@ def _parse_fields(self, fields):
parsed_fields.update([(x, None) for x in fields])

# Does not support field.attname.
field_names = set(field.name for field in self.model._meta.fields if not field.primary_key)
non_model_fields = set(x[0] for x in parsed_fields).difference(field_names)
to_search = (self.model._meta.fields + self.model._meta.many_to_many)
field_names = set(field.name for field in to_search if not field.primary_key)
non_model_fields = set(x[0].split('__')[0] for x in parsed_fields).difference(field_names)
if non_model_fields:
raise ValueError("The following fields do not exist in this"
" model: {0}".format(", ".join(x for x in non_model_fields)))
Expand All @@ -213,16 +234,36 @@ def _get_vector_for_field(self, field_name, weight=None, config=None, using=None
if not config:
config = self.config

return "setweight(to_tsvector('%s', coalesce(%s, '')), '%s')" % \
(config, self._get_field_value_query(field_name, using), weight)

def _get_field_value_query(self, field_name, using=None):
if using is None:
using = self.db

field = self.model._meta.get_field(field_name)

connection = connections[using]
qn = connection.ops.quote_name
if '__' not in field_name:
field = self.model._meta.get_field(field_name)
return '%s.%s' % (qn(self.model._meta.db_table), qn(field.column))

from_field, to_field = field_name.split('__')
field = self.model._meta.get_field(from_field)

model_pk = '%s.%s' % (qn(self.model._meta.db_table),
qn(self.model._meta.pk.column))
agg_name = '%s_agg' % to_field

q = field.rel.to.objects\
.filter(**{'%s__pk' % field.rel.related_name: 1})\
.annotate(**{agg_name: StringAgg(to_field)})\
.values(agg_name)

return "setweight(to_tsvector('%s', coalesce(%s.%s, '')), '%s')" % \
(config, qn(self.model._meta.db_table), qn(field.column), weight)
sql = re.sub(
r' GROUP(.*?)\)',
')',
'(%s)' % q.query.sql_with_params()[0] % model_pk)

return sql


class SearchQuerySet(QuerySet):
Expand Down Expand Up @@ -319,6 +360,14 @@ def search(self, query, rank_field=None, rank_function='ts_rank', config=None,

return qs

def word_tree_search(self, query, **kwargs):
if query:
kwargs['raw'] = True
query = re.sub('[^a-zA-Z0-9 ]+', '', query)
query = re.sub('[ ]+', ' & ', query.strip())
query = '%s:*' % query
return self.search(query, **kwargs)


class SearchManager(SearchManagerMixIn, models.Manager):
pass
Expand Down