Created
December 22, 2022 07:07
-
-
Save KohakuBlueleaf/420bb7febecd955aee07380024eef4c0 to your computer and use it in GitHub Desktop.
A very simple asyncio downloader for sbucaptions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, sys | |
from time import time_ns | |
from json import load | |
from io import BytesIO | |
import math | |
from PIL import Image | |
import asyncio | |
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor | |
from aiohttp import ClientSession, ClientTimeout | |
from tqdm import tqdm | |
if __name__ == '__main__': | |
import webdataset as wds | |
def load_metedatas(file_string): | |
with open(file_string, 'r') as f: | |
data = load(f) | |
img = data['image_urls'] | |
cap = data['captions'] | |
assert len(img) == len(cap) | |
return list(zip(img, cap)) | |
async def download_task( | |
url: str, | |
cap: str, | |
session: ClientSession, | |
queue: asyncio.Queue | |
): | |
for _ in range(3): | |
try: | |
async with session.get(url) as response: | |
match response.status: | |
case 200: | |
data = await response.read() | |
await queue.put((data, cap)) | |
return | |
case 404|410: | |
break | |
case _: | |
continue | |
except Exception as e: | |
print(f'\r{e}', end='\r') | |
await queue.put(None) | |
async def download_all( | |
metas: list[tuple[str, str]], | |
queue: asyncio.Queue | |
): | |
timeout = ClientTimeout(total=None, sock_read=20, sock_connect=20) | |
async with ClientSession(timeout=timeout, read_bufsize=2**20) as session: | |
all_task = [download_task(url, cap, session, queue) for url, cap in metas] | |
await asyncio.gather(*all_task) | |
await queue.put('Finish') | |
def process_img(data, caption, img_id): | |
img = Image.open(BytesIO(data)) | |
img.thumbnail((256, 256), Image.Resampling.BOX) | |
x, y = img.size | |
if x>y: | |
result = Image.new('RGB', (256, 256), (255, 255, 255)) | |
result.paste(img, (0, (256-y)//2)) | |
img = result | |
elif y>x: | |
result = Image.new('RGB', (256, 256), (255, 255, 255)) | |
result.paste(img, ((256-x)//2, 0)) | |
img = result | |
img_bin = BytesIO() | |
img.save(img_bin, format='jpeg') | |
return { | |
'__key__': f'sample-{img_id}', | |
'jpg': img_bin.getvalue(), | |
'text': caption | |
} | |
async def save_task( | |
total_len: int, | |
queue: asyncio.Queue | |
): | |
loop = asyncio.get_running_loop() | |
executor = ProcessPoolExecutor(60) | |
failed = 0 | |
datas: list[tuple[int, bytes, str]] = [] | |
tar_data = [] | |
for i in tqdm(range(total_len), total=total_len, leave=False): | |
item = await queue.get() | |
match item: | |
case (bytes(data), str(cap)): | |
datas.append((i, data, cap)) | |
case None: | |
failed += 1 | |
case 'Finish': | |
break | |
if len(datas) > 512: | |
tar_data += await asyncio.gather(*( | |
loop.run_in_executor(executor, process_img, img, cap, idx) | |
for idx, img, cap in datas | |
)) | |
datas = [] | |
tar_data += await asyncio.gather(*( | |
loop.run_in_executor(executor, process_img, img, cap, idx) | |
for idx, img, cap in datas | |
)) | |
return tar_data, failed | |
async def download_shard(metas, shard_id): | |
queue = asyncio.Queue() | |
asyncio.ensure_future(download_all(metas, queue)) | |
tar_data, failed = await save_task(len(metas), queue) | |
writer = wds.TarWriter(f'./sbucaptions/{shard_id:06}.tar') | |
for data in tar_data: | |
writer.write(data) | |
async def main(): | |
metas = load_metedatas('./sbu-captions-all.json') | |
total_len = len(metas) | |
shards = math.ceil(total_len/10000) | |
for i in tqdm(range(shards), total=shards, leave=False): | |
await download_shard(metas[i*10000:(i+1)*10000], i) | |
if __name__ == '__main__': | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment