=7.1.3",
+# "pytest-cov >=3.0.0",
+# ]
[project.urls]
Homepage = "https://github.com/sib-swiss/sparql-llm"
@@ -102,6 +107,7 @@ post-install-commands = []
# uv venv
# uv pip install ".[chat,gpu]"
+# uv pip compile pyproject.toml -o requirements.txt
# uv run python src/sparql_llm/embed_entities.py
[tool.hatch.envs.default.scripts]
fmt = [
@@ -122,7 +128,7 @@ cov-check = [
"python -c 'import webbrowser; webbrowser.open(\"http://0.0.0.0:3000\")'",
"python -m http.server 3000 --directory ./htmlcov",
]
-compile = "pip-compile -o requirements.txt pyproject.toml"
+# compile = "pip-compile -o requirements.txt pyproject.toml"
## TOOLS
@@ -211,4 +217,7 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["I", "F401"] # module imported but unused
# Tests can use magic values, assertions, and relative imports:
-"tests/**/*" = ["PLR2004", "S101", "S105", "TID252"]
\ No newline at end of file
+"tests/**/*" = ["PLR2004", "S101", "S105", "TID252"]
+
+# [tool.uv.workspace]
+# members = ["mcp-sparql"]
diff --git a/src/sparql_llm/api.py b/src/sparql_llm/api.py
index 853873a..d780540 100644
--- a/src/sparql_llm/api.py
+++ b/src/sparql_llm/api.py
@@ -9,6 +9,7 @@
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
+from litellm import completion
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from pydantic import BaseModel
@@ -17,8 +18,8 @@
from rdflib.plugins.sparql.parser import parseQuery
from starlette.middleware.cors import CORSMiddleware
-from sparql_llm.config import get_llm_client, get_prefixes_dict, settings
-from sparql_llm.embed import get_embedding_model, get_vectordb
+from sparql_llm.config import get_prefixes_dict, settings
+from sparql_llm.index import get_embedding_model, get_vectordb
from sparql_llm.utils import get_prefix_converter
from sparql_llm.validate_sparql import add_missing_prefixes, extract_sparql_queries, validate_sparql_with_void
@@ -34,8 +35,10 @@
STARTUP_PROMPT = "Here is a list of reference questions and query answers relevant to the user question that will help you answer the user question accurately:"
INTRO_USER_QUESTION_PROMPT = "The question from the user is:"
-llm_model = "gpt-4o"
+# llm_model = "gpt-4o"
+llm_model = "azure_ai/mistral-large"
# llm_model: str = "gpt-4o-mini"
+
# Models from glhf:
# llm_model: str = "hf:meta-llama/Meta-Llama-3.1-8B-Instruct"
# llm_model: str = "hf:mistralai/Mixtral-8x22B-Instruct-v0.1"
@@ -87,17 +90,24 @@ class ChatCompletionRequest(BaseModel):
api_key: Optional[str] = None
-async def stream_openai(response: Stream[ChatCompletionChunk], docs, full_prompt) -> AsyncGenerator[str, None]:
- """Stream the response from OpenAI"""
+async def stream_response(response: Stream[ChatCompletionChunk], request: ChatCompletionRequest) -> AsyncGenerator[str, None]:
+ """Stream the response from the LLM provider to the client"""
first_chunk = {
- "docs": [hit.dict() for hit in docs],
- "full_prompt": full_prompt,
+ "docs": response.docs,
+ "full_prompt": response.full_prompt,
}
yield f"data: {json.dumps(first_chunk)}\n\n"
+ full_msg = ""
for chunk in response:
if chunk.choices[0].finish_reason:
+ # TODO: Send message about validation running
+ # Then send a "fix" msg to let the client know the message needs to be replaced
+ # In this case the client will still enable showing the old wrong msg by clicking on a card
+ # if request.validate_output:
+ # messages.append(Message(role="assistant", content=full_msg))
+ # validate_and_fix_sparql(resp, response, request.model)
break
# print(chunk)
# ChatCompletionChunk(id='chatcmpl-9UxmYAx6E5Y3BXdI7YEVDmbOh9S2X',
@@ -110,6 +120,7 @@ async def stream_openai(response: Stream[ChatCompletionChunk], docs, full_prompt
"model": chunk.model,
"choices": [{"delta": {"content": chunk.choices[0].delta.content}}],
}
+ full_msg += chunk.choices[0].delta.content
yield f"data: {json.dumps(resp_chunk)}\n\n"
@@ -162,16 +173,17 @@ async def chat(request: ChatCompletionRequest):
if settings.expasy_api_key and request.api_key != settings.expasy_api_key:
raise ValueError("Invalid API key")
- client = get_llm_client(request.model)
# print(client.models.list())
+ # Remove any system prompt from the client request
+ request.messages = [msg for msg in request.messages if msg.role != "system"]
+ request.messages = [Message(role="system", content=settings.system_prompt), *request.messages]
question: str = request.messages[-1].content if request.messages else ""
-
+ query_embeddings = next(iter(embedding_model.embed([question])))
logging.info(f"User question: {question}")
- query_embeddings = next(iter(embedding_model.embed([question])))
# We build a big prompt with the most relevant queries retrieved from similarity search engine (could be increased)
- prompt_with_context = f"{STARTUP_PROMPT}\n\n"
+ prompt = f"{STARTUP_PROMPT}\n\n"
# # We also provide the example queries as previous messages to the LLM
# system_msg: list[Message] = [{"role": "system", "content": settings.system_prompt}]
@@ -184,15 +196,15 @@ async def chat(request: ChatCompletionRequest):
must=[
FieldCondition(
key="doc_type",
- match=MatchValue(value="sparql_query"),
+ match=MatchValue(value="SPARQL endpoints query examples"),
)
]
),
limit=settings.retrieved_queries_count,
)
for query_hit in query_hits:
- prompt_with_context += f"{query_hit.payload['question']}:\n\n```sparql\n# {query_hit.payload['endpoint_url']}\n{query_hit.payload['answer']}\n```\n\n"
- # prompt_with_context += f"{query_hit.payload['question']}\nQuery to run in SPARQL endpoint {query_hit.payload['endpoint_url']}\n\n{query_hit.payload['answer']}\n\n"
+ prompt += f"{query_hit.payload['question']}:\n\n```sparql\n# {query_hit.payload['endpoint_url']}\n{query_hit.payload['answer']}\n```\n\n"
+ # prompt += f"{query_hit.payload['question']}\nQuery to run in SPARQL endpoint {query_hit.payload['endpoint_url']}\n\n{query_hit.payload['answer']}\n\n"
# 2. Get the most relevant documents other than SPARQL query examples from the search engine (ShEx shapes, general infos)
docs_hits = vectordb.search(
@@ -202,16 +214,16 @@ async def chat(request: ChatCompletionRequest):
should=[
FieldCondition(
key="doc_type",
- match=MatchValue(value="shex"),
+ match=MatchValue(value="SPARQL endpoints classes schema"),
),
FieldCondition(
key="doc_type",
- match=MatchValue(value="schemaorg_description"),
+ match=MatchValue(value="General information"),
),
# NOTE: we don't add ontology documents yet, not clean enough
# FieldCondition(
# key="doc_type",
- # match=MatchValue(value="ontology"),
+ # match=MatchValue(value="Ontology"),
# ),
]
),
@@ -221,64 +233,70 @@ async def chat(request: ChatCompletionRequest):
)
# TODO: vectordb.search_groups(
# https://qdrant.tech/documentation/concepts/search/#search-groups
-
# TODO: hybrid search? https://qdrant.github.io/fastembed/examples/Hybrid_Search/#about-qdrant
# we might want to group by iri for shex docs https://qdrant.tech/documentation/concepts/hybrid-queries/?q=hybrid+se#grouping
# https://qdrant.tech/documentation/concepts/search/#search-groups
- prompt_with_context += "Here is some additional information that could be useful to answer the user question:\n\n"
+ prompt += "Here is some additional information that could be useful to answer the user question:\n\n"
# for docs_hit in docs_hits.groups:
for docs_hit in docs_hits:
- if docs_hit.payload["doc_type"] == "shex":
- prompt_with_context += f"ShEx shape for {docs_hit.payload['question']} in {docs_hit.payload['endpoint_url']}:\n```\n{docs_hit.payload['answer']}\n```\n\n"
- # elif docs_hit.payload["doc_type"] == "ontology":
- # prompt_with_context += f"Relevant part of the ontology for {docs_hit.payload['endpoint_url']}:\n```turtle\n{docs_hit.payload['question']}\n```\n\n"
+ if docs_hit.payload["doc_type"] == "SPARQL endpoints classes schema":
+ prompt += f"ShEx shape for {docs_hit.payload['question']} in {docs_hit.payload['endpoint_url']}:\n```\n{docs_hit.payload['answer']}\n```\n\n"
+ # elif docs_hit.payload["doc_type"] == "Ontology":
+ # prompt += f"Relevant part of the ontology for {docs_hit.payload['endpoint_url']}:\n```turtle\n{docs_hit.payload['question']}\n```\n\n"
else:
- prompt_with_context += f"Information about: {docs_hit.payload['question']}\nRelated to SPARQL endpoint {docs_hit.payload['endpoint_url']}\n\n{docs_hit.payload['answer']}\n\n"
-
- # 3. Extract potential entities from the user question
- entities_list = extract_entities(question)
- for entity in entities_list:
- prompt_with_context += f'\n\nEntities found in the user question for "{" ".join(entity["term"])}":\n\n'
- for match in entity["matchs"]:
- prompt_with_context += f"- {match.payload['label']} with IRI <{match.payload['uri']}> in endpoint {match.payload['endpoint_url']}\n\n"
-
- if len(entities_list) == 0:
- prompt_with_context += "\nNo entities found in the user question that matches entities in the endpoints. "
-
- prompt_with_context += "\nIf the user is asking for a named entity, and this entity cannot be found in the endpoint, warn them about the fact we could not find it in the endpoints.\n\n"
-
- prompt_with_context += f"\n{INTRO_USER_QUESTION_PROMPT}\n{question}"
- print(prompt_with_context)
+ prompt += f"Information about: {docs_hit.payload['question']}\nRelated to SPARQL endpoint {docs_hit.payload['endpoint_url']}\n\n{docs_hit.payload['answer']}\n\n"
+
+ # 3. Extract potential entities from the user question (experimental)
+ # entities_list = extract_entities(question)
+ # for entity in entities_list:
+ # prompt += f'\n\nEntities found in the user question for "{" ".join(entity["term"])}":\n\n'
+ # for match in entity["matchs"]:
+ # prompt += f"- {match.payload['label']} with IRI <{match.payload['uri']}> in endpoint {match.payload['endpoint_url']}\n\n"
+ # if len(entities_list) == 0:
+ # prompt += "\nNo entities found in the user question that matches entities in the endpoints. "
+ # prompt += "\nIf the user is asking for a named entity, and this entity cannot be found in the endpoint, warn them about the fact we could not find it in the endpoints.\n\n"
+
+ # 4. Add the user question to the prompt
+ prompt += f"\n{INTRO_USER_QUESTION_PROMPT}\n{question}"
+ print(prompt)
# Use messages from the request to keep memory of previous messages sent by the client
# Replace the question asked by the user with the big prompt with all contextual infos
- request.messages[-1].content = prompt_with_context
- all_messages = [Message(role="system", content=settings.system_prompt), *request.messages]
+ request.messages[-1].content = prompt
# Send the prompt to OpenAI to get a response
- response = client.chat.completions.create(
+ # response = client.chat.completions.create(
+ # model=request.model,
+ # messages=request.messages,
+ # stream=request.stream,
+ # temperature=request.temperature,
+ # # response_format={ "type": "json_object" },
+ # )
+ # NOTE: to get response as JSON object check https://github.com/jxnl/instructor or https://github.com/outlines-dev/outlines
+
+ # TODO: litellm
+ response = completion(
model=request.model,
- messages=all_messages,
+ messages=request.messages,
stream=request.stream,
- temperature=request.temperature,
- # response_format={ "type": "json_object" },
)
- # NOTE: to get response as JSON object check https://github.com/jxnl/instructor or https://github.com/outlines-dev/outlines
+
+ # NOTE: the response is similar to OpenAI API, but we add the list of hits and the full prompt used to ask the question
+ response.docs = [hit.model_dump() for hit in query_hits + docs_hits]
+ response.full_prompt = prompt
if request.stream:
return StreamingResponse(
- stream_openai(response, query_hits + docs_hits, prompt_with_context), media_type="application/x-ndjson"
+ stream_response(response, request), media_type="application/x-ndjson"
)
+ request.messages.append(Message(role="assistant", content=response.choices[0].message.content))
# print(response)
# print(response.choices[0].message.content)
response: ChatCompletion = (
- validate_and_fix_sparql(response, all_messages, client, request.model) if request.validate_output else response
+ validate_and_fix_sparql(request, response) if request.validate_output else response
)
- # NOTE: the response is similar to OpenAI API, but we add the list of hits and the full prompt used to ask the question
- response.docs = query_hits + docs_hits
- response.full_prompt = prompt_with_context
return response
# return {
# "id": response.id,
@@ -287,22 +305,22 @@ async def chat(request: ChatCompletionRequest):
# "model": response.model,
# "choices": [{"message": Message(role="assistant", content=response.choices[0].message.content)}],
# "docs": query_hits + docs_hits,
- # "full_prompt": prompt_with_context,
+ # "full_prompt": prompt,
# "usage": response.usage,
# }
def validate_and_fix_sparql(
- resp: ChatCompletion, messages: list[Message], client: OpenAI, llm_model: str, try_count: int = 0
+ request: ChatCompletionRequest, resp: ChatCompletion | None = None, try_count: int = 0
) -> ChatCompletion:
"""Recursive function to validate the SPARQL queries in the chat response and fix them if needed."""
-
if try_count >= settings.max_try_fix_sparql:
resp.choices[
0
].message.content = f"{resp.choices[0].message.content}\n\nThe SPARQL query could not be fixed after multiple tries. Please do it yourself!"
return resp
- generated_sparqls = extract_sparql_queries(resp.choices[0].message.content)
+ # generated_sparqls = extract_sparql_queries(resp.choices[0].message.content)
+ generated_sparqls = extract_sparql_queries(request.messages[-1].content)
# print("generated_sparqls", generated_sparqls)
error_detected = False
for gen_query in generated_sparqls:
@@ -336,10 +354,16 @@ def validate_and_fix_sparql(
"""
# Which is part of this answer:
# {md_resp}
- messages.append({"role": "assistant", "content": fix_prompt})
- fixing_resp = client.chat.completions.create(
- model=llm_model,
- messages=messages,
+ # messages.append({"role": "assistant", "content": fix_prompt})
+ request.messages.append(Message(role="assistant", content=fix_prompt))
+ # fixing_resp = client.chat.completions.create(
+ # model=llm_model,
+ # messages=messages,
+ # stream=False,
+ # )
+ fixing_resp = completion(
+ model=request.model,
+ messages=request.messages,
stream=False,
)
# md_resp = response.choices[0].message.content
@@ -353,7 +377,7 @@ def validate_and_fix_sparql(
)
if error_detected:
# Check again the fixed query
- return validate_and_fix_sparql(resp, messages, client, llm_model, try_count)
+ return validate_and_fix_sparql(request, resp, try_count)
return resp
diff --git a/src/sparql_llm/config.py b/src/sparql_llm/config.py
index e508623..f33313c 100644
--- a/src/sparql_llm/config.py
+++ b/src/sparql_llm/config.py
@@ -58,7 +58,7 @@
# print(" Total tokens:", response.usage.total_tokens)
# print(" Completion tokens:", response.usage.completion_tokens)
-
+# NOTE: still in use by tests, to be replaced with litellm
def get_llm_client(model: str) -> OpenAI:
if model.startswith("hf:"):
# Automatically use glhf API key if the model starts with "hf:"
diff --git a/src/sparql_llm/index.py b/src/sparql_llm/index.py
index b8e7855..0c5b5bc 100644
--- a/src/sparql_llm/index.py
+++ b/src/sparql_llm/index.py
@@ -56,8 +56,10 @@ def load_schemaorg_description(endpoint: dict[str, str]) -> list[Document]:
metadata={
"question": question,
"answer": json_ld_content,
+ # "answer": f"```json\n{json_ld_content}\n```",
+ "iri": endpoint["homepage"],
"endpoint_url": endpoint["endpoint_url"],
- "doc_type": "schemaorg_jsonld",
+ "doc_type": "General information",
},
)
)
@@ -81,7 +83,7 @@ def load_schemaorg_description(endpoint: dict[str, str]) -> list[Document]:
"answer": "\n".join(descs),
"endpoint_url": endpoint["endpoint_url"],
"iri": endpoint["homepage"],
- "doc_type": "schemaorg_description",
+ "doc_type": "General information",
},
)
)
@@ -133,7 +135,8 @@ def init_vectordb(vectordb_host: str = settings.vectordb_host) -> None:
"""Initialize the vectordb with example queries and ontology descriptions from the SPARQL endpoints"""
vectordb = get_vectordb(vectordb_host)
- # if not vectordb.collection_exists(settings.docs_collection_name):
+ if vectordb.collection_exists(settings.docs_collection_name):
+ vectordb.delete_collection(settings.docs_collection_name)
vectordb.create_collection(
collection_name=settings.docs_collection_name,
vectors_config=VectorParams(size=settings.embedding_dimensions, distance=Distance.COSINE),
@@ -177,7 +180,7 @@ def init_vectordb(vectordb_host: str = settings.vectordb_host) -> None:
""",
"endpoint_url": "https://sparql.uniprot.org/sparql/",
"iri": "http://www.uniprot.org/help/about",
- "doc_type": "schemaorg_description",
+ "doc_type": "General information",
},
)
)
diff --git a/src/sparql_llm/sparql_examples_loader.py b/src/sparql_llm/sparql_examples_loader.py
index d9c41da..5d47511 100644
--- a/src/sparql_llm/sparql_examples_loader.py
+++ b/src/sparql_llm/sparql_examples_loader.py
@@ -75,7 +75,7 @@ def _create_document(self, row: Any, prefix_map: dict[str, str]) -> Document:
"answer": query,
"endpoint_url": self.endpoint_url,
"query_type": parsed_query.algebra.name,
- "doc_type": "sparql_query",
+ "doc_type": "SPARQL endpoints query examples",
},
)
diff --git a/src/sparql_llm/sparql_void_shapes_loader.py b/src/sparql_llm/sparql_void_shapes_loader.py
index fee751b..d8ca1ae 100644
--- a/src/sparql_llm/sparql_void_shapes_loader.py
+++ b/src/sparql_llm/sparql_void_shapes_loader.py
@@ -41,7 +41,7 @@ def load(self) -> list[Document]:
"answer": shex_shape["shex"],
"endpoint_url": self.endpoint_url,
"iri": cls_uri,
- "doc_type": "shex",
+ "doc_type": "SPARQL endpoints classes schema",
}
if "label" in shex_shape:
docs.append(
diff --git a/src/sparql_llm/templates/index.html b/src/sparql_llm/templates/index.html
index c1c65b7..dd80035 100644
--- a/src/sparql_llm/templates/index.html
+++ b/src/sparql_llm/templates/index.html
@@ -225,13 +225,13 @@