Repository: spark
Updated Branches:
  refs/heads/branch-2.3 86457a16d -> 32429256f


[SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas

A self-join on a dataset which contains a `FlatMapGroupsInPandas` fails because 
of duplicate attributes. This happens because we are not dealing with this 
specific case in our `dedupAttr` rules.

The PR fix the issue by adding the management of the specific case

added UT + manual tests

Author: Marco Gaido <marcogaid...@gmail.com>
Author: Marco Gaido <mga...@hortonworks.com>

Closes #21737 from mgaido91/SPARK-24208.

(cherry picked from commit ebf4bfb966389342bfd9bdb8e3b612828c18730c)
Signed-off-by: Xiao Li <gatorsm...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32429256
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32429256
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32429256

Branch: refs/heads/branch-2.3
Commit: 32429256f3e659c648462e5b2740747645740c97
Parents: 86457a1
Author: Marco Gaido <marcogaid...@gmail.com>
Authored: Wed Jul 11 09:29:19 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Wed Jul 11 09:35:44 2018 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                         | 16 ++++++++++++++++
 .../spark/sql/catalyst/analysis/Analyzer.scala      |  4 ++++
 .../org/apache/spark/sql/GroupedDatasetSuite.scala  | 12 ++++++++++++
 3 files changed, 32 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/32429256/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index aa7d8eb..6bfb329 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -4691,6 +4691,22 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         result = df.groupby('time').apply(foo_udf).sort('time')
         self.assertPandasEqual(df.toPandas(), result.toPandas())
 
+    def test_self_join_with_pandas(self):
+        import pyspark.sql.functions as F
+
+        @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
+        def dummy_pandas_udf(df):
+            return df[['key', 'col']]
+
+        df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, 
col='B'),
+                                         Row(key=2, col='C')])
+        dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
+
+        # this was throwing an AnalysisException before SPARK-24208
+        res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
+                                               F.col('temp0.key') == 
F.col('temp1.key'))
+        self.assertEquals(res.count(), 5)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests import *

http://git-wip-us.apache.org/repos/asf/spark/blob/32429256/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8597d83..a584cb8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -723,6 +723,10 @@ class Analyzer(
             if 
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
           (oldVersion, oldVersion.copy(aggregateExpressions = 
newAliases(aggregateExpressions)))
 
+        case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
+            if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty 
=>
+          (oldVersion, oldVersion.copy(output = output.map(_.newInstance())))
+
         case oldVersion: Generate
             if 
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
           val newOutput = oldVersion.generatorOutput.map(_.newInstance())

http://git-wip-us.apache.org/repos/asf/spark/blob/32429256/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
index 218a1b7..9699fad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
@@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with 
SharedSQLContext {
     }
     datasetWithUDF.unpersist(true)
   }
+
+  test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
+    val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF(
+      "pyUDF",
+      null,
+      StructType(Seq(StructField("s", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+      true))
+    val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === 
$"temp1.s")
+    df1.queryExecution.assertAnalyzed()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to