Skip to content

Instantly share code, notes, and snippets.

@KohakuBlueleaf
Created December 22, 2022 07:07
Show Gist options
  • Save KohakuBlueleaf/420bb7febecd955aee07380024eef4c0 to your computer and use it in GitHub Desktop.
Save KohakuBlueleaf/420bb7febecd955aee07380024eef4c0 to your computer and use it in GitHub Desktop.
A very simple asyncio downloader for sbucaptions
import os, sys
from time import time_ns
from json import load
from io import BytesIO
import math
from PIL import Image
import asyncio
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from aiohttp import ClientSession, ClientTimeout
from tqdm import tqdm
if __name__ == '__main__':
import webdataset as wds
def load_metedatas(file_string):
with open(file_string, 'r') as f:
data = load(f)
img = data['image_urls']
cap = data['captions']
assert len(img) == len(cap)
return list(zip(img, cap))
async def download_task(
url: str,
cap: str,
session: ClientSession,
queue: asyncio.Queue
):
for _ in range(3):
try:
async with session.get(url) as response:
match response.status:
case 200:
data = await response.read()
await queue.put((data, cap))
return
case 404|410:
break
case _:
continue
except Exception as e:
print(f'\r{e}', end='\r')
await queue.put(None)
async def download_all(
metas: list[tuple[str, str]],
queue: asyncio.Queue
):
timeout = ClientTimeout(total=None, sock_read=20, sock_connect=20)
async with ClientSession(timeout=timeout, read_bufsize=2**20) as session:
all_task = [download_task(url, cap, session, queue) for url, cap in metas]
await asyncio.gather(*all_task)
await queue.put('Finish')
def process_img(data, caption, img_id):
img = Image.open(BytesIO(data))
img.thumbnail((256, 256), Image.Resampling.BOX)
x, y = img.size
if x>y:
result = Image.new('RGB', (256, 256), (255, 255, 255))
result.paste(img, (0, (256-y)//2))
img = result
elif y>x:
result = Image.new('RGB', (256, 256), (255, 255, 255))
result.paste(img, ((256-x)//2, 0))
img = result
img_bin = BytesIO()
img.save(img_bin, format='jpeg')
return {
'__key__': f'sample-{img_id}',
'jpg': img_bin.getvalue(),
'text': caption
}
async def save_task(
total_len: int,
queue: asyncio.Queue
):
loop = asyncio.get_running_loop()
executor = ProcessPoolExecutor(60)
failed = 0
datas: list[tuple[int, bytes, str]] = []
tar_data = []
for i in tqdm(range(total_len), total=total_len, leave=False):
item = await queue.get()
match item:
case (bytes(data), str(cap)):
datas.append((i, data, cap))
case None:
failed += 1
case 'Finish':
break
if len(datas) > 512:
tar_data += await asyncio.gather(*(
loop.run_in_executor(executor, process_img, img, cap, idx)
for idx, img, cap in datas
))
datas = []
tar_data += await asyncio.gather(*(
loop.run_in_executor(executor, process_img, img, cap, idx)
for idx, img, cap in datas
))
return tar_data, failed
async def download_shard(metas, shard_id):
queue = asyncio.Queue()
asyncio.ensure_future(download_all(metas, queue))
tar_data, failed = await save_task(len(metas), queue)
writer = wds.TarWriter(f'./sbucaptions/{shard_id:06}.tar')
for data in tar_data:
writer.write(data)
async def main():
metas = load_metedatas('./sbu-captions-all.json')
total_len = len(metas)
shards = math.ceil(total_len/10000)
for i in tqdm(range(shards), total=shards, leave=False):
await download_shard(metas[i*10000:(i+1)*10000], i)
if __name__ == '__main__':
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment