This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d9dd9944bf4 [SPARK-40601][PYTHON] Assert identical key size when cogrouping groups d9dd9944bf4 is described below commit d9dd9944bf4c3adba4bcb458304793376f083000 Author: Enrico Minack <git...@enrico.minack.dev> AuthorDate: Thu Oct 13 15:20:31 2022 +0900 [SPARK-40601][PYTHON] Assert identical key size when cogrouping groups Cogrouping two grouped DataFrames in PySpark that have different group key cardinalities raises an error that is not very descriptive: ```python left.groupby("id", "k") .cogroup(right.groupby("id")) ``` ``` py4j.protocol.Py4JJavaError: An error occurred while calling o726.collectToPython. : java.lang.IndexOutOfBoundsException: 1 at scala.collection.mutable.ResizableArray.apply(ResizableArray.scala:46) at scala.collection.mutable.ResizableArray.apply$(ResizableArray.scala:45) at scala.collection.mutable.ArrayBuffer.apply(ArrayBuffer.scala:49) at org.apache.spark.sql.catalyst.plans.physical.HashShuffleSpec.$anonfun$createPartitioning$5(partitioning.scala:650) ... org.apache.spark.sql.execution.exchange.EnsureRequirements.$anonfun$ensureDistributionAndOrdering$14(EnsureRequirements.scala:159) ``` ### What changes were proposed in this pull request? Assert identical size of groupby keys and provide a meaningful error on cogroup. ### Why are the changes needed? The error does not provide information on how to solve the problem. ### Does this PR introduce _any_ user-facing change? Yes, raises an `AssertionError: group keys must have same size` instead. ### How was this patch tested? Adds test `test_different_group_key_cardinality` to `pyspark.sql.tests.test_pandas_cogrouped_map`. Closes #38036 from EnricoMi/branch-cogroup-key-mismatch. Authored-by: Enrico Minack <git...@enrico.minack.dev> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../pyspark/sql/tests/test_pandas_cogrouped_map.py | 41 +++++++++++++++++++++- .../spark/sql/RelationalGroupedDataset.scala | 3 ++ .../org/apache/spark/sql/DataFrameSuite.scala | 32 ++++++++++++++++- 3 files changed, 74 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py index 3f403d9c9d6..88ba396e3f5 100644 --- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py @@ -20,7 +20,7 @@ from typing import cast from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf from pyspark.sql.types import DoubleType, StructType, StructField, Row -from pyspark.sql.utils import PythonException +from pyspark.sql.utils import IllegalArgumentException, PythonException from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, @@ -80,6 +80,29 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase): right = self.data2.withColumn("v3", lit("a")) self._test_merge(self.data1, right, "id long, k int, v int, v2 int, v3 string") + def test_different_keys(self): + left = self.data1 + right = self.data2 + + def merge_pandas(lft, rgt): + return pd.merge(lft.rename(columns={"id2": "id"}), rgt, on=["id", "k"]) + + result = ( + left.withColumnRenamed("id", "id2") + .groupby("id2") + .cogroup(right.groupby("id")) + .applyInPandas(merge_pandas, "id long, k int, v int, v2 int") + .sort(["id", "k"]) + .toPandas() + ) + + left = left.toPandas() + right = right.toPandas() + + expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"]) + + assert_frame_equal(expected, result) + def test_complex_group_by(self): left = pd.DataFrame.from_dict({"id": [1, 2, 3], "k": [5, 6, 7], "v": [9, 10, 11]}) @@ -125,6 +148,22 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase): assert_frame_equal(expected, result) + def test_different_group_key_cardinality(self): + left = self.data1 + right = self.data2 + + def merge_pandas(lft, _): + return lft + + with QuietTest(self.sc): + with self.assertRaisesRegex( + IllegalArgumentException, + "requirement failed: Cogroup keys must have same size: 2 != 1", + ): + (left.groupby("id", "k").cogroup(right.groupby("id"))).applyInPandas( + merge_pandas, "id long, k int, v int" + ) + def test_apply_in_pandas_not_returning_pandas_dataframe(self): left = self.data1 right = self.data2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 0429fd27a41..61517de0dfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -594,6 +594,9 @@ class RelationalGroupedDataset protected[sql]( expr: PythonUDF): DataFrame = { require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, "Must pass a cogrouped map udf") + require(this.groupingExprs.length == r.groupingExprs.length, + "Cogroup keys must have same size: " + + s"${this.groupingExprs.length} != ${r.groupingExprs.length}") require(expr.dataType.isInstanceOf[StructType], s"The returnType of the udf must be a ${StructType.simpleString}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60dd7e3952f..cb453902ce9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -30,11 +30,12 @@ import scala.util.Random import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SparkException +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, EqualTo, ExpressionSet, GreaterThan, Literal, Uuid} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, Uuid} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -2845,6 +2846,35 @@ class DataFrameSuite extends QueryTest parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`")) } + test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { + val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value") + val df2 = df1.filter($"value" === "A2") + + val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null, + StructType(Seq(StructField("x", LongType), StructField("y", LongType))), + Seq.empty, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + true) + + // the number of keys must match + val exception1 = intercept[IllegalArgumentException] { + df1.groupBy($"key1", $"key2").flatMapCoGroupsInPandas( + df2.groupBy($"key2"), flatMapCoGroupsInPandasUDF) + } + assert(exception1.getMessage.contains("Cogroup keys must have same size: 2 != 1")) + val exception2 = intercept[IllegalArgumentException] { + df1.groupBy($"key1").flatMapCoGroupsInPandas( + df2.groupBy($"key1", $"key2"), flatMapCoGroupsInPandasUDF) + } + assert(exception2.getMessage.contains("Cogroup keys must have same size: 1 != 2")) + + // but different keys are allowed + val actual = df1.groupBy($"key1").flatMapCoGroupsInPandas( + df2.groupBy($"key2"), flatMapCoGroupsInPandasUDF) + // can't evaluate the DataFrame as there is no PythonFunction given + assert(actual != null) + } + test("emptyDataFrame should be foldable") { val emptyDf = spark.emptyDataFrame.withColumn("id", lit(1L)) val joined = spark.range(10).join(emptyDf, "id") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org