@@ -74,7 +74,7 @@ public class VectorServer {
74
74
@ Getter
75
75
private Map <Integer , TreeMap <String , Integer >> preparedStatementParameters ;
76
76
@ Getter
77
- private TreeMap <Integer , String > preparedQueryString ;
77
+ private TreeMap <Integer , ServingPreparedStatement > orderedServingPreparedStatements ;
78
78
@ Getter
79
79
@ Setter
80
80
private HashSet <String > servingKeys ;
@@ -83,7 +83,7 @@ public class VectorServer {
83
83
private Schema .Parser parser = new Schema .Parser ();
84
84
private FeatureViewApi featureViewApi = new FeatureViewApi ();
85
85
86
- private Map <String , DatumReader <Object >> datumReadersComplexFeatures ;
86
+ private Map <Integer , Map < String , DatumReader <Object >>> featureGroupDatumReaders ;
87
87
private ExecutorService executorService = Executors .newCachedThreadPool ();
88
88
private boolean isBatch = false ;
89
89
private VariablesApi variablesApi = new VariablesApi ();
@@ -110,7 +110,7 @@ public List<Object> getFeatureVector(Map<String, Object> entry)
110
110
List <Object > servingVector = new ArrayList <>();
111
111
List <Future <List <Object >>> queryFutures = new ArrayList <>();
112
112
113
- for (Integer preparedStatementIndex : preparedQueryString .keySet ()) {
113
+ for (Integer preparedStatementIndex : orderedServingPreparedStatements .keySet ()) {
114
114
queryFutures .add (executorService .submit (() -> {
115
115
try {
116
116
return processQuery (entry , preparedStatementIndex );
@@ -136,8 +136,10 @@ private List<Object> processQuery(Map<String, Object> entry, int preparedStateme
136
136
List <Object > servingVector = new ArrayList <>();
137
137
try (Connection connection = hikariDataSource .getConnection ()) {
138
138
// Create the prepared statement
139
- PreparedStatement preparedStatement =
140
- connection .prepareStatement (preparedQueryString .get (preparedStatementIndex ));
139
+ ServingPreparedStatement servingPreparedStatement = orderedServingPreparedStatements .get (preparedStatementIndex );
140
+
141
+ System .out .println (servingPreparedStatement .getQueryOnline ());
142
+ PreparedStatement preparedStatement = connection .prepareStatement (servingPreparedStatement .getQueryOnline ());
141
143
142
144
// Set the parameters base do the entry object
143
145
Map <String , Integer > parameterIndexInStatement = preparedStatementParameters .get (preparedStatementIndex );
@@ -154,12 +156,20 @@ private List<Object> processQuery(Map<String, Object> entry, int preparedStateme
154
156
}
155
157
//Get column count
156
158
int columnCount = results .getMetaData ().getColumnCount ();
159
+
160
+ // get the complex schema datum readers for this feature group
161
+ Map <String , DatumReader <Object >> featuresDatumReaders =
162
+ featureGroupDatumReaders .get (servingPreparedStatement .getFeatureGroupId ());
163
+
157
164
//append results to servingVector
158
165
while (results .next ()) {
159
166
int index = 1 ;
160
167
while (index <= columnCount ) {
161
- if (datumReadersComplexFeatures .containsKey (results .getMetaData ().getColumnName (index ))) {
162
- servingVector .add (deserializeComplexFeature (datumReadersComplexFeatures , results , index ));
168
+ if (featuresDatumReaders != null
169
+ && featuresDatumReaders .containsKey (results .getMetaData ().getColumnName (index ))) {
170
+ servingVector .add (
171
+ deserializeComplexFeature (
172
+ featuresDatumReaders .get (results .getMetaData ().getColumnName (index )), results , index ));
163
173
} else {
164
174
servingVector .add (results .getObject (index ));
165
175
}
@@ -185,8 +195,8 @@ public List<List<Object>> getFeatureVectors(Map<String, List<Object>> entry)
185
195
throws SQLException , FeatureStoreException , IOException {
186
196
checkPrimaryKeys (entry .keySet ());
187
197
List <String > queries = Lists .newArrayList ();
188
- for (Integer fgId : preparedQueryString .keySet ()) {
189
- String query = preparedQueryString .get (fgId );
198
+ for (Integer fgId : orderedServingPreparedStatements .keySet ()) {
199
+ String query = orderedServingPreparedStatements .get (fgId ). getQueryOnline ( );
190
200
String zippedTupleString =
191
201
zipArraysToTupleString (preparedStatementParameters .get (fgId )
192
202
.entrySet ()
@@ -211,6 +221,8 @@ private List<List<Object>> getFeatureVectors(List<String> queries)
211
221
212
222
try (Connection connection = hikariDataSource .getConnection ()) {
213
223
try (Statement stmt = connection .createStatement ()) {
224
+ // Used to reference the ServingPreparedStatement for deserialization
225
+ int statementOrder = 0 ;
214
226
for (String query : queries ) {
215
227
int orderInBatch = 0 ;
216
228
@@ -224,12 +236,21 @@ private List<List<Object>> getFeatureVectors(List<String> queries)
224
236
}
225
237
//Get column count
226
238
int columnCount = results .getMetaData ().getColumnCount ();
239
+
240
+ // get the complex schema datum readers for this feature group
241
+ ServingPreparedStatement servingPreparedStatement = orderedServingPreparedStatements .get (statementOrder );
242
+ Map <String , DatumReader <Object >> featuresDatumReaders =
243
+ featureGroupDatumReaders .get (servingPreparedStatement .getFeatureGroupId ());
244
+
227
245
//append results to servingVector
228
246
while (results .next ()) {
229
247
int index = 1 ;
230
248
while (index <= columnCount ) {
231
- if (datumReadersComplexFeatures .containsKey (results .getMetaData ().getColumnName (index ))) {
232
- servingVector .add (deserializeComplexFeature (datumReadersComplexFeatures , results , index ));
249
+ if (featuresDatumReaders != null
250
+ && featuresDatumReaders .containsKey (results .getMetaData ().getColumnName (index ))) {
251
+ servingVector .add (
252
+ deserializeComplexFeature (
253
+ featuresDatumReaders .get (results .getMetaData ().getColumnName (index )), results , index ));
233
254
} else {
234
255
servingVector .add (results .getObject (index ));
235
256
}
@@ -246,6 +267,7 @@ private List<List<Object>> getFeatureVectors(List<String> queries)
246
267
orderInBatch ++;
247
268
}
248
269
}
270
+ statementOrder ++;
249
271
}
250
272
}
251
273
}
@@ -295,13 +317,13 @@ public void initPreparedStatement(FeatureStoreBase featureStoreBase,
295
317
Map <Integer , TreeMap <String , Integer >> preparedStatementParameters = new HashMap <>();
296
318
297
319
// in case its batch serving then we need to save sql string only
298
- TreeMap <Integer , String > preparedQueryString = new TreeMap <>();
320
+ TreeMap <Integer , ServingPreparedStatement > orderedServingPreparedStatements = new TreeMap <>();
299
321
300
322
// save unique primary key names that will be used by user to retrieve serving vector
301
323
HashSet <String > servingVectorKeys = new HashSet <>();
302
324
for (ServingPreparedStatement servingPreparedStatement : servingPreparedStatements ) {
303
- preparedQueryString .put (servingPreparedStatement .getPreparedStatementIndex (),
304
- servingPreparedStatement . getQueryOnline () );
325
+ orderedServingPreparedStatements .put (servingPreparedStatement .getPreparedStatementIndex (),
326
+ servingPreparedStatement );
305
327
TreeMap <String , Integer > parameterIndices = new TreeMap <>();
306
328
servingPreparedStatement .getPreparedStatementParameters ().forEach (preparedStatementParameter -> {
307
329
servingVectorKeys .add (preparedStatementParameter .getName ());
@@ -312,8 +334,8 @@ public void initPreparedStatement(FeatureStoreBase featureStoreBase,
312
334
this .servingKeys = servingVectorKeys ;
313
335
314
336
this .preparedStatementParameters = preparedStatementParameters ;
315
- this .preparedQueryString = preparedQueryString ;
316
- this .datumReadersComplexFeatures = getComplexFeatureSchemas (features );
337
+ this .orderedServingPreparedStatements = orderedServingPreparedStatements ;
338
+ this .featureGroupDatumReaders = getComplexFeatureSchemas (features );
317
339
}
318
340
319
341
@ VisibleForTesting
@@ -360,25 +382,32 @@ private String zipArraysToTupleString(List<List<Object>> lists) {
360
382
return "(" + String .join ("," , zippedTuples ) + ")" ;
361
383
}
362
384
363
- private Object deserializeComplexFeature (Map < String , DatumReader <Object >> complexFeatureSchemas , ResultSet results ,
385
+ private Object deserializeComplexFeature (DatumReader <Object > featureDatumReader , ResultSet results ,
364
386
int index ) throws SQLException , IOException {
365
387
if (results .getBytes (index ) != null ) {
366
388
Decoder decoder = DecoderFactory .get ().binaryDecoder (results .getBytes (index ), null );
367
- return complexFeatureSchemas . get ( results . getMetaData (). getColumnName ( index )) .read (null , decoder );
389
+ return featureDatumReader .read (null , decoder );
368
390
} else {
369
391
return null ;
370
392
}
371
393
}
372
394
373
395
@ VisibleForTesting
374
- public Map <String , DatumReader <Object >> getComplexFeatureSchemas (List <TrainingDatasetFeature > features )
396
+ public Map <Integer , Map < String , DatumReader <Object > >> getComplexFeatureSchemas (List <TrainingDatasetFeature > features )
375
397
throws FeatureStoreException , IOException {
376
- Map <String , DatumReader <Object >> featureSchemaMap = new HashMap <>();
398
+ Map <Integer , Map < String , DatumReader <Object > >> featureSchemaMap = new HashMap <>();
377
399
for (TrainingDatasetFeature f : features ) {
378
400
if (f .isComplex ()) {
401
+ Map <String , DatumReader <Object >> featureGroupMap = featureSchemaMap .get (f .getFeaturegroup ().getId ());
402
+ if (featureGroupMap == null ) {
403
+ featureGroupMap = new HashMap <>();
404
+ }
405
+
379
406
DatumReader <Object > datumReader =
380
407
new GenericDatumReader <>(parser .parse (f .getFeaturegroup ().getFeatureAvroSchema (f .getName ())));
381
- featureSchemaMap .put (f .getName (), datumReader );
408
+ featureGroupMap .put (f .getName (), datumReader );
409
+
410
+ featureSchemaMap .put (f .getFeaturegroup ().getId (), featureGroupMap );
382
411
}
383
412
}
384
413
return featureSchemaMap ;
0 commit comments