Skip to content

Commit

Permalink
improve tutorial with complex agent graph
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Feb 18, 2025
1 parent acbc9ff commit d0dbb03
Show file tree
Hide file tree
Showing 13 changed files with 609 additions and 105 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ nohup.out
node_modules/

packages/expasy-agent/src/expasy_agent/webapp
tutorial/chainlit.md
tutorial/public
Binary file modified chat-with-context/demo/sib-logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions notebooks/compare_queries_examples_to_void.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -408,7 +408,7 @@
"source": [
"from sparql_llm.sparql_examples_loader import GET_SPARQL_EXAMPLES_QUERY\n",
"from sparql_llm.utils import query_sparql\n",
"from sparql_llm.validate_sparql import get_void_dict, sparql_query_to_dict\n",
"from sparql_llm.validate_sparql import get_schema_for_endpoint, sparql_query_to_dict\n",
"\n",
"check_endpoints = [\n",
" \"https://sparql.omabrowser.org/sparql/\",\n",
Expand Down Expand Up @@ -438,7 +438,7 @@
" all_preds = set()\n",
"\n",
" # Get all classes and predicates from the void description\n",
" unfiltered_void_dict = get_void_dict(endpoint_url)\n",
" unfiltered_void_dict = get_schema_for_endpoint(endpoint_url)\n",
" void_dict = {}\n",
" for cls, cls_dict in unfiltered_void_dict.items():\n",
" if ignore_namespaces(cls):\n",
Expand Down
4 changes: 3 additions & 1 deletion packages/expasy-agent/src/expasy_agent/nodes/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,6 @@ async def validate_output(state: State, config: RunnableConfig) -> dict[str, Any
return response


prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(settings.endpoints)
prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(
settings.endpoints
)
6 changes: 5 additions & 1 deletion packages/sparql-llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ You can provide the examples as a file if it is not integrated in the endpoint,
loader = SparqlExamplesLoader("https://sparql.uniprot.org/sparql/", examples_file="uniprot_examples.ttl")
```

> Refer to the [LangChain documentation](https://python.langchain.com/v0.2/docs/) to figure out how to best integrate documents loaders to your stack.
> Refer to the [LangChain documentation](https://python.langchain.com/v0.2/docs/) to figure out how to best integrate documents loaders to your system.
> [!NOTE]
>
> You can check the completeness of your examples against the endpoint schema using [this notebook](https://github.com/sib-swiss/sparql-llm/blob/main/notebooks/compare_queries_examples_to_void.ipynb).
### SPARQL endpoint schema loader

Expand Down
12 changes: 6 additions & 6 deletions packages/sparql-llm/src/sparql_llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, NotRequired, Optional, TypedDict
from typing import Any, Optional, TypedDict

import httpx
import rdflib
Expand All @@ -11,13 +11,13 @@
# Prefixes utilities


class SparqlEndpointInfo(TypedDict):
"""A dictionary to store information and links about a SPARQL endpoint."""
class SparqlEndpointInfo(TypedDict, total=False):
"""A dictionary to store links and filepaths about a SPARQL endpoint."""

endpoint_url: str
void_file: NotRequired[str] = None
examples_file: NotRequired[str] = None
homepage_url: NotRequired[str] = None
void_file: Optional[str]
examples_file: Optional[str]
homepage_url: Optional[str]


GET_PREFIXES_QUERY = """PREFIX sh: <http://www.w3.org/ns/shacl#>
Expand Down
6 changes: 3 additions & 3 deletions packages/sparql-llm/src/sparql_llm/validate_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,14 @@ def validate_triple_pattern(
continue
if subj_type not in void_dict and not subj_type.startswith("?"):
issues.add(
f"Type {prefix_converter.compress_list([subj_type])[0]} for subject {subj} in endpoint {endpoint} does not exist. Available classes are: {', '.join(prefix_converter.compress_list(list(void_dict.keys())))}"
f"Type {prefix_converter.compress_list([subj_type])[0]} for subject {subj} in endpoint {endpoint} does not exist. Available classes are: `{'`, `'.join(prefix_converter.compress_list(list(void_dict.keys())))}`"
)
elif pred not in void_dict.get(subj_type, {}) and not pred.startswith("?"):
# TODO: also check if object type matches? (if defined, because it's not always available)
# NOTE: we use compress_list for single values also because it has passthrough enabled by default for when there is no match in the converter
# print(subj_type, pred, list(void_dict.get(subj_type, {}).keys()), void_dict.get(subj_type, {}))
issues.add(
f"Subject {subj} with type {prefix_converter.compress_list([subj_type])[0]} in endpoint {endpoint} does not support the predicate {prefix_converter.compress_list([pred])[0]}. It can have the following predicates: {', '.join(prefix_converter.compress_list(list(void_dict.get(subj_type, {}).keys())))}"
f"Subject {subj} with type `{prefix_converter.compress_list([subj_type])[0]}` in endpoint {endpoint} does not support the predicate `{prefix_converter.compress_list([pred])[0]}`. It can have the following predicates: `{'`, `'.join(prefix_converter.compress_list(list(void_dict.get(subj_type, {}).keys())))}`"
)
for obj in pred_dict[pred]:
# Recursively validates objects that are variables
Expand Down Expand Up @@ -217,7 +217,7 @@ def validate_triple_pattern(
if missing_pred is not None:
# print(f"!!!! Subject {subj} {parent_type} {parent_pred} is not a valid {potential_types} !")
issues.add(
f"Subject {subj} in endpoint {endpoint} does not support the predicate {prefix_converter.compress_list([missing_pred])[0]}. Correct predicate might be one of the following: {', '.join(prefix_converter.compress_list(list(potential_preds)))} (we inferred this variable might be of the type {prefix_converter.compress_list([potential_type])[0]})"
f"Subject {subj} in endpoint {endpoint} does not support the predicate `{prefix_converter.compress_list([missing_pred])[0]}`. Correct predicate might be one of the following: `{'`, `'.join(prefix_converter.compress_list(list(potential_preds)))}` (we inferred this variable might be of the type `{prefix_converter.compress_list([potential_type])[0]}`)"
)

# TODO: when no type and no parent but more than 1 predicate is used, we could try to infer the type from the predicates
Expand Down
19 changes: 19 additions & 0 deletions tutorial/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Tutorial to build a SPARQL agent

## Deploy the slides locally

On http://localhost:3000/sparql-llm/

```sh
cd slides
npm i
npm run dev
```

## Deploy chat

On http://localhost:8000

```sh
uv run chainlit run app.py
```
46 changes: 29 additions & 17 deletions tutorial/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,25 @@ def load_chat_model(model: str) -> BaseChatModel:
from langchain_community.embeddings import FastEmbedEmbeddings

vectordb = QdrantVectorStore.from_existing_collection(
path="data/qdrant",
host="localhost",
prefer_grpc=True,
# path="data/qdrant",
collection_name="sparql-docs",
embedding=FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5"),
)

retriever = vectordb.as_retriever()
number_of_docs_retrieved = 5
docs_retrieved_count = 5

# retrieved_docs = retriever.invoke(question, k=number_of_docs_retrieved)
# retrieved_docs = retriever.invoke(question, k=docs_retrieved_count)

from qdrant_client.models import FieldCondition, Filter, MatchValue
from langchain_core.documents import Document

def retrieve_docs(question: str) -> list[Document]:
def retrieve_docs(question: str) -> str:
retrieved_docs = retriever.invoke(
question,
k=number_of_docs_retrieved,
k=docs_retrieved_count,
filter=Filter(
must=[
FieldCondition(
Expand All @@ -54,7 +56,7 @@ def retrieve_docs(question: str) -> list[Document]:
)
retrieved_docs += retriever.invoke(
question,
k=number_of_docs_retrieved,
k=docs_retrieved_count,
filter=Filter(
must_not=[
FieldCondition(
Expand All @@ -64,9 +66,10 @@ def retrieve_docs(question: str) -> list[Document]:
]
),
)
return retrieved_docs
return f"<documents>\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n</documents>"

# retrieved_docs = retrieve_docs(question)
# relevant_docs = "\n".join(doc.page_content + "\n" + doc.metadata.get("answer") for doc in retrieved_docs)
# relevant_docs = retrieve_docs(question)

# print(f"📚️ Retrieved {len(retrieved_docs)} documents")
# # print(retrieved_docs)
Expand All @@ -92,18 +95,16 @@ def _format_doc(doc: Document) -> str:
Try to always answer with one query, if the answer lies in different endpoints, provide a federated query.
And briefly explain the query.
Here is a list of documents (reference questions and query answers, classes schema) relevant to the user question that will help you answer the user question accurately:
{retrieved_docs}
{relevant_docs}
"""
prompt_template = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("placeholder", "{messages}"),
])

# formatted_docs = "\n".join(doc.page_content + "\n" + doc.metadata.get("answer") for doc in retrieved_docs)
# formatted_docs = f"<documents>\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n</documents>"
# prompt_with_context = prompt_template.invoke({
# "messages": [("human", question)],
# "retrieved_docs": formatted_docs,
# "relevant_docs": relevant_docs,
# })

# print(str("\n".join(prompt_with_context.messages)))
Expand All @@ -119,17 +120,28 @@ def _format_doc(doc: Document) -> str:

@cl.on_message
async def on_message(msg: cl.Message):
retrieved_docs = retrieve_docs(msg.content)
formatted_docs = f"<documents>\n{'\n'.join(_format_doc(doc) for doc in retrieved_docs)}\n</documents>"
async with cl.Step(name=f"{len(retrieved_docs)} relevant documents") as step:
relevant_docs = retrieve_docs(msg.content)
async with cl.Step(name="relevant documents") as step:
# step.input = msg.content
step.output = formatted_docs
step.output = relevant_docs

prompt_with_context = prompt_template.invoke({
"messages": [("human", msg.content)],
"retrieved_docs": formatted_docs,
"relevant_docs": relevant_docs,
})
final_answer = cl.Message(content="")
for resp in llm.stream(prompt_with_context):
await final_answer.stream_token(resp.content)
await final_answer.send()

@cl.set_starters
async def set_starters():
return [
cl.Starter(
label="What are the rat orthologs of human TP53?",
message="What are the rat orthologs of human TP53?",
# icon="/public/idea.svg",
),
]

# uv run chainlit run app.py
Loading

0 comments on commit d0dbb03

Please sign in to comment.