Skip to content

Commit

Permalink
Remove duplication of FieldType (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexklibisz authored Jun 25, 2020
1 parent 859d594 commit cce765f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import io.circe.syntax._
import io.circe.{Json, JsonObject}
import org.apache.lucene.index.{IndexOptions, IndexableField, Term}
import org.apache.lucene.search.similarities.BooleanSimilarity
import org.apache.lucene.search.{DocValuesFieldExistsQuery, Query, TermInSetQuery, TermQuery}
import org.apache.lucene.document
import org.apache.lucene.search.{DocValuesFieldExistsQuery, Query, TermQuery}
import org.apache.lucene.util.BytesRef
import org.elasticsearch.common.xcontent.{ToXContent, XContentBuilder}
import org.elasticsearch.index.mapper.Mapper.TypeParser
Expand All @@ -32,9 +31,9 @@ object VectorMapper {
val sorted = vec.sorted() // Sort for faster intersections on the query side.
mapping match {
case Mapping.SparseBool(_) => Try(ExactQuery.index(field, sorted))
case Mapping.SparseIndexed(_) => Try(SparseIndexedQuery.index(field, sorted))
case m: Mapping.JaccardLsh => Try(LshQuery.index(field, sorted, m))
case m: Mapping.HammingLsh => Try(LshQuery.index(field, sorted, m))
case Mapping.SparseIndexed(_) => Try(SparseIndexedQuery.index(field, fieldType, sorted))
case m: Mapping.JaccardLsh => Try(LshQuery.index(field, fieldType, sorted, m))
case m: Mapping.HammingLsh => Try(LshQuery.index(field, fieldType, sorted, m))
case _ => Failure(incompatible(mapping, vec))
}
}
Expand All @@ -48,8 +47,8 @@ object VectorMapper {
else
mapping match {
case Mapping.DenseFloat(_) => Try(ExactQuery.index(field, vec))
case m: Mapping.AngularLsh => Try(LshQuery.index(field, vec, m))
case m: Mapping.L2Lsh => Try(LshQuery.index(field, vec, m))
case m: Mapping.AngularLsh => Try(LshQuery.index(field, fieldType, vec, m))
case m: Mapping.L2Lsh => Try(LshQuery.index(field, fieldType, vec, m))
case _ => Failure(incompatible(mapping, vec))
}
}
Expand All @@ -61,11 +60,11 @@ object VectorMapper {
class FieldType(typeName: String) extends MappedFieldType {

// We generally only care about the presence or absence of terms, not their counts or anything fancier.
this.setSimilarity(new SimilarityProvider("boolean", new BooleanSimilarity))
// TODO: Any way to dedup the settings here and in simpleTokenFieldType?
this.setOmitNorms(true)
this.setBoost(1f)
this.setTokenized(false)
setSimilarity(new SimilarityProvider("boolean", new BooleanSimilarity))
setOmitNorms(true)
setBoost(1f)
setTokenized(false)
setIndexOptions(IndexOptions.DOCS)

override def typeName(): String = typeName
override def clone(): FieldType = new FieldType(typeName)
Expand All @@ -79,23 +78,14 @@ object VectorMapper {
override def existsQuery(context: QueryShardContext): Query = new DocValuesFieldExistsQuery(name())
}

val simpleTokenFieldType: document.FieldType = {
val ft = new document.FieldType
ft.setIndexOptions(IndexOptions.DOCS)
ft.setTokenized(false)
ft.setOmitNorms(true)
ft.freeze()
ft
}

}

abstract class VectorMapper[V <: Vec: ElasticsearchCodec] { self =>

val CONTENT_TYPE: String
def checkAndCreateFields(mapping: Mapping, field: String, vec: V): Try[Seq[IndexableField]]

private val fieldType = new VectorMapper.FieldType(CONTENT_TYPE)
protected val fieldType = new VectorMapper.FieldType(CONTENT_TYPE)

import com.klibisz.elastiknn.utils.CirceUtils.javaMapEncoder

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.apache.lucene.queries.mlt.MoreLikeThis
import org.apache.lucene.search._
import org.apache.lucene.util.BytesRef
import org.elasticsearch.common.lucene.search.function.{CombineFunction, FunctionScoreQuery, LeafScoreFunction, ScoreFunction}
import org.elasticsearch.index.mapper.MappedFieldType

object LshQuery {

Expand Down Expand Up @@ -98,10 +99,10 @@ object LshQuery {
new FunctionScoreQuery(isecQuery, func, CombineFunction.REPLACE, 0f, Float.MaxValue)
}

def index[M <: Mapping, V <: Vec: StoredVec.Encoder, S <: StoredVec](field: String, vec: V, mapping: M)(
def index[M <: Mapping, V <: Vec: StoredVec.Encoder, S <: StoredVec](field: String, fieldType: MappedFieldType, vec: V, mapping: M)(
implicit lshFunctionCache: LshFunctionCache[M, V, S]): Seq[IndexableField] = {
ExactQuery.index(field, vec) ++ lshFunctionCache(mapping)(vec).map { h =>
new Field(field, UnsafeSerialization.writeInt(h), VectorMapper.simpleTokenFieldType)
new Field(field, UnsafeSerialization.writeInt(h), fieldType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.apache.lucene.index._
import org.apache.lucene.search._
import org.apache.lucene.util.BytesRef
import org.elasticsearch.common.lucene.search.function.{CombineFunction, FunctionScoreQuery, LeafScoreFunction, ScoreFunction}
import org.elasticsearch.index.mapper.MappedFieldType

object SparseIndexedQuery {

Expand Down Expand Up @@ -64,9 +65,9 @@ object SparseIndexedQuery {

def numTrueDocValueField(field: String): String = s"$field.num_true"

def index(field: String, vec: Vec.SparseBool): Seq[IndexableField] = {
def index(field: String, fieldType: MappedFieldType, vec: Vec.SparseBool): Seq[IndexableField] = {
vec.trueIndices.map { ti =>
new Field(field, UnsafeSerialization.writeInt(ti), VectorMapper.simpleTokenFieldType)
new Field(field, UnsafeSerialization.writeInt(ti), fieldType)
} ++ ExactQuery.index(field, vec) :+ new NumericDocValuesField(numTrueDocValueField(field), vec.trueIndices.length)
}

Expand Down

0 comments on commit cce765f

Please sign in to comment.