This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 307e46cc4df [SPARK-44532][CONNECT][SQL] Move ArrowUtils to sql/api
307e46cc4df is described below

commit 307e46cc4dfdad1f442e8c5c50ecb53c9ef7dc47
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Tue Jul 25 13:35:37 2023 -0400

    [SPARK-44532][CONNECT][SQL] Move ArrowUtils to sql/api
    
    ### What changes were proposed in this pull request?
    This PR moves `ArrowUtils` to `sql/api`. One method used for configuring 
python's arrow runner has been moved to `ArrowPythonRunner `.
    
    ### Why are the changes needed?
    ArrowUtils is used by connect's direct Arrow encoding (and a lot of other 
things in sql). We want to remove the connect scala client's catalyst 
dependency. We need to move ArrowUtil in order to do so.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing tests.
    
    Closes #42137 from hvanhovell/SPARK-44532.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../connect/client/arrow/ArrowDeserializer.scala   |  4 +--
 .../sql/connect/client/arrow/ArrowSerializer.scala |  4 +--
 sql/api/pom.xml                                    |  8 +++++
 .../org/apache/spark/sql/errors/ArrowErrors.scala  | 39 ++++++++++++++++++++++
 .../apache/spark/sql/errors/DataTypeErrors.scala   |  6 ++++
 .../org/apache/spark/sql/util/ArrowUtils.scala     | 19 +++--------
 sql/catalyst/pom.xml                               |  8 -----
 .../spark/sql/errors/QueryExecutionErrors.scala    | 20 -----------
 .../spark/sql/execution/arrow/ArrowWriter.scala    |  4 +--
 .../org/apache/spark/sql/execution/Columnar.scala  |  4 +--
 .../execution/python/AggregateInPandasExec.scala   |  3 +-
 .../sql/execution/python/ArrowEvalPythonExec.scala |  3 +-
 .../execution/python/ArrowEvalPythonUDTFExec.scala |  3 +-
 .../sql/execution/python/ArrowPythonRunner.scala   | 12 +++++++
 .../python/FlatMapCoGroupsInPandasExec.scala       |  3 +-
 .../python/FlatMapGroupsInPandasExec.scala         |  3 +-
 .../FlatMapGroupsInPandasWithStateExec.scala       |  3 +-
 .../sql/execution/python/MapInBatchExec.scala      |  3 +-
 .../python/WindowInPandasEvaluatorFactory.scala    |  4 +--
 19 files changed, 86 insertions(+), 67 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 01aba9cb0ce..4177a88ba52 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors}
 import org.apache.spark.sql.types.Decimal
 
 /**
@@ -341,7 +341,7 @@ object ArrowDeserializers {
         }
 
       case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
-        throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType)
+        throw DataTypeErrors.unsupportedDataTypeError(encoder.dataType)
 
       case _ =>
         throw new RuntimeException(
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index d29f90a6a19..9b39a75ceed 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -38,7 +38,7 @@ import 
org.apache.spark.sql.catalyst.DefinedByConstructorParams
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.errors.DataTypeErrors
 import org.apache.spark.sql.types.Decimal
 import org.apache.spark.sql.util.ArrowUtils
 
@@ -439,7 +439,7 @@ object ArrowSerializer {
         }
 
       case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
-        throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType)
+        throw DataTypeErrors.unsupportedDataTypeError(encoder.dataType)
 
       case _ =>
         throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) 
combination.")
diff --git a/sql/api/pom.xml b/sql/api/pom.xml
index 16db70180e3..312b5de55ba 100644
--- a/sql/api/pom.xml
+++ b/sql/api/pom.xml
@@ -64,6 +64,14 @@
             <groupId>org.antlr</groupId>
             <artifactId>antlr4-runtime</artifactId>
         </dependency>
+        <dependency>
+            <groupId>org.apache.arrow</groupId>
+            <artifactId>arrow-vector</artifactId>
+        </dependency>
+        <dependency>
+            <groupId>org.apache.arrow</groupId>
+            <artifactId>arrow-memory-netty</artifactId>
+        </dependency>
     </dependencies>
     <build>
         
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/ArrowErrors.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/ArrowErrors.scala
new file mode 100644
index 00000000000..59e35c9c9ba
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ArrowErrors.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.errors
+
+import org.apache.arrow.vector.types.pojo.ArrowType
+
+import org.apache.spark.SparkUnsupportedOperationException
+
+trait ArrowErrors {
+
+  def unsupportedArrowTypeError(typeName: ArrowType): 
SparkUnsupportedOperationException = {
+    new SparkUnsupportedOperationException(
+      errorClass = "UNSUPPORTED_ARROWTYPE",
+      messageParameters = Map("typeName" -> typeName.toString))
+  }
+
+  def duplicatedFieldNameInArrowStructError(
+      fieldNames: Seq[String]): SparkUnsupportedOperationException = {
+    new SparkUnsupportedOperationException(
+      errorClass = "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT",
+      messageParameters = Map("fieldNames" -> fieldNames.mkString("[", ", ", 
"]")))
+  }
+}
+
+object ArrowErrors extends ArrowErrors
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala
index fcc3086b573..1dd3cf3dd5d 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala
@@ -295,4 +295,10 @@ private[sql] object DataTypeErrors extends 
DataTypeErrorsBase {
       messageParameters = Map("operation" -> operation),
       cause = null)
   }
+
+  def unsupportedDataTypeError(typeName: DataType): 
SparkUnsupportedOperationException = {
+    new SparkUnsupportedOperationException(
+      errorClass = "UNSUPPORTED_DATATYPE",
+      messageParameters = Map("typeName" -> toSQLType(typeName)))
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
similarity index 91%
rename from 
sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index e880e973176..49990f7f033 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -26,8 +26,7 @@ import org.apache.arrow.vector.complex.MapVector
 import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, 
IntervalUnit, TimeUnit}
 import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
 
-import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.errors.{ArrowErrors, DataTypeErrors}
 import org.apache.spark.sql.types._
 
 private[sql] object ArrowUtils {
@@ -61,7 +60,7 @@ private[sql] object ArrowUtils {
     case _: YearMonthIntervalType => new 
ArrowType.Interval(IntervalUnit.YEAR_MONTH)
     case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
     case _ =>
-      throw QueryExecutionErrors.unsupportedDataTypeError(dt)
+      throw DataTypeErrors.unsupportedDataTypeError(dt)
   }
 
   def fromArrowType(dt: ArrowType): DataType = dt match {
@@ -86,7 +85,7 @@ private[sql] object ArrowUtils {
     case ArrowType.Null.INSTANCE => NullType
     case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => 
YearMonthIntervalType()
     case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => 
DayTimeIntervalType()
-    case _ => throw QueryExecutionErrors.unsupportedArrowTypeError(dt)
+    case _ => throw ArrowErrors.unsupportedArrowTypeError(dt)
   }
 
   /** Maps field from Spark to Arrow. NOTE: timeZoneId required for 
TimestampType */
@@ -172,16 +171,6 @@ private[sql] object ArrowUtils {
     }.toArray)
   }
 
-  /** Return Map with conf settings to be used in ArrowPythonRunner */
-  def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
-    val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> 
conf.sessionLocalTimeZone)
-    val pandasColsByName = 
Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
-      conf.pandasGroupedMapAssignColumnsByName.toString)
-    val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key 
->
-      conf.arrowSafeTypeConversion.toString)
-    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
-  }
-
   private def deduplicateFieldNames(
       dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match 
{
     case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, 
errorOnDuplicatedFieldNames)
@@ -190,7 +179,7 @@ private[sql] object ArrowUtils {
         st.names
       } else {
         if (errorOnDuplicatedFieldNames) {
-          throw 
QueryExecutionErrors.duplicatedFieldNameInArrowStructError(st.names)
+          throw ArrowErrors.duplicatedFieldNameInArrowStructError(st.names)
         }
         val genNawName = st.names.groupBy(identity).map {
           case (name, names) if names.length > 1 =>
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 91f25beb29f..81893e2fe5c 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -116,14 +116,6 @@
       <artifactId>univocity-parsers</artifactId>
       <type>jar</type>
     </dependency>
-    <dependency>
-      <groupId>org.apache.arrow</groupId>
-      <artifactId>arrow-vector</artifactId>
-    </dependency>
-    <dependency>
-      <groupId>org.apache.arrow</groupId>
-      <artifactId>arrow-memory-netty</artifactId>
-    </dependency>
     <dependency>
       <groupId>org.apache.datasketches</groupId>
       <artifactId>datasketches-java</artifactId>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 183c5425202..3fb14cd079f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -25,7 +25,6 @@ import java.time.temporal.ChronoField
 import java.util.concurrent.TimeoutException
 
 import com.fasterxml.jackson.core.{JsonParser, JsonToken}
-import org.apache.arrow.vector.types.pojo.ArrowType
 import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, Path}
 import org.apache.hadoop.fs.permission.FsPermission
 import org.codehaus.commons.compiler.{CompileException, 
InternalCompilerException}
@@ -1142,25 +1141,6 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase {
       messageParameters = Map("cost" -> cost))
   }
 
-  def unsupportedArrowTypeError(typeName: ArrowType): 
SparkUnsupportedOperationException = {
-    new SparkUnsupportedOperationException(
-      errorClass = "UNSUPPORTED_ARROWTYPE",
-      messageParameters = Map("typeName" -> typeName.toString))
-  }
-
-  def unsupportedDataTypeError(typeName: DataType): 
SparkUnsupportedOperationException = {
-    new SparkUnsupportedOperationException(
-      errorClass = "UNSUPPORTED_DATATYPE",
-      messageParameters = Map("typeName" -> toSQLType(typeName)))
-  }
-
-  def duplicatedFieldNameInArrowStructError(
-      fieldNames: Seq[String]): SparkUnsupportedOperationException = {
-    new SparkUnsupportedOperationException(
-      errorClass = "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT",
-      messageParameters = Map("fieldNames" -> fieldNames.mkString("[", ", ", 
"]")))
-  }
-
   def notSupportTypeError(dataType: DataType): Throwable = {
     new SparkException(
       errorClass = "_LEGACY_ERROR_TEMP_2100",
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index a55e4f0cfcd..534cd8f9ab0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -24,7 +24,7 @@ import org.apache.arrow.vector.complex._
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.errors.DataTypeErrors
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.ArrowUtils
 
@@ -83,7 +83,7 @@ object ArrowWriter {
       case (_: YearMonthIntervalType, vector: IntervalYearVector) => new 
IntervalYearWriter(vector)
       case (_: DayTimeIntervalType, vector: DurationVector) => new 
DurationWriter(vector)
       case (dt, _) =>
-        throw QueryExecutionErrors.unsupportedDataTypeError(dt)
+        throw DataTypeErrors.unsupportedDataTypeError(dt)
     }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
index 9932f2741b0..a2029816c23 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.errors.DataTypeErrors
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
 import org.apache.spark.sql.execution.datasources.V1WriteCommand
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -277,7 +277,7 @@ private object RowToColumnConverter {
       case dt: DecimalType => new DecimalConverter(dt)
       case mt: MapType => MapConverter(getConverterForType(mt.keyType, 
nullable = false),
         getConverterForType(mt.valueType, mt.valueContainsNull))
-      case unknown => throw 
QueryExecutionErrors.unsupportedDataTypeError(unknown)
+      case unknown => throw DataTypeErrors.unsupportedDataTypeError(unknown)
     }
 
     if (nullable) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index 50452a5d999..73560a596ca 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -31,7 +31,6 @@ import 
org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
 import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, 
UnaryExecNode}
 import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator
 import org.apache.spark.sql.types.{DataType, StructField, StructType}
-import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.util.Utils
 
 /**
@@ -102,7 +101,7 @@ case class AggregateInPandasExec(
 
     val sessionLocalTimeZone = conf.sessionLocalTimeZone
     val largeVarTypes = conf.arrowUseLargeVarTypes
-    val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+    val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
 
     val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 0e4a420b4b3..7db43a34a88 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ArrowUtils
 
 /**
  * Grouped a iterator into batches.
@@ -75,7 +74,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], 
resultAttrs: Seq[Attribute]
       evalType,
       conf.sessionLocalTimeZone,
       conf.arrowUseLargeVarTypes,
-      ArrowUtils.getPythonRunnerConfMap(conf),
+      ArrowPythonRunner.getPythonRunnerConfMap(conf),
       pythonMetrics,
       jobArtifactUUID)
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index 9c9ae0fca2d..9c0addfd2ae 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
 
 /**
@@ -49,7 +48,7 @@ case class ArrowEvalPythonUDTFExec(
   private val batchSize = conf.arrowMaxRecordsPerBatch
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val largeVarTypes = conf.arrowUseLargeVarTypes
-  private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
 
   override protected def evaluate(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index ea861df3c1f..d9bce96c477 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -56,3 +56,15 @@ class ArrowPythonRunner(
     "Pandas execution requires more than 4 bytes. Please set higher buffer. " +
       s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
 }
+
+object ArrowPythonRunner {
+  /** Return Map with conf settings to be used in ArrowPythonRunner */
+  def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
+    val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> 
conf.sessionLocalTimeZone)
+    val pandasColsByName = 
Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
+      conf.pandasGroupedMapAssignColumnsByName.toString)
+    val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key 
->
+      conf.arrowSafeTypeConversion.toString)
+    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
index 9ef133c6bea..bbfe97d1947 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
@@ -26,7 +26,6 @@ import 
org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, 
SparkPlan}
 import org.apache.spark.sql.execution.python.PandasGroupUtils._
-import org.apache.spark.sql.util.ArrowUtils
 
 
 /**
@@ -58,7 +57,7 @@ case class FlatMapCoGroupsInPandasExec(
   extends SparkPlan with BinaryExecNode with PythonSQLMetrics {
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
-  private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
   private val pandasFunction = func.asInstanceOf[PythonUDF].func
   private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index 0ae5a998943..f2d21ce8e96 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -26,7 +26,6 @@ import 
org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.python.PandasGroupUtils._
-import org.apache.spark.sql.util.ArrowUtils
 
 
 /**
@@ -55,7 +54,7 @@ case class FlatMapGroupsInPandasExec(
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
   private val largeVarTypes = conf.arrowUseLargeVarTypes
-  private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
   private val pandasFunction = func.asInstanceOf[PythonUDF].func
   private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
index d80320404b0..b05a2d130d3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -33,7 +33,6 @@ import 
org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExec
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.util.CompletionIterator
 
 /**
@@ -81,7 +80,7 @@ case class FlatMapGroupsInPandasWithStateExec(
   override def output: Seq[Attribute] = outAttributes
 
   private val sessionLocalTimeZone = conf.sessionLocalTimeZone
-  private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
 
   private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func
   private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 368184934fa..4a47c2089d6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.UnaryExecNode
-import org.apache.spark.sql.util.ArrowUtils
 
 /**
  * A relation produced by applying a function that takes an iterator of batches
@@ -46,7 +45,7 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
   override protected def doExecute(): RDD[InternalRow] = {
-    val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+    val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
     val pythonFunction = func.asInstanceOf[PythonUDF].func
     val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
     val evaluatorFactory = new MapInBatchEvaluatorFactory(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
index 364e94ab158..a32d892622b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala
@@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, 
UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, 
UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, IntegerType, StructField, 
StructType}
-import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.util.Utils
 
 class WindowInPandasEvaluatorFactory(
@@ -162,7 +161,8 @@ class WindowInPandasEvaluatorFactory(
 
     private val udfWindowBoundTypes = pyFuncs.indices.map(i =>
       frameWindowBoundTypes(expressionIndexToFrameIndex(i)))
-    private val pythonRunnerConf: Map[String, String] = 
(ArrowUtils.getPythonRunnerConfMap(conf)
+    private val pythonRunnerConf: Map[String, String] =
+      (ArrowPythonRunner.getPythonRunnerConfMap(conf)
       + (windowBoundTypeConf -> 
udfWindowBoundTypes.map(_.value).mkString(",")))
 
     // Filter child output attributes down to only those that are UDF inputs.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to