Skip to content

Commit 10d873b

Browse files
authored
[ML] InferenceService support aliases (elastic#128584) (elastic#128595)
"elser" is an alias for "elasticsearch", and "sagemaker" is an alias for "amazon_sagemaker". Users can continue to create and use providers by their alias. Elasticsearch will continue to support the alias when it reads the configuration from the internal index.
1 parent 2934488 commit 10d873b

File tree

10 files changed

+60
-21
lines changed

10 files changed

+60
-21
lines changed

docs/changelog/128584.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 128584
2+
summary: '`InferenceService` support aliases'
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ default void init(Client client) {}
2727

2828
String name();
2929

30+
/**
31+
* The aliases that map to {@link #name()}. {@link InferenceServiceRegistry} allows users to create and use inference services by one
32+
* of their aliases.
33+
*/
34+
default List<String> aliases() {
35+
return List.of();
36+
}
37+
3038
/**
3139
* Parse model configuration from the {@code config map} from a request and return
3240
* the parsed {@link Model}. This requires that both the secrets and service settings be contained in the

server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,22 @@
2424
public class InferenceServiceRegistry implements Closeable {
2525

2626
private final Map<String, InferenceService> services;
27+
private final Map<String, String> aliases;
2728
private final List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
2829

2930
public InferenceServiceRegistry(
3031
List<InferenceServiceExtension> inferenceServicePlugins,
3132
InferenceServiceExtension.InferenceServiceFactoryContext factoryContext
3233
) {
33-
// TODO check names are unique
34+
// toMap verifies that the names and aliases are unique
3435
services = inferenceServicePlugins.stream()
3536
.flatMap(r -> r.getInferenceServiceFactories().stream())
3637
.map(factory -> factory.create(factoryContext))
3738
.collect(Collectors.toMap(InferenceService::name, Function.identity()));
39+
aliases = services.values()
40+
.stream()
41+
.flatMap(service -> service.aliases().stream().distinct().map(alias -> Map.entry(alias, service.name())))
42+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
3843
}
3944

4045
public void init(Client client) {
@@ -56,13 +61,8 @@ public Map<String, InferenceService> getServices() {
5661
}
5762

5863
public Optional<InferenceService> getService(String serviceName) {
59-
60-
if ("elser".equals(serviceName)) { // ElserService.NAME before removal
61-
// here we are aliasing the elser service to use the elasticsearch service instead
62-
return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME
63-
} else {
64-
return Optional.ofNullable(services.get(serviceName));
65-
}
64+
var serviceKey = aliases.getOrDefault(serviceName, serviceName);
65+
return Optional.ofNullable(services.get(serviceKey));
6666
}
6767

6868
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public void testDefaultModels() throws IOException {
6565
var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
6666
assertDefaultRerankConfig(rerankModel);
6767

68-
putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
68+
putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
6969
var registeredModels = getMinimalConfigs();
7070
assertThat(registeredModels.size(), equalTo(1));
7171
assertTrue(registeredModels.containsKey("my-model"));

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,12 @@ static String updateConfig(@Nullable TaskType taskTypeInBody, String apiKey, int
119119
""", taskType, apiKey, temperature);
120120
}
121121

122-
static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) {
122+
static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody, String service) {
123123
var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
124124
return Strings.format("""
125125
{
126126
%s
127-
"service": "streaming_completion_test_service",
127+
"service": "%s",
128128
"service_settings": {
129129
"model": "my_model",
130130
"api_key": "abc64"
@@ -133,7 +133,7 @@ static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody
133133
"temperature": 3
134134
}
135135
}
136-
""", taskType);
136+
""", taskType, service);
137137
}
138138

139139
static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws
305305

306306
public void testUnsupportedStream() throws Exception {
307307
String modelId = "streaming";
308-
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
308+
putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
309309
var singleModel = getModel(modelId);
310310
assertEquals(modelId, singleModel.get("inference_id"));
311311
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));
@@ -326,8 +326,16 @@ public void testUnsupportedStream() throws Exception {
326326
}
327327

328328
public void testSupportedStream() throws Exception {
329+
testSupportedStream("streaming_completion_test_service");
330+
}
331+
332+
public void testSupportedStreamForAlias() throws Exception {
333+
testSupportedStream("streaming_completion_test_service_alias");
334+
}
335+
336+
private void testSupportedStream(String serviceName) throws Exception {
329337
String modelId = "streaming";
330-
putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION));
338+
putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION, serviceName));
331339
var singleModel = getModel(modelId);
332340
assertEquals(modelId, singleModel.get("inference_id"));
333341
assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type"));
@@ -352,7 +360,7 @@ public void testSupportedStream() throws Exception {
352360

353361
public void testUnifiedCompletionInference() throws Exception {
354362
String modelId = "streaming";
355-
putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION));
363+
putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION, "streaming_completion_test_service"));
356364
var singleModel = getModel(modelId);
357365
assertEquals(modelId, singleModel.get("inference_id"));
358366
assertEquals(TaskType.CHAT_COMPLETION.toString(), singleModel.get("task_type"));

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
5454
"text_embedding_test_service",
5555
"voyageai",
5656
"watsonxai",
57-
"sagemaker"
57+
"amazon_sagemaker"
5858
).toArray()
5959
)
6060
);
@@ -93,7 +93,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
9393
"text_embedding_test_service",
9494
"voyageai",
9595
"watsonxai",
96-
"sagemaker"
96+
"amazon_sagemaker"
9797
).toArray()
9898
)
9999
);
@@ -143,7 +143,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
143143
"openai",
144144
"streaming_completion_test_service",
145145
"hugging_face",
146-
"sagemaker"
146+
"amazon_sagemaker"
147147
).toArray()
148148
)
149149
);
@@ -158,7 +158,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
158158
assertThat(
159159
providers,
160160
containsInAnyOrder(
161-
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "sagemaker").toArray()
161+
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker").toArray()
162162
)
163163
);
164164
}

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public List<Factory> getInferenceServiceFactories() {
6060

6161
public static class TestInferenceService extends AbstractTestInferenceService {
6262
private static final String NAME = "streaming_completion_test_service";
63+
private static final String ALIAS = "streaming_completion_test_service_alias";
6364
private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
6465

6566
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
@@ -75,6 +76,11 @@ public String name() {
7576
return NAME;
7677
}
7778

79+
@Override
80+
public List<String> aliases() {
81+
return List.of(ALIAS);
82+
}
83+
7884
@Override
7985
protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
8086
return TestServiceSettings.fromMap(serviceSettingsMap);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,11 @@ public String name() {
778778
return NAME;
779779
}
780780

781+
@Override
782+
public List<String> aliases() {
783+
return List.of(OLD_ELSER_SERVICE_NAME);
784+
}
785+
781786
private RankedDocsResults textSimilarityResultsToRankedDocs(
782787
List<? extends InferenceResults> results,
783788
Function<Integer, String> inputSupplier,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails;
4646

4747
public class SageMakerService implements InferenceService {
48-
public static final String NAME = "sagemaker";
48+
public static final String NAME = "amazon_sagemaker";
49+
private static final String DISPLAY_NAME = "Amazon SageMaker";
50+
private static final List<String> ALIASES = List.of("sagemaker", "amazonsagemaker");
4951
private static final int DEFAULT_BATCH_SIZE = 256;
5052
private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
5153
private final SageMakerModelBuilder modelBuilder;
@@ -67,7 +69,7 @@ public SageMakerService(
6769
this.threadPool = threadPool;
6870
this.configuration = new LazyInitializable<>(
6971
() -> new InferenceServiceConfiguration.Builder().setService(NAME)
70-
.setName("Amazon SageMaker")
72+
.setName(DISPLAY_NAME)
7173
.setTaskTypes(supportedTaskTypes())
7274
.setConfigurations(configurationMap.get())
7375
.build()
@@ -79,6 +81,11 @@ public String name() {
7981
return NAME;
8082
}
8183

84+
@Override
85+
public List<String> aliases() {
86+
return ALIASES;
87+
}
88+
8289
@Override
8390
public void parseRequestConfig(
8491
String modelId,

0 commit comments

Comments
 (0)