Repository: spark Updated Branches: refs/heads/branch-2.3 eb4fa551e -> 551ccfba5
[SPARK-23009][PYTHON] Fix for non-str col names to createDataFrame from Pandas ## What changes were proposed in this pull request? This the case when calling `SparkSession.createDataFrame` using a Pandas DataFrame that has non-str column labels. The column name conversion logic to handle non-string or unicode in python2 is: ``` if column is not any type of string: name = str(column) else if column is unicode in Python 2: name = column.encode('utf-8') ``` ## How was this patch tested? Added a new test with a Pandas DataFrame that has int column labels Author: Bryan Cutler <cutl...@gmail.com> Closes #20210 from BryanCutler/python-createDataFrame-int-col-error-SPARK-23009. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/551ccfba Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/551ccfba Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/551ccfba Branch: refs/heads/branch-2.3 Commit: 551ccfba529996e987c4d2e8d4dd61c4ab9a2e95 Parents: eb4fa55 Author: Bryan Cutler <cutl...@gmail.com> Authored: Wed Jan 10 14:55:24 2018 +0900 Committer: hyukjinkwon <gurwls...@gmail.com> Committed: Thu Jan 11 09:46:50 2018 +0900 ---------------------------------------------------------------------- python/pyspark/sql/session.py | 4 +++- python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/551ccfba/python/pyspark/sql/session.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 3e45747..604021c 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -648,7 +648,9 @@ class SparkSession(object): # If no schema supplied by user then get the names of columns only if schema is None: - schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns] + schema = [str(x) if not isinstance(x, basestring) else + (x.encode('utf-8') if not isinstance(x, str) else x) + for x in data.columns] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: http://git-wip-us.apache.org/repos/asf/spark/blob/551ccfba/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 13576ff..80a94a9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3532,6 +3532,15 @@ class ArrowTests(ReusedSQLTestCase): self.assertTrue(expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]) + def test_createDataFrame_with_int_col_names(self): + import numpy as np + import pandas as pd + pdf = pd.DataFrame(np.random.rand(4, 2)) + df, df_arrow = self._createDataFrame_toggle(pdf) + pdf_col_names = [str(c) for c in pdf.columns] + self.assertEqual(pdf_col_names, df.columns) + self.assertEqual(pdf_col_names, df_arrow.columns) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org