This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d9dd9944bf4 [SPARK-40601][PYTHON] Assert identical key size when 
cogrouping groups
d9dd9944bf4 is described below

commit d9dd9944bf4c3adba4bcb458304793376f083000
Author: Enrico Minack <git...@enrico.minack.dev>
AuthorDate: Thu Oct 13 15:20:31 2022 +0900

    [SPARK-40601][PYTHON] Assert identical key size when cogrouping groups
    
    Cogrouping two grouped DataFrames in PySpark that have different group key 
cardinalities raises an error that is not very descriptive:
    
    ```python
    left.groupby("id", "k")
        .cogroup(right.groupby("id"))
    ```
    
    ```
    py4j.protocol.Py4JJavaError: An error occurred while calling 
o726.collectToPython.
    : java.lang.IndexOutOfBoundsException: 1
            at 
scala.collection.mutable.ResizableArray.apply(ResizableArray.scala:46)
            at 
scala.collection.mutable.ResizableArray.apply$(ResizableArray.scala:45)
            at scala.collection.mutable.ArrayBuffer.apply(ArrayBuffer.scala:49)
            at 
org.apache.spark.sql.catalyst.plans.physical.HashShuffleSpec.$anonfun$createPartitioning$5(partitioning.scala:650)
    ...
    
org.apache.spark.sql.execution.exchange.EnsureRequirements.$anonfun$ensureDistributionAndOrdering$14(EnsureRequirements.scala:159)
    ```
    
    ### What changes were proposed in this pull request?
    Assert identical size of groupby keys and provide a meaningful error on 
cogroup.
    
    ### Why are the changes needed?
    The error does not provide information on how to solve the problem.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, raises an `AssertionError: group keys must have same size` instead.
    
    ### How was this patch tested?
    Adds test `test_different_group_key_cardinality` to 
`pyspark.sql.tests.test_pandas_cogrouped_map`.
    
    Closes #38036 from EnricoMi/branch-cogroup-key-mismatch.
    
    Authored-by: Enrico Minack <git...@enrico.minack.dev>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../pyspark/sql/tests/test_pandas_cogrouped_map.py | 41 +++++++++++++++++++++-
 .../spark/sql/RelationalGroupedDataset.scala       |  3 ++
 .../org/apache/spark/sql/DataFrameSuite.scala      | 32 ++++++++++++++++-
 3 files changed, 74 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py
index 3f403d9c9d6..88ba396e3f5 100644
--- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py
@@ -20,7 +20,7 @@ from typing import cast
 
 from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf
 from pyspark.sql.types import DoubleType, StructType, StructField, Row
-from pyspark.sql.utils import PythonException
+from pyspark.sql.utils import IllegalArgumentException, PythonException
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     have_pandas,
@@ -80,6 +80,29 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
         right = self.data2.withColumn("v3", lit("a"))
         self._test_merge(self.data1, right, "id long, k int, v int, v2 int, v3 
string")
 
+    def test_different_keys(self):
+        left = self.data1
+        right = self.data2
+
+        def merge_pandas(lft, rgt):
+            return pd.merge(lft.rename(columns={"id2": "id"}), rgt, on=["id", 
"k"])
+
+        result = (
+            left.withColumnRenamed("id", "id2")
+            .groupby("id2")
+            .cogroup(right.groupby("id"))
+            .applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
+            .sort(["id", "k"])
+            .toPandas()
+        )
+
+        left = left.toPandas()
+        right = right.toPandas()
+
+        expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", 
"k"])
+
+        assert_frame_equal(expected, result)
+
     def test_complex_group_by(self):
         left = pd.DataFrame.from_dict({"id": [1, 2, 3], "k": [5, 6, 7], "v": 
[9, 10, 11]})
 
@@ -125,6 +148,22 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
 
         assert_frame_equal(expected, result)
 
+    def test_different_group_key_cardinality(self):
+        left = self.data1
+        right = self.data2
+
+        def merge_pandas(lft, _):
+            return lft
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                IllegalArgumentException,
+                "requirement failed: Cogroup keys must have same size: 2 != 1",
+            ):
+                (left.groupby("id", 
"k").cogroup(right.groupby("id"))).applyInPandas(
+                    merge_pandas, "id long, k int, v int"
+                )
+
     def test_apply_in_pandas_not_returning_pandas_dataframe(self):
         left = self.data1
         right = self.data2
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 0429fd27a41..61517de0dfa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -594,6 +594,9 @@ class RelationalGroupedDataset protected[sql](
       expr: PythonUDF): DataFrame = {
     require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
       "Must pass a cogrouped map udf")
+    require(this.groupingExprs.length == r.groupingExprs.length,
+      "Cogroup keys must have same size: " +
+        s"${this.groupingExprs.length} != ${r.groupingExprs.length}")
     require(expr.dataType.isInstanceOf[StructType],
       s"The returnType of the udf must be a ${StructType.simpleString}")
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 60dd7e3952f..cb453902ce9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -30,11 +30,12 @@ import scala.util.Random
 import org.scalatest.matchers.should.Matchers._
 
 import org.apache.spark.SparkException
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
 import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, 
AttributeReference, EqualTo, ExpressionSet, GreaterThan, Literal, Uuid}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, 
AttributeReference, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, 
Uuid}
 import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, 
LocalRelation, LogicalPlan, OneRowRelation, Statistics}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -2845,6 +2846,35 @@ class DataFrameSuite extends QueryTest
       parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`"))
   }
 
+  test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number 
of keys") {
+    val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value")
+    val df2 = df1.filter($"value" === "A2")
+
+    val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", 
null,
+      StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+      true)
+
+    // the number of keys must match
+    val exception1 = intercept[IllegalArgumentException] {
+      df1.groupBy($"key1", $"key2").flatMapCoGroupsInPandas(
+        df2.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
+    }
+    assert(exception1.getMessage.contains("Cogroup keys must have same size: 2 
!= 1"))
+    val exception2 = intercept[IllegalArgumentException] {
+      df1.groupBy($"key1").flatMapCoGroupsInPandas(
+        df2.groupBy($"key1", $"key2"), flatMapCoGroupsInPandasUDF)
+    }
+    assert(exception2.getMessage.contains("Cogroup keys must have same size: 1 
!= 2"))
+
+    // but different keys are allowed
+    val actual = df1.groupBy($"key1").flatMapCoGroupsInPandas(
+      df2.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
+    // can't evaluate the DataFrame as there is no PythonFunction given
+    assert(actual != null)
+  }
+
   test("emptyDataFrame should be foldable") {
     val emptyDf = spark.emptyDataFrame.withColumn("id", lit(1L))
     val joined = spark.range(10).join(emptyDf, "id")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to