spark git commit: [SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas

2018-07-11 Thread lixiao
Repository: spark
Updated Branches:
  refs/heads/branch-2.3 86457a16d -> 32429256f


[SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas

A self-join on a dataset which contains a `FlatMapGroupsInPandas` fails because 
of duplicate attributes. This happens because we are not dealing with this 
specific case in our `dedupAttr` rules.

The PR fix the issue by adding the management of the specific case

added UT + manual tests

Author: Marco Gaido 
Author: Marco Gaido 

Closes #21737 from mgaido91/SPARK-24208.

(cherry picked from commit ebf4bfb966389342bfd9bdb8e3b612828c18730c)
Signed-off-by: Xiao Li 


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32429256
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32429256
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32429256

Branch: refs/heads/branch-2.3
Commit: 32429256f3e659c648462e5b2740747645740c97
Parents: 86457a1
Author: Marco Gaido 
Authored: Wed Jul 11 09:29:19 2018 -0700
Committer: Xiao Li 
Committed: Wed Jul 11 09:35:44 2018 -0700

--
 python/pyspark/sql/tests.py | 16 
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  4 
 .../org/apache/spark/sql/GroupedDatasetSuite.scala  | 12 
 3 files changed, 32 insertions(+)
--


http://git-wip-us.apache.org/repos/asf/spark/blob/32429256/python/pyspark/sql/tests.py
--
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index aa7d8eb..6bfb329 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -4691,6 +4691,22 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
 result = df.groupby('time').apply(foo_udf).sort('time')
 self.assertPandasEqual(df.toPandas(), result.toPandas())
 
+def test_self_join_with_pandas(self):
+import pyspark.sql.functions as F
+
+@F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
+def dummy_pandas_udf(df):
+return df[['key', 'col']]
+
+df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, 
col='B'),
+ Row(key=2, col='C')])
+dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
+
+# this was throwing an AnalysisException before SPARK-24208
+res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
+   F.col('temp0.key') == 
F.col('temp1.key'))
+self.assertEquals(res.count(), 5)
+
 
 if __name__ == "__main__":
 from pyspark.sql.tests import *

http://git-wip-us.apache.org/repos/asf/spark/blob/32429256/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
--
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8597d83..a584cb8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -723,6 +723,10 @@ class Analyzer(
 if 
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
   (oldVersion, oldVersion.copy(aggregateExpressions = 
newAliases(aggregateExpressions)))
 
+case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
+if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty 
=>
+  (oldVersion, oldVersion.copy(output = output.map(_.newInstance(
+
 case oldVersion: Generate
 if 
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
   val newOutput = oldVersion.generatorOutput.map(_.newInstance())

http://git-wip-us.apache.org/repos/asf/spark/blob/32429256/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
--
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
index 218a1b7..9699fad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
@@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with 
SharedSQLContext {
 }
 datasetWithUDF.unpersist(true)
   }
+
+  test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
+val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF(
+  "pyUDF",
+  null,
+  

spark git commit: [SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas

2018-07-11 Thread lixiao
Repository: spark
Updated Branches:
  refs/heads/master 592cc8458 -> ebf4bfb96


[SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas

## What changes were proposed in this pull request?

A self-join on a dataset which contains a `FlatMapGroupsInPandas` fails because 
of duplicate attributes. This happens because we are not dealing with this 
specific case in our `dedupAttr` rules.

The PR fix the issue by adding the management of the specific case

## How was this patch tested?

added UT + manual tests

Author: Marco Gaido 
Author: Marco Gaido 

Closes #21737 from mgaido91/SPARK-24208.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ebf4bfb9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ebf4bfb9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ebf4bfb9

Branch: refs/heads/master
Commit: ebf4bfb966389342bfd9bdb8e3b612828c18730c
Parents: 592cc84
Author: Marco Gaido 
Authored: Wed Jul 11 09:29:19 2018 -0700
Committer: Xiao Li 
Committed: Wed Jul 11 09:29:19 2018 -0700

--
 python/pyspark/sql/tests.py | 16 
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  4 
 .../org/apache/spark/sql/GroupedDatasetSuite.scala  | 12 
 3 files changed, 32 insertions(+)
--


http://git-wip-us.apache.org/repos/asf/spark/blob/ebf4bfb9/python/pyspark/sql/tests.py
--
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 8d73806..4404dbe 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -5925,6 +5925,22 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
 'mixture.*aggregate function.*group aggregate pandas UDF'):
 df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
 
+def test_self_join_with_pandas(self):
+import pyspark.sql.functions as F
+
+@F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
+def dummy_pandas_udf(df):
+return df[['key', 'col']]
+
+df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, 
col='B'),
+ Row(key=2, col='C')])
+dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
+
+# this was throwing an AnalysisException before SPARK-24208
+res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
+   F.col('temp0.key') == 
F.col('temp1.key'))
+self.assertEquals(res.count(), 5)
+
 
 @unittest.skipIf(
 not _have_pandas or not _have_pyarrow,

http://git-wip-us.apache.org/repos/asf/spark/blob/ebf4bfb9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
--
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e187133..c078efd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -738,6 +738,10 @@ class Analyzer(
 if 
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
   (oldVersion, oldVersion.copy(aggregateExpressions = 
newAliases(aggregateExpressions)))
 
+case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
+if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty 
=>
+  (oldVersion, oldVersion.copy(output = output.map(_.newInstance(
+
 case oldVersion: Generate
 if 
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
   val newOutput = oldVersion.generatorOutput.map(_.newInstance())

http://git-wip-us.apache.org/repos/asf/spark/blob/ebf4bfb9/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
--
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
index 147c0b6..bd54ea4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala
@@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with 
SharedSQLContext {
 }
 datasetWithUDF.unpersist(true)
   }
+
+  test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
+val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF(
+  "pyUDF",
+  null,
+