22
22
import io .hops .hopsworks .common .featurestore .featuregroup .EmbeddingDTO ;
23
23
import io .hops .hopsworks .common .featurestore .featuregroup .FeaturegroupController ;
24
24
import io .hops .hopsworks .common .hdfs .Utils ;
25
+ import io .hops .hopsworks .common .models .ModelFacade ;
26
+ import io .hops .hopsworks .common .models .version .ModelVersionFacade ;
25
27
import io .hops .hopsworks .common .util .Settings ;
26
28
import io .hops .hopsworks .exceptions .FeaturestoreException ;
27
29
import io .hops .hopsworks .persistence .entity .featurestore .featuregroup .Embedding ;
28
30
import io .hops .hopsworks .persistence .entity .featurestore .featuregroup .EmbeddingFeature ;
29
31
import io .hops .hopsworks .persistence .entity .featurestore .featuregroup .Featuregroup ;
32
+ import io .hops .hopsworks .persistence .entity .models .Model ;
33
+ import io .hops .hopsworks .persistence .entity .models .version .ModelVersion ;
30
34
import io .hops .hopsworks .persistence .entity .project .Project ;
31
35
import io .hops .hopsworks .restutils .RESTCodes ;
32
36
import io .hops .hopsworks .vectordb .Index ;
@@ -55,6 +59,10 @@ public class EmbeddingController {
55
59
private VectorDatabaseClient vectorDatabaseClient ;
56
60
@ EJB
57
61
private FeaturegroupController featuregroupController ;
62
+ @ EJB
63
+ private ModelVersionFacade modelVersionFacade ;
64
+ @ EJB
65
+ private ModelFacade modelFacade ;
58
66
59
67
public void createVectorDbIndex (Project project , Featuregroup featureGroup )
60
68
throws FeaturestoreException {
@@ -72,6 +80,11 @@ public void createVectorDbIndex(Project project, Featuregroup featureGroup)
72
80
}
73
81
}
74
82
83
+ private ModelVersion getModel (Integer projectId , String modelName , Integer modelVersion ) {
84
+ Model model = modelFacade .findByProjectIdAndName (projectId , modelName );
85
+ return modelVersionFacade .findByProjectAndMlId (model .getId (), modelVersion );
86
+ }
87
+
75
88
public Embedding getEmbedding (Project project , EmbeddingDTO embeddingDTO , Featuregroup featuregroup )
76
89
throws FeaturestoreException {
77
90
Embedding embedding = new Embedding ();
@@ -94,8 +107,19 @@ public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featur
94
107
embedding .setEmbeddingFeatures (
95
108
embeddingDTO .getFeatures ()
96
109
.stream ()
97
- .map (mapping -> new EmbeddingFeature (embedding , mapping .getName (), mapping .getDimension (),
98
- mapping .getSimilarityFunctionType ()))
110
+ .map (mapping -> {
111
+ if (mapping .getModel () != null ) {
112
+ return new EmbeddingFeature (embedding , mapping .getName (), mapping .getDimension (),
113
+ mapping .getSimilarityFunctionType (),
114
+ getModel (mapping .getModel ().getModelRegistryId (),
115
+ mapping .getModel ().getModelName (),
116
+ mapping .getModel ().getModelVersion ()));
117
+ } else {
118
+ return new EmbeddingFeature (embedding , mapping .getName (), mapping .getDimension (),
119
+ mapping .getSimilarityFunctionType ());
120
+ }
121
+ }
122
+ )
99
123
.collect (Collectors .toList ())
100
124
);
101
125
return embedding ;
0 commit comments