Created
May 11, 2023 13:41
-
-
Save jacobtomlinson/071d652692a146e7d5ee835dab9085de to your computer and use it in GitHub Desktop.
Apache Beam Dask Limitation MRE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import warnings | |
import time | |
from contextlib import contextmanager | |
import apache_beam as beam | |
from apache_beam.options.pipeline_options import PipelineOptions | |
from apache_beam.runners.dask.dask_runner import DaskRunner | |
from dask.distributed import Client | |
from distributed.versions import VersionMismatchWarning | |
from dask_kubernetes.operator import KubeCluster | |
# Reduce output noise | |
for warn_type in [VersionMismatchWarning, FutureWarning]: | |
warnings.filterwarnings("ignore", category=warn_type) | |
@contextmanager | |
def daskcluster(): | |
"""Get a Dask cluster however you prefer.""" | |
n_workers = 256 | |
with KubeCluster( | |
name="beam-test", | |
n_workers=n_workers, | |
env={"EXTRA_PIP_PACKAGES": "apache-beam"}, | |
resources={ | |
"requests": {"cpu": "500m", "memory": "1Gi"}, | |
"limits": {"cpu": "1000m", "memory": "1.85Gi"}, | |
}, | |
shutdown_on_close=False, # Leave running for reuse next time | |
) as cluster: | |
cluster.scale( | |
n_workers | |
) # Ensure the right number of workers if reusing a cluster | |
print(f"Dashboard at: {cluster.dashboard_link}") | |
with Client(cluster) as client: | |
print(f"Waiting for all {n_workers} workers") | |
client.wait_for_workers(n_workers=n_workers) | |
yield client | |
class NoopDoFn(beam.DoFn): | |
def process(self, item): | |
time.sleep(10) | |
return [item] | |
def main() -> None: | |
# If this is 199 I get one task per file per stage, if this is 200 I get max 100 tasks per stage | |
n_items = 200 | |
with daskcluster() as client: | |
# Start a beam pipeline with a dask backend, and its options. | |
print("Running Pipeline") | |
pipeline = beam.Pipeline( | |
runner=DaskRunner(), | |
options=PipelineOptions( | |
["--dask_client_address", client.cluster.scheduler_address] | |
), | |
) | |
( | |
pipeline | |
| "Create collection" >> beam.Create(range(n_items)) | |
| "Noop 1" >> beam.ParDo(NoopDoFn()) | |
| "Noop 2" >> beam.ParDo(NoopDoFn()) | |
| "Noop 3" >> beam.ParDo(NoopDoFn()) | |
) | |
result = pipeline.run() | |
result.wait_until_finish() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment