Skip to content

Commit 470709b

Browse files
liuoooYangZhiBoGreenHand
authored andcommitted
[Feat] support thread copy
1 parent b357938 commit 470709b

File tree

6 files changed

+67
-2
lines changed

6 files changed

+67
-2
lines changed

app/api/deps.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,17 @@ async def get_param_from_request(request: Request):
9090
return get_param_from_request
9191

9292

93-
def verify_token_relation(relation_type: RelationType, name: str):
93+
def verify_token_relation(relation_type: RelationType, name: str, ignore_none_relation_id: bool = False):
94+
"""
95+
param relation_type: relation type
96+
param name: param name
97+
param ignore_none_relation_id: if ignore_none_relation_id is set, return where relation_id is None, use for copy thread api
98+
"""
9499
async def verify_authorization(
95100
session=Depends(get_session), token_id=Depends(get_token_id), relation_id=Depends(get_param(name))
96101
):
102+
if token_id and ignore_none_relation_id:
103+
return
97104
if token_id and relation_id:
98105
verify = TokenRelationQuery(token_id=token_id, relation_type=relation_type, relation_id=relation_id)
99106
if TokenRelationService.verify_relation(session=session, verify=verify):

app/models/thread.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class ThreadCreate(BaseModel):
1616
object: str = "thread"
1717
messages: Optional[list[MessageCreate]]
1818
# metadata: Optional[dict]
19+
thread_id: Optional[str]
20+
end_message_id: Optional[str]
1921

2022

2123
class ThreadUpdate(BaseModel):

app/providers/auth_provider.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def enable(self):
6161
verify_thread_depends = Depends(verify_token_relation(relation_type=RelationType.Thread, name="thread_id"))
6262
for route in thread.router.routes:
6363
if route.name == thread.create_thread.__name__:
64-
route.dependencies.append(Depends(verfiy_token))
64+
route.dependencies.append(Depends(
65+
verify_token_relation(relation_type=RelationType.Thread, name="thread_id", ignore_none_relation_id=True)))
6566
else:
6667
route.dependencies.append(verify_thread_depends)
6768

app/services/message/message.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,21 @@ def get_message_file(*, session: Session, thread_id: str, message_id: str, file_
7474
if msg_file is None:
7575
raise ResourceNotFoundError(message="Message file not found")
7676
return msg_file
77+
78+
@staticmethod
79+
def copy_messages(*, session: Session, from_thread_id: str, to_thread_id: str, end_message_id: str):
80+
"""
81+
copy thread messages to another thread
82+
"""
83+
statement = select(Message).where(Message.thread_id == from_thread_id)
84+
if end_message_id:
85+
statement = statement.where(Message.id <= end_message_id)
86+
original_messages = session.exec(statement.order_by(Message.id))
87+
88+
for original_message in original_messages:
89+
new_message = Message(
90+
thread_id=to_thread_id,
91+
**original_message.model_dump(exclude={"id", "created_at", "updated_at", "thread_id"})
92+
)
93+
session.add(new_message)
94+
session.commit()

app/services/thread/thread.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def create_thread(*, session: Session, body: ThreadCreate, token_id=None) -> Thr
3030
thread_id=thread_id,
3131
body=MessageCreate.from_orm(message),
3232
)
33+
elif body.thread_id:
34+
# copy thread
35+
from app.services.message.message import MessageService
36+
37+
MessageService.copy_messages(session=session,
38+
from_thread_id=body.thread_id,
39+
to_thread_id=thread_id,
40+
end_message_id=body.end_message_id)
3341
session.refresh(db_thread)
3442
return db_thread
3543

tests/e2e/test_thread_api.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
import openai
3+
4+
@pytest.fixture(name="client")
5+
def client_for_test():
6+
"""
7+
a openai client connected to local server for test
8+
"""
9+
return openai.OpenAI(base_url="http://localhost:8086/api/v1", api_key="ml-xxx")
10+
11+
12+
def test_thread_copy(client):
13+
"""
14+
test copy thread
15+
"""
16+
thread = client.beta.threads.create()
17+
contents = ["test1", "test2", "test3"]
18+
messages = [client.beta.threads.messages.create(thread_id=thread.id, role="user", content=content) for content in contents]
19+
for index, message in enumerate(messages):
20+
print(index)
21+
new_thread = client.beta.threads.create(extra_body={"thread_id": thread.id,
22+
"end_message_id": message.id})
23+
new_messages = client.beta.threads.messages.list(thread_id=new_thread.id).data
24+
assert len(new_messages) == index + 1
25+
26+
for i in range(index + 1):
27+
assert new_messages[i].content[0].text.value == contents[i]
28+
client.beta.threads.delete(thread_id=new_thread.id)
29+
client.beta.threads.delete(thread_id=thread.id)

0 commit comments

Comments
 (0)