diff --git a/__init__.py b/__init__.py index 9e629dd..a832479 100644 --- a/__init__.py +++ b/__init__.py @@ -634,19 +634,19 @@ async def chat_completions(request): async def generate(): try: - client=openai_client(key,api_url) - - response = client.chat.completions.create( - model=model_name, - messages=messages, - stream=True - ) - - for chunk in response: - if hasattr(chunk.choices[0].delta, 'content'): - content = chunk.choices[0].delta.content - if content is not None: - yield content.encode('utf-8') + b"\r\n" + headers = { + 'Authorization': f'Bearer {key}', + 'Content-Type': 'application/json' + } + payload = { + 'model': model_name, + 'messages': messages, + 'stream': True + } + async with aiohttp.ClientSession() as session: + async with session.post(f'{api_url}/chat/completions', json=payload, headers=headers) as resp: + async for line in resp.content: + yield line except Exception as e: yield f"Error: {str(e)}".encode('utf-8') + b"\r\n" diff --git a/web/javascript/chat.js b/web/javascript/chat.js index 516cefe..c425433 100644 --- a/web/javascript/chat.js +++ b/web/javascript/chat.js @@ -127,14 +127,14 @@ export async function* chatCompletion ( stream: true, key: apiKey, model_name: model_name, - api_url, + api_url } let response = await fetch(mixlabAPI, { method: 'POST', headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${apiKey}` + 'Content-Type': 'application/json' + // Authorization: `Bearer ${apiKey}` }, body: JSON.stringify(requestBody), mode: 'cors', // This is to ensure the request is made with CORS @@ -154,15 +154,12 @@ export async function* chatCompletion ( if (result.done) { break } - - // Add any leftover data to the current chunk of data const text = leftover + decoder.decode(result.value) - // Check if the last character is a line break - const endsWithLineBreak = text.endsWith('\r\n') + const endsWithLineBreak = text.endsWith('\n') // Split the text into lines - let lines = text.split('\r\n') + let lines = text.split('\n') // If the text doesn't end with a line break, then the last line is incomplete // Store it in leftover to be added to the next chunk of data @@ -172,13 +169,31 @@ export async function* chatCompletion ( leftover = '' // Reset leftover if we have a line break at the end } + // Parse all sse events and add them to result + const regex = /^(\S+):\s(.*)$/gm for (const line of lines) { - if (line) { - content += line - yield line // Yield the trimmed line - } else { - cont = false - break + const match = regex.exec(line) + if (match) { + result[match[1]] = match[2] + // since we know this is llama.cpp, let's just decode the json in data + if (result.data) { + result.data = JSON.parse(result.data) + + + content += result.data.choices[0].delta?.content || '' + // console.log('#result.content',content) + // yield + yield result + + // if we got a stop token from server, we will break here + if (result.data.choices[0].finish_reason == 'stop') { + if (result.data.generation_settings) { + // generation_settings = result.data.generation_settings; + } + cont = false + break + } + } } } } diff --git a/web/javascript/ui_mixlab.js b/web/javascript/ui_mixlab.js index c70b7d1..3bfe377 100644 --- a/web/javascript/ui_mixlab.js +++ b/web/javascript/ui_mixlab.js @@ -834,11 +834,14 @@ function createModelsModal (models, llmKey) { background-color: var(--comfy-input-bg);font-size: 16px;">MixLab App` const siliconflowHelp = document.createElement('a') - siliconflowHelp.textContent = showTextByLanguage('Use Siliconflow', { - "Use Siliconflow": '使用硅基流动' - }) +'\n' +showTextByLanguage('Or Local LLM', { - "Or Local LLM": '或者本地LLM' - }) + siliconflowHelp.textContent = + showTextByLanguage('Use Siliconflow', { + 'Use Siliconflow': '使用硅基流动' + }) + + '\n' + + showTextByLanguage('Or Local LLM', { + 'Or Local LLM': '或者本地LLM' + }) siliconflowHelp.style = `color: var(--input-text); background-color: var(--comfy-input-bg);margin-top:14px;font-size: 16px;` siliconflowHelp.href = 'https://cloud.siliconflow.cn/s/mixlabs' @@ -1356,7 +1359,7 @@ app.registerExtension({ Object.values(getLocalData('_mixlab_llm_api_key'))[0], getLocalData('_mixlab_llm_api_url')['-'] || Object.values(getLocalData('_mixlab_llm_api_url'))[0], - getLocalData('_mixlab_llm_model_name')['-'] || + getLocalData('_mixlab_llm_model_name')['-'] || Object.values(getLocalData('_mixlab_llm_model_name'))[0], [ { @@ -1367,9 +1370,11 @@ app.registerExtension({ ], controller, t => { - // console.log(t.endsWith('\r')) - widget.value += t - jsonStr += t + let content = t.data?.choices[0]?.delta?.content || '' + + console.log(content) + widget.value += content + // jsonStr += content } ) } catch (error) { @@ -1410,7 +1415,6 @@ app.registerExtension({ // content: localStorage.getItem('_mixlab_system_prompt') // }, // // { role: 'user', content: userInput } - // { // role: 'user', // content: [ @@ -1428,7 +1432,6 @@ app.registerExtension({ // t => { // // console.log(t) // widget.value += t - // NoteNode.size[1] = widget.element.scrollHeight + 20 // widget.computedHeight = NoteNode.size[1] // app.canvas.centerOnNode(NoteNode)