diff --git a/chat-with-context/index.html b/chat-with-context/index.html index 095f16d..0ee3a06 100644 --- a/chat-with-context/index.html +++ b/chat-with-context/index.html @@ -59,8 +59,10 @@

+ + diff --git a/chat-with-context/src/chat-with-context.tsx b/chat-with-context/src/chat-with-context.tsx index f952a0d..1884454 100644 --- a/chat-with-context/src/chat-with-context.tsx +++ b/chat-with-context/src/chat-with-context.tsx @@ -13,7 +13,7 @@ import "./style.css"; // https://github.com/solidjs/solid/blob/main/packages/solid-element/README.md // https://github.com/solidjs/templates/tree/main/ts-tailwindcss -type GenContext = { +type RefenceDocument = { score: number; payload: { doc_type: string; @@ -33,7 +33,7 @@ type Message = { role: "assistant" | "user"; content: Accessor; setContent: Setter; - sources: GenContext[]; + docs: RefenceDocument[]; links: Accessor; setLinks: Setter; }; @@ -52,7 +52,9 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => const [messages, setMessages] = createSignal([]); const [warningMsg, setWarningMsg] = createSignal(""); const [loading, setLoading] = createSignal(false); + const [abortController, setAbortController] = createSignal(new AbortController()); const [feedbackSent, setFeedbackSent] = createSignal(false); + const [selectedTab, setSelectedTab] = createSignal(""); const apiUrl = props.api.endsWith("/") ? props.api : props.api + "/"; if (props.api === "") setWarningMsg("Please provide an API URL for the chat component to work."); @@ -65,13 +67,10 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => let chatContainerEl!: HTMLDivElement; let inputTextEl!: HTMLTextAreaElement; - // NOTE: the 2 works but for now we want to always run validation - const streamResponse = false; - - const appendMessage = (msgContent: string, role: "assistant" | "user" = "assistant", sources: GenContext[] = []) => { + const appendMessage = (msgContent: string, role: "assistant" | "user" = "assistant", docs: RefenceDocument[] = []) => { const [content, setContent] = createSignal(msgContent); const [links, setLinks] = createSignal([]); - const newMsg: Message = {content, setContent, role, sources, links, setLinks}; + const newMsg: Message = {content, setContent, role, docs, links, setLinks}; const query = extractSparqlQuery(msgContent); if (query) newMsg.setLinks([{url: query, ...queryLinkLabels}]); setMessages(messages => [...messages, newMsg]); @@ -88,7 +87,7 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => async function submitInput(question: string) { if (!question.trim()) return; if (loading()) { - setWarningMsg("⏳ Thinking..."); + // setWarningMsg("⏳ Thinking..."); return; } inputTextEl.value = ""; @@ -96,15 +95,19 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => appendMessage(question, "user"); setTimeout(() => fixInputHeight(), 0); try { + const startTime = Date.now(); + // NOTE: the 2 works but for now we want to always run validation + const streamResponse = true; const response = await fetch(`${apiUrl}chat`, { method: "POST", headers: { "Content-Type": "application/json", - // 'Authorization': `Bearer ${apiKey}`, }, + signal: abortController().signal, body: JSON.stringify({ messages: messages().map(({content, role}) => ({content: content(), role})), model: "gpt-4o", + // model: "azure_ai/mistral-large", max_tokens: 50, stream: streamResponse, api_key: props.apiKey, @@ -116,7 +119,6 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => const reader = response.body?.getReader(); const decoder = new TextDecoder("utf-8"); let buffer = ""; - console.log(reader); // Iterate stream response while (true) { if (reader) { @@ -124,12 +126,13 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => if (done) break; buffer += decoder.decode(value, {stream: true}); let boundary = buffer.lastIndexOf("\n"); + // Add a small artificial delay to make streaming feel more natural + // if (boundary > 10000) await new Promise(resolve => setTimeout(resolve, 10)); + console.log(boundary, buffer) if (boundary !== -1) { const chunk = buffer.slice(0, boundary); buffer = buffer.slice(boundary + 1); - - const lines = chunk.split("\n").filter(line => line.trim() !== ""); - for (const line of lines) { + for (const line of chunk.split("\n").filter(line => line.trim() !== "")) { if (line === "data: [DONE]") { return; } @@ -172,7 +175,6 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => } // Extract query once message complete const query = extractSparqlQuery(lastMsg.content()); - console.log(query); if (query) lastMsg.setLinks([{url: query, ...queryLinkLabels}]); } else { // Don't stream, await full response with additional checks done on the server @@ -187,9 +189,13 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => setWarningMsg("An error occurred. Please try again."); } } + console.log(`Request completed in ${(Date.now() - startTime) / 1000} seconds`); } catch (error) { - console.error("Failed to send message", error); - setWarningMsg("An error occurred when querying the API. Please try again or contact an admin."); + if (error instanceof Error && error.name !== 'AbortError') { + console.error("An error occurred when querying the API", error); + setWarningMsg("An error occurred when querying the API. Please try again or contact an admin."); + } + // setWarningMsg("An error occurred when querying the API. Please try again or contact an admin."); } setLoading(false); setFeedbackSent(false); @@ -231,14 +237,15 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => innerHTML={DOMPurify.sanitize(marked.parse(msg.content()) as string)} /> - {/* Add sources references dialog */} - {msg.sources.length > 0 && ( + {/* Add reference docs dialog */} + {msg.docs.length > 0 && ( <>
- - {(source, iSource) => ( - <> -

- - {iSource() + 1} - {Math.round(source.score * 1000) / 1000} +

+ doc.payload.doc_type)))}> + {(docType) => + + } +
+ doc.payload.doc_type === selectedTab())}> + {(doc, iDoc) => ( + <> +

+ + {iDoc() + 1} - {Math.round(doc.score * 1000) / 1000} + + {doc.payload.question} ( + + {doc.payload.endpoint_url} + + ) +

+ {getLangForDocType(doc.payload.doc_type).startsWith("language-") ? ( +
+                                
+                                  {doc.payload.answer}
                                 
-                                {source.payload.question} (
-                                
-                                  {source.payload.endpoint_url}
-                                
-                                )
-                              

- {getLangForDocType(source.payload.doc_type).startsWith("language-") ? ( -
-                                  
-                                    {source.payload.answer}
-                                  
-                                
- ) : ( -

{source.payload.answer}

- )} - - )} - +
+ ) : ( +

{doc.payload.answer}

+ )} + + )}
@@ -352,6 +378,11 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => class="p-2 flex" onSubmit={event => { event.preventDefault(); + // Only abort if it's a click event (not from pressing Enter) + if (event.type === 'submit' && event.submitter && loading()) { + abortController().abort(); + setAbortController(new AbortController()); + } submitInput(inputTextEl.value); }} > @@ -372,18 +403,19 @@ customElement("chat-with-context", {api: "", examples: "", apiKey: ""}, props => /> {/*
=7.1.3", +# "pytest-cov >=3.0.0", +# ] [project.urls] Homepage = "https://github.com/sib-swiss/sparql-llm" @@ -102,6 +107,7 @@ post-install-commands = [] # uv venv # uv pip install ".[chat,gpu]" +# uv pip compile pyproject.toml -o requirements.txt # uv run python src/sparql_llm/embed_entities.py [tool.hatch.envs.default.scripts] fmt = [ @@ -122,7 +128,7 @@ cov-check = [ "python -c 'import webbrowser; webbrowser.open(\"http://0.0.0.0:3000\")'", "python -m http.server 3000 --directory ./htmlcov", ] -compile = "pip-compile -o requirements.txt pyproject.toml" +# compile = "pip-compile -o requirements.txt pyproject.toml" ## TOOLS @@ -211,4 +217,7 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["I", "F401"] # module imported but unused # Tests can use magic values, assertions, and relative imports: -"tests/**/*" = ["PLR2004", "S101", "S105", "TID252"] \ No newline at end of file +"tests/**/*" = ["PLR2004", "S101", "S105", "TID252"] + +# [tool.uv.workspace] +# members = ["mcp-sparql"] diff --git a/src/sparql_llm/api.py b/src/sparql_llm/api.py index 853873a..d780540 100644 --- a/src/sparql_llm/api.py +++ b/src/sparql_llm/api.py @@ -9,6 +9,7 @@ from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from litellm import completion from openai import OpenAI, Stream from openai.types.chat import ChatCompletion, ChatCompletionChunk from pydantic import BaseModel @@ -17,8 +18,8 @@ from rdflib.plugins.sparql.parser import parseQuery from starlette.middleware.cors import CORSMiddleware -from sparql_llm.config import get_llm_client, get_prefixes_dict, settings -from sparql_llm.embed import get_embedding_model, get_vectordb +from sparql_llm.config import get_prefixes_dict, settings +from sparql_llm.index import get_embedding_model, get_vectordb from sparql_llm.utils import get_prefix_converter from sparql_llm.validate_sparql import add_missing_prefixes, extract_sparql_queries, validate_sparql_with_void @@ -34,8 +35,10 @@ STARTUP_PROMPT = "Here is a list of reference questions and query answers relevant to the user question that will help you answer the user question accurately:" INTRO_USER_QUESTION_PROMPT = "The question from the user is:" -llm_model = "gpt-4o" +# llm_model = "gpt-4o" +llm_model = "azure_ai/mistral-large" # llm_model: str = "gpt-4o-mini" + # Models from glhf: # llm_model: str = "hf:meta-llama/Meta-Llama-3.1-8B-Instruct" # llm_model: str = "hf:mistralai/Mixtral-8x22B-Instruct-v0.1" @@ -87,17 +90,24 @@ class ChatCompletionRequest(BaseModel): api_key: Optional[str] = None -async def stream_openai(response: Stream[ChatCompletionChunk], docs, full_prompt) -> AsyncGenerator[str, None]: - """Stream the response from OpenAI""" +async def stream_response(response: Stream[ChatCompletionChunk], request: ChatCompletionRequest) -> AsyncGenerator[str, None]: + """Stream the response from the LLM provider to the client""" first_chunk = { - "docs": [hit.dict() for hit in docs], - "full_prompt": full_prompt, + "docs": response.docs, + "full_prompt": response.full_prompt, } yield f"data: {json.dumps(first_chunk)}\n\n" + full_msg = "" for chunk in response: if chunk.choices[0].finish_reason: + # TODO: Send message about validation running + # Then send a "fix" msg to let the client know the message needs to be replaced + # In this case the client will still enable showing the old wrong msg by clicking on a card + # if request.validate_output: + # messages.append(Message(role="assistant", content=full_msg)) + # validate_and_fix_sparql(resp, response, request.model) break # print(chunk) # ChatCompletionChunk(id='chatcmpl-9UxmYAx6E5Y3BXdI7YEVDmbOh9S2X', @@ -110,6 +120,7 @@ async def stream_openai(response: Stream[ChatCompletionChunk], docs, full_prompt "model": chunk.model, "choices": [{"delta": {"content": chunk.choices[0].delta.content}}], } + full_msg += chunk.choices[0].delta.content yield f"data: {json.dumps(resp_chunk)}\n\n" @@ -162,16 +173,17 @@ async def chat(request: ChatCompletionRequest): if settings.expasy_api_key and request.api_key != settings.expasy_api_key: raise ValueError("Invalid API key") - client = get_llm_client(request.model) # print(client.models.list()) + # Remove any system prompt from the client request + request.messages = [msg for msg in request.messages if msg.role != "system"] + request.messages = [Message(role="system", content=settings.system_prompt), *request.messages] question: str = request.messages[-1].content if request.messages else "" - + query_embeddings = next(iter(embedding_model.embed([question]))) logging.info(f"User question: {question}") - query_embeddings = next(iter(embedding_model.embed([question]))) # We build a big prompt with the most relevant queries retrieved from similarity search engine (could be increased) - prompt_with_context = f"{STARTUP_PROMPT}\n\n" + prompt = f"{STARTUP_PROMPT}\n\n" # # We also provide the example queries as previous messages to the LLM # system_msg: list[Message] = [{"role": "system", "content": settings.system_prompt}] @@ -184,15 +196,15 @@ async def chat(request: ChatCompletionRequest): must=[ FieldCondition( key="doc_type", - match=MatchValue(value="sparql_query"), + match=MatchValue(value="SPARQL endpoints query examples"), ) ] ), limit=settings.retrieved_queries_count, ) for query_hit in query_hits: - prompt_with_context += f"{query_hit.payload['question']}:\n\n```sparql\n# {query_hit.payload['endpoint_url']}\n{query_hit.payload['answer']}\n```\n\n" - # prompt_with_context += f"{query_hit.payload['question']}\nQuery to run in SPARQL endpoint {query_hit.payload['endpoint_url']}\n\n{query_hit.payload['answer']}\n\n" + prompt += f"{query_hit.payload['question']}:\n\n```sparql\n# {query_hit.payload['endpoint_url']}\n{query_hit.payload['answer']}\n```\n\n" + # prompt += f"{query_hit.payload['question']}\nQuery to run in SPARQL endpoint {query_hit.payload['endpoint_url']}\n\n{query_hit.payload['answer']}\n\n" # 2. Get the most relevant documents other than SPARQL query examples from the search engine (ShEx shapes, general infos) docs_hits = vectordb.search( @@ -202,16 +214,16 @@ async def chat(request: ChatCompletionRequest): should=[ FieldCondition( key="doc_type", - match=MatchValue(value="shex"), + match=MatchValue(value="SPARQL endpoints classes schema"), ), FieldCondition( key="doc_type", - match=MatchValue(value="schemaorg_description"), + match=MatchValue(value="General information"), ), # NOTE: we don't add ontology documents yet, not clean enough # FieldCondition( # key="doc_type", - # match=MatchValue(value="ontology"), + # match=MatchValue(value="Ontology"), # ), ] ), @@ -221,64 +233,70 @@ async def chat(request: ChatCompletionRequest): ) # TODO: vectordb.search_groups( # https://qdrant.tech/documentation/concepts/search/#search-groups - # TODO: hybrid search? https://qdrant.github.io/fastembed/examples/Hybrid_Search/#about-qdrant # we might want to group by iri for shex docs https://qdrant.tech/documentation/concepts/hybrid-queries/?q=hybrid+se#grouping # https://qdrant.tech/documentation/concepts/search/#search-groups - prompt_with_context += "Here is some additional information that could be useful to answer the user question:\n\n" + prompt += "Here is some additional information that could be useful to answer the user question:\n\n" # for docs_hit in docs_hits.groups: for docs_hit in docs_hits: - if docs_hit.payload["doc_type"] == "shex": - prompt_with_context += f"ShEx shape for {docs_hit.payload['question']} in {docs_hit.payload['endpoint_url']}:\n```\n{docs_hit.payload['answer']}\n```\n\n" - # elif docs_hit.payload["doc_type"] == "ontology": - # prompt_with_context += f"Relevant part of the ontology for {docs_hit.payload['endpoint_url']}:\n```turtle\n{docs_hit.payload['question']}\n```\n\n" + if docs_hit.payload["doc_type"] == "SPARQL endpoints classes schema": + prompt += f"ShEx shape for {docs_hit.payload['question']} in {docs_hit.payload['endpoint_url']}:\n```\n{docs_hit.payload['answer']}\n```\n\n" + # elif docs_hit.payload["doc_type"] == "Ontology": + # prompt += f"Relevant part of the ontology for {docs_hit.payload['endpoint_url']}:\n```turtle\n{docs_hit.payload['question']}\n```\n\n" else: - prompt_with_context += f"Information about: {docs_hit.payload['question']}\nRelated to SPARQL endpoint {docs_hit.payload['endpoint_url']}\n\n{docs_hit.payload['answer']}\n\n" - - # 3. Extract potential entities from the user question - entities_list = extract_entities(question) - for entity in entities_list: - prompt_with_context += f'\n\nEntities found in the user question for "{" ".join(entity["term"])}":\n\n' - for match in entity["matchs"]: - prompt_with_context += f"- {match.payload['label']} with IRI <{match.payload['uri']}> in endpoint {match.payload['endpoint_url']}\n\n" - - if len(entities_list) == 0: - prompt_with_context += "\nNo entities found in the user question that matches entities in the endpoints. " - - prompt_with_context += "\nIf the user is asking for a named entity, and this entity cannot be found in the endpoint, warn them about the fact we could not find it in the endpoints.\n\n" - - prompt_with_context += f"\n{INTRO_USER_QUESTION_PROMPT}\n{question}" - print(prompt_with_context) + prompt += f"Information about: {docs_hit.payload['question']}\nRelated to SPARQL endpoint {docs_hit.payload['endpoint_url']}\n\n{docs_hit.payload['answer']}\n\n" + + # 3. Extract potential entities from the user question (experimental) + # entities_list = extract_entities(question) + # for entity in entities_list: + # prompt += f'\n\nEntities found in the user question for "{" ".join(entity["term"])}":\n\n' + # for match in entity["matchs"]: + # prompt += f"- {match.payload['label']} with IRI <{match.payload['uri']}> in endpoint {match.payload['endpoint_url']}\n\n" + # if len(entities_list) == 0: + # prompt += "\nNo entities found in the user question that matches entities in the endpoints. " + # prompt += "\nIf the user is asking for a named entity, and this entity cannot be found in the endpoint, warn them about the fact we could not find it in the endpoints.\n\n" + + # 4. Add the user question to the prompt + prompt += f"\n{INTRO_USER_QUESTION_PROMPT}\n{question}" + print(prompt) # Use messages from the request to keep memory of previous messages sent by the client # Replace the question asked by the user with the big prompt with all contextual infos - request.messages[-1].content = prompt_with_context - all_messages = [Message(role="system", content=settings.system_prompt), *request.messages] + request.messages[-1].content = prompt # Send the prompt to OpenAI to get a response - response = client.chat.completions.create( + # response = client.chat.completions.create( + # model=request.model, + # messages=request.messages, + # stream=request.stream, + # temperature=request.temperature, + # # response_format={ "type": "json_object" }, + # ) + # NOTE: to get response as JSON object check https://github.com/jxnl/instructor or https://github.com/outlines-dev/outlines + + # TODO: litellm + response = completion( model=request.model, - messages=all_messages, + messages=request.messages, stream=request.stream, - temperature=request.temperature, - # response_format={ "type": "json_object" }, ) - # NOTE: to get response as JSON object check https://github.com/jxnl/instructor or https://github.com/outlines-dev/outlines + + # NOTE: the response is similar to OpenAI API, but we add the list of hits and the full prompt used to ask the question + response.docs = [hit.model_dump() for hit in query_hits + docs_hits] + response.full_prompt = prompt if request.stream: return StreamingResponse( - stream_openai(response, query_hits + docs_hits, prompt_with_context), media_type="application/x-ndjson" + stream_response(response, request), media_type="application/x-ndjson" ) + request.messages.append(Message(role="assistant", content=response.choices[0].message.content)) # print(response) # print(response.choices[0].message.content) response: ChatCompletion = ( - validate_and_fix_sparql(response, all_messages, client, request.model) if request.validate_output else response + validate_and_fix_sparql(request, response) if request.validate_output else response ) - # NOTE: the response is similar to OpenAI API, but we add the list of hits and the full prompt used to ask the question - response.docs = query_hits + docs_hits - response.full_prompt = prompt_with_context return response # return { # "id": response.id, @@ -287,22 +305,22 @@ async def chat(request: ChatCompletionRequest): # "model": response.model, # "choices": [{"message": Message(role="assistant", content=response.choices[0].message.content)}], # "docs": query_hits + docs_hits, - # "full_prompt": prompt_with_context, + # "full_prompt": prompt, # "usage": response.usage, # } def validate_and_fix_sparql( - resp: ChatCompletion, messages: list[Message], client: OpenAI, llm_model: str, try_count: int = 0 + request: ChatCompletionRequest, resp: ChatCompletion | None = None, try_count: int = 0 ) -> ChatCompletion: """Recursive function to validate the SPARQL queries in the chat response and fix them if needed.""" - if try_count >= settings.max_try_fix_sparql: resp.choices[ 0 ].message.content = f"{resp.choices[0].message.content}\n\nThe SPARQL query could not be fixed after multiple tries. Please do it yourself!" return resp - generated_sparqls = extract_sparql_queries(resp.choices[0].message.content) + # generated_sparqls = extract_sparql_queries(resp.choices[0].message.content) + generated_sparqls = extract_sparql_queries(request.messages[-1].content) # print("generated_sparqls", generated_sparqls) error_detected = False for gen_query in generated_sparqls: @@ -336,10 +354,16 @@ def validate_and_fix_sparql( """ # Which is part of this answer: # {md_resp} - messages.append({"role": "assistant", "content": fix_prompt}) - fixing_resp = client.chat.completions.create( - model=llm_model, - messages=messages, + # messages.append({"role": "assistant", "content": fix_prompt}) + request.messages.append(Message(role="assistant", content=fix_prompt)) + # fixing_resp = client.chat.completions.create( + # model=llm_model, + # messages=messages, + # stream=False, + # ) + fixing_resp = completion( + model=request.model, + messages=request.messages, stream=False, ) # md_resp = response.choices[0].message.content @@ -353,7 +377,7 @@ def validate_and_fix_sparql( ) if error_detected: # Check again the fixed query - return validate_and_fix_sparql(resp, messages, client, llm_model, try_count) + return validate_and_fix_sparql(request, resp, try_count) return resp diff --git a/src/sparql_llm/config.py b/src/sparql_llm/config.py index e508623..f33313c 100644 --- a/src/sparql_llm/config.py +++ b/src/sparql_llm/config.py @@ -58,7 +58,7 @@ # print(" Total tokens:", response.usage.total_tokens) # print(" Completion tokens:", response.usage.completion_tokens) - +# NOTE: still in use by tests, to be replaced with litellm def get_llm_client(model: str) -> OpenAI: if model.startswith("hf:"): # Automatically use glhf API key if the model starts with "hf:" diff --git a/src/sparql_llm/index.py b/src/sparql_llm/index.py index b8e7855..0c5b5bc 100644 --- a/src/sparql_llm/index.py +++ b/src/sparql_llm/index.py @@ -56,8 +56,10 @@ def load_schemaorg_description(endpoint: dict[str, str]) -> list[Document]: metadata={ "question": question, "answer": json_ld_content, + # "answer": f"```json\n{json_ld_content}\n```", + "iri": endpoint["homepage"], "endpoint_url": endpoint["endpoint_url"], - "doc_type": "schemaorg_jsonld", + "doc_type": "General information", }, ) ) @@ -81,7 +83,7 @@ def load_schemaorg_description(endpoint: dict[str, str]) -> list[Document]: "answer": "\n".join(descs), "endpoint_url": endpoint["endpoint_url"], "iri": endpoint["homepage"], - "doc_type": "schemaorg_description", + "doc_type": "General information", }, ) ) @@ -133,7 +135,8 @@ def init_vectordb(vectordb_host: str = settings.vectordb_host) -> None: """Initialize the vectordb with example queries and ontology descriptions from the SPARQL endpoints""" vectordb = get_vectordb(vectordb_host) - # if not vectordb.collection_exists(settings.docs_collection_name): + if vectordb.collection_exists(settings.docs_collection_name): + vectordb.delete_collection(settings.docs_collection_name) vectordb.create_collection( collection_name=settings.docs_collection_name, vectors_config=VectorParams(size=settings.embedding_dimensions, distance=Distance.COSINE), @@ -177,7 +180,7 @@ def init_vectordb(vectordb_host: str = settings.vectordb_host) -> None: """, "endpoint_url": "https://sparql.uniprot.org/sparql/", "iri": "http://www.uniprot.org/help/about", - "doc_type": "schemaorg_description", + "doc_type": "General information", }, ) ) diff --git a/src/sparql_llm/sparql_examples_loader.py b/src/sparql_llm/sparql_examples_loader.py index d9c41da..5d47511 100644 --- a/src/sparql_llm/sparql_examples_loader.py +++ b/src/sparql_llm/sparql_examples_loader.py @@ -75,7 +75,7 @@ def _create_document(self, row: Any, prefix_map: dict[str, str]) -> Document: "answer": query, "endpoint_url": self.endpoint_url, "query_type": parsed_query.algebra.name, - "doc_type": "sparql_query", + "doc_type": "SPARQL endpoints query examples", }, ) diff --git a/src/sparql_llm/sparql_void_shapes_loader.py b/src/sparql_llm/sparql_void_shapes_loader.py index fee751b..d8ca1ae 100644 --- a/src/sparql_llm/sparql_void_shapes_loader.py +++ b/src/sparql_llm/sparql_void_shapes_loader.py @@ -41,7 +41,7 @@ def load(self) -> list[Document]: "answer": shex_shape["shex"], "endpoint_url": self.endpoint_url, "iri": cls_uri, - "doc_type": "shex", + "doc_type": "SPARQL endpoints classes schema", } if "label" in shex_shape: docs.append( diff --git a/src/sparql_llm/templates/index.html b/src/sparql_llm/templates/index.html index c1c65b7..dd80035 100644 --- a/src/sparql_llm/templates/index.html +++ b/src/sparql_llm/templates/index.html @@ -225,13 +225,13 @@

let sourcesHtml = "" if (source_documents && source_documents.length > 0) { for (const [i, doc] of source_documents.entries()) { - if (doc.payload.doc_type === "ontology") { + if (doc.payload.doc_type === "Ontology") { sourcesMd += `\`${i + 1} - ${Math.round(doc.score*1000)/1000}\` Ontology for ${doc.payload.endpoint_url}\n\n\`\`\`turtle\n${doc.payload.question.trim()}\n\`\`\`\n\n`; - } else if (doc.payload.doc_type === "sparql_query") { + } else if (doc.payload.doc_type === "SPARQL endpoints query examples") { sourcesMd += `\`${i + 1} - ${Math.round(doc.score*1000)/1000}\` ${doc.payload.question} (${doc.payload.endpoint_url})\n\n\`\`\`sparql\n${doc.payload.answer}\n\`\`\`\n\n`; - } else if (doc.payload.doc_type === "schemaorg_jsonld") { + } else if (doc.payload.doc_type === "General information") { sourcesMd += `\`${i + 1} - ${Math.round(doc.score*1000)/1000}\` ${doc.payload.question} (${doc.payload.endpoint_url})\n\n\`\`\`json\n${doc.payload.answer}\n\`\`\`\n\n`; - } else if (doc.payload.doc_type === "shex") { + } else if (doc.payload.doc_type === "SPARQL endpoints classes schema") { sourcesMd += `\`${i + 1} - ${Math.round(doc.score*1000)/1000}\` ${doc.payload.question} (${doc.payload.endpoint_url})\n\n\`\`\`ttl\n${doc.payload.answer}\n\`\`\`\n\n`; } else { sourcesMd += `\`${i + 1} - ${Math.round(doc.score*1000)/1000}\` ${doc.payload.question} (${doc.payload.endpoint_url})\n\n${doc.payload.answer}\n\n`;