Repository: spark Updated Branches: refs/heads/master 5bed13a87 -> 056883e07
[SPARK-13266] [SQL] None read/writer options were not transalated to "null" ## What changes were proposed in this pull request? In Python, the `option` and `options` method of `DataFrameReader` and `DataFrameWriter` were sending the string "None" instead of `null` when passed `None`, therefore making it impossible to send an actual `null`. This fixes that problem. This is based on #11305 from mathieulongtin. ## How was this patch tested? Added test to readwriter.py. Author: Liang-Chi Hsieh <[email protected]> Author: mathieu longtin <[email protected]> Closes #12494 from viirya/py-df-none-option. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/056883e0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/056883e0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/056883e0 Branch: refs/heads/master Commit: 056883e070bd258d193fd4d783ab608a19b86c36 Parents: 5bed13a Author: Liang-Chi Hsieh <[email protected]> Authored: Fri Apr 22 09:19:36 2016 -0700 Committer: Davies Liu <[email protected]> Committed: Fri Apr 22 09:19:36 2016 -0700 ---------------------------------------------------------------------- python/pyspark/sql/readwriter.py | 9 ++++++--- python/pyspark/sql/tests.py | 3 +++ .../spark/sql/execution/datasources/csv/CSVOptions.scala | 6 +++++- 3 files changed, 14 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/056883e0/python/pyspark/sql/readwriter.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6c809d1..e39cf1a 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -33,10 +33,13 @@ __all__ = ["DataFrameReader", "DataFrameWriter"] def to_str(value): """ - A wrapper over str(), but convert bool values to lower case string + A wrapper over str(), but converts bool values to lower case strings. + If None is given, just returns None, instead of converting it to string "None". """ if isinstance(value, bool): return str(value).lower() + elif value is None: + return value else: return str(value) @@ -398,7 +401,7 @@ class DataFrameWriter(object): def option(self, key, value): """Adds an output option for the underlying data source. """ - self._jwrite = self._jwrite.option(key, value) + self._jwrite = self._jwrite.option(key, to_str(value)) return self @since(1.4) @@ -406,7 +409,7 @@ class DataFrameWriter(object): """Adds output options for the underlying data source. """ for k in options: - self._jwrite = self._jwrite.option(k, options[k]) + self._jwrite = self._jwrite.option(k, to_str(options[k])) return self @since(1.4) http://git-wip-us.apache.org/repos/asf/spark/blob/056883e0/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3b1b294..42e2830 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -859,6 +859,9 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + csvpath = os.path.join(tempfile.mkdtemp(), 'data') + df.write.option('quote', None).format('csv').save(csvpath) + shutil.rmtree(tmpPath) def test_save_and_load_builder(self): http://git-wip-us.apache.org/repos/asf/spark/blob/056883e0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 7b9d3b6..80a0ad7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -29,6 +29,7 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str val paramValue = parameters.get(paramName) paramValue match { case None => default + case Some(null) => default case Some(value) if value.length == 0 => '\u0000' case Some(value) if value.length == 1 => value.charAt(0) case _ => throw new RuntimeException(s"$paramName cannot be more than one character") @@ -39,6 +40,7 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str val paramValue = parameters.get(paramName) paramValue match { case None => default + case Some(null) => default case Some(value) => try { value.toInt } catch { @@ -50,7 +52,9 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str private def getBool(paramName: String, default: Boolean = false): Boolean = { val param = parameters.getOrElse(paramName, default.toString) - if (param.toLowerCase == "true") { + if (param == null) { + default + } else if (param.toLowerCase == "true") { true } else if (param.toLowerCase == "false") { false --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
