Created
March 16, 2023 07:08
-
-
Save tuulos/6c1f957cc49e44c277a4565dfebd04d7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from metaflow import FlowSpec, step, S3, Flow, Parameter, profile, kubernetes, conda, conda_base | |
# change columns according to your schema (or remove column list to load all) | |
COLUMNS = ['VendorID', 'tpep_pickup_datetime', 'tpep_dropoff_datetime'] | |
# group parquet files as 1GB batches | |
def shard_data(src, batch_size=1_000_000_000): | |
with S3() as s3: | |
objs = s3.list_recursive([src]) | |
random.shuffle(objs) | |
while objs: | |
size = 0 | |
batch = [] | |
while objs and size < batch_size: | |
obj = objs.pop() | |
batch.append(obj.url) | |
size += obj.size | |
yield batch | |
@conda_base(python='3.8.10') | |
class ShardedDataFlow(FlowSpec): | |
s3root = Parameter('s3root', help="S3 root for data") | |
@step | |
def start(self): | |
self.shards = list(shard_data(self.s3root)) | |
self.next(self.process_shard_arrow, foreach='shards') | |
@kubernetes(memory=12000) | |
@conda(libraries={'pyarrow': '5.0.0'}) | |
@step | |
def process_shard_arrow(self): | |
import pyarrow | |
from pyarrow.parquet import ParquetFile | |
self.shard_files = self.input | |
with S3() as s3: | |
with profile('loading data'): | |
objs = s3.get_many(self.shard_files) | |
with profile('deserializing parquet'): | |
table = pyarrow.concat_tables([ParquetFile(obj.path).read(columns=COLUMNS) for obj in objs]) | |
self.arrow_table_len = len(table) | |
self.next(self.process_shard_polars) | |
@kubernetes(memory=12000) | |
@conda(libraries={'polars': '0.16.13'}) | |
@step | |
def process_shard_polars(self): | |
import polars | |
self.shard_files = self.input | |
with S3() as s3: | |
with profile('loading data'): | |
objs = s3.get_many(self.shard_files) | |
with profile('deserializing polars'): | |
table = polars.concat([polars.read_parquet(obj.path, columns=COLUMNS) for obj in objs]) | |
print('table', table) | |
self.polars_table_len = len(table) | |
self.next(self.join) | |
@step | |
def join(self, inputs): | |
print('total rows', sum(inp.arrow_table_len for inp in inputs)) | |
self.next(self.end) | |
@step | |
def end(self): | |
pass | |
if __name__ == '__main__': | |
ShardedDataFlow() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment