Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/19349#discussion_r141250227
--- Diff: python/pyspark/serializers.py ---
@@ -211,33 +212,37 @@ def __repr__(self):
return "ArrowSerializer"
+def _create_batch(series):
+ import pyarrow as pa
+ # Make input conform to [(series1, type1), (series2, type2), ...]
+ if not isinstance(series, (list, tuple)) or \
+ (len(series) == 2 and isinstance(series[1], pa.DataType)):
+ series = [series]
+ series = ((s, None) if not isinstance(s, (list, tuple)) else s for s
in series)
+
+ # If a nullable integer series has been promoted to floating point
with NaNs, need to cast
+ # NOTE: this is not necessary with Arrow >= 0.7
+ def cast_series(s, t):
+ if t is None or s.dtype == t.to_pandas_dtype():
+ return s
+ else:
+ return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)
+
+ arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(),
type=t) for s, t in series]
+ return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in
xrange(len(arrs))])
+
+
class ArrowPandasSerializer(ArrowSerializer):
--- End diff --
Thanks! I'll remove it.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]