Skip to content

Commit c477648

Browse files
committed
Update deprecated pydantic functions, fix dataprep vdms_multimodal output, add vdms retriever without tei, make search engine and distance strategy configurable, fix vdms retrieval of clip, fix vdms dataprep test scripts, and fix vdms retrieval test script to test both with and without clip. These changes should help fix opea-project/GenAIExamples#1476
Signed-off-by: Lacewell, Chaunte W <chaunte.w.lacewell@intel.com>
1 parent 5dba6a4 commit c477648

File tree

9 files changed

+258
-81
lines changed

9 files changed

+258
-81
lines changed

comps/dataprep/src/integrations/utils/store_embeddings.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import numpy as np
77
import torchvision.transforms as T
88
from decord import VideoReader, cpu
9-
from langchain.pydantic_v1 import BaseModel, root_validator
109
from langchain_core.embeddings import Embeddings
1110
from langchain_vdms.vectorstores import VDMS, VDMS_Client
11+
from pydantic import BaseModel, model_validator
1212

1313
toPIL = T.ToPILImage()
1414

@@ -20,7 +20,7 @@ class vCLIPEmbeddings(BaseModel, Embeddings):
2020

2121
model: Any
2222

23-
@root_validator(allow_reuse=True)
23+
@model_validator(mode="before")
2424
def validate_environment(cls, values: Dict) -> Dict:
2525
"""Validate that open_clip and torch libraries are installed."""
2626
try:
@@ -98,6 +98,8 @@ def __init__(
9898
collection_name,
9999
embedding_dimensions: int = 512,
100100
chosen_video_search_type="similarity",
101+
engine: str = "FaissFlat",
102+
distance_strategy: str = "IP",
101103
):
102104

103105
self.host = host
@@ -109,6 +111,8 @@ def __init__(
109111
self.video_embedder = vCLIPEmbeddings(model=video_retriever_model)
110112
self.chosen_video_search_type = chosen_video_search_type
111113
self.embedding_dimensions = embedding_dimensions
114+
self.engine = engine
115+
self.distance_strategy = distance_strategy
112116

113117
# initialize_db
114118
self.get_db_client()
@@ -127,7 +131,7 @@ def init_db(self):
127131
client=self.client,
128132
embedding=self.video_embedder,
129133
collection_name=self.video_collection,
130-
engine="FaissFlat",
131-
distance_strategy="IP",
134+
engine=self.engine,
135+
distance_strategy=self.distance_strategy,
132136
embedding_dimensions=self.embedding_dimensions,
133137
)

comps/dataprep/src/integrations/vdms_multimodal.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
VECTORDB_SERVICE_HOST_IP = os.getenv("VDMS_HOST", "0.0.0.0")
2424
VECTORDB_SERVICE_PORT = os.getenv("VDMS_PORT", 55555)
2525
collection_name = os.getenv("INDEX_NAME", "rag-vdms")
26+
SEARCH_ENGINE = os.getenv("SEARCH_ENGINE", "FaissFlat")
27+
DISTANCE_STRATEGY = os.getenv("DISTANCE_STRATEGY", "IP")
2628

2729
logger = CustomLogger("opea_dataprep_vdms_multimodal")
2830
logflag = os.getenv("LOGFLAG", False)
@@ -145,14 +147,21 @@ async def ingest_videos(self, files: List[UploadFile] = File(None)):
145147
# init meanclip model
146148
model = self.setup_vclip_model(meanclip_cfg, device="cpu")
147149
vs = store_embeddings.VideoVS(
148-
host, port, selected_db, model, collection_name, embedding_dimensions=vector_dimensions
150+
host,
151+
port,
152+
selected_db,
153+
model,
154+
collection_name,
155+
embedding_dimensions=vector_dimensions,
156+
engine=SEARCH_ENGINE,
157+
distance_strategy=DISTANCE_STRATEGY,
149158
)
150159
logger.info("done creating DB, sleep 5s")
151160
await asyncio.sleep(5)
152161

153162
self.generate_embeddings(config, vector_dimensions, vs)
154163

155-
return {"message": "Videos ingested successfully"}
164+
return {"status": 200, "message": "Videos ingested successfully"}
156165

157166
async def get_videos(self):
158167
"""Returns list of names of uploaded videos saved on the server."""

comps/retrievers/deployment/docker_compose/compose.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,18 @@ services:
179179
tei-embedding-serving:
180180
condition: service_healthy
181181

182+
retriever-vdms-multimodal:
183+
extends: retriever
184+
container_name: retriever-vdms-multimodal
185+
environment:
186+
RETRIEVER_COMPONENT_NAME: "OPEA_RETRIEVER_VDMS"
187+
VDMS_INDEX_NAME: ${INDEX_NAME}
188+
VDMS_HOST: ${host_ip}
189+
VDMS_PORT: ${VDMS_PORT}
190+
VDMS_USE_CLIP: ${VDMS_USE_CLIP}
191+
depends_on:
192+
vdms-vector-db:
193+
condition: service_healthy
182194

183195
networks:
184196
default:

comps/retrievers/src/integrations/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,5 +184,5 @@ def format_opensearch_conn_from_env():
184184
VDMS_PORT = int(os.getenv("VDMS_PORT", 55555))
185185
VDMS_INDEX_NAME = os.getenv("VDMS_INDEX_NAME", "rag_vdms")
186186
VDMS_USE_CLIP = int(os.getenv("VDMS_USE_CLIP", 0))
187-
SEARCH_ENGINE = "FaissFlat"
188-
DISTANCE_STRATEGY = "IP"
187+
SEARCH_ENGINE = os.getenv("SEARCH_ENGINE", "FaissFlat")
188+
DISTANCE_STRATEGY = os.getenv("DISTANCE_STRATEGY", "IP")

comps/retrievers/src/integrations/vdms.py

Lines changed: 149 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,19 @@
33

44

55
import os
6+
from typing import Any, Dict, List
67

7-
from fastapi import HTTPException
8-
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceInferenceAPIEmbeddings
8+
import numpy as np
9+
import torch.nn as nn
10+
import torchvision.transforms as T
11+
from decord import VideoReader, cpu
12+
from einops import rearrange
13+
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
14+
from langchain_core.embeddings import Embeddings
915
from langchain_vdms.vectorstores import VDMS, VDMS_Client
16+
from pydantic import BaseModel, model_validator
17+
from torch import cat as torch_cat
18+
from transformers import AutoProcessor, AutoTokenizer, CLIPModel
1019

1120
from comps import CustomLogger, EmbedDoc, OpeaComponent, OpeaComponentRegistry, ServiceType
1221

@@ -24,51 +33,43 @@
2433

2534
logger = CustomLogger("vdms_retrievers")
2635
logflag = os.getenv("LOGFLAG", False)
36+
toPIL = T.ToPILImage()
2737

2838

2939
@OpeaComponentRegistry.register("OPEA_RETRIEVER_VDMS")
3040
class OpeaVDMsRetriever(OpeaComponent):
3141
"""A specialized retriever component derived from OpeaComponent for vdms retriever services.
3242
3343
Attributes:
34-
client (VDMs): An instance of the vdms client for vector database operations.
44+
client (VDMS): An instance of the vdms client for vector database operations.
3545
"""
3646

3747
def __init__(self, name: str, description: str, config: dict = None):
3848
super().__init__(name, ServiceType.RETRIEVER.name.lower(), description, config)
3949

4050
self.embedder = self._initialize_embedder()
41-
self.client = VDMS_Client(VDMS_HOST, VDMS_PORT)
51+
self.client = VDMS_Client(host=VDMS_HOST, port=VDMS_PORT)
4252
self.vector_db = self._initialize_vector_db()
4353
health_status = self.check_health()
4454
if not health_status:
4555
logger.error("OpeaVDMsRetriever health check failed.")
4656

4757
def _initialize_embedder(self):
4858
if VDMS_USE_CLIP:
49-
from comps.third_parties.clip.src.clip_embedding import vCLIP
59+
meanclip_cfg = {
60+
"model_name": "openai/clip-vit-base-patch32",
61+
"num_frm": 64,
62+
}
63+
video_retriever_model = vCLIP(meanclip_cfg) # , device="cpu")
64+
embeddings = vCLIPEmbeddings(model=video_retriever_model)
5065

51-
embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 64})
5266
elif TEI_EMBEDDING_ENDPOINT:
5367
# create embeddings using TEI endpoint service
5468
if logflag:
5569
logger.info(f"[ init embedder ] TEI_EMBEDDING_ENDPOINT:{TEI_EMBEDDING_ENDPOINT}")
56-
if not HUGGINGFACEHUB_API_TOKEN:
57-
raise HTTPException(
58-
status_code=400,
59-
detail="You MUST offer the `HUGGINGFACEHUB_API_TOKEN` when using `TEI_EMBEDDING_ENDPOINT`.",
60-
)
61-
import requests
62-
63-
response = requests.get(TEI_EMBEDDING_ENDPOINT + "/info")
64-
if response.status_code != 200:
65-
raise HTTPException(
66-
status_code=400, detail=f"TEI embedding endpoint {TEI_EMBEDDING_ENDPOINT} is not available."
67-
)
68-
model_id = response.json()["model_id"]
69-
embeddings = HuggingFaceInferenceAPIEmbeddings(
70-
api_key=HUGGINGFACEHUB_API_TOKEN, model_name=model_id, api_url=TEI_EMBEDDING_ENDPOINT
71-
)
70+
from langchain_huggingface import HuggingFaceEndpointEmbeddings
71+
72+
embeddings = HuggingFaceEndpointEmbeddings(model=TEI_EMBEDDING_ENDPOINT)
7273
else:
7374
# create embeddings using local embedding model
7475
if logflag:
@@ -78,24 +79,13 @@ def _initialize_embedder(self):
7879

7980
def _initialize_vector_db(self) -> VDMS:
8081
"""Initializes the vdms client."""
81-
if VDMS_USE_CLIP:
82-
dimensions = self.embedder.get_embedding_length()
83-
vector_db = VDMS(
84-
client=self.client,
85-
embedding=self.embedder,
86-
collection_name=VDMS_INDEX_NAME,
87-
embedding_dimensions=dimensions,
88-
distance_strategy=DISTANCE_STRATEGY,
89-
engine=SEARCH_ENGINE,
90-
)
91-
else:
92-
vector_db = VDMS(
93-
client=self.client,
94-
embedding=self.embedder,
95-
collection_name=VDMS_INDEX_NAME,
96-
distance_strategy=DISTANCE_STRATEGY,
97-
engine=SEARCH_ENGINE,
98-
)
82+
vector_db = VDMS(
83+
client=self.client,
84+
embedding=self.embedder,
85+
collection_name=VDMS_INDEX_NAME,
86+
distance_strategy=DISTANCE_STRATEGY,
87+
engine=SEARCH_ENGINE,
88+
)
9989
return vector_db
10090

10191
def check_health(self) -> bool:
@@ -154,8 +144,127 @@ async def invoke(self, input: EmbedDoc) -> list:
154144
lambda_mult=input.lambda_mult,
155145
filter=input.constraints,
156146
)
147+
else:
148+
raise ValueError(f"{input.search_type} not valid")
157149

158150
if logflag:
159151
logger.info(f"retrieve result: {search_res}")
160152

161153
return search_res
154+
155+
156+
class vCLIPEmbeddings(BaseModel, Embeddings):
157+
"""MeanCLIP Embeddings model."""
158+
159+
model: Any
160+
161+
def get_embedding_length(self):
162+
text_features = self.embed_query("sample_text")
163+
t_len = len(text_features)
164+
logger.info(f"text_features: {t_len}")
165+
return t_len
166+
167+
@model_validator(mode="before")
168+
def validate_environment(cls, values: Dict) -> Dict:
169+
"""Validate that open_clip and torch libraries are installed."""
170+
try:
171+
# Use the provided model if present
172+
if "model" not in values:
173+
raise ValueError("Model must be provided during initialization.")
174+
175+
except ImportError:
176+
raise ImportError("Please ensure CLIP model is loaded")
177+
return values
178+
179+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
180+
model_device = next(self.model.clip.parameters()).device
181+
text_features = self.model.get_text_embeddings(texts)
182+
183+
return text_features.detach().numpy()
184+
185+
def embed_query(self, text: str) -> List[float]:
186+
return self.embed_documents([text])[0]
187+
188+
def embed_video(self, paths: List[str], **kwargs: Any) -> List[List[float]]:
189+
# Open images directly as PIL images
190+
191+
video_features = []
192+
for vid_path in sorted(paths):
193+
# Encode the video to get the embeddings
194+
model_device = next(self.model.parameters()).device
195+
# Preprocess the video for the model
196+
clip_images = self.load_video_for_vclip(
197+
vid_path,
198+
num_frm=self.model.num_frm,
199+
max_img_size=224,
200+
start_time=kwargs.get("start_time", None),
201+
clip_duration=kwargs.get("clip_duration", None),
202+
)
203+
embeddings_tensor = self.model.get_video_embeddings([clip_images])
204+
205+
# Convert tensor to list and add to the video_features list
206+
embeddings_list = embeddings_tensor.tolist()
207+
208+
video_features.append(embeddings_list)
209+
210+
return video_features
211+
212+
def load_video_for_vclip(self, vid_path, num_frm=4, max_img_size=224, **kwargs):
213+
# Load video with VideoReader
214+
import decord
215+
216+
decord.bridge.set_bridge("torch")
217+
vr = VideoReader(vid_path, ctx=cpu(0))
218+
fps = vr.get_avg_fps()
219+
num_frames = len(vr)
220+
start_idx = int(fps * kwargs.get("start_time", [0])[0])
221+
end_idx = start_idx + int(fps * kwargs.get("clip_duration", [num_frames])[0])
222+
223+
frame_idx = np.linspace(start_idx, end_idx, num=num_frm, endpoint=False, dtype=int) # Uniform sampling
224+
clip_images = []
225+
226+
# read images
227+
temp_frms = vr.get_batch(frame_idx.astype(int).tolist())
228+
for idx in range(temp_frms.shape[0]):
229+
im = temp_frms[idx] # H W C
230+
clip_images.append(toPIL(im.permute(2, 0, 1)))
231+
232+
return clip_images
233+
234+
235+
class vCLIP(nn.Module):
236+
def __init__(self, cfg):
237+
super().__init__()
238+
239+
self.num_frm = cfg["num_frm"]
240+
self.model_name = cfg["model_name"]
241+
242+
self.clip = CLIPModel.from_pretrained(self.model_name)
243+
self.processor = AutoProcessor.from_pretrained(self.model_name)
244+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
245+
246+
def get_text_embeddings(self, texts):
247+
"""Input is list of texts."""
248+
text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
249+
text_features = self.clip.get_text_features(**text_inputs)
250+
return text_features
251+
252+
def get_image_embeddings(self, images):
253+
"""Input is list of images."""
254+
image_inputs = self.processor(images=images, return_tensors="pt")
255+
image_features = self.clip.get_image_features(**image_inputs)
256+
return image_features
257+
258+
def get_video_embeddings(self, frames_batch):
259+
"""Input is list of list of frames in video."""
260+
self.batch_size = len(frames_batch)
261+
vid_embs = []
262+
for frames in frames_batch:
263+
frame_embeddings = self.get_image_embeddings(frames)
264+
frame_embeddings = rearrange(frame_embeddings, "(b n) d -> b n d", b=len(frames_batch))
265+
# Normalize, mean aggregate and return normalized video_embeddings
266+
frame_embeddings = frame_embeddings / frame_embeddings.norm(dim=-1, keepdim=True)
267+
video_embeddings = frame_embeddings.mean(dim=1)
268+
video_embeddings = video_embeddings / video_embeddings.norm(dim=-1, keepdim=True)
269+
vid_embs.append(video_embeddings)
270+
return torch_cat(vid_embs, dim=0)

comps/retrievers/src/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
aiofiles
22
bs4
33
cairosvg
4+
decord
45
docarray[full]
56
docx2txt
67
easyocr

tests/dataprep/test_dataprep_vdms.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ function start_service() {
2929
export VDMS_HOST=$ip_address
3030
export VDMS_PORT=55555
3131
export COLLECTION_NAME="test-comps"
32-
export QDRANT_HOST=$ip_address
33-
export QDRANT_PORT=$QDRANT_PORT
32+
export VDMS_HOST=$ip_address
33+
export VDMS_PORT=$VDMS_PORT
3434
service_name="vdms-vector-db dataprep-vdms"
3535
cd $WORKPATH/comps/dataprep/deployment/docker_compose/
3636
docker compose up ${service_name} -d

tests/dataprep/test_dataprep_vdms_multimodal.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ function start_service() {
2828
export VDMS_HOST=$ip_address
2929
export VDMS_PORT=55555
3030
export COLLECTION_NAME="test-comps"
31-
export QDRANT_HOST=$ip_address
32-
export QDRANT_PORT=$QDRANT_PORT
31+
export VDMS_HOST=$ip_address
32+
export VDMS_PORT=$VDMS_PORT
3333
export TAG="comps"
3434
service_name="vdms-vector-db dataprep-vdms-multimodal"
3535
cd $WORKPATH/comps/dataprep/deployment/docker_compose/

0 commit comments

Comments
 (0)