|
8 | 8 |
|
9 | 9 | from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType
|
10 | 10 | from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest
|
11 |
| -from comps.cores.proto.docarray import LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc |
| 11 | +from comps.cores.proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc |
12 | 12 | from fastapi import Request
|
13 |
| -from fastapi.responses import StreamingResponse |
14 | 13 |
|
15 | 14 | MEGA_SERVICE_PORT = os.getenv("MEGA_SERVICE_PORT", 8889)
|
16 | 15 | EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0")
|
|
22 | 21 |
|
23 | 22 |
|
24 | 23 | def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
|
25 |
| - print(f"Inputs to {cur_node}: {inputs}") |
| 24 | + print(f"*** Inputs to {cur_node}:\n{inputs}") |
| 25 | + print("--" * 50) |
26 | 26 | for key, value in kwargs.items():
|
27 | 27 | print(f"{key}: {value}")
|
| 28 | + if self.services[cur_node].service_type == ServiceType.EMBEDDING: |
| 29 | + inputs["input"] = inputs["text"] |
| 30 | + del inputs["text"] |
| 31 | + elif self.services[cur_node].service_type == ServiceType.RETRIEVER: |
| 32 | + # input is EmbedDoc |
| 33 | + """Class EmbedDoc(BaseDoc): |
| 34 | +
|
| 35 | + text: Union[str, List[str]] |
| 36 | + embedding: Union[conlist(float, min_length=0), List[conlist(float, min_length=0)]] |
| 37 | + search_type: str = "similarity" |
| 38 | + k: int = 4 |
| 39 | + distance_threshold: Optional[float] = None |
| 40 | + fetch_k: int = 20 |
| 41 | + lambda_mult: float = 0.5 |
| 42 | + score_threshold: float = 0.2 |
| 43 | + constraints: Optional[Union[Dict[str, Any], List[Dict[str, Any]], None]] = None |
| 44 | + index_name: Optional[str] = None |
| 45 | + """ |
| 46 | + # prepare the retriever params |
| 47 | + retriever_parameters = kwargs.get("retriever_parameters", None) |
| 48 | + if retriever_parameters: |
| 49 | + inputs.update(retriever_parameters.dict()) |
| 50 | + elif self.services[cur_node].service_type == ServiceType.RERANK: |
| 51 | + # input is SearchedDoc |
| 52 | + """Class SearchedDoc(BaseDoc): |
| 53 | +
|
| 54 | + retrieved_docs: DocList[TextDoc] |
| 55 | + initial_query: str |
| 56 | + top_n: int = 1 |
| 57 | + """ |
| 58 | + # prepare the reranker params |
| 59 | + reranker_parameters = kwargs.get("reranker_parameters", None) |
| 60 | + if reranker_parameters: |
| 61 | + inputs.update(reranker_parameters.dict()) |
| 62 | + print(f"*** Formatted Inputs to {cur_node}:\n{inputs}") |
| 63 | + print("--" * 50) |
28 | 64 | return inputs
|
29 | 65 |
|
30 | 66 |
|
31 | 67 | def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs):
|
32 |
| - next_data = {} |
33 |
| - if self.services[cur_node].service_type == ServiceType.EMBEDDING: |
34 |
| - # turn into chat completion request |
35 |
| - # next_data = {"text": inputs["input"], "embedding": [item["embedding"] for item in data["data"]]} |
36 |
| - print("Assembing output from Embedding for next node...") |
37 |
| - print("Inputs to Embedding: ", inputs) |
38 |
| - print("Keyword arguments: ") |
39 |
| - for key, value in kwargs.items(): |
40 |
| - print(f"{key}: {value}") |
41 |
| - |
42 |
| - next_data = { |
43 |
| - "input": inputs["input"], |
44 |
| - "messages": inputs["input"], |
45 |
| - "embedding": [item["embedding"] for item in data["data"]], |
46 |
| - "k": kwargs["k"] if "k" in kwargs else 4, |
47 |
| - "search_type": kwargs["search_type"] if "search_type" in kwargs else "similarity", |
48 |
| - "distance_threshold": kwargs["distance_threshold"] if "distance_threshold" in kwargs else None, |
49 |
| - "fetch_k": kwargs["fetch_k"] if "fetch_k" in kwargs else 20, |
50 |
| - "lambda_mult": kwargs["lambda_mult"] if "lambda_mult" in kwargs else 0.5, |
51 |
| - "score_threshold": kwargs["score_threshold"] if "score_threshold" in kwargs else 0.2, |
52 |
| - "top_n": kwargs["top_n"] if "top_n" in kwargs else 1, |
53 |
| - } |
54 |
| - |
55 |
| - print("Output from Embedding for next node:\n", next_data) |
| 68 | + print(f"*** Direct Outputs from {cur_node}:\n{data}") |
| 69 | + print("--" * 50) |
56 | 70 |
|
| 71 | + if self.services[cur_node].service_type == ServiceType.EMBEDDING: |
| 72 | + # direct output from Embedding microservice is EmbeddingResponse |
| 73 | + """ |
| 74 | + class EmbeddingResponse(BaseModel): |
| 75 | + object: str = "list" |
| 76 | + model: Optional[str] = None |
| 77 | + data: List[EmbeddingResponseData] |
| 78 | + usage: Optional[UsageInfo] = None |
| 79 | +
|
| 80 | + class EmbeddingResponseData(BaseModel): |
| 81 | + index: int |
| 82 | + object: str = "embedding" |
| 83 | + embedding: Union[List[float], str] |
| 84 | + """ |
| 85 | + # turn it into EmbedDoc |
| 86 | + assert isinstance(data["data"], list) |
| 87 | + next_data = {"text": inputs["input"], "embedding": data["data"][0]["embedding"]} # EmbedDoc |
57 | 88 | else:
|
58 | 89 | next_data = data
|
59 | 90 |
|
| 91 | + print(f"*** Formatted Output from {cur_node} for next node:\n", next_data) |
| 92 | + print("--" * 50) |
60 | 93 | return next_data
|
61 | 94 |
|
62 | 95 |
|
@@ -100,54 +133,41 @@ def add_remote_service(self):
|
100 | 133 | self.megaservice.flow_to(retriever, rerank)
|
101 | 134 |
|
102 | 135 | async def handle_request(self, request: Request):
|
103 |
| - def parser_input(data, TypeClass, key): |
104 |
| - chat_request = None |
105 |
| - try: |
106 |
| - chat_request = TypeClass.parse_obj(data) |
107 |
| - query = getattr(chat_request, key) |
108 |
| - except: |
109 |
| - query = None |
110 |
| - return query, chat_request |
111 |
| - |
112 | 136 | data = await request.json()
|
113 |
| - query = None |
114 |
| - for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]): |
115 |
| - query, chat_request = parser_input(data, TypeClass, key) |
116 |
| - if query is not None: |
117 |
| - break |
118 |
| - if query is None: |
119 |
| - raise ValueError(f"Unknown request type: {data}") |
120 |
| - if chat_request is None: |
121 |
| - raise ValueError(f"Unknown request type: {data}") |
122 |
| - |
123 |
| - if isinstance(chat_request, ChatCompletionRequest): |
124 |
| - initial_inputs = { |
125 |
| - "messages": query, |
126 |
| - "input": query, # has to be input due to embedding expects either input or text |
127 |
| - "search_type": chat_request.search_type if chat_request.search_type else "similarity", |
128 |
| - "k": chat_request.k if chat_request.k else 4, |
129 |
| - "distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None, |
130 |
| - "fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20, |
131 |
| - "lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5, |
132 |
| - "score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2, |
133 |
| - "top_n": chat_request.top_n if chat_request.top_n else 1, |
134 |
| - } |
135 |
| - |
136 |
| - kwargs = { |
137 |
| - "search_type": chat_request.search_type if chat_request.search_type else "similarity", |
138 |
| - "k": chat_request.k if chat_request.k else 4, |
139 |
| - "distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None, |
140 |
| - "fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20, |
141 |
| - "lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5, |
142 |
| - "score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2, |
143 |
| - "top_n": chat_request.top_n if chat_request.top_n else 1, |
144 |
| - } |
145 |
| - result_dict, runtime_graph = await self.megaservice.schedule( |
146 |
| - initial_inputs=initial_inputs, |
147 |
| - **kwargs, |
148 |
| - ) |
149 |
| - else: |
150 |
| - result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"input": query}) |
| 137 | + chat_request = ChatCompletionRequest.parse_obj(data) |
| 138 | + |
| 139 | + prompt = chat_request.messages |
| 140 | + |
| 141 | + # dummy llm params |
| 142 | + parameters = LLMParams( |
| 143 | + max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, |
| 144 | + top_k=chat_request.top_k if chat_request.top_k else 10, |
| 145 | + top_p=chat_request.top_p if chat_request.top_p else 0.95, |
| 146 | + temperature=chat_request.temperature if chat_request.temperature else 0.01, |
| 147 | + frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, |
| 148 | + presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, |
| 149 | + repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, |
| 150 | + chat_template=chat_request.chat_template if chat_request.chat_template else None, |
| 151 | + model=chat_request.model if chat_request.model else None, |
| 152 | + ) |
| 153 | + |
| 154 | + retriever_parameters = RetrieverParms( |
| 155 | + search_type=chat_request.search_type if chat_request.search_type else "similarity", |
| 156 | + k=chat_request.k if chat_request.k else 4, |
| 157 | + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, |
| 158 | + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, |
| 159 | + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, |
| 160 | + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, |
| 161 | + ) |
| 162 | + reranker_parameters = RerankerParms( |
| 163 | + top_n=chat_request.top_n if chat_request.top_n else 1, |
| 164 | + ) |
| 165 | + result_dict, runtime_graph = await self.megaservice.schedule( |
| 166 | + initial_inputs={"text": prompt}, |
| 167 | + llm_parameters=parameters, |
| 168 | + retriever_parameters=retriever_parameters, |
| 169 | + reranker_parameters=reranker_parameters, |
| 170 | + ) |
151 | 171 |
|
152 | 172 | last_node = runtime_graph.all_leaves()[-1]
|
153 | 173 | response = result_dict[last_node]
|
|
0 commit comments