Skip to content

Commit b60df35

Browse files
authored
Feat: Introduce LLM integration (#799)
1 parent 620f114 commit b60df35

File tree

4 files changed

+99
-0
lines changed

4 files changed

+99
-0
lines changed

setup.cfg

+3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ ignore_missing_imports = True
6666
[mypy-psycopg2.*]
6767
ignore_missing_imports = True
6868

69+
[mypy-langchain.*]
70+
ignore_missing_imports = True
71+
6972
[autoflake]
7073
in-place = True
7174
expand-star-imports = True

setup.py

+4
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@
8787
"dbt": [
8888
"dbt-core<1.5.0",
8989
],
90+
"llm": [
91+
"langchain",
92+
"openai",
93+
],
9094
"postgres": [
9195
"psycopg2",
9296
],

sqlmesh/cli/main.py

+39
Original file line numberDiff line numberDiff line change
@@ -378,5 +378,44 @@ def migrate(ctx: click.Context) -> None:
378378
ctx.obj.migrate()
379379

380380

381+
@cli.command("prompt")
382+
@click.argument("prompt")
383+
@click.option(
384+
"-e",
385+
"--evaluate",
386+
is_flag=True,
387+
help="Evaluate the generated SQL query and display the results.",
388+
)
389+
@click.option(
390+
"-t",
391+
"--temperature",
392+
type=float,
393+
help="Sampling temperature. 0.0 - precise and predictable, 0.5 - balanced, 1.0 - creative. Default: 0.7",
394+
default=0.7,
395+
)
396+
@opt.verbose
397+
@click.pass_context
398+
@error_handler
399+
def prompt(
400+
ctx: click.Context, prompt: str, evaluate: bool, temperature: float, verbose: bool
401+
) -> None:
402+
"""Uses LLM to generate a SQL query from a prompt."""
403+
from sqlmesh.integrations.llm import LLMIntegration
404+
405+
context = ctx.obj
406+
407+
llm_integration = LLMIntegration(
408+
context.models.values(),
409+
context.engine_adapter.dialect,
410+
temperature=temperature,
411+
verbose=verbose,
412+
)
413+
query = llm_integration.query(prompt)
414+
415+
context.console.log_status_update(query)
416+
if evaluate:
417+
context.console.log_success(context.fetchdf(query))
418+
419+
381420
if __name__ == "__main__":
382421
cli()

sqlmesh/integrations/llm.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
from langchain import LLMChain, PromptTemplate
6+
from langchain.chat_models import ChatOpenAI
7+
8+
from sqlmesh.core.model import Model
9+
10+
_QUERY_PROMPT_TEMPLATE = """Given an input request, create a syntactically correct {dialect} SQL query.
11+
Use full table names.
12+
Convert string operands to lowercase in the WHERE clause.
13+
Reply with a SQL query and nothing else.
14+
15+
Use the following tables and columns:
16+
17+
{table_info}
18+
19+
Request: {input}"""
20+
21+
22+
class LLMIntegration:
23+
def __init__(
24+
self,
25+
models: t.Iterable[Model],
26+
dialect: str,
27+
temperature: float = 0.7,
28+
verbose: bool = False,
29+
):
30+
query_prompt_template = PromptTemplate.from_template(_QUERY_PROMPT_TEMPLATE).partial(
31+
dialect=dialect, table_info=_to_table_info(models)
32+
)
33+
llm = ChatOpenAI(temperature=temperature) # type: ignore
34+
self._query_chain = LLMChain(llm=llm, prompt=query_prompt_template, verbose=verbose)
35+
36+
def query(self, prompt: str) -> str:
37+
result = self._query_chain.predict(input=prompt).strip()
38+
select_pos = result.find("SELECT")
39+
if select_pos >= 0:
40+
return result[select_pos:]
41+
return result
42+
43+
44+
def _to_table_info(models: t.Iterable[Model]) -> str:
45+
infos = []
46+
for model in models:
47+
if not model.kind.is_materialized:
48+
continue
49+
50+
columns_csv = ", ".join(model.columns_to_types)
51+
infos.append(f"Table: {model.name}\nColumns: {columns_csv}\n")
52+
53+
return "\n".join(infos)

0 commit comments

Comments
 (0)