|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import typing as t |
| 4 | + |
| 5 | +from langchain import LLMChain, PromptTemplate |
| 6 | +from langchain.chat_models import ChatOpenAI |
| 7 | + |
| 8 | +from sqlmesh.core.model import Model |
| 9 | + |
| 10 | +_QUERY_PROMPT_TEMPLATE = """Given an input request, create a syntactically correct {dialect} SQL query. |
| 11 | +Use full table names. |
| 12 | +Convert string operands to lowercase in the WHERE clause. |
| 13 | +Reply with a SQL query and nothing else. |
| 14 | +
|
| 15 | +Use the following tables and columns: |
| 16 | +
|
| 17 | +{table_info} |
| 18 | +
|
| 19 | +Request: {input}""" |
| 20 | + |
| 21 | + |
| 22 | +class LLMIntegration: |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + models: t.Iterable[Model], |
| 26 | + dialect: str, |
| 27 | + temperature: float = 0.7, |
| 28 | + verbose: bool = False, |
| 29 | + ): |
| 30 | + query_prompt_template = PromptTemplate.from_template(_QUERY_PROMPT_TEMPLATE).partial( |
| 31 | + dialect=dialect, table_info=_to_table_info(models) |
| 32 | + ) |
| 33 | + llm = ChatOpenAI(temperature=temperature) # type: ignore |
| 34 | + self._query_chain = LLMChain(llm=llm, prompt=query_prompt_template, verbose=verbose) |
| 35 | + |
| 36 | + def query(self, prompt: str) -> str: |
| 37 | + result = self._query_chain.predict(input=prompt).strip() |
| 38 | + select_pos = result.find("SELECT") |
| 39 | + if select_pos >= 0: |
| 40 | + return result[select_pos:] |
| 41 | + return result |
| 42 | + |
| 43 | + |
| 44 | +def _to_table_info(models: t.Iterable[Model]) -> str: |
| 45 | + infos = [] |
| 46 | + for model in models: |
| 47 | + if not model.kind.is_materialized: |
| 48 | + continue |
| 49 | + |
| 50 | + columns_csv = ", ".join(model.columns_to_types) |
| 51 | + infos.append(f"Table: {model.name}\nColumns: {columns_csv}\n") |
| 52 | + |
| 53 | + return "\n".join(infos) |
0 commit comments