This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1a1d3034b1d [SPARK-45930][SQL] Support non-deterministic UDFs in
MapInPandas/MapInArrow
1a1d3034b1d is described below
commit 1a1d3034b1d7d3c457ef0b1b5693698c1c5e77d8
Author: allisonwang-db <[email protected]>
AuthorDate: Thu Nov 16 11:50:43 2023 +0900
[SPARK-45930][SQL] Support non-deterministic UDFs in MapInPandas/MapInArrow
### What changes were proposed in this pull request?
This PR supports non-deterministic UDFs in MapInPandas and MapInArrow.
### Why are the changes needed?
Currently, MapInPandas and MapInArrow do not support non-deterministic
UDFs. The analyzer will fail with this error:
`org.apache.spark.sql.AnalysisException:
[INVALID_NON_DETERMINISTIC_EXPRESSIONS] The operator expects a deterministic
expression, but the actual expression is "pyUDF()"`.
This is needed for https://github.com/apache/spark/pull/43791.
### Does this PR introduce _any_ user-facing change?
No. Users cannot directly create a non-deterministic UDF in PySpark to be
used in MapInPandas/MapInArrow.
### How was this patch tested?
New unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43810 from allisonwang-db/spark-45930-map-in-pandas-non-det.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/catalyst/analysis/CheckAnalysis.scala | 2 ++
.../sql/catalyst/analysis/AnalysisSuite.scala | 32 ++++++++++++++++++++++
2 files changed, 34 insertions(+)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index d41345f38c2..176a45a6f8e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -746,6 +746,8 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
!o.isInstanceOf[Expand] &&
!o.isInstanceOf[Generate] &&
!o.isInstanceOf[CreateVariable] &&
+ !o.isInstanceOf[MapInPandas] &&
+ !o.isInstanceOf[PythonMapInArrow] &&
// Lateral join is checked in checkSubqueryExpression.
!o.isInstanceOf[LateralJoin] =>
// The rule above is used to check Aggregate operator.
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 8e514e245cb..441b5fb6ca6 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -711,6 +711,38 @@ class AnalysisSuite extends AnalysisTest with Matchers {
Project(Seq(UnresolvedAttribute("temp0.a"),
UnresolvedAttribute("temp1.a")), join))
}
+ test("SPARK-45930: MapInPandas with non-deterministic UDF") {
+ val pythonUdf = PythonUDF("pyUDF", null,
+ StructType(Seq(StructField("a", LongType))),
+ Seq.empty,
+ PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+ false)
+ val output =
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType])
+ val project = Project(Seq(UnresolvedAttribute("a")), testRelation)
+ val mapInPandas = MapInPandas(
+ pythonUdf,
+ output,
+ project,
+ false)
+ assertAnalysisSuccess(mapInPandas)
+ }
+
+ test("SPARK-45930: MapInArrow with non-deterministic UDF") {
+ val pythonUdf = PythonUDF("pyUDF", null,
+ StructType(Seq(StructField("a", LongType))),
+ Seq.empty,
+ PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
+ false)
+ val output =
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType])
+ val project = Project(Seq(UnresolvedAttribute("a")), testRelation)
+ val mapInArrow = PythonMapInArrow(
+ pythonUdf,
+ output,
+ project,
+ false)
+ assertAnalysisSuccess(mapInArrow)
+ }
+
test("SPARK-34741: Avoid ambiguous reference in MergeIntoTable") {
val cond = $"a" > 1
assertAnalysisErrorClass(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]