Skip to content

Commit 9836332

Browse files
fix bugs in DocIndexRetriever (opea-project#1770)
Signed-off-by: minmin-intel <minmin.hou@intel.com> Signed-off-by: Chingis Yundunov <c.yundunov@datamonsters.com>
1 parent 54a6525 commit 9836332

12 files changed

+155
-114
lines changed

DocIndexRetriever/README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,22 @@ Example usage:
8080
```python
8181
url = "http://{host_ip}:{port}/v1/retrievaltool".format(host_ip=host_ip, port=port)
8282
payload = {
83-
"messages": query,
83+
"messages": query, # must be a string, this is a required field
8484
"k": 5, # retriever top k
8585
"top_n": 2, # reranker top n
8686
}
8787
response = requests.post(url, json=payload)
8888
```
89+
90+
**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below.
91+
92+
1. retriever
93+
* search_type: str = "similarity"
94+
* k: int = 4
95+
* distance_threshold: Optional[float] = None
96+
* fetch_k: int = 20
97+
* lambda_mult: float = 0.5
98+
* score_threshold: float = 0.2
99+
100+
2. reranker
101+
* top_n: int = 1

DocIndexRetriever/docker_compose/intel/cpu/xeon/README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ Retrieval from KnowledgeBase
9797
curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
9898
"messages": "Explain the OPEA project?"
9999
}'
100-
101-
# expected output
102-
{"id":"354e62c703caac8c547b3061433ec5e8","reranked_docs":[{"id":"06d5a5cefc06cf9a9e0b5fa74a9f233c","text":"Close SearchsearchMenu WikiNewsCommunity Daysx-twitter linkedin github searchStreamlining implementation of enterprise-grade Generative AIEfficiently integrate secure, performant, and cost-effective Generative AI workflows into business value.TODAYOPEA..."}],"initial_query":"Explain the OPEA project?"}
103100
```
104101

105102
**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below.
@@ -128,7 +125,7 @@ curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: applicati
128125
# embedding microservice
129126
curl http://${host_ip}:6000/v1/embeddings \
130127
-X POST \
131-
-d '{"text":"Explain the OPEA project"}' \
128+
-d '{"messages":"Explain the OPEA project"}' \
132129
-H 'Content-Type: application/json' > query
133130
docker container logs embedding-server
134131

DocIndexRetriever/docker_compose/intel/cpu/xeon/compose.yaml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ services:
1313
dataprep-redis-service:
1414
image: ${REGISTRY:-opea}/dataprep:${TAG:-latest}
1515
container_name: dataprep-redis-server
16-
# volumes:
17-
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
1816
depends_on:
19-
- redis-vector-db
17+
redis-vector-db:
18+
condition: service_started
19+
tei-embedding-service:
20+
condition: service_healthy
2021
ports:
2122
- "6007:5000"
2223
- "6008:6008"
@@ -28,7 +29,7 @@ services:
2829
REDIS_URL: ${REDIS_URL}
2930
REDIS_HOST: ${REDIS_HOST}
3031
INDEX_NAME: ${INDEX_NAME}
31-
TEI_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
32+
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
3233
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
3334
LOGFLAG: ${LOGFLAG}
3435
tei-embedding-service:
@@ -54,8 +55,6 @@ services:
5455
embedding:
5556
image: ${REGISTRY:-opea}/embedding:${TAG:-latest}
5657
container_name: embedding-server
57-
# volumes:
58-
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/comps
5958
ports:
6059
- "6000:6000"
6160
ipc: host
@@ -114,8 +113,6 @@ services:
114113
reranking:
115114
image: ${REGISTRY:-opea}/reranking:${TAG:-latest}
116115
container_name: reranking-tei-xeon-server
117-
# volumes:
118-
# - $WORKDIR/GenAIExamples/DocIndexRetriever/docker_image_build/GenAIComps/comps:/home/user/comps
119116
depends_on:
120117
tei-reranking-service:
121118
condition: service_healthy

DocIndexRetriever/docker_compose/intel/hpu/gaudi/README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,6 @@ Retrieval from KnowledgeBase
8787
curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
8888
"messages": "Explain the OPEA project?"
8989
}'
90-
91-
# expected output
92-
{"id":"354e62c703caac8c547b3061433ec5e8","reranked_docs":[{"id":"06d5a5cefc06cf9a9e0b5fa74a9f233c","text":"Close SearchsearchMenu WikiNewsCommunity Daysx-twitter linkedin github searchStreamlining implementation of enterprise-grade Generative AIEfficiently integrate secure, performant, and cost-effective Generative AI workflows into business value.TODAYOPEA..."}],"initial_query":"Explain the OPEA project?"}
9390
```
9491

9592
**Note**: `messages` is the required field. You can also pass in parameters for the retriever and reranker in the request. The parameters that can changed are listed below.
@@ -118,7 +115,7 @@ curl http://${host_ip}:8889/v1/retrievaltool -X POST -H "Content-Type: applicati
118115
# embedding microservice
119116
curl http://${host_ip}:6000/v1/embeddings \
120117
-X POST \
121-
-d '{"text":"Explain the OPEA project"}' \
118+
-d '{"messages":"Explain the OPEA project"}' \
122119
-H 'Content-Type: application/json' > query
123120
docker container logs embedding-server
124121

DocIndexRetriever/docker_compose/intel/hpu/gaudi/compose.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ services:
1515
image: ${REGISTRY:-opea}/dataprep:${TAG:-latest}
1616
container_name: dataprep-redis-server
1717
depends_on:
18-
- redis-vector-db
19-
- tei-embedding-service
18+
redis-vector-db:
19+
condition: service_started
20+
tei-embedding-service:
21+
condition: service_healthy
2022
ports:
2123
- "6007:5000"
2224
environment:
@@ -25,7 +27,7 @@ services:
2527
https_proxy: ${https_proxy}
2628
REDIS_URL: ${REDIS_URL}
2729
INDEX_NAME: ${INDEX_NAME}
28-
TEI_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
30+
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
2931
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
3032
tei-embedding-service:
3133
image: ghcr.io/huggingface/tei-gaudi:1.5.0
@@ -87,6 +89,8 @@ services:
8789
INDEX_NAME: ${INDEX_NAME}
8890
LOGFLAG: ${LOGFLAG}
8991
RETRIEVER_COMPONENT_NAME: "OPEA_RETRIEVER_REDIS"
92+
TEI_EMBEDDING_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT}
93+
HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN}
9094
restart: unless-stopped
9195
tei-reranking-service:
9296
image: ghcr.io/huggingface/text-embeddings-inference:cpu-1.6

DocIndexRetriever/retrieval_tool.py

Lines changed: 94 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88

99
from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType
1010
from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest
11-
from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
11+
from comps.cores.proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
1212
from fastapi import Request
13-
from fastapi.responses import StreamingResponse
1413

1514
MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889)
1615
EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0")
@@ -22,41 +21,75 @@
2221

2322

2423
def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
25-
print(f"Inputs to {cur_node}: {inputs}")
24+
print(f"*** Inputs to {cur_node}:\n{inputs}")
25+
print("--" * 50)
2626
for key, value in kwargs.items():
2727
print(f"{key}: {value}")
28+
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
29+
inputs["input"] = inputs["text"]
30+
del inputs["text"]
31+
elif self.services[cur_node].service_type == ServiceType.RETRIEVER:
32+
# input is EmbedDoc
33+
"""Class EmbedDoc(BaseDoc):
34+
35+
text: Union[str, List[str]]
36+
embedding: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]]
37+
search_type: str = "similarity"
38+
k: int = 4
39+
distance_threshold: Optional[float] = None
40+
fetch_k: int = 20
41+
lambda_mult: float = 0.5
42+
score_threshold: float = 0.2
43+
constraints: Optional[Union[Dict[str, Any], List[Dict[str, Any]], None]] = None
44+
index_name: Optional[str] = None
45+
"""
46+
# prepare the retriever params
47+
retriever_parameters = kwargs.get("retriever_parameters", None)
48+
if retriever_parameters:
49+
inputs.update(retriever_parameters.dict())
50+
elif self.services[cur_node].service_type == ServiceType.RERANK:
51+
# input is SearchedDoc
52+
"""Class SearchedDoc(BaseDoc):
53+
54+
retrieved_docs: DocList[TextDoc]
55+
initial_query: str
56+
top_n: int = 1
57+
"""
58+
# prepare the reranker params
59+
reranker_parameters = kwargs.get("reranker_parameters", None)
60+
if reranker_parameters:
61+
inputs.update(reranker_parameters.dict())
62+
print(f"*** Formatted Inputs to {cur_node}:\n{inputs}")
63+
print("--" * 50)
2864
return inputs
2965

3066

3167
def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs):
32-
next_data = {}
33-
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
34-
# turn into chat completion request
35-
# next_data = {"text": inputs["input"], "embedding": [item["embedding"] for item in data["data"]]}
36-
print("Assembing output from Embedding for next node...")
37-
print("Inputs to Embedding: ", inputs)
38-
print("Keyword arguments: ")
39-
for key, value in kwargs.items():
40-
print(f"{key}: {value}")
41-
42-
next_data = {
43-
"input": inputs["input"],
44-
"messages": inputs["input"],
45-
"embedding": [item["embedding"] for item in data["data"]],
46-
"k": kwargs["k"] if "k" in kwargs else 4,
47-
"search_type": kwargs["search_type"] if "search_type" in kwargs else "similarity",
48-
"distance_threshold": kwargs["distance_threshold"] if "distance_threshold" in kwargs else None,
49-
"fetch_k": kwargs["fetch_k"] if "fetch_k" in kwargs else 20,
50-
"lambda_mult": kwargs["lambda_mult"] if "lambda_mult" in kwargs else 0.5,
51-
"score_threshold": kwargs["score_threshold"] if "score_threshold" in kwargs else 0.2,
52-
"top_n": kwargs["top_n"] if "top_n" in kwargs else 1,
53-
}
54-
55-
print("Output from Embedding for next node:\n", next_data)
68+
print(f"*** Direct Outputs from {cur_node}:\n{data}")
69+
print("--" * 50)
5670

71+
if self.services[cur_node].service_type == ServiceType.EMBEDDING:
72+
# direct output from Embedding microservice is EmbeddingResponse
73+
"""
74+
class EmbeddingResponse(BaseModel):
75+
object: str = "list"
76+
model: Optional[str] = None
77+
data: List[EmbeddingResponseData]
78+
usage: Optional[UsageInfo] = None
79+
80+
class EmbeddingResponseData(BaseModel):
81+
index: int
82+
object: str = "embedding"
83+
embedding: Union[List[float], str]
84+
"""
85+
# turn it into EmbedDoc
86+
assert isinstance(data["data"], list)
87+
next_data = {"text": inputs["input"], "embedding": data["data"][0]["embedding"]} # EmbedDoc
5788
else:
5889
next_data = data
5990

91+
print(f"*** Formatted Output from {cur_node} for next node:\n", next_data)
92+
print("--" * 50)
6093
return next_data
6194

6295

@@ -100,54 +133,41 @@ def add_remote_service(self):
100133
self.megaservice.flow_to(retriever, rerank)
101134

102135
async def handle_request(self, request: Request):
103-
def parser_input(data, TypeClass, key):
104-
chat_request = None
105-
try:
106-
chat_request = TypeClass.parse_obj(data)
107-
query = getattr(chat_request, key)
108-
except:
109-
query = None
110-
return query, chat_request
111-
112136
data = await request.json()
113-
query = None
114-
for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
115-
query, chat_request = parser_input(data, TypeClass, key)
116-
if query is not None:
117-
break
118-
if query is None:
119-
raise ValueError(f"Unknown request type: {data}")
120-
if chat_request is None:
121-
raise ValueError(f"Unknown request type: {data}")
122-
123-
if isinstance(chat_request, ChatCompletionRequest):
124-
initial_inputs = {
125-
"messages": query,
126-
"input": query, # has to be input due to embedding expects either input or text
127-
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
128-
"k": chat_request.k if chat_request.k else 4,
129-
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
130-
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
131-
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
132-
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
133-
"top_n": chat_request.top_n if chat_request.top_n else 1,
134-
}
135-
136-
kwargs = {
137-
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
138-
"k": chat_request.k if chat_request.k else 4,
139-
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
140-
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
141-
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
142-
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
143-
"top_n": chat_request.top_n if chat_request.top_n else 1,
144-
}
145-
result_dict, runtime_graph = await self.megaservice.schedule(
146-
initial_inputs=initial_inputs,
147-
**kwargs,
148-
)
149-
else:
150-
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"input": query})
137+
chat_request = ChatCompletionRequest.parse_obj(data)
138+
139+
prompt = chat_request.messages
140+
141+
# dummy llm params
142+
parameters = LLMParams(
143+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
144+
top_k=chat_request.top_k if chat_request.top_k else 10,
145+
top_p=chat_request.top_p if chat_request.top_p else 0.95,
146+
temperature=chat_request.temperature if chat_request.temperature else 0.01,
147+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
148+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
149+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
150+
chat_template=chat_request.chat_template if chat_request.chat_template else None,
151+
model=chat_request.model if chat_request.model else None,
152+
)
153+
154+
retriever_parameters = RetrieverParms(
155+
search_type=chat_request.search_type if chat_request.search_type else "similarity",
156+
k=chat_request.k if chat_request.k else 4,
157+
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
158+
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
159+
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
160+
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
161+
)
162+
reranker_parameters = RerankerParms(
163+
top_n=chat_request.top_n if chat_request.top_n else 1,
164+
)
165+
result_dict, runtime_graph = await self.megaservice.schedule(
166+
initial_inputs={"text": prompt},
167+
llm_parameters=parameters,
168+
retriever_parameters=retriever_parameters,
169+
reranker_parameters=reranker_parameters,
170+
)
151171

152172
last_node = runtime_graph.all_leaves()[-1]
153173
response = result_dict[last_node]

DocIndexRetriever/tests/test.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import os
5+
from typing import Any
56

67
import requests
78

89

9-
def search_knowledge_base(query: str) -> str:
10+
def search_knowledge_base(query: str, args: Any) -> str:
1011
"""Search the knowledge base for a specific query."""
11-
url = os.environ.get("RETRIEVAL_TOOL_URL")
12+
url = os.environ.get("RETRIEVAL_TOOL_URL", "http://localhost:8889/v1/retrievaltool")
1213
print(url)
1314
proxies = {"http": ""}
14-
payload = {"messages": query, "k": 5, "top_n": 2}
15+
payload = {"messages": query, "k": args.k, "top_n": args.top_n}
1516
response = requests.post(url, json=payload, proxies=proxies)
1617
print(response)
1718
if "documents" in response.json():
@@ -33,6 +34,16 @@ def search_knowledge_base(query: str) -> str:
3334

3435

3536
if __name__ == "__main__":
36-
resp = search_knowledge_base("What is OPEA?")
37-
# resp = search_knowledge_base("Thriller")
37+
import argparse
38+
39+
parser = argparse.ArgumentParser(description="Test the knowledge base search.")
40+
parser.add_argument("--k", type=int, default=5, help="retriever top k")
41+
parser.add_argument("--top_n", type=int, default=2, help="reranker top n")
42+
args = parser.parse_args()
43+
44+
resp = search_knowledge_base("What is OPEA?", args)
45+
3846
print(resp)
47+
48+
if not resp.startswith("Error"):
49+
print("Test successful!")

DocIndexRetriever/tests/test_compose_milvus_on_gaudi.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ function validate_megaservice() {
8888
fi
8989

9090
# Curl the Mega Service
91-
echo "================Testing retriever service: Text Request ================"
91+
echo "================Testing retriever service ================"
9292
cd $WORKPATH/tests
9393
local CONTENT=$(http_proxy="" curl http://${ip_address}:8889/v1/retrievaltool -X POST -H "Content-Type: application/json" -d '{
94-
"text": "Explain the OPEA project?"
94+
"messages": "Explain the OPEA project?"
9595
}')
9696

9797
local EXIT_CODE=$(validate "$CONTENT" "OPEA" "doc-index-retriever-service-gaudi")

0 commit comments

Comments
 (0)