Skip to content

Commit e90c52a

Browse files
committed
Update VDMS related components
Signed-off-by: Lacewell, Chaunte W <chaunte.w.lacewell@intel.com>
1 parent eadfe2f commit e90c52a

File tree

11 files changed

+265
-86
lines changed

11 files changed

+265
-86
lines changed

comps/dataprep/src/integrations/utils/store_embeddings.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import numpy as np
77
import torchvision.transforms as T
88
from decord import VideoReader, cpu
9-
from langchain.pydantic_v1 import BaseModel, root_validator
10-
from langchain_community.vectorstores import VDMS
11-
from langchain_community.vectorstores.vdms import VDMS_Client
129
from langchain_core.embeddings import Embeddings
10+
from langchain_vdms.vectorstores import VDMS, VDMS_Client
11+
from pydantic import BaseModel, model_validator
1312

1413
toPIL = T.ToPILImage()
1514

@@ -21,7 +20,7 @@ class vCLIPEmbeddings(BaseModel, Embeddings):
2120

2221
model: Any
2322

24-
@root_validator(allow_reuse=True)
23+
@model_validator(mode="before")
2524
def validate_environment(cls, values: Dict) -> Dict:
2625
"""Validate that open_clip and torch libraries are installed."""
2726
try:
@@ -99,6 +98,8 @@ def __init__(
9998
collection_name,
10099
embedding_dimensions: int = 512,
101100
chosen_video_search_type="similarity",
101+
engine: str = "FaissFlat",
102+
distance_strategy: str = "IP",
102103
):
103104

104105
self.host = host
@@ -110,6 +111,8 @@ def __init__(
110111
self.video_embedder = vCLIPEmbeddings(model=video_retriever_model)
111112
self.chosen_video_search_type = chosen_video_search_type
112113
self.embedding_dimensions = embedding_dimensions
114+
self.engine = engine
115+
self.distance_strategy = distance_strategy
113116

114117
# initialize_db
115118
self.get_db_client()
@@ -128,7 +131,7 @@ def init_db(self):
128131
client=self.client,
129132
embedding=self.video_embedder,
130133
collection_name=self.video_collection,
131-
engine="FaissFlat",
132-
distance_strategy="IP",
134+
engine=self.engine,
135+
distance_strategy=self.distance_strategy,
133136
embedding_dimensions=self.embedding_dimensions,
134137
)

comps/dataprep/src/integrations/vdms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from fastapi import Body, File, Form, HTTPException, UploadFile
99
from langchain.text_splitter import RecursiveCharacterTextSplitter
1010
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceInferenceAPIEmbeddings
11-
from langchain_community.vectorstores.vdms import VDMS, VDMS_Client
1211
from langchain_text_splitters import HTMLHeaderTextSplitter
12+
from langchain_vdms.vectorstores import VDMS, VDMS_Client
1313

1414
from comps import CustomLogger, DocPath, OpeaComponent, OpeaComponentRegistry, ServiceType
1515
from comps.dataprep.src.utils import (

comps/dataprep/src/integrations/vdms_multimodal.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
VECTORDB_SERVICE_HOST_IP = os.getenv("VDMS_HOST", "0.0.0.0")
2424
VECTORDB_SERVICE_PORT = os.getenv("VDMS_PORT", 55555)
2525
collection_name = os.getenv("INDEX_NAME", "rag-vdms")
26+
SEARCH_ENGINE = os.getenv("SEARCH_ENGINE", "FaissFlat")
27+
DISTANCE_STRATEGY = os.getenv("DISTANCE_STRATEGY", "IP")
2628

2729
logger = CustomLogger("opea_dataprep_vdms_multimodal")
2830
logflag = os.getenv("LOGFLAG", False)
@@ -72,6 +74,7 @@ def store_into_vectordb(self, vs, metadata_file_path, dimensions):
7274
metadata_list = [data]
7375
if vs.selected_db == "vdms":
7476
vs.video_db.add_videos(
77+
texts=video_name_list,
7578
paths=video_name_list,
7679
metadatas=metadata_list,
7780
start_time=[data["timestamp"]],
@@ -145,14 +148,21 @@ async def ingest_videos(self, files: List[UploadFile] = File(None)):
145148
# init meanclip model
146149
model = self.setup_vclip_model(meanclip_cfg, device="cpu")
147150
vs = store_embeddings.VideoVS(
148-
host, port, selected_db, model, collection_name, embedding_dimensions=vector_dimensions
151+
host,
152+
port,
153+
selected_db,
154+
model,
155+
collection_name,
156+
embedding_dimensions=vector_dimensions,
157+
engine=SEARCH_ENGINE,
158+
distance_strategy=DISTANCE_STRATEGY,
149159
)
150160
logger.info("done creating DB, sleep 5s")
151161
await asyncio.sleep(5)
152162

153163
self.generate_embeddings(config, vector_dimensions, vs)
154164

155-
return {"message": "Videos ingested successfully"}
165+
return {"status": 200, "message": "Videos ingested successfully"}
156166

157167
async def get_videos(self):
158168
"""Returns list of names of uploaded videos saved on the server."""

comps/dataprep/src/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ einops
99
elasticsearch
1010
fastapi
1111
future
12-
graspologic
12+
graspologic
1313
html2text
1414
huggingface_hub
1515
ipython
@@ -21,6 +21,7 @@ langchain-openai
2121
langchain-pinecone
2222
langchain-redis
2323
langchain-text-splitters
24+
langchain-vdms
2425
langchain_huggingface
2526
langchain_milvus
2627
llama-index

comps/retrievers/deployment/docker_compose/compose.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,18 @@ services:
179179
tei-embedding-serving:
180180
condition: service_healthy
181181

182+
retriever-vdms-multimodal:
183+
extends: retriever
184+
container_name: retriever-vdms-multimodal
185+
environment:
186+
RETRIEVER_COMPONENT_NAME: "OPEA_RETRIEVER_VDMS"
187+
VDMS_INDEX_NAME: ${INDEX_NAME}
188+
VDMS_HOST: ${host_ip}
189+
VDMS_PORT: ${VDMS_PORT}
190+
VDMS_USE_CLIP: ${VDMS_USE_CLIP}
191+
depends_on:
192+
vdms-vector-db:
193+
condition: service_healthy
182194

183195
networks:
184196
default:

comps/retrievers/src/integrations/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,5 +184,5 @@ def format_opensearch_conn_from_env():
184184
VDMS_PORT = int(os.getenv("VDMS_PORT", 55555))
185185
VDMS_INDEX_NAME = os.getenv("VDMS_INDEX_NAME", "rag_vdms")
186186
VDMS_USE_CLIP = int(os.getenv("VDMS_USE_CLIP", 0))
187-
SEARCH_ENGINE = "FaissFlat"
188-
DISTANCE_STRATEGY = "IP"
187+
SEARCH_ENGINE = os.getenv("SEARCH_ENGINE", "FaissFlat")
188+
DISTANCE_STRATEGY = os.getenv("DISTANCE_STRATEGY", "IP")

comps/retrievers/src/integrations/vdms.py

Lines changed: 150 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,19 @@
33

44

55
import os
6+
from typing import Any, Dict, List
67

7-
from fastapi import HTTPException
8-
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceInferenceAPIEmbeddings
9-
from langchain_community.vectorstores.vdms import VDMS, VDMS_Client
8+
import numpy as np
9+
import torch.nn as nn
10+
import torchvision.transforms as T
11+
from decord import VideoReader, cpu
12+
from einops import rearrange
13+
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
14+
from langchain_core.embeddings import Embeddings
15+
from langchain_vdms.vectorstores import VDMS, VDMS_Client
16+
from pydantic import BaseModel, model_validator
17+
from torch import cat as torch_cat
18+
from transformers import AutoProcessor, AutoTokenizer, CLIPModel
1019

1120
from comps import CustomLogger, EmbedDoc, OpeaComponent, OpeaComponentRegistry, ServiceType
1221

@@ -24,51 +33,43 @@
2433

2534
logger = CustomLogger("vdms_retrievers")
2635
logflag = os.getenv("LOGFLAG", False)
36+
toPIL = T.ToPILImage()
2737

2838

2939
@OpeaComponentRegistry.register("OPEA_RETRIEVER_VDMS")
3040
class OpeaVDMsRetriever(OpeaComponent):
3141
"""A specialized retriever component derived from OpeaComponent for vdms retriever services.
3242
3343
Attributes:
34-
client (VDMs): An instance of the vdms client for vector database operations.
44+
client (VDMS): An instance of the vdms client for vector database operations.
3545
"""
3646

3747
def __init__(self, name: str, description: str, config: dict = None):
3848
super().__init__(name, ServiceType.RETRIEVER.name.lower(), description, config)
3949

4050
self.embedder = self._initialize_embedder()
41-
self.client = VDMS_Client(VDMS_HOST, VDMS_PORT)
51+
self.client = VDMS_Client(host=VDMS_HOST, port=VDMS_PORT)
4252
self.vector_db = self._initialize_vector_db()
4353
health_status = self.check_health()
4454
if not health_status:
4555
logger.error("OpeaVDMsRetriever health check failed.")
4656

4757
def _initialize_embedder(self):
4858
if VDMS_USE_CLIP:
49-
from comps.third_parties.clip.src.clip_embedding import vCLIP
59+
meanclip_cfg = {
60+
"model_name": "openai/clip-vit-base-patch32",
61+
"num_frm": 64,
62+
}
63+
video_retriever_model = vCLIP(meanclip_cfg) # , device="cpu")
64+
embeddings = vCLIPEmbeddings(model=video_retriever_model)
5065

51-
embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 64})
5266
elif TEI_EMBEDDING_ENDPOINT:
5367
# create embeddings using TEI endpoint service
5468
if logflag:
5569
logger.info(f"[ init embedder ] TEI_EMBEDDING_ENDPOINT:{TEI_EMBEDDING_ENDPOINT}")
56-
if not HUGGINGFACEHUB_API_TOKEN:
57-
raise HTTPException(
58-
status_code=400,
59-
detail="You MUST offer the `HUGGINGFACEHUB_API_TOKEN` when using `TEI_EMBEDDING_ENDPOINT`.",
60-
)
61-
import requests
62-
63-
response = requests.get(TEI_EMBEDDING_ENDPOINT + "/info")
64-
if response.status_code != 200:
65-
raise HTTPException(
66-
status_code=400, detail=f"TEI embedding endpoint {TEI_EMBEDDING_ENDPOINT} is not available."
67-
)
68-
model_id = response.json()["model_id"]
69-
embeddings = HuggingFaceInferenceAPIEmbeddings(
70-
api_key=HUGGINGFACEHUB_API_TOKEN, model_name=model_id, api_url=TEI_EMBEDDING_ENDPOINT
71-
)
70+
from langchain_huggingface import HuggingFaceEndpointEmbeddings
71+
72+
embeddings = HuggingFaceEndpointEmbeddings(model=TEI_EMBEDDING_ENDPOINT)
7273
else:
7374
# create embeddings using local embedding model
7475
if logflag:
@@ -78,24 +79,13 @@ def _initialize_embedder(self):
7879

7980
def _initialize_vector_db(self) -> VDMS:
8081
"""Initializes the vdms client."""
81-
if VDMS_USE_CLIP:
82-
dimensions = self.embedder.get_embedding_length()
83-
vector_db = VDMS(
84-
client=self.client,
85-
embedding=self.embedder,
86-
collection_name=VDMS_INDEX_NAME,
87-
embedding_dimensions=dimensions,
88-
distance_strategy=DISTANCE_STRATEGY,
89-
engine=SEARCH_ENGINE,
90-
)
91-
else:
92-
vector_db = VDMS(
93-
client=self.client,
94-
embedding=self.embedder,
95-
collection_name=VDMS_INDEX_NAME,
96-
distance_strategy=DISTANCE_STRATEGY,
97-
engine=SEARCH_ENGINE,
98-
)
82+
vector_db = VDMS(
83+
client=self.client,
84+
embedding=self.embedder,
85+
collection_name=VDMS_INDEX_NAME,
86+
distance_strategy=DISTANCE_STRATEGY,
87+
engine=SEARCH_ENGINE,
88+
)
9989
return vector_db
10090

10191
def check_health(self) -> bool:
@@ -154,8 +144,127 @@ async def invoke(self, input: EmbedDoc) -> list:
154144
lambda_mult=input.lambda_mult,
155145
filter=input.constraints,
156146
)
147+
else:
148+
raise ValueError(f"{input.search_type} not valid")
157149

158150
if logflag:
159151
logger.info(f"retrieve result: {search_res}")
160152

161153
return search_res
154+
155+
156+
class vCLIPEmbeddings(BaseModel, Embeddings):
157+
"""MeanCLIP Embeddings model."""
158+
159+
model: Any
160+
161+
def get_embedding_length(self):
162+
text_features = self.embed_query("sample_text")
163+
t_len = len(text_features)
164+
logger.info(f"text_features: {t_len}")
165+
return t_len
166+
167+
@model_validator(mode="before")
168+
def validate_environment(cls, values: Dict) -> Dict:
169+
"""Validate that open_clip and torch libraries are installed."""
170+
try:
171+
# Use the provided model if present
172+
if "model" not in values:
173+
raise ValueError("Model must be provided during initialization.")
174+
175+
except ImportError:
176+
raise ImportError("Please ensure CLIP model is loaded")
177+
return values
178+
179+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
180+
model_device = next(self.model.clip.parameters()).device
181+
text_features = self.model.get_text_embeddings(texts)
182+
183+
return text_features.detach().numpy()
184+
185+
def embed_query(self, text: str) -> List[float]:
186+
return self.embed_documents([text])[0]
187+
188+
def embed_video(self, paths: List[str], **kwargs: Any) -> List[List[float]]:
189+
# Open images directly as PIL images
190+
191+
video_features = []
192+
for vid_path in sorted(paths):
193+
# Encode the video to get the embeddings
194+
model_device = next(self.model.parameters()).device
195+
# Preprocess the video for the model
196+
clip_images = self.load_video_for_vclip(
197+
vid_path,
198+
num_frm=self.model.num_frm,
199+
max_img_size=224,
200+
start_time=kwargs.get("start_time", None),
201+
clip_duration=kwargs.get("clip_duration", None),
202+
)
203+
embeddings_tensor = self.model.get_video_embeddings([clip_images])
204+
205+
# Convert tensor to list and add to the video_features list
206+
embeddings_list = embeddings_tensor.tolist()
207+
208+
video_features.append(embeddings_list)
209+
210+
return video_features
211+
212+
def load_video_for_vclip(self, vid_path, num_frm=4, max_img_size=224, **kwargs):
213+
# Load video with VideoReader
214+
import decord
215+
216+
decord.bridge.set_bridge("torch")
217+
vr = VideoReader(vid_path, ctx=cpu(0))
218+
fps = vr.get_avg_fps()
219+
num_frames = len(vr)
220+
start_idx = int(fps * kwargs.get("start_time", [0])[0])
221+
end_idx = start_idx + int(fps * kwargs.get("clip_duration", [num_frames])[0])
222+
223+
frame_idx = np.linspace(start_idx, end_idx, num=num_frm, endpoint=False, dtype=int) # Uniform sampling
224+
clip_images = []
225+
226+
# read images
227+
temp_frms = vr.get_batch(frame_idx.astype(int).tolist())
228+
for idx in range(temp_frms.shape[0]):
229+
im = temp_frms[idx] # H W C
230+
clip_images.append(toPIL(im.permute(2, 0, 1)))
231+
232+
return clip_images
233+
234+
235+
class vCLIP(nn.Module):
236+
def __init__(self, cfg):
237+
super().__init__()
238+
239+
self.num_frm = cfg["num_frm"]
240+
self.model_name = cfg["model_name"]
241+
242+
self.clip = CLIPModel.from_pretrained(self.model_name)
243+
self.processor = AutoProcessor.from_pretrained(self.model_name)
244+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
245+
246+
def get_text_embeddings(self, texts):
247+
"""Input is list of texts."""
248+
text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
249+
text_features = self.clip.get_text_features(**text_inputs)
250+
return text_features
251+
252+
def get_image_embeddings(self, images):
253+
"""Input is list of images."""
254+
image_inputs = self.processor(images=images, return_tensors="pt")
255+
image_features = self.clip.get_image_features(**image_inputs)
256+
return image_features
257+
258+
def get_video_embeddings(self, frames_batch):
259+
"""Input is list of list of frames in video."""
260+
self.batch_size = len(frames_batch)
261+
vid_embs = []
262+
for frames in frames_batch:
263+
frame_embeddings = self.get_image_embeddings(frames)
264+
frame_embeddings = rearrange(frame_embeddings, "(b n) d -> b n d", b=len(frames_batch))
265+
# Normalize, mean aggregate and return normalized video_embeddings
266+
frame_embeddings = frame_embeddings / frame_embeddings.norm(dim=-1, keepdim=True)
267+
video_embeddings = frame_embeddings.mean(dim=1)
268+
video_embeddings = video_embeddings / video_embeddings.norm(dim=-1, keepdim=True)
269+
vid_embs.append(video_embeddings)
270+
return torch_cat(vid_embs, dim=0)

0 commit comments

Comments
 (0)