Skip to content

Commit 775902c

Browse files
committed
Sampler priority: better logging, always save to presets
1 parent acfbe6b commit 775902c

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

modules/presets.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,12 @@ def generate_preset_yaml(state):
120120
defaults = default_preset()
121121
data = {k: state[k] for k in presets_params()}
122122

123-
# Remove entries that are identical to the defaults
123+
# Remove entries that are identical to the defaults.
124+
# sampler_priority is always saved because it is experimental
125+
# and the default order may change.
126+
124127
for k in list(data.keys()):
125-
if data[k] == defaults[k]:
128+
if data[k] == defaults[k] and k != 'sampler_priority':
126129
del data[k]
127130

128131
return yaml.dump(data, sort_keys=False)

modules/sampler_hijack.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -428,16 +428,15 @@ def custom_sort_key(obj):
428428

429429
# Sort the list using the custom key function
430430
warpers = sorted(warpers, key=custom_sort_key)
431+
if shared.args.verbose:
432+
logger.info("WARPERS=")
433+
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])
431434

432435
if normalize is not None:
433436
warpers.append(normalize)
434437

435438
warpers.append(SpyLogitsWarper())
436439
warpers = LogitsProcessorList(warpers)
437-
if shared.args.verbose:
438-
logger.info("WARPERS=")
439-
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers])
440-
441440
return warpers
442441

443442

0 commit comments

Comments
 (0)