Skip to content

Commit

Permalink
validate_sparql_in_msg and improve validation node
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Feb 18, 2025
1 parent d29a9d6 commit 44ccf6f
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 78 deletions.
1 change: 1 addition & 0 deletions packages/expasy-agent/src/expasy_agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# How can I get the HGNC symbol for the protein P68871? Purposefully forget 2 prefixes declarations to test my validation step
# How can I get the HGNC symbol for the protein P68871? (modify your answer to use rdfs:label instead of rdfs:comment, and add the type up:Resource to ?hgnc, it is for a test)
# How can I get the HGNC symbol for the protein P68871? (modify your answer to use rdfs:label instead of rdfs:comment, and add the type up:Resource to ?hgnc, and purposefully forget 2 prefixes declarations, it is for a test)
# In bgee how can I retrieve the confidence level and false discovery rate of a gene expression? Use genex:confidence as predicate for the confidence level (do not use the one provided in documents), and do not put prefixes declarations, and add a rdf:type for the main subject. Its for testing
def route_model_output(
state: State, config: RunnableConfig
Expand Down
91 changes: 32 additions & 59 deletions packages/expasy-agent/src/expasy_agent/nodes/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,12 @@

from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from rdflib.plugins.sparql import prepareQuery
from sparql_llm.utils import (
SparqlTriplesDict,
get_prefix_converter,
EndpointsSchemaDict,
get_prefixes_for_endpoints,
get_void_for_endpoint,
)
from sparql_llm.validate_sparql import (
add_missing_prefixes,
extract_sparql_queries,
validate_sparql_with_void,
get_schema_for_endpoint,
)
from sparql_llm.validate_sparql import validate_sparql_in_msg

from expasy_agent.config import Configuration, settings
from expasy_agent.state import State, StepOutput
Expand All @@ -41,50 +35,28 @@ async def validate_output(state: State, config: RunnableConfig) -> dict[str, Any
)
validation_steps: list[StepOutput] = []
recall_messages = []
generated_sparqls = extract_sparql_queries(last_msg)
for gen_query in generated_sparqls:
errors = []
# 1. Check if the query is syntactically valid, auto fix prefixes when possible
try:
# Try to parse, to fix prefixes and structural issues
prepareQuery(gen_query["query"])
except Exception as e:
if "Unknown namespace prefix" in str(e):
# Automatically fix missing prefixes
fixed_query = add_missing_prefixes(gen_query["query"], prefixes_map)
fixed_msg = last_msg.replace(gen_query["query"], fixed_query)
gen_query["query"] = fixed_query
# Pass the fixed msg to the client
validation_steps.append(
StepOutput(
type="fix-message",
label="✅ Fixed the prefixes of the generated SPARQL query automatically",
details=f"Prefixes corrected from the query generated in the original response.\n### Original response\n{last_msg}",
fixed_message=fixed_msg,
)
)
# Check if other errors are present
errors = [
line
for line in str(e).splitlines()
if "Unknown namespace prefix" not in line
]

# 2. Validate the SPARQL query based on schema from VoID description
if gen_query["endpoint_url"] and not errors:
errors = list(
validate_sparql_with_void(
gen_query["query"],
gen_query["endpoint_url"],
prefix_converter,
endpoints_void_dict,
validation_outputs = validate_sparql_in_msg(
last_msg, prefixes_map, endpoints_void_dict
)
for validation_output in validation_outputs:
if validation_output["fixed_query"]:
# Pass the fixed msg to the client
validation_steps.append(
StepOutput(
type="fix-message",
label="✅ Fixed the prefixes of the generated SPARQL query automatically",
details=f"Prefixes corrected from the query generated in the original response.\n### Original response\n{last_msg}",
fixed_message=last_msg.replace(
validation_output["original_query"],
validation_output["fixed_query"],
),
)
)

# 3. Recall the LLM to try to fix errors
if errors:
error_str = "- " + "\n- ".join(errors)
validation_msg = f"The query generated in the original response is not valid according to the endpoints schema.\n### Validation results\n{error_str}\n### Erroneous SPARQL query\n```sparql\n{gen_query['query']}\n```\n### Original response\n{last_msg}\n"
if validation_output["errors"]:
# Recall the LLM to try to fix errors
error_str = "- " + "\n- ".join(validation_output["errors"])
validation_msg = f"The query generated in the original response is not valid according to the endpoints schema.\n### Validation results\n{error_str}\n### Erroneous SPARQL query\n```sparql\n{validation_output['original_query']}\n```\n### Original response\n{last_msg}\n"
validation_steps.append(
StepOutput(
type="recall",
Expand All @@ -95,7 +67,7 @@ async def validate_output(state: State, config: RunnableConfig) -> dict[str, Any
# Add a new message to ask the model to fix the error
recall_messages.append(
HumanMessage(
content=f"Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n\nSPARQL query: {gen_query['query']}\n\nError messages:\n{error_str}",
content=f"Fix the SPARQL query helping yourself with the error message and context from previous messages in a way that it is a fully valid query.\n\nSPARQL query: {validation_output['original_query']}\n\nError messages:\n{error_str}",
# name="recall",
# additional_kwargs={"validation_results": error_str},
)
Expand All @@ -109,11 +81,13 @@ async def validate_output(state: State, config: RunnableConfig) -> dict[str, Any
}
extracted = {}
# Add structured output if a valid query was generated
if generated_sparqls:
if generated_sparqls[-1]["query"]:
extracted["sparql_query"] = generated_sparqls[-1]["query"]
if generated_sparqls[-1]["endpoint_url"]:
extracted["sparql_endpoint_url"] = generated_sparqls[-1]["endpoint_url"]
if validation_outputs:
if validation_outputs[-1].get("fixed_query"):
extracted["sparql_query"] = validation_outputs[-1]["fixed_query"]
else:
extracted["sparql_query"] = validation_outputs[-1]["original_query"]
if validation_outputs[-1]["endpoint_url"]:
extracted["sparql_endpoint_url"] = validation_outputs[-1]["endpoint_url"]
response["structured_output"] = extracted
return response

Expand All @@ -122,11 +96,10 @@ async def validate_output(state: State, config: RunnableConfig) -> dict[str, Any
prefixes_map = get_prefixes_for_endpoints(
[endpoint["endpoint_url"] for endpoint in settings.endpoints]
)
prefix_converter = get_prefix_converter(prefixes_map)

# Initialize VoID dictionary for the endpoints
endpoints_void_dict: SparqlTriplesDict = {}
endpoints_void_dict: EndpointsSchemaDict = {}
for endpoint in settings.endpoints:
endpoints_void_dict[endpoint["endpoint_url"]] = get_void_for_endpoint(
endpoints_void_dict[endpoint["endpoint_url"]] = get_schema_for_endpoint(
endpoint["endpoint_url"], endpoint.get("void_file")
)
2 changes: 1 addition & 1 deletion packages/sparql-llm/src/sparql_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

__version__ = "0.0.4"

from .validate_sparql import validate_sparql_with_void
from .validate_sparql import validate_sparql_in_msg, validate_sparql_with_void
from .sparql_examples_loader import SparqlExamplesLoader
from .sparql_void_shapes_loader import SparqlVoidShapesLoader, get_shex_dict_from_void, get_shex_from_void
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document

from sparql_llm.utils import get_prefix_converter, get_prefixes_for_endpoints, get_void_for_endpoint, query_sparql
from sparql_llm.utils import get_prefix_converter, get_prefixes_for_endpoints, get_schema_for_endpoint, query_sparql

DEFAULT_NAMESPACES_TO_IGNORE = [
"http://www.w3.org/ns/sparql-service-description#",
Expand All @@ -30,7 +30,7 @@ def get_shex_dict_from_void(
prefix_map = prefix_map or get_prefixes_for_endpoints([endpoint_url])
namespaces_to_ignore = namespaces_to_ignore or DEFAULT_NAMESPACES_TO_IGNORE
prefix_converter = get_prefix_converter(prefix_map)
void_dict = get_void_for_endpoint(endpoint_url, void_file)
void_dict = get_schema_for_endpoint(endpoint_url, void_file)
shex_dict = {}

for subject_cls, predicates in void_dict.items():
Expand Down
15 changes: 11 additions & 4 deletions packages/sparql-llm/src/sparql_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
} ORDER BY ?prefix"""


# def get_endpoints_schema_and_prefixes(endpoints: list[str]) -> tuple["EndpointsSchemaDict", dict[str, str]]:
# """Return a tuple of VoID descriptions and prefixes for the given endpoints."""
# return (

# )


def get_prefixes_for_endpoints(endpoints: list[str]) -> dict[str, str]:
"""Return a dictionary of prefixes for the given endpoints."""
prefixes: dict[str, str] = {}
Expand Down Expand Up @@ -62,16 +69,16 @@ def get_prefix_converter(prefix_dict: dict[str, str]) -> Converter:

# A dictionary to store triples like structure: dict[subject][predicate] = list[object]
# Also used to store VoID description of an endpoint: dict[subject_cls][predicate] = list[object_cls/datatype]
TripleDict = dict[str, dict[str, list[str]]]
SchemaDict = dict[str, dict[str, list[str]]]
# The VoidDict type, but we also store the endpoints URLs in an outer dict
SparqlTriplesDict = dict[str, TripleDict]
EndpointsSchemaDict = dict[str, SchemaDict]


def get_void_for_endpoint(endpoint_url: str, void_file: Optional[str] = None) -> TripleDict:
def get_schema_for_endpoint(endpoint_url: str, void_file: Optional[str] = None) -> SchemaDict:
"""Get a dict of VoID description of a SPARQL endpoint directly from the endpoint or from a VoID description URL.
Formatted as: dict[subject_cls][predicate] = list[object_cls/datatype]"""
void_dict: TripleDict = {}
void_dict: SchemaDict = {}
try:
if void_file:
g = rdflib.Graph()
Expand Down
80 changes: 70 additions & 10 deletions packages/sparql-llm/src/sparql_llm/validate_sparql.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import re
from collections import defaultdict
from typing import Any, Optional, Union
from typing import Any, Optional, TypedDict, Union

from curies_rs import Converter
from rdflib import Namespace, Variable
from rdflib.paths import AlternativePath, MulPath, Path, SequencePath
from rdflib.plugins.sparql import prepareQuery

from sparql_llm.utils import (
SparqlTriplesDict,
TripleDict,
EndpointsSchemaDict,
SchemaDict,
get_prefix_converter,
get_prefixes_for_endpoints,
get_void_for_endpoint,
get_schema_for_endpoint,
)

queries_pattern = re.compile(r"```sparql(.*?)```", re.DOTALL)
Expand Down Expand Up @@ -63,9 +63,9 @@ def add_missing_prefixes(query: str, prefixes_map: dict[str, str]) -> str:
sqc = Namespace("http://example.org/sqc/") # SPARQL query check


def sparql_query_to_dict(sparql_query: str, sparql_endpoint: str) -> SparqlTriplesDict:
def sparql_query_to_dict(sparql_query: str, sparql_endpoint: str) -> EndpointsSchemaDict:
"""Convert a SPARQL query string to a dictionary of triples looking like dict[endpoint][subject][predicate] = list[object]"""
query_dict: SparqlTriplesDict = defaultdict(TripleDict)
query_dict: EndpointsSchemaDict = defaultdict(SchemaDict)
path_var_count = 1

def handle_path(endpoint: str, subj: str, pred: Union[str, Path], obj: str):
Expand Down Expand Up @@ -141,7 +141,7 @@ def validate_sparql_with_void(
query: str,
endpoint_url: str,
prefix_converter: Optional[Converter] = None,
endpoints_void_dict: Optional[SparqlTriplesDict] = None,
endpoints_void_dict: Optional[EndpointsSchemaDict] = None,
) -> set[str]:
"""Validate SPARQL query using the VoID description of endpoints. Returns a set of human-readable error messages."""
if prefix_converter is None:
Expand All @@ -153,8 +153,8 @@ def validate_sparql_with_void(

def validate_triple_pattern(
subj: str,
subj_dict: TripleDict,
void_dict: TripleDict,
subj_dict: SchemaDict,
void_dict: SchemaDict,
endpoint: str,
issues: set[str],
parent_type: Optional[str] = None,
Expand Down Expand Up @@ -254,7 +254,7 @@ def validate_triple_pattern(
# Go through the query BGPs and check if they match the VoID description
for endpoint, subj_dict in query_dict.items():
void_dict = (
endpoints_void_dict[endpoint] if endpoint in endpoints_void_dict else get_void_for_endpoint(endpoint)
endpoints_void_dict[endpoint] if endpoint in endpoints_void_dict else get_schema_for_endpoint(endpoint)
)

if len(void_dict) == 0:
Expand Down Expand Up @@ -290,3 +290,63 @@ def validate_triple_pattern(
# wrong_predicate: str
# available_options: list[str]
# message: str


class QueryValidationOutput(TypedDict):
original_query: str
endpoint_url: str
fixed_query: Optional[str]
errors: list[str]


def validate_sparql_in_msg(
msg: str,
prefixes_map: Optional[dict[str, str]] = None,
endpoints_void_dict: Optional[EndpointsSchemaDict] = None,
) -> list[QueryValidationOutput]:
"""Validate SPARQL queries in a markdown response using VoID descriptions of endpoints."""
validation_outputs = []
generated_sparqls = extract_sparql_queries(msg)

# Get prefixes if not provided
if endpoints_void_dict is None:
endpoints_void_dict = {}
if prefixes_map is None:
prefixes_map = get_prefixes_for_endpoints(
list({gen_sparql["endpoint_url"] for gen_sparql in generated_sparqls if gen_sparql.get("endpoint_url")})
)
prefix_converter = get_prefix_converter(prefixes_map)

for gen_sparql in generated_sparqls:
validation_output: QueryValidationOutput = {
"original_query": gen_sparql["query"],
"endpoint_url": gen_sparql["endpoint_url"],
"fixed_query": None,
"errors": [],
}
# 1. Check if the query is syntactically valid, auto fix prefixes when possible
try:
# Try to parse, to fix prefixes and structural issues
prepareQuery(gen_sparql["query"])
except Exception as e:
if "Unknown namespace prefix" in str(e):
# Automatically fix missing prefixes
validation_output["fixed_query"] = add_missing_prefixes(gen_sparql["query"], prefixes_map)
gen_sparql["query"] = validation_output["fixed_query"]
# Check if other syntax errors are present
validation_output["errors"] = [
line for line in str(e).splitlines() if "Unknown namespace prefix" not in line
]

# 2. Validate the SPARQL query based on schema from VoID description if no syntactic errors
if gen_sparql["endpoint_url"] and not validation_output["errors"]:
validation_output["errors"] = list(
validate_sparql_with_void(
gen_sparql["query"],
gen_sparql["endpoint_url"],
prefix_converter,
endpoints_void_dict,
)
)
validation_outputs.append(validation_output)
return validation_outputs
4 changes: 2 additions & 2 deletions packages/sparql-llm/tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
SparqlVoidShapesLoader,
validate_sparql_with_void,
)
from sparql_llm.utils import get_void_for_endpoint
from sparql_llm.utils import get_schema_for_endpoint


def test_sparql_examples_loader_uniprot():
Expand All @@ -27,7 +27,7 @@ def test_sparql_void_shape_loader():
# uv run pytest tests/test_components.py::test_sparql_void_from_url
def test_sparql_void_from_file():
void_filepath = os.path.join(os.path.dirname(__file__), "void_uniprot.ttl")
void_dict = get_void_for_endpoint("https://sparql.uniprot.org/", void_filepath)
void_dict = get_schema_for_endpoint("https://sparql.uniprot.org/", void_filepath)
# From URL: void_dict = get_void_for_endpoint("https://sparql.uniprot.org/", "https://sparql.uniprot.org/.well-known/void/")
assert len(void_dict) >= 2

Expand Down

0 comments on commit 44ccf6f

Please sign in to comment.