This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 0843b7741fa [SPARK-44311][CONNECT][SQL] Improved support for UDFs on value classes 0843b7741fa is described below commit 0843b7741fa959173fcc66067eedda9be501192c Author: Emil Ejbyfeldt <eejbyfe...@liveintent.com> AuthorDate: Tue Aug 1 10:50:04 2023 -0400 [SPARK-44311][CONNECT][SQL] Improved support for UDFs on value classes ### What changes were proposed in this pull request? This pr fixes using UDFs on value classes when it serialized as in underlying type. Previously it would only work if one either defined a UDF taking the underlying type and/or for cases where the schema derived does not "unbox" the value to its underlying type. Before this change the following code: ``` final case class ValueClass(a: Int) extends AnyVal final case class Wrapper(v: ValueClass) val f = udf((a: ValueClass) => a.a > 0) spark.createDataset(Seq(Wrapper(ValueClass(1)))).filter(f(col("v"))).show() ``` would fails with ``` java.lang.ClassCastException: class org.apache.spark.sql.types.IntegerType$ cannot be cast to class org.apache.spark.sql.types.StructType (org.apache.spark.sql.types.IntegerType$ and org.apache.spark.sql.types.StructType are in unnamed module of loader 'app') at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF$$anonfun$apply$42$$anonfun$applyOrElse$218.$anonfun$applyOrElse$220(Analyzer.scala:3241) at scala.Option.map(Option.scala:242) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF$$anonfun$apply$42$$anonfun$applyOrElse$218.$anonfun$applyOrElse$219(Analyzer.scala:3239) at scala.collection.immutable.List.map(List.scala:246) at scala.collection.immutable.List.map(List.scala:79) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF$$anonfun$apply$42$$anonfun$applyOrElse$218.applyOrElse(Analyzer.scala:3237) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF$$anonfun$apply$42$$anonfun$applyOrElse$218.applyOrElse(Analyzer.scala:3234) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUpWithPruning$2(TreeNode.scala:566) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:104) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUpWithPruning(TreeNode.scala:566) ``` ### Why are the changes needed? This is something as a user I would expect to just work. ### Does this PR introduce _any_ user-facing change? Yes, it if fixes using a UDF on value class that is serialized as it underlying type. ### How was this patch tested? Existing test and new tests cases in DatasetSuite.scala Closes #41876 from eejbyfeldt/SPARK-44311. Authored-by: Emil Ejbyfeldt <eejbyfe...@liveintent.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit 821026bc730ce87e6e97d304c7673bfcb23fd03a) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 7 ++++++- .../spark/sql/catalyst/expressions/ScalaUDF.scala | 4 +++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 24 ++++++++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 30c6e4b4bc0..7f2471c9e19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3245,7 +3245,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val dataType = udf.children(i).dataType encOpt.map { enc => val attrs = if (enc.isSerializedAsStructForTopLevel) { - DataTypeUtils.toAttributes(dataType.asInstanceOf[StructType]) + // Value class that has been replaced with its underlying type + if (enc.schema.fields.size == 1 && enc.schema.fields.head.dataType == dataType) { + DataTypeUtils.toAttributes(enc.schema.asInstanceOf[StructType]) + } else { + DataTypeUtils.toAttributes(dataType.asInstanceOf[StructType]) + } } else { // the field name doesn't matter here, so we use // a simple literal to avoid any overhead diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 40274a83340..910960bf84b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -162,7 +162,9 @@ case class ScalaUDF( if (useEncoder) { val enc = inputEncoders(i).get val fromRow = enc.createDeserializer() - val converter = if (enc.isSerializedAsStructForTopLevel) { + val unwrappedValueClass = enc.isSerializedAsStruct && + enc.schema.fields.size == 1 && enc.schema.fields.head.dataType == dataType + val converter = if (enc.isSerializedAsStructForTopLevel && !unwrappedValueClass) { row: Any => fromRow(row.asInstanceOf[InternalRow]) } else { val inputRow = new GenericInternalRow(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a021b049cf0..c967540541a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2514,6 +2514,27 @@ class DatasetSuite extends QueryTest } } } + + test("SPARK-44311: UDF on value class taking underlying type (backwards compatability)") { + val f = udf((v: Int) => v > 1) + val ds = Seq(ValueClassContainer(ValueClass(1)), ValueClassContainer(ValueClass(2))).toDS() + + checkDataset(ds.filter(f(col("v"))), ValueClassContainer(ValueClass(2))) + } + + test("SPARK-44311: UDF on value class field in product") { + val f = udf((v: ValueClass) => v.i > 1) + val ds = Seq(ValueClassContainer(ValueClass(1)), ValueClassContainer(ValueClass(2))).toDS() + + checkDataset(ds.filter(f(col("v"))), ValueClassContainer(ValueClass(2))) + } + + test("SPARK-44311: UDF on value class this is stored as a struct") { + val f = udf((v: ValueClass) => v.i > 1) + val ds = Seq(Tuple1(ValueClass(1)), Tuple1(ValueClass(2))).toDS() + + checkDataset(ds.filter(f(col("_1"))), Tuple1(ValueClass(2))) + } } class DatasetLargeResultCollectingSuite extends QueryTest @@ -2545,6 +2566,9 @@ class DatasetLargeResultCollectingSuite extends QueryTest } } +case class ValueClass(i: Int) extends AnyVal +case class ValueClassContainer(v: ValueClass) + case class Bar(a: Int) object AssertExecutionId { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org