Skip to content

Commit

Permalink
feat: Add result ranking prompt for NLDataframeRetriever
Browse files Browse the repository at this point in the history
This commit adds a new result ranking prompt to the NLDataframeRetriever class. The prompt allows users to provide a schema and query, and asks them to rate the relevance of the schema in modeling the domain of the query. The relevance must be a number between 0 and 1, where 1 indicates high relevance and 0 indicates low relevance.

The significant changes include:
- Added DEFAULT_RESULT_RANKING_TMPL constant for the result ranking template
- Added DEFAULT_RESULT_RANKING_PROMPTROMPT constant for the result ranking prompt template
- Updated NLDataframeRetriever constructor to accept a result_ranking_prompt parameter
- Initialized self._result_ranking_prompt with either the provided parameter or the default prompt template
- Modified NLDataframeRetriever.complete() method to use self._result_ranking_prompt as part of the LLM completion request

These changes allow users of NLDataframeRetriever to easily rank the relevance of schemas in modeling their queries, providing more accurate results.
  • Loading branch information
colombod committed Sep 13, 2024
1 parent 35c0c91 commit 8c6045d
Showing 1 changed file with 22 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@
prompt_type=PromptType.CUSTOM,
)

DEFAULT_RESULT_RANKING_TMPL = """\
given the schema:{schema}\
and the query: {query}\
how relevant is the schema?\
the relevance must be a number between 0 and 1 where 1 indicates that the schema is able to model the domain of the query and 0 indicates that the schema is not able to model the domain of the query.\
produce only the numeric value and nothing else.\
relevance:
"""

DEFAULT_RESULT_RANKING_PROMPTROMPT = PromptTemplate(
DEFAULT_RESULT_RANKING_TMPL,
prompt_type=PromptType.CUSTOM,
)


class NLDataframeRetriever(BaseRetriever):
def __init__(
Expand All @@ -62,13 +77,17 @@ def __init__(
text_to_sql_prompt: Optional[BasePromptTemplate] = None,
schema_to_owl_prompt: Optional[BasePromptTemplate] = None,
schema_use_detection_prompt: Optional[BasePromptTemplate] = None,
result_ranking_prompt: Optional[BasePromptTemplate] = None,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
):
self._llm = resolve_llm(llm)
self._similarity_top_k = similarity_top_k
self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_PROMPT
self._result_ranking_prompt = (
result_ranking_prompt or DEFAULT_RESULT_RANKING_PROMPTROMPT
)
self._schema_to_owl_prompt = (
schema_to_owl_prompt or DEFAULT_OWL_GENERATOR_PROMPT
)
Expand Down Expand Up @@ -142,16 +161,9 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
)

rank = self._llm.complete(
f"""
given the schema:{tables_desc_str}
and the query: {query_bundle.query_str}
how relevant is the schema?
the relevance must be a number between 0 and 1 where 1 indicates that the schema is able to model the domain of the query and 0 indicates that the schema is not able to model the domain of the query.
produce only the numeric value and nothing else.
relevance:
"""
self._result_ranking_prompt,
query=query_bundle.query_str,
schema=tables_desc_str,
)

score = 1.0
Expand Down

0 comments on commit 8c6045d

Please sign in to comment.