-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
235 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from langchain_core.language_models import BaseChatModel | ||
|
||
# question = "What are the rat orthologs of human TP53?" | ||
|
||
def load_chat_model(model: str) -> BaseChatModel: | ||
provider, model_name = model.split("/", maxsplit=1) | ||
if provider == "groq": | ||
# https://python.langchain.com/docs/integrations/chat/groq/ | ||
from langchain_groq import ChatGroq | ||
return ChatGroq(model_name=model_name, temperature=0) | ||
if provider == "openai": | ||
# https://python.langchain.com/docs/integrations/chat/openai/ | ||
from langchain_openai import ChatOpenAI | ||
return ChatOpenAI(model_name=model_name, temperature=0) | ||
if provider == "ollama": | ||
# https://python.langchain.com/docs/integrations/chat/ollama/ | ||
from langchain_ollama import ChatOllama | ||
return ChatOllama(model=model_name, temperature=0) | ||
raise ValueError(f"Unknown provider: {provider}") | ||
|
||
llm = load_chat_model("groq/llama-3.3-70b-versatile") | ||
# llm = load_chat_model("openai/gpt-4o-mini") | ||
# llm = load_chat_model("ollama/mistral") | ||
|
||
from langchain_qdrant import QdrantVectorStore | ||
from langchain_community.embeddings import FastEmbedEmbeddings | ||
|
||
vectordb = QdrantVectorStore.from_existing_collection( | ||
path="data/qdrant", | ||
collection_name="sparql-docs", | ||
embedding=FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5"), | ||
) | ||
|
||
retriever = vectordb.as_retriever() | ||
number_of_docs_retrieved = 5 | ||
|
||
# retrieved_docs = retriever.invoke(question, k=number_of_docs_retrieved) | ||
|
||
from qdrant_client.models import FieldCondition, Filter, MatchValue | ||
from langchain_core.documents import Document | ||
|
||
def retrieve_docs(question: str) -> list[Document]: | ||
retrieved_docs = retriever.invoke( | ||
question, | ||
k=number_of_docs_retrieved, | ||
filter=Filter( | ||
must=[ | ||
FieldCondition( | ||
key="metadata.doc_type", | ||
match=MatchValue(value="SPARQL endpoints query examples"), | ||
) | ||
] | ||
) | ||
) | ||
retrieved_docs += retriever.invoke( | ||
question, | ||
k=number_of_docs_retrieved, | ||
filter=Filter( | ||
must_not=[ | ||
FieldCondition( | ||
key="metadata.doc_type", | ||
match=MatchValue(value="SPARQL endpoints query examples"), | ||
) | ||
] | ||
), | ||
) | ||
return retrieved_docs | ||
|
||
# retrieved_docs = retrieve_docs(question) | ||
|
||
# print(f"📚️ Retrieved {len(retrieved_docs)} documents") | ||
# # print(retrieved_docs) | ||
# for doc in retrieved_docs: | ||
# print(f"{doc.metadata.get('doc_type')} - {doc.metadata.get('endpoint_url')} - {doc.page_content}") | ||
|
||
|
||
from langchain_core.prompts import ChatPromptTemplate | ||
|
||
def _format_doc(doc: Document) -> str: | ||
"""Format a question/answer document to be provided as context to the model.""" | ||
doc_lang = ( | ||
"sparql" if "query" in doc.metadata.get("doc_type", "") | ||
else "shex" if "schema" in doc.metadata.get("doc_type", "") | ||
else "" | ||
) | ||
return f"<document>\n{doc.page_content} ({doc.metadata.get('endpoint_url')}):\n\n```{doc_lang}\n{doc.metadata.get('answer')}\n```\n</document>" | ||
|
||
|
||
SYSTEM_PROMPT = """You are an assistant that helps users to write SPARQL queries. | ||
Put the SPARQL query inside a markdown codeblock with the "sparql" language tag, and always add the URL of the endpoint on which the query should be executed in a comment at the start of the query inside the codeblocks. | ||
Use the queries examples and classes shapes provided in the prompt to derive your answer, don't try to create a query from nothing and do not provide a generic query. | ||
Try to always answer with one query, if the answer lies in different endpoints, provide a federated query. | ||
And briefly explain the query. | ||
Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately: | ||
{retrieved_docs} | ||
""" | ||
prompt_template = ChatPromptTemplate.from_messages([ | ||
("system", SYSTEM_PROMPT), | ||
("placeholder", "{messages}"), | ||
]) | ||
|
||
# formatted_docs = "\n".join(doc.page_content + "\n" + doc.metadata.get("answer") for doc in retrieved_docs) | ||
# formatted_docs = f"<documents>\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n</documents>" | ||
# prompt_with_context = prompt_template.invoke({ | ||
# "messages": [("human", question)], | ||
# "retrieved_docs": formatted_docs, | ||
# }) | ||
|
||
# print(str("\n".join(prompt_with_context.messages))) | ||
|
||
# resp = llm.invoke("What are the rat orthologs of human TP53?") | ||
# print(resp) | ||
|
||
# for msg in llm.stream(prompt_with_context): | ||
# print(msg.content, end='') | ||
|
||
|
||
import chainlit as cl | ||
|
||
@cl.on_message | ||
async def on_message(msg: cl.Message): | ||
retrieved_docs = retrieve_docs(msg.content) | ||
formatted_docs = f"<documents>\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n</documents>" | ||
async with cl.Step(name=f"{len(retrieved_docs)} relevant documents") as step: | ||
# step.input = msg.content | ||
step.output = formatted_docs | ||
|
||
prompt_with_context = prompt_template.invoke({ | ||
"messages": [("human", msg.content)], | ||
"retrieved_docs": formatted_docs, | ||
}) | ||
final_answer = cl.Message(content="") | ||
for resp in llm.stream(prompt_with_context): | ||
await final_answer.stream_token(resp.content) | ||
await final_answer.send() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os | ||
|
||
from langchain_qdrant import QdrantVectorStore | ||
from langchain_community.embeddings import FastEmbedEmbeddings | ||
from langchain_core.documents import Document | ||
from sparql_llm import SparqlExamplesLoader, SparqlVoidShapesLoader | ||
|
||
|
||
# List of endpoints that will be used | ||
endpoints: list[dict[str, str]] = [ | ||
{ | ||
# The URL of the SPARQL endpoint from which most informations will be extracted | ||
"endpoint_url": "https://sparql.uniprot.org/sparql/", | ||
# If VoID description or SPARQL query examples are not available in the endpoint, you can provide a VoID file (local or remote URL) | ||
"void_file": "../packages/sparql-llm/tests/void_uniprot.ttl", | ||
# "examples_file": "uniprot_examples.ttl", | ||
}, | ||
{ | ||
"endpoint_url": "https://www.bgee.org/sparql/", | ||
}, | ||
{ | ||
"endpoint_url": "https://sparql.omabrowser.org/sparql/", | ||
} | ||
] | ||
|
||
|
||
# Get documents from the SPARQL endpoints | ||
docs: list[Document] = [] | ||
for endpoint in endpoints: | ||
print(f"\n 🔎 Getting metadata for {endpoint['endpoint_url']}") | ||
queries_loader = SparqlExamplesLoader( | ||
endpoint["endpoint_url"], | ||
examples_file=endpoint.get("examples_file"), | ||
verbose=True, | ||
) | ||
docs += queries_loader.load() | ||
|
||
void_loader = SparqlVoidShapesLoader( | ||
endpoint["endpoint_url"], | ||
void_file=endpoint.get("void_file"), | ||
verbose=True, | ||
) | ||
docs += void_loader.load() | ||
|
||
os.makedirs('data', exist_ok=True) | ||
|
||
vectordb = QdrantVectorStore.from_documents( | ||
docs, | ||
path="data/qdrant", | ||
collection_name="sparql-docs", | ||
force_recreate=True, | ||
embedding=FastEmbedEmbeddings( | ||
model_name="BAAI/bge-small-en-v1.5", | ||
# providers=["CUDAExecutionProvider"], # Uncomment this line to use your GPUs | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[project] | ||
name = "tutorial-sparql-agent" | ||
version = "0.0.1" | ||
requires-python = "==3.12.*" | ||
|
||
dependencies = [ | ||
"sparql-llm >=0.0.4", | ||
"langchain >=0.3.14", | ||
"langchain-community >=0.3.17", | ||
"langchain-openai >=0.3.6", | ||
"langchain-groq >=0.2.4", | ||
"langchain-ollama >=0.2.3", | ||
"langchain-qdrant >=0.2.0", | ||
"qdrant-client >=1.13.0", | ||
"fastembed >=0.5.1", | ||
# "fastembed-gpu >=0.5.1", # Optional GPU support | ||
"chainlit", | ||
"langgraph >=0.2.73", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters