Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Last active September 13, 2023 11:01
Show Gist options
  • Save cloneofsimo/4352c5207344bdcd61aa34b34aec5a5f to your computer and use it in GitHub Desktop.
Save cloneofsimo/4352c5207344bdcd61aa34b34aec5a5f to your computer and use it in GitHub Desktop.
Prompt Reweighting
from diffusers import DiffusionPipeline
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import re
MODEL_CACHE = "./cache"
def split_by_emphasis(text, tokenizer, normalize = True):
def _toked(te):
return tokenizer(te,
truncation=True,
return_tensors="pt").input_ids[0, 1:-1]
# Regular expression pattern to match the emphasized parts of the text
pattern = r"\(([^)]+)\)"
# Find all matches of the pattern in the text
matches = re.findall(pattern, text)
# Split the text using the emphasis markers and create a list of tuples
# containing the split parts and their corresponding emphasis values
nulltext = tokenizer("", padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).input_ids[0]
eos_token_id = tokenizer.eos_token_id
result = []
start = 0
for match in matches:
emphasis_start = text.find("(" + match + ")", start)
emphasis_end = emphasis_start + len("(" + match + ")")
result.append((_toked(text[start:emphasis_start]), 1.0))
parts = match.split(":")
result.append((_toked(parts[0]), float(parts[1])))
start = emphasis_end
result.append((_toked(text[start:]), 1.0))
# concatenate all the tensors, with float expanded
result_id = torch.cat([r[0] for r in result], dim = 0)
result_emphasis = torch.cat([torch.tensor([r[1]] * len(r[0])) for r in result], dim = 0)
# if normalize, normalize the emphasis
if normalize:
result_emphasis = result_emphasis / result_emphasis.mean()
nulltext[1:len(result_id) + 1] = result_id
emphasis_ones = torch.ones_like(nulltext).float()
emphasis_ones[1:len(result_id) + 1] = result_emphasis
# add eos at eos
nulltext[len(result_id) + 1] = eos_token_id
return nulltext, emphasis_ones
def new_encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = (
[self.tokenizer, self.tokenizer_2]
if self.tokenizer is not None
else [self.tokenizer_2]
)
text_encoders = (
[self.text_encoder, self.text_encoder_2]
if self.text_encoder is not None
else [self.text_encoder_2]
)
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
if isinstance(prompt, str):
text_input_ids_new, emphasis = split_by_emphasis(prompt, tokenizer)
text_input_ids_new = text_input_ids_new.unsqueeze(0).to(device)
else:
new_input_ids, emphasises = [], []
for p in prompt:
new_input_id, emphasis = split_by_emphasis(p, tokenizer)
new_input_ids.append(new_input_id)
emphasises.append(emphasis)
text_input_ids_new = torch.stack(new_input_ids, dim=0).to(device)
emphasis = torch.stack(emphasises, dim=0).to(device)
prompt_embeds = text_encoder(
text_input_ids_new.to(device),
output_hidden_states=True,
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] * emphasis.reshape(
1, -1, 1
).to(dtype=text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = (
negative_prompt is None and self.config.force_zeros_for_empty_prompt
)
if (
do_classifier_free_guidance
and negative_prompt_embeds is None
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=text_encoder.dtype, device=device
)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
## use it like
pipe = DiffusionPipeline.from_pretrained(
MODEL_CACHE,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
pipe.to("cuda")
pipe.__class__.encode_prompt = new_encode_prompt
generator = torch.Generator("cuda").manual_seed(seed)
output = pipe(
prompt=[prompt] * num_outputs,
negative_prompt=[negative_prompt] * num_outputs,
width=width,
height=height,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
)
@cloneofsimo
Copy link
Author

of course, bootstrapped from Huggingface's diffusers!!

(cat:0.8) and a (dog:1.1), particle dynamics, artwork, visual explosion, (blue force field:1.5), (colorful:1.2)
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment