@@ -41,14 +41,14 @@ public static void main(String[] args) throws IOException {
41
41
// https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings
42
42
String endpoint = "us-central1-aiplatform.googleapis.com:443" ;
43
43
String project = "YOUR_PROJECT_ID" ;
44
- String model = "text -embedding-005 " ;
44
+ String model = "gemini -embedding-001 " ;
45
45
predictTextEmbeddings (
46
46
endpoint ,
47
47
project ,
48
48
model ,
49
49
List .of ("banana bread?" , "banana muffins?" ),
50
50
"QUESTION_ANSWERING" ,
51
- OptionalInt .of (256 ));
51
+ OptionalInt .of (3072 ));
52
52
}
53
53
54
54
// Gets text embeddings from a pretrained, foundational model.
@@ -67,37 +67,40 @@ public static List<List<Float>> predictTextEmbeddings(
67
67
EndpointName endpointName =
68
68
EndpointName .ofProjectLocationPublisherModelName (project , location , "google" , model );
69
69
70
+ List <List <Float >> floats = new ArrayList <>();
70
71
// You can use this prediction service client for multiple requests.
71
72
try (PredictionServiceClient client = PredictionServiceClient .create (settings )) {
72
- PredictRequest .Builder request =
73
- PredictRequest .newBuilder ().setEndpoint (endpointName .toString ());
74
- if (outputDimensionality .isPresent ()) {
75
- request .setParameters (
76
- Value .newBuilder ()
77
- .setStructValue (
78
- Struct .newBuilder ()
79
- .putFields ("outputDimensionality" , valueOf (outputDimensionality .getAsInt ()))
80
- .build ()));
81
- }
73
+ // gemini-embedding-001 takes one input at a time.
82
74
for (int i = 0 ; i < texts .size (); i ++) {
75
+ PredictRequest .Builder request =
76
+ PredictRequest .newBuilder ().setEndpoint (endpointName .toString ());
77
+ if (outputDimensionality .isPresent ()) {
78
+ request .setParameters (
79
+ Value .newBuilder ()
80
+ .setStructValue (
81
+ Struct .newBuilder ()
82
+ .putFields (
83
+ "outputDimensionality" , valueOf (outputDimensionality .getAsInt ()))
84
+ .build ()));
85
+ }
83
86
request .addInstances (
84
87
Value .newBuilder ()
85
88
.setStructValue (
86
89
Struct .newBuilder ()
87
90
.putFields ("content" , valueOf (texts .get (i )))
88
91
.putFields ("task_type" , valueOf (task ))
89
92
.build ()));
90
- }
91
- PredictResponse response = client . predict ( request . build ());
92
- List < List < Float >> floats = new ArrayList <>();
93
- for ( Value prediction : response . getPredictionsList ()) {
94
- Value embeddings = prediction .getStructValue ().getFieldsOrThrow ("embeddings " );
95
- Value values = embeddings . getStructValue (). getFieldsOrThrow ( "values" );
96
- floats . add (
97
- values . getListValue (). getValuesList (). stream ( )
98
- .map (Value :: getNumberValue )
99
- . map ( Double :: floatValue )
100
- . collect ( toList ()));
93
+ PredictResponse response = client . predict ( request . build ());
94
+
95
+ for ( Value prediction : response . getPredictionsList ()) {
96
+ Value embeddings = prediction . getStructValue (). getFieldsOrThrow ( "embeddings" );
97
+ Value values = embeddings .getStructValue ().getFieldsOrThrow ("values " );
98
+ floats . add (
99
+ values . getListValue (). getValuesList (). stream ()
100
+ . map ( Value :: getNumberValue )
101
+ .map (Double :: floatValue )
102
+ . collect ( toList ()));
103
+ }
101
104
}
102
105
return floats ;
103
106
}
0 commit comments