Created
November 18, 2022 19:33
-
-
Save pplantinga/5c56a19dfcff8e139952b9d41340f48f to your computer and use it in GitHub Desktop.
Simple UI to play Telephone Diffusion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, sys, argparse, os | |
import tkinter as tk | |
from diffusers import StableDiffusionPipeline | |
MAXLEN=4 | |
# Parse the name of the game folder | |
p = argparse.ArgumentParser() | |
p.add_argument("game_name") | |
args = p.parse_args() | |
# Load most recent file, if any | |
os.makedirs(args.game_name, exist_ok=True) | |
files = sorted(os.listdir(args.game_name)) | |
if len(files) > 0: | |
number, _ = os.path.splitext(files[-1]) | |
current_index = int(number) + 1 | |
else: | |
current_index = 0 | |
# Initialize stable diffusion | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"CompVis/stable-diffusion-v1-4", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
use_auth_token=True, | |
) | |
pipe = pipe.to("cuda") | |
pipe.enable_attention_slicing() | |
# Initialize window | |
window = tk.Tk() | |
entryframe = tk.Frame(master=window, width=500) | |
entryframe.pack(fill=tk.X, side=tk.TOP) | |
imagelabel = tk.Label(master=window) | |
# If there's already a file, display it | |
if len(files) > 0: | |
imagetk = tk.PhotoImage(file=os.path.join(args.game_name, number + ".png")) | |
imagelabel.configure(image=imagetk) | |
imagelabel.pack(fill=tk.BOTH, side=tk.TOP) | |
entry = tk.Entry(master=entryframe, width=80) | |
def index2fpath(index): | |
index_string = str(index) | |
fname = "0" * (MAXLEN - len(index_string)) + index_string | |
fpath = os.path.join(args.game_name, fname) | |
return fpath | |
# Generate a new image from the prompt on button press | |
def generate_from_prompt(): | |
# Handle entry field | |
prompt = entry.get() | |
entry.delete(0, tk.END) | |
# Generate filenames | |
global current_index | |
fpath = index2fpath(current_index) | |
current_index += 1 | |
# Write prompt to file | |
with open(fpath + ".txt", "w") as w: | |
w.write(prompt) | |
# Stable-ly diffuse | |
with torch.cuda.amp.autocast(): | |
image = pipe(prompt).images[0] | |
# Save & display image | |
fpath = index2fpath(current_index) | |
image = image.resize((1024, 1024)) | |
image.save(fpath + ".png") | |
imagetk = tk.PhotoImage(file=fpath + ".png") | |
imagelabel.configure(image=imagetk) | |
imagelabel.image=imagetk | |
current_index += 1 | |
button = tk.Button(master=entryframe, text="Press me!", command=generate_from_prompt) | |
button.pack(side=tk.RIGHT) | |
entry.pack(fill=tk.X, side=tk.LEFT) | |
window.mainloop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment