Skip to content

Instantly share code, notes, and snippets.

@KohakuBlueleaf
Created August 29, 2023 05:02
Show Gist options
  • Save KohakuBlueleaf/a2aeb10fad00eae1f71ec85a7cff30b8 to your computer and use it in GitHub Desktop.
Save KohakuBlueleaf/a2aeb10fad00eae1f71ec85a7cff30b8 to your computer and use it in GitHub Desktop.
A danbooru crawler for making large datasets
import asyncio
import cv2
import json
import traceback
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
import numpy as np
from httpx import AsyncClient, HTTPError, Timeout
from tqdm import tqdm
TARGET_RES = 1024
MIN_RES = 512
MAX_PIXEL = 30000000
OUTPUT_DIR = "./6600K-1600K"
WEBP_QUALITY = 90
TIMEOUT = 2.5
RETRIES = 5
DOWNLOAD_WORKERS = 64
PROCESS_WORKERS = 32
DANBOORU_USERNAME = ""
DANBOORU_API_KEY = ""
headers_pixiv = {
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36",
'referer': 'https://www.pixiv.net/'
}
banned_tags = ['furry', "realistic", "3d", "1940s_(style)","1950s_(style)","1960s_(style)","1970s_(style)","1980s_(style)","1990s_(style)","retro_artstyle","screentones","pixel_art","magazine_scan","scan"]
bad_tags = ["absurdres", "jpeg_artifacts", "highres", "translation_request", "translated", "commentary", "commentary_request", "commentary_typo", "character_request", "bad_id", "bad_link", "bad_pixiv_id", "bad_twitter_id", "bad_tumblr_id", "bad_deviantart_id", "bad_nicoseiga_id", "md5_mismatch", "cosplay_request", "artist_request", "wide_image", "author_request", "artist_name"]
def rescale(
image: np.ndarray,
output_size: int
) -> np.ndarray:
h,w = image.shape[:2]
r = max(output_size / h, output_size / w)
new_h, new_w = int(h * r), int(w * r)
return cv2.resize(image,(new_w, new_h))
def save_img(
img_id: int,
img: np.ndarray,
tags: list[str]
) -> None:
'''
Save image to target size and directory
'''
img = img.astype(np.float32) / np.iinfo(img.dtype).max
if min(img.shape[:2]) < MIN_RES:
return None
if img.shape[0]*img.shape[1] > MAX_PIXEL:
return None
if img.shape[-1] == 4:
alpha = img[:, :, -1][:, :, np.newaxis]
img = (1 - alpha) * 1 + alpha * img[:, :, :-1]
if len(img.shape) < 3 or img.shape[-1] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if min(img.shape[:2]) > TARGET_RES:
img = rescale(img, TARGET_RES)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(exist_ok=True)
img_path = output_dir / f'{img_id}.webp'
cv2.imwrite(
str(img_path),
cv2.cvtColor((img * 255).astype("uint8"), cv2.COLOR_RGB2BGR),
[int(cv2.IMWRITE_WEBP_QUALITY), WEBP_QUALITY]
)
with open(output_dir / f'{img_id}.txt',"w") as f:
tags = ", ".join(tags).replace("_"," ").strip()
f.write(tags)
async def get_image(
img_id: int,
client_session: AsyncClient,
):
url = f'https://danbooru.donmai.us/posts/{img_id}.json?login={DANBOORU_USERNAME}&api_key={DANBOORU_API_KEY}'
try:
res = await client_session.get(url)
if res.status_code == 404:
return None, '404'
reason = res.status_code
success = res.status_code == 200
except HTTPError:
reason = traceback.format_exc()
success = False
if not success:
return None, reason
res = json.loads(res.text)
if res["file_ext"] not in ["jpg", "png", "webp"]:
return None, 'not image'
img_url = None
if 'file_url' in res:
img_url = res["file_url"] or img_url
elif 'source' in res and 'i.pximg.net' in res['source']:
img_url = res['source'] or img_url
if img_url is None:
return None, 'no img url'
tags = res["tag_string"]
tags = tags.split()
tags = [tag for tag in tags if tag not in bad_tags]
for tag in banned_tags:
if tag in tags:
return None, 'banned tag'
try:
img_res = await client_session.get(img_url, headers=headers_pixiv)
if img_res.status_code == 404:
return None, 'img url 404'
reason = f'img download {img_res.status_code}'
success = img_res.status_code == 200
except HTTPError:
reason = traceback.format_exc()
success = False
if not success:
return None, reason
img_res = img_res.read()
img = cv2.imdecode(np.frombuffer(img_res, np.uint8), cv2.IMREAD_UNCHANGED)
return img, tags
async def download_image(
img_id: int,
client_session: AsyncClient,
event_loop: asyncio.AbstractEventLoop|None = None,
semaphore: asyncio.Semaphore|None = None,
proc_pool: ProcessPoolExecutor|None = None,
progress_bar: tqdm|None = None,
):
try:
for _ in range(RETRIES):
if semaphore is not None:
async with semaphore:
img, tags = await get_image(img_id, client_session)
else:
img, tags = await get_image(img_id, client_session)
if img is None:
await asyncio.sleep(TIMEOUT*2)
else:
break
else:
return
if event_loop is not None and proc_pool is not None:
await event_loop.run_in_executor(proc_pool, save_img, img_id, img, tags)
else:
save_img(img_id, img, tags)
except Exception as e:
print(e)
print(traceback.format_exc())
finally:
if progress_bar is not None:
progress_bar.update()
async def test():
test_id = 6600_000
async with AsyncClient() as client_session:
await download_image(test_id, client_session)
async def test_batch(size=32):
test_id = 6600_000
test_end = 6600_000 - size
semaphore = asyncio.Semaphore(32)
pool = ProcessPoolExecutor(16)
loop = asyncio.get_event_loop()
pbar = tqdm(total=test_id - test_end, desc='download danbooru', leave=False)
async with AsyncClient(timeout=Timeout(timeout=TIMEOUT)) as client_session:
tasks = [
download_image(
i, client_session, loop,
semaphore=semaphore, proc_pool=pool, progress_bar=pbar
)
for i in range(test_id, test_end, -1)
]
await asyncio.gather(*tasks)
print()
async def main(start_id, end_id):
semaphore = asyncio.Semaphore(DOWNLOAD_WORKERS)
pool = ProcessPoolExecutor(PROCESS_WORKERS)
loop = asyncio.get_event_loop()
pbar = tqdm(total=start_id - end_id, desc='download danbooru', leave=False)
async with AsyncClient(timeout=Timeout(timeout=TIMEOUT)) as client_session:
tasks = [
download_image(
i, client_session, loop,
semaphore=semaphore, proc_pool=pool, progress_bar=pbar
)
for i in range(start_id, end_id, -1)
]
await asyncio.gather(*tasks)
print()
if __name__ == '__main__':
asyncio.run(main(6_600_000, 5_000_000))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment