Last active
July 4, 2024 01:38
-
-
Save tsvikas/186015d1d085e0e01e1e5170d54c2b8c to your computer and use it in GitHub Desktop.
generate images with dall-e
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Generate images using Dall-E, save the parameters and the output easily to file. | |
Usage: | |
```python | |
import os; os.environ["OPENAI_API_KEY"] = "SECRET_KEY" # set your API key | |
from generate_images import GeneratedImage, GeneratedImagesFile # import this code | |
img = GeneratedImage.generate("Astronaut") # generate an image | |
img # display the generated image + metadata in Jupyter | |
img.save_image("astronaut.png") # save a specific image | |
astronaut_images = GeneratedImagesFile("astronaut.jsonl") # load a file with many images | |
astronaut_images.append(img) # add a generated image to the file | |
astronaut_images.generate("Cool astronaut") # or generate and add to the file with one function | |
astronaut_images.generate_many("Psychadelic astronaut", n=10) # you can generate and add more than one image | |
astronaut_images # display thumbnails in Jupyter | |
astronaut_images[0] # access a specific image | |
astronaut_images[1, -1] # display thumbnails for a subset of images | |
astronaut_images.select(1, -1).display() # display a subset of images | |
astronaut_images.select(1, -1).copy_to("something.jsonl") # copy a subset of images | |
``` | |
""" # noqa: E501 | |
# TODO: improve docs, maybe with copilot / claude | |
# TODO: add testing, maybe with mock | |
import base64 | |
import dataclasses | |
import functools | |
from collections.abc import Iterable, Mapping | |
from pathlib import Path | |
from typing import Any, Literal, Self | |
import jsonlines | |
import openai | |
from IPython.display import Markdown, display | |
# TODO: if no API KEY, make it read-only. | |
client = openai.OpenAI() | |
ModelType = Literal["dall-e-2", "dall-e-3"] | |
SizeType = Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] | |
QualityType = Literal["standard", "hd"] | None | |
StyleType = Literal["vivid", "natural"] | None | |
def _check_valid_values( | |
name: str, value: str | None, valid_values: list[str | None], model: str | |
) -> None: | |
if value not in valid_values: | |
raise ValueError( | |
f"Invalid {name} for model {model!r}. " | |
f"Expected one of: {valid_values}. " | |
f"Received: {value!r}" | |
) | |
def _image_b64_to_html(image_b64: str, width: int) -> str: | |
img_src = f"data:image/png;base64,{image_b64}" | |
style = f"width:{width}px; display:inline-block; margin-right: 10px;" | |
return f'<img src="{img_src}" style="{style}"/>' | |
@dataclasses.dataclass | |
class GeneratedImage: | |
""" | |
The generated image from OpenAI. | |
API documentation: https://platform.openai.com/docs/api-reference/images/create | |
pricing details: https://openai.com/api/pricing/ | |
""" | |
prompt: str | None | |
revised_prompt: str | None | |
model: ModelType | |
size: SizeType | |
quality: QualityType | |
style: StyleType | |
image_b64: str = dataclasses.field(repr=False) | |
@property | |
def image_bytes(self) -> bytes: | |
"""Convert the image to bytes.""" | |
return base64.b64decode(self.image_b64) | |
def save_image(self, filename: Path | str) -> None: | |
""" | |
Save the image to file. | |
The API does not specify which format he uses, but it seems to return PNG. | |
""" | |
Path(filename).write_bytes(self.image_bytes) | |
def _repr_markdown_(self) -> str: | |
header_fields = ["prompt", "revised_prompt"] | |
detail_fields = ["model", "size", "quality", "style"] | |
header_markdowns = [f"**{fld}**: {getattr(self, fld)}" for fld in header_fields] | |
# TODO: add max width | |
image_markdown = f"" | |
details_markdown = "**details**: " + " | ".join( | |
[ | |
getattr(self, fld) | |
for fld in detail_fields | |
if getattr(self, fld) is not None | |
] | |
) | |
return " \n".join([*header_markdowns, details_markdown, image_markdown]) | |
def to_dict(self) -> dict[str, Any]: | |
"""Convert to dictionary.""" | |
data = dataclasses.asdict(self) | |
return data | |
@classmethod | |
def from_dict(cls, d: Mapping) -> Self: | |
"""Convert from dictionary.""" | |
return cls(**d) | |
@classmethod | |
def generate( # noqa: PLR0913 | |
cls, | |
prompt: str, | |
*, | |
model: ModelType = "dall-e-3", | |
size: SizeType = "1024x1024", | |
quality: QualityType = None, | |
style: StyleType = None, | |
use_exact_prompt: bool = False, | |
) -> Self: | |
""" | |
Create an image given a prompt. | |
use_exact_prompt: will add a OpenAI recommendad pre-prompt, to prevent revising | |
the prompt. | |
""" | |
if model == "dall-e-2": | |
_check_valid_values( | |
"size", size, ["256x256", "512x512", "1024x1024"], model | |
) | |
_check_valid_values("quality", quality, [None], model) | |
_check_valid_values("style", style, [None], model) | |
elif model == "dall-e-3": | |
_check_valid_values( | |
"size", size, ["1024x1024", "1792x1024", "1024x1792"], model | |
) | |
if quality is None: | |
quality = "standard" | |
if style is None: | |
style = "vivid" | |
else: | |
raise ValueError("Unsupported model") | |
pre_prompt = ( | |
"I NEED to test how the tool works with extremely simple prompts. " | |
"DO NOT add any detail, just use it AS-IS: " | |
if use_exact_prompt | |
else "" | |
) | |
prompt = pre_prompt + prompt | |
response = client.images.generate( | |
prompt=prompt, | |
model=model, | |
n=1, | |
quality=quality, | |
response_format="b64_json", | |
size=size, | |
style=style, | |
) | |
response_data = dict(response.data[0]) | |
# see https://platform.openai.com/docs/guides/images/prompting | |
return cls( | |
image_b64=response_data["b64_json"], | |
prompt=prompt, | |
revised_prompt=response_data["revised_prompt"], | |
model=model, | |
size=size, | |
quality=quality, | |
style=style, | |
) | |
def thumbnail(self, width: int = 128) -> Markdown: | |
"""Return a markdown object with thumbnail for the image.""" | |
# TODO: maybe lower the resolution, to help with filesize | |
return Markdown(_image_b64_to_html(self.image_b64, width)) | |
def display(self) -> None: | |
"""Display the image in jupyter.""" | |
display(self) | |
class GeneratedImages: | |
"""A list of generated images.""" | |
def __init__(self, images: Iterable[GeneratedImage]): | |
self._images = list(images) | |
@functools.singledispatchmethod | |
def __getitem__(self, index): # noqa: ANN001 | |
raise TypeError(f"Invalid index type: {type(index)}") | |
@__getitem__.register(slice) | |
def _(self, index: slice) -> "GeneratedImages": | |
return GeneratedImages(self._images[index]) | |
@__getitem__.register(Iterable) | |
def _(self, index: Iterable[int]) -> "GeneratedImages": | |
if isinstance(index, str | bytes): | |
raise TypeError(f"Invalid index type: {type(index)}") | |
return GeneratedImages([self._images[i] for i in index]) | |
@__getitem__.register(int) | |
def _(self, index: int) -> GeneratedImage: | |
return self._images[index] | |
def select(self, *indexes: tuple[int]) -> "GeneratedImages": | |
"""Return a subset of images.""" | |
return self[indexes] | |
def __repr__(self) -> str: | |
return f"{type(self).__name__}({self._images!r})" | |
def thumbnails(self, width: int = 128) -> Markdown: | |
"""Return a markdown object with thumbnails for the images.""" | |
# TODO: add alt text for index / prompt | |
return Markdown("\n".join(img.thumbnail(width).data for img in self._images)) | |
def _repr_markdown_(self) -> str: | |
return self.thumbnails(width=128).data | |
def copy_to(self, filename: Path | str) -> "GeneratedImagesFile": | |
"""Copy the images to another file.""" | |
other = GeneratedImagesFile(filename) | |
return other.extend(self._images) | |
def display(self) -> None: | |
"""Display all images.""" | |
for img in self._images: | |
img.display() | |
class GeneratedImagesFile(GeneratedImages): | |
""" | |
A jsonlines file contains several generated images. | |
Genereated images can be appended or generated directly. | |
""" | |
def __init__(self, filename: Path | str): | |
self.filename = Path(filename) | |
if self.filename.exists(): | |
with jsonlines.open(self.filename, "r") as reader: | |
images = [GeneratedImage.from_dict(data) for data in reader] | |
else: | |
images = [] | |
super().__init__(images) | |
def __repr__(self) -> str: | |
return f"{type(self).__name__}(filename={self.filename})" | |
def _repr_markdown_(self) -> str: | |
return ( | |
# TODO: the filename part is not working with jupyter | |
# f"**filename**: {self.filename}\n " + | |
super()._repr_markdown_() | |
) | |
def overwrite(self, images: Iterable[GeneratedImage]) -> Self: | |
"""Overwrite the files with images.""" | |
# TODO: add backup before? | |
self._images = list(images) | |
with jsonlines.open(self.filename, "w") as writer: | |
for img in self._images: | |
writer.write(img.to_dict()) | |
return self | |
def remove_last(self) -> Self: | |
""" | |
Remove the last image from the file. | |
It is a non effective way, since it rewrites all images. | |
""" | |
return self.overwrite(self._images[:-1]) | |
def append(self, img: GeneratedImage) -> Self: | |
"""Add an image to the file.""" | |
with jsonlines.open(self.filename, mode="a") as writer: | |
writer.write(img.to_dict()) | |
print(f"Image saved to index {len(self._images)}") | |
self._images.append(img) | |
return self | |
def extend(self, imgs: Iterable[GeneratedImage]) -> Self: | |
"""Add several images to the file.""" | |
for img in imgs: | |
self.append(img) | |
return self | |
def generate( # noqa: PLR0913 | |
self, | |
prompt: str, | |
*, | |
model: ModelType = "dall-e-3", | |
size: SizeType = "1024x1024", | |
quality: QualityType = None, | |
style: StyleType = None, | |
use_exact_prompt: bool = False, | |
) -> GeneratedImage: | |
"""Create an image given a prompt, and save it to the file.""" | |
img = GeneratedImage.generate( | |
prompt=prompt, | |
model=model, | |
size=size, | |
quality=quality, | |
style=style, | |
use_exact_prompt=use_exact_prompt, | |
) | |
self.append(img) | |
return img | |
def generate_many( # noqa: PLR0913 | |
self, | |
prompt: str, | |
n: int = 1, | |
*, | |
model: ModelType = "dall-e-3", | |
size: SizeType = "1024x1024", | |
quality: QualityType = None, | |
style: StyleType = None, | |
use_exact_prompt: bool = False, | |
) -> list[GeneratedImage]: | |
"""Create several images with the same prompt, and save them to the file.""" | |
return [ | |
self.generate( | |
prompt=prompt, | |
model=model, | |
size=size, | |
quality=quality, | |
style=style, | |
use_exact_prompt=use_exact_prompt, | |
) | |
for _i in range(n) | |
] | |
def generate_or_load( | |
self, index: int | None, prompt: str | None = None, **kwargs: dict[str, Any] | |
) -> GeneratedImage: | |
""" | |
Generate and save an image, or load it from the file. | |
An helper function that is used to verify that an image is not created twice | |
""" | |
# TODO: maybe instead search for the prompt in the file? | |
if index is None: | |
if prompt is None: | |
raise ValueError("A prompt is required when generating a new image.") | |
return self.generate(prompt=prompt, **kwargs) | |
print(f"Image loaded from index {index}") | |
img = self._images[index] | |
if prompt is not None: | |
kwargs = {"prompt": prompt, **kwargs} | |
for name, value in kwargs.items(): | |
if (saved_value := getattr(img, name)) != value: | |
raise ValueError(f"MISMATCH {name}, should be {saved_value!r}") | |
return img |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment