Skip to content

Improve HNSW filtered search speed through new heuristic #126876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126876.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126876
summary: Improve HNSW filtered search speed through new heuristic
area: Vector Search
type: enhancement
issues: []
6 changes: 6 additions & 0 deletions docs/reference/elasticsearch/index-settings/index-modules.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,9 @@ $$$index-final-pipeline$$$

$$$index-hidden$$$ `index.hidden`
: Indicates whether the index should be hidden by default. Hidden indices are not returned by default when using a wildcard expression. This behavior is controlled per request through the use of the `expand_wildcards` parameter. Possible values are `true` and `false` (default).

$$$index-dense-vector-hnsw-filter-heuristic$$$ `index.dense_vector.hnsw_filter_heuristic`
: The heuristic to utilize when executing a filtered search against vectors in an HNSW graph. This setting is in technical preview may be changed or removed in a future release. It can be set to:

* `acorn` (default) - Only vectors that match the filter criteria are searched. This is the fastest option, and generally provides faster searches at similar recall to `fanout`, but `num_candidates` might need to be increased for exceptionally high recall requirements.
* `fanout` - All vectors are compared with the query vector, but only those passing the criteria are added to the search results. Can be slower than `acorn`, but may yield higher recall.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.similarity.SimilarityService;
import org.elasticsearch.index.store.FsDirectoryFactory;
import org.elasticsearch.index.store.Store;
Expand Down Expand Up @@ -157,6 +158,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings {
IndexSettings.INDEX_TRANSLOG_RETENTION_AGE_SETTING,
IndexSettings.INDEX_TRANSLOG_RETENTION_SIZE_SETTING,
IndexSettings.INDEX_SEARCH_IDLE_AFTER,
DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC,
IndexFieldDataService.INDEX_FIELDDATA_CACHE_KEY,
IndexSettings.IGNORE_ABOVE_SETTING,
FieldMapper.IGNORE_MALFORMED_SETTING,
Expand Down
16 changes: 16 additions & 0 deletions server/src/main/java/org/elasticsearch/index/IndexSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.index.mapper.IgnoredSourceFieldMapper;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.translog.Translog;
import org.elasticsearch.indices.recovery.RecoverySettings;
import org.elasticsearch.ingest.IngestService;
Expand Down Expand Up @@ -896,6 +897,7 @@ private void setRetentionLeaseMillis(final TimeValue retentionLease) {
private volatile int maxTokenCount;
private volatile int maxNgramDiff;
private volatile int maxShingleDiff;
private volatile DenseVectorFieldMapper.FilterHeuristic hnswFilterHeuristic;
private volatile TimeValue searchIdleAfter;
private volatile int maxAnalyzedOffset;
private volatile boolean weightMatchesEnabled;
Expand Down Expand Up @@ -1091,6 +1093,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti
logsdbAddHostNameField = scopedSettings.get(LOGSDB_ADD_HOST_NAME_FIELD);
skipIgnoredSourceWrite = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING);
skipIgnoredSourceRead = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING);
hnswFilterHeuristic = scopedSettings.get(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC);
indexMappingSourceMode = scopedSettings.get(INDEX_MAPPER_SOURCE_MODE_SETTING);
recoverySourceEnabled = RecoverySettings.INDICES_RECOVERY_SOURCE_ENABLED_SETTING.get(nodeSettings);
recoverySourceSyntheticEnabled = DiscoveryNode.isStateless(nodeSettings) == false
Expand Down Expand Up @@ -1203,6 +1206,7 @@ public IndexSettings(final IndexMetadata indexMetadata, final Settings nodeSetti
this::setSkipIgnoredSourceWrite
);
scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead);
scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, this::setHnswFilterHeuristic);
}

private void setSearchIdleAfter(TimeValue searchIdleAfter) {
Expand Down Expand Up @@ -1821,4 +1825,16 @@ public TimestampBounds getTimestampBounds() {
public IndexRouting getIndexRouting() {
return indexRouting;
}

/**
* The heuristic to utilize when executing filtered search on vectors indexed
* in HNSW format.
*/
public DenseVectorFieldMapper.FilterHeuristic getHnswFilterHeuristic() {
return this.hnswFilterHeuristic;
}

private void setHnswFilterHeuristic(DenseVectorFieldMapper.FilterHeuristic heuristic) {
this.hnswFilterHeuristic = heuristic;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.IndexVersion;
Expand Down Expand Up @@ -109,6 +111,46 @@ public static boolean isNotUnitVector(float magnitude) {
return Math.abs(magnitude - 1.0f) > EPS;
}

/**
* The heuristic to utilize when executing a filtered search against vectors indexed in an HNSW graph.
*/
public enum FilterHeuristic {
/**
* This heuristic searches the entire graph, doing vector comparisons in all immediate neighbors
* but only collects vectors that match the filtering criteria.
*/
FANOUT {
static final KnnSearchStrategy FANOUT_STRATEGY = new KnnSearchStrategy.Hnsw(0);

@Override
public KnnSearchStrategy getKnnSearchStrategy() {
return FANOUT_STRATEGY;
}
},
/**
* This heuristic will only compare vectors that match the filtering criteria.
*/
ACORN {
static final KnnSearchStrategy ACORN_STRATEGY = new KnnSearchStrategy.Hnsw(50);

@Override
public KnnSearchStrategy getKnnSearchStrategy() {
return ACORN_STRATEGY;
}
};

public abstract KnnSearchStrategy getKnnSearchStrategy();
}

public static final Setting<FilterHeuristic> HNSW_FILTER_HEURISTIC = Setting.enumSetting(
FilterHeuristic.class,
"index.dense_vector.hnsw_filter_heuristic",
FilterHeuristic.ACORN,
fh -> {},
Setting.Property.IndexScope,
Setting.Property.Dynamic
);

public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0;
public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION;
public static final IndexVersion NORMALIZE_COSINE = IndexVersions.NORMALIZED_VECTOR_COSINE;
Expand Down Expand Up @@ -2159,25 +2201,44 @@ public Query createKnnQuery(
Float oversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
BitSetProducer parentFilter,
DenseVectorFieldMapper.FilterHeuristic heuristic
) {
if (isIndexed() == false) {
throw new IllegalArgumentException(
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
);
}
KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy();
return switch (getElementType()) {
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
case BYTE -> createKnnByteQuery(
queryVector.asByteVector(),
k,
numCands,
filter,
similarityThreshold,
parentFilter,
knnSearchStrategy
);
case FLOAT -> createKnnFloatQuery(
queryVector.asFloatVector(),
k,
numCands,
oversample,
filter,
similarityThreshold,
parentFilter
parentFilter,
knnSearchStrategy
);
case BIT -> createKnnBitQuery(
queryVector.asByteVector(),
k,
numCands,
filter,
similarityThreshold,
parentFilter,
knnSearchStrategy
);
case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
};
}

Expand All @@ -2195,12 +2256,13 @@ private Query createKnnBitQuery(
int numCands,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
BitSetProducer parentFilter,
KnnSearchStrategy searchStrategy
) {
elementType.checkDimensions(dims, queryVector.length);
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
Expand All @@ -2217,7 +2279,8 @@ private Query createKnnByteQuery(
int numCands,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
BitSetProducer parentFilter,
KnnSearchStrategy searchStrategy
) {
elementType.checkDimensions(dims, queryVector.length);

Expand All @@ -2226,8 +2289,8 @@ private Query createKnnByteQuery(
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
Expand All @@ -2245,7 +2308,8 @@ private Query createKnnFloatQuery(
Float queryOversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
BitSetProducer parentFilter,
KnnSearchStrategy knnSearchStrategy
) {
elementType.checkDimensions(dims, queryVector.length);
elementType.checkVectorBounds(queryVector);
Expand Down Expand Up @@ -2279,8 +2343,16 @@ && isNotUnitVector(squaredMagnitude)) {
numCands = Math.max(adjustedK, numCands);
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, numCands, parentFilter)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter);
? new ESDiversifyingChildrenFloatKnnVectorQuery(
name(),
queryVector,
filter,
adjustedK,
numCands,
parentFilter,
knnSearchStrategy
)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
if (rescore) {
knnQuery = new RescoreKnnVectorQuery(
name(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider {
Expand All @@ -25,9 +26,10 @@ public ESDiversifyingChildrenByteKnnVectorQuery(
Query childFilter,
Integer k,
int numCands,
BitSetProducer parentsFilter
BitSetProducer parentsFilter,
KnnSearchStrategy strategy
) {
super(field, query, childFilter, numCands, parentsFilter);
super(field, query, childFilter, numCands, parentsFilter, strategy);
this.kParam = k;
}

Expand All @@ -42,4 +44,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
public void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

public KnnSearchStrategy getStrategy() {
return searchStrategy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements QueryProfilerProvider {
Expand All @@ -25,9 +26,10 @@ public ESDiversifyingChildrenFloatKnnVectorQuery(
Query childFilter,
Integer k,
int numCands,
BitSetProducer parentsFilter
BitSetProducer parentsFilter,
KnnSearchStrategy strategy
) {
super(field, query, childFilter, numCands, parentsFilter);
super(field, query, childFilter, numCands, parentsFilter, strategy);
this.kParam = k;
}

Expand All @@ -42,4 +44,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
public void profile(QueryProfiler queryProfiler) {
queryProfiler.addVectorOpsCount(vectorOpsCount);
}

public KnnSearchStrategy getStrategy() {
return searchStrategy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider {
private final Integer kParam;
private long vectorOpsCount;

public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter) {
super(field, target, numCands, filter);
public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
}

Expand All @@ -39,4 +40,8 @@ public void profile(QueryProfiler queryProfiler) {
public Integer kParam() {
return kParam;
}

public KnnSearchStrategy getStrategy() {
return searchStrategy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider {
private final Integer kParam;
private long vectorOpsCount;

public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter) {
super(field, target, numCands, filter);
public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter, KnnSearchStrategy strategy) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
}

Expand All @@ -39,4 +40,8 @@ public void profile(QueryProfiler queryProfiler) {
public Integer kParam() {
return kParam;
}

public KnnSearchStrategy getStrategy() {
return searchStrategy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,17 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
}
}

return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, oversample, filterQuery, vectorSimilarity, parentBitSet);
DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic();
return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
oversample,
filterQuery,
vectorSimilarity,
parentBitSet,
heuristic
);
}

@Override
Expand Down
Loading
Loading