Skip to content

Instantly share code, notes, and snippets.

@d0rc
Created September 13, 2024 15:13
Show Gist options
  • Save d0rc/9c254eafb106dd6a77c90f17f188a57b to your computer and use it in GitHub Desktop.
Save d0rc/9c254eafb106dd6a77c90f17f188a57b to your computer and use it in GitHub Desktop.
from flask import Flask, request, send_file
import torch
from diffusers import FluxPipeline
from io import BytesIO
import ssl
import os
app = Flask(__name__)
# Initialize the pipeline
torch.set_default_device('cuda')
model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
@app.route('/generate', methods=['POST'])
def generate_image():
data = request.json
prompt = data.get('prompt', 'A cat holding a sign that says hello world')
seed = data.get('seed', 42)
num_inference_steps = data.get('num_inference_steps', 25)
image = pipe(
prompt,
output_type="pil",
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
img_io = BytesIO()
image.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
def create_self_signed_cert(cert_file, key_file):
from OpenSSL import crypto
# Create a key pair
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 2048)
# Create a self-signed cert
cert = crypto.X509()
cert.get_subject().C = "US"
cert.get_subject().ST = "State"
cert.get_subject().L = "City"
cert.get_subject().O = "Organization"
cert.get_subject().OU = "Organizational Unit"
cert.get_subject().CN = "localhost"
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(k)
cert.sign(k, 'sha256')
# Write the cert and key files
with open(cert_file, "wb") as f:
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
with open(key_file, "wb") as f:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k))
if __name__ == '__main__':
cert_file = 'cert.pem'
key_file = 'key.pem'
if not (os.path.exists(cert_file) and os.path.exists(key_file)):
create_self_signed_cert(cert_file, key_file)
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(cert_file, key_file)
app.run(host='0.0.0.0', port=5000, ssl_context=context)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment