Skip to content

Commit 69cd787

Browse files
authored
feat: Update model name to gemini-embedding-001 (#10098)
* feat: Update model name to gemini-embedding-001 * Update to process one text per API call. * fix lints
1 parent a6afe6d commit 69cd787

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

aiplatform/src/main/java/aiplatform/PredictTextEmbeddingsSample.java

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ public static void main(String[] args) throws IOException {
4141
// https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings
4242
String endpoint = "us-central1-aiplatform.googleapis.com:443";
4343
String project = "YOUR_PROJECT_ID";
44-
String model = "text-embedding-005";
44+
String model = "gemini-embedding-001";
4545
predictTextEmbeddings(
4646
endpoint,
4747
project,
4848
model,
4949
List.of("banana bread?", "banana muffins?"),
5050
"QUESTION_ANSWERING",
51-
OptionalInt.of(256));
51+
OptionalInt.of(3072));
5252
}
5353

5454
// Gets text embeddings from a pretrained, foundational model.
@@ -67,37 +67,40 @@ public static List<List<Float>> predictTextEmbeddings(
6767
EndpointName endpointName =
6868
EndpointName.ofProjectLocationPublisherModelName(project, location, "google", model);
6969

70+
List<List<Float>> floats = new ArrayList<>();
7071
// You can use this prediction service client for multiple requests.
7172
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
72-
PredictRequest.Builder request =
73-
PredictRequest.newBuilder().setEndpoint(endpointName.toString());
74-
if (outputDimensionality.isPresent()) {
75-
request.setParameters(
76-
Value.newBuilder()
77-
.setStructValue(
78-
Struct.newBuilder()
79-
.putFields("outputDimensionality", valueOf(outputDimensionality.getAsInt()))
80-
.build()));
81-
}
73+
// gemini-embedding-001 takes one input at a time.
8274
for (int i = 0; i < texts.size(); i++) {
75+
PredictRequest.Builder request =
76+
PredictRequest.newBuilder().setEndpoint(endpointName.toString());
77+
if (outputDimensionality.isPresent()) {
78+
request.setParameters(
79+
Value.newBuilder()
80+
.setStructValue(
81+
Struct.newBuilder()
82+
.putFields(
83+
"outputDimensionality", valueOf(outputDimensionality.getAsInt()))
84+
.build()));
85+
}
8386
request.addInstances(
8487
Value.newBuilder()
8588
.setStructValue(
8689
Struct.newBuilder()
8790
.putFields("content", valueOf(texts.get(i)))
8891
.putFields("task_type", valueOf(task))
8992
.build()));
90-
}
91-
PredictResponse response = client.predict(request.build());
92-
List<List<Float>> floats = new ArrayList<>();
93-
for (Value prediction : response.getPredictionsList()) {
94-
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
95-
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
96-
floats.add(
97-
values.getListValue().getValuesList().stream()
98-
.map(Value::getNumberValue)
99-
.map(Double::floatValue)
100-
.collect(toList()));
93+
PredictResponse response = client.predict(request.build());
94+
95+
for (Value prediction : response.getPredictionsList()) {
96+
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
97+
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
98+
floats.add(
99+
values.getListValue().getValuesList().stream()
100+
.map(Value::getNumberValue)
101+
.map(Double::floatValue)
102+
.collect(toList()));
103+
}
101104
}
102105
return floats;
103106
}

aiplatform/src/test/java/aiplatform/PredictTextEmbeddingsSampleTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void testPredictTextEmbeddings() throws IOException {
5252
PredictTextEmbeddingsSample.predictTextEmbeddings(
5353
APIS_ENDPOINT,
5454
PROJECT,
55-
"text-embedding-005",
55+
"gemini-embedding-001",
5656
texts,
5757
"QUESTION_ANSWERING",
5858
OptionalInt.of(5));

0 commit comments

Comments
 (0)