diff --git a/.changeset/shaggy-pets-knock.md b/.changeset/shaggy-pets-knock.md new file mode 100644 index 0000000..7f2e031 --- /dev/null +++ b/.changeset/shaggy-pets-knock.md @@ -0,0 +1,9 @@ +--- +'@modelscope-studio/antdx': patch +'@modelscope-studio/antd': patch +'@modelscope-studio/pro': patch +'@modelscope-studio/frontend': patch +'modelscope_studio': patch +--- + +fix: the logic of uploading files diff --git a/backend/modelscope_studio/components/pro/multimodal_input/__init__.py b/backend/modelscope_studio/components/pro/multimodal_input/__init__.py index 2b46b68..5ae14a6 100644 --- a/backend/modelscope_studio/components/pro/multimodal_input/__init__.py +++ b/backend/modelscope_studio/components/pro/multimodal_input/__init__.py @@ -18,9 +18,11 @@ class MultimodalInputUploadConfig(GradioModel): fullscreen_drop: Optional[bool] = False allow_paste_file: Optional[bool] = True + allow_speech: Optional[bool] = False show_count: Optional[bool] = True button_tooltip: Optional[str] = None accept: Optional[str] = None + max_count: Optional[int] = None directory: Optional[bool] = False multiple: Optional[bool] = False disabled: Optional[bool] = False diff --git a/docs/app.py b/docs/app.py index 66b5a6a..26f99ef 100644 --- a/docs/app.py +++ b/docs/app.py @@ -529,4 +529,5 @@ def more_components(): demo = site.render() if __name__ == "__main__": - demo.queue().launch(ssr_mode=False) + demo.queue(default_concurrency_limit=100, + max_size=100).launch(ssr_mode=False, max_threads=100) diff --git a/docs/layout_templates/chatbot/demos/app.py b/docs/layout_templates/chatbot/demos/app.py index 19525b7..c00900a 100644 --- a/docs/layout_templates/chatbot/demos/app.py +++ b/docs/layout_templates/chatbot/demos/app.py @@ -203,9 +203,12 @@ def preprocess_submit_handler(state_value): conversation["meta"]["disabled"] = True return { **({ - sender: gr.update(value=None, loading=True), - attachments: gr.update(value=[]), - attachments_badge: gr.update(dot=False), + sender: + gr.update(value=None, loading=True) if clear_input else gr.update(loading=True), + attachments: + gr.update(value=[]), + attachments_badge: + gr.update(dot=False), } if clear_input else {}), conversations: gr.update(active_key=state_value["conversation_id"], @@ -1020,13 +1023,14 @@ def logo(): sender, conversation_delete_menu_item, clear_btn, conversations, add_conversation_btn, chatbot, state ]) - sender.cancel(fn=None, cancels=[submit_event, regenerating_event]) sender.cancel(fn=Gradio_Events.cancel, inputs=[state], outputs=[ sender, conversation_delete_menu_item, clear_btn, conversations, add_conversation_btn, chatbot, state - ]) + ], + cancels=[submit_event, regenerating_event], + queue=False) if __name__ == "__main__": demo.queue().launch(ssr_mode=False) diff --git a/frontend/antd/upload/dragger/upload.dragger.tsx b/frontend/antd/upload/dragger/upload.dragger.tsx index f140df1..2f6cb7a 100644 --- a/frontend/antd/upload/dragger/upload.dragger.tsx +++ b/frontend/antd/upload/dragger/upload.dragger.tsx @@ -5,6 +5,7 @@ import type { FileData } from '@gradio/client'; import { useFunction } from '@utils/hooks/useFunction'; import { renderParamsSlot } from '@utils/renderParamsSlot'; import { type GetProps, Upload as AUpload, type UploadFile } from 'antd'; +import type { RcFile } from 'antd/es/upload'; import { noop } from 'lodash-es'; const isUploadFile = (file: FileData | UploadFile): file is UploadFile => { @@ -21,7 +22,7 @@ export const UploadDragger = sveltify< Omit, 'fileList' | 'onChange'> & { onValueChange?: (value: FileData[]) => void; onChange?: (value: string[]) => void; - upload: (files: File[]) => Promise<(FileData | null)[]>; + upload: (files: RcFile[]) => Promise<(FileData | null)[]>; fileList: FileData[]; setSlotParams: SetSlotParams; }, @@ -91,13 +92,19 @@ export const UploadDragger = sveltify< setFileList(fileListProp); }, [fileListProp]); const validFileList = useMemo(() => { + const visited: Record = {}; return ( - fileList?.map((file) => { + fileList.map((file) => { if (!isUploadFile(file)) { + const uid = file.url || file.path; + if (!visited[uid]) { + visited[uid] = 0; + } + visited[uid]++; return { ...file, name: file.orig_name || file.path, - uid: file.uid || file.url || file.path, + uid: file.uid || uid + '-' + visited[uid], status: 'done' as const, }; } @@ -105,6 +112,7 @@ export const UploadDragger = sveltify< }) || [] ); }, [fileList]); + return ( { - if (uploadingRef.current) { - return; - } - onRemove?.(file); + // onRemove={(file) => { + // if (uploadingRef.current) { + // return; + // } + // onRemove?.(file); + // const index = validFileList.findIndex((v) => v.uid === file.uid); + // const newFileList = fileList.slice() as FileData[]; + // newFileList.splice(index, 1); + // onValueChange?.(newFileList); + // onChange?.(newFileList.map((v) => v.path)); + // }} + maxCount={maxCount} + onChange={async (info) => { + const file = info.file; + const files = info.fileList; + + // remove const index = validFileList.findIndex((v) => v.uid === file.uid); - const newFileList = fileList.slice() as FileData[]; - newFileList.splice(index, 1); - onValueChange?.(newFileList); - onChange?.(newFileList.map((v) => v.path)); - }} - beforeUpload={async (file, files) => { - if (beforeUploadFunction) { - if (!(await beforeUploadFunction(file, files))) { - return false; + + if (index !== -1) { + if (uploadingRef.current) { + return; } - } + onRemove?.(file); + const newFileList = fileList.slice() as FileData[]; + newFileList.splice(index, 1); + onValueChange?.(newFileList); + onChange?.(newFileList.map((v) => v.path)); + } else { + // add + if (beforeUploadFunction) { + if (!(await beforeUploadFunction(file, files))) { + return; + } + } + if (uploadingRef.current) { + return; + } + uploadingRef.current = true; + let validFiles = files.filter((v) => v.status !== 'done'); - if (uploadingRef.current) { - return false; - } - uploadingRef.current = true; - let validFiles = files; - if (typeof maxCount === 'number') { - const max = maxCount - fileList.length; - validFiles = files.slice(0, max < 0 ? 0 : max); - } else if (maxCount === 1) { - validFiles = files.slice(0, 1); - } else if (validFiles.length === 0) { - uploadingRef.current = false; - return false; - } + if (maxCount === 1) { + validFiles = validFiles.slice(0, 1); + } else if (validFiles.length === 0) { + uploadingRef.current = false; + return; + } else if (typeof maxCount === 'number') { + const max = maxCount - fileList.length; + validFiles = validFiles.slice(0, max < 0 ? 0 : max); + } + + const lastFileList = fileList; - const lastFileList = fileList; - setFileList((prev) => [ - ...(maxCount === 1 ? [] : prev), - ...validFiles.map((v) => { - return { - ...v, - size: v.size, - uid: v.uid, - name: v.name, - status: 'uploading' as const, - }; - }), - ]); - const fileDataList = (await upload(validFiles)).filter( - (v) => v - ) as (FileData & { uid: string })[]; + setFileList((prev) => [ + ...(maxCount === 1 ? [] : prev), + ...validFiles.map((v) => { + return { + ...v, + size: v.size, + uid: v.uid, + name: v.name, + status: 'uploading' as const, + }; + }), + ]); - const mergedFileList = - maxCount === 1 - ? fileDataList - : ([...lastFileList, ...fileDataList] as FileData[]); - uploadingRef.current = false; - onValueChange?.(mergedFileList); - onChange?.(mergedFileList.map((v) => v.path)); - return false; + const fileDataList = ( + await upload(validFiles.map((f) => f.originFileObj as RcFile)) + ).filter(Boolean) as FileData[]; + const mergedFileList = + maxCount === 1 + ? fileDataList + : ([...lastFileList, ...fileDataList] as FileData[]); + + uploadingRef.current = false; + + setFileList(mergedFileList); + onValueChange?.(mergedFileList); + onChange?.(mergedFileList.map((v) => v.path)); + } }} - maxCount={1} customRequest={customRequestFunction || noop} progress={ progress diff --git a/frontend/antd/upload/upload.tsx b/frontend/antd/upload/upload.tsx index 416dd55..712c7a5 100644 --- a/frontend/antd/upload/upload.tsx +++ b/frontend/antd/upload/upload.tsx @@ -93,13 +93,19 @@ export const Upload = sveltify< setFileList(fileListProp); }, [fileListProp]); const validFileList = useMemo(() => { + const visited: Record = {}; return ( - fileList?.map((file) => { + fileList.map((file) => { if (!isUploadFile(file)) { + const uid = file.url || file.path; + if (!visited[uid]) { + visited[uid] = 0; + } + visited[uid]++; return { ...file, name: file.orig_name || file.path, - uid: file.uid || file.url || file.path, + uid: file.uid || uid + '-' + visited[uid], status: 'done' as const, }; } @@ -114,7 +120,7 @@ export const Upload = sveltify< data={dataFunction || data} previewFile={previewFileFunction} isImageUrl={isImageUrlFunction} - maxCount={1} + maxCount={maxCount} itemRender={ slots.itemRender ? renderParamsSlot({ slots, setSlotParams, key: 'itemRender' }) @@ -125,63 +131,85 @@ export const Upload = sveltify< ? renderParamsSlot({ slots, setSlotParams, key: 'iconRender' }) : iconRenderFunction } - onRemove={(file) => { - if (uploadingRef.current) { - return; - } - onRemove?.(file); - const index = validFileList.findIndex((v) => v.uid === file.uid); - const newFileList = fileList.slice() as FileData[]; - newFileList.splice(index, 1); - onValueChange?.(newFileList); - onChange?.(newFileList.map((v) => v.path)); - }} + // onRemove={(file) => { + // if (uploadingRef.current) { + // return; + // } + // onRemove?.(file); + // const index = validFileList.findIndex((v) => v.uid === file.uid); + // const newFileList = fileList.slice() as FileData[]; + // newFileList.splice(index, 1); + // onValueChange?.(newFileList); + // onChange?.(newFileList.map((v) => v.path)); + // }} customRequest={customRequestFunction || noop} - beforeUpload={async (file, files) => { - if (beforeUploadFunction) { - if (!(await beforeUploadFunction(file, files))) { - return false; + onChange={async (info) => { + const file = info.file; + const files = info.fileList; + // remove + const index = validFileList.findIndex((v) => v.uid === file.uid); + + if (index !== -1) { + if (uploadingRef.current) { + return; } - } + onRemove?.(file); + const newFileList = fileList.slice() as FileData[]; + newFileList.splice(index, 1); + onValueChange?.(newFileList); + onChange?.(newFileList.map((v) => v.path)); + } else { + // add + if (beforeUploadFunction) { + if (!(await beforeUploadFunction(file, files))) { + return; + } + } + if (uploadingRef.current) { + return; + } + uploadingRef.current = true; + let validFiles = files.filter((v) => v.status !== 'done'); + + if (maxCount === 1) { + validFiles = validFiles.slice(0, 1); + } else if (validFiles.length === 0) { + uploadingRef.current = false; + return; + } else if (typeof maxCount === 'number') { + const max = maxCount - fileList.length; + validFiles = validFiles.slice(0, max < 0 ? 0 : max); + } + + const lastFileList = fileList; + + setFileList((prev) => [ + ...(maxCount === 1 ? [] : prev), + ...validFiles.map((v) => { + return { + ...v, + size: v.size, + uid: v.uid, + name: v.name, + status: 'uploading' as const, + }; + }), + ]); + + const fileDataList = ( + await upload(validFiles.map((f) => f.originFileObj as RcFile)) + ).filter(Boolean) as FileData[]; + const mergedFileList = + maxCount === 1 + ? fileDataList + : ([...lastFileList, ...fileDataList] as FileData[]); - if (uploadingRef.current) { - return false; - } - uploadingRef.current = true; - let validFiles = files; - if (typeof maxCount === 'number') { - const max = maxCount - fileList.length; - validFiles = files.slice(0, max < 0 ? 0 : max); - } else if (maxCount === 1) { - validFiles = files.slice(0, 1); - } else if (validFiles.length === 0) { uploadingRef.current = false; - return false; + + setFileList(mergedFileList); + onValueChange?.(mergedFileList); + onChange?.(mergedFileList.map((v) => v.path)); } - const lastFileList = fileList; - setFileList((prev) => [ - ...(maxCount === 1 ? [] : prev), - ...validFiles.map((v) => { - return { - ...v, - size: v.size, - uid: v.uid, - name: v.name, - status: 'uploading' as const, - }; - }), - ]); - const fileDataList = (await upload(validFiles)).filter( - (v) => v - ) as (FileData & { uid: string })[]; - const mergedFileList = - maxCount === 1 - ? fileDataList - : ([...lastFileList, ...fileDataList] as FileData[]); - uploadingRef.current = false; - onValueChange?.(mergedFileList); - onChange?.(mergedFileList.map((v) => v.path)); - return false; }} progress={ progress diff --git a/frontend/antdx/attachments/attachments.tsx b/frontend/antdx/attachments/attachments.tsx index 20af24f..6e9a6b7 100644 --- a/frontend/antdx/attachments/attachments.tsx +++ b/frontend/antdx/attachments/attachments.tsx @@ -65,6 +65,7 @@ export const Attachments = sveltify< placeholder, getDropContainer, children, + maxCount, ...props }) => { const supportShowUploadListConfig = @@ -113,13 +114,19 @@ export const Attachments = sveltify< setFileList(items); }, [items]); const validFileList = useMemo(() => { + const visited: Record = {}; return ( - fileList?.map((file) => { + fileList.map((file) => { if (!isUploadFile(file)) { + const uid = file.url || file.path; + if (!visited[uid]) { + visited[uid] = 0; + } + visited[uid]++; return { ...file, name: file.orig_name || file.path, - uid: file.uid || file.url || file.path, + uid: file.uid || uid + '-' + visited[uid], status: 'done' as const, }; } @@ -183,27 +190,29 @@ export const Attachments = sveltify< ? renderParamsSlot({ slots, setSlotParams, key: 'iconRender' }) : iconRenderFunction } - onRemove={(file) => { - if (uploadingRef.current) { - return; - } - onRemove?.(file); - const index = validFileList.findIndex((v) => v.uid === file.uid); - const newFileList = fileList.slice() as FileData[]; - newFileList.splice(index, 1); - onValueChange?.(newFileList); - onChange?.(newFileList.map((v) => v.path)); - }} + maxCount={maxCount} + // onRemove={(file) => { + // if (uploadingRef.current) { + // return; + // } + // onRemove?.(file); + // const index = validFileList.findIndex((v) => v.uid === file.uid); + // const newFileList = fileList.slice() as FileData[]; + // newFileList.splice(index, 1); + // onValueChange?.(newFileList); + // onChange?.(newFileList.map((v) => v.path)); + // }} onChange={async (info) => { const file = info.file; const files = info.fileList; // remove - if (validFileList.find((v) => v.uid === file.uid)) { + const index = validFileList.findIndex((v) => v.uid === file.uid); + + if (index !== -1) { if (uploadingRef.current) { return; } onRemove?.(file); - const index = validFileList.findIndex((v) => v.uid === file.uid); const newFileList = fileList.slice() as FileData[]; newFileList.splice(index, 1); onValueChange?.(newFileList); @@ -212,17 +221,29 @@ export const Attachments = sveltify< // add if (beforeUploadFunction) { if (!(await beforeUploadFunction(file, files))) { - return false; + return; } } if (uploadingRef.current) { - return false; + return; } uploadingRef.current = true; - const validFiles = files.filter((v) => v.status !== 'done'); + let validFiles = files.filter((v) => v.status !== 'done'); + + if (maxCount === 1) { + validFiles = validFiles.slice(0, 1); + } else if (validFiles.length === 0) { + uploadingRef.current = false; + return; + } else if (typeof maxCount === 'number') { + const max = maxCount - fileList.length; + validFiles = validFiles.slice(0, max < 0 ? 0 : max); + } + const lastFileList = fileList; + setFileList((prev) => [ - ...prev, + ...(maxCount === 1 ? [] : prev), ...validFiles.map((v) => { return { ...v, @@ -233,14 +254,18 @@ export const Attachments = sveltify< }; }), ]); + const fileDataList = ( await upload(validFiles.map((f) => f.originFileObj as RcFile)) - ).filter((v) => v) as (FileData & { uid: string })[]; - const mergedFileList = [ - ...lastFileList, - ...fileDataList, - ] as FileData[]; + ).filter(Boolean) as FileData[]; + const mergedFileList = + maxCount === 1 + ? fileDataList + : ([...lastFileList, ...fileDataList] as FileData[]); + uploadingRef.current = false; + + setFileList(mergedFileList); onValueChange?.(mergedFileList); onChange?.(mergedFileList.map((v) => v.path)); } diff --git a/frontend/antdx/conversations/conversations.tsx b/frontend/antdx/conversations/conversations.tsx index 1ea9d3e..40c6817 100644 --- a/frontend/antdx/conversations/conversations.tsx +++ b/frontend/antdx/conversations/conversations.tsx @@ -34,9 +34,16 @@ function patchMenuEvents(menuProps: MenuProps, conversation: Conversation) { return Object.keys(menuProps).reduce((acc, key) => { if (key.startsWith('on') && isFunction(menuProps[key])) { const originalEvent = menuProps[key]; - acc[key] = (...args: any[]) => { - originalEvent?.(conversation, ...args); - }; + if (key === 'onClick') { + acc[key] = (menuInfo, ...args) => { + menuInfo.domEvent.stopPropagation(); + originalEvent?.(conversation, menuInfo, ...args); + }; + } else { + acc[key] = (...args: any[]) => { + originalEvent?.(conversation, ...args); + }; + } } else { acc[key] = menuProps[key]; } diff --git a/frontend/pro/multimodal-input/multimodal-input.tsx b/frontend/pro/multimodal-input/multimodal-input.tsx index f1b77c0..beb3f02 100644 --- a/frontend/pro/multimodal-input/multimodal-input.tsx +++ b/frontend/pro/multimodal-input/multimodal-input.tsx @@ -11,12 +11,16 @@ import { } from '@ant-design/x'; import { type FileData } from '@gradio/client'; import { convertObjectKeyToCamelCase } from '@utils/convertToCamelCase'; +import { useMemoizedFn } from '@utils/hooks/useMemoizedFn'; import { useValueChange } from '@utils/hooks/useValueChange'; import { omitUndefinedProps } from '@utils/omitUndefinedProps'; import { Badge, Button, Tooltip, type UploadFile } from 'antd'; import type { RcFile } from 'antd/es/upload'; import { noop, omit } from 'lodash-es'; +import { useRecorder } from './recorder'; +import { processAudio } from './utils'; + const isUploadFile = (file: FileData | UploadFile): file is UploadFile => { return !!(file as UploadFile).name; }; @@ -33,6 +37,7 @@ export interface MultimodalInputChangedValue { export interface UploadConfig extends Omit { fullscreenDrop?: boolean; + allowSpeech?: boolean; allowPasteFile?: boolean; showCount?: boolean; buttonTooltip?: string; @@ -66,7 +71,7 @@ export const MultimodalInput = sveltify< > & { children?: React.ReactNode; value?: MultimodalInputValue; - upload: (files: RcFile[]) => Promise; + upload: (files: File[]) => Promise; onPasteFile?: (value: string[]) => void; onValueChange: (value: MultimodalInputValue) => void; onChange?: (value: MultimodalInputChangedValue) => void; @@ -99,8 +104,44 @@ export const MultimodalInput = sveltify< }) => { const [open, setOpen] = useState(false); const suggestionOpen = useSuggestionOpenContext(); + const recorderContainerRef = useRef(null); + + const uploadFile = useMemoizedFn(async (file: File | File[]) => { + if (!(uploadConfig?.allowPasteFile ?? true)) { + return; + } + const maxCount = uploadConfig?.maxCount; + if ( + typeof maxCount === 'number' && + maxCount > 0 && + fileList.length >= maxCount + ) { + return; + } + const filesData = await upload(Array.isArray(file) ? file : [file]); - // const [recording, setRecording] = useState(false); + const newValue: MultimodalInputValue = { + ...value, + files: [...(fileList as FileData[]), ...filesData], + }; + onChange?.(formatChangedValue(newValue)); + setValue(newValue); + return filesData; + }); + + const { start, stop, recording } = useRecorder({ + container: recorderContainerRef.current, + async onStop(blob) { + const audioFile = new File( + [await processAudio(blob)], + `${Date.now()}_recording_result.wav`, + { + type: 'audio/wav', + } + ); + uploadFile(audioFile); + }, + }); const [value, setValue] = useValueChange({ onValueChange, value: valueProp, @@ -124,13 +165,19 @@ export const MultimodalInput = sveltify< }, [value?.files]); const validFileList = useMemo(() => { + const visited: Record = {}; return ( fileList.map((file) => { if (!isUploadFile(file)) { + const uid = file.url || file.path; + if (!visited[uid]) { + visited[uid] = 0; + } + visited[uid]++; return { ...file, name: file.orig_name || file.path, - uid: file.uid || file.url || file.path, + uid: file.uid || uid + '-' + visited[uid], status: 'done' as const, }; } @@ -138,8 +185,10 @@ export const MultimodalInput = sveltify< }) || [] ); }, [fileList]); + return ( <> +
{children}
{ @@ -172,28 +229,10 @@ export const MultimodalInput = sveltify< setValue(newValue); }} onPasteFile={async (file) => { - if (!(uploadConfig?.allowPasteFile ?? true)) { - return; - } - const maxCount = uploadConfig?.maxCount; - if ( - typeof maxCount === 'number' && - maxCount > 0 && - fileList.length >= maxCount - ) { - return; + const filesData = await uploadFile(file); + if (filesData) { + onPasteFile?.(filesData.map((url) => url.path)); } - const filesData = await upload( - (Array.isArray(file) ? file : [file]) as RcFile[] - ); - - onPasteFile?.(filesData.map((url) => url.path)); - const newValue: MultimodalInputValue = { - ...value, - files: [...(fileList as FileData[]), ...filesData], - }; - onChange?.(formatChangedValue(newValue)); - setValue(newValue); }} prefix={ <> @@ -265,35 +304,19 @@ export const MultimodalInput = sveltify< onDownload={onDownload} onPreview={onPreview} onDrop={onDrop} - onRemove={(file) => { - if (uploadingRef.current) { - return; - } - onRemove?.(file); - const index = validFileList.findIndex( - (v) => v.uid === file.uid - ); - const newFileList = fileList.slice() as FileData[]; - newFileList.splice(index, 1); - const newValue: MultimodalInputValue = { - ...value, - files: newFileList, - }; - setValue(newValue); - onChange?.(formatChangedValue(newValue)); - }} onChange={async (info) => { const file = info.file; const files = info.fileList; // remove - if (validFileList.find((v) => v.uid === file.uid)) { + const index = validFileList.findIndex( + (v) => v.uid === file.uid + ); + + if (index !== -1) { if (uploadingRef.current) { return; } onRemove?.(file); - const index = validFileList.findIndex( - (v) => v.uid === file.uid - ); const newFileList = fileList.slice() as FileData[]; newFileList.splice(index, 1); const newValue: MultimodalInputValue = { @@ -305,13 +328,26 @@ export const MultimodalInput = sveltify< } else { // add if (uploadingRef.current) { - return false; + return; } uploadingRef.current = true; - const validFiles = files.filter((v) => v.status !== 'done'); + let validFiles = files.filter((v) => v.status !== 'done'); + + const maxCount = uploadConfig?.maxCount; + if (maxCount === 1) { + validFiles = validFiles.slice(0, 1); + } else if (validFiles.length === 0) { + uploadingRef.current = false; + return; + } else if (typeof maxCount === 'number') { + const max = maxCount - fileList.length; + validFiles = validFiles.slice(0, max < 0 ? 0 : max); + } + const lastFileList = fileList; + setFileList((prev) => [ - ...prev, + ...(maxCount === 1 ? [] : prev), ...validFiles.map((v) => { return { ...v, @@ -327,17 +363,19 @@ export const MultimodalInput = sveltify< await upload( validFiles.map((f) => f.originFileObj as RcFile) ) - ).filter((v) => v) as (FileData & { uid: string })[]; - const mergedFileList = [ - ...lastFileList, - ...fileDataList, - ] as FileData[]; + ).filter(Boolean) as FileData[]; + const mergedFileList = + maxCount === 1 + ? fileDataList + : ([...lastFileList, ...fileDataList] as FileData[]); uploadingRef.current = false; const newValue: MultimodalInputValue = { ...value, files: mergedFileList, }; + + setFileList(mergedFileList); onValueChange?.(newValue); onChange?.(formatChangedValue(newValue)); } diff --git a/frontend/pro/multimodal-input/recorder.ts b/frontend/pro/multimodal-input/recorder.ts new file mode 100644 index 0000000..a693f11 --- /dev/null +++ b/frontend/pro/multimodal-input/recorder.ts @@ -0,0 +1,47 @@ +import { useEffect, useRef, useState } from 'react'; +import { useMemoizedFn } from '@utils/hooks/useMemoizedFn'; +import RecordPlugin from 'wavesurfer.js/dist/plugins/record.js'; +import WaveSurfer from 'wavesurfer.js/dist/wavesurfer'; + +export interface UseRecorderOptions { + container: HTMLElement | null; + onStop: (blob: Blob) => void; +} + +export function useRecorder({ container, onStop }: UseRecorderOptions) { + const recorderRef = useRef(null); + const [recording, setRecording] = useState(false); + + const start = useMemoizedFn(() => { + recorderRef.current?.startRecording(); + }); + + const stop = useMemoizedFn(() => { + recorderRef.current?.stopRecording(); + }); + const onStopMemoized = useMemoizedFn(onStop); + + useEffect(() => { + if (container) { + const micWaveform = WaveSurfer.create({ + normalize: false, + container, + }); + const recorder = micWaveform.registerPlugin(RecordPlugin.create()); + recorderRef.current = recorder; + + recorder.on('record-start', () => { + setRecording(true); + }); + recorder.on('record-end', (blob) => { + onStopMemoized(blob); + setRecording(false); + }); + } + }, [container, onStopMemoized]); + return { + recording, + start, + stop, + }; +} diff --git a/frontend/pro/multimodal-input/utils.ts b/frontend/pro/multimodal-input/utils.ts new file mode 100644 index 0000000..23a82c2 --- /dev/null +++ b/frontend/pro/multimodal-input/utils.ts @@ -0,0 +1,124 @@ +export function audioBufferToWav(audioBuffer: AudioBuffer): Uint8Array { + // Write WAV header + const writeString = function ( + view: DataView, + offset: number, + string: string + ): void { + for (let i = 0; i < string.length; i++) { + view.setUint8(offset + i, string.charCodeAt(i)); + } + }; + + const numOfChan = audioBuffer.numberOfChannels; + const length = audioBuffer.length * numOfChan * 2 + 44; + const buffer = new ArrayBuffer(length); + const view = new DataView(buffer); + let offset = 0; + + writeString(view, offset, 'RIFF'); + offset += 4; + view.setUint32(offset, length - 8, true); + offset += 4; + writeString(view, offset, 'WAVE'); + offset += 4; + writeString(view, offset, 'fmt '); + offset += 4; + view.setUint32(offset, 16, true); + offset += 4; // Sub-chunk size, 16 for PCM + view.setUint16(offset, 1, true); + offset += 2; // PCM format + view.setUint16(offset, numOfChan, true); + offset += 2; + view.setUint32(offset, audioBuffer.sampleRate, true); + offset += 4; + view.setUint32(offset, audioBuffer.sampleRate * 2 * numOfChan, true); + offset += 4; + view.setUint16(offset, numOfChan * 2, true); + offset += 2; + view.setUint16(offset, 16, true); + offset += 2; + writeString(view, offset, 'data'); + offset += 4; + view.setUint32(offset, audioBuffer.length * numOfChan * 2, true); + offset += 4; + + // Write PCM audio data + for (let i = 0; i < audioBuffer.numberOfChannels; i++) { + const channel = audioBuffer.getChannelData(i); + for (let j = 0; j < channel.length; j++) { + view.setInt16(offset, channel[j] * 0xffff, true); + offset += 2; + } + } + + return new Uint8Array(buffer); +} + +export const process_audio = ( + audioBuffer: AudioBuffer, + start?: number, + end?: number +): Promise => { + const audioContext = new AudioContext(); + const numberOfChannels = audioBuffer.numberOfChannels; + const sampleRate = audioBuffer.sampleRate; + + let trimmedLength = audioBuffer.length; + let startOffset = 0; + + if (start && end) { + startOffset = Math.round(start * sampleRate); + const endOffset = Math.round(end * sampleRate); + trimmedLength = endOffset - startOffset; + } + + const trimmedAudioBuffer = audioContext.createBuffer( + numberOfChannels, + trimmedLength, + sampleRate + ); + + for (let channel = 0; channel < numberOfChannels; channel++) { + const channelData = audioBuffer.getChannelData(channel); + const trimmedData = trimmedAudioBuffer.getChannelData(channel); + for (let i = 0; i < trimmedLength; i++) { + trimmedData[i] = channelData[startOffset + i]; + } + } + + return Promise.resolve(audioBufferToWav(trimmedAudioBuffer)); +}; +export async function processAudio(blob: Blob, start?: number, end?: number) { + const arrayBuffer = await blob.arrayBuffer(); + const context = new AudioContext(); + const audioBuffer = await context.decodeAudioData(arrayBuffer); + const audioContext = new AudioContext(); + const numberOfChannels = audioBuffer.numberOfChannels; + const sampleRate = audioBuffer.sampleRate; + + let trimmedLength = audioBuffer.length; + let startOffset = 0; + + if (start && end) { + startOffset = Math.round(start * sampleRate); + const endOffset = Math.round(end * sampleRate); + trimmedLength = endOffset - startOffset; + } + + const trimmedAudioBuffer = audioContext.createBuffer( + numberOfChannels, + trimmedLength, + sampleRate + ); + + for (let channel = 0; channel < numberOfChannels; channel++) { + const channelData = audioBuffer.getChannelData(channel); + const trimmedData = trimmedAudioBuffer.getChannelData(channel); + for (let i = 0; i < trimmedLength; i++) { + trimmedData[i] = channelData[startOffset + i]; + } + } + + return Promise.resolve(audioBufferToWav(trimmedAudioBuffer)); +}