Skip to content

Commit 66406f4

Browse files
niklubnik
and
nik
authored
Add guidance,langchain,openai runtimes; multi-I/O skills (#34)
Co-authored-by: nik <nik@heartex.net>
1 parent ff9749a commit 66406f4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1681
-2081
lines changed

README.md

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,52 +104,45 @@ Click [here](./examples/quickstart.ipynb) to see an extended quickstart example.
104104
import pandas as pd
105105

106106
from adala.agents import Agent
107-
from adala.datasets import DataFrameDataset
108-
from adala.environments import BasicEnvironment
107+
from adala.environments import StaticEnvironment
109108
from adala.skills import ClassificationSkill
110-
from adala.runtimes import OpenAIRuntime
109+
from adala.runtimes import OpenAIChatRuntime
111110
from rich import print
112111

113112
# Train dataset
114-
ground_truth_df = pd.DataFrame([
113+
train_df = pd.DataFrame([
115114
["It was the negative first impressions, and then it started working.", "Positive"],
116115
["Not loud enough and doesn't turn on like it should.", "Negative"],
117116
["I don't know what to say.", "Neutral"],
118117
["Manager was rude, but the most important that mic shows very flat frequency response.", "Positive"],
119118
["The phone doesn't seem to accept anything except CBR mp3s.", "Negative"],
120119
["I tried it before, I bought this device for my son.", "Neutral"],
121-
], columns=["text", "ground_truth"])
120+
], columns=["text", "sentiment"])
122121

123122
# Test dataset
124-
predict_df = pd.DataFrame([
123+
test_df = pd.DataFrame([
125124
"All three broke within two months of use.",
126125
"The device worked for a long time, can't say anything bad.",
127126
"Just a random line of text."
128127
], columns=["text"])
129128

130-
ground_truth_dataset = DataFrameDataset(df=ground_truth_df)
131-
predict_dataset = DataFrameDataset(df=predict_df)
132-
133129
agent = Agent(
134130
# connect to a dataset
135-
environment=BasicEnvironment(
136-
ground_truth_dataset=ground_truth_dataset,
137-
ground_truth_columns={"sentiment_classification": "ground_truth"}
138-
),
131+
environment=StaticEnvironment(df=train_df),
139132

140133
# define a skill
141134
skills=ClassificationSkill(
142-
name='sentiment_classification',
135+
name='sentiment',
143136
instructions="Label text as positive, negative or neutral.",
144-
labels=["Positive", "Negative", "Neutral"],
145-
input_data_field='text'
137+
labels={'sentiment': ["Positive", "Negative", "Neutral"]},
138+
input_template="Text: {text}",
139+
output_template="Sentiment: {sentiment}"
146140
),
147141

148142
# define all the different runtimes your skills may use
149143
runtimes = {
150144
# You can specify your OPENAI API KEY here via `OpenAIRuntime(..., api_key='your-api-key')`
151-
'openai': OpenAIRuntime(model='gpt-3.5-turbo-instruct'),
152-
'openai-gpt3': OpenAIRuntime(model='gpt-3.5-turbo')
145+
'openai': OpenAIChatRuntime(model='gpt-3.5-turbo'),
153146
},
154147
default_runtime='openai',
155148

@@ -166,7 +159,7 @@ print(agent.skills)
166159
agent.learn(learning_iterations=3, accuracy_threshold=0.95)
167160

168161
print('\n=> Run tests ...')
169-
predictions = agent.run(predict_dataset)
162+
predictions = agent.run(test_df)
170163
print('\n => Test results:')
171164
print(predictions)
172165
```

adala/agents/base.py

Lines changed: 140 additions & 119 deletions
Large diffs are not rendered by default.

adala/datasets/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

adala/datasets/base.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

adala/datasets/dataframe.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

adala/datasets/label_studio.py

Lines changed: 0 additions & 129 deletions
This file was deleted.

adala/environments/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .base import Environment, BasicEnvironment
1+
from .base import Environment, StaticEnvironment
22
from .console import ConsoleEnvironment
3-
from .web import WebEnvironment
3+
from .web import WebStaticEnvironment

0 commit comments

Comments
 (0)