This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ba8cae2031f [SPARK-43223][CONNECT] Typed agg, reduce functions, RelationalGroupedDataset#as ba8cae2031f is described below commit ba8cae2031f81dc326d386cbe7d19c1f0a8f239e Author: Zhen Li <zhenli...@users.noreply.github.com> AuthorDate: Mon May 15 11:05:33 2023 -0400 [SPARK-43223][CONNECT] Typed agg, reduce functions, RelationalGroupedDataset#as ### What changes were proposed in this pull request? Added the agg, reduce support in `KeyValueGroupedDataset`. Added `Dataset#reduce` Added `RelationalGroupedDataset#as`. Summary: * `KVGDS#agg`: `KVGDS#agg` and the `RelationalGroupedDS#agg` shares the exact same proto. The only difference is that the KVGDS always passing a UDF as the first grouping expression. That's also how we tell them apart in this PR. * `KVGDS#reduce`: Reduce is a special aggregation. The client uses an UnresolvedFunc "reduce" to mark the agg operator is a `ReduceAggregator` and calls `KVGDS#agg` directly. The server would be able to pick this func up directly and reuse the agg code path by sending in a `ReduceAggregator`. * `Dataset#reduce`: This is free after `KVGDS#reduce`. * `RelationalGroupedDS#as`: The only difference between `KVGDS` created using `ds#groupByKey` and `ds#agg#as` is the grouping expressions. The former requires one grouping func as the grouping expression, the latter uses a dummy func (to pass encoders/types to the server) + grouping expressions. Thus the server can count how many grouping expressions received and decide if the `KVGDS` should be created as `ds#groupByKey` or `ds#agg#as`. Followups: * [SPARK-43415] Support mapValues in the Agg functions. * [SPARK-43416] The tupled ProductEncoder dose not pick up the fields names from the server. ### Why are the changes needed? Missing APIs in Scala Client ### Does this PR introduce _any_ user-facing change? Added `KeyValueGrouppedDataset#agg, reduce`, `Dataset#reduce`, `RelationalGroupedDataset#as` methods for the Scala client. ### How was this patch tested? E2E tests Closes #40796 from zhenlineo/typed-agg. Authored-by: Zhen Li <zhenli...@users.noreply.github.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 66 +++-- .../apache/spark/sql/KeyValueGroupedDataset.scala | 255 ++++++++++++++++-- .../spark/sql/RelationalGroupedDataset.scala | 14 +- .../sql/KeyValueGroupedDatasetE2ETestSuite.scala | 290 ++++++++++++++++++--- .../sql/UserDefinedFunctionE2ETestSuite.scala | 18 ++ .../CheckConnectJvmClientCompatibility.scala | 8 - .../spark/sql/connect/client/util/QueryTest.scala | 36 ++- .../apache/spark/sql/connect/common/UdfUtils.scala | 4 + .../sql/connect/planner/SparkConnectPlanner.scala | 209 +++++++++++---- .../spark/sql/catalyst/plans/logical/object.scala | 16 ++ .../main/scala/org/apache/spark/sql/Column.scala | 13 +- .../apache/spark/sql/KeyValueGroupedDataset.scala | 15 +- .../spark/sql/RelationalGroupedDataset.scala | 53 ++-- .../spark/sql/expressions/ReduceAggregator.scala | 6 + .../apache/spark/sql/internal/TypedAggUtils.scala | 62 +++++ 15 files changed, 883 insertions(+), 182 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 555f6c312c5..7a680bde7d3 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1242,10 +1242,7 @@ class Dataset[T] private[sql] ( */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { - new RelationalGroupedDataset( - toDF(), - cols.map(_.expr), - proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) } /** @@ -1273,10 +1270,45 @@ class Dataset[T] private[sql] ( val colNames: Seq[String] = col1 +: cols new RelationalGroupedDataset( toDF(), - colNames.map(colName => Column(colName).expr), + colNames.map(colName => Column(colName)), proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) } + /** + * (Scala-specific) Reduces the elements of this Dataset using the specified binary function. + * The given `func` must be commutative and associative or the result may be non-deterministic. + * + * @group action + * @since 3.5.0 + */ + def reduce(func: (T, T) => T): T = { + val udf = ScalarUserDefinedFunction( + function = func, + inputEncoders = encoder :: encoder :: Nil, + outputEncoder = encoder) + val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr + + val result = sparkSession + .newDataset(encoder) { builder => + builder.getAggregateBuilder + .setInput(plan.getRoot) + .addAggregateExpressions(reduceExpr) + .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + } + .collect() + assert(result.length == 1) + result(0) + } + + /** + * (Java-specific) Reduces the elements of this Dataset using the specified binary function. The + * given `func` must be commutative and associative or the result may be non-deterministic. + * + * @group action + * @since 3.5.0 + */ + def reduce(func: ReduceFunction[T]): T = reduce(UdfUtils.mapReduceFuncToScalaFunc(func)) + /** * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given * key `func`. @@ -1285,15 +1317,7 @@ class Dataset[T] private[sql] ( * @since 3.5.0 */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - val kEncoder = encoderFor[K] - new KeyValueGroupedDatasetImpl[K, T, K, T]( - this, - sparkSession, - plan, - kEncoder, - kEncoder, - func, - UdfUtils.identical()) + KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) } /** @@ -1327,10 +1351,7 @@ class Dataset[T] private[sql] ( */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { - new RelationalGroupedDataset( - toDF(), - cols.map(_.expr), - proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) + new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) } /** @@ -1360,7 +1381,7 @@ class Dataset[T] private[sql] ( val colNames: Seq[String] = col1 +: cols new RelationalGroupedDataset( toDF(), - colNames.map(colName => Column(colName).expr), + colNames.map(colName => Column(colName)), proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) } @@ -1385,10 +1406,7 @@ class Dataset[T] private[sql] ( */ @scala.annotation.varargs def cube(cols: Column*): RelationalGroupedDataset = { - new RelationalGroupedDataset( - toDF(), - cols.map(_.expr), - proto.Aggregate.GroupType.GROUP_TYPE_CUBE) + new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_CUBE) } /** @@ -1417,7 +1435,7 @@ class Dataset[T] private[sql] ( val colNames: Seq[String] = col1 +: cols new RelationalGroupedDataset( toDF(), - colNames.map(colName => Column(colName).expr), + colNames.map(colName => Column(colName)), proto.Aggregate.GroupType.GROUP_TYPE_CUBE) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 2d712bc4c51..7b2fa3b52be 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -25,6 +25,7 @@ import scala.language.existentials import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.ScalarUserDefinedFunction import org.apache.spark.sql.functions.col @@ -239,6 +240,153 @@ abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder) } + /** + * (Scala-specific) Reduces the elements of each group of data using the specified binary + * function. The given function must be commutative and associative or the result may be + * non-deterministic. + * + * @since 3.5.0 + */ + def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { + throw new UnsupportedOperationException + } + + /** + * (Java-specific) Reduces the elements of each group of data using the specified binary + * function. The given function must be commutative and associative or the result may be + * non-deterministic. + * + * @since 3.5.0 + */ + def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { + reduceGroups(UdfUtils.mapReduceFuncToScalaFunc(f)) + } + + /** + * Internal helper function for building typed aggregations that return tuples. For simplicity + * and code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + */ + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + throw new UnsupportedOperationException + } + + /** + * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key and the + * result of computing this aggregation over all elements in the group. + * + * @since 3.5.0 + */ + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.5.0 + */ + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.5.0 + */ + def agg[U1, U2, U3]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.5.0 + */ + def agg[U1, U2, U3, U4]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.5.0 + */ + def agg[U1, U2, U3, U4, U5]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = + aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.5.0 + */ + def agg[U1, U2, U3, U4, U5, U6]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = + aggUntyped(col1, col2, col3, col4, col5, col6) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.5.0 + */ + def agg[U1, U2, U3, U4, U5, U6, U7]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = + aggUntyped(col1, col2, col3, col4, col5, col6, col7) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and + * the result of computing these aggregations over all elements in the group. + * + * @since 3.5.0 + */ + def agg[U1, U2, U3, U4, U5, U6, U7, U8]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7], + col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = + aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8) + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] + + /** + * Returns a [[Dataset]] that contains a tuple with each key and the number of items present for + * that key. + * + * @since 3.5.0 + */ + def count(): Dataset[(K, Long)] = agg(functions.count("*")) + /** * (Scala-specific) Applies the given function to each cogrouped data. For each unique group, * the function will be passed the grouping key and 2 iterators containing all elements in the @@ -322,41 +470,45 @@ abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable * [[KeyValueGroupedDataset]] behaves on the server. */ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( - private val ds: Dataset[IV], private val sparkSession: SparkSession, private val plan: proto.Plan, private val ikEncoder: AgnosticEncoder[IK], private val kEncoder: AgnosticEncoder[K], - private val groupingFunc: IV => IK, - private val valueMapFunc: IV => V) + private val ivEncoder: AgnosticEncoder[IV], + private val vEncoder: AgnosticEncoder[V], + private val groupingExprs: java.util.List[proto.Expression], + private val valueMapFunc: IV => V, + private val keysFunc: () => Dataset[IK]) extends KeyValueGroupedDataset[K, V] { - private val ivEncoder = ds.encoder - override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = { new KeyValueGroupedDatasetImpl[L, V, IK, IV]( - ds, sparkSession, plan, ikEncoder, encoderFor[L], - groupingFunc, - valueMapFunc) + ivEncoder, + vEncoder, + groupingExprs, + valueMapFunc, + keysFunc) } override def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = { new KeyValueGroupedDatasetImpl[K, W, IK, IV]( - ds, sparkSession, plan, ikEncoder, kEncoder, - groupingFunc, - valueMapFunc.andThen(valueFunc)) + ivEncoder, + encoderFor[W], + groupingExprs, + valueMapFunc.andThen(valueFunc), + keysFunc) } override def keys: Dataset[K] = { - ds.map(groupingFunc)(ikEncoder) + keysFunc() .dropDuplicates() .as(kEncoder) } @@ -371,7 +523,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( builder.getGroupMapBuilder .setInput(plan.getRoot) .addAllSortingExpressions(sortExprs.map(e => e.expr).asJava) - .addAllGroupingExpressions(getGroupingExpressions) + .addAllGroupingExpressions(groupingExprs) .setFunc(getUdf(nf, outputEncoder)(ivEncoder)) } } @@ -387,21 +539,37 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( sparkSession.newDataset[R](outputEncoder) { builder => builder.getCoGroupMapBuilder .setInput(plan.getRoot) - .addAllInputGroupingExpressions(getGroupingExpressions) + .addAllInputGroupingExpressions(groupingExprs) .addAllInputSortingExpressions(thisSortExprs.map(e => e.expr).asJava) .setOther(otherImpl.plan.getRoot) - .addAllOtherGroupingExpressions(otherImpl.getGroupingExpressions) + .addAllOtherGroupingExpressions(otherImpl.groupingExprs) .addAllOtherSortingExpressions(otherSortExprs.map(e => e.expr).asJava) .setFunc(getUdf(nf, outputEncoder)(ivEncoder, otherImpl.ivEncoder)) } } - private def getGroupingExpressions = { - val gf = ScalarUserDefinedFunction( - function = groupingFunc, - inputEncoders = ivEncoder :: Nil, // Using the original value and key encoders - outputEncoder = ikEncoder) - Arrays.asList(gf.apply(col("*")).expr) + override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + // TODO(SPARK-43415): For each column, apply the valueMap func first + val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(_.encoder)) // apply keyAs change + sparkSession.newDataset(rEnc) { builder => + builder.getAggregateBuilder + .setInput(plan.getRoot) + .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + .addAllGroupingExpressions(groupingExprs) + .addAllAggregateExpressions(columns.map(_.expr).asJava) + } + } + + override def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { + val inputEncoders = Seq(vEncoder, vEncoder) + val udf = ScalarUserDefinedFunction( + function = f, + inputEncoders = inputEncoders, + outputEncoder = vEncoder) + val input = udf.apply(inputEncoders.map(_ => col("*")): _*) + val expr = Column.fn("reduce", input).expr + val aggregator: TypedColumn[V, V] = new TypedColumn[V, V](expr, vEncoder) + agg(aggregator) } private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: AgnosticEncoder[U])( @@ -414,3 +582,48 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction } } + +private object KeyValueGroupedDatasetImpl { + def apply[K, V]( + ds: Dataset[V], + kEncoder: AgnosticEncoder[K], + groupingFunc: V => K): KeyValueGroupedDatasetImpl[K, V, K, V] = { + val gf = ScalarUserDefinedFunction( + function = groupingFunc, + inputEncoders = ds.encoder :: Nil, // Using the original value and key encoders + outputEncoder = kEncoder) + new KeyValueGroupedDatasetImpl( + ds.sparkSession, + ds.plan, + kEncoder, + kEncoder, + ds.encoder, + ds.encoder, + Arrays.asList(gf.apply(col("*")).expr), + UdfUtils.identical(), + () => ds.map(groupingFunc)(kEncoder)) + } + + def apply[K, V]( + df: DataFrame, + kEncoder: AgnosticEncoder[K], + vEncoder: AgnosticEncoder[V], + groupingExprs: Seq[Column]): KeyValueGroupedDatasetImpl[K, V, K, V] = { + // Use a dummy udf to pass the K V encoders + val dummyGroupingFunc = ScalarUserDefinedFunction( + function = UdfUtils.noOp[V, K](), + inputEncoders = vEncoder :: Nil, + outputEncoder = kEncoder).apply(col("*")) + + new KeyValueGroupedDatasetImpl( + df.sparkSession, + df.plan, + kEncoder, + kEncoder, + vEncoder, + vEncoder, + (Seq(dummyGroupingFunc) ++ groupingExprs).map(_.expr).asJava, + UdfUtils.identical(), + () => df.select(groupingExprs: _*).as(kEncoder)) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 5a10e1d52eb..c19314a0d5c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -37,7 +37,7 @@ import org.apache.spark.connect.proto */ class RelationalGroupedDataset private[sql] ( private[sql] val df: DataFrame, - private[sql] val groupingExprs: Seq[proto.Expression], + private[sql] val groupingExprs: Seq[Column], groupType: proto.Aggregate.GroupType, pivot: Option[proto.Aggregate.Pivot] = None) { @@ -45,7 +45,7 @@ class RelationalGroupedDataset private[sql] ( df.sparkSession.newDataFrame { builder => builder.getAggregateBuilder .setInput(df.plan.getRoot) - .addAllGroupingExpressions(groupingExprs.asJava) + .addAllGroupingExpressions(groupingExprs.map(_.expr).asJava) .addAllAggregateExpressions(aggExprs.map(e => e.expr).asJava) groupType match { @@ -65,6 +65,16 @@ class RelationalGroupedDataset private[sql] ( } } + /** + * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions of + * current `RelationalGroupedDataset`. + * + * @since 3.5.0 + */ + def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = { + KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs) + } + /** * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The * resulting `DataFrame` will also contain the grouping columns. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index 097efa01a42..e7a77eed70d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -20,13 +20,17 @@ import java.util.Arrays import io.grpc.StatusRuntimeException -import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.connect.client.util.QueryTest +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ /** - * All tests in this class requires client UDF artifacts synced with the server. TODO: It means - * these tests only works with SBT for now. + * All tests in this class requires client UDF artifacts synced with the server. */ -class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { +class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { + + lazy val session: SparkSession = spark + import session.implicits._ test("mapGroups") { val session: SparkSession = spark @@ -40,8 +44,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { } test("flatGroupMap") { - val session: SparkSession = spark - import session.implicits._ val values = spark .range(10) .groupByKey(v => v % 2) @@ -51,8 +53,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { } test("keys") { - val session: SparkSession = spark - import session.implicits._ val values = spark .range(10) .groupByKey(v => v % 2) @@ -63,8 +63,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { test("keyAs - keys") { // It is okay to cast from Long to Double, but not Long to Int. - val session: SparkSession = spark - import session.implicits._ val values = spark .range(10) .groupByKey(v => v % 2) @@ -75,8 +73,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { } test("keyAs - flatGroupMap") { - val session: SparkSession = spark - import session.implicits._ val values = spark .range(10) .groupByKey(v => v % 2) @@ -87,8 +83,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { } test("keyAs mapValues - cogroup") { - val session: SparkSession = spark - import session.implicits._ val grouped = spark .range(10) .groupByKey(v => v % 2) @@ -120,8 +114,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { } test("mapValues - flatGroupMap") { - val session: SparkSession = spark - import session.implicits._ val values = spark .range(10) .groupByKey(v => v % 2) @@ -132,8 +124,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { } test("mapValues - keys") { - val session: SparkSession = spark - import session.implicits._ val values = spark .range(10) .groupByKey(v => v % 2) @@ -144,13 +134,11 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { } test("flatMapSortedGroups") { - val session: SparkSession = spark - import session.implicits._ val grouped = spark .range(10) .groupByKey(v => v % 2) val values = grouped - .flatMapSortedGroups(functions.desc("id")) { (g, iter) => + .flatMapSortedGroups(desc("id")) { (g, iter) => Iterator(String.valueOf(g), iter.mkString(",")) } .collectAsList() @@ -160,7 +148,7 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { // Star is not allowed as group sort column val message = intercept[StatusRuntimeException] { grouped - .flatMapSortedGroups(functions.col("*")) { (g, iter) => + .flatMapSortedGroups(col("*")) { (g, iter) => Iterator(String.valueOf(g), iter.mkString(",")) } .collectAsList() @@ -169,8 +157,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { } test("cogroup") { - val session: SparkSession = spark - import session.implicits._ val grouped = spark .range(10) .groupByKey(v => v % 2) @@ -187,8 +173,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { } test("cogroupSorted") { - val session: SparkSession = spark - import session.implicits._ val grouped = spark .range(10) .groupByKey(v => v % 2) @@ -196,9 +180,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { .range(10) .groupByKey(v => v / 2) val values = grouped - .cogroupSorted(otherGrouped)(functions.desc("id"))(functions.desc("id")) { - (k, it, otherIt) => - Iterator(String.valueOf(k), it.mkString(",") + ";" + otherIt.mkString(",")) + .cogroupSorted(otherGrouped)(desc("id"))(desc("id")) { (k, it, otherIt) => + Iterator(String.valueOf(k), it.mkString(",") + ";" + otherIt.mkString(",")) } .collectAsList() @@ -215,4 +198,253 @@ class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession { "4", ";9,8")) } + + test("agg, keyAs") { + val ds = spark + .range(10) + .groupByKey(v => v % 2) + .keyAs[Double] + .agg(count("*")) + + checkDatasetUnorderly(ds, (0.0, 5L), (1.0, 5L)) + } + + test("typed aggregation: expr") { + val session: SparkSession = spark + import session.implicits._ + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg(sum("_2").as[Long]), + ("a", 30L), + ("b", 3L), + ("c", 1L)) + } + + test("typed aggregation: expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), + ("a", 30L, 32L), + ("b", 3L, 5L), + ("c", 1L, 2L)) + } + + test("typed aggregation: expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), + ("a", 30L, 32L, 2L), + ("b", 3L, 5L, 2L), + ("c", 1L, 2L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1) + .agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double]), + ("a", 30L, 32L, 2L, 15.0), + ("b", 3L, 5L, 2L, 1.5), + ("c", 1L, 2L, 1L, 1.0)) + } + + test("typed aggregation: expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1) + .agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long]), + ("a", 30L, 32L, 2L, 15.0, 2L), + ("b", 3L, 5L, 2L, 1.5, 2L), + ("c", 1L, 2L, 1L, 1.0, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1) + .agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long], + max("_2").as[Long]), + ("a", 30L, 32L, 2L, 15.0, 2L, 20L), + ("b", 3L, 5L, 2L, 1.5, 2L, 2L), + ("c", 1L, 2L, 1L, 1.0, 1L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1) + .agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long], + max("_2").as[Long], + min("_2").as[Long]), + ("a", 30L, 32L, 2L, 15.0, 2L, 20L, 10L), + ("b", 3L, 5L, 2L, 1.5, 2L, 2L, 1L), + ("c", 1L, 2L, 1L, 1.0, 1L, 1L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr, expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDatasetUnorderly( + ds.groupByKey(_._1) + .agg( + sum("_2").as[Long], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double], + countDistinct("*").as[Long], + max("_2").as[Long], + min("_2").as[Long], + mean("_2").as[Double]), + ("a", 30L, 32L, 2L, 15.0, 2L, 20L, 10L, 15.0), + ("b", 3L, 5L, 2L, 1.5, 2L, 2L, 1L, 1.5), + ("c", 1L, 2L, 1L, 1.0, 1L, 1L, 1L, 1.0)) + } + + test("SPARK-24762: Enable top-level Option of Product encoders") { + val data = Seq(Some((1, "a")), Some((2, "b")), None) + val ds = data.toDS() + + checkDataset(ds, data: _*) + + val schema = new StructType().add( + "value", + new StructType() + .add("_1", IntegerType, nullable = false) + .add("_2", StringType, nullable = true), + nullable = true) + + assert(ds.schema == schema) + + val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0))) + val nestedDs = nestedOptData.toDS() + + checkDataset(nestedDs, nestedOptData: _*) + + val nestedSchema = StructType( + Seq(StructField( + "value", + StructType(Seq( + StructField( + "_1", + StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true)))), + StructField("_2", DoubleType, nullable = false))), + nullable = true))) + assert(nestedDs.schema == nestedSchema) + } + + test("SPARK-24762: Resolving Option[Product] field") { + val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null)) + .toDS() + .as[(Int, Option[(String, Double)])] + checkDataset(ds, (1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None)) + } + + test("SPARK-24762: select Option[Product] field") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]]) + checkDataset(ds1, Some((1, 2)), Some((2, 3)), Some((3, 4))) + + val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1), null)").as[Option[(Int, Int)]]) + checkDataset(ds2, None, None, Some((3, 4))) + } + + test("SPARK-24762: typed agg on Option[Product] type") { + val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS() + assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1))) + + assert( + ds.groupByKey(x => x).count().collect() === + Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1))) + } + + test("SPARK-25942: typed aggregation on primitive type") { + val ds = Seq(1, 2, 3).toDS() + + val agg = ds + .groupByKey(_ >= 2) + .agg(sum("value").as[Long], sum($"value" + 1).as[Long]) + checkDatasetUnorderly(agg, (false, 1L, 2L), (true, 5L, 7L)) + } + + test("SPARK-25942: typed aggregation on product type") { + val ds = Seq((1, 2), (2, 3), (3, 4)).toDS() + val agg = ds.groupByKey(x => x).agg(sum("_1").as[Long], sum($"_2" + 1).as[Long]) + checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 3L, 5L)) + } + + test("SPARK-26085: fix key attribute name for atomic type for typed aggregation") { + // TODO(SPARK-43416): Recursively rename the position based tuple to the schema name from the + // server. + val ds = Seq(1, 2, 3).toDS() + assert(ds.groupByKey(x => x).count().schema.head.name == "_1") + + // Enable legacy flag to follow previous Spark behavior + withSQLConf("spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue" -> "true") { + assert(ds.groupByKey(x => x).count().schema.head.name == "_1") + } + } + + test("reduceGroups") { + val ds = Seq("abc", "xyz", "hello").toDS() + checkDatasetUnorderly( + ds.groupByKey(_.length).reduceGroups(_ + _), + (3, "abcxyz"), + (5, "hello")) + } + + test("groupby") { + val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) + .toDF("key", "seq", "value") + val grouped = ds.groupBy($"key").as[String, (String, Int, Int)] + val aggregated = grouped + .flatMapSortedGroups($"seq", expr("length(key)"), $"value") { (g, iter) => + Iterator(g, iter.mkString(", ")) + } + + checkDatasetUnorderly( + aggregated, + "a", + "(a,1,10), (a,2,20)", + "b", + "(b,1,2), (b,2,1)", + "c", + "(c,1,1)") + } + + test("groupby - keyAs, keys") { + val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) + .toDF("key", "seq", "value") + val grouped = ds.groupBy($"value").as[String, (String, Int, Int)] + val keys = grouped.keyAs[String].keys.sort($"value") + + checkDataset(keys, "1", "2", "10", "20") + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index b07d1459df5..b5bbee67803 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -197,4 +197,22 @@ class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession { spark.range(10).repartition(1).foreachPartition(func) assert(sum.get() == 0) // The value is not 45 } + + test("Dataset reduce") { + val session: SparkSession = spark + import session.implicits._ + assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55) + } + + test("Dataset reduce - java") { + val session: SparkSession = spark + import session.implicits._ + assert( + spark + .range(10) + .map(_ + 1) + .reduce(new ReduceFunction[Long] { + override def call(v1: Long, v2: Long): Long = v1 + v2 + }) == 55) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 28a28994a76..32a44c350d9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -165,7 +165,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.joinWith"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), // protected - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.reduce"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"), // deprecated ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"), @@ -190,17 +189,10 @@ object CheckConnectJvmClientCompatibility { ), // streaming ProblemFilters.exclude[Problem]( "org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.reduceGroups"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.agg"), - ProblemFilters.exclude[Problem]( - "org.apache.spark.sql.KeyValueGroupedDataset.aggUntyped" - ), // protected internal - ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.count"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.this"), // RelationalGroupedDataset ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.as"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.this"), // SparkSession diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala index 1c3f49f897f..fdbb3edbf84 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala @@ -21,7 +21,7 @@ import java.util.TimeZone import org.scalatest.Assertions -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.util.sideBySide abstract class QueryTest extends RemoteSparkSession { @@ -45,6 +45,40 @@ abstract class QueryTest extends RemoteSparkSession { protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { checkAnswer(df, expectedAnswer.collect()) } + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer. + */ + protected def checkDataset[T](ds: => Dataset[T], expectedAnswer: T*): Unit = { + val result = ds.collect() + + if (!QueryTest.compare(result.toSeq, expectedAnswer)) { + fail(s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + """.stripMargin) + } + } + + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer, after sort. + */ + protected def checkDatasetUnorderly[T: Ordering]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { + val result = ds.collect() + + if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) { + fail(s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + """.stripMargin) + } + } } object QueryTest extends Assertions { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala index 7cd251b245f..06a6c74f268 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala @@ -95,5 +95,9 @@ private[sql] object UdfUtils extends Serializable { } } + def mapReduceFuncToScalaFunc[T](func: ReduceFunction[T]): (T, T) => T = func.call + def identical[T](): T => T = t => t + + def noOp[V, K](): V => K = _ => null.asInstanceOf[K] } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index b86ed866d6e..8562722a95b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -40,7 +40,7 @@ import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.Streami import org.apache.spark.connect.proto.WriteStreamOperationStart import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase import org.apache.spark.ml.{functions => MLFunctions} -import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{Column, Dataset, Encoders, RelationalGroupedDataset, SparkSession} import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} @@ -67,7 +67,8 @@ import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartiti import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper -import org.apache.spark.sql.internal.CatalogImpl +import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils} import org.apache.spark.sql.streaming.Trigger import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -526,7 +527,7 @@ class SparkConnectPlanner(val session: SparkSession) { private def transformTypedMapPartitions( fun: proto.CommonInlineUserDefinedFunction, child: LogicalPlan): LogicalPlan = { - val udf = ScalaUdf(fun) + val udf = TypedScalaUdf(fun) val deserialized = DeserializeToObject(udf.inputDeserializer(), udf.inputObjAttr, child) val mapped = MapPartitions( udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]], @@ -562,7 +563,7 @@ class SparkConnectPlanner(val session: SparkSession) { private def transformTypedGroupMap( rel: proto.GroupMap, commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = { - val udf = ScalaUdf(commonUdf) + val udf = TypedScalaUdf(commonUdf) val ds = UntypedKeyValueGroupedDataset( rel.getInput, rel.getGroupingExpressionsList, @@ -614,7 +615,7 @@ class SparkConnectPlanner(val session: SparkSession) { private def transformTypedCoGroupMap( rel: proto.CoGroupMap, commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = { - val udf = ScalaUdf(commonUdf) + val udf = TypedScalaUdf(commonUdf) val left = UntypedKeyValueGroupedDataset( rel.getInput, rel.getInputGroupingExpressionsList, @@ -644,57 +645,89 @@ class SparkConnectPlanner(val session: SparkSession) { } /** - * This is the untyped version of [[KeyValueGroupedDataset]]. + * This is the untyped version of [[org.apache.spark.sql.KeyValueGroupedDataset]]. */ private case class UntypedKeyValueGroupedDataset( kEncoder: ExpressionEncoder[_], vEncoder: ExpressionEncoder[_], - valueDeserializer: Expression, analyzed: LogicalPlan, dataAttributes: Seq[Attribute], groupingAttributes: Seq[Attribute], - sortOrder: Seq[SortOrder]) + sortOrder: Seq[SortOrder]) { + val valueDeserializer: Expression = + UnresolvedDeserializer(vEncoder.deserializer, dataAttributes) + } + private object UntypedKeyValueGroupedDataset { def apply( input: proto.Relation, groupingExprs: java.util.List[proto.Expression], sortingExprs: java.util.List[proto.Expression]): UntypedKeyValueGroupedDataset = { - val logicalPlan = transformRelation(input) - assert(groupingExprs.size() == 1) - val groupFunc = groupingExprs.asScala.toSeq - .map(expr => unpackUdf(expr.getCommonInlineUserDefinedFunction)) - .head - - assert(groupFunc.inputEncoders.size == 1) - val vEnc = ExpressionEncoder(groupFunc.inputEncoders.head) - val kEnc = ExpressionEncoder(groupFunc.outputEncoder) - - val withGroupingKey = new AppendColumns( - groupFunc.function.asInstanceOf[Any => Any], - vEnc.clsTag.runtimeClass, - vEnc.schema, - UnresolvedDeserializer(vEnc.deserializer), - kEnc.namedExpressions, - logicalPlan) - - // The input logical plan of KeyValueGroupedDataset need to be executed and analyzed - val analyzed = session.sessionState.executePlan(withGroupingKey).analyzed - val dataAttributes = logicalPlan.output - val groupingAttributes = withGroupingKey.newColumns - val valueDeserializer = UnresolvedDeserializer(vEnc.deserializer, dataAttributes) // Compute sort order val sortExprs = sortingExprs.asScala.toSeq.map(expr => transformExpression(expr)) val sortOrder: Seq[SortOrder] = MapGroups.sortOrder(sortExprs) + apply(transformRelation(input), groupingExprs, sortOrder) + } + + def apply( + logicalPlan: LogicalPlan, + groupingExprs: java.util.List[proto.Expression], + sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = { + // If created via ds#groupByKey, then there should be only one groupingFunc. + // If created via relationalGroupedDS#as, then we are expecting a dummy groupingFuc + // (for types) + groupingExprs + if (groupingExprs.size() == 1) { + createFromGroupByKeyFunc(logicalPlan, groupingExprs, sortOrder) + } else if (groupingExprs.size() > 1) { + createFromRelationalDataset(logicalPlan, groupingExprs, sortOrder) + } else { + throw InvalidPlanInput( + "The grouping expression cannot be absent for KeyValueGroupedDataset") + } + } + + private def createFromRelationalDataset( + logicalPlan: LogicalPlan, + groupingExprs: java.util.List[proto.Expression], + sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = { + assert(groupingExprs.size() >= 1) + val dummyFunc = TypedScalaUdf(groupingExprs.get(0)) + val groupExprs = groupingExprs.asScala.toSeq.drop(1).map(expr => transformExpression(expr)) + + val (qe, aliasedGroupings) = + RelationalGroupedDataset.handleGroupingExpression(logicalPlan, session, groupExprs) + + UntypedKeyValueGroupedDataset( + dummyFunc.outEnc, + dummyFunc.inEnc, + qe.analyzed, + logicalPlan.output, + aliasedGroupings, + sortOrder) + } + + private def createFromGroupByKeyFunc( + logicalPlan: LogicalPlan, + groupingExprs: java.util.List[proto.Expression], + sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = { + assert(groupingExprs.size() == 1) + val groupFunc = TypedScalaUdf(groupingExprs.get(0)) + val vEnc = groupFunc.inEnc + val kEnc = groupFunc.outEnc + + val withGroupingKey = AppendColumns(groupFunc.function, vEnc, kEnc, logicalPlan) + // The input logical plan of KeyValueGroupedDataset need to be executed and analyzed + val analyzed = session.sessionState.executePlan(withGroupingKey).analyzed + UntypedKeyValueGroupedDataset( kEnc, vEnc, - valueDeserializer, analyzed, - dataAttributes, - groupingAttributes, + logicalPlan.output, + withGroupingKey.newColumns, sortOrder) } } @@ -702,7 +735,7 @@ class SparkConnectPlanner(val session: SparkSession) { /** * The UDF used in typed APIs, where the input column is absent. */ - private case class ScalaUdf( + private case class TypedScalaUdf( function: AnyRef, outEnc: ExpressionEncoder[_], outputObjAttr: Attribute, @@ -713,17 +746,26 @@ class SparkConnectPlanner(val session: SparkSession) { UnresolvedDeserializer(inEnc.deserializer, inputAttributes) } } - private object ScalaUdf { - def apply(commonUdf: proto.CommonInlineUserDefinedFunction): ScalaUdf = { + private object TypedScalaUdf { + def apply(expr: proto.Expression): TypedScalaUdf = { + if (expr.hasCommonInlineUserDefinedFunction + && expr.getCommonInlineUserDefinedFunction.hasScalarScalaUdf) { + apply(expr.getCommonInlineUserDefinedFunction) + } else { + throw InvalidPlanInput(s"Expecting a Scala UDF, but get ${expr.getExprTypeCase}") + } + } + + def apply(commonUdf: proto.CommonInlineUserDefinedFunction): TypedScalaUdf = { val udf = unpackUdf(commonUdf) val outEnc = ExpressionEncoder(udf.outputEncoder) // There might be more than one inputs, but we only interested in the first one. // Most typed API takes one UDF input. // For the few that takes more than one inputs, e.g. grouping function mapping UDFs, - // we only interested in the first input which is the key of the grouping function. + // the first input which is the key of the grouping function. assert(udf.inputEncoders.nonEmpty) val inEnc = ExpressionEncoder(udf.inputEncoders.head) // single input encoder or key encoder - ScalaUdf(udf.function, outEnc, generateObjAttr(outEnc), inEnc, generateObjAttr(inEnc)) + TypedScalaUdf(udf.function, outEnc, generateObjAttr(outEnc), inEnc, generateObjAttr(inEnc)) } } @@ -1117,27 +1159,31 @@ class SparkConnectPlanner(val session: SparkSession) { assert(rel.hasInput) val baseRel = transformRelation(rel.getInput) val cond = rel.getCondition - cond.getExprTypeCase match { - case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION - if isTypedFilter(cond.getCommonInlineUserDefinedFunction) => - transformTypedFilter(cond.getCommonInlineUserDefinedFunction, baseRel) - case _ => - logical.Filter(condition = transformExpression(cond), child = baseRel) + if (isTypedScalaUdfExpr(cond)) { + transformTypedFilter(cond.getCommonInlineUserDefinedFunction, baseRel) + } else { + logical.Filter(condition = transformExpression(cond), child = baseRel) } } - private def isTypedFilter(udf: proto.CommonInlineUserDefinedFunction): Boolean = { - // It is a scala udf && the udf argument is an unresolved start. - // This means the udf is a typed filter to filter on all inputs - udf.getFunctionCase == proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF && - udf.getArgumentsCount == 1 && - udf.getArguments(0).getExprTypeCase == proto.Expression.ExprTypeCase.UNRESOLVED_STAR + private def isTypedScalaUdfExpr(expr: proto.Expression): Boolean = { + expr.getExprTypeCase match { + case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION => + val udf = expr.getCommonInlineUserDefinedFunction + // A typed scala udf is a scala udf && the udf argument is an unresolved start. + udf.getFunctionCase == + proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF && + udf.getArgumentsCount == 1 && + udf.getArguments(0).getExprTypeCase == proto.Expression.ExprTypeCase.UNRESOLVED_STAR + case _ => + false + } } private def transformTypedFilter( fun: proto.CommonInlineUserDefinedFunction, child: LogicalPlan): TypedFilter = { - val udf = ScalaUdf(fun) + val udf = TypedScalaUdf(fun) TypedFilter(udf.function, child)(udf.inEnc) } @@ -1853,13 +1899,39 @@ class SparkConnectPlanner(val session: SparkSession) { } private def transformAggregate(rel: proto.Aggregate): LogicalPlan = { + rel.getGroupType match { + case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY + // This relies on the assumption that a KVGDS always requires the head to be a Typed UDF. + // This is the case for datasets created via groupByKey, + // and also via RelationalGroupedDS#as, as the first is a dummy UDF currently. + if rel.getGroupingExpressionsList.size() >= 1 && + isTypedScalaUdfExpr(rel.getGroupingExpressionsList.get(0)) => + transformKeyValueGroupedAggregate(rel) + case _ => + transformRelationalGroupedAggregate(rel) + } + } + + private def transformKeyValueGroupedAggregate(rel: proto.Aggregate): LogicalPlan = { + val input = transformRelation(rel.getInput) + val ds = UntypedKeyValueGroupedDataset(input, rel.getGroupingExpressionsList, Seq.empty) + + val keyColumn = TypedAggUtils.aggKeyColumn(ds.kEncoder, ds.groupingAttributes) + val namedColumns = rel.getAggregateExpressionsList.asScala.toSeq + .map(expr => transformExpressionWithTypedReduceExpression(expr, input)) + .map(toNamedExpression) + logical.Aggregate(ds.groupingAttributes, keyColumn +: namedColumns, ds.analyzed) + } + + private def transformRelationalGroupedAggregate(rel: proto.Aggregate): LogicalPlan = { if (!rel.hasInput) { throw InvalidPlanInput("Aggregate needs a plan input") } val input = transformRelation(rel.getInput) val groupingExprs = rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression) - val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq.map(transformExpression) + val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq + .map(expr => transformExpressionWithTypedReduceExpression(expr, input)) val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression) rel.getGroupType match { @@ -1917,6 +1989,37 @@ class SparkConnectPlanner(val session: SparkSession) { } } + private def transformTypedReduceExpression( + fun: proto.Expression.UnresolvedFunction, + dataAttributes: Seq[Attribute]): Expression = { + assert(fun.getFunctionName == "reduce") + if (fun.getArgumentsCount != 1) { + throw InvalidPlanInput("reduce requires single child expression") + } + val udf = fun.getArgumentsList.asScala.toSeq.map(transformExpression) match { + case Seq(f: ScalaUDF) => + f + case other => + throw InvalidPlanInput(s"reduce should carry a scalar scala udf, but got $other") + } + assert(udf.outputEncoder.isDefined) + val tEncoder = udf.outputEncoder.get // (T, T) => T + val reduce = ReduceAggregator(udf.function)(tEncoder).toColumn.expr + TypedAggUtils.withInputType(reduce, tEncoder, dataAttributes) + } + + private def transformExpressionWithTypedReduceExpression( + expr: proto.Expression, + plan: LogicalPlan): Expression = { + expr.getExprTypeCase match { + case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION + if expr.getUnresolvedFunction.getFunctionName == "reduce" => + // The reduce func needs the input data attribute, thus handle it specially here + transformTypedReduceExpression(expr.getUnresolvedFunction, plan.output) + case _ => transformExpression(expr) + } + } + def process( command: proto.Command, userId: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 35640c4aec8..54c0b84ff52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -337,6 +337,22 @@ object AppendColumns { encoderFor[U].namedExpressions, child) } + + private[sql] def apply( + func: AnyRef, + inEncoder: ExpressionEncoder[_], + outEncoder: ExpressionEncoder[_], + child: LogicalPlan, + inputAttributes: Seq[Attribute] = Nil): AppendColumns = { + new AppendColumns( + func.asInstanceOf[Any => Any], + inEncoder.clsTag.runtimeClass, + inEncoder.schema, + UnresolvedDeserializer(inEncoder.deserializer, inputAttributes), + outEncoder.namedExpressions, + child + ) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3c9f3e58cec..8d4cb50dfa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.TypedAggUtils import org.apache.spark.sql.types._ private[sql] object Column { @@ -81,17 +82,7 @@ class TypedColumn[-T, U]( private[sql] def withInputType( inputEncoder: ExpressionEncoder[_], inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { - val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes) - - // This only inserts inputs into typed aggregate expressions. For untyped aggregate expressions, - // the resolving is handled in the analyzer directly. - val newExpr = expr transform { - case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => - ta.withInputInfo( - deser = unresolvedDeserializer, - cls = inputEncoder.clsTag.runtimeClass, - schema = inputEncoder.schema) - } + val newExpr = TypedAggUtils.withInputType(expr, inputEncoder, inputAttributes) new TypedColumn[T, U](newExpr, encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a8e78043c28..4c2ccb27eab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -21,11 +21,11 @@ import scala.collection.JavaConverters._ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, CreateStruct, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.TypedAggUtils import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} /** @@ -673,16 +673,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) { - assert(groupingAttributes.length == 1) - if (SQLConf.get.nameNonStructGroupingKeyAsValue) { - groupingAttributes.head - } else { - Alias(groupingAttributes.head, "key")() - } - } else { - Alias(CreateStruct(groupingAttributes), "key")() - } + val keyColumn = TypedAggUtils.aggKeyColumn(kExprEnc, groupingAttributes) val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sparkSession, aggregate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 31c303921f3..29138b5bf58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} @@ -53,6 +54,7 @@ class RelationalGroupedDataset protected[sql]( private[sql] val df: DataFrame, private[sql] val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { + import RelationalGroupedDataset._ private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) { @@ -85,14 +87,6 @@ class RelationalGroupedDataset protected[sql]( } } - private[this] def alias(expr: Expression): NamedExpression = expr match { - case expr: NamedExpression => expr - case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => - UnresolvedAlias(a, Some(Column.generateAlias)) - case u: UnresolvedFunction => UnresolvedAlias(expr, None) - case expr: Expression => Alias(expr, toPrettySQL(expr))() - } - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) : DataFrame = { @@ -143,25 +137,15 @@ class RelationalGroupedDataset protected[sql]( val keyEncoder = encoderFor[K] val valueEncoder = encoderFor[T] - // Resolves grouping expressions. - val dummyPlan = Project(groupingExprs.map(alias), LocalRelation(df.logicalPlan.output)) - val analyzedPlan = df.sparkSession.sessionState.analyzer.execute(dummyPlan) - .asInstanceOf[Project] - df.sparkSession.sessionState.analyzer.checkAnalysis(analyzedPlan) - val aliasedGroupings = analyzedPlan.projectList - - // Adds the grouping expressions that are not in base DataFrame into outputs. - val addedCols = aliasedGroupings.filter(g => !df.logicalPlan.outputSet.contains(g.toAttribute)) - val qe = Dataset.ofRows( - df.sparkSession, - Project(df.logicalPlan.output ++ addedCols, df.logicalPlan)).queryExecution + val (qe, groupingAttributes) = + handleGroupingExpression(df.logicalPlan, df.sparkSession, groupingExprs) new KeyValueGroupedDataset( keyEncoder, valueEncoder, qe, df.logicalPlan.output, - aliasedGroupings.map(_.toAttribute)) + groupingAttributes) } /** @@ -700,6 +684,33 @@ private[sql] object RelationalGroupedDataset { new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType) } + private[sql] def handleGroupingExpression( + logicalPlan: LogicalPlan, + sparkSession: SparkSession, + groupingExprs: Seq[Expression]): (QueryExecution, Seq[Attribute]) = { + // Resolves grouping expressions. + val dummyPlan = Project(groupingExprs.map(alias), LocalRelation(logicalPlan.output)) + val analyzedPlan = sparkSession.sessionState.analyzer.execute(dummyPlan) + .asInstanceOf[Project] + sparkSession.sessionState.analyzer.checkAnalysis(analyzedPlan) + val aliasedGroupings = analyzedPlan.projectList + + // Adds the grouping expressions that are not in base DataFrame into outputs. + val addedCols = aliasedGroupings.filter(g => !logicalPlan.outputSet.contains(g.toAttribute)) + val newPlan = Project(logicalPlan.output ++ addedCols, logicalPlan) + val qe = sparkSession.sessionState.executePlan(newPlan) + + (qe, aliasedGroupings.map(_.toAttribute)) + } + + private def alias(expr: Expression): NamedExpression = expr match { + case expr: NamedExpression => expr + case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => + UnresolvedAlias(a, Some(Column.generateAlias)) + case u: UnresolvedFunction => UnresolvedAlias(expr, None) + case expr: Expression => Alias(expr, toPrettySQL(expr))() + } + /** * The Grouping Type */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala index e266ae55cc4..41306cd0a99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -66,3 +66,9 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) reduction._2 } } + +private[sql] object ReduceAggregator { + def apply[T: Encoder](f: AnyRef): ReduceAggregator[T] = { + new ReduceAggregator(f.asInstanceOf[(T, T) => T]) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala new file mode 100644 index 00000000000..68bda47cf8c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression + +private[sql] object TypedAggUtils { + + def aggKeyColumn[A]( + encoder: ExpressionEncoder[A], + groupingAttributes: Seq[Attribute]): NamedExpression = { + if (!encoder.isSerializedAsStructForTopLevel) { + assert(groupingAttributes.length == 1) + if (SQLConf.get.nameNonStructGroupingKeyAsValue) { + groupingAttributes.head + } else { + Alias(groupingAttributes.head, "key")() + } + } else { + Alias(CreateStruct(groupingAttributes), "key")() + } + } + + /** + * Insert inputs into typed aggregate expressions. For untyped aggregate expressions, + * the resolving is handled in the analyzer directly. + */ + private[sql] def withInputType( + expr: Expression, + inputEncoder: ExpressionEncoder[_], + inputAttributes: Seq[Attribute]): Expression = { + val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes) + + expr transform { + case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => + ta.withInputInfo( + deser = unresolvedDeserializer, + cls = inputEncoder.clsTag.runtimeClass, + schema = inputEncoder.schema + ) + } + } +} + --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org