Last active
September 13, 2023 11:01
-
-
Save cloneofsimo/4352c5207344bdcd61aa34b34aec5a5f to your computer and use it in GitHub Desktop.
Prompt Reweighting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from diffusers import DiffusionPipeline | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
import re | |
MODEL_CACHE = "./cache" | |
def split_by_emphasis(text, tokenizer, normalize = True): | |
def _toked(te): | |
return tokenizer(te, | |
truncation=True, | |
return_tensors="pt").input_ids[0, 1:-1] | |
# Regular expression pattern to match the emphasized parts of the text | |
pattern = r"\(([^)]+)\)" | |
# Find all matches of the pattern in the text | |
matches = re.findall(pattern, text) | |
# Split the text using the emphasis markers and create a list of tuples | |
# containing the split parts and their corresponding emphasis values | |
nulltext = tokenizer("", padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt" | |
).input_ids[0] | |
eos_token_id = tokenizer.eos_token_id | |
result = [] | |
start = 0 | |
for match in matches: | |
emphasis_start = text.find("(" + match + ")", start) | |
emphasis_end = emphasis_start + len("(" + match + ")") | |
result.append((_toked(text[start:emphasis_start]), 1.0)) | |
parts = match.split(":") | |
result.append((_toked(parts[0]), float(parts[1]))) | |
start = emphasis_end | |
result.append((_toked(text[start:]), 1.0)) | |
# concatenate all the tensors, with float expanded | |
result_id = torch.cat([r[0] for r in result], dim = 0) | |
result_emphasis = torch.cat([torch.tensor([r[1]] * len(r[0])) for r in result], dim = 0) | |
# if normalize, normalize the emphasis | |
if normalize: | |
result_emphasis = result_emphasis / result_emphasis.mean() | |
nulltext[1:len(result_id) + 1] = result_id | |
emphasis_ones = torch.ones_like(nulltext).float() | |
emphasis_ones[1:len(result_id) + 1] = result_emphasis | |
# add eos at eos | |
nulltext[len(result_id) + 1] = eos_token_id | |
return nulltext, emphasis_ones | |
def new_encode_prompt( | |
self, | |
prompt, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt=None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
lora_scale: Optional[float] = None, | |
): | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
# Define tokenizers and text encoders | |
tokenizers = ( | |
[self.tokenizer, self.tokenizer_2] | |
if self.tokenizer is not None | |
else [self.tokenizer_2] | |
) | |
text_encoders = ( | |
[self.text_encoder, self.text_encoder_2] | |
if self.text_encoder is not None | |
else [self.text_encoder_2] | |
) | |
if prompt_embeds is None: | |
# textual inversion: procecss multi-vector tokens if necessary | |
prompt_embeds_list = [] | |
for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
if isinstance(prompt, str): | |
text_input_ids_new, emphasis = split_by_emphasis(prompt, tokenizer) | |
text_input_ids_new = text_input_ids_new.unsqueeze(0).to(device) | |
else: | |
new_input_ids, emphasises = [], [] | |
for p in prompt: | |
new_input_id, emphasis = split_by_emphasis(p, tokenizer) | |
new_input_ids.append(new_input_id) | |
emphasises.append(emphasis) | |
text_input_ids_new = torch.stack(new_input_ids, dim=0).to(device) | |
emphasis = torch.stack(emphasises, dim=0).to(device) | |
prompt_embeds = text_encoder( | |
text_input_ids_new.to(device), | |
output_hidden_states=True, | |
) | |
pooled_prompt_embeds = prompt_embeds[0] | |
prompt_embeds = prompt_embeds.hidden_states[-2] * emphasis.reshape( | |
1, -1, 1 | |
).to(dtype=text_encoder.dtype, device=device) | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view( | |
bs_embed * num_images_per_prompt, seq_len, -1 | |
) | |
prompt_embeds_list.append(prompt_embeds) | |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | |
# get unconditional embeddings for classifier free guidance | |
zero_out_negative_prompt = ( | |
negative_prompt is None and self.config.force_zeros_for_empty_prompt | |
) | |
if ( | |
do_classifier_free_guidance | |
and negative_prompt_embeds is None | |
and zero_out_negative_prompt | |
): | |
negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) | |
elif do_classifier_free_guidance and negative_prompt_embeds is None: | |
negative_prompt = negative_prompt or "" | |
uncond_tokens: List[str] | |
if prompt is not None and type(prompt) is not type(negative_prompt): | |
raise TypeError( | |
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
f" {type(prompt)}." | |
) | |
elif isinstance(negative_prompt, str): | |
uncond_tokens = [negative_prompt] | |
elif batch_size != len(negative_prompt): | |
raise ValueError( | |
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
" the batch size of `prompt`." | |
) | |
else: | |
uncond_tokens = negative_prompt | |
negative_prompt_embeds_list = [] | |
for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
max_length = prompt_embeds.shape[1] | |
uncond_input = tokenizer( | |
uncond_tokens, | |
padding="max_length", | |
max_length=max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
negative_prompt_embeds = text_encoder( | |
uncond_input.input_ids.to(device), | |
output_hidden_states=True, | |
) | |
# We are only ALWAYS interested in the pooled output of the final text encoder | |
negative_pooled_prompt_embeds = negative_prompt_embeds[0] | |
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] | |
if do_classifier_free_guidance: | |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method | |
seq_len = negative_prompt_embeds.shape[1] | |
negative_prompt_embeds = negative_prompt_embeds.to( | |
dtype=text_encoder.dtype, device=device | |
) | |
negative_prompt_embeds = negative_prompt_embeds.repeat( | |
1, num_images_per_prompt, 1 | |
) | |
negative_prompt_embeds = negative_prompt_embeds.view( | |
batch_size * num_images_per_prompt, seq_len, -1 | |
) | |
# For classifier free guidance, we need to do two forward passes. | |
# Here we concatenate the unconditional and text embeddings into a single batch | |
# to avoid doing two forward passes | |
negative_prompt_embeds_list.append(negative_prompt_embeds) | |
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) | |
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | |
bs_embed * num_images_per_prompt, -1 | |
) | |
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( | |
1, num_images_per_prompt | |
).view(bs_embed * num_images_per_prompt, -1) | |
return ( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) | |
## use it like | |
pipe = DiffusionPipeline.from_pretrained( | |
MODEL_CACHE, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
) | |
pipe.to("cuda") | |
pipe.__class__.encode_prompt = new_encode_prompt | |
generator = torch.Generator("cuda").manual_seed(seed) | |
output = pipe( | |
prompt=[prompt] * num_outputs, | |
negative_prompt=[negative_prompt] * num_outputs, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
num_inference_steps=num_inference_steps, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
of course, bootstrapped from Huggingface's diffusers!!
(cat:0.8) and a (dog:1.1), particle dynamics, artwork, visual explosion, (blue force field:1.5), (colorful:1.2)
