Skip to content

Instantly share code, notes, and snippets.

Convert a RDD of pandas DataFrames to a single Spark DataFrame using Arrow and without collecting all data in the driver.
import pandas as pd
def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, parallelism=1):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType, Row, DataType, StringType, StructType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version
require_minimum_pandas_version()
require_minimum_pyarrow_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
elif isinstance(schema, DataType):
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [to_arrow_type(TimestampType())
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
for t in pdf.dtypes]
# Slice the DataFrame to be batched
step = -(-len(pdf) // parallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
# Create Arrow record batches
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
timezone)
for pdf_slice in pdf_slices]
return map(bytearray, map(ArrowSerializer().dumps, batches))
def createFromPandasDataframesRDD(self, prdd, schema=None, timezone=None):
from pyspark.sql.types import from_arrow_schema
from pyspark.sql.dataframe import DataFrame
from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer
# Map rdd of pandas dataframes to arrow record batches
prdd = prdd.filter(lambda x: isinstance(x, pd.DataFrame)).cache()
# If schema is not defined, get from the first dataframe
if schema is None:
schema = [str(x) if not isinstance(x, basestring) else
(x.encode('utf-8') if not isinstance(x, str) else x)
for x in prdd.map(lambda x: x.columns).first()]
prdd = prdd.flatMap(lambda x: _dataframe_to_arrow_record_batch(x, schema=schema, timezone=timezone))
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
struct = from_arrow_schema(ArrowSerializer().loads(prdd.first()).schema)
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
# Create the Spark DataFrame directly from the Arrow data and schema
jrdd = prdd._to_java_object_rdd()
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
jrdd, schema.json(), self._wrapped._jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df
from pyspark.sql import SparkSession
SparkSession.createFromPandasDataframesRDD = createFromPandasDataframesRDD
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@tahashmi
Copy link

Thank you so much @linar-jether ! It's very helpful.
I want to avoid Pandas conversion because my data is in Arrow RecordBatches on all worker nodes.

@ashish615
Copy link

@linar-jether @tahashmi while running your code I am facing below error
Traceback (most recent call last):
File "recordbatch.py", line 48, in
spark._wrapped._jsqlContext)
AttributeError: 'SparkSession' object has no attribute '_wrapped'
Am I missing something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment