Skip to content

Commit 969d929

Browse files
committed
Rewrite llm query operator
This is part of the series of tasks for converting llm operator to use Henry's new LLMMap or LLMElementMap.
1 parent 4a9e52e commit 969d929

File tree

5 files changed

+144
-14
lines changed

5 files changed

+144
-14
lines changed

lib/sycamore/sycamore/docset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,10 +1331,9 @@ def llm_query(self, query_agent: LLMTextQueryAgent, **kwargs) -> "DocSet":
13311331
element_type: (Optional) Parameter to only execute the LLM query on a particular element type. If not
13321332
specified, the query will be executed on all elements.
13331333
"""
1334-
from sycamore.transforms import LLMQuery
1334+
query = query_agent.as_llm_map(self.plan, **kwargs)
13351335

1336-
queries = LLMQuery(self.plan, query_agent=query_agent, **kwargs)
1337-
return DocSet(self.context, queries)
1336+
return DocSet(self.context, query)
13381337

13391338
def groupby(self, grouped_key: Union[str, list[str]], entity: Optional[str] = None) -> "GroupedData":
13401339
from sycamore.grouped_data import GroupedData

lib/sycamore/sycamore/llms/prompts/prompts.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,58 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt:
662662
if self.response_format is not None:
663663
result.response_format = self.response_format
664664
return result
665+
666+
667+
class JinjaTableMergerPrompt(JinjaElementPrompt):
668+
def __init__(
669+
self,
670+
*,
671+
system: Optional[str] = None,
672+
user: Union[None, str, list[str]] = None,
673+
include_image: bool = False,
674+
response_format: ResponseFormat = None,
675+
**kwargs,
676+
):
677+
from jinja2.sandbox import SandboxedEnvironment
678+
from jinja2 import Template
679+
680+
super().__init__()
681+
self.system = system
682+
self.user = user
683+
self.include_image = include_image
684+
self.response_format = response_format
685+
self.kwargs = kwargs
686+
self._env = SandboxedEnvironment(extensions=["jinja2.ext.loopcontrols"])
687+
self._sys_template: Optional[Template] = None
688+
self._user_templates: Union[None, list[Template]] = None
689+
690+
def render_element(self, elt: Element, doc: Document) -> RenderedPrompt:
691+
filtered = [e for e in doc.elements if e.type == "table"]
692+
idx = filtered.index(elt)
693+
prev = None
694+
if idx > 0:
695+
prev = filtered[idx - 1]
696+
697+
if self._user_templates is None:
698+
userlist = self.user if isinstance(self.user, list) else [self.user] # type: ignore
699+
templates = compile_templates([self.system] + userlist, self._env) # type: ignore
700+
self._sys_template = templates[0]
701+
self._user_templates = [t for t in templates[1:] if t is not None]
702+
703+
render_args = copy.deepcopy(self.kwargs)
704+
render_args["elt"] = elt
705+
render_args["doc"] = doc
706+
render_args["prev"] = prev
707+
708+
result = render_templates(self._sys_template, self._user_templates, render_args)
709+
if self.include_image and len(result.messages) > 0:
710+
from sycamore.utils.pdf_utils import get_element_image
711+
712+
images = []
713+
if prev:
714+
images.append(get_element_image(prev, doc))
715+
images.append(get_element_image(elt, doc))
716+
result.messages[-1].images = images
717+
if self.response_format is not None:
718+
result.response_format = self.response_format
719+
return result

lib/sycamore/sycamore/tests/unit/transforms/test_llm_query.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ def test_llm_query_text_does_not_call_llm(self, mocker):
1515
prompt = "Give me a one word summary response about the text"
1616
output_property = "output_property"
1717
query_agent = LLMTextQueryAgent(prompt=prompt, llm=llm, output_property=output_property)
18-
doc = query_agent.execute_query(doc)
18+
doc = query_agent.as_llm_map(None)._local_process([doc])
1919

20-
assert output_property not in doc.elements[0].properties
20+
assert output_property not in doc[0].properties
2121

2222
def test_summarize_text_element_calls_llm(self, mocker):
2323
llm = mocker.Mock(spec=OpenAI)
24-
generate = mocker.patch.object(llm, "generate_old")
24+
generate = mocker.patch.object(llm, "generate")
2525
generate.return_value = {"summary": "summary"}
2626
doc = Document()
2727
element1 = Element()
@@ -32,14 +32,14 @@ def test_summarize_text_element_calls_llm(self, mocker):
3232
prompt = "Give me a one word summary response about the text"
3333
output_property = "output_property"
3434
query_agent = LLMTextQueryAgent(prompt=prompt, llm=llm, output_property=output_property)
35-
doc = query_agent.execute_query(doc)
35+
doc = query_agent.as_llm_map(None)._local_process([doc])
3636

37-
assert doc.elements[0].properties[output_property] == {"summary": "summary"}
38-
assert doc.elements[1].properties[output_property] == {"summary": "summary"}
37+
assert doc[0].elements[0].properties[output_property] == {"summary": "summary"}
38+
assert doc[0].elements[1].properties[output_property] == {"summary": "summary"}
3939

4040
def test_summarize_text_document_calls_llm(self, mocker):
4141
llm = mocker.Mock(spec=OpenAI)
42-
generate = mocker.patch.object(llm, "generate_old")
42+
generate = mocker.patch.object(llm, "generate")
4343
generate.return_value = {"summary": "summary"}
4444
doc = Document()
4545
element1 = Element()
@@ -51,6 +51,6 @@ def test_summarize_text_document_calls_llm(self, mocker):
5151
prompt = "Give me a one word summary response about the text"
5252
output_property = "output_property"
5353
query_agent = LLMTextQueryAgent(prompt=prompt, llm=llm, per_element=False, output_property=output_property)
54-
doc = query_agent.execute_query(doc)
54+
doc = query_agent.as_llm_map(None)._local_process([doc])
5555

56-
assert doc.properties[output_property] == {"summary": "summary"}
56+
assert doc[0].properties[output_property] == {"summary": "summary"}

lib/sycamore/sycamore/transforms/base_llm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def __init__(
188188
validate: Callable[[Element], bool] = lambda e: True,
189189
max_tries: int = 5,
190190
filter: Callable[[Element], bool] = lambda e: True,
191+
number_of_elements: Optional[int] = None,
191192
**kwargs,
192193
):
193194
self._prompt = prompt
@@ -199,6 +200,7 @@ def __init__(
199200
self._validate = validate
200201
self._max_tries = max_tries
201202
self._filter = filter
203+
self._number_of_elements = number_of_elements
202204
super().__init__(child, f=self.llm_map_elements, **kwargs)
203205

204206
def llm_map_elements(self, documents: list[Document]) -> list[Document]:
@@ -207,7 +209,15 @@ def llm_map_elements(self, documents: list[Document]) -> list[Document]:
207209
for e, _ in elt_doc_pairs:
208210
e.properties[self._iteration_var] = 0
209211

210-
skips = [not self._filter(e) for e, _ in elt_doc_pairs]
212+
skips = []
213+
counter = 0
214+
for e, _ in elt_doc_pairs:
215+
if self._filter(e) and (not self._number_of_elements or counter < self._number_of_elements):
216+
counter += 1
217+
skips.append(False)
218+
else:
219+
skips.append(True)
220+
211221
tries = 0
212222
while not all(skips) and tries < self._max_tries:
213223
tries += 1

lib/sycamore/sycamore/transforms/llm_query.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Optional, Any, Union
22

33
from sycamore.data import Element, Document
4+
from sycamore.llms.prompts.prompts import JinjaPrompt, ElementListPrompt, JinjaElementPrompt, JinjaTableMergerPrompt
45
from sycamore.plan_nodes import NonCPUUser, NonGPUUser, Node
56
from sycamore.llms import LLM
7+
from sycamore.transforms.base_llm import LLMMapElements, LLMMap
68
from sycamore.transforms.map import Map
79
from sycamore.utils.time_trace import timetrace
810
from jinja2.sandbox import SandboxedEnvironment
@@ -55,7 +57,7 @@ def __init__(
5557
self._output_property = output_property
5658
self._llm_kwargs = llm_kwargs
5759
self._per_element = per_element
58-
self._format_kwargs = format_kwargs
60+
self._format_kwargs = format_kwargs if format_kwargs else {}
5961
self._number_of_elements = number_of_elements
6062
self._element_type = element_type
6163
self._table_cont = table_cont
@@ -125,6 +127,70 @@ def _query_text_object(
125127
object["properties"][self._output_property] = llm_resp
126128
return object
127129

130+
def get_element_prompt(self):
131+
if self._table_cont:
132+
prompt = (
133+
self._prompt
134+
+ "\n"
135+
+ "{% if prev %}"
136+
+ "ELEMENT 1: \n"
137+
+ "{prev}"
138+
+ "\n\n"
139+
+ "ELEMENT 2: \n"
140+
+ "{elt}"
141+
+ "{% else %}"
142+
+ "{elt}"
143+
+ "{% endif %}"
144+
)
145+
return JinjaTableMergerPrompt(system=None, user=prompt, **self._format_kwargs)
146+
else:
147+
return JinjaElementPrompt(system=None, user=self._prompt, **self._format_kwargs)
148+
149+
def get_document_prompt(self):
150+
if self._number_of_elements:
151+
152+
def selector(elts: list[Element]):
153+
elts = [elt for elt in elts if elt.type == self._element_type]
154+
if self._number_of_elements:
155+
elts = elts[: self._number_of_elements]
156+
return elts
157+
158+
return ElementListPrompt(
159+
system=None,
160+
user=self._prompt,
161+
element_select=selector,
162+
element_list_constructor=lambda elts: "\n".join([elt.text_representation for elt in elts]),
163+
**self._format_kwargs,
164+
)
165+
else:
166+
if self._format_kwargs:
167+
return JinjaPrompt(system=None, user=self._prompt, **self._format_kwargs)
168+
else:
169+
return JinjaPrompt(system=None, user=self._prompt + "\n" + "{{doc}}", **self._format_kwargs)
170+
171+
def as_llm_map(self, child) -> Node:
172+
if self._per_element:
173+
174+
def element_filter(e: Element):
175+
return not self._element_type or e.type == self._element_type
176+
177+
prompt = self.get_element_prompt()
178+
llm_map = LLMMapElements(
179+
child,
180+
prompt=prompt,
181+
output_field=self._output_property,
182+
llm=self._llm,
183+
filter=element_filter,
184+
number_of_elements=self._number_of_elements,
185+
**self._llm_kwargs,
186+
)
187+
else:
188+
prompt = self.get_document_prompt()
189+
llm_map = LLMMap(
190+
child, prompt=prompt, output_field=self._output_property, llm=self._llm, **self._llm_kwargs
191+
)
192+
return llm_map
193+
128194

129195
class LLMQuery(NonCPUUser, NonGPUUser, Map):
130196
"""

0 commit comments

Comments
 (0)