Last active
April 7, 2025 05:26
-
-
Save jwickens/7be655d478f546f8262de0037a70b7ce to your computer and use it in GitHub Desktop.
Pytorch-like iterable Dataset example backed by async iterator (postgres)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncpg | |
async def create_pool(): | |
pool = await asyncpg.create_pool( | |
database="research", | |
user="jwickens", | |
setup=setup_connection, | |
min_size=32, | |
max_size=32 | |
) | |
return pool | |
async def setup_connection(connection): | |
await connection.execute("set search_path to endofday") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime, date | |
from torch.utils.data import DataLoader | |
from typing import NamedTuple, List, Optional, AsyncIterator, Iterator | |
from enum import Enum | |
from asyncpg.pool import Pool | |
from asyncio import AbstractEventLoop | |
import asyncio | |
from utils import wrap_async_iter | |
class OHLCV(NamedTuple): | |
open: float | |
high: float | |
low: float | |
close: float | |
volume: int | |
class DatedOHLCV(NamedTuple): | |
date: date | |
ohlcv: OHLCV | |
class OHLCVSequence(NamedTuple): | |
ticker: str | |
sequence: List[DatedOHLCV] | |
class OHLCVSequenceDataset: | |
class Type(Enum): | |
training = 'training_sample' | |
test = 'test_sample' | |
batch_size: int | |
sequence_length: int | |
percent_sample: float | |
pool: Pool | |
loop: AbstractEventLoop | |
type: Type | |
def __init__(self, | |
type: Type, | |
loop: AbstractEventLoop, | |
pool: Pool, | |
percent_sample: float = 100, | |
sequence_length: int = 32, | |
batch_size: int = 32 | |
): | |
super().__init__() | |
self.type = type | |
self.loop = loop | |
self.batch_size = batch_size | |
self.pool = pool | |
self.percent_sample = percent_sample | |
self.sequence_length = sequence_length | |
async def iter_start_point_cursor(self): | |
async with self.pool.acquire() as connection: | |
async with connection.transaction(): | |
async for record in connection.cursor(f""" | |
SELECT ticker, ticker_id, date_id | |
FROM stock_data, {self.type.value} TABLESAMPLE SYSTEM ({self.percent_sample}) | |
WHERE stock_data.id = {self.type.value}.stock_data_id | |
"""): | |
yield record | |
async def get_sequence(self, start_point) -> OHLCVSequence: | |
ticker, ticker_id, date_id = start_point | |
async with self.pool.acquire() as connection: | |
result = await connection.fetch(f""" | |
SELECT | |
date, | |
open, | |
high, | |
low, | |
close, | |
volume | |
FROM stock_data | |
WHERE stock_data.ticker_id = {ticker_id} | |
AND stock_data.date_id >= {date_id} | |
ORDER BY stock_data.date_id | |
LIMIT {self.sequence_length} | |
""") | |
def convert_row(row): | |
return DatedOHLCV( | |
date=row[0], | |
ohlcv=OHLCV(*row[1:]) | |
) | |
return OHLCVSequence( | |
ticker=ticker, | |
sequence=list(map(convert_row, result)) | |
) | |
async def __aiter__(self) -> AsyncIterator[List[OHLCVSequence]]: | |
start_points: List = [] | |
async def map_start_points(): | |
tasks = map(self.get_sequence, start_points) | |
batch = await asyncio.gather(*tasks) | |
start_points.clear() | |
return batch | |
async for start_point in self.iter_start_point_cursor(): | |
start_points.append(start_point) | |
if len(start_points) == self.batch_size: | |
batch = await map_start_points() | |
yield batch | |
batch = await map_start_points() | |
if len(batch) > 0: | |
yield batch | |
def __iter__(self) -> Iterator[List[OHLCVSequence]]: | |
return wrap_async_iter(self, self.loop) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataset import OHLCVSequenceDataset | |
from torch.utils.data import DataLoader | |
from datetime import datetime | |
from database import create_pool | |
import threading | |
import asyncio | |
loop = asyncio.get_event_loop() | |
# create an asyncio loop that runs in the background to | |
# serve our asyncio needs | |
threading.Thread(target=loop.run_forever, daemon=True).start() | |
pool = asyncio.run_coroutine_threadsafe(create_pool(), loop=loop).result() | |
start = datetime.now() | |
d = OHLCVSequenceDataset( | |
type=OHLCVSequenceDataset.Type.test, | |
percent_sample=0.1, | |
loop=loop, | |
pool=pool) | |
i = 0 | |
for x in d: | |
if i == 0: | |
print(f"first batch in {datetime.now() - start}") | |
print(len(x)) | |
print(len(x[0])) | |
print(x[0].ticker) | |
print(x[0].sequence[0].date) | |
print(x[0].sequence[0].ohlcv) | |
if i == 1: | |
print(f"second batches in {datetime.now() - start}") | |
if i == 10: | |
print(f"10 batches in {datetime.now() - start}") | |
i += 1 | |
print(f"{i} batches in {datetime.now() - start}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://stackoverflow.com/a/55164899 | |
def wrap_async_iter(ait, loop): | |
"""Wrap an asynchronous iterator into a synchronous one""" | |
q = queue.Queue() | |
_END = object() | |
def yield_queue_items(): | |
while True: | |
next_item = q.get() | |
if next_item is _END: | |
break | |
yield next_item | |
# After observing _END we know the aiter_to_queue coroutine has | |
# completed. Invoke result() for side effect - if an exception | |
# was raised by the async iterator, it will be propagated here. | |
async_result.result() | |
async def aiter_to_queue(): | |
try: | |
async for item in ait: | |
q.put(item) | |
finally: | |
q.put(_END) | |
async_result = asyncio.run_coroutine_threadsafe(aiter_to_queue(), loop) | |
return yield_queue_items() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment