Skip to content

Commit 62484d4

Browse files
committed
Implement basic fill-in-the-middle support
1 parent f2735d8 commit 62484d4

File tree

1 file changed

+147
-48
lines changed

1 file changed

+147
-48
lines changed

mikupad.html

+147-48
Original file line numberDiff line numberDiff line change
@@ -3046,8 +3046,7 @@
30463046

30473047
const promptText = useMemo(() => joinPrompt(promptChunks), [promptChunks]);
30483048

3049-
// compute separately as I imagine this can get expensive
3050-
const assembledWorldInfo = useMemo(() => {
3049+
const assembleWorldInfo = (promptText) => {
30513050
// assemble non-empty wi
30523051
const validWorldInfo = !Array.isArray(worldInfo.entries) ? [] : worldInfo.entries.filter(entry =>
30533052
entry.keys.length > 0 && !(entry.keys.length == 1 && entry.keys[0] == "") && entry.text !== "");
@@ -3082,16 +3081,19 @@
30823081
});
30833082
});
30843083

3085-
const assembledWorldInfo = activeWorldInfo.length > 0
3084+
return activeWorldInfo.length > 0
30863085
? activeWorldInfo.map(entry => entry.text).join("\n")
30873086
: "";
3087+
};
30883088

3089-
return assembledWorldInfo
3089+
// compute separately as I imagine this can get expensive
3090+
const assembledWorldInfo = useMemo(() => {
3091+
return assembleWorldInfo(promptText);
30903092
}, [worldInfo]);
30913093

3092-
const modifiedPrompt = useMemo(() => {
3093-
// add world info to memory for easier assembly
3094-
memoryTokens["worldInfo"] = assembledWorldInfo;
3094+
const assembleFinalPrompt = (assembledWorldInfo, promptText) => {
3095+
if ("worldInfo" in memoryTokens)
3096+
delete memoryTokens["worldInfo"];
30953097

30963098
const order = ["prefix","text","suffix"]
30973099
const assembledAuthorNote = authorNoteTokens.text && authorNoteTokens.text !== ""
@@ -3100,19 +3102,19 @@
31003102

31013103
// replacements for the contextOrder string
31023104
const replacements = {
3103-
"{wiPrefix}": memoryTokens.worldInfo && memoryTokens.worldInfo !== ""
3105+
"{wiPrefix}": assembledWorldInfo && assembledWorldInfo !== ""
31043106
? worldInfo.prefix
31053107
: "", // wi prefix and suffix will be added whenever wi isn't empty
3106-
"{wiText}": memoryTokens.worldInfo,
3107-
"{wiSuffix}": memoryTokens.worldInfo && memoryTokens.worldInfo !== ""
3108+
"{wiText}": assembledWorldInfo,
3109+
"{wiSuffix}": assembledWorldInfo && assembledWorldInfo !== ""
31083110
? worldInfo.suffix
31093111
: "",
31103112

3111-
"{memPrefix}": memoryTokens.text && memoryTokens.text !== "" || memoryTokens.worldInfo !== ""
3113+
"{memPrefix}": memoryTokens.text && memoryTokens.text !== "" || assembledWorldInfo !== ""
31123114
? memoryTokens.prefix
31133115
: "", // memory prefix and suffix will be added whenever memory or wi aren't empty
31143116
"{memText}": memoryTokens.text,
3115-
"{memSuffix}": memoryTokens.text && memoryTokens.text !== "" || memoryTokens.worldInfo !== ""
3117+
"{memSuffix}": memoryTokens.text && memoryTokens.text !== "" || assembledWorldInfo !== ""
31163118
? memoryTokens.suffix
31173119
: "",
31183120
}
@@ -3160,9 +3162,74 @@
31603162
}).join("\n").replace(/\\n/g, '\n');
31613163

31623164
return permContextPrompt;
3165+
};
3166+
3167+
const modifiedPrompt = useMemo(() => {
3168+
return assembleFinalPrompt(assembledWorldInfo, promptText);
31633169
}, [contextLength, promptText, memoryTokens, authorNoteTokens, authorNoteDepth, assembledWorldInfo, worldInfo.prefix, worldInfo.suffix]);
31643170

3165-
async function predict(prompt = modifiedPrompt, chunkCount = promptChunks.length) {
3171+
// predict all {fill} placeholders
3172+
async function fillsPredict() {
3173+
const fillPlaceholder = "{fill}";
3174+
3175+
let leftPromptChunks = [];
3176+
let rightPromptChunks = [];
3177+
let fillIdx = undefined;
3178+
3179+
for (let i = 0; i < promptChunks.length; i++) {
3180+
const chunk = promptChunks[i];
3181+
if (chunk.content.includes(fillPlaceholder)) {
3182+
// split the chunk in 2
3183+
const left = { content: chunk.content.substring(0, chunk.content.indexOf(fillPlaceholder)), type: "user" };
3184+
const right = { content: chunk.content.substring(chunk.content.indexOf(fillPlaceholder) + fillPlaceholder.length), type: "user" };
3185+
fillIdx = i + 1;
3186+
leftPromptChunks = [
3187+
...promptChunks.slice(0, Math.max(0, i - 1)),
3188+
...[left]
3189+
];
3190+
rightPromptChunks = [
3191+
...[right],
3192+
...promptChunks.slice(i + 1, promptChunks.length - 1),
3193+
];
3194+
break;
3195+
}
3196+
}
3197+
3198+
if (!fillIdx)
3199+
return;
3200+
3201+
let promptText = joinPrompt(leftPromptChunks);
3202+
let assembledWorldInfo = assembleWorldInfo(promptText);
3203+
let finalPrompt = assembleFinalPrompt(assembledWorldInfo, promptText);
3204+
3205+
predict(finalPrompt, leftPromptChunks.length, (chunk) => {
3206+
console.log(chunk);
3207+
if (rightPromptChunks[0]) {
3208+
if (chunk.content.trim().startsWith(rightPromptChunks[0].content[0])) {
3209+
if (chunk.content[0] == ' ' && rightPromptChunks[0].content[0] != ' ') {
3210+
rightPromptChunks[0].content = ' ' + rightPromptChunks[0].content;
3211+
setPromptChunks(p => [
3212+
...leftPromptChunks,
3213+
...rightPromptChunks
3214+
]);
3215+
}
3216+
return false;
3217+
}
3218+
}
3219+
leftPromptChunks = [
3220+
...leftPromptChunks,
3221+
chunk
3222+
];
3223+
setPromptChunks(p => [
3224+
...leftPromptChunks,
3225+
...rightPromptChunks
3226+
]);
3227+
setTokens(t => t + (chunk?.completion_probabilities?.length ?? 1));
3228+
return true;
3229+
});
3230+
}
3231+
3232+
async function predict(prompt = modifiedPrompt, chunkCount = promptChunks.length, callback = undefined) {
31663233
if (cancel) {
31673234
cancel?.();
31683235

@@ -3172,7 +3239,7 @@
31723239
setCancel(() => () => cancelled = true);
31733240
await new Promise(resolve => setTimeout(resolve, 500));
31743241
if (cancelled)
3175-
return;
3242+
return false;
31763243
}
31773244

31783245
const ac = new AbortController();
@@ -3192,21 +3259,23 @@
31923259
// so let's set the predictStartTokens beforehand.
31933260
setPredictStartTokens(tokens);
31943261

3195-
const tokenCount = await getTokenCount({
3196-
endpoint,
3197-
endpointAPI,
3198-
...(endpointAPI == 3 || endpointAPI == 0 ? { endpointAPIKey } : {}),
3199-
content: prompt,
3200-
signal: ac.signal,
3201-
...(isMikupadEndpoint ? { proxyEndpoint: sessionStorage.proxyEndpoint } : {})
3202-
});
3203-
setTokens(tokenCount);
3204-
setPredictStartTokens(tokenCount);
3262+
if (!callback) {
3263+
const tokenCount = await getTokenCount({
3264+
endpoint,
3265+
endpointAPI,
3266+
...(endpointAPI == 3 || endpointAPI == 0 ? { endpointAPIKey } : {}),
3267+
content: prompt,
3268+
signal: ac.signal,
3269+
...(isMikupadEndpoint ? { proxyEndpoint: sessionStorage.proxyEndpoint } : {})
3270+
});
3271+
setTokens(tokenCount);
3272+
setPredictStartTokens(tokenCount);
32053273

3206-
while (undoStack.current.at(-1) >= chunkCount)
3207-
undoStack.current.pop();
3208-
undoStack.current.push(chunkCount);
3209-
redoStack.current = [];
3274+
while (undoStack.current.at(-1) >= chunkCount)
3275+
undoStack.current.pop();
3276+
undoStack.current.push(chunkCount);
3277+
redoStack.current = [];
3278+
}
32103279
setUndoHovered(false);
32113280
setRejectedAPIKey(false);
32123281
promptArea.current.scrollTarget = undefined;
@@ -3259,8 +3328,13 @@
32593328
chunk.content = chunk.stopping_word;
32603329
if (!chunk.content)
32613330
continue;
3262-
setPromptChunks(p => [...p, chunk]);
3263-
setTokens(t => t + (chunk?.completion_probabilities?.length ?? 1));
3331+
if (callback) {
3332+
if (!callback(chunk))
3333+
break;
3334+
} else {
3335+
setPromptChunks(p => [...p, chunk]);
3336+
setTokens(t => t + (chunk?.completion_probabilities?.length ?? 1));
3337+
}
32643338
chunkCount += 1;
32653339
}
32663340
} catch (e) {
@@ -3279,9 +3353,12 @@
32793353
return false;
32803354
} finally {
32813355
setCancel(c => c === cancelThis ? null : c);
3282-
if (undoStack.current.at(-1) === chunkCount)
3283-
undoStack.current.pop();
3356+
if (!callback) {
3357+
if (undoStack.current.at(-1) === chunkCount)
3358+
undoStack.current.pop();
3359+
}
32843360
}
3361+
return true;
32853362
}
32863363

32873364
function undo() {
@@ -3507,7 +3584,7 @@
35073584
switch (`${altKey}:${ctrlKey}:${shiftKey}:${key}`) {
35083585
case 'false:false:true:Enter':
35093586
case 'false:true:false:Enter':
3510-
predict();
3587+
fillsPredict();//predict();
35113588
break;
35123589
case 'false:false:false:Escape':
35133590
cancel();
@@ -3654,28 +3731,50 @@
36543731
newValue = newValue.slice(0, -chunk.content.length);
36553732
}
36563733

3734+
// Merge chunks if they're from the user
3735+
let mergeUserChunks = (chunks, newContent) => {
3736+
let lastChunk = chunks[chunks.length - 1];
3737+
while (lastChunk && lastChunk.type === 'user') {
3738+
lastChunk.content += newContent;
3739+
if (chunks[chunks.length - 2] && chunks[chunks.length - 2].type === 'user') {
3740+
newContent = lastChunk.content;
3741+
lastChunk = chunks[chunks.length - 2];
3742+
chunks.splice(chunks.length - 1, 1);
3743+
} else {
3744+
return chunks;
3745+
}
3746+
}
3747+
return [...chunks, { type: 'user', content: newContent }];
3748+
};
3749+
3750+
let newPrompt = [...start];
3751+
if (newValue) {
3752+
newPrompt = mergeUserChunks(newPrompt, newValue);
3753+
}
3754+
if (end.length && end[0].type === 'user') {
3755+
newPrompt = mergeUserChunks(newPrompt, end.shift().content);
3756+
}
3757+
newPrompt.push(...end);
3758+
36573759
// Remove all undo positions within the modified range.
3658-
undoStack.current = undoStack.current.filter(pos => start.length < pos);
3760+
undoStack.current = undoStack.current.filter(pos => pos > start.length && pos < newPrompt.length);
36593761
if (!undoStack.current.length)
36603762
setUndoHovered(false);
36613763

3662-
// Update all undo positions.
3663-
if (start.length + end.length + (+!!newValue) !== oldPromptLength) {
3664-
// Reset redo stack if a new chunk is added/removed at the end.
3665-
if (!end.length)
3666-
redoStack.current = [];
3764+
// Adjust undo/redo stacks.
3765+
const chunkDifference = oldPromptLength - newPrompt.length;
3766+
undoStack.current = undoStack.current.map(pos => {
3767+
if (pos >= start.length) {
3768+
return pos - chunkDifference;
3769+
}
3770+
return pos;
3771+
});
36673772

3668-
if (!oldPrompt.length)
3669-
undoStack.current = undoStack.current.map(pos => pos + 1);
3670-
else
3671-
undoStack.current = undoStack.current.map(pos => pos - oldPrompt.length);
3773+
// Reset redo stack if a new chunk is added/removed at the end.
3774+
if (chunkDifference < 0 && !end.length) {
3775+
redoStack.current = [];
36723776
}
36733777

3674-
const newPrompt = [
3675-
...start,
3676-
...(newValue ? [{ type: 'user', content: newValue }] : []),
3677-
...end,
3678-
];
36793778
return newPrompt;
36803779
});
36813780
}

0 commit comments

Comments
 (0)