Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/19646#discussion_r148695963
--- Diff: python/pyspark/sql/session.py ---
@@ -512,9 +512,39 @@ def createDataFrame(self, data, schema=None,
samplingRatio=None, verifySchema=Tr
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
+ import numpy as np
+
+ # Convert pandas.DataFrame to list of numpy records
+ np_records = data.to_records(index=False)
+
+ # Check if any columns need to be fixed for Spark to infer
properly
+ record_type_list = None
+ if schema is None and len(np_records) > 0:
+ cur_dtypes = np_records[0].dtype
+ col_names = cur_dtypes.names
+ record_type_list = []
+ has_rec_fix = False
+ for i in xrange(len(cur_dtypes)):
+ curr_type = cur_dtypes[i]
+ # If type is a datetime64 timestamp, convert to
microseconds
+ # NOTE: if dtype is M8[ns] then np.record.tolist()
will output values as longs,
+ # this conversion will lead to an output of py
datetime objects, see SPARK-22417
+ if curr_type == np.dtype('M8[ns]'):
+ curr_type = 'M8[us]'
+ has_rec_fix = True
+ record_type_list.append((str(col_names[i]), curr_type))
+ if not has_rec_fix:
+ record_type_list = None
--- End diff --
Shall we put this into an internal method?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]