Skip to content

Commit ffa786c

Browse files
SirOibafdavitbzh
authored and
davitbzh
committed
Fix bug with deserialization of complex features with prefix
1 parent c37f249 commit ffa786c

File tree

3 files changed

+67
-25
lines changed

3 files changed

+67
-25
lines changed

java/hsfs/src/main/java/com/logicalclocks/hsfs/constructor/ServingPreparedStatement.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
@NoArgsConstructor
3333
@JsonIgnoreProperties(ignoreUnknown = true)
3434
public class ServingPreparedStatement extends RestDto<ServingPreparedStatement> {
35+
@Getter
36+
@Setter
37+
private Integer featureGroupId;
3538
@Getter
3639
@Setter
3740
private Integer preparedStatementIndex;
@@ -41,4 +44,14 @@ public class ServingPreparedStatement extends RestDto<ServingPreparedStatement>
4144
@Getter
4245
@Setter
4346
private String queryOnline;
47+
@Getter
48+
@Setter
49+
private String prefix;
50+
51+
public ServingPreparedStatement(Integer preparedStatementIndex,
52+
List<PreparedStatementParameter> preparedStatementParameters, String queryOnline) {
53+
this.preparedStatementIndex = preparedStatementIndex;
54+
this.preparedStatementParameters = preparedStatementParameters;
55+
this.queryOnline = queryOnline;
56+
}
4457
}

java/hsfs/src/main/java/com/logicalclocks/hsfs/engine/VectorServer.java

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public class VectorServer {
7474
@Getter
7575
private Map<Integer, TreeMap<String, Integer>> preparedStatementParameters;
7676
@Getter
77-
private TreeMap<Integer, String> preparedQueryString;
77+
private TreeMap<Integer, ServingPreparedStatement> orderedServingPreparedStatements;
7878
@Getter
7979
@Setter
8080
private HashSet<String> servingKeys;
@@ -83,7 +83,7 @@ public class VectorServer {
8383
private Schema.Parser parser = new Schema.Parser();
8484
private FeatureViewApi featureViewApi = new FeatureViewApi();
8585

86-
private Map<String, DatumReader<Object>> datumReadersComplexFeatures;
86+
private Map<Integer, Map<String, DatumReader<Object>>> featureGroupDatumReaders;
8787
private ExecutorService executorService = Executors.newCachedThreadPool();
8888
private boolean isBatch = false;
8989
private VariablesApi variablesApi = new VariablesApi();
@@ -110,7 +110,7 @@ public List<Object> getFeatureVector(Map<String, Object> entry)
110110
List<Object> servingVector = new ArrayList<>();
111111
List<Future<List<Object>>> queryFutures = new ArrayList<>();
112112

113-
for (Integer preparedStatementIndex : preparedQueryString.keySet()) {
113+
for (Integer preparedStatementIndex : orderedServingPreparedStatements.keySet()) {
114114
queryFutures.add(executorService.submit(() -> {
115115
try {
116116
return processQuery(entry, preparedStatementIndex);
@@ -136,8 +136,10 @@ private List<Object> processQuery(Map<String, Object> entry, int preparedStateme
136136
List<Object> servingVector = new ArrayList<>();
137137
try (Connection connection = hikariDataSource.getConnection()) {
138138
// Create the prepared statement
139-
PreparedStatement preparedStatement =
140-
connection.prepareStatement(preparedQueryString.get(preparedStatementIndex));
139+
ServingPreparedStatement servingPreparedStatement = orderedServingPreparedStatements.get(preparedStatementIndex);
140+
141+
System.out.println(servingPreparedStatement.getQueryOnline());
142+
PreparedStatement preparedStatement = connection.prepareStatement(servingPreparedStatement.getQueryOnline());
141143

142144
// Set the parameters base do the entry object
143145
Map<String, Integer> parameterIndexInStatement = preparedStatementParameters.get(preparedStatementIndex);
@@ -154,12 +156,20 @@ private List<Object> processQuery(Map<String, Object> entry, int preparedStateme
154156
}
155157
//Get column count
156158
int columnCount = results.getMetaData().getColumnCount();
159+
160+
// get the complex schema datum readers for this feature group
161+
Map<String, DatumReader<Object>> featuresDatumReaders =
162+
featureGroupDatumReaders.get(servingPreparedStatement.getFeatureGroupId());
163+
157164
//append results to servingVector
158165
while (results.next()) {
159166
int index = 1;
160167
while (index <= columnCount) {
161-
if (datumReadersComplexFeatures.containsKey(results.getMetaData().getColumnName(index))) {
162-
servingVector.add(deserializeComplexFeature(datumReadersComplexFeatures, results, index));
168+
if (featuresDatumReaders != null
169+
&& featuresDatumReaders.containsKey(results.getMetaData().getColumnName(index))) {
170+
servingVector.add(
171+
deserializeComplexFeature(
172+
featuresDatumReaders.get(results.getMetaData().getColumnName(index)), results, index));
163173
} else {
164174
servingVector.add(results.getObject(index));
165175
}
@@ -185,8 +195,8 @@ public List<List<Object>> getFeatureVectors(Map<String, List<Object>> entry)
185195
throws SQLException, FeatureStoreException, IOException {
186196
checkPrimaryKeys(entry.keySet());
187197
List<String> queries = Lists.newArrayList();
188-
for (Integer fgId : preparedQueryString.keySet()) {
189-
String query = preparedQueryString.get(fgId);
198+
for (Integer fgId : orderedServingPreparedStatements.keySet()) {
199+
String query = orderedServingPreparedStatements.get(fgId).getQueryOnline();
190200
String zippedTupleString =
191201
zipArraysToTupleString(preparedStatementParameters.get(fgId)
192202
.entrySet()
@@ -211,6 +221,8 @@ private List<List<Object>> getFeatureVectors(List<String> queries)
211221

212222
try (Connection connection = hikariDataSource.getConnection()) {
213223
try (Statement stmt = connection.createStatement()) {
224+
// Used to reference the ServingPreparedStatement for deserialization
225+
int statementOrder = 0;
214226
for (String query : queries) {
215227
int orderInBatch = 0;
216228

@@ -224,12 +236,21 @@ private List<List<Object>> getFeatureVectors(List<String> queries)
224236
}
225237
//Get column count
226238
int columnCount = results.getMetaData().getColumnCount();
239+
240+
// get the complex schema datum readers for this feature group
241+
ServingPreparedStatement servingPreparedStatement = orderedServingPreparedStatements.get(statementOrder);
242+
Map<String, DatumReader<Object>> featuresDatumReaders =
243+
featureGroupDatumReaders.get(servingPreparedStatement.getFeatureGroupId());
244+
227245
//append results to servingVector
228246
while (results.next()) {
229247
int index = 1;
230248
while (index <= columnCount) {
231-
if (datumReadersComplexFeatures.containsKey(results.getMetaData().getColumnName(index))) {
232-
servingVector.add(deserializeComplexFeature(datumReadersComplexFeatures, results, index));
249+
if (featuresDatumReaders != null
250+
&& featuresDatumReaders.containsKey(results.getMetaData().getColumnName(index))) {
251+
servingVector.add(
252+
deserializeComplexFeature(
253+
featuresDatumReaders.get(results.getMetaData().getColumnName(index)), results, index));
233254
} else {
234255
servingVector.add(results.getObject(index));
235256
}
@@ -246,6 +267,7 @@ private List<List<Object>> getFeatureVectors(List<String> queries)
246267
orderInBatch++;
247268
}
248269
}
270+
statementOrder++;
249271
}
250272
}
251273
}
@@ -295,13 +317,13 @@ public void initPreparedStatement(FeatureStoreBase featureStoreBase,
295317
Map<Integer, TreeMap<String, Integer>> preparedStatementParameters = new HashMap<>();
296318

297319
// in case its batch serving then we need to save sql string only
298-
TreeMap<Integer, String> preparedQueryString = new TreeMap<>();
320+
TreeMap<Integer, ServingPreparedStatement> orderedServingPreparedStatements = new TreeMap<>();
299321

300322
// save unique primary key names that will be used by user to retrieve serving vector
301323
HashSet<String> servingVectorKeys = new HashSet<>();
302324
for (ServingPreparedStatement servingPreparedStatement : servingPreparedStatements) {
303-
preparedQueryString.put(servingPreparedStatement.getPreparedStatementIndex(),
304-
servingPreparedStatement.getQueryOnline());
325+
orderedServingPreparedStatements.put(servingPreparedStatement.getPreparedStatementIndex(),
326+
servingPreparedStatement);
305327
TreeMap<String, Integer> parameterIndices = new TreeMap<>();
306328
servingPreparedStatement.getPreparedStatementParameters().forEach(preparedStatementParameter -> {
307329
servingVectorKeys.add(preparedStatementParameter.getName());
@@ -312,8 +334,8 @@ public void initPreparedStatement(FeatureStoreBase featureStoreBase,
312334
this.servingKeys = servingVectorKeys;
313335

314336
this.preparedStatementParameters = preparedStatementParameters;
315-
this.preparedQueryString = preparedQueryString;
316-
this.datumReadersComplexFeatures = getComplexFeatureSchemas(features);
337+
this.orderedServingPreparedStatements = orderedServingPreparedStatements;
338+
this.featureGroupDatumReaders = getComplexFeatureSchemas(features);
317339
}
318340

319341
@VisibleForTesting
@@ -360,25 +382,32 @@ private String zipArraysToTupleString(List<List<Object>> lists) {
360382
return "(" + String.join(",", zippedTuples) + ")";
361383
}
362384

363-
private Object deserializeComplexFeature(Map<String, DatumReader<Object>> complexFeatureSchemas, ResultSet results,
385+
private Object deserializeComplexFeature(DatumReader<Object> featureDatumReader, ResultSet results,
364386
int index) throws SQLException, IOException {
365387
if (results.getBytes(index) != null) {
366388
Decoder decoder = DecoderFactory.get().binaryDecoder(results.getBytes(index), null);
367-
return complexFeatureSchemas.get(results.getMetaData().getColumnName(index)).read(null, decoder);
389+
return featureDatumReader.read(null, decoder);
368390
} else {
369391
return null;
370392
}
371393
}
372394

373395
@VisibleForTesting
374-
public Map<String, DatumReader<Object>> getComplexFeatureSchemas(List<TrainingDatasetFeature> features)
396+
public Map<Integer, Map<String, DatumReader<Object>>> getComplexFeatureSchemas(List<TrainingDatasetFeature> features)
375397
throws FeatureStoreException, IOException {
376-
Map<String, DatumReader<Object>> featureSchemaMap = new HashMap<>();
398+
Map<Integer, Map<String, DatumReader<Object>>> featureSchemaMap = new HashMap<>();
377399
for (TrainingDatasetFeature f : features) {
378400
if (f.isComplex()) {
401+
Map<String, DatumReader<Object>> featureGroupMap = featureSchemaMap.get(f.getFeaturegroup().getId());
402+
if (featureGroupMap == null) {
403+
featureGroupMap = new HashMap<>();
404+
}
405+
379406
DatumReader<Object> datumReader =
380407
new GenericDatumReader<>(parser.parse(f.getFeaturegroup().getFeatureAvroSchema(f.getName())));
381-
featureSchemaMap.put(f.getName(), datumReader);
408+
featureGroupMap.put(f.getName(), datumReader);
409+
410+
featureSchemaMap.put(f.getFeaturegroup().getId(), featureGroupMap);
382411
}
383412
}
384413
return featureSchemaMap;

java/hsfs/src/test/java/com/logicalclocks/hsfs/engine/TestVectorServer.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public void testQueryOrder() throws Exception {
4343
VectorServer vectorServer = Mockito.mock(VectorServer.class);
4444
Mockito.doCallRealMethod().when(vectorServer).initPreparedStatement(Mockito.any(), Mockito.any(),
4545
Mockito.any(), Mockito.anyBoolean(), Mockito.anyBoolean());
46-
Mockito.when(vectorServer.getPreparedQueryString()).thenCallRealMethod();
46+
Mockito.when(vectorServer.getOrderedServingPreparedStatements()).thenCallRealMethod();
4747
Mockito.when(vectorServer.getPreparedStatementParameters()).thenCallRealMethod();
4848

4949
FeatureStoreBase featureStoreBase = Mockito.mock(FeatureStoreBase.class);
@@ -61,9 +61,9 @@ public void testQueryOrder() throws Exception {
6161

6262
vectorServer.initPreparedStatement(featureStoreBase, null, servingPreparedStatements, false, true);
6363

64-
Map<Integer,String> queries = vectorServer.getPreparedQueryString();
65-
Assert.assertEquals("SELECT * FROM table_0 WHERE pk=? AND id=?", queries.get(0));
66-
Assert.assertEquals("SELECT * FROM table_1 WHERE id=?", queries.get(1));
64+
Map<Integer, ServingPreparedStatement> queries = vectorServer.getOrderedServingPreparedStatements();
65+
Assert.assertEquals("SELECT * FROM table_0 WHERE pk=? AND id=?", queries.get(0).getQueryOnline());
66+
Assert.assertEquals("SELECT * FROM table_1 WHERE id=?", queries.get(1).getQueryOnline());
6767

6868
TreeMap<String, Integer> preparedStatementParameters = vectorServer.getPreparedStatementParameters().get(0);
6969
Assert.assertEquals((Integer)0, preparedStatementParameters.get("pk"));

0 commit comments

Comments
 (0)