This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 307e46cc4df [SPARK-44532][CONNECT][SQL] Move ArrowUtils to sql/api 307e46cc4df is described below commit 307e46cc4dfdad1f442e8c5c50ecb53c9ef7dc47 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Tue Jul 25 13:35:37 2023 -0400 [SPARK-44532][CONNECT][SQL] Move ArrowUtils to sql/api ### What changes were proposed in this pull request? This PR moves `ArrowUtils` to `sql/api`. One method used for configuring python's arrow runner has been moved to `ArrowPythonRunner `. ### Why are the changes needed? ArrowUtils is used by connect's direct Arrow encoding (and a lot of other things in sql). We want to remove the connect scala client's catalyst dependency. We need to move ArrowUtil in order to do so. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. Closes #42137 from hvanhovell/SPARK-44532. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../connect/client/arrow/ArrowDeserializer.scala | 4 +-- .../sql/connect/client/arrow/ArrowSerializer.scala | 4 +-- sql/api/pom.xml | 8 +++++ .../org/apache/spark/sql/errors/ArrowErrors.scala | 39 ++++++++++++++++++++++ .../apache/spark/sql/errors/DataTypeErrors.scala | 6 ++++ .../org/apache/spark/sql/util/ArrowUtils.scala | 19 +++-------- sql/catalyst/pom.xml | 8 ----- .../spark/sql/errors/QueryExecutionErrors.scala | 20 ----------- .../spark/sql/execution/arrow/ArrowWriter.scala | 4 +-- .../org/apache/spark/sql/execution/Columnar.scala | 4 +-- .../execution/python/AggregateInPandasExec.scala | 3 +- .../sql/execution/python/ArrowEvalPythonExec.scala | 3 +- .../execution/python/ArrowEvalPythonUDTFExec.scala | 3 +- .../sql/execution/python/ArrowPythonRunner.scala | 12 +++++++ .../python/FlatMapCoGroupsInPandasExec.scala | 3 +- .../python/FlatMapGroupsInPandasExec.scala | 3 +- .../FlatMapGroupsInPandasWithStateExec.scala | 3 +- .../sql/execution/python/MapInBatchExec.scala | 3 +- .../python/WindowInPandasEvaluatorFactory.scala | 4 +-- 19 files changed, 86 insertions(+), 67 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 01aba9cb0ce..4177a88ba52 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.errors.{DataTypeErrors, QueryCompilationErrors} import org.apache.spark.sql.types.Decimal /** @@ -341,7 +341,7 @@ object ArrowDeserializers { } case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => - throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType) + throw DataTypeErrors.unsupportedDataTypeError(encoder.dataType) case _ => throw new RuntimeException( diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index d29f90a6a19..9b39a75ceed 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.ArrowUtils @@ -439,7 +439,7 @@ object ArrowSerializer { } case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => - throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType) + throw DataTypeErrors.unsupportedDataTypeError(encoder.dataType) case _ => throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) combination.") diff --git a/sql/api/pom.xml b/sql/api/pom.xml index 16db70180e3..312b5de55ba 100644 --- a/sql/api/pom.xml +++ b/sql/api/pom.xml @@ -64,6 +64,14 @@ <groupId>org.antlr</groupId> <artifactId>antlr4-runtime</artifactId> </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-vector</artifactId> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-memory-netty</artifactId> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ArrowErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ArrowErrors.scala new file mode 100644 index 00000000000..59e35c9c9ba --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ArrowErrors.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.errors + +import org.apache.arrow.vector.types.pojo.ArrowType + +import org.apache.spark.SparkUnsupportedOperationException + +trait ArrowErrors { + + def unsupportedArrowTypeError(typeName: ArrowType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_ARROWTYPE", + messageParameters = Map("typeName" -> typeName.toString)) + } + + def duplicatedFieldNameInArrowStructError( + fieldNames: Seq[String]): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT", + messageParameters = Map("fieldNames" -> fieldNames.mkString("[", ", ", "]"))) + } +} + +object ArrowErrors extends ArrowErrors diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index fcc3086b573..1dd3cf3dd5d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -295,4 +295,10 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { messageParameters = Map("operation" -> operation), cause = null) } + + def unsupportedDataTypeError(typeName: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "UNSUPPORTED_DATATYPE", + messageParameters = Map("typeName" -> toSQLType(typeName))) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala similarity index 91% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala rename to sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index e880e973176..49990f7f033 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -26,8 +26,7 @@ import org.apache.arrow.vector.complex.MapVector import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.errors.{ArrowErrors, DataTypeErrors} import org.apache.spark.sql.types._ private[sql] object ArrowUtils { @@ -61,7 +60,7 @@ private[sql] object ArrowUtils { case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) case _ => - throw QueryExecutionErrors.unsupportedDataTypeError(dt) + throw DataTypeErrors.unsupportedDataTypeError(dt) } def fromArrowType(dt: ArrowType): DataType = dt match { @@ -86,7 +85,7 @@ private[sql] object ArrowUtils { case ArrowType.Null.INSTANCE => NullType case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType() case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() - case _ => throw QueryExecutionErrors.unsupportedArrowTypeError(dt) + case _ => throw ArrowErrors.unsupportedArrowTypeError(dt) } /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ @@ -172,16 +171,6 @@ private[sql] object ArrowUtils { }.toArray) } - /** Return Map with conf settings to be used in ArrowPythonRunner */ - def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { - val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) - val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> - conf.pandasGroupedMapAssignColumnsByName.toString) - val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key -> - conf.arrowSafeTypeConversion.toString) - Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) - } - private def deduplicateFieldNames( dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match { case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) @@ -190,7 +179,7 @@ private[sql] object ArrowUtils { st.names } else { if (errorOnDuplicatedFieldNames) { - throw QueryExecutionErrors.duplicatedFieldNameInArrowStructError(st.names) + throw ArrowErrors.duplicatedFieldNameInArrowStructError(st.names) } val genNawName = st.names.groupBy(identity).map { case (name, names) if names.length > 1 => diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 91f25beb29f..81893e2fe5c 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -116,14 +116,6 @@ <artifactId>univocity-parsers</artifactId> <type>jar</type> </dependency> - <dependency> - <groupId>org.apache.arrow</groupId> - <artifactId>arrow-vector</artifactId> - </dependency> - <dependency> - <groupId>org.apache.arrow</groupId> - <artifactId>arrow-memory-netty</artifactId> - </dependency> <dependency> <groupId>org.apache.datasketches</groupId> <artifactId>datasketches-java</artifactId> diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 183c5425202..3fb14cd079f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -25,7 +25,6 @@ import java.time.temporal.ChronoField import java.util.concurrent.TimeoutException import com.fasterxml.jackson.core.{JsonParser, JsonToken} -import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.hadoop.fs.{FileAlreadyExistsException, FileStatus, Path} import org.apache.hadoop.fs.permission.FsPermission import org.codehaus.commons.compiler.{CompileException, InternalCompilerException} @@ -1142,25 +1141,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map("cost" -> cost)) } - def unsupportedArrowTypeError(typeName: ArrowType): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "UNSUPPORTED_ARROWTYPE", - messageParameters = Map("typeName" -> typeName.toString)) - } - - def unsupportedDataTypeError(typeName: DataType): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "UNSUPPORTED_DATATYPE", - messageParameters = Map("typeName" -> toSQLType(typeName))) - } - - def duplicatedFieldNameInArrowStructError( - fieldNames: Seq[String]): SparkUnsupportedOperationException = { - new SparkUnsupportedOperationException( - errorClass = "DUPLICATED_FIELD_NAME_IN_ARROW_STRUCT", - messageParameters = Map("fieldNames" -> fieldNames.mkString("[", ", ", "]"))) - } - def notSupportTypeError(dataType: DataType): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2100", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index a55e4f0cfcd..534cd8f9ab0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -24,7 +24,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils @@ -83,7 +83,7 @@ object ArrowWriter { case (_: YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) case (dt, _) => - throw QueryExecutionErrors.unsupportedDataTypeError(dt) + throw DataTypeErrors.unsupportedDataTypeError(dt) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 9932f2741b0..a2029816c23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.V1WriteCommand import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -277,7 +277,7 @@ private object RowToColumnConverter { case dt: DecimalType => new DecimalConverter(dt) case mt: MapType => MapConverter(getConverterForType(mt.keyType, nullable = false), getConverterForType(mt.valueType, mt.valueContainsNull)) - case unknown => throw QueryExecutionErrors.unsupportedDataTypeError(unknown) + case unknown => throw DataTypeErrors.unsupportedDataTypeError(unknown) } if (nullable) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 50452a5d999..73560a596ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator import org.apache.spark.sql.types.{DataType, StructField, StructType} -import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils /** @@ -102,7 +101,7 @@ case class AggregateInPandasExec( val sessionLocalTimeZone = conf.sessionLocalTimeZone val largeVarTypes = conf.arrowUseLargeVarTypes - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 0e4a420b4b3..7db43a34a88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils /** * Grouped a iterator into batches. @@ -75,7 +74,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] evalType, conf.sessionLocalTimeZone, conf.arrowUseLargeVarTypes, - ArrowUtils.getPythonRunnerConfMap(conf), + ArrowPythonRunner.getPythonRunnerConfMap(conf), pythonMetrics, jobArtifactUUID) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala index 9c9ae0fca2d..9c0addfd2ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -49,7 +48,7 @@ case class ArrowEvalPythonUDTFExec( private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val largeVarTypes = conf.arrowUseLargeVarTypes - private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) override protected def evaluate( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index ea861df3c1f..d9bce96c477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -56,3 +56,15 @@ class ArrowPythonRunner( "Pandas execution requires more than 4 bytes. Please set higher buffer. " + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") } + +object ArrowPythonRunner { + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> + conf.pandasGroupedMapAssignColumnsByName.toString) + val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key -> + conf.arrowSafeTypeConversion.toString) + Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index 9ef133c6bea..bbfe97d1947 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan} import org.apache.spark.sql.execution.python.PandasGroupUtils._ -import org.apache.spark.sql.util.ArrowUtils /** @@ -58,7 +57,7 @@ case class FlatMapCoGroupsInPandasExec( extends SparkPlan with BinaryExecNode with PythonSQLMetrics { private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) private val pandasFunction = func.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 0ae5a998943..f2d21ce8e96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.python.PandasGroupUtils._ -import org.apache.spark.sql.util.ArrowUtils /** @@ -55,7 +54,7 @@ case class FlatMapGroupsInPandasExec( private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val largeVarTypes = conf.arrowUseLargeVarTypes - private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) private val pandasFunction = func.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index d80320404b0..b05a2d130d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExec import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.CompletionIterator /** @@ -81,7 +80,7 @@ case class FlatMapGroupsInPandasWithStateExec( override def output: Seq[Attribute] = outAttributes private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 368184934fa..4a47c2089d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.UnaryExecNode -import org.apache.spark.sql.util.ArrowUtils /** * A relation produced by applying a function that takes an iterator of batches @@ -46,7 +45,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { override def outputPartitioning: Partitioning = child.outputPartitioning override protected def doExecute(): RDD[InternalRow] = { - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) val pythonFunction = func.asInstanceOf[PythonUDF].func val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) val evaluatorFactory = new MapInBatchEvaluatorFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala index 364e94ab158..a32d892622b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} -import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils class WindowInPandasEvaluatorFactory( @@ -162,7 +161,8 @@ class WindowInPandasEvaluatorFactory( private val udfWindowBoundTypes = pyFuncs.indices.map(i => frameWindowBoundTypes(expressionIndexToFrameIndex(i))) - private val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf) + private val pythonRunnerConf: Map[String, String] = + (ArrowPythonRunner.getPythonRunnerConfMap(conf) + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(","))) // Filter child output attributes down to only those that are UDF inputs. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org