Skip to content

Simplified RRF Retriever #129659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions docs/changelog/129659.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129659
summary: Simplified RRF Retriever
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder;
import org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder;

import java.util.Set;

Expand All @@ -34,7 +35,8 @@ public Set<NodeFeature> getTestFeatures() {
LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX,
LINEAR_RETRIEVER_L2_NORM,
LINEAR_RETRIEVER_MINSCORE_FIX,
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,27 @@
package org.elasticsearch.xpack.rank.rrf;

import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -29,7 +37,6 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
Expand All @@ -40,11 +47,14 @@
* formula.
*/
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");

public static final String NAME = "rrf";

public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
public static final ParseField FIELDS_FIELD = new ParseField("fields");
public static final ParseField QUERY_FIELD = new ParseField("query");

public static final int DEFAULT_RANK_CONSTANT = 60;
@SuppressWarnings("unchecked")
Expand All @@ -53,22 +63,29 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
false,
args -> {
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
List<RetrieverSource> innerRetrievers = childRetrievers.stream().map(RetrieverSource::from).toList();
int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1];
int rankConstant = args[2] == null ? DEFAULT_RANK_CONSTANT : (int) args[2];
return new RRFRetrieverBuilder(innerRetrievers, rankWindowSize, rankConstant);
List<String> fields = (List<String>) args[1];
String query = (String) args[2];
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];

List<RetrieverSource> innerRetrievers = childRetrievers != null
? childRetrievers.stream().map(RetrieverSource::from).toList()
: List.of();
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
}
);

static {
PARSER.declareObjectArray(constructorArg(), (p, c) -> {
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
p.nextToken();
String name = p.currentName();
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
c.trackRetrieverUsage(retrieverBuilder.getName());
p.nextToken();
return retrieverBuilder;
}, RETRIEVERS_FIELD);
PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
RetrieverBuilder.declareBaseParserFields(PARSER);
Expand All @@ -81,25 +98,60 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
return PARSER.apply(parser, context);
}

private final List<String> fields;
private final String query;
private final int rankConstant;

public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) {
this(new ArrayList<>(), rankWindowSize, rankConstant);
public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
this(childRetrievers, null, null, rankWindowSize, rankConstant);
}

RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
super(childRetrievers, rankWindowSize);
public RRFRetrieverBuilder(
List<RetrieverSource> childRetrievers,
List<String> fields,
String query,
int rankWindowSize,
int rankConstant
) {
// Use a mutable list for childRetrievers so that we can use addChild
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
this.fields = fields == null ? List.of() : List.copyOf(fields);
this.query = query;
this.rankConstant = rankConstant;
}

public int rankConstant() {
return rankConstant;
}

@Override
public String getName() {
return NAME;
}

@Override
public ActionRequestValidationException validate(
SearchSourceBuilder source,
ActionRequestValidationException validationException,
boolean isScroll,
boolean allowPartialSearchResults
) {
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
return MultiFieldsInnerRetrieverUtils.validateParams(
innerRetrievers,
fields,
query,
getName(),
RETRIEVERS_FIELD.getPreferredName(),
FIELDS_FIELD.getPreferredName(),
QUERY_FIELD.getPreferredName(),
validationException
);
}

@Override
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
clone.retrieverName = retrieverName;
return clone;
Expand Down Expand Up @@ -162,17 +214,72 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
return topResults;
}

@Override
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
RetrieverBuilder rewritten = this;

ResolvedIndices resolvedIndices = ctx.getResolvedIndices();
if (resolvedIndices != null && query != null) {
// TODO: Refactor duplicate code
// Using the multi-fields query format
var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
if (localIndicesMetadata.size() > 1) {
throw new IllegalArgumentException(
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices"
);
} else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
throw new IllegalArgumentException(
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices"
);
}
Comment on lines +223 to +234
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kderusso I know you requested that we refactor this common code, can we handle that in a follow up along with refactoring the common test code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, given the timing I'm fine with that happening in a followup


List<RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
fields,
query,
localIndicesMetadata.values(),
r -> {
List<RetrieverSource> retrievers = r.stream()
.map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
.toList();
return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
},
w -> {
if (w != 1.0f) {
throw new IllegalArgumentException(
"[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]"
);
}
}
).stream().map(RetrieverSource::from).toList();

if (fieldsInnerRetrievers.isEmpty() == false) {
// TODO: This is a incomplete solution as it does not address other incomplete copy issues
// (such as dropping the retriever name and min score)
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
} else {
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
rewritten = new StandardRetrieverBuilder(new MatchNoneQueryBuilder());
}
}

return rewritten;
}

// ---- FOR TESTING XCONTENT PARSING ----

@Override
public boolean doEquals(Object o) {
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
return super.doEquals(o) && rankConstant == that.rankConstant;
return super.doEquals(o)
&& Objects.equals(fields, that.fields)
&& Objects.equals(query, that.query)
&& rankConstant == that.rankConstant;
}

@Override
public int doHashCode() {
return Objects.hash(super.doHashCode(), rankConstant);
return Objects.hash(super.doHashCode(), fields, query, rankConstant);
}

@Override
Expand All @@ -186,6 +293,17 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
builder.endArray();
}

if (fields.isEmpty() == false) {
builder.startArray(FIELDS_FIELD.getPreferredName());
for (String field : fields) {
builder.value(field);
}
builder.endArray();
}
if (query != null) {
builder.field(QUERY_FIELD.getPreferredName(), query);
}

builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.common.Strings;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
Expand Down Expand Up @@ -45,13 +46,22 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() {
if (randomBoolean()) {
rankConstant = randomIntBetween(1, 1000000);
}
var ret = new RRFRetrieverBuilder(rankWindowSize, rankConstant);

List<String> fields = null;
String query = null;
if (randomBoolean()) {
fields = randomList(1, 10, () -> randomAlphaOfLengthBetween(1, 10));
query = randomAlphaOfLengthBetween(1, 10);
}

int retrieverCount = randomIntBetween(2, 50);
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>(retrieverCount);
while (retrieverCount > 0) {
ret.addChild(TestRetrieverBuilder.createRandomTestRetrieverBuilder());
innerRetrievers.add(CompoundRetrieverBuilder.RetrieverSource.from(TestRetrieverBuilder.createRandomTestRetrieverBuilder()));
--retrieverCount;
}
return ret;

return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
}

@Override
Expand Down Expand Up @@ -94,28 +104,32 @@ protected NamedXContentRegistry xContentRegistry() {
}

public void testRRFRetrieverParsing() throws IOException {
String restContent = "{"
+ " \"retriever\": {"
+ " \"rrf\": {"
+ " \"retrievers\": ["
+ " {"
+ " \"test\": {"
+ " \"value\": \"foo\""
+ " }"
+ " },"
+ " {"
+ " \"test\": {"
+ " \"value\": \"bar\""
+ " }"
+ " }"
+ " ],"
+ " \"rank_window_size\": 100,"
+ " \"rank_constant\": 10,"
+ " \"min_score\": 20.0,"
+ " \"_name\": \"foo_rrf\""
+ " }"
+ " }"
+ "}";
String restContent = """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we parse this twice, once with retrievers and once with field/query?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is only for XContent parsing purposes, the resulting retriever does not need to pass SearchRequest validation

{
"retriever": {
"rrf": {
"retrievers": [
{
"test": {
"value": "foo"
}
},
{
"test": {
"value": "bar"
}
}
],
"fields": ["field1", "field2"],
"query": "baz",
"rank_window_size": 100,
"rank_constant": 10,
"min_score": 20.0,
"_name": "foo_rrf"
}
}
}
""";
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);
Expand Down
Loading
Loading