Skip to content

Commit

Permalink
Make Magic Prompts aware of LoRA syntax too
Browse files Browse the repository at this point in the history
Refs #707
  • Loading branch information
akx committed Jan 16, 2024
1 parent b1edd80 commit c3f2f87
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 28 deletions.
30 changes: 4 additions & 26 deletions sd_dynamic_prompts/attention_generator.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,9 @@
import re

from dynamicprompts.generators.attentiongenerator import AttentionGenerator

# A1111 special syntax (LoRA, hypernet, etc.)
A1111_SPECIAL_SYNTAX_RE = re.compile(r"\s*<[^>]+>")


def remove_a1111_special_syntax_chunks(s: str) -> tuple[str, list[str]]:
"""
Remove A1111 special syntax chunks from a string and return the string and the chunks.
"""
chunks: list[str] = []

def put_chunk(m):
chunks.append(m.group(0))
return ""

return re.sub(A1111_SPECIAL_SYNTAX_RE, put_chunk, s), chunks


def append_chunks(s: str, chunks: list[str]) -> str:
"""
Append (A1111 special syntax) chunks to a string.
"""
if not chunks:
return s
return f"{s}{''.join(chunks)}"
from sd_dynamic_prompts.special_syntax import (
append_chunks,
remove_a1111_special_syntax_chunks,
)


class SpecialSyntaxAwareAttentionGenerator(AttentionGenerator):
Expand Down
6 changes: 4 additions & 2 deletions sd_dynamic_prompts/generator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,11 @@ def create_generator(self):
generator = self.create_basic_generator()

if self._is_magic_prompt:
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
from sd_dynamic_prompts.magic_prompt import (
SpecialSyntaxAwareMagicPromptGenerator,
)

generator = MagicPromptGenerator(
generator = SpecialSyntaxAwareMagicPromptGenerator(
generator,
model_name=self._magic_model,
device=self._device,
Expand Down
26 changes: 26 additions & 0 deletions sd_dynamic_prompts/magic_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from itertools import zip_longest

from dynamicprompts.generators.magicprompt import MagicPromptGenerator

from sd_dynamic_prompts.special_syntax import (
append_chunks,
remove_a1111_special_syntax_chunks,
)


class SpecialSyntaxAwareMagicPromptGenerator(MagicPromptGenerator):
"""
Magic Prompt generator that is aware of A1111 special syntax (LoRA, hypernet, etc.).
"""

def _generate_magic_prompts(self, orig_prompts: list[str]) -> list[str]:
orig_prompts, chunks = zip(
*(remove_a1111_special_syntax_chunks(p) for p in orig_prompts),
)
magic_prompts = super()._generate_magic_prompts(orig_prompts)
# in case we somehow get less magic prompts than we started with,
# use zip_longest instead of zip.
return [
append_chunks(prompt, chunk)
for prompt, chunk in zip_longest(magic_prompts, chunks, fillvalue=None)
]
26 changes: 26 additions & 0 deletions sd_dynamic_prompts/special_syntax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import re

# A1111 special syntax (LoRA, hypernet, etc.)
A1111_SPECIAL_SYNTAX_RE = re.compile(r"\s*<[^>]+>")


def remove_a1111_special_syntax_chunks(s: str) -> tuple[str, list[str]]:
"""
Remove A1111 special syntax chunks from a string and return the string and the chunks.
"""
chunks: list[str] = []

def put_chunk(m):
chunks.append(m.group(0))
return ""

return re.sub(A1111_SPECIAL_SYNTAX_RE, put_chunk, s), chunks


def append_chunks(s: str, chunks: list[str]) -> str:
"""
Append (A1111 special syntax) chunks to a string.
"""
if not chunks:
return s
return f"{s}{''.join(chunks)}"
31 changes: 31 additions & 0 deletions tests/prompts/test_magic_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
def fake_generator(prompts, **_kwargs):
for prompt in prompts:
assert "<" not in prompt # should have been stripped
yield [{"generated_text": f"magical {prompt}"}]


def test_magic_prompts(monkeypatch):
# Instrument the superclass so it doesn't try to load the model
import dynamicprompts.generators.magicprompt as mp

if hasattr(mp, "_import_transformers"):
monkeypatch.setattr(mp, "_import_transformers", lambda: None)
monkeypatch.setattr(
mp.MagicPromptGenerator,
"_load_pipeline",
lambda self, model_name: fake_generator,
)

from sd_dynamic_prompts.magic_prompt import SpecialSyntaxAwareMagicPromptGenerator

generator = SpecialSyntaxAwareMagicPromptGenerator()
for prompt in generator.generate(
"purple cat singing opera, artistic, painting "
"<lora:loraname:0.7> <hypernet:v18000Steps:1>",
5,
):
# These must remain unchanged
assert "<lora:loraname:0.7>" in prompt
assert "<hypernet:v18000Steps:1>" in prompt
# but we should expect to see some magic
assert prompt.startswith("magical ")

0 comments on commit c3f2f87

Please sign in to comment.