Skip to content

Commit b357938

Browse files
feat: add extra body to assistant
1 parent fcc0746 commit b357938

File tree

11 files changed

+157
-12
lines changed

11 files changed

+157
-12
lines changed

app/api/routes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,3 @@ def router_init():
1313
api_router.include_router(files.router, prefix="/files", tags=["files"])
1414
api_router.include_router(token.router, prefix="/tokens", tags=["tokens"])
1515
api_router.include_router(action.router, prefix="/actions", tags=["actions"])
16-

app/core/runner/llm_backend.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,25 @@ class LLMBackend:
1313
def __init__(self, base_url: str, api_key) -> None:
1414
self.base_url = base_url + "/" if base_url else None
1515
self.api_key = api_key
16-
self.client = OpenAI(
17-
base_url=self.base_url,
18-
api_key=self.api_key
19-
)
16+
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
2017

2118
def run(
22-
self, messages: List, model: str, tools: List = None, tool_choice="auto", stream=False
19+
self, messages: List, model: str, tools: List = None, tool_choice="auto", stream=False, extra_body=None
2320
) -> ChatCompletion | Stream[ChatCompletionChunk]:
2421
chat_params = {
2522
"messages": messages,
2623
"model": model,
2724
"stream": stream,
2825
}
26+
if extra_body:
27+
model_params = extra_body.get("model_params")
28+
if model_params:
29+
if "n" in model_params:
30+
raise ValueError("n is not allowed in model_params")
31+
chat_params.update(model_params)
2932
if tools:
30-
chat_params['tools'] = tools
31-
chat_params['tool_choice'] = tool_choice if tool_choice else "auto"
33+
chat_params["tools"] = tools
34+
chat_params["tool_choice"] = tool_choice if tool_choice else "auto"
3235
logging.info("chat_params: %s", chat_params)
3336
response = self.client.chat.completions.create(**chat_params)
3437
logging.info("chat_response: %s", response)

app/core/runner/thread_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __run_step(self, llm: LLMBackend, run: Run, run_steps: List[RunStep], instru
106106
tools=[tool.openai_function for tool in tools],
107107
tool_choice="auto" if len(run_steps) < self.max_step else "none",
108108
stream=True,
109+
extra_body=run.extra_body,
109110
)
110111

111112
# create message creation run step callback

app/models/action.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from sqlalchemy import Column, JSON
22
from typing import Optional
33

4-
from sqlmodel import Field, Boolean
4+
from sqlmodel import Field
55

66
from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
77

@@ -27,5 +27,4 @@ class ActionBase(BaseModel):
2727

2828

2929
class Action(ActionBase, PrimaryKeyMixin, TimeStampMixin, table=True):
30-
3130
object: str = Field(nullable=False, default="action")

app/models/assistant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class AssistantBase(BaseModel):
1414
metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON))
1515
name: Optional[str] = Field(default=None)
1616
tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
17+
extra_body: Optional[dict] = Field(default=None, sa_column=Column(JSON))
1718

1819

1920
class Assistant(AssistantBase, PrimaryKeyMixin, TimeStampMixin, table=True):

app/models/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Any
1+
from typing import Optional, Any, Union
22

33
from sqlalchemy import Column, Enum
44
from sqlalchemy.sql.sqltypes import JSON, TEXT
@@ -61,7 +61,7 @@ class RunCreate(BaseModel):
6161
file_ids: Optional[list] = []
6262
metadata_: Optional[dict] = Field(default={}, alias="metadata")
6363
tools: Optional[list] = []
64-
extra_body: Optional[dict[str, dict[str, Authentication]]] = {}
64+
extra_body: Optional[dict[str, Union[dict[str, Union[Authentication, Any]], Any]]] = {}
6565

6666
@root_validator()
6767
def root_validator(cls, data: Any):

app/services/run/run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def create_run(
3030
body.instructions = db_asst.instructions
3131
if not body.tools and db_asst.tools:
3232
body.tools = db_asst.tools
33+
if not body.extra_body and db_asst.extra_body:
34+
body.extra_body = db_asst.extra_body
3335
# create run
3436
db_run = Run.model_validate(body.model_dump(), update={"thread_id": thread_id, "file_ids": db_asst.file_ids})
3537
session.add(db_run)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""add extra body to assistants
2+
3+
Revision ID: 8dbb8f38ef77
4+
Revises: e7339aab6549
5+
Create Date: 2024-03-19 15:27:39.793603
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
import sqlmodel
13+
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = '8dbb8f38ef77'
17+
down_revision: Union[str, None] = 'e7339aab6549'
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.add_column('assistant', sa.Column('extra_body', sa.JSON(), nullable=True))
25+
# ### end Alembic commands ###
26+
27+
28+
def downgrade() -> None:
29+
# ### commands auto generated by Alembic - please adjust! ###
30+
op.drop_column('assistant', 'extra_body')
31+
# ### end Alembic commands ###

tests/e2e/assistant_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import openai
2+
3+
from app.api.deps import get_session
4+
from app.services.assistant.assistant import AssistantService
5+
6+
7+
# 测试创建动作
8+
def test_create_assistant():
9+
client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx")
10+
assistant = client.beta.assistants.create(
11+
name="Assistant Demo",
12+
instructions="你是一个有用的助手",
13+
extra_body={
14+
"extra_body": {
15+
"frequency_penalty": 0,
16+
"logit_bias": None,
17+
"logprobs": True,
18+
"top_logprobs": 0,
19+
"max_tokens": 1024,
20+
"presence_penalty": 0.6,
21+
"temperature": 1,
22+
"n": 1,
23+
"presence_penalty": 0,
24+
"top_p": 1,
25+
}
26+
},
27+
# https://platform.openai.com/docs/api-reference/chat/create 具体参数看这里
28+
model="gpt-3.5-turbo-1106",
29+
)
30+
session = next(get_session())
31+
assistant = AssistantService.get_assistant(session=session, assistant_id=assistant.id)
32+
assert assistant.name == "Assistant Demo"
33+
assert assistant.instructions == "你是一个有用的助手"
34+
assert assistant.model == "gpt-3.5-turbo-1106"
35+
assert assistant.extra_body == {
36+
"frequency_penalty": 0,
37+
"logit_bias": None,
38+
"logprobs": True,
39+
"top_logprobs": 0,
40+
"max_tokens": 1024,
41+
"presence_penalty": 0.6,
42+
"temperature": 1,
43+
"n": 1,
44+
"presence_penalty": 0,
45+
"top_p": 1,
46+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import time
2+
3+
import openai
4+
5+
from app.api.deps import get_session
6+
from app.schemas.tool.action import ActionBulkCreateRequest
7+
from app.schemas.tool.authentication import Authentication, AuthenticationType
8+
from app.services.tool.action import ActionService
9+
10+
11+
def test_run_with_assistant_extra_body():
12+
client = openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="xxx")
13+
# 创建带有 action 的 assistant
14+
assistant = client.beta.assistants.create(
15+
name="Assistant Demo",
16+
instructions="你是一个有用的助手",
17+
model="gpt-3.5-turbo-1106",
18+
extra_body={
19+
"extra_body": {
20+
"model_params": {
21+
"frequency_penalty": 0,
22+
"logit_bias": None,
23+
"max_tokens": 1024,
24+
"presence_penalty": 0.6,
25+
"temperature": 1,
26+
"presence_penalty": 0,
27+
"top_p": 1,
28+
}
29+
}
30+
},
31+
)
32+
print(assistant, end="\n\n")
33+
34+
thread = client.beta.threads.create()
35+
print(thread, end="\n\n")
36+
37+
message = client.beta.threads.messages.create(
38+
thread_id=thread.id,
39+
role="user",
40+
content="你好,介绍一下你自己",
41+
)
42+
print(message, end="\n\n")
43+
44+
run = client.beta.threads.runs.create(thread_id=thread.id, assistant_id=assistant.id, instructions="")
45+
print(run, end="\n\n")
46+
47+
while True:
48+
# run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
49+
run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
50+
if run.status == "completed":
51+
print("done!", end="\n\n")
52+
messages = client.beta.threads.messages.list(thread_id=thread.id)
53+
54+
print("messages: ")
55+
for message in messages:
56+
assert message.content[0].type == "text"
57+
print(messages)
58+
print({"role": message.role, "message": message.content[0].text.value})
59+
60+
break
61+
else:
62+
print("\nin progress...")
63+
time.sleep(1)

0 commit comments

Comments
 (0)