Skip to content

Commit 4690e54

Browse files
committed
Implement basic fill-in-the-middle support
1 parent 15acb02 commit 4690e54

File tree

1 file changed

+165
-59
lines changed

1 file changed

+165
-59
lines changed

mikupad.html

+165-59
Original file line numberDiff line numberDiff line change
@@ -3634,8 +3634,7 @@
36343634

36353635
const promptText = useMemo(() => joinPrompt(promptChunks), [promptChunks]);
36363636

3637-
// compute separately as I imagine this can get expensive
3638-
const assembledWorldInfo = useMemo(() => {
3637+
const assembleWorldInfo = (promptText) => {
36393638
// assemble non-empty wi
36403639
const validWorldInfo = !Array.isArray(worldInfo.entries) ? [] : worldInfo.entries.filter(entry =>
36413640
entry.keys.length > 0 && !(entry.keys.length == 1 && entry.keys[0] == "") && entry.text !== "");
@@ -3670,16 +3669,19 @@
36703669
});
36713670
});
36723671

3673-
const assembledWorldInfo = activeWorldInfo.length > 0
3672+
return activeWorldInfo.length > 0
36743673
? activeWorldInfo.map(entry => entry.text).join("\n")
36753674
: "";
3675+
};
36763676

3677-
return assembledWorldInfo
3677+
// compute separately as I imagine this can get expensive
3678+
const assembledWorldInfo = useMemo(() => {
3679+
return assembleWorldInfo(promptText);
36783680
}, [worldInfo]);
36793681

3680-
const additionalContextPrompt = useMemo(() => {
3681-
// add world info to memory for easier assembly
3682-
memoryTokens["worldInfo"] = assembledWorldInfo;
3682+
const assembleAdditionalContext = (assembledWorldInfo, promptText) => {
3683+
if ("worldInfo" in memoryTokens)
3684+
delete memoryTokens["worldInfo"];
36833685

36843686
const order = ["prefix","text","suffix"]
36853687
const assembledAuthorNote = authorNoteTokens.text && authorNoteTokens.text !== ""
@@ -3688,19 +3690,19 @@
36883690

36893691
// replacements for the contextOrder string
36903692
const contextReplacements = {
3691-
"{wiPrefix}": memoryTokens.worldInfo && memoryTokens.worldInfo !== ""
3693+
"{wiPrefix}": assembledWorldInfo && assembledWorldInfo !== ""
36923694
? worldInfo.prefix
36933695
: "", // wi prefix and suffix will be added whenever wi isn't empty
3694-
"{wiText}": memoryTokens.worldInfo,
3695-
"{wiSuffix}": memoryTokens.worldInfo && memoryTokens.worldInfo !== ""
3696+
"{wiText}": assembledWorldInfo,
3697+
"{wiSuffix}": assembledWorldInfo && assembledWorldInfo !== ""
36963698
? worldInfo.suffix
36973699
: "",
36983700

3699-
"{memPrefix}": memoryTokens.text && memoryTokens.text !== "" || memoryTokens.worldInfo !== ""
3701+
"{memPrefix}": memoryTokens.text && memoryTokens.text !== "" || assembledWorldInfo !== ""
37003702
? memoryTokens.prefix
37013703
: "", // memory prefix and suffix will be added whenever memory or wi aren't empty
37023704
"{memText}": memoryTokens.text,
3703-
"{memSuffix}": memoryTokens.text && memoryTokens.text !== "" || memoryTokens.worldInfo !== ""
3705+
"{memSuffix}": memoryTokens.text && memoryTokens.text !== "" || assembledWorldInfo !== ""
37043706
? memoryTokens.suffix
37053707
: "",
37063708
}
@@ -3749,9 +3751,13 @@
37493751
}).join("\n").replace(/\\n/g, '\n');
37503752

37513753
return permContextPrompt;
3754+
};
3755+
3756+
const additionalContextPrompt = useMemo(() => {
3757+
return assembleAdditionalContext(assembledWorldInfo, promptText);
37523758
}, [contextLength, promptText, memoryTokens, authorNoteTokens, authorNoteDepth, assembledWorldInfo, worldInfo.prefix, worldInfo.suffix]);
37533759

3754-
const modifiedPrompt = useMemo(() => {
3760+
const assembleFinalPrompt = (additionalContextPrompt) => {
37553761
const templateReplacements = {
37563762
"{inst}": templates[selectedTemplate]?.instPre && templates[selectedTemplate]?.instPre !== ""
37573763
? templates[selectedTemplate]?.instPre
@@ -3774,9 +3780,75 @@
37743780
}).replace(/\\n/g, '\n');
37753781

37763782
return finalPrompt;
3783+
}
3784+
3785+
const modifiedPrompt = useMemo(() => {
3786+
return assembleFinalPrompt(additionalContextPrompt);
37773787
}, [additionalContextPrompt, templates, selectedTemplate]);
37783788

3779-
async function predict(prompt = modifiedPrompt, chunkCount = promptChunks.length) {
3789+
// predict all {fill} placeholders
3790+
async function fillsPredict() {
3791+
const fillPlaceholder = "{fill}";
3792+
3793+
let leftPromptChunks = [];
3794+
let rightPromptChunks = [];
3795+
let fillIdx = undefined;
3796+
3797+
for (let i = 0; i < promptChunks.length; i++) {
3798+
const chunk = promptChunks[i];
3799+
if (chunk.content.includes(fillPlaceholder)) {
3800+
// split the chunk in 2
3801+
const left = { content: chunk.content.substring(0, chunk.content.indexOf(fillPlaceholder)), type: "user" };
3802+
const right = { content: chunk.content.substring(chunk.content.indexOf(fillPlaceholder) + fillPlaceholder.length), type: "user" };
3803+
fillIdx = i + 1;
3804+
leftPromptChunks = [
3805+
...promptChunks.slice(0, Math.max(0, i - 1)),
3806+
...[left]
3807+
];
3808+
rightPromptChunks = [
3809+
...[right],
3810+
...promptChunks.slice(i + 1, promptChunks.length - 1),
3811+
];
3812+
break;
3813+
}
3814+
}
3815+
3816+
if (!fillIdx)
3817+
return;
3818+
3819+
const promptText = joinPrompt(leftPromptChunks);
3820+
const assembledWorldInfo = assembleWorldInfo(promptText);
3821+
const additionalContextPrompt = assembleAdditionalContext(assembledWorldInfo, promptText);
3822+
const finalPrompt = assembleFinalPrompt(additionalContextPrompt);
3823+
3824+
predict(finalPrompt, leftPromptChunks.length, (chunk) => {
3825+
console.log(chunk);
3826+
if (rightPromptChunks[0]) {
3827+
if (chunk.content.trim().startsWith(rightPromptChunks[0].content[0])) {
3828+
if (chunk.content[0] == ' ' && rightPromptChunks[0].content[0] != ' ') {
3829+
rightPromptChunks[0].content = ' ' + rightPromptChunks[0].content;
3830+
setPromptChunks(p => [
3831+
...leftPromptChunks,
3832+
...rightPromptChunks
3833+
]);
3834+
}
3835+
return false;
3836+
}
3837+
}
3838+
leftPromptChunks = [
3839+
...leftPromptChunks,
3840+
chunk
3841+
];
3842+
setPromptChunks(p => [
3843+
...leftPromptChunks,
3844+
...rightPromptChunks
3845+
]);
3846+
setTokens(t => t + (chunk?.completion_probabilities?.length ?? 1));
3847+
return true;
3848+
});
3849+
}
3850+
3851+
async function predict(prompt = modifiedPrompt, chunkCount = promptChunks.length, callback = undefined) {
37803852
if (cancel) {
37813853
cancel?.();
37823854

@@ -3786,7 +3858,7 @@
37863858
setCancel(() => () => cancelled = true);
37873859
await new Promise(resolve => setTimeout(resolve, 500));
37883860
if (cancelled)
3789-
return;
3861+
return false;
37903862
}
37913863

37923864
const ac = new AbortController();
@@ -3806,30 +3878,32 @@
38063878
// so let's set the predictStartTokens beforehand.
38073879
setPredictStartTokens(tokens);
38083880

3809-
const tokenCount = await getTokenCount({
3810-
endpoint,
3811-
endpointAPI,
3812-
...(endpointAPI == 3 || endpointAPI == 0 ? { endpointAPIKey } : {}),
3813-
content: prompt,
3814-
signal: ac.signal,
3815-
...(isMikupadEndpoint ? { proxyEndpoint: sessionStorage.proxyEndpoint } : {})
3816-
});
3817-
setTokens(tokenCount);
3818-
setPredictStartTokens(tokenCount);
3819-
3820-
// Chat Mode
3821-
if (chatMode && !restartedPredict) {
3822-
// add user EOT template (instruct suffix) if not switch completion
3823-
const eotUser = templates[selectedTemplate]?.instSuf.replace(/\\n/g, '\n')
3824-
setPromptChunks(p => [...p, { type: 'user', content: eotUser }])
3825-
prompt += `${eotUser}`
3826-
}
3827-
setRestartedPredict(false)
3881+
if (!callback) {
3882+
const tokenCount = await getTokenCount({
3883+
endpoint,
3884+
endpointAPI,
3885+
...(endpointAPI == 3 || endpointAPI == 0 ? { endpointAPIKey } : {}),
3886+
content: prompt,
3887+
signal: ac.signal,
3888+
...(isMikupadEndpoint ? { proxyEndpoint: sessionStorage.proxyEndpoint } : {})
3889+
});
3890+
setTokens(tokenCount);
3891+
setPredictStartTokens(tokenCount);
3892+
3893+
// Chat Mode
3894+
if (chatMode && !restartedPredict) {
3895+
// add user EOT template (instruct suffix) if not switch completion
3896+
const eotUser = templates[selectedTemplate]?.instSuf.replace(/\\n/g, '\n')
3897+
setPromptChunks(p => [...p, { type: 'user', content: eotUser }])
3898+
prompt += `${eotUser}`
3899+
}
3900+
setRestartedPredict(false)
38283901

3829-
while (undoStack.current.at(-1) >= chunkCount)
3830-
undoStack.current.pop();
3831-
undoStack.current.push(chunkCount);
3832-
redoStack.current = [];
3902+
while (undoStack.current.at(-1) >= chunkCount)
3903+
undoStack.current.pop();
3904+
undoStack.current.push(chunkCount);
3905+
redoStack.current = [];
3906+
}
38333907
setUndoHovered(false);
38343908
setRejectedAPIKey(false);
38353909
promptArea.current.scrollTarget = undefined;
@@ -3882,8 +3956,13 @@
38823956
chunk.content = chunk.stopping_word;
38833957
if (!chunk.content)
38843958
continue;
3885-
setPromptChunks(p => [...p, chunk]);
3886-
setTokens(t => t + (chunk?.completion_probabilities?.length ?? 1));
3959+
if (callback) {
3960+
if (!callback(chunk))
3961+
break;
3962+
} else {
3963+
setPromptChunks(p => [...p, chunk]);
3964+
setTokens(t => t + (chunk?.completion_probabilities?.length ?? 1));
3965+
}
38873966
chunkCount += 1;
38883967
}
38893968
} catch (e) {
@@ -3902,16 +3981,21 @@
39023981
return false;
39033982
} finally {
39043983
setCancel(c => c === cancelThis ? null : c);
3905-
if (undoStack.current.at(-1) === chunkCount)
3906-
undoStack.current.pop();
3984+
if (!callback) {
3985+
if (undoStack.current.at(-1) === chunkCount)
3986+
undoStack.current.pop();
3987+
}
39073988
}
3989+
39083990
// Chat Mode
3909-
if (chatMode) {
3991+
if (!callback && chatMode) {
39103992
// add bot EOT template (instruct prefix)
39113993
const eotBot = templates[selectedTemplate]?.instPre.replace(/\\n/g, '\n')
39123994
setPromptChunks(p => [...p, { type: 'user', content: eotBot }])
39133995
prompt += `${eotBot}`
39143996
}
3997+
3998+
return true;
39153999
}
39164000

39174001
function undo() {
@@ -4139,7 +4223,7 @@
41394223
switch (`${altKey}:${ctrlKey}:${shiftKey}:${key}`) {
41404224
case 'false:false:true:Enter':
41414225
case 'false:true:false:Enter':
4142-
predict();
4226+
fillsPredict();//predict();
41434227
break;
41444228
case 'false:false:false:Escape':
41454229
cancel();
@@ -4286,28 +4370,50 @@
42864370
newValue = newValue.slice(0, -chunk.content.length);
42874371
}
42884372

4373+
// Merge chunks if they're from the user
4374+
let mergeUserChunks = (chunks, newContent) => {
4375+
let lastChunk = chunks[chunks.length - 1];
4376+
while (lastChunk && lastChunk.type === 'user') {
4377+
lastChunk.content += newContent;
4378+
if (chunks[chunks.length - 2] && chunks[chunks.length - 2].type === 'user') {
4379+
newContent = lastChunk.content;
4380+
lastChunk = chunks[chunks.length - 2];
4381+
chunks.splice(chunks.length - 1, 1);
4382+
} else {
4383+
return chunks;
4384+
}
4385+
}
4386+
return [...chunks, { type: 'user', content: newContent }];
4387+
};
4388+
4389+
let newPrompt = [...start];
4390+
if (newValue) {
4391+
newPrompt = mergeUserChunks(newPrompt, newValue);
4392+
}
4393+
if (end.length && end[0].type === 'user') {
4394+
newPrompt = mergeUserChunks(newPrompt, end.shift().content);
4395+
}
4396+
newPrompt.push(...end);
4397+
42894398
// Remove all undo positions within the modified range.
4290-
undoStack.current = undoStack.current.filter(pos => start.length < pos);
4399+
undoStack.current = undoStack.current.filter(pos => pos > start.length && pos < newPrompt.length);
42914400
if (!undoStack.current.length)
42924401
setUndoHovered(false);
42934402

4294-
// Update all undo positions.
4295-
if (start.length + end.length + (+!!newValue) !== oldPromptLength) {
4296-
// Reset redo stack if a new chunk is added/removed at the end.
4297-
if (!end.length)
4298-
redoStack.current = [];
4403+
// Adjust undo/redo stacks.
4404+
const chunkDifference = oldPromptLength - newPrompt.length;
4405+
undoStack.current = undoStack.current.map(pos => {
4406+
if (pos >= start.length) {
4407+
return pos - chunkDifference;
4408+
}
4409+
return pos;
4410+
});
42994411

4300-
if (!oldPrompt.length)
4301-
undoStack.current = undoStack.current.map(pos => pos + 1);
4302-
else
4303-
undoStack.current = undoStack.current.map(pos => pos - oldPrompt.length);
4412+
// Reset redo stack if a new chunk is added/removed at the end.
4413+
if (chunkDifference < 0 && !end.length) {
4414+
redoStack.current = [];
43044415
}
43054416

4306-
const newPrompt = [
4307-
...start,
4308-
...(newValue ? [{ type: 'user', content: newValue }] : []),
4309-
...end,
4310-
];
43114417
return newPrompt;
43124418
});
43134419
}

0 commit comments

Comments
 (0)