This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 032623f62c07 [SPARK-50553][CONNECT] Throw `InvalidPlanInput` for 
invalid plan message
032623f62c07 is described below

commit 032623f62c071d3b0f2879f1951cbf4a6b3e55d3
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Dec 12 16:29:47 2024 +0900

    [SPARK-50553][CONNECT] Throw `InvalidPlanInput` for invalid plan message
    
    ### What changes were proposed in this pull request?
    Throw `InvalidPlanInput` for invalid plan message
    
    ### Why are the changes needed?
    Should throw `InvalidPlanInput` for invalid plan message, instead of 
`AssertionError`
    
    ### Does this PR introduce _any_ user-facing change?
    error message improvement: `AssertionError -> InvalidPlanInput`
    
    ### How was this patch tested?
    updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49161 from zhengruifeng/assert_plan.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 42 +++++++++++-----------
 .../connect/planner/SparkConnectPlannerSuite.scala |  2 +-
 2 files changed, 23 insertions(+), 21 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ec67c57a37f5..82dfcf7a3694 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -558,7 +558,7 @@ class SparkConnectPlanner(
 
   private def transformToSchema(rel: proto.ToSchema): LogicalPlan = {
     val schema = transformDataType(rel.getSchema)
-    assert(schema.isInstanceOf[StructType])
+    assertPlan(schema.isInstanceOf[StructType])
 
     Dataset
       .ofRows(session, transformRelation(rel.getInput))
@@ -876,7 +876,7 @@ class SparkConnectPlanner(
         logicalPlan: LogicalPlan,
         groupingExprs: java.util.List[proto.Expression],
         sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
-      assert(groupingExprs.size() >= 1)
+      assertPlan(groupingExprs.size() >= 1)
       val dummyFunc = TypedScalaUdf(groupingExprs.get(0), None)
       val groupExprs = groupingExprs.asScala.toSeq.drop(1).map(expr => 
transformExpression(expr))
 
@@ -896,7 +896,7 @@ class SparkConnectPlanner(
         logicalPlan: LogicalPlan,
         groupingExprs: java.util.List[proto.Expression],
         sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
-      assert(groupingExprs.size() == 1)
+      assertPlan(groupingExprs.size() == 1)
       val groupFunc = TypedScalaUdf(groupingExprs.get(0), 
Some(logicalPlan.output))
       val vEnc = groupFunc.inEnc
       val kEnc = groupFunc.outEnc
@@ -952,7 +952,7 @@ class SparkConnectPlanner(
       // Most typed API takes one UDF input.
       // For the few that takes more than one inputs, e.g. grouping function 
mapping UDFs,
       // the first input which is the key of the grouping function.
-      assert(udf.inputEncoders.nonEmpty)
+      assertPlan(udf.inputEncoders.nonEmpty)
       val inEnc = udf.inputEncoders.head // single input encoder or key encoder
       TypedScalaUdf(udf.function, udf.outputEncoder, inEnc, inputAttrs)
     }
@@ -1431,7 +1431,7 @@ class SparkConnectPlanner(
   }
 
   private def transformFilter(rel: proto.Filter): LogicalPlan = {
-    assert(rel.hasInput)
+    assertPlan(rel.hasInput)
     val baseRel = transformRelation(rel.getInput)
     val cond = rel.getCondition
     if (isTypedScalaUdfExpr(cond)) {
@@ -1754,7 +1754,7 @@ class SparkConnectPlanner(
     val udf = fun.getScalarScalaUdf
     val udfPacket = unpackUdf(fun)
     if (udf.getAggregate) {
-      assert(udfPacket.inputEncoders.size == 1, "UDAF should have exactly one 
input encoder")
+      assertPlan(udfPacket.inputEncoders.size == 1, "UDAF should have exactly 
one input encoder")
       UserDefinedAggregator(
         aggregator = udfPacket.function.asInstanceOf[Aggregator[Any, Any, 
Any]],
         inputEncoder = ExpressionEncoder(udfPacket.inputEncoders.head),
@@ -2072,7 +2072,7 @@ class SparkConnectPlanner(
   }
 
   private def transformJoin(rel: proto.Join): LogicalPlan = {
-    assert(rel.hasLeft && rel.hasRight, "Both join sides must be present")
+    assertPlan(rel.hasLeft && rel.hasRight, "Both join sides must be present")
     if (rel.hasJoinCondition && rel.getUsingColumnsCount > 0) {
       throw InvalidPlanInput(
         s"Using columns or join conditions cannot be set at the same time in 
Join")
@@ -2144,7 +2144,7 @@ class SparkConnectPlanner(
   }
 
   private def transformLateralJoin(rel: proto.LateralJoin): LogicalPlan = {
-    assert(rel.hasLeft && rel.hasRight, "Both join sides must be present")
+    assertPlan(rel.hasLeft && rel.hasRight, "Both join sides must be present")
     val joinCondition =
       if (rel.hasJoinCondition) 
Some(transformExpression(rel.getJoinCondition)) else None
     val joinType = transformJoinType(
@@ -2157,7 +2157,7 @@ class SparkConnectPlanner(
   }
 
   private def transformSort(sort: proto.Sort): LogicalPlan = {
-    assert(sort.getOrderCount > 0, "'order' must be present and contain 
elements.")
+    assertPlan(sort.getOrderCount > 0, "'order' must be present and contain 
elements.")
     logical.Sort(
       child = transformRelation(sort.getInput),
       global = sort.getIsGlobal,
@@ -2287,10 +2287,8 @@ class SparkConnectPlanner(
   private def transformTypedReduceExpression(
       fun: proto.Expression.UnresolvedFunction,
       dataAttributes: Seq[Attribute]): Expression = {
-    assert(fun.getFunctionName == "reduce")
-    if (fun.getArgumentsCount != 1) {
-      throw InvalidPlanInput("reduce requires single child expression")
-    }
+    assertPlan(fun.getFunctionName == "reduce")
+    assertPlan(fun.getArgumentsCount == 1, "reduce requires single child 
expression")
     val udf = fun.getArgumentsList.asScala match {
       case collection.Seq(e)
           if e.hasCommonInlineUserDefinedFunction &&
@@ -2320,10 +2318,10 @@ class SparkConnectPlanner(
       expr: proto.TypedAggregateExpression,
       baseRelationOpt: Option[LogicalPlan]): AggregateExpression = {
     val udf = expr.getScalarScalaUdf
-    assert(udf.getAggregate)
+    assertPlan(udf.getAggregate)
 
     val udfPacket = unpackScalaUDF[UdfPacket](udf)
-    assert(udfPacket.inputEncoders.size == 1, "UDAF should have exactly one 
input encoder")
+    assertPlan(udfPacket.inputEncoders.size == 1, "UDAF should have exactly 
one input encoder")
 
     val aggregator = udfPacket.function.asInstanceOf[Aggregator[Any, Any, Any]]
     val tae =
@@ -2356,17 +2354,17 @@ class SparkConnectPlanner(
     }.toSeq
     action.getActionType match {
       case proto.MergeAction.ActionType.ACTION_TYPE_DELETE =>
-        assert(assignments.isEmpty, "Delete action should not have 
assignment.")
+        assertPlan(assignments.isEmpty, "Delete action should not have 
assignment.")
         DeleteAction(condition)
       case proto.MergeAction.ActionType.ACTION_TYPE_INSERT =>
         InsertAction(condition, assignments)
       case proto.MergeAction.ActionType.ACTION_TYPE_INSERT_STAR =>
-        assert(assignments.isEmpty, "InsertStar action should not have 
assignment.")
+        assertPlan(assignments.isEmpty, "InsertStar action should not have 
assignment.")
         InsertStarAction(condition)
       case proto.MergeAction.ActionType.ACTION_TYPE_UPDATE =>
         UpdateAction(condition, assignments)
       case proto.MergeAction.ActionType.ACTION_TYPE_UPDATE_STAR =>
-        assert(assignments.isEmpty, "UpdateStar action should not have 
assignment.")
+        assertPlan(assignments.isEmpty, "UpdateStar action should not have 
assignment.")
         UpdateStarAction(condition)
       case _ =>
         throw InvalidPlanInput(s"Unsupported merge action type 
${action.getActionType}.")
@@ -3578,7 +3576,7 @@ class SparkConnectPlanner(
       getCreateExternalTable: proto.CreateExternalTable): LogicalPlan = {
     val schema = if (getCreateExternalTable.hasSchema) {
       val struct = transformDataType(getCreateExternalTable.getSchema)
-      assert(struct.isInstanceOf[StructType])
+      assertPlan(struct.isInstanceOf[StructType])
       struct.asInstanceOf[StructType]
     } else {
       new StructType
@@ -3608,7 +3606,7 @@ class SparkConnectPlanner(
   private def transformCreateTable(getCreateTable: proto.CreateTable): 
LogicalPlan = {
     val schema = if (getCreateTable.hasSchema) {
       val struct = transformDataType(getCreateTable.getSchema)
-      assert(struct.isInstanceOf[StructType])
+      assertPlan(struct.isInstanceOf[StructType])
       struct.asInstanceOf[StructType]
     } else {
       new StructType
@@ -3724,4 +3722,8 @@ class SparkConnectPlanner(
   private def transformLazyExpression(getLazyExpression: 
proto.LazyExpression): Expression = {
     LazyExpression(transformExpression(getLazyExpression.getChild))
   }
+
+  private def assertPlan(assertion: Boolean, message: String = ""): Unit = {
+    if (!assertion) throw InvalidPlanInput(message)
+  }
 }
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index e44d3eacc66d..84d5fd68d4c7 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -274,7 +274,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
   test("Simple Join") {
     val incompleteJoin =
       
proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build()
-    intercept[AssertionError](transform(incompleteJoin))
+    intercept[InvalidPlanInput](transform(incompleteJoin))
 
     // Join type JOIN_TYPE_UNSPECIFIED is not supported.
     intercept[InvalidPlanInput] {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to