1
1
import * as YAML from "@std/yaml" ;
2
2
import * as z from "@/common/myZod.ts" ;
3
3
import { logger } from "@/common/logging.ts" ;
4
+ import { BaseSamplerRequest } from "@/common/sampling.ts" ;
4
5
5
6
export const SamplerOverride = z . object ( {
6
7
override : z . unknown ( ) . refine ( ( val ) => val !== undefined && val !== null , {
@@ -16,17 +17,28 @@ export type SamplerOverride = z.infer<typeof SamplerOverride>;
16
17
class SamplerOverridesContainer {
17
18
selectedPreset ?: string ;
18
19
overrides : Record < string , SamplerOverride > = { } ;
20
+ forcedOverrides : Record < string , SamplerOverride > = { } ;
19
21
}
20
22
21
- export const overridesContainer = new SamplerOverridesContainer ( ) ;
23
+ // No need to export this, the functions work properly
24
+ const overridesContainer = new SamplerOverridesContainer ( ) ;
22
25
23
26
export function overridesFromDict ( newOverrides : Record < string , unknown > ) {
24
27
const parsedOverrides : Record < string , SamplerOverride > = { } ;
25
28
29
+ // Forced also includes additive
30
+ const forcedOverrides : Record < string , SamplerOverride > = { } ;
31
+
26
32
// Validate each entry as a SamplerOverride type
27
33
for ( const [ key , value ] of Object . entries ( newOverrides ) ) {
28
34
try {
29
- parsedOverrides [ key ] = SamplerOverride . parse ( value ) ;
35
+ const parsedOverride = SamplerOverride . parse ( value ) ;
36
+ parsedOverrides [ key ] = parsedOverride ;
37
+
38
+ // Add to forced object for faster lookup
39
+ if ( parsedOverride . force || parsedOverride . additive ) {
40
+ forcedOverrides [ key ] = parsedOverride ;
41
+ }
30
42
} catch ( error ) {
31
43
if ( error instanceof Error ) {
32
44
logger . error ( error . stack ) ;
@@ -39,6 +51,7 @@ export function overridesFromDict(newOverrides: Record<string, unknown>) {
39
51
}
40
52
41
53
overridesContainer . overrides = parsedOverrides ;
54
+ overridesContainer . forcedOverrides = forcedOverrides ;
42
55
}
43
56
44
57
export async function overridesFromFile ( presetName : string ) {
@@ -62,6 +75,27 @@ export async function overridesFromFile(presetName: string) {
62
75
}
63
76
}
64
77
78
+ export function forcedSamplerOverrides ( params : BaseSamplerRequest ) {
79
+ const castParams = params as Record < string , unknown > ;
80
+
81
+ for (
82
+ const [ key , value ] of Object . entries ( overridesContainer . forcedOverrides )
83
+ ) {
84
+ if ( value . force ) {
85
+ castParams [ key ] = value . override ;
86
+ } else if (
87
+ value . additive && Array . isArray ( value . override ) &&
88
+ Array . isArray ( castParams [ key ] )
89
+ ) {
90
+ castParams [ key ] = Array . from (
91
+ new Set ( [ ...castParams [ key ] , ...value . override ] ) ,
92
+ ) ;
93
+ }
94
+ }
95
+
96
+ return params ;
97
+ }
98
+
65
99
export function getSamplerDefault < T > ( key : string , fallback ?: T ) : T {
66
100
const defaultValue = overridesContainer . overrides [ key ] ?. override ??
67
101
fallback ;
0 commit comments