Skip to content

Commit ff9749a

Browse files
niklubalex-medicraticnik
authored
Parallel skill merge branch (#35)
Co-authored-by: alex-medicratic <137122049+alex-medicratic@users.noreply.github.com> Co-authored-by: nik <nik@heartex.net>
1 parent 8ac8cc9 commit ff9749a

File tree

2 files changed

+176
-2
lines changed

2 files changed

+176
-2
lines changed

adala/skills/skillset.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,95 @@ def __rich__(self):
244244
class ParallelSkillSet(SkillSet):
245245
"""
246246
Represents a set of skills that are acquired simultaneously to reach a goal.
247+
248+
In a ParallelSkillSet, each skill can be developed independently of the others. This is useful
249+
for agents that require multiple, diverse capabilities, or tasks where each skill contributes a piece of
250+
the overall solution.
251+
252+
Examples:
253+
Create a ParallelSkillSet with a list of skills specified as BaseSkill instances
254+
>>> from adala.skills import ParallelSkillSet, TextClassificationSkill, TextGenerationSkill
255+
>>> skillset = ParallelSkillSet(skills=[TextClassificationSkill(name='Classify sentiment', instructions='Classify the sentiment'), TextGenerationSkill(name='Summarize text', instructions='Generate a summar')])
256+
257+
Create a ParallelSkillSet with a dictionary of skill names to BaseSkill instances
258+
>>> from adala.skills import ParallelSkillSet, TextClassificationSkill, TextGenerationSkill
259+
>>> skillset = ParallelSkillSet(skills={'sentiment_analysis': TextClassificationSkill(name='Classify sentiment', instructions='Classify the sentiment'),'text_summary': TextGenerationSkill(name='Summarize text', instructions='Generate a summary')})
247260
"""
248-
249-
pass
261+
262+
@field_validator("skills", mode="before")
263+
@classmethod
264+
def skills_validator(
265+
cls, v: Union[List[BaseSkill], Dict[str, BaseSkill]]
266+
) -> Dict[str, BaseSkill]:
267+
"""
268+
Validates and converts the skills attribute to a dictionary of skill names to BaseSkill instances.
269+
270+
Args:
271+
v (List[BaseSkill], Dict[str, BaseSkill]]): The skills attribute to validate.
272+
273+
Returns:
274+
Dict[str, BaseSkill]: Dictionary mapping skill names to their corresponding BaseSkill instances.
275+
"""
276+
skills = OrderedDict()
277+
if not v:
278+
return skills
279+
280+
if isinstance(v, list) and isinstance(v[0], BaseSkill):
281+
# convert list of skill names to dictionary
282+
for skill in v:
283+
skills[skill.name] = skill
284+
elif isinstance(v, dict):
285+
skills = v
286+
else:
287+
raise ValidationError(
288+
f"skills must be a list or dictionary, not {type(skills)}"
289+
)
290+
return skills
291+
292+
def apply(
293+
self,
294+
dataset: Union[Dataset, InternalDataFrame],
295+
runtime: Runtime,
296+
improved_skill: Optional[str] = None,
297+
) -> InternalDataFrame:
298+
"""
299+
Applies each skill on the dataset, enhancing the agent's experience.
300+
301+
Args:
302+
dataset (Dataset): The dataset to apply the skills on.
303+
runtime (Runtime): The runtime environment in which to apply the skills.
304+
improved_skill (Optional[str], optional): Unused in ParallelSkillSet. Defaults to None.
305+
Returns:
306+
InternalDataFrame: Skill predictions.
307+
"""
308+
predictions = None
309+
310+
for i, skill_name in enumerate(self.skills.keys()):
311+
skill = self.skills[skill_name]
312+
# use input dataset for the first node in the pipeline
313+
input_dataset = dataset if i == 0 else predictions
314+
print_text(f"Applying skill: {skill_name}")
315+
predictions = skill.apply(input_dataset, runtime)
316+
317+
return predictions
318+
319+
def select_skill_to_improve(
320+
self, accuracy: Mapping, accuracy_threshold: Optional[float] = 0.9
321+
) -> Optional[BaseSkill]:
322+
"""
323+
Selects the skill with the lowest accuracy to improve.
324+
325+
Args:
326+
accuracy (Mapping): Accuracy of each skill.
327+
accuracy_threshold (Optional[float], optional): Accuracy threshold. Defaults to 1.0.
328+
Returns:
329+
Optional[BaseSkill]: Skill to improve. None if no skill to improve.
330+
"""
331+
skills_below_threshold = [
332+
skill_name
333+
for skill_name in self.skills.keys()
334+
if accuracy[skill_name] < accuracy_threshold
335+
]
336+
if skills_below_threshold:
337+
weakest_skill_name = min(skills_below_threshold, key=accuracy.get)
338+
return self.skills[weakest_skill_name]

tests/test_llm_parallel_skillset.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pandas as pd
2+
3+
from utils import patching, PatchedCalls
4+
5+
@patching(
6+
target_function=PatchedCalls.OPENAI_MODEL_LIST.value,
7+
data=[{'input': {}, 'output': {'data': [{'id': 'gpt-3.5-turbo-instruct'}]}}],
8+
)
9+
@patching(
10+
target_function=PatchedCalls.GUIDANCE.value,
11+
data=[
12+
# Responses for the first text entry
13+
{
14+
'input': {"text_": "Apple's latest product, the iPhone 15, was released in September 2023."},
15+
'output': {"predictions": ""} # No person mentioned
16+
},
17+
{
18+
'input': {"text_": "Barack Obama was the 44th president of the United States."},
19+
'output': {"predictions": "Barack Obama"}
20+
},
21+
{
22+
'input': {"text_": "Apple's latest product, the iPhone 15, was released in September 2023."},
23+
'output': {"predictions": "iPhone 15"}
24+
},
25+
{
26+
'input': {"text_": "Barack Obama was the 44th president of the United States."},
27+
'output': {"predictions": ""} # No product mentioned
28+
},
29+
{
30+
'input': {"text_": "Apple's latest product, the iPhone 15, was released in September 2023."},
31+
'output': {"predictions": "September 2023"}
32+
},
33+
{
34+
'input': {"text_": "Barack Obama was the 44th president of the United States."},
35+
'output': {"predictions": ""} # No date mentioned
36+
},
37+
{
38+
'input': {"text_": "Apple's latest product, the iPhone 15, was released in September 2023."},
39+
'output': {"predictions": ""} # No location mentioned
40+
},
41+
{
42+
'input': {"text_": "Barack Obama was the 44th president of the United States."},
43+
'output': {"predictions": "United States"}
44+
}
45+
],
46+
strict=False
47+
)
48+
def test_llm_parallel_skillset():
49+
from adala.skills.skillset import ParallelSkillSet, LLMSkill
50+
from adala.datasets import DataFrameDataset, InternalDataFrame
51+
from adala.runtimes import OpenAIRuntime
52+
53+
skillset = ParallelSkillSet(
54+
skills=[
55+
LLMSkill(name="skill_person", instructions="Extract person's name", input_data_field="text"),
56+
LLMSkill(name="skill_product", instructions="Extract product name", input_data_field="text"),
57+
LLMSkill(name="skill_date", instructions="Extract date", input_data_field="text"),
58+
LLMSkill(name="skill_location", instructions="Extract location", input_data_field="text"),
59+
]
60+
)
61+
dataset = DataFrameDataset(df=InternalDataFrame([
62+
"Apple's latest product, the iPhone 15, was released in September 2023.",
63+
"Barack Obama was the 44th president of the United States.",
64+
], columns=["text"]))
65+
predictions = skillset.apply(
66+
dataset=dataset,
67+
runtime=OpenAIRuntime(verbose=True),
68+
)
69+
70+
pd.testing.assert_frame_equal(InternalDataFrame.from_records([
71+
{
72+
'text': "Apple's latest product, the iPhone 15, was released in September 2023.",
73+
'skill_person': "", # No person mentioned
74+
'skill_product': 'iPhone 15',
75+
'skill_date': 'September 2023',
76+
'skill_location': "" # No location mentioned
77+
},
78+
{
79+
'text': 'Barack Obama was the 44th president of the United States.',
80+
'skill_person': 'Barack Obama',
81+
'skill_product': "", # No product mentioned
82+
'skill_date': "", # No date mentioned
83+
'skill_location': 'United States'
84+
}
85+
]), predictions)

0 commit comments

Comments
 (0)