Skip to content

Instantly share code, notes, and snippets.

@pplantinga
Created November 18, 2022 19:33
Show Gist options
  • Save pplantinga/5c56a19dfcff8e139952b9d41340f48f to your computer and use it in GitHub Desktop.
Save pplantinga/5c56a19dfcff8e139952b9d41340f48f to your computer and use it in GitHub Desktop.
Simple UI to play Telephone Diffusion
import torch, sys, argparse, os
import tkinter as tk
from diffusers import StableDiffusionPipeline
MAXLEN=4
# Parse the name of the game folder
p = argparse.ArgumentParser()
p.add_argument("game_name")
args = p.parse_args()
# Load most recent file, if any
os.makedirs(args.game_name, exist_ok=True)
files = sorted(os.listdir(args.game_name))
if len(files) > 0:
number, _ = os.path.splitext(files[-1])
current_index = int(number) + 1
else:
current_index = 0
# Initialize stable diffusion
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True,
)
pipe = pipe.to("cuda")
pipe.enable_attention_slicing()
# Initialize window
window = tk.Tk()
entryframe = tk.Frame(master=window, width=500)
entryframe.pack(fill=tk.X, side=tk.TOP)
imagelabel = tk.Label(master=window)
# If there's already a file, display it
if len(files) > 0:
imagetk = tk.PhotoImage(file=os.path.join(args.game_name, number + ".png"))
imagelabel.configure(image=imagetk)
imagelabel.pack(fill=tk.BOTH, side=tk.TOP)
entry = tk.Entry(master=entryframe, width=80)
def index2fpath(index):
index_string = str(index)
fname = "0" * (MAXLEN - len(index_string)) + index_string
fpath = os.path.join(args.game_name, fname)
return fpath
# Generate a new image from the prompt on button press
def generate_from_prompt():
# Handle entry field
prompt = entry.get()
entry.delete(0, tk.END)
# Generate filenames
global current_index
fpath = index2fpath(current_index)
current_index += 1
# Write prompt to file
with open(fpath + ".txt", "w") as w:
w.write(prompt)
# Stable-ly diffuse
with torch.cuda.amp.autocast():
image = pipe(prompt).images[0]
# Save & display image
fpath = index2fpath(current_index)
image = image.resize((1024, 1024))
image.save(fpath + ".png")
imagetk = tk.PhotoImage(file=fpath + ".png")
imagelabel.configure(image=imagetk)
imagelabel.image=imagetk
current_index += 1
button = tk.Button(master=entryframe, text="Press me!", command=generate_from_prompt)
button.pack(side=tk.RIGHT)
entry.pack(fill=tk.X, side=tk.LEFT)
window.mainloop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment