This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d427aa3c6f47 [SPARK-50473][SQL] Simplify classic Column handling
d427aa3c6f47 is described below
commit d427aa3c6f473e33c60eb9f8c7211dfb76cc3e3d
Author: Herman van Hovell <[email protected]>
AuthorDate: Tue Dec 3 11:42:21 2024 -0400
[SPARK-50473][SQL] Simplify classic Column handling
### What changes were proposed in this pull request?
We added a couple of helper functions that make it easier to work Columns
in the Classic API. This covers functionality that creates a Column from an
Expression, and creating (named) Expressions from a Column. There are currently
multiple ways of doing the same thing and this is confusing. This PR attempts
to simplify this a bit, by making the following changes:
- `ExpressionUtils` is moved to the background.
`ClassicConversions`/`ColumnConversion` are now predominantly used. The benefit
of this is that most code now looks like pre-Spark 4 code.
- `ExpressionUtils.expression(..)` and `ExpressionUtils.column(..)` are not
implicit anymore. This was confusing.
- `testImplicits` now supports both Expression -> Column and Column ->
Expression conversions.
### Why are the changes needed?
Easier to understand code.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49038 from hvanhovell/SPARK-50473.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../CheckConnectJvmClientCompatibility.scala | 11 +++--
.../org/apache/spark/ml/stat/Summarizer.scala | 11 +++--
.../sql/connect/planner/SparkConnectPlanner.scala | 56 +++++++++++-----------
.../apache/spark/sql/DataFrameNaFunctions.scala | 2 +-
.../main/scala/org/apache/spark/sql/Dataset.scala | 2 +-
.../spark/sql/RelationalGroupedDataset.scala | 6 +--
.../scala/org/apache/spark/sql/SparkSession.scala | 26 +++-------
.../spark/sql/api/python/PythonSQLUtils.scala | 6 ++-
...{ClassicConversions.scala => conversions.scala} | 54 ++++++++++++++++++++-
.../python/UserDefinedPythonFunction.scala | 7 +--
.../spark/sql/execution/stat/FrequentItems.scala | 8 ++--
.../spark/sql/internal/MergeIntoWriterImpl.scala | 2 +-
.../spark/sql/internal/columnNodeSupport.scala | 11 ++---
.../spark/sql/DataFrameComplexTypeSuite.scala | 5 +-
.../apache/spark/sql/DataFrameFunctionsSuite.scala | 7 ++-
.../apache/spark/sql/DataFrameSelfJoinSuite.scala | 7 ++-
.../org/apache/spark/sql/DataFrameSuite.scala | 11 ++---
.../spark/sql/DataFrameWindowFunctionsSuite.scala | 3 +-
.../apache/spark/sql/IntegratedUDFTestUtils.scala | 5 +-
.../org/apache/spark/sql/JsonFunctionsSuite.scala | 3 +-
.../spark/sql/TypedImperativeAggregateSuite.scala | 8 ++--
.../connector/functions/V2FunctionBenchmark.scala | 8 +++-
.../execution/datasources/orc/OrcFilterSuite.scala | 9 ++--
.../sql/execution/datasources/orc/OrcTest.scala | 6 +--
.../datasources/orc/OrcV1FilterSuite.scala | 12 ++---
.../datasources/parquet/ParquetFilterSuite.scala | 14 ++++--
.../spark/sql/sources/BucketedReadSuite.scala | 3 +-
.../org/apache/spark/sql/test/SQLTestUtils.scala | 10 ++--
.../ObjectHashAggregateExecBenchmark.scala | 5 +-
.../hive/OptimizeHiveMetadataOnlyQuerySuite.scala | 2 +-
.../hive/execution/ObjectHashAggregateSuite.scala | 5 +-
31 files changed, 187 insertions(+), 138 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index d9ff8d9122ea..b5ea973aa1d7 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -233,9 +233,11 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.artifact.ArtifactManager$SparkContextResourceType$"),
// ColumnNode conversions
+
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession"),
ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.Converter"),
-
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$Converter$"),
+ "org.apache.spark.sql.SparkSession.expression"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.SparkSession.toRichColumn"),
// UDFRegistration
ProblemFilters.exclude[DirectMissingMethodProblem](
@@ -295,10 +297,9 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.KeyValueGroupedDatasetImpl$"),
// ColumnNode conversions
- ProblemFilters.exclude[IncompatibleResultTypeProblem](
- "org.apache.spark.sql.SparkSession#RichColumn.expr"),
ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession#RichColumn.typedExpr"),
+ "org.apache.spark.sql.SparkSession.RichColumn"),
+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSession$RichColumn"),
// New public APIs added in the client
// Dataset
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
index 4c3242c13209..e67b72e09060 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
@@ -29,8 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression,
ImplicitCastInputTypes}
import
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.BinaryLike
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.functions.lit
-import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
+import org.apache.spark.sql.internal.ExpressionUtils.expression
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -249,13 +250,13 @@ private[ml] class SummaryBuilderImpl(
) extends SummaryBuilder {
override def summary(featuresCol: Column, weightCol: Column): Column = {
- SummaryBuilderImpl.MetricsAggregate(
+ Column(SummaryBuilderImpl.MetricsAggregate(
requestedMetrics,
requestedCompMetrics,
- featuresCol,
- weightCol,
+ expression(featuresCol),
+ expression(weightCol),
mutableAggBufferOffset = 0,
- inputAggBufferOffset = 0)
+ inputAggBufferOffset = 0))
}
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ee030a52b221..6ecf3ce11038 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -43,7 +43,7 @@ import
org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile,
TaskResourceProfile, TaskResourceRequest}
-import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation,
RelationalGroupedDataset, Row, SparkSession}
+import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter,
Observation, RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery,
PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute,
UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue,
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar,
UnresolvedTableValuedFunction, UnresolvedTranspose}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder,
ExpressionEncoder, RowEncoder}
@@ -58,6 +58,7 @@ import
org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap,
CharVarcharUtils}
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter,
StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
import
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
@@ -77,7 +78,6 @@ import
org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeout
import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator,
SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl,
TypedAggUtils}
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode,
StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -106,7 +106,7 @@ class SparkConnectPlanner(
@Since("4.0.0")
@DeveloperApi
def session: SparkSession = sessionHolder.session
- import sessionHolder.session.RichColumn
+ import sessionHolder.session.toRichColumn
private[connect] def parser = session.sessionState.sqlParser
@@ -554,7 +554,7 @@ class SparkConnectPlanner(
.ofRows(session, transformRelation(rel.getInput))
.stat
.sampleBy(
- col = column(transformExpression(rel.getCol)),
+ col = Column(transformExpression(rel.getCol)),
fractions = fractions.toMap,
seed = if (rel.hasSeed) rel.getSeed else Utils.random.nextLong)
.logicalPlan
@@ -646,17 +646,17 @@ class SparkConnectPlanner(
val pythonUdf = transformPythonUDF(commonUdf)
val cols =
rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
- column(transformExpression(expr)))
+ Column(transformExpression(expr)))
val group = Dataset
.ofRows(session, transformRelation(rel.getInput))
.groupBy(cols: _*)
pythonUdf.evalType match {
case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF =>
- group.flatMapGroupsInPandas(column(pythonUdf)).logicalPlan
+ group.flatMapGroupsInPandas(Column(pythonUdf)).logicalPlan
case PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF =>
- group.flatMapGroupsInArrow(column(pythonUdf)).logicalPlan
+ group.flatMapGroupsInArrow(Column(pythonUdf)).logicalPlan
case _ =>
throw InvalidPlanInput(
@@ -765,10 +765,10 @@ class SparkConnectPlanner(
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
val inputCols =
rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
- column(transformExpression(expr)))
+ Column(transformExpression(expr)))
val otherCols =
rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
- column(transformExpression(expr)))
+ Column(transformExpression(expr)))
val input = Dataset
.ofRows(session, transformRelation(rel.getInput))
@@ -783,10 +783,10 @@ class SparkConnectPlanner(
pythonUdf.evalType match {
case PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF =>
- input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+ input.flatMapCoGroupsInPandas(other, Column(pythonUdf)).logicalPlan
case PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF =>
- input.flatMapCoGroupsInArrow(other, pythonUdf).logicalPlan
+ input.flatMapCoGroupsInArrow(other, Column(pythonUdf)).logicalPlan
case _ =>
throw InvalidPlanInput(
@@ -982,7 +982,7 @@ class SparkConnectPlanner(
private def transformApplyInPandasWithState(rel:
proto.ApplyInPandasWithState): LogicalPlan = {
val pythonUdf = transformPythonUDF(rel.getFunc)
val cols =
- rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
column(transformExpression(expr)))
+ rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
Column(transformExpression(expr)))
val outputSchema = parseSchema(rel.getOutputSchema)
@@ -992,7 +992,7 @@ class SparkConnectPlanner(
.ofRows(session, transformRelation(rel.getInput))
.groupBy(cols: _*)
.applyInPandasWithState(
- column(pythonUdf),
+ Column(pythonUdf),
outputSchema,
stateSchema,
rel.getOutputMode,
@@ -1080,7 +1080,7 @@ class SparkConnectPlanner(
Metadata.empty
}
- (alias.getName(0), column(transformExpression(alias.getExpr)),
metadata)
+ (alias.getName(0), Column(transformExpression(alias.getExpr)),
metadata)
}.unzip3
Dataset
@@ -1142,7 +1142,7 @@ class SparkConnectPlanner(
private def transformUnpivot(rel: proto.Unpivot): LogicalPlan = {
val ids = rel.getIdsList.asScala.toArray.map { expr =>
- column(transformExpression(expr))
+ Column(transformExpression(expr))
}
if (!rel.hasValues) {
@@ -1155,7 +1155,7 @@ class SparkConnectPlanner(
transformRelation(rel.getInput))
} else {
val values = rel.getValues.getValuesList.asScala.toArray.map { expr =>
- column(transformExpression(expr))
+ Column(transformExpression(expr))
}
Unpivot(
@@ -1184,7 +1184,7 @@ class SparkConnectPlanner(
private def transformCollectMetrics(rel: proto.CollectMetrics, planId:
Long): LogicalPlan = {
val metrics = rel.getMetricsList.asScala.toSeq.map { expr =>
- column(transformExpression(expr))
+ Column(transformExpression(expr))
}
val name = rel.getName
val input = transformRelation(rel.getInput)
@@ -2112,10 +2112,10 @@ class SparkConnectPlanner(
private def transformAsOfJoin(rel: proto.AsOfJoin): LogicalPlan = {
val left = Dataset.ofRows(session, transformRelation(rel.getLeft))
val right = Dataset.ofRows(session, transformRelation(rel.getRight))
- val leftAsOf = column(transformExpression(rel.getLeftAsOf))
- val rightAsOf = column(transformExpression(rel.getRightAsOf))
+ val leftAsOf = Column(transformExpression(rel.getLeftAsOf))
+ val rightAsOf = Column(transformExpression(rel.getRightAsOf))
val joinType = rel.getJoinType
- val tolerance = if (rel.hasTolerance)
column(transformExpression(rel.getTolerance)) else null
+ val tolerance = if (rel.hasTolerance)
Column(transformExpression(rel.getTolerance)) else null
val allowExactMatches = rel.getAllowExactMatches
val direction = rel.getDirection
@@ -2131,7 +2131,7 @@ class SparkConnectPlanner(
allowExactMatches = allowExactMatches,
direction = direction)
} else {
- val joinExprs = if (rel.hasJoinExpr)
column(transformExpression(rel.getJoinExpr)) else null
+ val joinExprs = if (rel.hasJoinExpr)
Column(transformExpression(rel.getJoinExpr)) else null
left.joinAsOf(
other = right,
leftAsOf = leftAsOf,
@@ -2172,7 +2172,7 @@ class SparkConnectPlanner(
private def transformDrop(rel: proto.Drop): LogicalPlan = {
var output = Dataset.ofRows(session, transformRelation(rel.getInput))
if (rel.getColumnsCount > 0) {
- val cols = rel.getColumnsList.asScala.toSeq.map(expr =>
column(transformExpression(expr)))
+ val cols = rel.getColumnsList.asScala.toSeq.map(expr =>
Column(transformExpression(expr)))
output = output.drop(cols.head, cols.tail: _*)
}
if (rel.getColumnNamesCount > 0) {
@@ -2247,7 +2247,7 @@ class SparkConnectPlanner(
rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral)
} else {
RelationalGroupedDataset
- .collectPivotValues(Dataset.ofRows(session, input),
column(pivotExpr))
+ .collectPivotValues(Dataset.ofRows(session, input),
Column(pivotExpr))
.map(expressions.Literal.apply)
}
logical.Pivot(
@@ -2574,12 +2574,12 @@ class SparkConnectPlanner(
if (!namedArguments.isEmpty) {
session.sql(
sql.getQuery,
- namedArguments.asScala.toMap.transform((_, e) =>
column(transformExpression(e))),
+ namedArguments.asScala.toMap.transform((_, e) =>
Column(transformExpression(e))),
tracker)
} else if (!posArguments.isEmpty) {
session.sql(
sql.getQuery,
- posArguments.asScala.map(e => column(transformExpression(e))).toArray,
+ posArguments.asScala.map(e => Column(transformExpression(e))).toArray,
tracker)
} else if (!args.isEmpty) {
session.sql(
@@ -2830,7 +2830,7 @@ class SparkConnectPlanner(
if (writeOperation.getPartitioningColumnsCount > 0) {
val names = writeOperation.getPartitioningColumnsList.asScala
.map(transformExpression)
- .map(column)
+ .map(Column(_))
.toSeq
w.partitionedBy(names.head, names.tail: _*)
}
@@ -2848,7 +2848,7 @@ class SparkConnectPlanner(
w.create()
}
case proto.WriteOperationV2.Mode.MODE_OVERWRITE =>
-
w.overwrite(column(transformExpression(writeOperation.getOverwriteCondition)))
+
w.overwrite(Column(transformExpression(writeOperation.getOverwriteCondition)))
case proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS =>
w.overwritePartitions()
case proto.WriteOperationV2.Mode.MODE_APPEND =>
@@ -3410,7 +3410,7 @@ class SparkConnectPlanner(
val sourceDs = Dataset.ofRows(session,
transformRelation(cmd.getSourceTablePlan))
val mergeInto = sourceDs
- .mergeInto(cmd.getTargetTableName,
column(transformExpression(cmd.getMergeCondition)))
+ .mergeInto(cmd.getTargetTableName,
Column(transformExpression(cmd.getMergeCondition)))
.asInstanceOf[MergeIntoWriterImpl[Row]]
mergeInto.matchedActions ++= matchedActions
mergeInto.notMatchedActions ++= notMatchedActions
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 53e12f58edd6..0d49e850b463 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types._
@Stable
final class DataFrameNaFunctions private[sql](df: DataFrame)
extends api.DataFrameNaFunctions {
- import df.sparkSession.RichColumn
+ import df.sparkSession.toRichColumn
protected def drop(minNonNulls: Option[Int]): Dataset[Row] = {
drop0(minNonNulls, outputAttributes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a74d93b44db9..846d97b25786 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -225,7 +225,7 @@ class Dataset[T] private[sql](
queryExecution.sparkSession
}
- import sparkSession.RichColumn
+ import sparkSession.toRichColumn
// A globally unique id of this Dataset.
private[sql] val id = Dataset.curId.getAndIncrement()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 6f0db42ec1f5..b8c4b03fc13d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.internal.ExpressionUtils.{column, generateAlias}
+import org.apache.spark.sql.internal.ExpressionUtils.generateAlias
import org.apache.spark.sql.internal.TypedAggUtils.withInputType
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{NumericType, StructType}
@@ -114,7 +114,7 @@ class RelationalGroupedDataset protected[sql](
namedExpr
}
}
- columnExprs.map(column)
+ columnExprs.map(Column(_))
}
/** @inheritdoc */
@@ -238,7 +238,7 @@ class RelationalGroupedDataset protected[sql](
broadcastVars: Array[Broadcast[Object]],
outputSchema: StructType): DataFrame = {
val groupingNamedExpressions = groupingExprs.map(alias)
- val groupingCols = groupingNamedExpressions.map(column)
+ val groupingCols = groupingNamedExpressions.map(Column(_))
val groupingDataFrame = df.select(groupingCols : _*)
val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
Dataset.ofRows(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index dbe4543c3310..878fdc8e267a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery,
PosParameterizedQuery, UnresolvedRelation}
import org.apache.spark.sql.catalyst.encoders._
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody,
LocalRelation, LogicalPlan, Range}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -98,7 +98,7 @@ class SparkSession private(
@transient private[sql] val extensions: SparkSessionExtensions,
@transient private[sql] val initialSessionOptions: Map[String, String],
@transient private val parentManagedJobTags: Map[String, String])
- extends api.SparkSession with Logging { self =>
+ extends api.SparkSession with Logging with classic.ColumnConversions { self
=>
// The call site where this SparkSession was constructed.
private val creationSite: CallSite = Utils.getCallSite()
@@ -797,23 +797,11 @@ class SparkSession private(
.getOrElse(sparkContext.defaultParallelism)
}
- private[sql] object Converter extends ColumnNodeToExpressionConverter with
Serializable {
- override protected def parser: ParserInterface = sessionState.sqlParser
- override protected def conf: SQLConf = sessionState.conf
- }
-
- private[sql] def expression(e: Column): Expression = Converter(e.node)
-
- private[sql] implicit class RichColumn(val column: Column) {
- /**
- * Returns the expression for this column.
- */
- def expr: Expression = Converter(column.node)
- /**
- * Returns the expression for this column either with an existing or auto
assigned name.
- */
- def named: NamedExpression = ExpressionUtils.toNamed(expr)
- }
+ override protected[sql] val converter: ColumnNodeToExpressionConverter =
+ new ColumnNodeToExpressionConverter with Serializable {
+ override protected def parser: ParserInterface = sessionState.sqlParser
+ override protected def conf: SQLConf = sessionState.conf
+ }
private[sql] lazy val observationManager = new ObservationManager(this)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index a66a6e54a7c8..da03293ce743 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -33,10 +33,11 @@ import
org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRe
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.execution.{ExplainMode, QueryExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
+import org.apache.spark.sql.internal.ExpressionUtils.expression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{MutableURLClassLoader, Utils}
@@ -152,7 +153,8 @@ private[sql] object PythonSQLUtils extends Logging {
Column(internal.LambdaFunction(function.node, arguments))
}
- def namedArgumentExpression(name: String, e: Column): Column =
NamedArgumentExpression(name, e)
+ def namedArgumentExpression(name: String, e: Column): Column =
+ Column(NamedArgumentExpression(name, expression(e)))
@scala.annotation.varargs
def fn(name: String, arguments: Column*): Column = Column.fn(name,
arguments: _*)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/classic/conversions.scala
similarity index 56%
rename from
sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/classic/conversions.scala
index 8c3223fa72f5..e90fd4b6a603 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/conversions.scala
@@ -20,8 +20,8 @@ import scala.language.implicitConversions
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.internal.ExpressionUtils
+import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
+import org.apache.spark.sql.internal.{ColumnNodeToExpressionConverter,
ExpressionUtils}
/**
* Conversions from sql interfaces to the Classic specific implementation.
@@ -56,4 +56,54 @@ trait ClassicConversions {
}
}
+@DeveloperApi
object ClassicConversions extends ClassicConversions
+
+/**
+ * Conversions from a [[Column]] to an [[Expression]].
+ */
+@DeveloperApi
+trait ColumnConversions {
+ protected def converter: ColumnNodeToExpressionConverter
+
+ /**
+ * Convert a [[Column]] into an [[Expression]].
+ */
+ @DeveloperApi
+ def expression(column: Column): Expression = converter(column.node)
+
+ /**
+ * Wrap a [[Column]] with a [[RichColumn]] to provide the `expr` and `named`
methods.
+ */
+ @DeveloperApi
+ implicit def toRichColumn(column: Column): RichColumn = new
RichColumn(column, converter)
+}
+
+/**
+ * Automatic conversions from a Column to an Expression. This uses the active
SparkSession for
+ * parsing, and the active SQLConf for fetching configurations.
+ *
+ * This functionality is not part of the ClassicConversions because it is
generally better to use
+ * `SparkSession.toRichColumn(...)` or `SparkSession.expression(...)` directly.
+ */
+@DeveloperApi
+object ColumnConversions extends ColumnConversions {
+ override protected def converter: ColumnNodeToExpressionConverter =
+ ColumnNodeToExpressionConverter
+}
+
+/**
+ * Helper class that adds the `expr` and `named` methods to a Column. This can
be used to reinstate
+ * the pre-Spark 4 Column functionality.
+ */
+@DeveloperApi
+class RichColumn(column: Column, converter: ColumnNodeToExpressionConverter) {
+ /**
+ * Returns the expression for this column.
+ */
+ def expr: Expression = converter(column.node)
+ /**
+ * Returns the expression for this column either with an existing or auto
assigned name.
+ */
+ def named: NamedExpression = ExpressionUtils.toNamed(expr)
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index ea1f5e6ae134..388ede5d062e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -28,8 +28,9 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset,
SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending,
Descending, Expression, FunctionTableSubqueryArgumentExpression,
NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF,
PythonUDTF, PythonUDTFAnalyzeResult, PythonUDTFSelectedExpression, SortOrder,
UnresolvedPolymorphicPythonUDTF}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan,
NamedParametersSupport, OneRowRelation}
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
+import org.apache.spark.sql.internal.ExpressionUtils.expression
import org.apache.spark.sql.types.{DataType, StructType}
/**
@@ -75,10 +76,10 @@ case class UserDefinedPythonFunction(
* Returns a [[Column]] that will evaluate the UDF expression with the given
input.
*/
def fromUDFExpr(expr: Expression): Column = {
- expr match {
+ Column(expr match {
case udaf: PythonUDAF => udaf.toAggregateExpression()
case _ => expr
- }
+ })
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 148766f9d002..221ca17ddf19 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -22,13 +22,13 @@ import java.io.{ByteArrayInputStream,
ByteArrayOutputStream, DataInputStream, Da
import scala.collection.mutable
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{functions, DataFrame}
+import org.apache.spark.sql.{functions, Column, DataFrame}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression,
UnsafeProjection, UnsafeRow}
import
org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate,
TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.GenericArrayData
-import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -52,13 +52,15 @@ object FrequentItems extends Logging {
df: DataFrame,
cols: Seq[String],
support: Double): DataFrame = {
+ import df.sparkSession.expression
require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1],
but got $support.")
// number of max items to keep counts for
val sizeOfMap = (1 / support).toInt
val frequentItemCols = cols.map { col =>
- column(new CollectFrequentItems(functions.col(col),
sizeOfMap)).as(s"${col}_freqItems")
+ Column(new CollectFrequentItems(expression(functions.col(col)),
sizeOfMap))
+ .as(s"${col}_freqItems")
}
df.select(frequentItemCols: _*)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
index bb8146e3e0e3..2f1a34648a47 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
@@ -44,7 +44,7 @@ class MergeIntoWriterImpl[T] private[sql] (table: String, ds:
Dataset[T], on: Co
private val df: DataFrame = ds.toDF()
private[sql] val sparkSession = ds.sparkSession
- import sparkSession.RichColumn
+ import sparkSession.toRichColumn
private val tableName =
sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
index 00e9a01f33c1..8b4726114890 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala
@@ -16,8 +16,6 @@
*/
package org.apache.spark.sql.internal
-import scala.language.implicitConversions
-
import UserDefinedFunctionUtils.toScalaUDF
import org.apache.spark.SparkException
@@ -302,13 +300,14 @@ private[spark] object ExpressionUtils {
/**
* Create an Expression backed Column.
*/
- implicit def column(e: Expression): Column = Column(ExpressionColumnNode(e))
+ def column(e: Expression): Column = Column(ExpressionColumnNode(e))
/**
- * Create an ColumnNode backed Expression. Please not that this has to be
converted to an actual
- * Expression before it is used.
+ * Create an ColumnNode backed Expression. This can only be used for
expressions that will be
+ * used to construct a [[Column]]. In all other cases please use
`SparkSession.expression(...)`,
+ * `SparkSession.toRichColumn(...)`, or
`org.apache.spark.sql.classic.ColumnConversions`.
*/
- implicit def expression(c: Column): Expression = ColumnNodeExpression(c.node)
+ def expression(c: Column): Expression = ColumnNodeExpression(c.node)
/**
* Returns the expression either with an existing or auto assigned name.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
index 48ea0e01a437..8024b579e5d0 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.objects.MapObjects
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{ArrayType, BooleanType, Decimal,
DoubleType, IntegerType, MapType, StringType, StructField, StructType}
@@ -92,8 +91,8 @@ class DataFrameComplexTypeSuite extends QueryTest with
SharedSparkSession {
// items: Seq[Int] => items.map { item => Seq(Struct(item)) }
val result = df.select(
- column(MapObjects(
- (item: Expression) => array(struct(column(item))).expr,
+ Column(MapObjects(
+ (item: Expression) => array(struct(Column(item))).expr,
$"items".expr,
df.schema("items").dataType.asInstanceOf[ArrayType].elementType
)) as "items"
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 4494057b1eef..ce34db47c6df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -32,7 +32,6 @@ import
org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import
org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -406,7 +405,7 @@ class DataFrameFunctionsSuite extends QueryTest with
SharedSparkSession {
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
- expr = nullifzero(Literal.create(20201231, DateType))
+ expr = nullifzero(Column(Literal.create(20201231, DateType)))
checkError(
intercept[AnalysisException](df.select(expr)),
condition = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES",
@@ -588,7 +587,7 @@ class DataFrameFunctionsSuite extends QueryTest with
SharedSparkSession {
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
- expr = zeroifnull(Literal.create(20201231, DateType))
+ expr = zeroifnull(Column(Literal.create(20201231, DateType)))
checkError(
intercept[AnalysisException](df.select(expr)),
condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
@@ -5737,7 +5736,7 @@ class DataFrameFunctionsSuite extends QueryTest with
SharedSparkSession {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false),
(false, true))) {
val c = if (codegenFallback) {
- column(CodegenFallbackExpr(v.expr))
+ Column(CodegenFallbackExpr(v.expr))
} else {
v
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index f0ed2241fd28..0e9b1c9d2104 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions.{Alias,
Ascending, AttributeRef
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate,
ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, count, explode, sum, year}
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData
@@ -375,7 +374,7 @@ class DataFrameSelfJoinSuite extends QueryTest with
SharedSparkSession {
Seq.empty,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
true)
- val df7 = df1.mapInPandas(mapInPandasUDF)
+ val df7 = df1.mapInPandas(Column(mapInPandasUDF))
val df8 = df7.filter($"x" > 0)
assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y")))
assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y")))
@@ -386,7 +385,7 @@ class DataFrameSelfJoinSuite extends QueryTest with
SharedSparkSession {
Seq.empty,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
true)
- val df9 =
df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF)
+ val df9 =
df1.groupBy($"key1").flatMapGroupsInPandas(Column(flatMapGroupsInPandasUDF))
val df10 = df9.filter($"x" > 0)
assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y")))
assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y")))
@@ -398,7 +397,7 @@ class DataFrameSelfJoinSuite extends QueryTest with
SharedSparkSession {
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
true)
val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas(
- df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
+ df1.groupBy($"key2"), Column(flatMapCoGroupsInPandasUDF))
val df12 = df11.filter($"x" > 0)
assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y")))
assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ff251ddbbfb5..c1d977dad82d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -43,7 +43,6 @@ import
org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT,
SharedSparkSession}
import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper,
ContainerStringWrapper, StringWrapper, TestData2}
@@ -1567,7 +1566,7 @@ class DataFrameSuite extends QueryTest
test("SPARK-46794: exclude subqueries from LogicalRDD constraints") {
withTempDir { checkpointDir =>
val subquery =
-
column(ScalarSubquery(spark.range(10).selectExpr("max(id)").logicalPlan))
+
Column(ScalarSubquery(spark.range(10).selectExpr("max(id)").logicalPlan))
val df = spark.range(1000).filter($"id" === subquery)
assert(df.logicalPlan.constraints.exists(_.exists(_.isInstanceOf[ScalarSubquery])))
@@ -2054,18 +2053,18 @@ class DataFrameSuite extends QueryTest
// the number of keys must match
val exception1 = intercept[IllegalArgumentException] {
df1.groupBy($"key1", $"key2").flatMapCoGroupsInPandas(
- df2.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
+ df2.groupBy($"key2"), Column(flatMapCoGroupsInPandasUDF))
}
assert(exception1.getMessage.contains("Cogroup keys must have same size: 2
!= 1"))
val exception2 = intercept[IllegalArgumentException] {
df1.groupBy($"key1").flatMapCoGroupsInPandas(
- df2.groupBy($"key1", $"key2"), flatMapCoGroupsInPandasUDF)
+ df2.groupBy($"key1", $"key2"), Column(flatMapCoGroupsInPandasUDF))
}
assert(exception2.getMessage.contains("Cogroup keys must have same size: 1
!= 2"))
// but different keys are allowed
val actual = df1.groupBy($"key1").flatMapCoGroupsInPandas(
- df2.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
+ df2.groupBy($"key2"), Column(flatMapCoGroupsInPandasUDF))
// can't evaluate the DataFrame as there is no PythonFunction given
assert(actual != null)
}
@@ -2419,7 +2418,7 @@ class DataFrameSuite extends QueryTest
| SELECT a, b FROM (SELECT a, b FROM VALUES (1, 2) AS t(a, b))
|)
|""".stripMargin)
- val stringCols = df.logicalPlan.output.map(column(_).cast(StringType))
+ val stringCols = df.logicalPlan.output.map(Column(_).cast(StringType))
val castedDf = df.select(stringCols: _*)
checkAnswer(castedDf, Row("1", "1") :: Row("1", "2") :: Nil)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 8a86aa10887c..01e72daead44 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -29,7 +29,6 @@ import
org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, S
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer,
UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -862,7 +861,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest
lead($"value", 2, null, true).over(window),
lead($"value", 3, null, true).over(window),
lead(concat($"value", $"key"), 1, null, true).over(window),
- column(Lag($"value".expr, NonFoldableLiteral(1), Literal(null),
true)).over(window),
+ Column(Lag($"value".expr, NonFoldableLiteral(1), Literal(null),
true)).over(window),
lag($"value", 2).over(window),
lag($"value", 0, null, true).over(window),
lag($"value", 1, null, true).over(window),
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index cdea4446d946..22f55819d1d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -31,10 +31,11 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId,
PythonUDF}
import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.classic.ClassicConversions._
import
org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource
import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction,
UserDefinedPythonTableFunction}
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
-import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
+import org.apache.spark.sql.internal.ExpressionUtils.expression
import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF
import org.apache.spark.sql.types.{DataType, IntegerType, NullType,
StringType, StructType, VariantType}
import org.apache.spark.util.ArrayImplicits._
@@ -1592,7 +1593,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
Cast(toScalaUDF(udf, Cast(expr, StringType) :: Nil), rt)
}
- def apply(exprs: Column*): Column = builder(exprs.map(expression))
+ def apply(exprs: Column*): Column = Column(builder(exprs.map(expression)))
val prettyName: String = "Scala UDF"
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 84408d8e2495..3803360f2da4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -1394,7 +1393,7 @@ class JsonFunctionsSuite extends QueryTest with
SharedSparkSession {
val df = Seq(1).toDF("a")
val schema = StructType(StructField("b",
ObjectType(classOf[java.lang.Integer])) :: Nil)
val row = InternalRow.fromSeq(Seq(Integer.valueOf(1)))
- val structData = column(Literal.create(row, schema))
+ val structData = Column(Literal.create(row, schema))
checkError(
exception = intercept[AnalysisException] {
df.select($"a").withColumn("c", to_json(structData)).collect()
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
index 624bae70ce09..662eead137c4 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.ExpressionUtils.{column => toColumn,
expression}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -89,7 +88,7 @@ class TypedImperativeAggregateSuite extends QueryTest with
SharedSparkSession {
test("dataframe aggregate with object aggregate buffer, should not use
HashAggregate") {
val df = data.toDF("a", "b")
- val max = TypedMax($"a")
+ val max = Column(TypedMax($"a".expr))
// Always uses SortAggregateExec
val sparkPlan = df.select(max).queryExecution.sparkPlan
@@ -212,9 +211,10 @@ class TypedImperativeAggregateSuite extends QueryTest with
SharedSparkSession {
checkAnswer(query, expected)
}
- private def typedMax(column: Column): Column = TypedMax(column)
+ private def typedMax(column: Column): Column = Column(TypedMax(column.expr))
- private def nullableTypedMax(column: Column): Column = TypedMax(column,
nullable = true)
+ private def nullableTypedMax(column: Column): Column =
+ Column(TypedMax(column.expr, nullable = true))
}
object TypedImperativeAggregateSuite {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
index 1401048cf705..a5f0285bf2ef 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala
@@ -21,15 +21,16 @@ import
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd
import
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault,
JavaLongAddMagic, JavaLongAddStaticMagic}
import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, EvalMode,
Expression}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog}
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction,
ScalarFunction, UnboundFunction}
import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark
import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, LongType,
NumericType, StructType}
@@ -64,6 +65,7 @@ object V2FunctionBenchmark extends SqlBasedBenchmark {
N: Long,
codegenEnabled: Boolean,
resultNullable: Boolean): Unit = {
+ import spark.toRichColumn
withSQLConf(s"spark.sql.catalog.$catalogName" ->
classOf[InMemoryCatalog].getName) {
createFunction("java_long_add_default",
new JavaLongAdd(new JavaLongAddDefault(resultNullable)))
@@ -81,7 +83,9 @@ object V2FunctionBenchmark extends SqlBasedBenchmark {
s"codegen = $codegenEnabled"
val benchmark = new Benchmark(name, N, output = output)
benchmark.addCase(s"native_long_add", numIters = 3) { _ =>
- spark.range(N).select(NativeAdd(col("id"), col("id"),
resultNullable)).noop()
+ spark.range(N)
+ .select(Column(NativeAdd(col("id").expr, col("id").expr,
resultNullable)))
+ .noop()
}
Seq("java_long_add_default", "java_long_add_magic",
"java_long_add_static_magic",
"scala_long_add_default", "scala_long_add_magic").foreach {
functionName =>
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
index 500c0647bcb2..bf9740970a66 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
@@ -28,14 +28,13 @@ import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf,
SearchArgument, SearchA
import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder
import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException}
-import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.internal.ExpressionUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -47,7 +46,7 @@ import org.apache.spark.util.ArrayImplicits._
*/
@ExtendedSQLTest
class OrcFilterSuite extends OrcTest with SharedSparkSession {
- import testImplicits.toRichColumn
+ import testImplicits.{toRichColumn, ColumnConstructorExt}
override protected def sparkConf: SparkConf =
super
@@ -60,8 +59,8 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
checker: (SearchArgument) => Unit): Unit = {
val output = predicate.collect { case a: Attribute => a }.distinct
val query = df
- .select(output.map(e => ExpressionUtils.column(e)): _*)
- .where(ExpressionUtils.column(predicate))
+ .select(output.map(e => Column(e)): _*)
+ .where(Column(predicate))
query.queryExecution.optimizedPlan match {
case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o:
OrcScan, _, _, _)) =>
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
index b8669ee4d1ef..9fbc872ad262 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
@@ -28,10 +28,10 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{Attribute, Predicate}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.execution.datasources.FileBasedDataSourceTest
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION
import org.apache.spark.util.ArrayImplicits._
@@ -118,8 +118,8 @@ trait OrcTest extends QueryTest with
FileBasedDataSourceTest with BeforeAndAfter
(implicit df: DataFrame): Unit = {
val output = predicate.collect { case a: Attribute => a }.distinct
val query = df
- .select(output.map(e => column(e)): _*)
- .where(predicate)
+ .select(output.map(e => Column(e)): _*)
+ .where(Column(predicate))
query.queryExecution.optimizedPlan match {
case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, o:
OrcScan, _, _, _)) =>
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala
index 5260ebf15e4f..8018417f923a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala
@@ -21,12 +21,12 @@ import scala.jdk.CollectionConverters._
import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl
import org.apache.spark.SparkConf
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Predicate}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy,
HadoopFsRelation, LogicalRelationWithTable}
import org.apache.spark.sql.execution.datasources.orc.OrcShimUtils.{Operator,
SearchArgument}
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.tags.ExtendedSQLTest
@@ -44,8 +44,8 @@ class OrcV1FilterSuite extends OrcFilterSuite {
checker: (SearchArgument) => Unit): Unit = {
val output = predicate.collect { case a: Attribute => a }.distinct
val query = df
- .select(output.map(e => column(e)): _*)
- .where(predicate)
+ .select(output.map(e => Column(e)): _*)
+ .where(Column(predicate))
var maybeRelation: Option[HadoopFsRelation] = None
val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect {
@@ -90,8 +90,8 @@ class OrcV1FilterSuite extends OrcFilterSuite {
(implicit df: DataFrame): Unit = {
val output = predicate.collect { case a: Attribute => a }.distinct
val query = df
- .select(output.map(e => column(e)): _*)
- .where(predicate)
+ .select(output.map(e => Column(e)): _*)
+ .where(Column(predicate))
var maybeRelation: Option[HadoopFsRelation] = None
val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 37edb9ea2315..5f7a0c9e7e74 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -49,7 +49,7 @@ import
org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsR
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.{ExpressionUtils, LegacyBehaviorPolicy,
SQLConf}
+import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.LegacyBehaviorPolicy.{CORRECTED, LEGACY}
import
org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType.{INT96,
TIMESTAMP_MICROS, TIMESTAMP_MILLIS}
import org.apache.spark.sql.test.SharedSparkSession
@@ -2233,6 +2233,8 @@ abstract class ParquetFilterSuite extends QueryTest with
ParquetTest with Shared
@ExtendedSQLTest
class ParquetV1FilterSuite extends ParquetFilterSuite {
+ import testImplicits.ColumnConstructorExt
+
override protected def sparkConf: SparkConf =
super
.sparkConf
@@ -2260,8 +2262,8 @@ class ParquetV1FilterSuite extends ParquetFilterSuite {
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false",
SQLConf.NESTED_PREDICATE_PUSHDOWN_FILE_SOURCE_LIST.key ->
pushdownDsList) {
val query = df
- .select(output.map(ExpressionUtils.column): _*)
- .where(ExpressionUtils.column(predicate))
+ .select(output.map(Column(_)): _*)
+ .where(Column(predicate))
val nestedOrAttributes = predicate.collectFirst {
case g: GetStructField => g
@@ -2313,6 +2315,8 @@ class ParquetV1FilterSuite extends ParquetFilterSuite {
@ExtendedSQLTest
class ParquetV2FilterSuite extends ParquetFilterSuite {
+ import testImplicits.ColumnConstructorExt
+
// TODO: enable Parquet V2 write path after file source V2 writers are
workable.
override protected def sparkConf: SparkConf =
super
@@ -2339,8 +2343,8 @@ class ParquetV2FilterSuite extends ParquetFilterSuite {
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
InferFiltersFromConstraints.ruleName,
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
val query = df
- .select(output.map(ExpressionUtils.column): _*)
- .where(ExpressionUtils.column(predicate))
+ .select(output.map(Column(_)): _*)
+ .where(Column(predicate))
query.queryExecution.optimizedPlan.collectFirst {
case PhysicalOperation(_, filters,
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 24732223c669..c4b09c4b289e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -32,7 +32,6 @@ import
org.apache.spark.sql.execution.datasources.BucketingUtils
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
@@ -229,7 +228,7 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
checkPrunedAnswers(
bucketSpec,
bucketValues = Seq(bucketValue, bucketValue + 1, bucketValue + 2,
bucketValue + 3),
- filterCondition = column(inSetExpr),
+ filterCondition = Column(inSetExpr),
df)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index fe5a0f8ee257..c93f17701c62 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -41,10 +41,11 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.classic.{ClassicConversions, ColumnConversions}
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
import org.apache.spark.sql.execution.datasources.DataSourceUtils
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.{ColumnNodeToExpressionConverter, SQLConf}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.UninterruptibleThread
import org.apache.spark.util.Utils
@@ -239,9 +240,12 @@ private[sql] trait SQLTestUtilsBase
* This is because we create the `SparkSession` immediately before the first
test is run,
* but the implicits import is needed in the constructor.
*/
- protected object testImplicits extends SQLImplicits {
+ protected object testImplicits
+ extends SQLImplicits
+ with ClassicConversions
+ with ColumnConversions {
override protected def session: SparkSession = self.spark
- implicit def toRichColumn(c: Column): SparkSession#RichColumn =
session.RichColumn(c)
+ override protected def converter: ColumnNodeToExpressionConverter =
self.spark.converter
}
protected override def withSQLConf[T](pairs: (String, String)*)(f: => T): T
= {
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala
index 700a4984a4e3..f5bf49439d3f 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala
@@ -23,10 +23,11 @@ import
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.functions.{lit, percentile_approx => pa}
import org.apache.spark.sql.hive.execution.TestingTypedCount
import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.internal.ExpressionUtils.{column => toCol,
expression}
+import org.apache.spark.sql.internal.ExpressionUtils.expression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.LongType
@@ -117,7 +118,7 @@ object ObjectHashAggregateExecBenchmark extends
SqlBasedBenchmark {
output = output
)
- def typed_count(column: Column): Column = TestingTypedCount(column)
+ def typed_count(column: Column): Column =
Column(TestingTypedCount(expression(column)))
val df = spark.range(N)
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala
index 2152a29b17ff..6709a139dcf9 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala
@@ -32,7 +32,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest
with TestHiveSingleto
with BeforeAndAfter with SQLTestUtils {
import spark.implicits._
- import spark.RichColumn
+ import spark.toRichColumn
override def beforeAll(): Unit = {
super.beforeAll()
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
index bcd0644af078..008a324f73da 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
@@ -23,12 +23,11 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax
import org.scalatest.matchers.must.Matchers._
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper}
+import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.functions.{col, count_distinct, first, lit, max,
percentile_approx => pa}
import org.apache.spark.sql.hive.test.TestHiveSingleton
-import org.apache.spark.sql.internal.ExpressionUtils.{column => toCol,
expression}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
@@ -181,7 +180,7 @@ class ObjectHashAggregateSuite
pa(column, lit(percentage), lit(10000))
}
- private def typed_count(column: Column): Column = TestingTypedCount(column)
+ private def typed_count(column: Column): Column =
Column(TestingTypedCount(column.expr))
// Generates 50 random rows for a given schema.
private def generateRandomRows(schemaForGenerator: StructType): Seq[Row] = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]