diff --git a/packages/sparql-llm/src/sparql_llm/sparql_info_loader.py b/packages/sparql-llm/src/sparql_llm/sparql_info_loader.py index 467ceb9..9c03a33 100644 --- a/packages/sparql-llm/src/sparql_llm/sparql_info_loader.py +++ b/packages/sparql-llm/src/sparql_llm/sparql_info_loader.py @@ -21,7 +21,7 @@ def load(self) -> list[Document]: """Load and return documents from the SPARQL endpoint.""" docs: list[Document] = [] - resources_summary_question = "Which resources are available through this system?" + resources_summary_question = "Which resources do you support?" metadata = { "question": resources_summary_question, "answer": f"This system helps to access the following SPARQL endpoints {self.org_label}:\n- " diff --git a/tutorial/app.py b/tutorial/app.py index 0dc8266..0155385 100644 --- a/tutorial/app.py +++ b/tutorial/app.py @@ -1,16 +1,7 @@ -from typing import Literal -from langchain_qdrant import QdrantVectorStore -from langchain_community.embeddings import FastEmbedEmbeddings -from langchain_core.documents import Document from langchain_core.language_models import BaseChatModel -from langchain_core.prompts import ChatPromptTemplate -from langgraph.graph import StateGraph -from langgraph.graph.message import MessagesState -from qdrant_client.models import FieldCondition, Filter, MatchValue +from qdrant_client.models import FieldCondition, Filter, MatchValue, ScoredPoint import chainlit as cl -# https://docs.chainlit.io/integrations/langchain - def load_chat_model(model: str) -> BaseChatModel: provider, model_name = model.split("/", maxsplit=1) @@ -20,6 +11,7 @@ def load_chat_model(model: str) -> BaseChatModel: return ChatGroq( model_name=model_name, temperature=0, + ) if provider == "openai": # https://python.langchain.com/docs/integrations/chat/openai/ @@ -38,49 +30,37 @@ def load_chat_model(model: str) -> BaseChatModel: raise ValueError(f"Unknown provider: {provider}") # llm = load_chat_model("groq/llama-3.3-70b-versatile") -llm = load_chat_model("openai/gpt-4o-mini") - +# llm = load_chat_model("openai/gpt-4o-mini") +llm = load_chat_model("ollama/mistral") -vectordb = QdrantVectorStore.from_existing_collection( - # path="data/qdrant", - host="localhost", - prefer_grpc=True, - collection_name="sparql-docs", - embedding=FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5"), -) -retriever = vectordb.as_retriever() - -class AgentState(MessagesState): - """State of the agent available inside each node.""" - relevant_docs: str - passed_validation: bool - try_count: int +from index import vectordb, embedding_model, collection_name retrieved_docs_count = 3 -async def retrieve_docs(state: AgentState) -> dict[str, str]: +async def retrieve_docs(question: str) -> str: """Retrieve documents relevant to the user's question.""" - last_msg = state["messages"][-1] - retriever = vectordb.as_retriever() - retrieved_docs = retriever.invoke( - last_msg.content, - k=retrieved_docs_count, - filter=Filter( + question_embeddings = next(iter(embedding_model.embed([question]))) + retrieved_docs = vectordb.search( + collection_name=collection_name, + query_vector=question_embeddings, + limit=retrieved_docs_count, + query_filter=Filter( must=[ FieldCondition( - key="metadata.doc_type", + key="doc_type", match=MatchValue(value="SPARQL endpoints query examples"), ) ] - ) + ), ) - retrieved_docs += retriever.invoke( - last_msg.content, - k=retrieved_docs_count, - filter=Filter( + retrieved_docs += vectordb.search( + collection_name=collection_name, + query_vector=question_embeddings, + limit=retrieved_docs_count, + query_filter=Filter( must_not=[ FieldCondition( - key="metadata.doc_type", + key="doc_type", match=MatchValue(value="SPARQL endpoints query examples"), ) ] @@ -89,18 +69,18 @@ async def retrieve_docs(state: AgentState) -> dict[str, str]: relevant_docs = f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n" async with cl.Step(name=f"{len(retrieved_docs)} relevant documents πŸ“šοΈ") as step: step.output = relevant_docs - return {"relevant_docs": relevant_docs} + return relevant_docs -def _format_doc(doc: Document) -> str: +def _format_doc(doc: ScoredPoint) -> str: """Format a question/answer document to be provided as context to the model.""" doc_lang = ( - "sparql" if "query" in doc.metadata.get("doc_type", "") - else "shex" if "schema" in doc.metadata.get("doc_type", "") + "sparql" if "query" in doc.payload.get("doc_type", "") + else "shex" if "schema" in doc.payload.get("doc_type", "") else "" ) - return f"\n{doc.page_content} ({doc.metadata.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.metadata.get('answer')}\n```\n" - # # Default formatting + return f"\n{doc.payload['question']} ({doc.payload.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.payload.get('answer')}\n```\n" + # # Generic formatting: # meta = "".join(f" {k}={v!r}" for k, v in doc.metadata.items()) # if meta: # meta = f" {meta}" @@ -115,21 +95,6 @@ def _format_doc(doc: Document) -> str: Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately: {relevant_docs} """ -prompt_template = ChatPromptTemplate.from_messages([ - ("system", SYSTEM_PROMPT), - ("placeholder", "{messages}"), -]) - -def call_model(state: AgentState): - """Call the model with the retrieved documents as context.""" - prompt_with_context = prompt_template.invoke({ - "messages": state["messages"], - "relevant_docs": state['relevant_docs'], - }) - response = llm.invoke(prompt_with_context) - # # Fix id of response to use the same as the rest of the messages - # response.id = state["messages"][-1].id - return {"messages": [response]} import logging @@ -144,101 +109,73 @@ def call_model(state: AgentState): from sparql_llm import validate_sparql_in_msg -async def validate_output(state: AgentState) -> dict[str, bool | list[tuple[str, str]] | int]: - """Node to validate the output of a LLM call, e.g. SPARQL queries generated.""" - recall_messages = [] - validation_outputs = validate_sparql_in_msg(state["messages"][-1].content, prefixes_map, endpoints_void_dict) +async def validate_output(last_msg: str) -> str | None: + """Validate the output of a LLM call, e.g. SPARQL queries generated.""" + validation_outputs = validate_sparql_in_msg(last_msg, prefixes_map, endpoints_void_dict) for validation_output in validation_outputs: if validation_output["fixed_query"]: async with cl.Step(name="missing prefixes correction βœ…") as step: step.output = f"Missing prefixes added to the generated query:\n```sparql\n{validation_output['fixed_query']}\n```" if validation_output["errors"]: - # errors_str = "- " + "\n- ".join(validation_output["errors"]) recall_msg = f"""Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n ### Error messages:\n- {'\n- '.join(validation_output['errors'])}\n -### Erroneous SPARQL query\n```sparql\n{validation_output['original_query']}\n```""" - # print(error_str, flush=True) +### Erroneous SPARQL query\n```sparql\n{validation_output.get('fixed_query', validation_output['original_query'])}\n```""" async with cl.Step(name=f"SPARQL query validation, got {len(validation_output['errors'])} errors to fix 🐞") as step: step.output = recall_msg - # Add a new message to ask the model to fix the error - recall_messages.append(("human", recall_msg)) - return { - "messages": recall_messages, - "try_count": state.get("try_count", 0) + 1, - "passed_validation": not recall_messages, - } - - - - -max_try_fix_sparql = 3 -def route_model_output(state: AgentState) -> Literal["__end__", "call_model"]: - """Determine the next node based on the model's output.""" - if state["try_count"] > max_try_fix_sparql: - return "__end__" - - # # Check for tool calls first - # if isinstance(last_msg, AIMessage) and state["messages"][-1].tool_calls: - # return "tools" - - # If validation failed, we need to call the model again - if not state["passed_validation"]: - return "call_model" - return "__end__" + return recall_msg -# Define the LangGraph graph -builder = StateGraph(AgentState) - -builder.add_node(retrieve_docs) -builder.add_node(call_model) -builder.add_node(validate_output) - -builder.add_edge("__start__", "retrieve_docs") -builder.add_edge("retrieve_docs", "call_model") -builder.add_edge("call_model", "validate_output") -# Add a conditional edge to determine the next step after `validate_output` -builder.add_conditional_edges("validate_output", route_model_output) - -graph = builder.compile() +# Setup chainlit web UI +import chainlit as cl +max_try_count = 3 -# Setup chainlit web UI @cl.on_message async def on_message(msg: cl.Message): - # config = {"configurable": {"thread_id": cl.context.session.id}} - # cb = cl.LangchainCallbackHandler() - print(cl.chat_context.to_openai()) - answer = cl.Message(content="") - async for msg, metadata in graph.astream( - # {"messages": [HumanMessage(content=msg.content)]}, - # {"messages": [("human", msg.content)]}, - {"messages": cl.chat_context.to_openai()}, - stream_mode="messages", - # config=RunnableConfig(callbacks=[cb], **config), - ): - if not msg.response_metadata: - # and msg.content and not isinstance(msg, HumanMessage) and metadata["langgraph_node"] == "call_model" - # print(msg, metadata) - await answer.stream_token(msg.content) + """Main function to handle when user send a message to the assistant.""" + relevant_docs = await retrieve_docs(msg.content) + messages = [ + ("system", SYSTEM_PROMPT.format(relevant_docs=relevant_docs)), + *cl.chat_context.to_openai(), + ] + # # NOTE: to fix issue with ollama only considering the last message: + # messages = [ + # ("human", SYSTEM_PROMPT.format(relevant_docs=relevant_docs) + f"\n\nHere is the user question:\n{msg.content}"), + # ] + + for _i in range(max_try_count): + answer = cl.Message(content="") + for resp in llm.stream(messages): + await answer.stream_token(resp.content) + if resp.usage_metadata: + print(resp.usage_metadata) + await answer.send() + + validation_msg = await validate_output(answer.content) + if validation_msg is None: + break else: - await answer.send() - answer = cl.Message(content="") + messages.append(("human", validation_msg)) + @cl.set_starters async def set_starters(): return [ + cl.Starter( + label="Supported resources", + message="Which resources do you support?", + # icon="/public/idea.svg", + ), cl.Starter( label="Rat orthologs", message="What are the rat orthologs of human TP53?", - # icon="/public/idea.svg", ), cl.Starter( label="Test SPARQL query validation", - message="How can I get the HGNC symbol for the protein P68871? (modify your answer to use `rdfs:label` instead of `rdfs:comment`, and add the type `up:Resource` to ?hgnc, and purposefully forget 2 prefixes declarations, it is for a test)", + message="How can I get the HGNC symbol for the protein P68871? (modify your answer to use `rdfs:label` instead of `rdfs:comment`, and add the type `up:Resource` to ?hgnc, and forget all prefixes declarations, it is for a test)", ), ] diff --git a/tutorial/chain.py b/tutorial/chain.py deleted file mode 100644 index 027dbe0..0000000 --- a/tutorial/chain.py +++ /dev/null @@ -1,149 +0,0 @@ -from langchain_core.language_models import BaseChatModel - -# question = "What are the rat orthologs of human TP53?" - -def load_chat_model(model: str) -> BaseChatModel: - provider, model_name = model.split("/", maxsplit=1) - if provider == "groq": - # https://python.langchain.com/docs/integrations/chat/groq/ - from langchain_groq import ChatGroq - return ChatGroq(model_name=model_name, temperature=0) - if provider == "openai": - # https://python.langchain.com/docs/integrations/chat/openai/ - from langchain_openai import ChatOpenAI - return ChatOpenAI(model_name=model_name, temperature=0) - if provider == "ollama": - # https://python.langchain.com/docs/integrations/chat/ollama/ - from langchain_ollama import ChatOllama - return ChatOllama(model=model_name, temperature=0) - raise ValueError(f"Unknown provider: {provider}") - -llm = load_chat_model("groq/llama-3.3-70b-versatile") -# llm = load_chat_model("openai/gpt-4o-mini") -# llm = load_chat_model("ollama/mistral") - -from langchain_qdrant import QdrantVectorStore -from langchain_community.embeddings import FastEmbedEmbeddings - -vectordb = QdrantVectorStore.from_existing_collection( - host="localhost", - prefer_grpc=True, - # path="data/qdrant", - collection_name="sparql-docs", - embedding=FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5"), -) - -retriever = vectordb.as_retriever() -docs_retrieved_count = 5 - -# retrieved_docs = retriever.invoke(question, k=docs_retrieved_count) - -from qdrant_client.models import FieldCondition, Filter, MatchValue -from langchain_core.documents import Document - -def retrieve_docs(question: str) -> str: - retrieved_docs = retriever.invoke( - question, - k=docs_retrieved_count, - filter=Filter( - must=[ - FieldCondition( - key="metadata.doc_type", - match=MatchValue(value="SPARQL endpoints query examples"), - ) - ] - ) - ) - retrieved_docs += retriever.invoke( - question, - k=docs_retrieved_count, - filter=Filter( - must_not=[ - FieldCondition( - key="metadata.doc_type", - match=MatchValue(value="SPARQL endpoints query examples"), - ) - ] - ), - ) - return f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n" - -# relevant_docs = "\n".join(doc.page_content + "\n" + doc.metadata.get("answer") for doc in retrieved_docs) -# relevant_docs = retrieve_docs(question) - -# print(f"πŸ“šοΈ Retrieved {len(retrieved_docs)} documents") -# # print(retrieved_docs) -# for doc in retrieved_docs: -# print(f"{doc.metadata.get('doc_type')} - {doc.metadata.get('endpoint_url')} - {doc.page_content}") - - -from langchain_core.prompts import ChatPromptTemplate - -def _format_doc(doc: Document) -> str: - """Format a question/answer document to be provided as context to the model.""" - doc_lang = ( - "sparql" if "query" in doc.metadata.get("doc_type", "") - else "shex" if "schema" in doc.metadata.get("doc_type", "") - else "" - ) - return f"\n{doc.page_content} ({doc.metadata.get('endpoint_url')}):\n\n```{doc_lang}\n{doc.metadata.get('answer')}\n```\n" - - -SYSTEM_PROMPT = """You are an assistant that helps users to write SPARQL queries. -Put the SPARQL query inside a markdown codeblock with the "sparql" language tag, and always add the URL of the endpoint on which the query should be executed in a comment at the start of the query inside the codeblocks. -Use the queries examples and classes shapes provided in the prompt to derive your answer, don't try to create a query from nothing and do not provide a generic query. -Try to always answer with one query, if the answer lies in different endpoints, provide a federated query. -And briefly explain the query. -Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately: -{relevant_docs} -""" -prompt_template = ChatPromptTemplate.from_messages([ - ("system", SYSTEM_PROMPT), - ("placeholder", "{messages}"), -]) - -# prompt_with_context = prompt_template.invoke({ -# "messages": [("human", question)], -# "relevant_docs": relevant_docs, -# }) - -# print(str("\n".join(prompt_with_context.messages))) - -# resp = llm.invoke("What are the rat orthologs of human TP53?") -# print(resp) - -# for msg in llm.stream(prompt_with_context): -# print(msg.content, end='') - - -import chainlit as cl - -@cl.on_message -async def on_message(msg: cl.Message): - relevant_docs = retrieve_docs(msg.content) - async with cl.Step(name="relevant documents") as step: - # step.input = msg.content - step.output = relevant_docs - - prompt_with_context = prompt_template.invoke({ - # "messages": [("human", msg.content)], - "messages": cl.chat_context.to_openai(), - "relevant_docs": relevant_docs, - }) - answer = cl.Message(content="") - for resp in llm.stream(prompt_with_context): - await answer.stream_token(resp.content) - await answer.send() - - -@cl.set_starters -async def set_starters(): - return [ - cl.Starter( - label="Rat orthologs", - message="What are the rat orthologs of human TP53?", - # icon="/public/idea.svg", - ), - ] - -# uv run chainlit run simple.py diff --git a/tutorial/graph.py b/tutorial/graph.py new file mode 100644 index 0000000..6cd1030 --- /dev/null +++ b/tutorial/graph.py @@ -0,0 +1,255 @@ +from typing import Literal +from langchain_qdrant import QdrantVectorStore +from langchain_community.embeddings import FastEmbedEmbeddings +from langchain_core.documents import Document +from langchain_core.language_models import BaseChatModel +from langgraph.graph import StateGraph +from langgraph.graph.message import MessagesState +from qdrant_client.models import FieldCondition, Filter, MatchValue +import chainlit as cl + + +def load_chat_model(model: str) -> BaseChatModel: + provider, model_name = model.split("/", maxsplit=1) + if provider == "groq": + # https://python.langchain.com/docs/integrations/chat/groq/ + from langchain_groq import ChatGroq + return ChatGroq( + model_name=model_name, + temperature=0, + ) + if provider == "openai": + # https://python.langchain.com/docs/integrations/chat/openai/ + from langchain_openai import ChatOpenAI + return ChatOpenAI( + model_name=model_name, + temperature=0, + ) + if provider == "ollama": + # https://python.langchain.com/docs/integrations/chat/ollama/ + from langchain_ollama import ChatOllama + return ChatOllama( + model=model_name, + temperature=0, + ) + raise ValueError(f"Unknown provider: {provider}") + +llm = load_chat_model("groq/llama-3.3-70b-versatile") +# llm = load_chat_model("openai/gpt-4o-mini") +# llm = load_chat_model("ollama/mistral") + + +vectordb = QdrantVectorStore.from_existing_collection( + # path="data/qdrant", + host="localhost", + prefer_grpc=True, + collection_name="sparql-docs", + embedding=FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5"), +) +retriever = vectordb.as_retriever() + + +class AgentState(MessagesState): + """State of the agent available inside each node.""" + relevant_docs: str + passed_validation: bool + try_count: int + +retrieved_docs_count = 3 +async def retrieve_docs(state: AgentState) -> dict[str, str]: + """Retrieve documents relevant to the user's question.""" + last_msg = state["messages"][-1] + retriever = vectordb.as_retriever() + retrieved_docs = retriever.invoke( + last_msg.content, + k=retrieved_docs_count, + filter=Filter( + must=[ + FieldCondition( + key="metadata.doc_type", + match=MatchValue(value="SPARQL endpoints query examples"), + ) + ] + ) + ) + retrieved_docs += retriever.invoke( + last_msg.content, + k=retrieved_docs_count, + filter=Filter( + must_not=[ + FieldCondition( + key="metadata.doc_type", + match=MatchValue(value="SPARQL endpoints query examples"), + ) + ] + ), + ) + relevant_docs = f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n" + async with cl.Step(name=f"{len(retrieved_docs)} relevant documents πŸ“šοΈ") as step: + step.output = relevant_docs + return {"relevant_docs": relevant_docs} + + +def _format_doc(doc: Document) -> str: + """Format a question/answer document to be provided as context to the model.""" + doc_lang = ( + "sparql" if "query" in doc.metadata.get("doc_type", "") + else "shex" if "schema" in doc.metadata.get("doc_type", "") + else "" + ) + return f"\n{doc.page_content} ({doc.metadata.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.metadata.get('answer')}\n```\n" + # # Default formatting + # meta = "".join(f" {k}={v!r}" for k, v in doc.metadata.items()) + # if meta: + # meta = f" {meta}" + # return f"\n{doc.page_content}\n" + + +SYSTEM_PROMPT = """You are an assistant that helps users to write SPARQL queries. +Put the SPARQL query inside a markdown codeblock with the "sparql" language tag, and always add the URL of the endpoint on which the query should be executed in a comment at the start of the query inside the codeblocks. +Use the queries examples and classes shapes provided in the prompt to derive your answer, don't try to create a query from nothing and do not provide a generic query. +Try to always answer with one query, if the answer lies in different endpoints, provide a federated query. +And briefly explain the query. +Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately: +{relevant_docs} +""" + +def call_model(state: AgentState): + """Call the model with the retrieved documents as context.""" + response = llm.invoke([ + ("system", SYSTEM_PROMPT.format(relevant_docs=state["relevant_docs"])), + *state["messages"], + ]) + # NOTE: to fix issue with ollama ignoring system messages + # state["messages"][-1].content = SYSTEM_PROMPT.replace("{relevant_docs}", state['relevant_docs']) + "\n\nHere is the user question:\n" + state["messages"][-1].content + # response = llm.invoke(state["messages"]) + + # # Fix id of response to use the same as the rest of the messages + # response.id = state["messages"][-1].id + return {"messages": [response]} + + +import logging +from sparql_llm.utils import get_prefixes_and_schema_for_endpoints +from index import endpoints + +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.info("Initializing endpoints metadata...") +# Retrieve the prefixes map and initialize VoID schema dictionary from the indexed endpoints +prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(endpoints) + + +from sparql_llm import validate_sparql_in_msg +from langchain_core.messages import AIMessage + +async def validate_output(state: AgentState) -> dict[str, bool | list[tuple[str, str]] | int]: + """Node to validate the output of a LLM call, e.g. SPARQL queries generated.""" + recall_messages = [] + print(state["messages"]) + last_msg = next(msg.content for msg in reversed(state["messages"]) if isinstance(msg, AIMessage) and msg.content) + print(last_msg) + # last_msg = state["messages"][-1].content + validation_outputs = validate_sparql_in_msg(last_msg, prefixes_map, endpoints_void_dict) + for validation_output in validation_outputs: + if validation_output["fixed_query"]: + async with cl.Step(name="missing prefixes correction βœ…") as step: + step.output = f"Missing prefixes added to the generated query:\n```sparql\n{validation_output['fixed_query']}\n```" + if validation_output["errors"]: + # errors_str = "- " + "\n- ".join(validation_output["errors"]) + recall_msg = f"""Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n +### Error messages:\n- {'\n- '.join(validation_output['errors'])}\n +### Erroneous SPARQL query\n```sparql\n{validation_output.get('fixed_query', validation_output['original_query'])}\n```""" + # print(error_str, flush=True) + async with cl.Step(name=f"SPARQL query validation, got {len(validation_output['errors'])} errors to fix 🐞") as step: + step.output = recall_msg + # Add a new message to ask the model to fix the error + recall_messages.append(("human", recall_msg)) + return { + "messages": recall_messages, + "try_count": state.get("try_count", 0) + 1, + "passed_validation": not recall_messages, + } + + + + +max_try_count = 3 +def route_model_output(state: AgentState) -> Literal["__end__", "call_model"]: + """Determine the next node based on the model's output.""" + if state["try_count"] > max_try_count: + return "__end__" + + # # Check for tool calls first + # if isinstance(last_msg, AIMessage) and state["messages"][-1].tool_calls: + # return "tools" + + # If validation failed, we need to call the model again + if not state["passed_validation"]: + return "call_model" + return "__end__" + + + + +# Define the LangGraph graph +builder = StateGraph(AgentState) + +builder.add_node(retrieve_docs) +builder.add_node(call_model) +builder.add_node(validate_output) + +builder.add_edge("__start__", "retrieve_docs") +builder.add_edge("retrieve_docs", "call_model") +builder.add_edge("call_model", "validate_output") +# Add a conditional edge to determine the next step after `validate_output` +builder.add_conditional_edges("validate_output", route_model_output) + +graph = builder.compile() + + +# from langchain_core.runnables.graph import MermaidDrawMethod +# with open('data/sparql_workflow.png', 'wb') as f: +# f.write(graph.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API)) + + +# Setup chainlit web UI +# https://docs.chainlit.io/integrations/langchain + +@cl.on_message +async def on_message(msg: cl.Message): + # config = {"configurable": {"thread_id": cl.context.session.id}} + # cb = cl.LangchainCallbackHandler() + answer = cl.Message(content="") + async for msg, metadata in graph.astream( + {"messages": cl.chat_context.to_openai()}, + stream_mode="messages", + # config=RunnableConfig(callbacks=[cb], **config), + ): + if not msg.response_metadata: + await answer.stream_token(msg.content) + else: + await answer.send() + print(msg.usage_metadata) + # print(metadata) + answer = cl.Message(content="") + + +@cl.set_starters +async def set_starters(): + return [ + cl.Starter( + label="Supported resources", + message="Which resources do you support?", + # icon="/public/idea.svg", + ), + cl.Starter( + label="Rat orthologs", + message="What are the rat orthologs of human TP53?", + ), + cl.Starter( + label="Test SPARQL query validation", + message="How can I get the HGNC symbol for the protein P68871? (modify your answer to use `rdfs:label` instead of `rdfs:comment`, and add the type `up:Resource` to ?hgnc, and forget all prefixes declarations, it is for a test)", + ), + ] + +# uv run chainlit run graph.py diff --git a/tutorial/index.py b/tutorial/index.py index 36578e0..ef67fd5 100644 --- a/tutorial/index.py +++ b/tutorial/index.py @@ -1,5 +1,3 @@ -from langchain_qdrant import QdrantVectorStore -from langchain_community.embeddings import FastEmbedEmbeddings from langchain_core.documents import Document from sparql_llm import SparqlExamplesLoader, SparqlVoidShapesLoader, SparqlInfoLoader @@ -24,6 +22,18 @@ ] +from fastembed import TextEmbedding +from qdrant_client import QdrantClient +from qdrant_client.http.models import Distance, VectorParams + +embedding_model = TextEmbedding( + "BAAI/bge-small-en-v1.5", + # providers=["CUDAExecutionProvider"], # Replace the fastembed dependency with fastembed-gpu to use your GPUs +) +embedding_dimensions = 384 +vectordb = QdrantClient(host="localhost", prefer_grpc=True) +collection_name = "sparql-docs" + def index_endpoints(): # Get documents from the SPARQL endpoints docs: list[Document] = [] @@ -41,20 +51,38 @@ def index_endpoints(): ).load() docs += SparqlInfoLoader(endpoints, source_iri="https://www.expasy.org/").load() - # os.makedirs('data', exist_ok=True) - - QdrantVectorStore.from_documents( - docs, - # path="data/qdrant", - host="localhost", - prefer_grpc=True, - collection_name="sparql-docs", - force_recreate=True, - embedding=FastEmbedEmbeddings( - model_name="BAAI/bge-small-en-v1.5", - # providers=["CUDAExecutionProvider"], # Uncomment this line to use your GPUs - ), + + + if vectordb.collection_exists(collection_name): + vectordb.delete_collection(collection_name) + vectordb.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=embedding_dimensions, distance=Distance.COSINE), ) + embeddings = embedding_model.embed([q.page_content for q in docs]) + vectordb.upload_collection( + collection_name=collection_name, + vectors=[embed.tolist() for embed in embeddings], + payload=[doc.metadata for doc in docs], + ) + + # # Using LangChain VectorStore object + # from langchain_qdrant import QdrantVectorStore + # from langchain_community.embeddings import FastEmbedEmbeddings + # QdrantVectorStore.from_documents( + # docs, + # host="localhost", + # # location=":memory:", + # # path="data/qdrant", + # prefer_grpc=True, + # collection_name="sparql-docs", + # force_recreate=True, + # embedding=FastEmbedEmbeddings( + # model_name="BAAI/bge-small-en-v1.5", + # # providers=["CUDAExecutionProvider"], # Uncomment this line to use your GPUs + # ), + # ) + if __name__ == "__main__": index_endpoints() diff --git a/tutorial/public/style.css b/tutorial/public/style.css index 12297bf..becf3c5 100644 --- a/tutorial/public/style.css +++ b/tutorial/public/style.css @@ -1,2 +1,6 @@ -pre { padding: .5em; } -a.watermark { display: none !important; } +pre { + padding: .5em; +} +a.watermark { + display: none !important; +} diff --git a/tutorial/pyproject.toml b/tutorial/pyproject.toml index 9b51913..c3a7684 100644 --- a/tutorial/pyproject.toml +++ b/tutorial/pyproject.toml @@ -8,14 +8,14 @@ dependencies = [ # "sparql-llm @ git+https://github.com/sib-swiss/sparql-llm.git#subdirectory=packages/sparql-llm", # "sparql-llm @ file:///home/vemonet/dev/expasy/sparql-llm/packages/sparql-llm", "langchain >=0.3.19", - "langchain-community >=0.3.17", "langchain-openai >=0.3.6", "langchain-groq >=0.2.4", "langchain-ollama >=0.2.3", - "langchain-qdrant >=0.2.0", "qdrant-client >=1.13.0", "fastembed >=0.5.1", # "fastembed-gpu >=0.5.1", # Optional GPU support "chainlit", "langgraph >=0.2.73", + "langchain-qdrant >=0.2.0", + "langchain-community >=0.3.17", ] diff --git a/tutorial/slides/index.html b/tutorial/slides/index.html index 4d04e08..9ef81c5 100644 --- a/tutorial/slides/index.html +++ b/tutorial/slides/index.html @@ -12,7 +12,7 @@
-
+
@@ -32,13 +32,13 @@ hash: true, history: true, hashOneBasedIndex: true, - // slideNumber: true, + slideNumber: true, clipcode: { // https://www.npmjs.com/package/@edc4it/reveal.js-clipcode style: { copybg: 'silver', scale: 0.8, - radius: 1, + radius: 0.5, }, }, }); @@ -51,7 +51,9 @@ font-size: 0.6em; } .reveal section h2 { - font-size: 1em; + font-family: Lato, sans-serif; + font-size: .9em; + text-transform: none; } .reveal section pre code { border-radius: 8px; diff --git a/tutorial/slides/public/slides.md b/tutorial/slides/public/slides.md index 25d3532..8c3e396 100644 --- a/tutorial/slides/public/slides.md +++ b/tutorial/slides/public/slides.md @@ -12,14 +12,16 @@ As we progress, you'll be provided with code snippets to gradually construct the 2. Index documents 3. Use indexed documents as context 4. Add a web UI -5. Define a more complex agent workflow -6. Add SPARQL query validation +5. Add SPARQL query validation +6. Optional: use an agent framework --- ## Setup -[Install `uv`](https://docs.astral.sh/uv/getting-started/installation/) to easily handle dependencies and run scripts +[Install `uv`](https://docs.astral.sh/uv/getting-started/installation/) to easily handle dependencies and run scripts. + +If you use VSCode we recommend to have the [`Python` extension](https://marketplace.visualstudio.com/items?itemName=ms-python.python) installed. Create a new folder, you will be using this same folder along the tutorial. @@ -46,21 +48,18 @@ requires-python = "==3.12.*" dependencies = [ "sparql-llm >=0.0.8", "langchain >=0.3.19", - "langchain-community >=0.3.17", "langchain-openai >=0.3.6", "langchain-groq >=0.2.4", "langchain-ollama >=0.2.3", - "langchain-qdrant >=0.2.0", "qdrant-client >=1.13.0", "fastembed >=0.5.1", "chainlit >=2.2.1", - "langgraph >=0.2.73", ] ``` --- -## Call a LLM +## Programmatically query a LLM Create a `app.py` file in the same folder @@ -116,6 +115,8 @@ llm = load_chat_model("groq/llama-3.3-70b-versatile") # llm = load_chat_model("openai/gpt-4o-mini") ``` +> Alternatively you could replace LangChain by [LiteLLM](https://docs.litellm.ai/docs/) here + --- ## Use a local LLM @@ -139,6 +140,10 @@ Add the new provider: llm = load_chat_model("ollama/mistral") ``` +> Ollama is mainly a wrapper around [llama.cpp](https://python.langchain.com/docs/integrations/chat/llamacpp/), you can also [download `.gguf` files](https://huggingface.co/lmstudio-community/Mistral-7B-Instruct-v0.3-GGUF) and use them directly. + +> [vLLM](https://github.com/vllm-project/vllm) and [llamafile](https://github.com/Mozilla-Ocho/llamafile) are other solutions to serve LLMs locally. + --- ## Setup vector store @@ -149,7 +154,7 @@ Deploy a **[Qdrant](https://qdrant.tech/documentation/)** vector store using [do docker run -d -p 6333:6333 -p 6334:6334 -v $(pwd)/data/qdrant:/qdrant/storage qdrant/qdrant ``` -Or create a `compose.yml` file and start with `docker compose up` +Or create a `compose.yml` file and start with `docker compose up -d` ```yml services: @@ -211,7 +216,6 @@ def index_endpoints(): endpoint["endpoint_url"], examples_file=endpoint.get("examples_file"), ).load() - docs += SparqlVoidShapesLoader( endpoint["endpoint_url"], void_file=endpoint.get("void_file"), @@ -240,25 +244,65 @@ Finally we can load these documents in the **[Qdrant](https://qdrant.tech/docume We use **[FastEmbed](https://qdrant.github.io/fastembed/)** to generate embeddings locally with [open source embedding models](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models). ```python -from langchain_qdrant import QdrantVectorStore -from langchain_community.embeddings import FastEmbedEmbeddings +from fastembed import TextEmbedding +from qdrant_client import QdrantClient +from qdrant_client.http.models import Distance, VectorParams -vectordb = QdrantVectorStore.from_documents( - docs, +embedding_model = TextEmbedding( + "BAAI/bge-small-en-v1.5", + # providers=["CUDAExecutionProvider"], # Replace the fastembed dependency with fastembed-gpu to use your GPUs +) +embedding_dimensions = 384 +collection_name = "sparql-docs" +vectordb = QdrantClient( host="localhost", prefer_grpc=True, - # path="data/qdrant", # if not using Qdrant as a service - collection_name="sparql-docs", - embedding=FastEmbedEmbeddings( - model_name="BAAI/bge-small-en-v1.5", - # providers=["CUDAExecutionProvider"], # Replace the fastembed dependency with fastembed-gpu to use your GPUs - ), - force_recreate=True, + # location=":memory:", # if not using Qdrant as a service ) +def index_endpoints(): + # [...] + if vectordb.collection_exists(collection_name): + vectordb.delete_collection(collection_name) + vectordb.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=embedding_dimensions, distance=Distance.COSINE), + ) + embeddings = embedding_model.embed([q.page_content for q in docs]) + vectordb.upload_collection( + collection_name=collection_name, + vectors=[embed.tolist() for embed in embeddings], + payload=[doc.metadata for doc in docs], + ) ``` > Checkout indexed docs at http://localhost:6333/dashboard +---- + +Alternatively you could use a LangChain retriever instead of the Qdrant client directly + +```python +from langchain_qdrant import QdrantVectorStore +from langchain_community.embeddings import FastEmbedEmbeddings + +def index_endpoints(): + # [...] + QdrantVectorStore.from_documents( + docs, + host="localhost", + prefer_grpc=True, + # location=":memory:", # if not using Qdrant as a service + collection_name="sparql-docs", + embedding=FastEmbedEmbeddings( + model_name="BAAI/bge-small-en-v1.5", + # providers=["CUDAExecutionProvider"], # Replace the fastembed dependency with fastembed-gpu to use your GPUs + ), + force_recreate=True, + ) +``` + +> You will need to add the following dependencies to your `pyproject.toml`: `langchain-qdrant` and `langchain-community` + --- ## Provide context to the LLM @@ -267,6 +311,25 @@ Now we can go back to our `app.py` file. And retrieve documents related to the user question using the vector store +```python +from index import vectordb, embedding_model, collection_name + +question_embeddings = next(iter(embedding_model.embed([question]))) + +retrieved_docs_count = 3 +retrieved_docs = vectordb.search( + collection_name=collection_name, + query_vector=question_embeddings, + limit=retrieved_docs_count, +) +relevant_docs = "\n".join(doc.payload["question"] + "\n" + doc.payload["answer"] for doc in retrieved_docs) +print(f"πŸ“šοΈ Retrieved {len(retrieved_docs)} documents", retrieved_docs[0]) +``` + +---- + +If you are using LangChain retriever + ```python from langchain_qdrant import QdrantVectorStore from langchain_community.embeddings import FastEmbedEmbeddings @@ -282,11 +345,11 @@ retriever = vectordb.as_retriever() retrieved_docs_count = 3 retrieved_docs = retriever.invoke(question, k=retrieved_docs_count) relevant_docs = "\n".join(doc.page_content + "\n" + doc.metadata.get("answer") for doc in retrieved_docs) - -print(f"πŸ“šοΈ Retrieved {len(retrieved_docs)} documents") -print(retrieved_docs[0]) +print(f"πŸ“šοΈ Retrieved {len(retrieved_docs)} documents", retrieved_docs[0]) ``` +> LangChain retriever returns a list of `Document` instead of `ScoredPoint`, access the fields using `metadata` instead of `payload` + --- ## Provide context to the LLM @@ -294,43 +357,38 @@ print(retrieved_docs[0]) Customize the system prompt to provide the retrieved documents ```python -from langchain_core.prompts import ChatPromptTemplate - SYSTEM_PROMPT = """You are an assistant that helps users to write SPARQL queries. Put the SPARQL query inside a markdown codeblock with the "sparql" language tag, and always add the URL of the endpoint on which the query should be executed in a comment at the start of the query inside the codeblocks. Use the queries examples and classes shapes provided in the prompt to derive your answer, don't try to create a query from nothing and do not provide a generic query. Try to always answer with one query, if the answer lies in different endpoints, provide a federated query. And briefly explain the query. Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately: -{relevant_docs} -""" -prompt_template = ChatPromptTemplate.from_messages([ - ("system", SYSTEM_PROMPT), - ("placeholder", "{messages}"), -]) -prompt_with_context = prompt_template.invoke({ - "messages": [("human", question)], - "relevant_docs": relevant_docs, -}) +{relevant_docs}""" +messages = [ + ("system", SYSTEM_PROMPT.format(relevant_docs=relevant_docs)), + ("human", question), +] ``` +> Try now to pass `messages` to `llm.stream()` + --- ## Provide context to the LLM -We can improve how the documents are formatted when passed to the LLM: +We can improve how the documents are formatted when passed to the LLM ```python -from langchain_core.documents import Document +from qdrant_client.models import ScoredPoint -def _format_doc(doc: Document) -> str: - """Format our question/answer document to be provided as context to the model.""" +def _format_doc(doc: ScoredPoint) -> str: + """Format a question/answer document to be provided as context to the model.""" doc_lang = ( - "sparql" if "query" in doc.metadata.get("doc_type", "") - else "shex" if "schema" in doc.metadata.get("doc_type", "") + "sparql" if "query" in doc.payload.get("doc_type", "") + else "shex" if "schema" in doc.payload.get("doc_type", "") else "" ) - return f"\n{doc.page_content} ({doc.metadata.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.metadata.get('answer')}\n```\n" + return f"\n{doc.payload['question']} ({doc.payload.get('endpoint_url', '')}):\n\n```{doc_lang}\n{doc.payload.get('answer')}\n```\n" relevant_docs = f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n" ``` @@ -339,7 +397,39 @@ relevant_docs = f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_ ## Provide context to the LLM -We can retrieve documents related to query examples and classes shapes separately, to make sure we always get a number of examples and classes shapes. +We can retrieve documents related to query examples and classes shapes separately, to make sure we always get a number of examples and classes shapes + +```python +from qdrant_client.models import FieldCondition, Filter, MatchValue + +def retrieve_docs(question: str) -> str: + question_embeddings = next(iter(embedding_model.embed([question]))) + retrieved_docs = vectordb.search( + collection_name=collection_name, + query_vector=question_embeddings, + limit=retrieved_docs_count, + query_filter=Filter(must=[FieldCondition( + key="doc_type", + match=MatchValue(value="SPARQL endpoints query examples"), + )]), + ) + retrieved_docs += vectordb.search( + collection_name=collection_name, + query_vector=question_embeddings, + limit=retrieved_docs_count, + query_filter=Filter(must_not=[FieldCondition( + key="doc_type", + match=MatchValue(value="SPARQL endpoints query examples"), + )]), + ) + return f"\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n" + +relevant_docs = retrieve_docs(question) +``` + +---- + +If using LangChain retriever: ```python from qdrant_client.models import FieldCondition, Filter, MatchValue @@ -377,16 +467,19 @@ import chainlit as cl @cl.on_message async def on_message(msg: cl.Message): + """Main function to handle when user send a message to the assistant.""" relevant_docs = retrieve_docs(msg.content) async with cl.Step(name="relevant documents πŸ“šοΈ") as step: step.output = relevant_docs - prompt_with_context = prompt_template.invoke({ - "messages": cl.chat_context.to_openai(), - "relevant_docs": relevant_docs, - }) + messages = [ + ("system", SYSTEM_PROMPT.format(relevant_docs=relevant_docs)), + *cl.chat_context.to_openai(), + ] answer = cl.Message(content="") - for resp in llm.stream(prompt_with_context): + for resp in llm.stream(messages): await answer.stream_token(resp.content) + if resp.usage_metadata: + print(resp.usage_metadata) await answer.send() ``` @@ -413,7 +506,7 @@ async def set_starters(): ] ``` -[Customize the UI](https://docs.chainlit.io/customisation/overview): +And [customize the UI](https://docs.chainlit.io/customisation/overview) - Change general settings in `.chainlit/config.toml` - e.g. set `custom_css= "/public/style.css"` containing: `pre { padding: .5em; } a.watermark { display: none !important; }` @@ -423,151 +516,183 @@ async def set_starters(): --- -## Define an agent workflow +## Deploy with a nice web UI -Create more complex agent workflow agent that can loop over themselves using [LangGraph](https://langchain-ai.github.io/langgraph/#): +You can also change `retrieve_docs()` to make it `async`, and directly define the chainlit step in the retrieval function -- To validate a generated query -- To use tools +```python +async def retrieve_docs(question: str) -> str: + # [...] + async with cl.Step(name=f"{len(retrieved_docs)} relevant documents πŸ“šοΈ") as step: + step.output = relevant_docs + return relevant_docs + +@cl.on_message +async def on_message(msg: cl.Message): + relevant_docs = await retrieve_docs(msg.content) + # [...] +``` --- -## Define an agent workflow +## Add SPARQL query validation -Define the state and update the retrieve function +
+
-```python -from langgraph.graph.message import MessagesState +Why do we add validation of the query generated: -class AgentState(MessagesState): - """State of the agent available inside each node.""" - relevant_docs: str +🧠 fix missing prefixes +πŸ„ detect use of a wrong predicate with a class -async def retrieve_docs(state: AgentState) -> dict[str, str]: - question = state["messages"][-1].content - # [...] - async with cl.Step(name=f"{len(retrieved_docs)} relevant documents πŸ“šοΈ") as step: - step.output = relevant_docs - # This will update relevant_docs in the state: - return {"relevant_docs": relevant_docs} -``` +
+
+ SPARQL agent workflow +
+
--- -## Define an agent workflow +## Add SPARQL query validation -Define the node to call the LLM +Initialize the prefixes map and VoID classes schema that will be used by validation ```python -def call_model(state: AgentState): - """Call the model with the retrieved documents as context.""" - prompt_with_context = prompt_template.invoke({ - "messages": state["messages"], - "relevant_docs": state['relevant_docs'], - }) - response = llm.invoke(prompt_with_context) - return {"messages": [response]} +import logging +from sparql_llm.utils import get_prefixes_and_schema_for_endpoints +from index import endpoints + +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.info("Initializing endpoints metadata...") +prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(endpoints) ``` --- -## Define an agent workflow +## Add SPARQL query validation -Define the workflow "graph" +Create the validation function ```python -from langgraph.graph import StateGraph - -builder = StateGraph(AgentState) - -builder.add_node(retrieve_docs) -builder.add_node(call_model) - -builder.add_edge("__start__", "retrieve_docs") -builder.add_edge("retrieve_docs", "call_model") -builder.add_edge("call_model", "__end__") +from sparql_llm import validate_sparql_in_msg +from langchain_core.messages import AIMessage -graph = builder.compile() +async def validate_output(last_msg: str) -> str | None: + """Validate the output of a LLM call, e.g. SPARQL queries generated.""" + validation_outputs = validate_sparql_in_msg(last_msg, prefixes_map, endpoints_void_dict) + for validation_output in validation_outputs: + # Add step when missing prefixes have been fixed + if validation_output["fixed_query"]: + async with cl.Step(name="missing prefixes correction βœ…") as step: + step.output = f"Missing prefixes added to the generated query:\n```sparql\n{validation_output['fixed_query']}\n```" + # Create a new message to ask the model to fix the errors + if validation_output["errors"]: + recall_msg = f"""Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n +### Error messages:\n- {'\n- '.join(validation_output['errors'])}\n +### Erroneous SPARQL query\n```sparql\n{validation_output.get('fixed_query', validation_output['original_query'])}\n```""" + async with cl.Step(name=f"SPARQL query validation, got {len(validation_output['errors'])} errors to fix 🐞") as step: + step.output = recall_msg + return recall_msg ``` --- -## Define an agent workflow +## Add SPARQL query validation -Update the UI +Update the main `on_message` function running the chat to add a loop that makes sure the validation passes, if not we recall the LLM asking to fix the wrong query ```python +max_try_count = 3 + @cl.on_message async def on_message(msg: cl.Message): - answer = cl.Message(content="") - async for msg, metadata in graph.astream( - {"messages": cl.chat_context.to_openai()}, - stream_mode="messages", - ): - if not msg.response_metadata: - await answer.stream_token(msg.content) + # [...] + for _i in range(max_try_count): + answer = cl.Message(content="") + for resp in llm.stream(messages): + await answer.stream_token(resp.content) + await answer.send() + validation_msg = await validate_output(answer.content) + if validation_msg is None: + break else: - await answer.send() - answer = cl.Message(content="") + messages.append(("human", validation_msg)) ``` > Try running your agent again now --- -## Add SPARQL query validation +## Use an agent framework + +Optionally you can move to fully use an "agent framework" like [LangGraph](https://langchain-ai.github.io/langgraph/#): + +βœ… Give access to some nice features + +- switch between streaming and complete response +- parallel execution of nodes +- generate a visual diagram for your workflow + +βœ… Provide structure to build your workflow -Add fields to the state related to validation +⚠️ Can be slower at runtime than doing things yourself + +⚠️ Relies on more dependencies increasing the overall complexity of the system, some people might find it more confusing than just using good old loops + +--- + +## Use an agent framework + +Add the `langgraph` dependency to your `pyproject.toml` + +Define the state and update the retrieve function ```python +from langgraph.graph.message import MessagesState + class AgentState(MessagesState): - # [...] + """State of the agent available inside each node.""" + relevant_docs: str passed_validation: bool try_count: int + + +async def retrieve_docs(state: AgentState) -> dict[str, str]: + question = state["messages"][-1].content + # [...] + # This will update relevant_docs in the state: + return {"relevant_docs": relevant_docs} ``` --- -## Add SPARQL query validation +## Use an agent framework -Initialize the prefixes map and VoID classes schema that will be used by validation +Define the node to call the LLM ```python -import logging -from sparql_llm.utils import get_prefixes_and_schema_for_endpoints -from index import endpoints - -logging.getLogger("httpx").setLevel(logging.WARNING) -logging.info("Initializing endpoints metadata...") -prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(endpoints) +def call_model(state: AgentState): + """Call the model with the retrieved documents as context.""" + response = llm.invoke([ + ("system", SYSTEM_PROMPT.format(relevant_docs=state["relevant_docs"])), + *state["messages"], + ]) + return {"messages": [response]} ``` --- -## Add SPARQL query validation +## Use an agent framework -Create the validation node +Update the function that does validation ```python -from sparql_llm import validate_sparql_in_msg - -async def validate_output(state: AgentState) -> dict[str, bool | list[tuple[str, str]] | int]: - """Node to validate the output of a LLM call, e.g. SPARQL queries generated.""" - recall_messages = [] - validation_outputs = validate_sparql_in_msg(state["messages"][-1].content, prefixes_map, endpoints_void_dict) - for validation_output in validation_outputs: - # Handle when missing prefixes have been fixed - if validation_output["fixed_query"]: - async with cl.Step(name="missing prefixes correction βœ…") as step: - step.output = f"Missing prefixes added to the generated query:\n```sparql\n{validation_output['fixed_query']}\n```" - # Add a new message to ask the model to fix the errors - if validation_output["errors"]: - recall_msg = f"""Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n -### Error messages:\n- {'\n- '.join(validation_output['errors'])}\n -### Erroneous SPARQL query\n```sparql\n{validation_output['original_query']}\n```""" - async with cl.Step(name=f"SPARQL query validation, got {len(validation_output['errors'])} errors to fix 🐞") as step: - step.output = recall_msg +async def validate_output(state) -> dict[str, bool | list[tuple[str, str]] | int]: + recall_messages = [] + last_msg = next(msg.content for msg in reversed(state["messages"]) if msg.content) + # [...] + # Add a new message to ask the model to fix the error recall_messages.append(("human", recall_msg)) return { "messages": recall_messages, @@ -576,19 +701,21 @@ async def validate_output(state: AgentState) -> dict[str, bool | list[tuple[str, } ``` + + --- -## Add SPARQL query validation +## Use an agent framework Create a conditional edge to route the workflow based on validation results ```python from typing import Literal -max_try_fix_sparql = 3 +max_try_count = 3 def route_model_output(state: AgentState) -> Literal["call_model", "__end__"]: """Determine the next node based on the model's output.""" - if state["try_count"] > max_try_fix_sparql: + if state["try_count"] > max_try_count: return "__end__" if not state["passed_validation"]: return "call_model" @@ -597,15 +724,67 @@ def route_model_output(state: AgentState) -> Literal["call_model", "__end__"]: --- -## Add SPARQL query validation +## Use an agent framework -Add this new edge to the workflow graph +Define the workflow "graph" ```python +from langgraph.graph import StateGraph + +builder = StateGraph(AgentState) + +builder.add_node(retrieve_docs) +builder.add_node(call_model) builder.add_node(validate_output) +builder.add_edge("__start__", "retrieve_docs") +builder.add_edge("retrieve_docs", "call_model") builder.add_edge("call_model", "validate_output") builder.add_conditional_edges("validate_output", route_model_output) + +graph = builder.compile() +``` + +--- + +## Use an agent framework + +Update the UI + +```python +@cl.on_message +async def on_message(msg: cl.Message): + answer = cl.Message(content="") + async for msg, metadata in graph.astream( + {"messages": cl.chat_context.to_openai()}, + stream_mode="messages", + ): + if not msg.response_metadata: + await answer.stream_token(msg.content) + else: + print(msg.usage_metadata) + await answer.send() + answer = cl.Message(content="") ``` > Try running your agent again now + +--- + +## Thank you + +[Complete script on GitHub](https://github.com/sib-swiss/sparql-llm/blob/main/tutorial/app.py) + +  + +Live deployment for SIB endpoints (UniProt, Bgee, OMA, Rhea…) + +[**chat.expasy.org**](https://chat.expasy.org) + +  + +Code: [**github.com/sib-swiss/sparql-llm**](https://github.com/sib-swiss/sparql-llm) + +Short paper: [arxiv.org/abs/2410.06062](https://arxiv.org/abs/2410.06062) + +Standalone components available as a pip package: [pypi.org/project/sparql-llm](https://pypi.org/project/sparql-llm) diff --git a/tutorial/slides/public/sparql_workflow.png b/tutorial/slides/public/sparql_workflow.png new file mode 100644 index 0000000..51b3955 Binary files /dev/null and b/tutorial/slides/public/sparql_workflow.png differ diff --git a/tutorial/uv.lock b/tutorial/uv.lock index 6901233..9707790 100644 --- a/tutorial/uv.lock +++ b/tutorial/uv.lock @@ -1590,7 +1590,8 @@ wheels = [ [[package]] name = "sparql-llm" -source = { directory = "../packages/sparql-llm" } +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "beautifulsoup4" }, { name = "curies-rs" }, @@ -1598,22 +1599,9 @@ dependencies = [ { name = "langchain-core" }, { name = "rdflib" }, ] - -[package.metadata] -requires-dist = [ - { name = "beautifulsoup4", specifier = ">=4.13.0" }, - { name = "curies-rs", specifier = ">=0.1.3" }, - { name = "httpx", specifier = ">=0.27.2" }, - { name = "langchain-core", specifier = ">=0.3.34" }, - { name = "rdflib", specifier = ">=7.0.0" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "mypy", specifier = ">=1.15.0" }, - { name = "pytest", specifier = ">=8.3.4" }, - { name = "pytest-cov", specifier = ">=6.0.0" }, - { name = "ruff", specifier = ">=0.9.5" }, +sdist = { url = "https://files.pythonhosted.org/packages/96/d6/a7e418dbc29d1571cbbbc8bb40f000fc35b4735e845f1ce91b8c6d22945b/sparql_llm-0.0.8.tar.gz", hash = "sha256:8b567f021d9074f072cbda852f7570ffd6bd26e03868019f8eef0acdb3cea670", size = 193778 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/ce/a68cb0274aa211e97a4c4efb014abce4eb8e4c4938ba3de12b797f5449a4/sparql_llm-0.0.8-py3-none-any.whl", hash = "sha256:465bf8ffa9fc88497541a17d35b44b153312d1b786e580c084596864eb1de009", size = 85356 }, ] [[package]] @@ -1780,7 +1768,7 @@ requires-dist = [ { name = "langchain-qdrant", specifier = ">=0.2.0" }, { name = "langgraph", specifier = ">=0.2.73" }, { name = "qdrant-client", specifier = ">=1.13.0" }, - { name = "sparql-llm", directory = "../packages/sparql-llm" }, + { name = "sparql-llm", specifier = ">=0.0.8" }, ] [[package]]