Repository: spark
Updated Branches:
  refs/heads/master 592cc8458 -> ebf4bfb96


[SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas

## What changes were proposed in this pull request?

A self-join on a dataset which contains a `FlatMapGroupsInPandas` fails because 
of duplicate attributes. This happens because we are not dealing with this 
specific case in our `dedupAttr` rules.

The PR fix the issue by adding the management of the specific case

## How was this patch tested?

added UT + manual tests

Author: Marco Gaido <marcogaid...@gmail.com>
Author: Marco Gaido <mga...@hortonworks.com>

Closes #21737 from mgaido91/SPARK-24208.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ebf4bfb9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ebf4bfb9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ebf4bfb9

Branch: refs/heads/master
Commit: ebf4bfb966389342bfd9bdb8e3b612828c18730c
Parents: 592cc84
Author: Marco Gaido <marcogaid...@gmail.com>
Authored: Wed Jul 11 09:29:19 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Wed Jul 11 09:29:19 2018 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                         | 16 ++++++++++++++++
 .../spark/sql/catalyst/analysis/Analyzer.scala      |  4 ++++
 .../org/apache/spark/sql/GroupedDatasetSuite.scala  | 12 ++++++++++++
 3 files changed, 32 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ebf4bfb9/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 8d73806..4404dbe 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -5925,6 +5925,22 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
                     'mixture.*aggregate function.*group aggregate pandas UDF'):
                 df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
 
+    def test_self_join_with_pandas(self):
+        import pyspark.sql.functions as F
+
+        @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
+        def dummy_pandas_udf(df):
+            return df[['key', 'col']]
+
+        df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, 
col='B'),
+                                         Row(key=2, col='C')])
+        dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
+
+        # this was throwing an AnalysisException before SPARK-24208
+        res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
+                                               F.col('temp0.key') == 
F.col('temp1.key'))
+        self.assertEquals(res.count(), 5)
+
 
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,

http://git-wip-us.apache.org/repos/asf/spark/blob/ebf4bfb9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e187133..c078efd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -738,6 +738,10 @@ class Analyzer(
             if 
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
           (oldVersion, oldVersion.copy(aggregateExpressions = 
newAliases(aggregateExpressions)))
 
+        case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
+            if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty 
=>
+          (oldVersion, oldVersion.copy(output = output.map(_.newInstance())))
+
         case oldVersion: Generate
             if 
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
           val newOutput = oldVersion.generatorOutput.map(_.newInstance())

http://git-wip-us.apache.org/repos/asf/spark/blob/ebf4bfb9/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
index 147c0b6..bd54ea4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
@@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with 
SharedSQLContext {
     }
     datasetWithUDF.unpersist(true)
   }
+
+  test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
+    val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF(
+      "pyUDF",
+      null,
+      StructType(Seq(StructField("s", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+      true))
+    val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === 
$"temp1.s")
+    df1.queryExecution.assertAnalyzed()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to