Last active
October 20, 2023 17:53
-
-
Save Saren-Arterius/2014d9799361711d25ff33740990a7b8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/modules/api/api.py b/modules/api/api.py | |
index e6edffe7..86f51fa3 100644 | |
--- a/modules/api/api.py | |
+++ b/modules/api/api.py | |
@@ -337,6 +337,10 @@ class Api: | |
return script_args | |
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): | |
+ txt2imgreq.width = round(txt2imgreq.width / 64) * 64 | |
+ txt2imgreq.height = round(txt2imgreq.height / 64) * 64 | |
+ print('[t2i]', txt2imgreq.width, 'x', txt2imgreq.height, '|', txt2imgreq.prompt) | |
+ | |
script_runner = scripts.scripts_txt2img | |
if not script_runner.scripts: | |
script_runner.initialize_scripts(False) | |
@@ -387,6 +391,10 @@ class Api: | |
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) | |
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): | |
+ img2imgreq.width = round(img2imgreq.width / 64) * 64 | |
+ img2imgreq.height = round(img2imgreq.height / 64) * 64 | |
+ print('[i2i]', img2imgreq.width, 'x', img2imgreq.height, '|', img2imgreq.prompt) | |
+ | |
init_images = img2imgreq.init_images | |
if init_images is None: | |
raise HTTPException(status_code=404, detail="Init image not found") | |
diff --git a/modules/processing.py b/modules/processing.py | |
index e124e7f0..7c13be80 100644 | |
--- a/modules/processing.py | |
+++ b/modules/processing.py | |
@@ -233,6 +233,8 @@ class StableDiffusionProcessing: | |
self.cached_uc = StableDiffusionProcessing.cached_uc | |
self.cached_c = StableDiffusionProcessing.cached_c | |
+ shared.current_prompt = f'{self.prompt.lower()}|{self.width}*{self.height}*{self.batch_size}' | |
+ | |
@property | |
def sd_model(self): | |
return shared.sd_model | |
diff --git a/modules/sd_unet.py b/modules/sd_unet.py | |
index 5525cfbc..49daf1d4 100644 | |
--- a/modules/sd_unet.py | |
+++ b/modules/sd_unet.py | |
@@ -1,8 +1,8 @@ | |
import torch.nn | |
import ldm.modules.diffusionmodules.openaimodel | |
+import time | |
from modules import script_callbacks, shared, devices | |
- | |
unet_options = [] | |
current_unet_option = None | |
current_unet = None | |
@@ -85,8 +85,17 @@ class SdUnet(torch.nn.Module): | |
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): | |
- if current_unet is not None: | |
- return current_unet.forward(x, timesteps, context, *args, **kwargs) | |
- | |
- return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) | |
- | |
+ try: | |
+ if current_unet is not None and shared.current_prompt != shared.skip_unet_prompt: | |
+ if '[TRT]' in shared.opts.sd_unet and '<lora:' in shared.current_prompt: | |
+ raise Exception('LoRA unsupported in TRT UNet') | |
+ f = current_unet.forward(x, timesteps, context, *args, **kwargs) | |
+ return f | |
+ except Exception as e: | |
+ start = time.time() | |
+ print('[UNet] Skipping TRT UNet for this request:', e, '-', shared.current_prompt) | |
+ shared.sd_model.model.diffusion_model.to(devices.device) | |
+ shared.skip_unet_prompt = shared.current_prompt | |
+ print('[UNet] Used', time.time() - start, 'seconds') | |
+ | |
+ return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) | |
\ No newline at end of file | |
diff --git a/modules/shared.py b/modules/shared.py | |
index 63661939..577bd100 100644 | |
--- a/modules/shared.py | |
+++ b/modules/shared.py | |
@@ -85,3 +85,6 @@ list_checkpoint_tiles = shared_items.list_checkpoint_tiles | |
refresh_checkpoints = shared_items.refresh_checkpoints | |
list_samplers = shared_items.list_samplers | |
reload_hypernetworks = shared_items.reload_hypernetworks | |
+ | |
+current_prompt = '' | |
+skip_unet_prompt = '' | |
\ No newline at end of file |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment