Created
August 29, 2023 05:02
-
-
Save KohakuBlueleaf/a2aeb10fad00eae1f71ec85a7cff30b8 to your computer and use it in GitHub Desktop.
A danbooru crawler for making large datasets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import cv2 | |
import json | |
import traceback | |
from concurrent.futures import ProcessPoolExecutor | |
from pathlib import Path | |
import numpy as np | |
from httpx import AsyncClient, HTTPError, Timeout | |
from tqdm import tqdm | |
TARGET_RES = 1024 | |
MIN_RES = 512 | |
MAX_PIXEL = 30000000 | |
OUTPUT_DIR = "./6600K-1600K" | |
WEBP_QUALITY = 90 | |
TIMEOUT = 2.5 | |
RETRIES = 5 | |
DOWNLOAD_WORKERS = 64 | |
PROCESS_WORKERS = 32 | |
DANBOORU_USERNAME = "" | |
DANBOORU_API_KEY = "" | |
headers_pixiv = { | |
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36", | |
'referer': 'https://www.pixiv.net/' | |
} | |
banned_tags = ['furry', "realistic", "3d", "1940s_(style)","1950s_(style)","1960s_(style)","1970s_(style)","1980s_(style)","1990s_(style)","retro_artstyle","screentones","pixel_art","magazine_scan","scan"] | |
bad_tags = ["absurdres", "jpeg_artifacts", "highres", "translation_request", "translated", "commentary", "commentary_request", "commentary_typo", "character_request", "bad_id", "bad_link", "bad_pixiv_id", "bad_twitter_id", "bad_tumblr_id", "bad_deviantart_id", "bad_nicoseiga_id", "md5_mismatch", "cosplay_request", "artist_request", "wide_image", "author_request", "artist_name"] | |
def rescale( | |
image: np.ndarray, | |
output_size: int | |
) -> np.ndarray: | |
h,w = image.shape[:2] | |
r = max(output_size / h, output_size / w) | |
new_h, new_w = int(h * r), int(w * r) | |
return cv2.resize(image,(new_w, new_h)) | |
def save_img( | |
img_id: int, | |
img: np.ndarray, | |
tags: list[str] | |
) -> None: | |
''' | |
Save image to target size and directory | |
''' | |
img = img.astype(np.float32) / np.iinfo(img.dtype).max | |
if min(img.shape[:2]) < MIN_RES: | |
return None | |
if img.shape[0]*img.shape[1] > MAX_PIXEL: | |
return None | |
if img.shape[-1] == 4: | |
alpha = img[:, :, -1][:, :, np.newaxis] | |
img = (1 - alpha) * 1 + alpha * img[:, :, :-1] | |
if len(img.shape) < 3 or img.shape[-1] == 1: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
if min(img.shape[:2]) > TARGET_RES: | |
img = rescale(img, TARGET_RES) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
output_dir = Path(OUTPUT_DIR) | |
output_dir.mkdir(exist_ok=True) | |
img_path = output_dir / f'{img_id}.webp' | |
cv2.imwrite( | |
str(img_path), | |
cv2.cvtColor((img * 255).astype("uint8"), cv2.COLOR_RGB2BGR), | |
[int(cv2.IMWRITE_WEBP_QUALITY), WEBP_QUALITY] | |
) | |
with open(output_dir / f'{img_id}.txt',"w") as f: | |
tags = ", ".join(tags).replace("_"," ").strip() | |
f.write(tags) | |
async def get_image( | |
img_id: int, | |
client_session: AsyncClient, | |
): | |
url = f'https://danbooru.donmai.us/posts/{img_id}.json?login={DANBOORU_USERNAME}&api_key={DANBOORU_API_KEY}' | |
try: | |
res = await client_session.get(url) | |
if res.status_code == 404: | |
return None, '404' | |
reason = res.status_code | |
success = res.status_code == 200 | |
except HTTPError: | |
reason = traceback.format_exc() | |
success = False | |
if not success: | |
return None, reason | |
res = json.loads(res.text) | |
if res["file_ext"] not in ["jpg", "png", "webp"]: | |
return None, 'not image' | |
img_url = None | |
if 'file_url' in res: | |
img_url = res["file_url"] or img_url | |
elif 'source' in res and 'i.pximg.net' in res['source']: | |
img_url = res['source'] or img_url | |
if img_url is None: | |
return None, 'no img url' | |
tags = res["tag_string"] | |
tags = tags.split() | |
tags = [tag for tag in tags if tag not in bad_tags] | |
for tag in banned_tags: | |
if tag in tags: | |
return None, 'banned tag' | |
try: | |
img_res = await client_session.get(img_url, headers=headers_pixiv) | |
if img_res.status_code == 404: | |
return None, 'img url 404' | |
reason = f'img download {img_res.status_code}' | |
success = img_res.status_code == 200 | |
except HTTPError: | |
reason = traceback.format_exc() | |
success = False | |
if not success: | |
return None, reason | |
img_res = img_res.read() | |
img = cv2.imdecode(np.frombuffer(img_res, np.uint8), cv2.IMREAD_UNCHANGED) | |
return img, tags | |
async def download_image( | |
img_id: int, | |
client_session: AsyncClient, | |
event_loop: asyncio.AbstractEventLoop|None = None, | |
semaphore: asyncio.Semaphore|None = None, | |
proc_pool: ProcessPoolExecutor|None = None, | |
progress_bar: tqdm|None = None, | |
): | |
try: | |
for _ in range(RETRIES): | |
if semaphore is not None: | |
async with semaphore: | |
img, tags = await get_image(img_id, client_session) | |
else: | |
img, tags = await get_image(img_id, client_session) | |
if img is None: | |
await asyncio.sleep(TIMEOUT*2) | |
else: | |
break | |
else: | |
return | |
if event_loop is not None and proc_pool is not None: | |
await event_loop.run_in_executor(proc_pool, save_img, img_id, img, tags) | |
else: | |
save_img(img_id, img, tags) | |
except Exception as e: | |
print(e) | |
print(traceback.format_exc()) | |
finally: | |
if progress_bar is not None: | |
progress_bar.update() | |
async def test(): | |
test_id = 6600_000 | |
async with AsyncClient() as client_session: | |
await download_image(test_id, client_session) | |
async def test_batch(size=32): | |
test_id = 6600_000 | |
test_end = 6600_000 - size | |
semaphore = asyncio.Semaphore(32) | |
pool = ProcessPoolExecutor(16) | |
loop = asyncio.get_event_loop() | |
pbar = tqdm(total=test_id - test_end, desc='download danbooru', leave=False) | |
async with AsyncClient(timeout=Timeout(timeout=TIMEOUT)) as client_session: | |
tasks = [ | |
download_image( | |
i, client_session, loop, | |
semaphore=semaphore, proc_pool=pool, progress_bar=pbar | |
) | |
for i in range(test_id, test_end, -1) | |
] | |
await asyncio.gather(*tasks) | |
print() | |
async def main(start_id, end_id): | |
semaphore = asyncio.Semaphore(DOWNLOAD_WORKERS) | |
pool = ProcessPoolExecutor(PROCESS_WORKERS) | |
loop = asyncio.get_event_loop() | |
pbar = tqdm(total=start_id - end_id, desc='download danbooru', leave=False) | |
async with AsyncClient(timeout=Timeout(timeout=TIMEOUT)) as client_session: | |
tasks = [ | |
download_image( | |
i, client_session, loop, | |
semaphore=semaphore, proc_pool=pool, progress_bar=pbar | |
) | |
for i in range(start_id, end_id, -1) | |
] | |
await asyncio.gather(*tasks) | |
print() | |
if __name__ == '__main__': | |
asyncio.run(main(6_600_000, 5_000_000)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment