This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6bbfa2dad8c7 [SPARK-50979][CONNECT] Remove .expr/.typedExpr implicits
6bbfa2dad8c7 is described below
commit 6bbfa2dad8c70b94ca52eb7cddde5ec68efbe0b1
Author: Herman van Hovell <[email protected]>
AuthorDate: Wed Jan 29 09:43:08 2025 -0400
[SPARK-50979][CONNECT] Remove .expr/.typedExpr implicits
### What changes were proposed in this pull request?
This PR removed the .expr/.typedExpr Column conversion implicits from the
Connect client.
### Why are the changes needed?
Code clean-up.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49657 from hvanhovell/SPARK-50979.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../apache/spark/sql/catalyst/trees/origin.scala | 12 +++---
.../spark/sql/connect/DataFrameNaFunctions.scala | 9 ++---
.../spark/sql/connect/DataFrameStatFunctions.scala | 7 ++--
.../spark/sql/connect/DataFrameWriterV2.scala | 6 +--
.../org/apache/spark/sql/connect/Dataset.scala | 44 ++++++++++-----------
.../spark/sql/connect/KeyValueGroupedDataset.scala | 15 ++++---
.../apache/spark/sql/connect/MergeIntoWriter.scala | 10 ++---
.../sql/connect/RelationalGroupedDataset.scala | 19 +++++----
.../apache/spark/sql/connect/SparkSession.scala | 12 ++----
.../spark/sql/connect/columnNodeSupport.scala | 4 +-
.../test/resources/query-tests/queries/hint.json | 8 ++--
.../resources/query-tests/queries/hint.proto.bin | Bin 240 -> 248 bytes
12 files changed, 72 insertions(+), 74 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
index 9fbfb9e679e5..23de9c222724 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
@@ -127,11 +127,13 @@ object CurrentOrigin {
}
}
- private val sparkCodePattern =
Pattern.compile("(org\\.apache\\.spark\\.sql\\." +
- "(?:(classic|connect)\\.)?" +
-
"(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder)"
+
- "(?:|\\..*|\\$.*))" +
- "|(scala\\.collection\\..*)")
+ private val sparkCodePattern = Pattern.compile(
+ "(org\\.apache\\.spark\\.sql\\." +
+ "(?:(classic|connect)\\.)?" +
+
"(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder"
+
+ "|SparkSession|ColumnNodeToProtoConverter)" +
+ "(?:|\\..*|\\$.*))" +
+ "|(scala\\.collection\\..*)")
private def sparkCode(ste: StackTraceElement): Boolean = {
sparkCodePattern.matcher(ste.getClassName).matches()
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
index 8f6c6ef07b3d..7b79387fbfde 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameNaFunctions.scala
@@ -23,8 +23,8 @@ import org.apache.spark.connect.proto.{NAReplace, Relation}
import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
import org.apache.spark.connect.proto.NAReplace.Replacement
import org.apache.spark.sql
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.ConnectConversions._
-import org.apache.spark.sql.functions
/**
* Functionality for working with missing data in `DataFrame`s.
@@ -33,7 +33,6 @@ import org.apache.spark.sql.functions
*/
final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession,
root: Relation)
extends sql.DataFrameNaFunctions {
- import sparkSession.RichColumn
override protected def drop(minNonNulls: Option[Int]): DataFrame =
buildDropDataFrame(None, minNonNulls)
@@ -103,7 +102,7 @@ final class DataFrameNaFunctions private[sql]
(sparkSession: SparkSession, root:
sparkSession.newDataFrame { builder =>
val fillNaBuilder = builder.getFillNaBuilder.setInput(root)
values.map { case (colName, replaceValue) =>
-
fillNaBuilder.addCols(colName).addValues(functions.lit(replaceValue).expr.getLiteral)
+
fillNaBuilder.addCols(colName).addValues(toLiteral(replaceValue).getLiteral)
}
}
}
@@ -143,8 +142,8 @@ final class DataFrameNaFunctions private[sql]
(sparkSession: SparkSession, root:
replacementMap.map { case (oldValue, newValue) =>
Replacement
.newBuilder()
- .setOldValue(functions.lit(oldValue).expr.getLiteral)
- .setNewValue(functions.lit(newValue).expr.getLiteral)
+ .setOldValue(toLiteral(oldValue).getLiteral)
+ .setNewValue(toLiteral(newValue).getLiteral)
.build()
}
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
index f3c3f82a233a..a510afc716a7 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameStatFunctions.scala
@@ -23,9 +23,9 @@ import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
PrimitiveDoubleEncoder}
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr,
toLiteral}
import org.apache.spark.sql.connect.ConnectConversions._
import
org.apache.spark.sql.connect.DataFrameStatFunctions.approxQuantileResultEncoder
-import org.apache.spark.sql.functions.lit
/**
* Statistic functions for `DataFrame`s.
@@ -120,20 +120,19 @@ final class DataFrameStatFunctions private[sql]
(protected val df: DataFrame)
/** @inheritdoc */
def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long):
DataFrame = {
- import sparkSession.RichColumn
require(
fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
sparkSession.newDataFrame { builder =>
val sampleByBuilder = builder.getSampleByBuilder
.setInput(root)
- .setCol(col.expr)
+ .setCol(toExpr(col))
.setSeed(seed)
fractions.foreach { case (k, v) =>
sampleByBuilder.addFractions(
StatSampleBy.Fraction
.newBuilder()
- .setStratum(lit(k).expr.getLiteral)
+ .setStratum(toLiteral(k).getLiteral)
.setFraction(v))
}
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
index 42cf2cdfad58..06d339487bfb 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/DataFrameWriterV2.scala
@@ -23,6 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.connect.proto
import org.apache.spark.sql
import org.apache.spark.sql.Column
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
/**
* Interface used to write a [[org.apache.spark.sql.Dataset]] to external
storage using the v2
@@ -33,7 +34,6 @@ import org.apache.spark.sql.Column
@Experimental
final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
extends sql.DataFrameWriterV2[T] {
- import ds.sparkSession.RichColumn
private val builder = proto.WriteOperationV2
.newBuilder()
@@ -73,7 +73,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String,
ds: Dataset[T])
/** @inheritdoc */
@scala.annotation.varargs
override def partitionedBy(column: Column, columns: Column*): this.type = {
- builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava)
+ builder.addAllPartitioningColumns((column +: columns).map(toExpr).asJava)
this
}
@@ -106,7 +106,7 @@ final class DataFrameWriterV2[T] private[sql] (table:
String, ds: Dataset[T])
/** @inheritdoc */
def overwrite(condition: Column): Unit = {
- builder.setOverwriteCondition(condition.expr)
+ builder.setOverwriteCondition(toExpr(condition))
executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE)
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
index 36003283a336..419ac3b7f74a 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr,
toLiteral, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
StorageLevelProtoConverter}
@@ -140,7 +141,6 @@ class Dataset[T] private[sql] (
@DeveloperApi val plan: proto.Plan,
val encoder: Encoder[T])
extends sql.Dataset[T] {
- import sparkSession.RichColumn
// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)
@@ -336,7 +336,7 @@ class Dataset[T] private[sql] (
buildJoin(right, Seq(joinExprs)) { builder =>
builder
.setJoinType(toJoinType(joinType))
- .setJoinCondition(joinExprs.expr)
+ .setJoinCondition(toExpr(joinExprs))
}
}
@@ -375,7 +375,7 @@ class Dataset[T] private[sql] (
.setLeft(plan.getRoot)
.setRight(other.plan.getRoot)
.setJoinType(joinTypeValue)
- .setJoinCondition(condition.expr)
+ .setJoinCondition(toExpr(condition))
.setJoinDataType(joinBuilder.getJoinDataTypeBuilder
.setIsLeftStruct(this.agnosticEncoder.isStruct)
.setIsRightStruct(other.agnosticEncoder.isStruct))
@@ -396,7 +396,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(joinExprs.toSeq) { builder =>
val lateralJoinBuilder = builder.getLateralJoinBuilder
lateralJoinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
- joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(c.expr))
+ joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(toExpr(c)))
lateralJoinBuilder.setJoinType(joinTypeValue)
}
}
@@ -440,7 +440,7 @@ class Dataset[T] private[sql] (
builder.getHintBuilder
.setInput(plan.getRoot)
.setName(name)
- .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
+ .addAllParameters(parameters.map(p => toLiteral(p)).asJava)
}
private def getPlanId: Option[Long] =
@@ -486,7 +486,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(encoder) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
- .addExpressions(col.typedExpr(this.encoder))
+ .addExpressions(toTypedExpr(col, this.encoder))
}
}
@@ -504,14 +504,14 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(encoder, cols) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
- .addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava)
+ .addAllExpressions(cols.map(c => toTypedExpr(c, this.encoder)).asJava)
}
}
/** @inheritdoc */
def filter(condition: Column): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, Seq(condition)) { builder =>
-
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
+
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(toExpr(condition))
}
}
@@ -523,12 +523,12 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(ids.toSeq ++ valuesOption.toSeq.flatten) {
builder =>
val unpivot = builder.getUnpivotBuilder
.setInput(plan.getRoot)
- .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
+ .addAllIds(ids.toImmutableArraySeq.map(toExpr).asJava)
.setVariableColumnName(variableColumnName)
.setValueColumnName(valueColumnName)
valuesOption.foreach { values =>
unpivot.getValuesBuilder
- .addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
+ .addAllValues(values.toImmutableArraySeq.map(toExpr).asJava)
}
}
}
@@ -537,7 +537,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(indices) { builder =>
val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
indices.foreach { indexColumn =>
- transpose.addIndexColumns(indexColumn.expr)
+ transpose.addIndexColumns(toExpr(indexColumn))
}
}
@@ -553,7 +553,7 @@ class Dataset[T] private[sql] (
function = func,
inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil,
outputEncoder = agnosticEncoder)
- val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr
+ val reduceExpr = toExpr(Column.fn("reduce", udf.apply(col("*"), col("*"))))
val result = sparkSession
.newDataset(agnosticEncoder) { builder =>
@@ -590,7 +590,7 @@ class Dataset[T] private[sql] (
val groupingSetMsgs = groupingSets.map { groupingSet =>
val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder()
for (groupCol <- groupingSet) {
- groupingSetMsg.addGroupingSet(groupCol.expr)
+ groupingSetMsg.addGroupingSet(toExpr(groupCol))
}
groupingSetMsg.build()
}
@@ -779,7 +779,7 @@ class Dataset[T] private[sql] (
s"The size of column names: ${names.size} isn't equal to " +
s"the size of columns: ${values.size}")
val aliases = values.zip(names).map { case (value, name) =>
- value.name(name).expr.getAlias
+ toExpr(value.name(name)).getAlias
}
sparkSession.newDataFrame(values) { builder =>
builder.getWithColumnsBuilder
@@ -812,7 +812,7 @@ class Dataset[T] private[sql] (
def withMetadata(columnName: String, metadata: Metadata): DataFrame = {
val newAlias = proto.Expression.Alias
.newBuilder()
- .setExpr(col(columnName).expr)
+ .setExpr(toExpr(col(columnName)))
.addName(columnName)
.setMetadata(metadata.json)
sparkSession.newDataFrame { builder =>
@@ -845,7 +845,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(cols) { builder =>
builder.getDropBuilder
.setInput(plan.getRoot)
- .addAllColumns(cols.map(_.expr).asJava)
+ .addAllColumns(cols.map(toExpr).asJava)
}
}
@@ -915,7 +915,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset[T](agnosticEncoder) { builder =>
builder.getFilterBuilder
.setInput(plan.getRoot)
- .setCondition(udf.apply(col("*")).expr)
+ .setCondition(toExpr(udf.apply(col("*"))))
}
}
@@ -944,7 +944,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(outputEncoder) { builder =>
builder.getMapPartitionsBuilder
.setInput(plan.getRoot)
- .setFunc(udf.apply(col("*")).expr.getCommonInlineUserDefinedFunction)
+
.setFunc(toExpr(udf.apply(col("*"))).getCommonInlineUserDefinedFunction)
}
}
@@ -1020,7 +1020,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(agnosticEncoder, partitionExprs) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
- .addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
+ .addAllPartitionExprs(partitionExprs.map(toExpr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
}
}
@@ -1036,7 +1036,7 @@ class Dataset[T] private[sql] (
// The underlying `LogicalPlan` operator special-cases all-`SortOrder`
arguments.
// However, we don't want to complicate the semantics of this API method.
// Instead, let's give users a friendly error message, pointing them to
the new method.
- val sortOrders = partitionExprs.filter(_.expr.hasSortOrder)
+ val sortOrders = partitionExprs.filter(e => toExpr(e).hasSortOrder)
if (sortOrders.nonEmpty) {
throw new IllegalArgumentException(
s"Invalid partitionExprs specified: $sortOrders\n" +
@@ -1050,7 +1050,7 @@ class Dataset[T] private[sql] (
partitionExprs: Seq[Column]): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression
must be specified.")
val sortExprs = partitionExprs.map {
- case e if e.expr.hasSortOrder => e
+ case e if toExpr(e).hasSortOrder => e
case e => e.asc
}
buildRepartitionByExpression(numPartitions, sortExprs)
@@ -1158,7 +1158,7 @@ class Dataset[T] private[sql] (
builder.getCollectMetricsBuilder
.setInput(plan.getRoot)
.setName(name)
- .addAllMetrics((expr +: exprs).map(_.expr).asJava)
+ .addAllMetrics((expr +: exprs).map(toExpr).asJava)
}
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
index c984582ed6ae..dc494649b397 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/KeyValueGroupedDataset.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql
import org.apache.spark.sql.{Column, Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
ProductEncoder}
-import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr,
toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfUtils}
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
@@ -394,7 +394,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
private val valueMapFunc: Option[IV => V],
private val keysFunc: () => Dataset[IK])
extends KeyValueGroupedDataset[K, V] {
- import sparkSession.RichColumn
override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
new KeyValueGroupedDatasetImpl[L, V, IK, IV](
@@ -436,7 +435,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
sparkSession.newDataset[U](outputEncoder) { builder =>
builder.getGroupMapBuilder
.setInput(plan.getRoot)
- .addAllSortingExpressions(sortExprs.map(e => e.expr).asJava)
+ .addAllSortingExpressions(sortExprs.map(toExpr).asJava)
.addAllGroupingExpressions(groupingExprs)
.setFunc(getUdf(nf, outputEncoder)(ivEncoder))
}
@@ -453,10 +452,10 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
builder.getCoGroupMapBuilder
.setInput(plan.getRoot)
.addAllInputGroupingExpressions(groupingExprs)
- .addAllInputSortingExpressions(thisSortExprs.map(e => e.expr).asJava)
+ .addAllInputSortingExpressions(thisSortExprs.map(toExpr).asJava)
.setOther(otherImpl.plan.getRoot)
.addAllOtherGroupingExpressions(otherImpl.groupingExprs)
- .addAllOtherSortingExpressions(otherSortExprs.map(e => e.expr).asJava)
+ .addAllOtherSortingExpressions(otherSortExprs.map(toExpr).asJava)
.setFunc(getUdf(nf, outputEncoder)(ivEncoder, otherImpl.ivEncoder))
}
}
@@ -469,7 +468,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
.setInput(plan.getRoot)
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
.addAllGroupingExpressions(groupingExprs)
- .addAllAggregateExpressions(columns.map(_.typedExpr(vEncoder)).asJava)
+ .addAllAggregateExpressions(columns.map(c => toTypedExpr(c,
vEncoder)).asJava)
}
}
@@ -534,7 +533,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
function = nf,
inputEncoders = inputEncoders,
outputEncoder = outputEncoder)
- udf.apply(inputEncoders.map(_ => col("*")):
_*).expr.getCommonInlineUserDefinedFunction
+ toExpr(udf.apply(inputEncoders.map(_ => col("*")):
_*)).getCommonInlineUserDefinedFunction
}
private def getUdf[U: Encoder, S: Encoder](
@@ -549,7 +548,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
function = nf,
inputEncoders = inputEncoders,
outputEncoder = outputEncoder)
- udf.apply(inputEncoders.map(_ => col("*")):
_*).expr.getCommonInlineUserDefinedFunction
+ toExpr(udf.apply(inputEncoders.map(_ => col("*")):
_*)).getCommonInlineUserDefinedFunction
}
/**
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
index c245a8644a3c..66354e63ca8a 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/MergeIntoWriter.scala
@@ -24,6 +24,7 @@ import org.apache.spark.connect.proto.{Expression,
MergeAction, MergeIntoTableCo
import org.apache.spark.connect.proto.MergeAction.ActionType._
import org.apache.spark.sql
import org.apache.spark.sql.Column
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
import org.apache.spark.sql.functions.expr
/**
@@ -44,13 +45,12 @@ import org.apache.spark.sql.functions.expr
@Experimental
class MergeIntoWriter[T] private[sql] (table: String, ds: Dataset[T], on:
Column)
extends sql.MergeIntoWriter[T] {
- import ds.sparkSession.RichColumn
private val builder = MergeIntoTableCommand
.newBuilder()
.setTargetTableName(table)
.setSourceTablePlan(ds.plan.getRoot)
- .setMergeCondition(on.expr)
+ .setMergeCondition(toExpr(on))
/**
* Executes the merge operation.
@@ -121,12 +121,12 @@ class MergeIntoWriter[T] private[sql] (table: String, ds:
Dataset[T], on: Column
condition: Option[Column],
assignments: Map[String, Column] = Map.empty): Expression = {
val builder = proto.MergeAction.newBuilder().setActionType(actionType)
- condition.foreach(c => builder.setCondition(c.expr))
+ condition.foreach(c => builder.setCondition(toExpr(c)))
assignments.foreach { case (k, v) =>
builder
.addAssignmentsBuilder()
- .setKey(expr(k).expr)
- .setValue(v.expr)
+ .setKey(toExpr(expr(k)))
+ .setValue(toExpr(v))
}
Expression
.newBuilder()
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
index 00dc1fb6906f..ac361047bbd0 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/RelationalGroupedDataset.scala
@@ -23,6 +23,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql
import org.apache.spark.sql.{functions, Column, Encoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr,
toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
/**
@@ -44,14 +45,13 @@ class RelationalGroupedDataset private[sql] (
pivot: Option[proto.Aggregate.Pivot] = None,
groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None)
extends sql.RelationalGroupedDataset {
- import df.sparkSession.RichColumn
protected def toDF(aggExprs: Seq[Column]): DataFrame = {
df.sparkSession.newDataFrame(groupingExprs ++ aggExprs) { builder =>
val aggBuilder = builder.getAggregateBuilder
.setInput(df.plan.getRoot)
- groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr))
- aggExprs.foreach(c =>
aggBuilder.addAggregateExpressions(c.typedExpr(df.encoder)))
+ groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(toExpr(c)))
+ aggExprs.foreach(c => aggBuilder.addAggregateExpressions(toTypedExpr(c,
df.encoder)))
groupType match {
case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
@@ -152,10 +152,13 @@ class RelationalGroupedDataset private[sql] (
groupType match {
case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
val valueExprs = values.map {
- case c: Column if c.expr.hasLiteral => c.expr.getLiteral
- case c: Column if !c.expr.hasLiteral =>
- throw new IllegalArgumentException("values only accept literal
Column")
- case v => functions.lit(v).expr.getLiteral
+ case c: Column =>
+ val e = toExpr(c)
+ if (!e.hasLiteral) {
+ throw new IllegalArgumentException("values only accept literal
Column")
+ }
+ e.getLiteral
+ case v => toExpr(functions.lit(v)).getLiteral
}
new RelationalGroupedDataset(
df,
@@ -164,7 +167,7 @@ class RelationalGroupedDataset private[sql] (
Some(
proto.Aggregate.Pivot
.newBuilder()
- .setCol(pivotColumn.expr)
+ .setCol(toExpr(pivotColumn))
.addAllValues(valueExprs.asJava)
.build()))
case _ =>
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index f7998cf60eca..032ab670dab0 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -43,11 +43,10 @@ import org.apache.spark.sql.{Column, Encoder,
ExperimentalMethods, Observation,
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
BoxedLongEncoder, UnboundRowEncoder}
-import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr,
toTypedExpr}
+import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator,
SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
-import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
@@ -213,7 +212,7 @@ class SparkSession private[sql] (
val sqlCommand = proto.SqlCommand
.newBuilder()
.setSql(sqlText)
- .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava)
+ .addAllPosArguments(args.map(a =>
toLiteral(a)).toImmutableArraySeq.asJava)
.build()
sql(sqlCommand)
}
@@ -228,7 +227,7 @@ class SparkSession private[sql] (
val sqlCommand = proto.SqlCommand
.newBuilder()
.setSql(sqlText)
- .putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr)
}.asJava)
+ .putAllNamedArguments(args.asScala.map { case (k, v) => (k,
toLiteral(v)) }.asJava)
.build()
sql(sqlCommand)
}
@@ -653,11 +652,6 @@ class SparkSession private[sql] (
}
override private[sql] def isUsable: Boolean = client.isSessionValid
-
- implicit class RichColumn(c: Column) {
- def expr: proto.Expression = toExpr(c)
- def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e)
- }
}
// The minimal builder needed to create a spark session.
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
index f44ec5b2d504..f08b6e709f13 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala
@@ -24,7 +24,7 @@ import org.apache.spark.connect.proto.Expression
import
org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering.{SORT_NULLS_FIRST,
SORT_NULLS_LAST}
import
org.apache.spark.connect.proto.Expression.SortOrder.SortDirection.{SORT_DIRECTION_ASCENDING,
SORT_DIRECTION_DESCENDING}
import
org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBoundary,
FrameType}
-import org.apache.spark.sql.{Column, Encoder}
+import org.apache.spark.sql.{functions, Column, Encoder}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder
@@ -37,6 +37,8 @@ import org.apache.spark.sql.internal.{Alias,
CaseWhenOtherwise, Cast, ColumnNode
object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
def toExpr(column: Column): proto.Expression = apply(column.node, None)
+ def toLiteral(v: Any): proto.Expression = apply(functions.lit(v).node, None)
+
def toTypedExpr[I](column: Column, encoder: Encoder[I]): proto.Expression = {
apply(column.node, Option(encoder))
}
diff --git
a/sql/connect/common/src/test/resources/query-tests/queries/hint.json
b/sql/connect/common/src/test/resources/query-tests/queries/hint.json
index 2ac930c0a3a7..2348d0f84715 100644
--- a/sql/connect/common/src/test/resources/query-tests/queries/hint.json
+++ b/sql/connect/common/src/test/resources/query-tests/queries/hint.json
@@ -22,13 +22,13 @@
"stackTrace": [{
"classLoaderName": "app",
"declaringClass": "org.apache.spark.sql.connect.Dataset",
- "methodName": "~~trimmed~anonfun~~",
+ "methodName": "hint",
"fileName": "Dataset.scala"
}, {
"classLoaderName": "app",
- "declaringClass": "org.apache.spark.sql.connect.SparkSession",
- "methodName": "newDataset",
- "fileName": "SparkSession.scala"
+ "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+ "methodName": "~~trimmed~anonfun~~",
+ "fileName": "PlanGenerationTestSuite.scala"
}]
}
}
diff --git
a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin
b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin
index 06459ee5b765..ce7c63b57c47 100644
Binary files
a/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin and
b/sql/connect/common/src/test/resources/query-tests/queries/hint.proto.bin
differ
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]