Skip to content

Commit 6aadde2

Browse files
authored
[FSTORE-1190] Attach model to embedding feature (#1481)
* save model to embedding feature * remove modelId * address comment * fix NPE (cherry picked from commit 6e36379)
1 parent 79a7139 commit 6aadde2

File tree

6 files changed

+113
-7
lines changed

6 files changed

+113
-7
lines changed

hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java

+26-2
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@
2222
import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingDTO;
2323
import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController;
2424
import io.hops.hopsworks.common.hdfs.Utils;
25+
import io.hops.hopsworks.common.models.ModelFacade;
26+
import io.hops.hopsworks.common.models.version.ModelVersionFacade;
2527
import io.hops.hopsworks.common.util.Settings;
2628
import io.hops.hopsworks.exceptions.FeaturestoreException;
2729
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding;
2830
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature;
2931
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup;
32+
import io.hops.hopsworks.persistence.entity.models.Model;
33+
import io.hops.hopsworks.persistence.entity.models.version.ModelVersion;
3034
import io.hops.hopsworks.persistence.entity.project.Project;
3135
import io.hops.hopsworks.restutils.RESTCodes;
3236
import io.hops.hopsworks.vectordb.Index;
@@ -55,6 +59,10 @@ public class EmbeddingController {
5559
private VectorDatabaseClient vectorDatabaseClient;
5660
@EJB
5761
private FeaturegroupController featuregroupController;
62+
@EJB
63+
private ModelVersionFacade modelVersionFacade;
64+
@EJB
65+
private ModelFacade modelFacade;
5866

5967
public void createVectorDbIndex(Project project, Featuregroup featureGroup)
6068
throws FeaturestoreException {
@@ -72,6 +80,11 @@ public void createVectorDbIndex(Project project, Featuregroup featureGroup)
7280
}
7381
}
7482

83+
private ModelVersion getModel(Integer projectId, String modelName, Integer modelVersion) {
84+
Model model = modelFacade.findByProjectIdAndName(projectId, modelName);
85+
return modelVersionFacade.findByProjectAndMlId(model.getId(), modelVersion);
86+
}
87+
7588
public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featuregroup featuregroup)
7689
throws FeaturestoreException {
7790
Embedding embedding = new Embedding();
@@ -94,8 +107,19 @@ public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featur
94107
embedding.setEmbeddingFeatures(
95108
embeddingDTO.getFeatures()
96109
.stream()
97-
.map(mapping -> new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(),
98-
mapping.getSimilarityFunctionType()))
110+
.map(mapping -> {
111+
if (mapping.getModel() != null) {
112+
return new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(),
113+
mapping.getSimilarityFunctionType(),
114+
getModel(mapping.getModel().getModelRegistryId(),
115+
mapping.getModel().getModelName(),
116+
mapping.getModel().getModelVersion()));
117+
} else {
118+
return new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(),
119+
mapping.getSimilarityFunctionType());
120+
}
121+
}
122+
)
99123
.collect(Collectors.toList())
100124
);
101125
return embedding;

hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java

+10
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,20 @@ public class EmbeddingFeatureDTO {
3333
private String similarityFunctionType;
3434
@Getter
3535
private Integer dimension;
36+
@Getter
37+
private ModelDto model;
38+
3639

3740
public EmbeddingFeatureDTO(EmbeddingFeature feature) {
3841
name = feature.getName();
3942
similarityFunctionType = feature.getSimilarityFunctionType();
4043
dimension = feature.getDimension();
44+
if (feature.getModelVersion() != null) {
45+
model = new ModelDto(
46+
// model registry id is same as project id
47+
feature.getModelVersion().getModel().getProject().getId(),
48+
feature.getModelVersion().getModel().getName(),
49+
feature.getModelVersion().getModelVersionPK().getVersion());
50+
}
4151
}
4252
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* This file is part of Hopsworks
3+
* Copyright (C) 2024, Hopsworks AB. All rights reserved
4+
*
5+
* Hopsworks is free software: you can redistribute it and/or modify it under the terms of
6+
* the GNU Affero General Public License as published by the Free Software Foundation,
7+
* either version 3 of the License, or (at your option) any later version.
8+
*
9+
* Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
10+
* without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11+
* PURPOSE. See the GNU Affero General Public License for more details.
12+
*
13+
* You should have received a copy of the GNU Affero General Public License along with this program.
14+
* If not, see <https://www.gnu.org/licenses/>.
15+
*/
16+
17+
package io.hops.hopsworks.common.featurestore.featuregroup;
18+
19+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
20+
import lombok.AllArgsConstructor;
21+
import lombok.Getter;
22+
import lombok.NoArgsConstructor;
23+
24+
@NoArgsConstructor
25+
@AllArgsConstructor
26+
@JsonIgnoreProperties(ignoreUnknown = true)
27+
public class ModelDto {
28+
29+
@Getter
30+
private Integer modelRegistryId;
31+
@Getter
32+
private String modelName;
33+
@Getter
34+
private Integer modelVersion;
35+
36+
}

hopsworks-common/src/main/java/io/hops/hopsworks/common/models/ModelFacade.java

+10
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ public Model findByProjectAndName(Project project, String name) {
8585
}
8686
}
8787

88+
public Model findByProjectIdAndName(Integer projectId, String name) {
89+
TypedQuery<Model> query = em.createNamedQuery("Model.findByProjectIdAndName", Model.class);
90+
query.setParameter("name", name).setParameter("projectId", projectId);
91+
try {
92+
return query.getSingleResult();
93+
} catch (NoResultException e) {
94+
return null;
95+
}
96+
}
97+
8898
public CollectionInfo findByProject(Integer offset, Integer limit,
8999
Set<? extends AbstractFacade.FilterBy> filters,
90100
Set<? extends AbstractFacade.SortBy> sorts, Project project) {

hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java

+23
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,23 @@
1616

1717
package io.hops.hopsworks.persistence.entity.featurestore.featuregroup;
1818

19+
import io.hops.hopsworks.persistence.entity.models.version.ModelVersion;
20+
1921
import javax.persistence.Basic;
2022
import javax.persistence.Column;
2123
import javax.persistence.Entity;
2224
import javax.persistence.GeneratedValue;
2325
import javax.persistence.GenerationType;
2426
import javax.persistence.Id;
2527
import javax.persistence.JoinColumn;
28+
import javax.persistence.JoinColumns;
29+
import javax.persistence.OneToOne;
2630
import javax.persistence.Table;
31+
import javax.xml.bind.annotation.XmlRootElement;
2732

2833
@Entity
2934
@Table(name = "embedding_feature", catalog = "hopsworks")
35+
@XmlRootElement
3036
public class EmbeddingFeature {
3137
@Id
3238
@GeneratedValue(strategy = GenerationType.IDENTITY)
@@ -41,6 +47,11 @@ public class EmbeddingFeature {
4147
private Integer dimension;
4248
@Column(name = "similarity_function_type")
4349
private String similarityFunctionType;
50+
@JoinColumns({
51+
@JoinColumn(name = "hsml_model_version", referencedColumnName = "version"),
52+
@JoinColumn(name = "hsml_model_id", referencedColumnName = "model_id")})
53+
@OneToOne
54+
private ModelVersion modelVersion;
4455

4556
public EmbeddingFeature() {
4657
}
@@ -53,6 +64,15 @@ public EmbeddingFeature(Embedding embedding, String name, Integer dimension,
5364
this.similarityFunctionType = similarityFunctionType;
5465
}
5566

67+
public EmbeddingFeature(Embedding embedding, String name, Integer dimension,
68+
String similarityFunctionType, ModelVersion modelVersion) {
69+
this.embedding = embedding;
70+
this.name = name;
71+
this.dimension = dimension;
72+
this.similarityFunctionType = similarityFunctionType;
73+
this.modelVersion = modelVersion;
74+
}
75+
5676
public EmbeddingFeature(Integer id, Embedding embedding, String name, Integer dimension,
5777
String similarityFunctionType) {
5878
this.id = id;
@@ -82,4 +102,7 @@ public String getSimilarityFunctionType() {
82102
return similarityFunctionType;
83103
}
84104

105+
public ModelVersion getModelVersion() {
106+
return modelVersion;
107+
}
85108
}

hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/Model.java

+8-5
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,14 @@
4545
@Table(name = "model", catalog = "hopsworks")
4646
@XmlRootElement
4747
@NamedQueries({
48-
@NamedQuery(name = "Model.findAll",
49-
query = "SELECT m FROM Model m"),
50-
@NamedQuery(name = "Model.findByProjectAndName",
51-
query
52-
= "SELECT m FROM Model m WHERE m.name = :name AND m.project = :project"),})
48+
@NamedQuery(name = "Model.findAll",
49+
query = "SELECT m FROM Model m"),
50+
@NamedQuery(name = "Model.findByProjectAndName",
51+
query
52+
= "SELECT m FROM Model m WHERE m.name = :name AND m.project = :project"),
53+
@NamedQuery(name = "Model.findByProjectIdAndName",
54+
query
55+
= "SELECT m FROM Model m WHERE m.name = :name AND m.project.id = :projectId"),})
5356
public class Model implements Serializable {
5457
private static final long serialVersionUID = 1L;
5558

0 commit comments

Comments
 (0)