Skip to content

Commit c60e3ab

Browse files
committed
Sampling: Add forced and additive support
Forced overrides overwrite what's passed in to that specific value. Additive adds the override to an existing array. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
1 parent 4b1e6d9 commit c60e3ab

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

common/samplerOverrides.ts

+36-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import * as YAML from "@std/yaml";
22
import * as z from "@/common/myZod.ts";
33
import { logger } from "@/common/logging.ts";
4+
import { BaseSamplerRequest } from "@/common/sampling.ts";
45

56
export const SamplerOverride = z.object({
67
override: z.unknown().refine((val) => val !== undefined && val !== null, {
@@ -16,17 +17,28 @@ export type SamplerOverride = z.infer<typeof SamplerOverride>;
1617
class SamplerOverridesContainer {
1718
selectedPreset?: string;
1819
overrides: Record<string, SamplerOverride> = {};
20+
forcedOverrides: Record<string, SamplerOverride> = {};
1921
}
2022

21-
export const overridesContainer = new SamplerOverridesContainer();
23+
// No need to export this, the functions work properly
24+
const overridesContainer = new SamplerOverridesContainer();
2225

2326
export function overridesFromDict(newOverrides: Record<string, unknown>) {
2427
const parsedOverrides: Record<string, SamplerOverride> = {};
2528

29+
// Forced also includes additive
30+
const forcedOverrides: Record<string, SamplerOverride> = {};
31+
2632
// Validate each entry as a SamplerOverride type
2733
for (const [key, value] of Object.entries(newOverrides)) {
2834
try {
29-
parsedOverrides[key] = SamplerOverride.parse(value);
35+
const parsedOverride = SamplerOverride.parse(value);
36+
parsedOverrides[key] = parsedOverride;
37+
38+
// Add to forced object for faster lookup
39+
if (parsedOverride.force || parsedOverride.additive) {
40+
forcedOverrides[key] = parsedOverride;
41+
}
3042
} catch (error) {
3143
if (error instanceof Error) {
3244
logger.error(error.stack);
@@ -39,6 +51,7 @@ export function overridesFromDict(newOverrides: Record<string, unknown>) {
3951
}
4052

4153
overridesContainer.overrides = parsedOverrides;
54+
overridesContainer.forcedOverrides = forcedOverrides;
4255
}
4356

4457
export async function overridesFromFile(presetName: string) {
@@ -62,6 +75,27 @@ export async function overridesFromFile(presetName: string) {
6275
}
6376
}
6477

78+
export function forcedSamplerOverrides(params: BaseSamplerRequest) {
79+
const castParams = params as Record<string, unknown>;
80+
81+
for (
82+
const [key, value] of Object.entries(overridesContainer.forcedOverrides)
83+
) {
84+
if (value.force) {
85+
castParams[key] = value.override;
86+
} else if (
87+
value.additive && Array.isArray(value.override) &&
88+
Array.isArray(castParams[key])
89+
) {
90+
castParams[key] = Array.from(
91+
new Set([...castParams[key], ...value.override]),
92+
);
93+
}
94+
}
95+
96+
return params;
97+
}
98+
6599
export function getSamplerDefault<T>(key: string, fallback?: T): T {
66100
const defaultValue = overridesContainer.overrides[key]?.override ??
67101
fallback;

common/sampling.ts

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import * as z from "@/common/myZod.ts";
2+
import { forcedSamplerOverrides } from "@/common/samplerOverrides.ts";
23

34
// Sampling schemas
45
const GenerationOptionsSchema = z.aliasedObject(
@@ -217,8 +218,8 @@ const MirostatSchema = z.object({
217218
description: "Mirostat options",
218219
});
219220

220-
// Construct from aliased sampler requests
221-
export const BaseSamplerRequest = GenerationOptionsSchema
221+
// Define the schema
222+
const BaseSamplerRequestSchema = GenerationOptionsSchema
222223
.and(TemperatureSamplerSchema)
223224
.and(AlphabetSamplerSchema)
224225
.and(PenaltySamplerSchema)
@@ -227,4 +228,11 @@ export const BaseSamplerRequest = GenerationOptionsSchema
227228
.and(DynatempSchema)
228229
.and(MirostatSchema);
229230

230-
export type BaseSamplerRequest = z.infer<typeof BaseSamplerRequest>;
231+
// Define the type from the schema
232+
export type BaseSamplerRequest = z.infer<typeof BaseSamplerRequestSchema>;
233+
234+
// Apply transforms and expose the type
235+
export const BaseSamplerRequest = BaseSamplerRequestSchema
236+
.transform((obj) => {
237+
return forcedSamplerOverrides(obj);
238+
});

0 commit comments

Comments
 (0)