Github user BryanCutler commented on a diff in the pull request:
https://github.com/apache/spark/pull/19459#discussion_r145291702
--- Diff: python/pyspark/sql/session.py ---
@@ -414,6 +415,39 @@ def _createFromLocal(self, data, schema):
data = [schema.toInternal(row) for row in data]
return self._sc.parallelize(data), schema
+ def _createFromPandasWithArrow(self, pdf, schema):
+ """
+ Create a DataFrame from a given pandas.DataFrame by slicing it
into partitions, converting
+ to Arrow data, then sending to the JVM to parallelize. If a schema
is passed in, the
+ data types will be used to coerce the data in Pandas to Arrow
conversion.
+ """
+ from pyspark.serializers import ArrowSerializer
+ from pyspark.sql.types import from_arrow_schema, to_arrow_schema
+ import pyarrow as pa
+
+ # Slice the DataFrame into batches
+ step = -(-len(pdf) // self.sparkContext.defaultParallelism) #
round int up
+ pdf_slices = (pdf[start:start + step] for start in xrange(0,
len(pdf), step))
+ arrow_schema = to_arrow_schema(schema) if schema is not None else
None
+ batches = [pa.RecordBatch.from_pandas(pdf_slice,
schema=arrow_schema, preserve_index=False)
+ for pdf_slice in pdf_slices]
+
+ # Verify schema, there will be at least 1 batch from
pandas.DataFrame
+ schema_from_arrow = from_arrow_schema(batches[0].schema)
+ if schema is not None and schema != schema_from_arrow:
+ raise ValueError("Supplied schema does not match result from
Arrow\nsupplied: " +
+ "%s\n!=\nfrom Arrow: %s" % (str(schema),
str(schema_from_arrow)))
--- End diff --
@ueshin and @HyukjinKwon after thinking about what to do when the schema is
not equal, I have some concerns:
1. Fallback to `createDataFrame` without Arrow - I implemented this and
works fine, but there is no logging in python (afaik) so my concern is that it
does this silently and causes bad performance and the user will not know why.
2. Cast types using `astype` similar to `ArrowPandasSerializer.dump_stream`
- The issue I see with that is if there are null values and ints have been
promoted to floats, this works fine in `dump_stream` because we are working
with pd.Series and pyarrow allows us to pass a validity mask, which ignores
the filled values. There aren't options to pass in masks for pd.DataFrames, so
I believe it will try to interpret whatever fill values are there and cause an
error. I can look into this more though.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]