This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6bbfa2dad8c7 [SPARK-50979][CONNECT] Remove .expr/.typedExpr implicits
6bbfa2dad8c7 is described below

commit 6bbfa2dad8c70b94ca52eb7cddde5ec68efbe0b1
Author: Herman van Hovell <[email protected]>
AuthorDate: Wed Jan 29 09:43:08 2025 -0400

    [SPARK-50979][CONNECT] Remove .expr/.typedExpr implicits
    
    ### What changes were proposed in this pull request?
    This PR removed the .expr/.typedExpr Column conversion implicits from the 
Connect client.
    
    ### Why are the changes needed?
    Code clean-up.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #49657 from hvanhovell/SPARK-50979.
    
    Authored-by: Herman van Hovell <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 .../apache/spark/sql/catalyst/trees/origin.scala   |  12 +++---
 .../spark/sql/connect/DataFrameNaFunctions.scala   |   9 ++---
 .../spark/sql/connect/DataFrameStatFunctions.scala |   7 ++--
 .../spark/sql/connect/DataFrameWriterV2.scala      |   6 +--
 .../org/apache/spark/sql/connect/Dataset.scala     |  44 ++++++++++-----------
 .../spark/sql/connect/KeyValueGroupedDataset.scala |  15 ++++---
 .../apache/spark/sql/connect/MergeIntoWriter.scala |  10 ++---
 .../sql/connect/RelationalGroupedDataset.scala     |  19 +++++----
 .../apache/spark/sql/connect/SparkSession.scala    |  12 ++----
 .../spark/sql/connect/columnNodeSupport.scala      |   4 +-
 .../test/resources/query-tests/queries/hint.json   |   8 ++--
 .../resources/query-tests/queries/hint.proto.bin   | Bin 240 -> 248 bytes
 12 files changed, 72 insertions(+), 74 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
index 9fbfb9e679e5..23de9c222724 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
@@ -127,11 +127,13 @@ object CurrentOrigin {
     }
   }
 
-  private val sparkCodePattern = 
Pattern.compile("(org\\.apache\\.spark\\.sql\\." +
-    "(?:(classic|connect)\\.)?" +
-    
"(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder)"
 +
-    "(?:|\\..*|\\$.*))" +
-    "|(scala\\.collection\\..*)")
+  private val sparkCodePattern = Pattern.compile(
+    "(org\\.apache\\.spark\\.sql\\." +
+      "(?:(classic|connect)\\.)?" +
+      
"(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder"
 +
+      "|SparkSession|ColumnNodeToProtoConverter)" +
+      "(?:|\\..*|\\$.*))" +
+      "|(scala\\.collection\\..*)")
 
   private def sparkCode(ste: StackTraceElement): Boolean = {
     sparkCodePattern.matcher(ste.getClassName).matches()
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
index 8f6c6ef07b3d..7b79387fbfde 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
@@ -23,8 +23,8 @@ import org.apache.spark.connect.proto.{NAReplace, Relation}
 import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
 import org.apache.spark.connect.proto.NAReplace.Replacement
 import org.apache.spark.sql
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
 import org.apache.spark.sql.connect.ConnectConversions._
-import org.apache.spark.sql.functions
 
 /**
  * Functionality for working with missing data in `DataFrame`s.
@@ -33,7 +33,6 @@ import org.apache.spark.sql.functions
  */
 final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, 
root: Relation)
     extends sql.DataFrameNaFunctions {
-  import sparkSession.RichColumn
 
   override protected def drop(minNonNulls: Option[Int]): DataFrame =
     buildDropDataFrame(None, minNonNulls)
@@ -103,7 +102,7 @@ final class DataFrameNaFunctions private[sql] 
(sparkSession: SparkSession, root:
     sparkSession.newDataFrame { builder =>
       val fillNaBuilder = builder.getFillNaBuilder.setInput(root)
       values.map { case (colName, replaceValue) =>
-        
fillNaBuilder.addCols(colName).addValues(functions.lit(replaceValue).expr.getLiteral)
+        
fillNaBuilder.addCols(colName).addValues(toLiteral(replaceValue).getLiteral)
       }
     }
   }
@@ -143,8 +142,8 @@ final class DataFrameNaFunctions private[sql] 
(sparkSession: SparkSession, root:
     replacementMap.map { case (oldValue, newValue) =>
       Replacement
         .newBuilder()
-        .setOldValue(functions.lit(oldValue).expr.getLiteral)
-        .setNewValue(functions.lit(newValue).expr.getLiteral)
+        .setOldValue(toLiteral(oldValue).getLiteral)
+        .setNewValue(toLiteral(newValue).getLiteral)
         .build()
     }
   }
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
index f3c3f82a233a..a510afc716a7 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
@@ -23,9 +23,9 @@ import org.apache.spark.connect.proto.{Relation, StatSampleBy}
 import org.apache.spark.sql
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, 
PrimitiveDoubleEncoder}
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, 
toLiteral}
 import org.apache.spark.sql.connect.ConnectConversions._
 import 
org.apache.spark.sql.connect.DataFrameStatFunctions.approxQuantileResultEncoder
-import org.apache.spark.sql.functions.lit
 
 /**
  * Statistic functions for `DataFrame`s.
@@ -120,20 +120,19 @@ final class DataFrameStatFunctions private[sql] 
(protected val df: DataFrame)
 
   /** @inheritdoc */
   def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): 
DataFrame = {
-    import sparkSession.RichColumn
     require(
       fractions.values.forall(p => p >= 0.0 && p <= 1.0),
       s"Fractions must be in [0, 1], but got $fractions.")
     sparkSession.newDataFrame { builder =>
       val sampleByBuilder = builder.getSampleByBuilder
         .setInput(root)
-        .setCol(col.expr)
+        .setCol(toExpr(col))
         .setSeed(seed)
       fractions.foreach { case (k, v) =>
         sampleByBuilder.addFractions(
           StatSampleBy.Fraction
             .newBuilder()
-            .setStratum(lit(k).expr.getLiteral)
+            .setStratum(toLiteral(k).getLiteral)
             .setFraction(v))
       }
     }
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
index 42cf2cdfad58..06d339487bfb 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
@@ -23,6 +23,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.connect.proto
 import org.apache.spark.sql
 import org.apache.spark.sql.Column
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
 
 /**
  * Interface used to write a [[org.apache.spark.sql.Dataset]] to external 
storage using the v2
@@ -33,7 +34,6 @@ import org.apache.spark.sql.Column
 @Experimental
 final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
     extends sql.DataFrameWriterV2[T] {
-  import ds.sparkSession.RichColumn
 
   private val builder = proto.WriteOperationV2
     .newBuilder()
@@ -73,7 +73,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, 
ds: Dataset[T])
   /** @inheritdoc */
   @scala.annotation.varargs
   override def partitionedBy(column: Column, columns: Column*): this.type = {
-    builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava)
+    builder.addAllPartitioningColumns((column +: columns).map(toExpr).asJava)
     this
   }
 
@@ -106,7 +106,7 @@ final class DataFrameWriterV2[T] private[sql] (table: 
String, ds: Dataset[T])
 
   /** @inheritdoc */
   def overwrite(condition: Column): Unit = {
-    builder.setOverwriteCondition(condition.expr)
+    builder.setOverwriteCondition(toExpr(condition))
     executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE)
   }
 
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
index 36003283a336..419ac3b7f74a 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.OrderUtils
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, 
toLiteral, toTypedExpr}
 import org.apache.spark.sql.connect.ConnectConversions._
 import org.apache.spark.sql.connect.client.SparkResult
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
StorageLevelProtoConverter}
@@ -140,7 +141,6 @@ class Dataset[T] private[sql] (
     @DeveloperApi val plan: proto.Plan,
     val encoder: Encoder[T])
     extends sql.Dataset[T] {
-  import sparkSession.RichColumn
 
   // Make sure we don't forget to set plan id.
   assert(plan.getRoot.getCommon.hasPlanId)
@@ -336,7 +336,7 @@ class Dataset[T] private[sql] (
     buildJoin(right, Seq(joinExprs)) { builder =>
       builder
         .setJoinType(toJoinType(joinType))
-        .setJoinCondition(joinExprs.expr)
+        .setJoinCondition(toExpr(joinExprs))
     }
   }
 
@@ -375,7 +375,7 @@ class Dataset[T] private[sql] (
         .setLeft(plan.getRoot)
         .setRight(other.plan.getRoot)
         .setJoinType(joinTypeValue)
-        .setJoinCondition(condition.expr)
+        .setJoinCondition(toExpr(condition))
         .setJoinDataType(joinBuilder.getJoinDataTypeBuilder
           .setIsLeftStruct(this.agnosticEncoder.isStruct)
           .setIsRightStruct(other.agnosticEncoder.isStruct))
@@ -396,7 +396,7 @@ class Dataset[T] private[sql] (
     sparkSession.newDataFrame(joinExprs.toSeq) { builder =>
       val lateralJoinBuilder = builder.getLateralJoinBuilder
       lateralJoinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
-      joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(c.expr))
+      joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(toExpr(c)))
       lateralJoinBuilder.setJoinType(joinTypeValue)
     }
   }
@@ -440,7 +440,7 @@ class Dataset[T] private[sql] (
       builder.getHintBuilder
         .setInput(plan.getRoot)
         .setName(name)
-        .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
+        .addAllParameters(parameters.map(p => toLiteral(p)).asJava)
     }
 
   private def getPlanId: Option[Long] =
@@ -486,7 +486,7 @@ class Dataset[T] private[sql] (
     sparkSession.newDataset(encoder) { builder =>
       builder.getProjectBuilder
         .setInput(plan.getRoot)
-        .addExpressions(col.typedExpr(this.encoder))
+        .addExpressions(toTypedExpr(col, this.encoder))
     }
   }
 
@@ -504,14 +504,14 @@ class Dataset[T] private[sql] (
     sparkSession.newDataset(encoder, cols) { builder =>
       builder.getProjectBuilder
         .setInput(plan.getRoot)
-        .addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava)
+        .addAllExpressions(cols.map(c => toTypedExpr(c, this.encoder)).asJava)
     }
   }
 
   /** @inheritdoc */
   def filter(condition: Column): Dataset[T] = {
     sparkSession.newDataset(agnosticEncoder, Seq(condition)) { builder =>
-      
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
+      
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(toExpr(condition))
     }
   }
 
@@ -523,12 +523,12 @@ class Dataset[T] private[sql] (
     sparkSession.newDataFrame(ids.toSeq ++ valuesOption.toSeq.flatten) { 
builder =>
       val unpivot = builder.getUnpivotBuilder
         .setInput(plan.getRoot)
-        .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
+        .addAllIds(ids.toImmutableArraySeq.map(toExpr).asJava)
         .setVariableColumnName(variableColumnName)
         .setValueColumnName(valueColumnName)
       valuesOption.foreach { values =>
         unpivot.getValuesBuilder
-          .addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
+          .addAllValues(values.toImmutableArraySeq.map(toExpr).asJava)
       }
     }
   }
@@ -537,7 +537,7 @@ class Dataset[T] private[sql] (
     sparkSession.newDataFrame(indices) { builder =>
       val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
       indices.foreach { indexColumn =>
-        transpose.addIndexColumns(indexColumn.expr)
+        transpose.addIndexColumns(toExpr(indexColumn))
       }
     }
 
@@ -553,7 +553,7 @@ class Dataset[T] private[sql] (
       function = func,
       inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil,
       outputEncoder = agnosticEncoder)
-    val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr
+    val reduceExpr = toExpr(Column.fn("reduce", udf.apply(col("*"), col("*"))))
 
     val result = sparkSession
       .newDataset(agnosticEncoder) { builder =>
@@ -590,7 +590,7 @@ class Dataset[T] private[sql] (
     val groupingSetMsgs = groupingSets.map { groupingSet =>
       val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder()
       for (groupCol <- groupingSet) {
-        groupingSetMsg.addGroupingSet(groupCol.expr)
+        groupingSetMsg.addGroupingSet(toExpr(groupCol))
       }
       groupingSetMsg.build()
     }
@@ -779,7 +779,7 @@ class Dataset[T] private[sql] (
       s"The size of column names: ${names.size} isn't equal to " +
         s"the size of columns: ${values.size}")
     val aliases = values.zip(names).map { case (value, name) =>
-      value.name(name).expr.getAlias
+      toExpr(value.name(name)).getAlias
     }
     sparkSession.newDataFrame(values) { builder =>
       builder.getWithColumnsBuilder
@@ -812,7 +812,7 @@ class Dataset[T] private[sql] (
   def withMetadata(columnName: String, metadata: Metadata): DataFrame = {
     val newAlias = proto.Expression.Alias
       .newBuilder()
-      .setExpr(col(columnName).expr)
+      .setExpr(toExpr(col(columnName)))
       .addName(columnName)
       .setMetadata(metadata.json)
     sparkSession.newDataFrame { builder =>
@@ -845,7 +845,7 @@ class Dataset[T] private[sql] (
     sparkSession.newDataFrame(cols) { builder =>
       builder.getDropBuilder
         .setInput(plan.getRoot)
-        .addAllColumns(cols.map(_.expr).asJava)
+        .addAllColumns(cols.map(toExpr).asJava)
     }
   }
 
@@ -915,7 +915,7 @@ class Dataset[T] private[sql] (
     sparkSession.newDataset[T](agnosticEncoder) { builder =>
       builder.getFilterBuilder
         .setInput(plan.getRoot)
-        .setCondition(udf.apply(col("*")).expr)
+        .setCondition(toExpr(udf.apply(col("*"))))
     }
   }
 
@@ -944,7 +944,7 @@ class Dataset[T] private[sql] (
     sparkSession.newDataset(outputEncoder) { builder =>
       builder.getMapPartitionsBuilder
         .setInput(plan.getRoot)
-        .setFunc(udf.apply(col("*")).expr.getCommonInlineUserDefinedFunction)
+        
.setFunc(toExpr(udf.apply(col("*"))).getCommonInlineUserDefinedFunction)
     }
   }
 
@@ -1020,7 +1020,7 @@ class Dataset[T] private[sql] (
     sparkSession.newDataset(agnosticEncoder, partitionExprs) { builder =>
       val repartitionBuilder = builder.getRepartitionByExpressionBuilder
         .setInput(plan.getRoot)
-        .addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
+        .addAllPartitionExprs(partitionExprs.map(toExpr).asJava)
       numPartitions.foreach(repartitionBuilder.setNumPartitions)
     }
   }
@@ -1036,7 +1036,7 @@ class Dataset[T] private[sql] (
     // The underlying `LogicalPlan` operator special-cases all-`SortOrder` 
arguments.
     // However, we don't want to complicate the semantics of this API method.
     // Instead, let's give users a friendly error message, pointing them to 
the new method.
-    val sortOrders = partitionExprs.filter(_.expr.hasSortOrder)
+    val sortOrders = partitionExprs.filter(e => toExpr(e).hasSortOrder)
     if (sortOrders.nonEmpty) {
       throw new IllegalArgumentException(
         s"Invalid partitionExprs specified: $sortOrders\n" +
@@ -1050,7 +1050,7 @@ class Dataset[T] private[sql] (
       partitionExprs: Seq[Column]): Dataset[T] = {
     require(partitionExprs.nonEmpty, "At least one partition-by expression 
must be specified.")
     val sortExprs = partitionExprs.map {
-      case e if e.expr.hasSortOrder => e
+      case e if toExpr(e).hasSortOrder => e
       case e => e.asc
     }
     buildRepartitionByExpression(numPartitions, sortExprs)
@@ -1158,7 +1158,7 @@ class Dataset[T] private[sql] (
       builder.getCollectMetricsBuilder
         .setInput(plan.getRoot)
         .setName(name)
-        .addAllMetrics((expr +: exprs).map(_.expr).asJava)
+        .addAllMetrics((expr +: exprs).map(toExpr).asJava)
     }
   }
 
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
index c984582ed6ae..dc494649b397 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql
 import org.apache.spark.sql.{Column, Encoder, TypedColumn}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, 
ProductEncoder}
-import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, 
toTypedExpr}
 import org.apache.spark.sql.connect.ConnectConversions._
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfUtils}
 import org.apache.spark.sql.expressions.SparkUserDefinedFunction
@@ -394,7 +394,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
     private val valueMapFunc: Option[IV => V],
     private val keysFunc: () => Dataset[IK])
     extends KeyValueGroupedDataset[K, V] {
-  import sparkSession.RichColumn
 
   override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
     new KeyValueGroupedDatasetImpl[L, V, IK, IV](
@@ -436,7 +435,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
     sparkSession.newDataset[U](outputEncoder) { builder =>
       builder.getGroupMapBuilder
         .setInput(plan.getRoot)
-        .addAllSortingExpressions(sortExprs.map(e => e.expr).asJava)
+        .addAllSortingExpressions(sortExprs.map(toExpr).asJava)
         .addAllGroupingExpressions(groupingExprs)
         .setFunc(getUdf(nf, outputEncoder)(ivEncoder))
     }
@@ -453,10 +452,10 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
       builder.getCoGroupMapBuilder
         .setInput(plan.getRoot)
         .addAllInputGroupingExpressions(groupingExprs)
-        .addAllInputSortingExpressions(thisSortExprs.map(e => e.expr).asJava)
+        .addAllInputSortingExpressions(thisSortExprs.map(toExpr).asJava)
         .setOther(otherImpl.plan.getRoot)
         .addAllOtherGroupingExpressions(otherImpl.groupingExprs)
-        .addAllOtherSortingExpressions(otherSortExprs.map(e => e.expr).asJava)
+        .addAllOtherSortingExpressions(otherSortExprs.map(toExpr).asJava)
         .setFunc(getUdf(nf, outputEncoder)(ivEncoder, otherImpl.ivEncoder))
     }
   }
@@ -469,7 +468,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
         .setInput(plan.getRoot)
         .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
         .addAllGroupingExpressions(groupingExprs)
-        .addAllAggregateExpressions(columns.map(_.typedExpr(vEncoder)).asJava)
+        .addAllAggregateExpressions(columns.map(c => toTypedExpr(c, 
vEncoder)).asJava)
     }
   }
 
@@ -534,7 +533,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
       function = nf,
       inputEncoders = inputEncoders,
       outputEncoder = outputEncoder)
-    udf.apply(inputEncoders.map(_ => col("*")): 
_*).expr.getCommonInlineUserDefinedFunction
+    toExpr(udf.apply(inputEncoders.map(_ => col("*")): 
_*)).getCommonInlineUserDefinedFunction
   }
 
   private def getUdf[U: Encoder, S: Encoder](
@@ -549,7 +548,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
       function = nf,
       inputEncoders = inputEncoders,
       outputEncoder = outputEncoder)
-    udf.apply(inputEncoders.map(_ => col("*")): 
_*).expr.getCommonInlineUserDefinedFunction
+    toExpr(udf.apply(inputEncoders.map(_ => col("*")): 
_*)).getCommonInlineUserDefinedFunction
   }
 
   /**
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
index c245a8644a3c..66354e63ca8a 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
@@ -24,6 +24,7 @@ import org.apache.spark.connect.proto.{Expression, 
MergeAction, MergeIntoTableCo
 import org.apache.spark.connect.proto.MergeAction.ActionType._
 import org.apache.spark.sql
 import org.apache.spark.sql.Column
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
 import org.apache.spark.sql.functions.expr
 
 /**
@@ -44,13 +45,12 @@ import org.apache.spark.sql.functions.expr
 @Experimental
 class MergeIntoWriter[T] private[sql] (table: String, ds: Dataset[T], on: 
Column)
     extends sql.MergeIntoWriter[T] {
-  import ds.sparkSession.RichColumn
 
   private val builder = MergeIntoTableCommand
     .newBuilder()
     .setTargetTableName(table)
     .setSourceTablePlan(ds.plan.getRoot)
-    .setMergeCondition(on.expr)
+    .setMergeCondition(toExpr(on))
 
   /**
    * Executes the merge operation.
@@ -121,12 +121,12 @@ class MergeIntoWriter[T] private[sql] (table: String, ds: 
Dataset[T], on: Column
       condition: Option[Column],
       assignments: Map[String, Column] = Map.empty): Expression = {
     val builder = proto.MergeAction.newBuilder().setActionType(actionType)
-    condition.foreach(c => builder.setCondition(c.expr))
+    condition.foreach(c => builder.setCondition(toExpr(c)))
     assignments.foreach { case (k, v) =>
       builder
         .addAssignmentsBuilder()
-        .setKey(expr(k).expr)
-        .setValue(v.expr)
+        .setKey(toExpr(expr(k)))
+        .setValue(toExpr(v))
     }
     Expression
       .newBuilder()
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
index 00dc1fb6906f..ac361047bbd0 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
@@ -23,6 +23,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.sql
 import org.apache.spark.sql.{functions, Column, Encoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, 
toTypedExpr}
 import org.apache.spark.sql.connect.ConnectConversions._
 
 /**
@@ -44,14 +45,13 @@ class RelationalGroupedDataset private[sql] (
     pivot: Option[proto.Aggregate.Pivot] = None,
     groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None)
     extends sql.RelationalGroupedDataset {
-  import df.sparkSession.RichColumn
 
   protected def toDF(aggExprs: Seq[Column]): DataFrame = {
     df.sparkSession.newDataFrame(groupingExprs ++ aggExprs) { builder =>
       val aggBuilder = builder.getAggregateBuilder
         .setInput(df.plan.getRoot)
-      groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr))
-      aggExprs.foreach(c => 
aggBuilder.addAggregateExpressions(c.typedExpr(df.encoder)))
+      groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(toExpr(c)))
+      aggExprs.foreach(c => aggBuilder.addAggregateExpressions(toTypedExpr(c, 
df.encoder)))
 
       groupType match {
         case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
@@ -152,10 +152,13 @@ class RelationalGroupedDataset private[sql] (
     groupType match {
       case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
         val valueExprs = values.map {
-          case c: Column if c.expr.hasLiteral => c.expr.getLiteral
-          case c: Column if !c.expr.hasLiteral =>
-            throw new IllegalArgumentException("values only accept literal 
Column")
-          case v => functions.lit(v).expr.getLiteral
+          case c: Column =>
+            val e = toExpr(c)
+            if (!e.hasLiteral) {
+              throw new IllegalArgumentException("values only accept literal 
Column")
+            }
+            e.getLiteral
+          case v => toExpr(functions.lit(v)).getLiteral
         }
         new RelationalGroupedDataset(
           df,
@@ -164,7 +167,7 @@ class RelationalGroupedDataset private[sql] (
           Some(
             proto.Aggregate.Pivot
               .newBuilder()
-              .setCol(pivotColumn.expr)
+              .setCol(toExpr(pivotColumn))
               .addAllValues(valueExprs.asJava)
               .build()))
       case _ =>
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index f7998cf60eca..032ab670dab0 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -43,11 +43,10 @@ import org.apache.spark.sql.{Column, Encoder, 
ExperimentalMethods, Observation,
 import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, 
BoxedLongEncoder, UnboundRowEncoder}
-import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, 
toTypedExpr}
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
 import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, 
SparkConnectClient, SparkResult}
 import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
-import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.types.StructType
@@ -213,7 +212,7 @@ class SparkSession private[sql] (
     val sqlCommand = proto.SqlCommand
       .newBuilder()
       .setSql(sqlText)
-      .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava)
+      .addAllPosArguments(args.map(a => 
toLiteral(a)).toImmutableArraySeq.asJava)
       .build()
     sql(sqlCommand)
   }
@@ -228,7 +227,7 @@ class SparkSession private[sql] (
     val sqlCommand = proto.SqlCommand
       .newBuilder()
       .setSql(sqlText)
-      .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) 
}.asJava)
+      .putAllNamedArguments(args.asScala.map { case (k, v) => (k, 
toLiteral(v)) }.asJava)
       .build()
     sql(sqlCommand)
   }
@@ -653,11 +652,6 @@ class SparkSession private[sql] (
   }
 
   override private[sql] def isUsable: Boolean = client.isSessionValid
-
-  implicit class RichColumn(c: Column) {
-    def expr: proto.Expression = toExpr(c)
-    def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e)
-  }
 }
 
 // The minimal builder needed to create a spark session.
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
index f44ec5b2d504..f08b6e709f13 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
@@ -24,7 +24,7 @@ import org.apache.spark.connect.proto.Expression
 import 
org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering.{SORT_NULLS_FIRST,
 SORT_NULLS_LAST}
 import 
org.apache.spark.connect.proto.Expression.SortOrder.SortDirection.{SORT_DIRECTION_ASCENDING,
 SORT_DIRECTION_DESCENDING}
 import 
org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBoundary, 
FrameType}
-import org.apache.spark.sql.{Column, Encoder}
+import org.apache.spark.sql.{functions, Column, Encoder}
 import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import 
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder
@@ -37,6 +37,8 @@ import org.apache.spark.sql.internal.{Alias, 
CaseWhenOtherwise, Cast, ColumnNode
 object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
   def toExpr(column: Column): proto.Expression = apply(column.node, None)
 
+  def toLiteral(v: Any): proto.Expression = apply(functions.lit(v).node, None)
+
   def toTypedExpr[I](column: Column, encoder: Encoder[I]): proto.Expression = {
     apply(column.node, Option(encoder))
   }
diff --git 
a/sql/connect/common/src/test/resources/query-tests/queries/hint.json 
b/sql/connect/common/src/test/resources/query-tests/queries/hint.json
index 2ac930c0a3a7..2348d0f84715 100644
--- a/sql/connect/common/src/test/resources/query-tests/queries/hint.json
+++ b/sql/connect/common/src/test/resources/query-tests/queries/hint.json
@@ -22,13 +22,13 @@
             "stackTrace": [{
               "classLoaderName": "app",
               "declaringClass": "org.apache.spark.sql.connect.Dataset",
-              "methodName": "~~trimmed~anonfun~~",
+              "methodName": "hint",
               "fileName": "Dataset.scala"
             }, {
               "classLoaderName": "app",
-              "declaringClass": "org.apache.spark.sql.connect.SparkSession",
-              "methodName": "newDataset",
-              "fileName": "SparkSession.scala"
+              "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+              "methodName": "~~trimmed~anonfun~~",
+              "fileName": "PlanGenerationTestSuite.scala"
             }]
           }
         }
diff --git 
a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin 
b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin
index 06459ee5b765..ce7c63b57c47 100644
Binary files 
a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin and 
b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin 
differ


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to