3
3
4
4
5
5
import os
6
+ from typing import Any , Dict , List
6
7
7
- from fastapi import HTTPException
8
- from langchain_community .embeddings import HuggingFaceBgeEmbeddings , HuggingFaceInferenceAPIEmbeddings
8
+ import numpy as np
9
+ import torch .nn as nn
10
+ import torchvision .transforms as T
11
+ from decord import VideoReader , cpu
12
+ from einops import rearrange
13
+ from langchain_community .embeddings import HuggingFaceBgeEmbeddings
14
+ from langchain_core .embeddings import Embeddings
9
15
from langchain_vdms .vectorstores import VDMS , VDMS_Client
16
+ from pydantic import BaseModel , model_validator
17
+ from torch import cat as torch_cat
18
+ from transformers import AutoProcessor , AutoTokenizer , CLIPModel
10
19
11
20
from comps import CustomLogger , EmbedDoc , OpeaComponent , OpeaComponentRegistry , ServiceType
12
21
24
33
25
34
logger = CustomLogger ("vdms_retrievers" )
26
35
logflag = os .getenv ("LOGFLAG" , False )
36
+ toPIL = T .ToPILImage ()
27
37
28
38
29
39
@OpeaComponentRegistry .register ("OPEA_RETRIEVER_VDMS" )
30
40
class OpeaVDMsRetriever (OpeaComponent ):
31
41
"""A specialized retriever component derived from OpeaComponent for vdms retriever services.
32
42
33
43
Attributes:
34
- client (VDMs ): An instance of the vdms client for vector database operations.
44
+ client (VDMS ): An instance of the vdms client for vector database operations.
35
45
"""
36
46
37
47
def __init__ (self , name : str , description : str , config : dict = None ):
38
48
super ().__init__ (name , ServiceType .RETRIEVER .name .lower (), description , config )
39
49
40
50
self .embedder = self ._initialize_embedder ()
41
- self .client = VDMS_Client (VDMS_HOST , VDMS_PORT )
51
+ self .client = VDMS_Client (host = VDMS_HOST , port = VDMS_PORT )
42
52
self .vector_db = self ._initialize_vector_db ()
43
53
health_status = self .check_health ()
44
54
if not health_status :
45
55
logger .error ("OpeaVDMsRetriever health check failed." )
46
56
47
57
def _initialize_embedder (self ):
48
58
if VDMS_USE_CLIP :
49
- from comps .third_parties .clip .src .clip_embedding import vCLIP
59
+ meanclip_cfg = {
60
+ "model_name" : "openai/clip-vit-base-patch32" ,
61
+ "num_frm" : 64 ,
62
+ }
63
+ video_retriever_model = vCLIP (meanclip_cfg ) # , device="cpu")
64
+ embeddings = vCLIPEmbeddings (model = video_retriever_model )
50
65
51
- embeddings = vCLIP ({"model_name" : "openai/clip-vit-base-patch32" , "num_frm" : 64 })
52
66
elif TEI_EMBEDDING_ENDPOINT :
53
67
# create embeddings using TEI endpoint service
54
68
if logflag :
55
69
logger .info (f"[ init embedder ] TEI_EMBEDDING_ENDPOINT:{ TEI_EMBEDDING_ENDPOINT } " )
56
- if not HUGGINGFACEHUB_API_TOKEN :
57
- raise HTTPException (
58
- status_code = 400 ,
59
- detail = "You MUST offer the `HUGGINGFACEHUB_API_TOKEN` when using `TEI_EMBEDDING_ENDPOINT`." ,
60
- )
61
- import requests
62
-
63
- response = requests .get (TEI_EMBEDDING_ENDPOINT + "/info" )
64
- if response .status_code != 200 :
65
- raise HTTPException (
66
- status_code = 400 , detail = f"TEI embedding endpoint { TEI_EMBEDDING_ENDPOINT } is not available."
67
- )
68
- model_id = response .json ()["model_id" ]
69
- embeddings = HuggingFaceInferenceAPIEmbeddings (
70
- api_key = HUGGINGFACEHUB_API_TOKEN , model_name = model_id , api_url = TEI_EMBEDDING_ENDPOINT
71
- )
70
+ from langchain_huggingface import HuggingFaceEndpointEmbeddings
71
+
72
+ embeddings = HuggingFaceEndpointEmbeddings (model = TEI_EMBEDDING_ENDPOINT )
72
73
else :
73
74
# create embeddings using local embedding model
74
75
if logflag :
@@ -78,24 +79,13 @@ def _initialize_embedder(self):
78
79
79
80
def _initialize_vector_db (self ) -> VDMS :
80
81
"""Initializes the vdms client."""
81
- if VDMS_USE_CLIP :
82
- dimensions = self .embedder .get_embedding_length ()
83
- vector_db = VDMS (
84
- client = self .client ,
85
- embedding = self .embedder ,
86
- collection_name = VDMS_INDEX_NAME ,
87
- embedding_dimensions = dimensions ,
88
- distance_strategy = DISTANCE_STRATEGY ,
89
- engine = SEARCH_ENGINE ,
90
- )
91
- else :
92
- vector_db = VDMS (
93
- client = self .client ,
94
- embedding = self .embedder ,
95
- collection_name = VDMS_INDEX_NAME ,
96
- distance_strategy = DISTANCE_STRATEGY ,
97
- engine = SEARCH_ENGINE ,
98
- )
82
+ vector_db = VDMS (
83
+ client = self .client ,
84
+ embedding = self .embedder ,
85
+ collection_name = VDMS_INDEX_NAME ,
86
+ distance_strategy = DISTANCE_STRATEGY ,
87
+ engine = SEARCH_ENGINE ,
88
+ )
99
89
return vector_db
100
90
101
91
def check_health (self ) -> bool :
@@ -154,8 +144,127 @@ async def invoke(self, input: EmbedDoc) -> list:
154
144
lambda_mult = input .lambda_mult ,
155
145
filter = input .constraints ,
156
146
)
147
+ else :
148
+ raise ValueError (f"{ input .search_type } not valid" )
157
149
158
150
if logflag :
159
151
logger .info (f"retrieve result: { search_res } " )
160
152
161
153
return search_res
154
+
155
+
156
+ class vCLIPEmbeddings (BaseModel , Embeddings ):
157
+ """MeanCLIP Embeddings model."""
158
+
159
+ model : Any
160
+
161
+ def get_embedding_length (self ):
162
+ text_features = self .embed_query ("sample_text" )
163
+ t_len = len (text_features )
164
+ logger .info (f"text_features: { t_len } " )
165
+ return t_len
166
+
167
+ @model_validator (mode = "before" )
168
+ def validate_environment (cls , values : Dict ) -> Dict :
169
+ """Validate that open_clip and torch libraries are installed."""
170
+ try :
171
+ # Use the provided model if present
172
+ if "model" not in values :
173
+ raise ValueError ("Model must be provided during initialization." )
174
+
175
+ except ImportError :
176
+ raise ImportError ("Please ensure CLIP model is loaded" )
177
+ return values
178
+
179
+ def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
180
+ model_device = next (self .model .clip .parameters ()).device
181
+ text_features = self .model .get_text_embeddings (texts )
182
+
183
+ return text_features .detach ().numpy ()
184
+
185
+ def embed_query (self , text : str ) -> List [float ]:
186
+ return self .embed_documents ([text ])[0 ]
187
+
188
+ def embed_video (self , paths : List [str ], ** kwargs : Any ) -> List [List [float ]]:
189
+ # Open images directly as PIL images
190
+
191
+ video_features = []
192
+ for vid_path in sorted (paths ):
193
+ # Encode the video to get the embeddings
194
+ model_device = next (self .model .parameters ()).device
195
+ # Preprocess the video for the model
196
+ clip_images = self .load_video_for_vclip (
197
+ vid_path ,
198
+ num_frm = self .model .num_frm ,
199
+ max_img_size = 224 ,
200
+ start_time = kwargs .get ("start_time" , None ),
201
+ clip_duration = kwargs .get ("clip_duration" , None ),
202
+ )
203
+ embeddings_tensor = self .model .get_video_embeddings ([clip_images ])
204
+
205
+ # Convert tensor to list and add to the video_features list
206
+ embeddings_list = embeddings_tensor .tolist ()
207
+
208
+ video_features .append (embeddings_list )
209
+
210
+ return video_features
211
+
212
+ def load_video_for_vclip (self , vid_path , num_frm = 4 , max_img_size = 224 , ** kwargs ):
213
+ # Load video with VideoReader
214
+ import decord
215
+
216
+ decord .bridge .set_bridge ("torch" )
217
+ vr = VideoReader (vid_path , ctx = cpu (0 ))
218
+ fps = vr .get_avg_fps ()
219
+ num_frames = len (vr )
220
+ start_idx = int (fps * kwargs .get ("start_time" , [0 ])[0 ])
221
+ end_idx = start_idx + int (fps * kwargs .get ("clip_duration" , [num_frames ])[0 ])
222
+
223
+ frame_idx = np .linspace (start_idx , end_idx , num = num_frm , endpoint = False , dtype = int ) # Uniform sampling
224
+ clip_images = []
225
+
226
+ # read images
227
+ temp_frms = vr .get_batch (frame_idx .astype (int ).tolist ())
228
+ for idx in range (temp_frms .shape [0 ]):
229
+ im = temp_frms [idx ] # H W C
230
+ clip_images .append (toPIL (im .permute (2 , 0 , 1 )))
231
+
232
+ return clip_images
233
+
234
+
235
+ class vCLIP (nn .Module ):
236
+ def __init__ (self , cfg ):
237
+ super ().__init__ ()
238
+
239
+ self .num_frm = cfg ["num_frm" ]
240
+ self .model_name = cfg ["model_name" ]
241
+
242
+ self .clip = CLIPModel .from_pretrained (self .model_name )
243
+ self .processor = AutoProcessor .from_pretrained (self .model_name )
244
+ self .tokenizer = AutoTokenizer .from_pretrained (self .model_name )
245
+
246
+ def get_text_embeddings (self , texts ):
247
+ """Input is list of texts."""
248
+ text_inputs = self .tokenizer (texts , padding = True , return_tensors = "pt" )
249
+ text_features = self .clip .get_text_features (** text_inputs )
250
+ return text_features
251
+
252
+ def get_image_embeddings (self , images ):
253
+ """Input is list of images."""
254
+ image_inputs = self .processor (images = images , return_tensors = "pt" )
255
+ image_features = self .clip .get_image_features (** image_inputs )
256
+ return image_features
257
+
258
+ def get_video_embeddings (self , frames_batch ):
259
+ """Input is list of list of frames in video."""
260
+ self .batch_size = len (frames_batch )
261
+ vid_embs = []
262
+ for frames in frames_batch :
263
+ frame_embeddings = self .get_image_embeddings (frames )
264
+ frame_embeddings = rearrange (frame_embeddings , "(b n) d -> b n d" , b = len (frames_batch ))
265
+ # Normalize, mean aggregate and return normalized video_embeddings
266
+ frame_embeddings = frame_embeddings / frame_embeddings .norm (dim = - 1 , keepdim = True )
267
+ video_embeddings = frame_embeddings .mean (dim = 1 )
268
+ video_embeddings = video_embeddings / video_embeddings .norm (dim = - 1 , keepdim = True )
269
+ vid_embs .append (video_embeddings )
270
+ return torch_cat (vid_embs , dim = 0 )
0 commit comments