diff --git a/packages/sparql-llm/src/sparql_llm/sparql_info_loader.py b/packages/sparql-llm/src/sparql_llm/sparql_info_loader.py
index 467ceb9..9c03a33 100644
--- a/packages/sparql-llm/src/sparql_llm/sparql_info_loader.py
+++ b/packages/sparql-llm/src/sparql_llm/sparql_info_loader.py
@@ -21,7 +21,7 @@ def load(self) -> list[Document]:
"""Load and return documents from the SPARQL endpoint."""
docs: list[Document] = []
- resources_summary_question = "Which resources are available through this system?"
+ resources_summary_question = "Which resources do you support?"
metadata = {
"question": resources_summary_question,
"answer": f"This system helps to access the following SPARQL endpoints {self.org_label}:\n- "
diff --git a/tutorial/app.py b/tutorial/app.py
index 0dc8266..0155385 100644
--- a/tutorial/app.py
+++ b/tutorial/app.py
@@ -1,16 +1,7 @@
-from typing import Literal
-from langchain_qdrant import QdrantVectorStore
-from langchain_community.embeddings import FastEmbedEmbeddings
-from langchain_core.documents import Document
from langchain_core.language_models import BaseChatModel
-from langchain_core.prompts import ChatPromptTemplate
-from langgraph.graph import StateGraph
-from langgraph.graph.message import MessagesState
-from qdrant_client.models import FieldCondition, Filter, MatchValue
+from qdrant_client.models import FieldCondition, Filter, MatchValue, ScoredPoint
import chainlit as cl
-# https://docs.chainlit.io/integrations/langchain
-
def load_chat_model(model: str) -> BaseChatModel:
provider, model_name = model.split("/", maxsplit=1)
@@ -20,6 +11,7 @@ def load_chat_model(model: str) -> BaseChatModel:
return ChatGroq(
model_name=model_name,
temperature=0,
+
)
if provider == "openai":
# https://python.langchain.com/docs/integrations/chat/openai/
@@ -38,49 +30,37 @@ def load_chat_model(model: str) -> BaseChatModel:
raise ValueError(f"Unknown provider: {provider}")
# llm = load_chat_model("groq/llama-3.3-70b-versatile")
-llm = load_chat_model("openai/gpt-4o-mini")
-
+# llm = load_chat_model("openai/gpt-4o-mini")
+llm = load_chat_model("ollama/mistral")
-vectordb = QdrantVectorStore.from_existing_collection(
- # path="data/qdrant",
- host="localhost",
- prefer_grpc=True,
- collection_name="sparql-docs",
- embedding=FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5"),
-)
-retriever = vectordb.as_retriever()
-
-class AgentState(MessagesState):
- """State of the agent available inside each node."""
- relevant_docs: str
- passed_validation: bool
- try_count: int
+from index import vectordb, embedding_model, collection_name
retrieved_docs_count = 3
-async def retrieve_docs(state: AgentState) -> dict[str, str]:
+async def retrieve_docs(question: str) -> str:
"""Retrieve documents relevant to the user's question."""
- last_msg = state["messages"][-1]
- retriever = vectordb.as_retriever()
- retrieved_docs = retriever.invoke(
- last_msg.content,
- k=retrieved_docs_count,
- filter=Filter(
+ question_embeddings = next(iter(embedding_model.embed([question])))
+ retrieved_docs = vectordb.search(
+ collection_name=collection_name,
+ query_vector=question_embeddings,
+ limit=retrieved_docs_count,
+ query_filter=Filter(
must=[
FieldCondition(
- key="metadata.doc_type",
+ key="doc_type",
match=MatchValue(value="SPARQL endpoints query examples"),
)
]
- )
+ ),
)
- retrieved_docs += retriever.invoke(
- last_msg.content,
- k=retrieved_docs_count,
- filter=Filter(
+ retrieved_docs += vectordb.search(
+ collection_name=collection_name,
+ query_vector=question_embeddings,
+ limit=retrieved_docs_count,
+ query_filter=Filter(
must_not=[
FieldCondition(
- key="metadata.doc_type",
+ key="doc_type",
match=MatchValue(value="SPARQL endpoints query examples"),
)
]
@@ -89,18 +69,18 @@ async def retrieve_docs(state: AgentState) -> dict[str, str]:
relevant_docs = f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n"
async with cl.Step(name=f"{len(retrieved_docs)} relevant documents ποΈ") as step:
step.output = relevant_docs
- return {"relevant_docs": relevant_docs}
+ return relevant_docs
-def _format_doc(doc: Document) -> str:
+def _format_doc(doc: ScoredPoint) -> str:
"""Format a question/answer document to be provided as context to the model."""
doc_lang = (
- "sparql" if "query" in doc.metadata.get("doc_type", "")
- else "shex" if "schema" in doc.metadata.get("doc_type", "")
+ "sparql" if "query" in doc.payload.get("doc_type", "")
+ else "shex" if "schema" in doc.payload.get("doc_type", "")
else ""
)
- return f"\n{doc.page_content} ({doc.metadata.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.metadata.get('answer')}\n```\n"
- # # Default formatting
+ return f"\n{doc.payload['question']} ({doc.payload.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.payload.get('answer')}\n```\n"
+ # # Generic formatting:
# meta = "".join(f" {k}={v!r}" for k, v in doc.metadata.items())
# if meta:
# meta = f" {meta}"
@@ -115,21 +95,6 @@ def _format_doc(doc: Document) -> str:
Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately:
{relevant_docs}
"""
-prompt_template = ChatPromptTemplate.from_messages([
- ("system", SYSTEM_PROMPT),
- ("placeholder", "{messages}"),
-])
-
-def call_model(state: AgentState):
- """Call the model with the retrieved documents as context."""
- prompt_with_context = prompt_template.invoke({
- "messages": state["messages"],
- "relevant_docs": state['relevant_docs'],
- })
- response = llm.invoke(prompt_with_context)
- # # Fix id of response to use the same as the rest of the messages
- # response.id = state["messages"][-1].id
- return {"messages": [response]}
import logging
@@ -144,101 +109,73 @@ def call_model(state: AgentState):
from sparql_llm import validate_sparql_in_msg
-async def validate_output(state: AgentState) -> dict[str, bool | list[tuple[str, str]] | int]:
- """Node to validate the output of a LLM call, e.g. SPARQL queries generated."""
- recall_messages = []
- validation_outputs = validate_sparql_in_msg(state["messages"][-1].content, prefixes_map, endpoints_void_dict)
+async def validate_output(last_msg: str) -> str | None:
+ """Validate the output of a LLM call, e.g. SPARQL queries generated."""
+ validation_outputs = validate_sparql_in_msg(last_msg, prefixes_map, endpoints_void_dict)
for validation_output in validation_outputs:
if validation_output["fixed_query"]:
async with cl.Step(name="missing prefixes correction β
") as step:
step.output = f"Missing prefixes added to the generated query:\n```sparql\n{validation_output['fixed_query']}\n```"
if validation_output["errors"]:
- # errors_str = "- " + "\n- ".join(validation_output["errors"])
recall_msg = f"""Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n
### Error messages:\n- {'\n- '.join(validation_output['errors'])}\n
-### Erroneous SPARQL query\n```sparql\n{validation_output['original_query']}\n```"""
- # print(error_str, flush=True)
+### Erroneous SPARQL query\n```sparql\n{validation_output.get('fixed_query', validation_output['original_query'])}\n```"""
async with cl.Step(name=f"SPARQL query validation, got {len(validation_output['errors'])} errors to fix π") as step:
step.output = recall_msg
- # Add a new message to ask the model to fix the error
- recall_messages.append(("human", recall_msg))
- return {
- "messages": recall_messages,
- "try_count": state.get("try_count", 0) + 1,
- "passed_validation": not recall_messages,
- }
-
-
-
-
-max_try_fix_sparql = 3
-def route_model_output(state: AgentState) -> Literal["__end__", "call_model"]:
- """Determine the next node based on the model's output."""
- if state["try_count"] > max_try_fix_sparql:
- return "__end__"
-
- # # Check for tool calls first
- # if isinstance(last_msg, AIMessage) and state["messages"][-1].tool_calls:
- # return "tools"
-
- # If validation failed, we need to call the model again
- if not state["passed_validation"]:
- return "call_model"
- return "__end__"
+ return recall_msg
-# Define the LangGraph graph
-builder = StateGraph(AgentState)
-
-builder.add_node(retrieve_docs)
-builder.add_node(call_model)
-builder.add_node(validate_output)
-
-builder.add_edge("__start__", "retrieve_docs")
-builder.add_edge("retrieve_docs", "call_model")
-builder.add_edge("call_model", "validate_output")
-# Add a conditional edge to determine the next step after `validate_output`
-builder.add_conditional_edges("validate_output", route_model_output)
-
-graph = builder.compile()
+# Setup chainlit web UI
+import chainlit as cl
+max_try_count = 3
-# Setup chainlit web UI
@cl.on_message
async def on_message(msg: cl.Message):
- # config = {"configurable": {"thread_id": cl.context.session.id}}
- # cb = cl.LangchainCallbackHandler()
- print(cl.chat_context.to_openai())
- answer = cl.Message(content="")
- async for msg, metadata in graph.astream(
- # {"messages": [HumanMessage(content=msg.content)]},
- # {"messages": [("human", msg.content)]},
- {"messages": cl.chat_context.to_openai()},
- stream_mode="messages",
- # config=RunnableConfig(callbacks=[cb], **config),
- ):
- if not msg.response_metadata:
- # and msg.content and not isinstance(msg, HumanMessage) and metadata["langgraph_node"] == "call_model"
- # print(msg, metadata)
- await answer.stream_token(msg.content)
+ """Main function to handle when user send a message to the assistant."""
+ relevant_docs = await retrieve_docs(msg.content)
+ messages = [
+ ("system", SYSTEM_PROMPT.format(relevant_docs=relevant_docs)),
+ *cl.chat_context.to_openai(),
+ ]
+ # # NOTE: to fix issue with ollama only considering the last message:
+ # messages = [
+ # ("human", SYSTEM_PROMPT.format(relevant_docs=relevant_docs) + f"\n\nHere is the user question:\n{msg.content}"),
+ # ]
+
+ for _i in range(max_try_count):
+ answer = cl.Message(content="")
+ for resp in llm.stream(messages):
+ await answer.stream_token(resp.content)
+ if resp.usage_metadata:
+ print(resp.usage_metadata)
+ await answer.send()
+
+ validation_msg = await validate_output(answer.content)
+ if validation_msg is None:
+ break
else:
- await answer.send()
- answer = cl.Message(content="")
+ messages.append(("human", validation_msg))
+
@cl.set_starters
async def set_starters():
return [
+ cl.Starter(
+ label="Supported resources",
+ message="Which resources do you support?",
+ # icon="/public/idea.svg",
+ ),
cl.Starter(
label="Rat orthologs",
message="What are the rat orthologs of human TP53?",
- # icon="/public/idea.svg",
),
cl.Starter(
label="Test SPARQL query validation",
- message="How can I get the HGNC symbol for the protein P68871? (modify your answer to use `rdfs:label` instead of `rdfs:comment`, and add the type `up:Resource` to ?hgnc, and purposefully forget 2 prefixes declarations, it is for a test)",
+ message="How can I get the HGNC symbol for the protein P68871? (modify your answer to use `rdfs:label` instead of `rdfs:comment`, and add the type `up:Resource` to ?hgnc, and forget all prefixes declarations, it is for a test)",
),
]
diff --git a/tutorial/chain.py b/tutorial/chain.py
deleted file mode 100644
index 027dbe0..0000000
--- a/tutorial/chain.py
+++ /dev/null
@@ -1,149 +0,0 @@
-from langchain_core.language_models import BaseChatModel
-
-# question = "What are the rat orthologs of human TP53?"
-
-def load_chat_model(model: str) -> BaseChatModel:
- provider, model_name = model.split("/", maxsplit=1)
- if provider == "groq":
- # https://python.langchain.com/docs/integrations/chat/groq/
- from langchain_groq import ChatGroq
- return ChatGroq(model_name=model_name, temperature=0)
- if provider == "openai":
- # https://python.langchain.com/docs/integrations/chat/openai/
- from langchain_openai import ChatOpenAI
- return ChatOpenAI(model_name=model_name, temperature=0)
- if provider == "ollama":
- # https://python.langchain.com/docs/integrations/chat/ollama/
- from langchain_ollama import ChatOllama
- return ChatOllama(model=model_name, temperature=0)
- raise ValueError(f"Unknown provider: {provider}")
-
-llm = load_chat_model("groq/llama-3.3-70b-versatile")
-# llm = load_chat_model("openai/gpt-4o-mini")
-# llm = load_chat_model("ollama/mistral")
-
-from langchain_qdrant import QdrantVectorStore
-from langchain_community.embeddings import FastEmbedEmbeddings
-
-vectordb = QdrantVectorStore.from_existing_collection(
- host="localhost",
- prefer_grpc=True,
- # path="data/qdrant",
- collection_name="sparql-docs",
- embedding=FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5"),
-)
-
-retriever = vectordb.as_retriever()
-docs_retrieved_count = 5
-
-# retrieved_docs = retriever.invoke(question, k=docs_retrieved_count)
-
-from qdrant_client.models import FieldCondition, Filter, MatchValue
-from langchain_core.documents import Document
-
-def retrieve_docs(question: str) -> str:
- retrieved_docs = retriever.invoke(
- question,
- k=docs_retrieved_count,
- filter=Filter(
- must=[
- FieldCondition(
- key="metadata.doc_type",
- match=MatchValue(value="SPARQL endpoints query examples"),
- )
- ]
- )
- )
- retrieved_docs += retriever.invoke(
- question,
- k=docs_retrieved_count,
- filter=Filter(
- must_not=[
- FieldCondition(
- key="metadata.doc_type",
- match=MatchValue(value="SPARQL endpoints query examples"),
- )
- ]
- ),
- )
- return f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n"
-
-# relevant_docs = "\n".join(doc.page_content + "\n" + doc.metadata.get("answer") for doc in retrieved_docs)
-# relevant_docs = retrieve_docs(question)
-
-# print(f"ποΈ Retrieved {len(retrieved_docs)} documents")
-# # print(retrieved_docs)
-# for doc in retrieved_docs:
-# print(f"{doc.metadata.get('doc_type')} - {doc.metadata.get('endpoint_url')} - {doc.page_content}")
-
-
-from langchain_core.prompts import ChatPromptTemplate
-
-def _format_doc(doc: Document) -> str:
- """Format a question/answer document to be provided as context to the model."""
- doc_lang = (
- "sparql" if "query" in doc.metadata.get("doc_type", "")
- else "shex" if "schema" in doc.metadata.get("doc_type", "")
- else ""
- )
- return f"\n{doc.page_content} ({doc.metadata.get('endpoint_url')}):\n\n```{doc_lang}\n{doc.metadata.get('answer')}\n```\n"
-
-
-SYSTEM_PROMPT = """You are an assistant that helps users to write SPARQL queries.
-Put the SPARQL query inside a markdown codeblock with the "sparql" language tag, and always add the URL of the endpoint on which the query should be executed in a comment at the start of the query inside the codeblocks.
-Use the queries examples and classes shapes provided in the prompt to derive your answer, don't try to create a query from nothing and do not provide a generic query.
-Try to always answer with one query, if the answer lies in different endpoints, provide a federated query.
-And briefly explain the query.
-Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately:
-{relevant_docs}
-"""
-prompt_template = ChatPromptTemplate.from_messages([
- ("system", SYSTEM_PROMPT),
- ("placeholder", "{messages}"),
-])
-
-# prompt_with_context = prompt_template.invoke({
-# "messages": [("human", question)],
-# "relevant_docs": relevant_docs,
-# })
-
-# print(str("\n".join(prompt_with_context.messages)))
-
-# resp = llm.invoke("What are the rat orthologs of human TP53?")
-# print(resp)
-
-# for msg in llm.stream(prompt_with_context):
-# print(msg.content, end='')
-
-
-import chainlit as cl
-
-@cl.on_message
-async def on_message(msg: cl.Message):
- relevant_docs = retrieve_docs(msg.content)
- async with cl.Step(name="relevant documents") as step:
- # step.input = msg.content
- step.output = relevant_docs
-
- prompt_with_context = prompt_template.invoke({
- # "messages": [("human", msg.content)],
- "messages": cl.chat_context.to_openai(),
- "relevant_docs": relevant_docs,
- })
- answer = cl.Message(content="")
- for resp in llm.stream(prompt_with_context):
- await answer.stream_token(resp.content)
- await answer.send()
-
-
-@cl.set_starters
-async def set_starters():
- return [
- cl.Starter(
- label="Rat orthologs",
- message="What are the rat orthologs of human TP53?",
- # icon="/public/idea.svg",
- ),
- ]
-
-# uv run chainlit run simple.py
diff --git a/tutorial/graph.py b/tutorial/graph.py
new file mode 100644
index 0000000..6cd1030
--- /dev/null
+++ b/tutorial/graph.py
@@ -0,0 +1,255 @@
+from typing import Literal
+from langchain_qdrant import QdrantVectorStore
+from langchain_community.embeddings import FastEmbedEmbeddings
+from langchain_core.documents import Document
+from langchain_core.language_models import BaseChatModel
+from langgraph.graph import StateGraph
+from langgraph.graph.message import MessagesState
+from qdrant_client.models import FieldCondition, Filter, MatchValue
+import chainlit as cl
+
+
+def load_chat_model(model: str) -> BaseChatModel:
+ provider, model_name = model.split("/", maxsplit=1)
+ if provider == "groq":
+ # https://python.langchain.com/docs/integrations/chat/groq/
+ from langchain_groq import ChatGroq
+ return ChatGroq(
+ model_name=model_name,
+ temperature=0,
+ )
+ if provider == "openai":
+ # https://python.langchain.com/docs/integrations/chat/openai/
+ from langchain_openai import ChatOpenAI
+ return ChatOpenAI(
+ model_name=model_name,
+ temperature=0,
+ )
+ if provider == "ollama":
+ # https://python.langchain.com/docs/integrations/chat/ollama/
+ from langchain_ollama import ChatOllama
+ return ChatOllama(
+ model=model_name,
+ temperature=0,
+ )
+ raise ValueError(f"Unknown provider: {provider}")
+
+llm = load_chat_model("groq/llama-3.3-70b-versatile")
+# llm = load_chat_model("openai/gpt-4o-mini")
+# llm = load_chat_model("ollama/mistral")
+
+
+vectordb = QdrantVectorStore.from_existing_collection(
+ # path="data/qdrant",
+ host="localhost",
+ prefer_grpc=True,
+ collection_name="sparql-docs",
+ embedding=FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5"),
+)
+retriever = vectordb.as_retriever()
+
+
+class AgentState(MessagesState):
+ """State of the agent available inside each node."""
+ relevant_docs: str
+ passed_validation: bool
+ try_count: int
+
+retrieved_docs_count = 3
+async def retrieve_docs(state: AgentState) -> dict[str, str]:
+ """Retrieve documents relevant to the user's question."""
+ last_msg = state["messages"][-1]
+ retriever = vectordb.as_retriever()
+ retrieved_docs = retriever.invoke(
+ last_msg.content,
+ k=retrieved_docs_count,
+ filter=Filter(
+ must=[
+ FieldCondition(
+ key="metadata.doc_type",
+ match=MatchValue(value="SPARQL endpoints query examples"),
+ )
+ ]
+ )
+ )
+ retrieved_docs += retriever.invoke(
+ last_msg.content,
+ k=retrieved_docs_count,
+ filter=Filter(
+ must_not=[
+ FieldCondition(
+ key="metadata.doc_type",
+ match=MatchValue(value="SPARQL endpoints query examples"),
+ )
+ ]
+ ),
+ )
+ relevant_docs = f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n"
+ async with cl.Step(name=f"{len(retrieved_docs)} relevant documents ποΈ") as step:
+ step.output = relevant_docs
+ return {"relevant_docs": relevant_docs}
+
+
+def _format_doc(doc: Document) -> str:
+ """Format a question/answer document to be provided as context to the model."""
+ doc_lang = (
+ "sparql" if "query" in doc.metadata.get("doc_type", "")
+ else "shex" if "schema" in doc.metadata.get("doc_type", "")
+ else ""
+ )
+ return f"\n{doc.page_content} ({doc.metadata.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.metadata.get('answer')}\n```\n"
+ # # Default formatting
+ # meta = "".join(f" {k}={v!r}" for k, v in doc.metadata.items())
+ # if meta:
+ # meta = f" {meta}"
+ # return f"\n{doc.page_content}\n"
+
+
+SYSTEM_PROMPT = """You are an assistant that helps users to write SPARQL queries.
+Put the SPARQL query inside a markdown codeblock with the "sparql" language tag, and always add the URL of the endpoint on which the query should be executed in a comment at the start of the query inside the codeblocks.
+Use the queries examples and classes shapes provided in the prompt to derive your answer, don't try to create a query from nothing and do not provide a generic query.
+Try to always answer with one query, if the answer lies in different endpoints, provide a federated query.
+And briefly explain the query.
+Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately:
+{relevant_docs}
+"""
+
+def call_model(state: AgentState):
+ """Call the model with the retrieved documents as context."""
+ response = llm.invoke([
+ ("system", SYSTEM_PROMPT.format(relevant_docs=state["relevant_docs"])),
+ *state["messages"],
+ ])
+ # NOTE: to fix issue with ollama ignoring system messages
+ # state["messages"][-1].content = SYSTEM_PROMPT.replace("{relevant_docs}", state['relevant_docs']) + "\n\nHere is the user question:\n" + state["messages"][-1].content
+ # response = llm.invoke(state["messages"])
+
+ # # Fix id of response to use the same as the rest of the messages
+ # response.id = state["messages"][-1].id
+ return {"messages": [response]}
+
+
+import logging
+from sparql_llm.utils import get_prefixes_and_schema_for_endpoints
+from index import endpoints
+
+logging.getLogger("httpx").setLevel(logging.WARNING)
+logging.info("Initializing endpoints metadata...")
+# Retrieve the prefixes map and initialize VoID schema dictionary from the indexed endpoints
+prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(endpoints)
+
+
+from sparql_llm import validate_sparql_in_msg
+from langchain_core.messages import AIMessage
+
+async def validate_output(state: AgentState) -> dict[str, bool | list[tuple[str, str]] | int]:
+ """Node to validate the output of a LLM call, e.g. SPARQL queries generated."""
+ recall_messages = []
+ print(state["messages"])
+ last_msg = next(msg.content for msg in reversed(state["messages"]) if isinstance(msg, AIMessage) and msg.content)
+ print(last_msg)
+ # last_msg = state["messages"][-1].content
+ validation_outputs = validate_sparql_in_msg(last_msg, prefixes_map, endpoints_void_dict)
+ for validation_output in validation_outputs:
+ if validation_output["fixed_query"]:
+ async with cl.Step(name="missing prefixes correction β
") as step:
+ step.output = f"Missing prefixes added to the generated query:\n```sparql\n{validation_output['fixed_query']}\n```"
+ if validation_output["errors"]:
+ # errors_str = "- " + "\n- ".join(validation_output["errors"])
+ recall_msg = f"""Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n
+### Error messages:\n- {'\n- '.join(validation_output['errors'])}\n
+### Erroneous SPARQL query\n```sparql\n{validation_output.get('fixed_query', validation_output['original_query'])}\n```"""
+ # print(error_str, flush=True)
+ async with cl.Step(name=f"SPARQL query validation, got {len(validation_output['errors'])} errors to fix π") as step:
+ step.output = recall_msg
+ # Add a new message to ask the model to fix the error
+ recall_messages.append(("human", recall_msg))
+ return {
+ "messages": recall_messages,
+ "try_count": state.get("try_count", 0) + 1,
+ "passed_validation": not recall_messages,
+ }
+
+
+
+
+max_try_count = 3
+def route_model_output(state: AgentState) -> Literal["__end__", "call_model"]:
+ """Determine the next node based on the model's output."""
+ if state["try_count"] > max_try_count:
+ return "__end__"
+
+ # # Check for tool calls first
+ # if isinstance(last_msg, AIMessage) and state["messages"][-1].tool_calls:
+ # return "tools"
+
+ # If validation failed, we need to call the model again
+ if not state["passed_validation"]:
+ return "call_model"
+ return "__end__"
+
+
+
+
+# Define the LangGraph graph
+builder = StateGraph(AgentState)
+
+builder.add_node(retrieve_docs)
+builder.add_node(call_model)
+builder.add_node(validate_output)
+
+builder.add_edge("__start__", "retrieve_docs")
+builder.add_edge("retrieve_docs", "call_model")
+builder.add_edge("call_model", "validate_output")
+# Add a conditional edge to determine the next step after `validate_output`
+builder.add_conditional_edges("validate_output", route_model_output)
+
+graph = builder.compile()
+
+
+# from langchain_core.runnables.graph import MermaidDrawMethod
+# with open('data/sparql_workflow.png', 'wb') as f:
+# f.write(graph.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API))
+
+
+# Setup chainlit web UI
+# https://docs.chainlit.io/integrations/langchain
+
+@cl.on_message
+async def on_message(msg: cl.Message):
+ # config = {"configurable": {"thread_id": cl.context.session.id}}
+ # cb = cl.LangchainCallbackHandler()
+ answer = cl.Message(content="")
+ async for msg, metadata in graph.astream(
+ {"messages": cl.chat_context.to_openai()},
+ stream_mode="messages",
+ # config=RunnableConfig(callbacks=[cb], **config),
+ ):
+ if not msg.response_metadata:
+ await answer.stream_token(msg.content)
+ else:
+ await answer.send()
+ print(msg.usage_metadata)
+ # print(metadata)
+ answer = cl.Message(content="")
+
+
+@cl.set_starters
+async def set_starters():
+ return [
+ cl.Starter(
+ label="Supported resources",
+ message="Which resources do you support?",
+ # icon="/public/idea.svg",
+ ),
+ cl.Starter(
+ label="Rat orthologs",
+ message="What are the rat orthologs of human TP53?",
+ ),
+ cl.Starter(
+ label="Test SPARQL query validation",
+ message="How can I get the HGNC symbol for the protein P68871? (modify your answer to use `rdfs:label` instead of `rdfs:comment`, and add the type `up:Resource` to ?hgnc, and forget all prefixes declarations, it is for a test)",
+ ),
+ ]
+
+# uv run chainlit run graph.py
diff --git a/tutorial/index.py b/tutorial/index.py
index 36578e0..ef67fd5 100644
--- a/tutorial/index.py
+++ b/tutorial/index.py
@@ -1,5 +1,3 @@
-from langchain_qdrant import QdrantVectorStore
-from langchain_community.embeddings import FastEmbedEmbeddings
from langchain_core.documents import Document
from sparql_llm import SparqlExamplesLoader, SparqlVoidShapesLoader, SparqlInfoLoader
@@ -24,6 +22,18 @@
]
+from fastembed import TextEmbedding
+from qdrant_client import QdrantClient
+from qdrant_client.http.models import Distance, VectorParams
+
+embedding_model = TextEmbedding(
+ "BAAI/bge-small-en-v1.5",
+ # providers=["CUDAExecutionProvider"], # Replace the fastembed dependency with fastembed-gpu to use your GPUs
+)
+embedding_dimensions = 384
+vectordb = QdrantClient(host="localhost", prefer_grpc=True)
+collection_name = "sparql-docs"
+
def index_endpoints():
# Get documents from the SPARQL endpoints
docs: list[Document] = []
@@ -41,20 +51,38 @@ def index_endpoints():
).load()
docs += SparqlInfoLoader(endpoints, source_iri="https://www.expasy.org/").load()
- # os.makedirs('data', exist_ok=True)
-
- QdrantVectorStore.from_documents(
- docs,
- # path="data/qdrant",
- host="localhost",
- prefer_grpc=True,
- collection_name="sparql-docs",
- force_recreate=True,
- embedding=FastEmbedEmbeddings(
- model_name="BAAI/bge-small-en-v1.5",
- # providers=["CUDAExecutionProvider"], # Uncomment this line to use your GPUs
- ),
+
+
+ if vectordb.collection_exists(collection_name):
+ vectordb.delete_collection(collection_name)
+ vectordb.create_collection(
+ collection_name=collection_name,
+ vectors_config=VectorParams(size=embedding_dimensions, distance=Distance.COSINE),
)
+ embeddings = embedding_model.embed([q.page_content for q in docs])
+ vectordb.upload_collection(
+ collection_name=collection_name,
+ vectors=[embed.tolist() for embed in embeddings],
+ payload=[doc.metadata for doc in docs],
+ )
+
+ # # Using LangChain VectorStore object
+ # from langchain_qdrant import QdrantVectorStore
+ # from langchain_community.embeddings import FastEmbedEmbeddings
+ # QdrantVectorStore.from_documents(
+ # docs,
+ # host="localhost",
+ # # location=":memory:",
+ # # path="data/qdrant",
+ # prefer_grpc=True,
+ # collection_name="sparql-docs",
+ # force_recreate=True,
+ # embedding=FastEmbedEmbeddings(
+ # model_name="BAAI/bge-small-en-v1.5",
+ # # providers=["CUDAExecutionProvider"], # Uncomment this line to use your GPUs
+ # ),
+ # )
+
if __name__ == "__main__":
index_endpoints()
diff --git a/tutorial/public/style.css b/tutorial/public/style.css
index 12297bf..becf3c5 100644
--- a/tutorial/public/style.css
+++ b/tutorial/public/style.css
@@ -1,2 +1,6 @@
-pre { padding: .5em; }
-a.watermark { display: none !important; }
+pre {
+ padding: .5em;
+}
+a.watermark {
+ display: none !important;
+}
diff --git a/tutorial/pyproject.toml b/tutorial/pyproject.toml
index 9b51913..c3a7684 100644
--- a/tutorial/pyproject.toml
+++ b/tutorial/pyproject.toml
@@ -8,14 +8,14 @@ dependencies = [
# "sparql-llm @ git+https://github.com/sib-swiss/sparql-llm.git#subdirectory=packages/sparql-llm",
# "sparql-llm @ file:///home/vemonet/dev/expasy/sparql-llm/packages/sparql-llm",
"langchain >=0.3.19",
- "langchain-community >=0.3.17",
"langchain-openai >=0.3.6",
"langchain-groq >=0.2.4",
"langchain-ollama >=0.2.3",
- "langchain-qdrant >=0.2.0",
"qdrant-client >=1.13.0",
"fastembed >=0.5.1",
# "fastembed-gpu >=0.5.1", # Optional GPU support
"chainlit",
"langgraph >=0.2.73",
+ "langchain-qdrant >=0.2.0",
+ "langchain-community >=0.3.17",
]
diff --git a/tutorial/slides/index.html b/tutorial/slides/index.html
index 4d04e08..9ef81c5 100644
--- a/tutorial/slides/index.html
+++ b/tutorial/slides/index.html
@@ -12,7 +12,7 @@
@@ -32,13 +32,13 @@
hash: true,
history: true,
hashOneBasedIndex: true,
- // slideNumber: true,
+ slideNumber: true,
clipcode: {
// https://www.npmjs.com/package/@edc4it/reveal.js-clipcode
style: {
copybg: 'silver',
scale: 0.8,
- radius: 1,
+ radius: 0.5,
},
},
});
@@ -51,7 +51,9 @@
font-size: 0.6em;
}
.reveal section h2 {
- font-size: 1em;
+ font-family: Lato, sans-serif;
+ font-size: .9em;
+ text-transform: none;
}
.reveal section pre code {
border-radius: 8px;
diff --git a/tutorial/slides/public/slides.md b/tutorial/slides/public/slides.md
index 25d3532..8c3e396 100644
--- a/tutorial/slides/public/slides.md
+++ b/tutorial/slides/public/slides.md
@@ -12,14 +12,16 @@ As we progress, you'll be provided with code snippets to gradually construct the
2. Index documents
3. Use indexed documents as context
4. Add a web UI
-5. Define a more complex agent workflow
-6. Add SPARQL query validation
+5. Add SPARQL query validation
+6. Optional: use an agent framework
---
## Setup
-[Install `uv`](https://docs.astral.sh/uv/getting-started/installation/) to easily handle dependencies and run scripts
+[Install `uv`](https://docs.astral.sh/uv/getting-started/installation/) to easily handle dependencies and run scripts.
+
+If you use VSCode we recommend to have the [`Python` extension](https://marketplace.visualstudio.com/items?itemName=ms-python.python) installed.
Create a new folder, you will be using this same folder along the tutorial.
@@ -46,21 +48,18 @@ requires-python = "==3.12.*"
dependencies = [
"sparql-llm >=0.0.8",
"langchain >=0.3.19",
- "langchain-community >=0.3.17",
"langchain-openai >=0.3.6",
"langchain-groq >=0.2.4",
"langchain-ollama >=0.2.3",
- "langchain-qdrant >=0.2.0",
"qdrant-client >=1.13.0",
"fastembed >=0.5.1",
"chainlit >=2.2.1",
- "langgraph >=0.2.73",
]
```
---
-## Call a LLM
+## Programmatically query a LLM
Create a `app.py` file in the same folder
@@ -116,6 +115,8 @@ llm = load_chat_model("groq/llama-3.3-70b-versatile")
# llm = load_chat_model("openai/gpt-4o-mini")
```
+> Alternatively you could replace LangChain by [LiteLLM](https://docs.litellm.ai/docs/) here
+
---
## Use a local LLM
@@ -139,6 +140,10 @@ Add the new provider:
llm = load_chat_model("ollama/mistral")
```
+> Ollama is mainly a wrapper around [llama.cpp](https://python.langchain.com/docs/integrations/chat/llamacpp/), you can also [download `.gguf` files](https://huggingface.co/lmstudio-community/Mistral-7B-Instruct-v0.3-GGUF) and use them directly.
+
+> [vLLM](https://github.com/vllm-project/vllm) and [llamafile](https://github.com/Mozilla-Ocho/llamafile) are other solutions to serve LLMs locally.
+
---
## Setup vector store
@@ -149,7 +154,7 @@ Deploy a **[Qdrant](https://qdrant.tech/documentation/)** vector store using [do
docker run -d -p 6333:6333 -p 6334:6334 -v $(pwd)/data/qdrant:/qdrant/storage qdrant/qdrant
```
-Or create a `compose.yml` file and start with `docker compose up`
+Or create a `compose.yml` file and start with `docker compose up -d`
```yml
services:
@@ -211,7 +216,6 @@ def index_endpoints():
endpoint["endpoint_url"],
examples_file=endpoint.get("examples_file"),
).load()
-
docs += SparqlVoidShapesLoader(
endpoint["endpoint_url"],
void_file=endpoint.get("void_file"),
@@ -240,25 +244,65 @@ Finally we can load these documents in the **[Qdrant](https://qdrant.tech/docume
We use **[FastEmbed](https://qdrant.github.io/fastembed/)** to generate embeddings locally with [open source embedding models](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models).
```python
-from langchain_qdrant import QdrantVectorStore
-from langchain_community.embeddings import FastEmbedEmbeddings
+from fastembed import TextEmbedding
+from qdrant_client import QdrantClient
+from qdrant_client.http.models import Distance, VectorParams
-vectordb = QdrantVectorStore.from_documents(
- docs,
+embedding_model = TextEmbedding(
+ "BAAI/bge-small-en-v1.5",
+ # providers=["CUDAExecutionProvider"], # Replace the fastembed dependency with fastembed-gpu to use your GPUs
+)
+embedding_dimensions = 384
+collection_name = "sparql-docs"
+vectordb = QdrantClient(
host="localhost",
prefer_grpc=True,
- # path="data/qdrant", # if not using Qdrant as a service
- collection_name="sparql-docs",
- embedding=FastEmbedEmbeddings(
- model_name="BAAI/bge-small-en-v1.5",
- # providers=["CUDAExecutionProvider"], # Replace the fastembed dependency with fastembed-gpu to use your GPUs
- ),
- force_recreate=True,
+ # location=":memory:", # if not using Qdrant as a service
)
+def index_endpoints():
+ # [...]
+ if vectordb.collection_exists(collection_name):
+ vectordb.delete_collection(collection_name)
+ vectordb.create_collection(
+ collection_name=collection_name,
+ vectors_config=VectorParams(size=embedding_dimensions, distance=Distance.COSINE),
+ )
+ embeddings = embedding_model.embed([q.page_content for q in docs])
+ vectordb.upload_collection(
+ collection_name=collection_name,
+ vectors=[embed.tolist() for embed in embeddings],
+ payload=[doc.metadata for doc in docs],
+ )
```
> Checkout indexed docs at http://localhost:6333/dashboard
+----
+
+Alternatively you could use a LangChain retriever instead of the Qdrant client directly
+
+```python
+from langchain_qdrant import QdrantVectorStore
+from langchain_community.embeddings import FastEmbedEmbeddings
+
+def index_endpoints():
+ # [...]
+ QdrantVectorStore.from_documents(
+ docs,
+ host="localhost",
+ prefer_grpc=True,
+ # location=":memory:", # if not using Qdrant as a service
+ collection_name="sparql-docs",
+ embedding=FastEmbedEmbeddings(
+ model_name="BAAI/bge-small-en-v1.5",
+ # providers=["CUDAExecutionProvider"], # Replace the fastembed dependency with fastembed-gpu to use your GPUs
+ ),
+ force_recreate=True,
+ )
+```
+
+> You will need to add the following dependencies to your `pyproject.toml`: `langchain-qdrant` and `langchain-community`
+
---
## Provide context to the LLM
@@ -267,6 +311,25 @@ Now we can go back to our `app.py` file.
And retrieve documents related to the user question using the vector store
+```python
+from index import vectordb, embedding_model, collection_name
+
+question_embeddings = next(iter(embedding_model.embed([question])))
+
+retrieved_docs_count = 3
+retrieved_docs = vectordb.search(
+ collection_name=collection_name,
+ query_vector=question_embeddings,
+ limit=retrieved_docs_count,
+)
+relevant_docs = "\n".join(doc.payload["question"] + "\n" + doc.payload["answer"] for doc in retrieved_docs)
+print(f"ποΈ Retrieved {len(retrieved_docs)} documents", retrieved_docs[0])
+```
+
+----
+
+If you are using LangChain retriever
+
```python
from langchain_qdrant import QdrantVectorStore
from langchain_community.embeddings import FastEmbedEmbeddings
@@ -282,11 +345,11 @@ retriever = vectordb.as_retriever()
retrieved_docs_count = 3
retrieved_docs = retriever.invoke(question, k=retrieved_docs_count)
relevant_docs = "\n".join(doc.page_content + "\n" + doc.metadata.get("answer") for doc in retrieved_docs)
-
-print(f"ποΈ Retrieved {len(retrieved_docs)} documents")
-print(retrieved_docs[0])
+print(f"ποΈ Retrieved {len(retrieved_docs)} documents", retrieved_docs[0])
```
+> LangChain retriever returns a list of `Document` instead of `ScoredPoint`, access the fields using `metadata` instead of `payload`
+
---
## Provide context to the LLM
@@ -294,43 +357,38 @@ print(retrieved_docs[0])
Customize the system prompt to provide the retrieved documents
```python
-from langchain_core.prompts import ChatPromptTemplate
-
SYSTEM_PROMPT = """You are an assistant that helps users to write SPARQL queries.
Put the SPARQL query inside a markdown codeblock with the "sparql" language tag, and always add the URL of the endpoint on which the query should be executed in a comment at the start of the query inside the codeblocks.
Use the queries examples and classes shapes provided in the prompt to derive your answer, don't try to create a query from nothing and do not provide a generic query.
Try to always answer with one query, if the answer lies in different endpoints, provide a federated query.
And briefly explain the query.
Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately:
-{relevant_docs}
-"""
-prompt_template = ChatPromptTemplate.from_messages([
- ("system", SYSTEM_PROMPT),
- ("placeholder", "{messages}"),
-])
-prompt_with_context = prompt_template.invoke({
- "messages": [("human", question)],
- "relevant_docs": relevant_docs,
-})
+{relevant_docs}"""
+messages = [
+ ("system", SYSTEM_PROMPT.format(relevant_docs=relevant_docs)),
+ ("human", question),
+]
```
+> Try now to pass `messages` to `llm.stream()`
+
---
## Provide context to the LLM
-We can improve how the documents are formatted when passed to the LLM:
+We can improve how the documents are formatted when passed to the LLM
```python
-from langchain_core.documents import Document
+from qdrant_client.models import ScoredPoint
-def _format_doc(doc: Document) -> str:
- """Format our question/answer document to be provided as context to the model."""
+def _format_doc(doc: ScoredPoint) -> str:
+ """Format a question/answer document to be provided as context to the model."""
doc_lang = (
- "sparql" if "query" in doc.metadata.get("doc_type", "")
- else "shex" if "schema" in doc.metadata.get("doc_type", "")
+ "sparql" if "query" in doc.payload.get("doc_type", "")
+ else "shex" if "schema" in doc.payload.get("doc_type", "")
else ""
)
- return f"\n{doc.page_content} ({doc.metadata.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.metadata.get('answer')}\n```\n"
+ return f"\n{doc.payload['question']} ({doc.payload.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.payload.get('answer')}\n```\n"
relevant_docs = f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n"
```
@@ -339,7 +397,39 @@ relevant_docs = f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_
## Provide context to the LLM
-We can retrieve documents related to query examples and classes shapes separately, to make sure we always get a number of examples and classes shapes.
+We can retrieve documents related to query examples and classes shapes separately, to make sure we always get a number of examples and classes shapes
+
+```python
+from qdrant_client.models import FieldCondition, Filter, MatchValue
+
+def retrieve_docs(question: str) -> str:
+ question_embeddings = next(iter(embedding_model.embed([question])))
+ retrieved_docs = vectordb.search(
+ collection_name=collection_name,
+ query_vector=question_embeddings,
+ limit=retrieved_docs_count,
+ query_filter=Filter(must=[FieldCondition(
+ key="doc_type",
+ match=MatchValue(value="SPARQL endpoints query examples"),
+ )]),
+ )
+ retrieved_docs += vectordb.search(
+ collection_name=collection_name,
+ query_vector=question_embeddings,
+ limit=retrieved_docs_count,
+ query_filter=Filter(must_not=[FieldCondition(
+ key="doc_type",
+ match=MatchValue(value="SPARQL endpoints query examples"),
+ )]),
+ )
+ return f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n"
+
+relevant_docs = retrieve_docs(question)
+```
+
+----
+
+If using LangChain retriever:
```python
from qdrant_client.models import FieldCondition, Filter, MatchValue
@@ -377,16 +467,19 @@ import chainlit as cl
@cl.on_message
async def on_message(msg: cl.Message):
+ """Main function to handle when user send a message to the assistant."""
relevant_docs = retrieve_docs(msg.content)
async with cl.Step(name="relevant documents ποΈ") as step:
step.output = relevant_docs
- prompt_with_context = prompt_template.invoke({
- "messages": cl.chat_context.to_openai(),
- "relevant_docs": relevant_docs,
- })
+ messages = [
+ ("system", SYSTEM_PROMPT.format(relevant_docs=relevant_docs)),
+ *cl.chat_context.to_openai(),
+ ]
answer = cl.Message(content="")
- for resp in llm.stream(prompt_with_context):
+ for resp in llm.stream(messages):
await answer.stream_token(resp.content)
+ if resp.usage_metadata:
+ print(resp.usage_metadata)
await answer.send()
```
@@ -413,7 +506,7 @@ async def set_starters():
]
```
-[Customize the UI](https://docs.chainlit.io/customisation/overview):
+And [customize the UI](https://docs.chainlit.io/customisation/overview)
- Change general settings in `.chainlit/config.toml`
- e.g. set `custom_css= "/public/style.css"` containing: `pre { padding: .5em; } a.watermark { display: none !important; }`
@@ -423,151 +516,183 @@ async def set_starters():
---
-## Define an agent workflow
+## Deploy with a nice web UI
-Create more complex agent workflow agent that can loop over themselves using [LangGraph](https://langchain-ai.github.io/langgraph/#):
+You can also change `retrieve_docs()` to make it `async`, and directly define the chainlit step in the retrieval function
-- To validate a generated query
-- To use tools
+```python
+async def retrieve_docs(question: str) -> str:
+ # [...]
+ async with cl.Step(name=f"{len(retrieved_docs)} relevant documents ποΈ") as step:
+ step.output = relevant_docs
+ return relevant_docs
+
+@cl.on_message
+async def on_message(msg: cl.Message):
+ relevant_docs = await retrieve_docs(msg.content)
+ # [...]
+```
---
-## Define an agent workflow
+## Add SPARQL query validation
-Define the state and update the retrieve function
+
+
-```python
-from langgraph.graph.message import MessagesState
+Why do we add validation of the query generated:
-class AgentState(MessagesState):
- """State of the agent available inside each node."""
- relevant_docs: str
+π§ fix missing prefixes
+π detect use of a wrong predicate with a class
-async def retrieve_docs(state: AgentState) -> dict[str, str]:
- question = state["messages"][-1].content
- # [...]
- async with cl.Step(name=f"{len(retrieved_docs)} relevant documents ποΈ") as step:
- step.output = relevant_docs
- # This will update relevant_docs in the state:
- return {"relevant_docs": relevant_docs}
-```
+
+
+

+
+
---
-## Define an agent workflow
+## Add SPARQL query validation
-Define the node to call the LLM
+Initialize the prefixes map and VoID classes schema that will be used by validation
```python
-def call_model(state: AgentState):
- """Call the model with the retrieved documents as context."""
- prompt_with_context = prompt_template.invoke({
- "messages": state["messages"],
- "relevant_docs": state['relevant_docs'],
- })
- response = llm.invoke(prompt_with_context)
- return {"messages": [response]}
+import logging
+from sparql_llm.utils import get_prefixes_and_schema_for_endpoints
+from index import endpoints
+
+logging.getLogger("httpx").setLevel(logging.WARNING)
+logging.info("Initializing endpoints metadata...")
+prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(endpoints)
```
---
-## Define an agent workflow
+## Add SPARQL query validation
-Define the workflow "graph"
+Create the validation function
```python
-from langgraph.graph import StateGraph
-
-builder = StateGraph(AgentState)
-
-builder.add_node(retrieve_docs)
-builder.add_node(call_model)
-
-builder.add_edge("__start__", "retrieve_docs")
-builder.add_edge("retrieve_docs", "call_model")
-builder.add_edge("call_model", "__end__")
+from sparql_llm import validate_sparql_in_msg
+from langchain_core.messages import AIMessage
-graph = builder.compile()
+async def validate_output(last_msg: str) -> str | None:
+ """Validate the output of a LLM call, e.g. SPARQL queries generated."""
+ validation_outputs = validate_sparql_in_msg(last_msg, prefixes_map, endpoints_void_dict)
+ for validation_output in validation_outputs:
+ # Add step when missing prefixes have been fixed
+ if validation_output["fixed_query"]:
+ async with cl.Step(name="missing prefixes correction β
") as step:
+ step.output = f"Missing prefixes added to the generated query:\n```sparql\n{validation_output['fixed_query']}\n```"
+ # Create a new message to ask the model to fix the errors
+ if validation_output["errors"]:
+ recall_msg = f"""Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n
+### Error messages:\n- {'\n- '.join(validation_output['errors'])}\n
+### Erroneous SPARQL query\n```sparql\n{validation_output.get('fixed_query', validation_output['original_query'])}\n```"""
+ async with cl.Step(name=f"SPARQL query validation, got {len(validation_output['errors'])} errors to fix π") as step:
+ step.output = recall_msg
+ return recall_msg
```
---
-## Define an agent workflow
+## Add SPARQL query validation
-Update the UI
+Update the main `on_message` function running the chat to add a loop that makes sure the validation passes, if not we recall the LLM asking to fix the wrong query
```python
+max_try_count = 3
+
@cl.on_message
async def on_message(msg: cl.Message):
- answer = cl.Message(content="")
- async for msg, metadata in graph.astream(
- {"messages": cl.chat_context.to_openai()},
- stream_mode="messages",
- ):
- if not msg.response_metadata:
- await answer.stream_token(msg.content)
+ # [...]
+ for _i in range(max_try_count):
+ answer = cl.Message(content="")
+ for resp in llm.stream(messages):
+ await answer.stream_token(resp.content)
+ await answer.send()
+ validation_msg = await validate_output(answer.content)
+ if validation_msg is None:
+ break
else:
- await answer.send()
- answer = cl.Message(content="")
+ messages.append(("human", validation_msg))
```
> Try running your agent again now
---
-## Add SPARQL query validation
+## Use an agent framework
+
+Optionally you can move to fully use an "agent framework" like [LangGraph](https://langchain-ai.github.io/langgraph/#):
+
+β
Give access to some nice features
+
+- switch between streaming and complete response
+- parallel execution of nodes
+- generate a visual diagram for your workflow
+
+β
Provide structure to build your workflow
-Add fields to the state related to validation
+β οΈ Can be slower at runtime than doing things yourself
+
+β οΈ Relies on more dependencies increasing the overall complexity of the system, some people might find it more confusing than just using good old loops
+
+---
+
+## Use an agent framework
+
+Add the `langgraph` dependency to your `pyproject.toml`
+
+Define the state and update the retrieve function
```python
+from langgraph.graph.message import MessagesState
+
class AgentState(MessagesState):
- # [...]
+ """State of the agent available inside each node."""
+ relevant_docs: str
passed_validation: bool
try_count: int
+
+
+async def retrieve_docs(state: AgentState) -> dict[str, str]:
+ question = state["messages"][-1].content
+ # [...]
+ # This will update relevant_docs in the state:
+ return {"relevant_docs": relevant_docs}
```
---
-## Add SPARQL query validation
+## Use an agent framework
-Initialize the prefixes map and VoID classes schema that will be used by validation
+Define the node to call the LLM
```python
-import logging
-from sparql_llm.utils import get_prefixes_and_schema_for_endpoints
-from index import endpoints
-
-logging.getLogger("httpx").setLevel(logging.WARNING)
-logging.info("Initializing endpoints metadata...")
-prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(endpoints)
+def call_model(state: AgentState):
+ """Call the model with the retrieved documents as context."""
+ response = llm.invoke([
+ ("system", SYSTEM_PROMPT.format(relevant_docs=state["relevant_docs"])),
+ *state["messages"],
+ ])
+ return {"messages": [response]}
```
---
-## Add SPARQL query validation
+## Use an agent framework
-Create the validation node
+Update the function that does validation
```python
-from sparql_llm import validate_sparql_in_msg
-
-async def validate_output(state: AgentState) -> dict[str, bool | list[tuple[str, str]] | int]:
- """Node to validate the output of a LLM call, e.g. SPARQL queries generated."""
- recall_messages = []
- validation_outputs = validate_sparql_in_msg(state["messages"][-1].content, prefixes_map, endpoints_void_dict)
- for validation_output in validation_outputs:
- # Handle when missing prefixes have been fixed
- if validation_output["fixed_query"]:
- async with cl.Step(name="missing prefixes correction β
") as step:
- step.output = f"Missing prefixes added to the generated query:\n```sparql\n{validation_output['fixed_query']}\n```"
- # Add a new message to ask the model to fix the errors
- if validation_output["errors"]:
- recall_msg = f"""Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n
-### Error messages:\n- {'\n- '.join(validation_output['errors'])}\n
-### Erroneous SPARQL query\n```sparql\n{validation_output['original_query']}\n```"""
- async with cl.Step(name=f"SPARQL query validation, got {len(validation_output['errors'])} errors to fix π") as step:
- step.output = recall_msg
+async def validate_output(state) -> dict[str, bool | list[tuple[str, str]] | int]:
+ recall_messages = []
+ last_msg = next(msg.content for msg in reversed(state["messages"]) if msg.content)
+ # [...]
+ # Add a new message to ask the model to fix the error
recall_messages.append(("human", recall_msg))
return {
"messages": recall_messages,
@@ -576,19 +701,21 @@ async def validate_output(state: AgentState) -> dict[str, bool | list[tuple[str,
}
```
+
+
---
-## Add SPARQL query validation
+## Use an agent framework
Create a conditional edge to route the workflow based on validation results
```python
from typing import Literal
-max_try_fix_sparql = 3
+max_try_count = 3
def route_model_output(state: AgentState) -> Literal["call_model", "__end__"]:
"""Determine the next node based on the model's output."""
- if state["try_count"] > max_try_fix_sparql:
+ if state["try_count"] > max_try_count:
return "__end__"
if not state["passed_validation"]:
return "call_model"
@@ -597,15 +724,67 @@ def route_model_output(state: AgentState) -> Literal["call_model", "__end__"]:
---
-## Add SPARQL query validation
+## Use an agent framework
-Add this new edge to the workflow graph
+Define the workflow "graph"
```python
+from langgraph.graph import StateGraph
+
+builder = StateGraph(AgentState)
+
+builder.add_node(retrieve_docs)
+builder.add_node(call_model)
builder.add_node(validate_output)
+builder.add_edge("__start__", "retrieve_docs")
+builder.add_edge("retrieve_docs", "call_model")
builder.add_edge("call_model", "validate_output")
builder.add_conditional_edges("validate_output", route_model_output)
+
+graph = builder.compile()
+```
+
+---
+
+## Use an agent framework
+
+Update the UI
+
+```python
+@cl.on_message
+async def on_message(msg: cl.Message):
+ answer = cl.Message(content="")
+ async for msg, metadata in graph.astream(
+ {"messages": cl.chat_context.to_openai()},
+ stream_mode="messages",
+ ):
+ if not msg.response_metadata:
+ await answer.stream_token(msg.content)
+ else:
+ print(msg.usage_metadata)
+ await answer.send()
+ answer = cl.Message(content="")
```
> Try running your agent again now
+
+---
+
+## Thank you
+
+[Complete script on GitHub](https://github.com/sib-swiss/sparql-llm/blob/main/tutorial/app.py)
+
+
+
+Live deployment for SIB endpoints (UniProt, Bgee, OMA, Rheaβ¦)
+
+[**chat.expasy.org**](https://chat.expasy.org)
+
+
+
+Code: [**github.com/sib-swiss/sparql-llm**](https://github.com/sib-swiss/sparql-llm)
+
+Short paper: [arxiv.org/abs/2410.06062](https://arxiv.org/abs/2410.06062)
+
+Standalone components available as a pip package: [pypi.org/project/sparql-llm](https://pypi.org/project/sparql-llm)
diff --git a/tutorial/slides/public/sparql_workflow.png b/tutorial/slides/public/sparql_workflow.png
new file mode 100644
index 0000000..51b3955
Binary files /dev/null and b/tutorial/slides/public/sparql_workflow.png differ
diff --git a/tutorial/uv.lock b/tutorial/uv.lock
index 6901233..9707790 100644
--- a/tutorial/uv.lock
+++ b/tutorial/uv.lock
@@ -1590,7 +1590,8 @@ wheels = [
[[package]]
name = "sparql-llm"
-source = { directory = "../packages/sparql-llm" }
+version = "0.0.8"
+source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "beautifulsoup4" },
{ name = "curies-rs" },
@@ -1598,22 +1599,9 @@ dependencies = [
{ name = "langchain-core" },
{ name = "rdflib" },
]
-
-[package.metadata]
-requires-dist = [
- { name = "beautifulsoup4", specifier = ">=4.13.0" },
- { name = "curies-rs", specifier = ">=0.1.3" },
- { name = "httpx", specifier = ">=0.27.2" },
- { name = "langchain-core", specifier = ">=0.3.34" },
- { name = "rdflib", specifier = ">=7.0.0" },
-]
-
-[package.metadata.requires-dev]
-dev = [
- { name = "mypy", specifier = ">=1.15.0" },
- { name = "pytest", specifier = ">=8.3.4" },
- { name = "pytest-cov", specifier = ">=6.0.0" },
- { name = "ruff", specifier = ">=0.9.5" },
+sdist = { url = "https://files.pythonhosted.org/packages/96/d6/a7e418dbc29d1571cbbbc8bb40f000fc35b4735e845f1ce91b8c6d22945b/sparql_llm-0.0.8.tar.gz", hash = "sha256:8b567f021d9074f072cbda852f7570ffd6bd26e03868019f8eef0acdb3cea670", size = 193778 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/71/ce/a68cb0274aa211e97a4c4efb014abce4eb8e4c4938ba3de12b797f5449a4/sparql_llm-0.0.8-py3-none-any.whl", hash = "sha256:465bf8ffa9fc88497541a17d35b44b153312d1b786e580c084596864eb1de009", size = 85356 },
]
[[package]]
@@ -1780,7 +1768,7 @@ requires-dist = [
{ name = "langchain-qdrant", specifier = ">=0.2.0" },
{ name = "langgraph", specifier = ">=0.2.73" },
{ name = "qdrant-client", specifier = ">=1.13.0" },
- { name = "sparql-llm", directory = "../packages/sparql-llm" },
+ { name = "sparql-llm", specifier = ">=0.0.8" },
]
[[package]]