Repository: spark Updated Branches: refs/heads/master 683ffe062 -> 4c5269f1a
[SPARK-22370][SQL][PYSPARK] Config values should be captured in Driver. ## What changes were proposed in this pull request? `ArrowEvalPythonExec` and `FlatMapGroupsInPandasExec` are refering config values of `SQLConf` in function for `mapPartitions`/`mapPartitionsInternal`, but we should capture them in Driver. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN <[email protected]> Closes #19587 from ueshin/issues/SPARK-22370. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4c5269f1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4c5269f1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4c5269f1 Branch: refs/heads/master Commit: 4c5269f1aa529e6a397b68d6dc409d89e32685bd Parents: 683ffe0 Author: Takuya UESHIN <[email protected]> Authored: Sat Oct 28 18:33:09 2017 +0100 Committer: Wenchen Fan <[email protected]> Committed: Sat Oct 28 18:33:09 2017 +0100 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 20 ++++++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++++++ .../execution/python/ArrowEvalPythonExec.scala | 6 ++++-- .../python/FlatMapGroupsInPandasExec.scala | 3 ++- 4 files changed, 32 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4c5269f1/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 98afae6..8ed37c9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3476,6 +3476,26 @@ class VectorizedUDFTests(ReusedPySparkTestCase): expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) self.assertEquals(expected, ts) + def test_vectorized_udf_check_config(self): + from pyspark.sql.functions import pandas_udf, col + orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) + self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) + try: + df = self.spark.range(10, numPartitions=1) + + @pandas_udf(returnType=LongType()) + def check_records_per_batch(x): + self.assertTrue(x.size <= 3) + return x + + result = df.select(check_records_per_batch(col("id"))) + self.assertEquals(df.collect(), result.collect()) + finally: + if orig_value is None: + self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") + else: + self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): http://git-wip-us.apache.org/repos/asf/spark/blob/4c5269f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d21b4af..ddf2cbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -25,6 +25,12 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => + /** + * The active config object within the current scope. + * Note that if you want to refer config values during execution, you have to capture them + * in Driver and use the captured values in Executors. + * See [[SQLConf.get]] for more information. + */ def conf: SQLConf = SQLConf.get def output: Seq[Attribute] http://git-wip-us.apache.org/repos/asf/spark/blob/4c5269f1/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 0db463a..bcda2da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -61,6 +61,9 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends EvalPythonExec(udfs, output, child) { + private val batchSize = conf.arrowMaxRecordsPerBatch + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + protected override def evaluate( funcs: Seq[ChainedPythonFunctions], bufferSize: Int, @@ -73,13 +76,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex .map { case (attr, i) => attr.withName(s"_$i") }) - val batchSize = conf.arrowMaxRecordsPerBatch // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { http://git-wip-us.apache.org/repos/asf/spark/blob/4c5269f1/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index cc93fda..e1e04e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -77,6 +77,7 @@ case class FlatMapGroupsInPandasExec( val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) val schema = StructType(child.schema.drop(groupingAttributes.length)) + val sessionLocalTimeZone = conf.sessionLocalTimeZone inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { @@ -94,7 +95,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, conf.sessionLocalTimeZone) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema, sessionLocalTimeZone) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
