-
Notifications
You must be signed in to change notification settings - Fork 272
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make Magic Prompts aware of LoRA syntax too
Refs #707
- Loading branch information
Showing
5 changed files
with
91 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from itertools import zip_longest | ||
|
||
from dynamicprompts.generators.magicprompt import MagicPromptGenerator | ||
|
||
from sd_dynamic_prompts.special_syntax import ( | ||
append_chunks, | ||
remove_a1111_special_syntax_chunks, | ||
) | ||
|
||
|
||
class SpecialSyntaxAwareMagicPromptGenerator(MagicPromptGenerator): | ||
""" | ||
Magic Prompt generator that is aware of A1111 special syntax (LoRA, hypernet, etc.). | ||
""" | ||
|
||
def _generate_magic_prompts(self, orig_prompts: list[str]) -> list[str]: | ||
orig_prompts, chunks = zip( | ||
*(remove_a1111_special_syntax_chunks(p) for p in orig_prompts), | ||
) | ||
magic_prompts = super()._generate_magic_prompts(orig_prompts) | ||
# in case we somehow get less magic prompts than we started with, | ||
# use zip_longest instead of zip. | ||
return [ | ||
append_chunks(prompt, chunk) | ||
for prompt, chunk in zip_longest(magic_prompts, chunks, fillvalue=None) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import re | ||
|
||
# A1111 special syntax (LoRA, hypernet, etc.) | ||
A1111_SPECIAL_SYNTAX_RE = re.compile(r"\s*<[^>]+>") | ||
|
||
|
||
def remove_a1111_special_syntax_chunks(s: str) -> tuple[str, list[str]]: | ||
""" | ||
Remove A1111 special syntax chunks from a string and return the string and the chunks. | ||
""" | ||
chunks: list[str] = [] | ||
|
||
def put_chunk(m): | ||
chunks.append(m.group(0)) | ||
return "" | ||
|
||
return re.sub(A1111_SPECIAL_SYNTAX_RE, put_chunk, s), chunks | ||
|
||
|
||
def append_chunks(s: str, chunks: list[str]) -> str: | ||
""" | ||
Append (A1111 special syntax) chunks to a string. | ||
""" | ||
if not chunks: | ||
return s | ||
return f"{s}{''.join(chunks)}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
def fake_generator(prompts, **_kwargs): | ||
for prompt in prompts: | ||
assert "<" not in prompt # should have been stripped | ||
yield [{"generated_text": f"magical {prompt}"}] | ||
|
||
|
||
def test_magic_prompts(monkeypatch): | ||
# Instrument the superclass so it doesn't try to load the model | ||
import dynamicprompts.generators.magicprompt as mp | ||
|
||
if hasattr(mp, "_import_transformers"): | ||
monkeypatch.setattr(mp, "_import_transformers", lambda: None) | ||
monkeypatch.setattr( | ||
mp.MagicPromptGenerator, | ||
"_load_pipeline", | ||
lambda self, model_name: fake_generator, | ||
) | ||
|
||
from sd_dynamic_prompts.magic_prompt import SpecialSyntaxAwareMagicPromptGenerator | ||
|
||
generator = SpecialSyntaxAwareMagicPromptGenerator() | ||
for prompt in generator.generate( | ||
"purple cat singing opera, artistic, painting " | ||
"<lora:loraname:0.7> <hypernet:v18000Steps:1>", | ||
5, | ||
): | ||
# These must remain unchanged | ||
assert "<lora:loraname:0.7>" in prompt | ||
assert "<hypernet:v18000Steps:1>" in prompt | ||
# but we should expect to see some magic | ||
assert prompt.startswith("magical ") |