This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 032623f62c07 [SPARK-50553][CONNECT] Throw `InvalidPlanInput` for
invalid plan message
032623f62c07 is described below
commit 032623f62c071d3b0f2879f1951cbf4a6b3e55d3
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Dec 12 16:29:47 2024 +0900
[SPARK-50553][CONNECT] Throw `InvalidPlanInput` for invalid plan message
### What changes were proposed in this pull request?
Throw `InvalidPlanInput` for invalid plan message
### Why are the changes needed?
Should throw `InvalidPlanInput` for invalid plan message, instead of
`AssertionError`
### Does this PR introduce _any_ user-facing change?
error message improvement: `AssertionError -> InvalidPlanInput`
### How was this patch tested?
updated tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49161 from zhengruifeng/assert_plan.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 42 +++++++++++-----------
.../connect/planner/SparkConnectPlannerSuite.scala | 2 +-
2 files changed, 23 insertions(+), 21 deletions(-)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ec67c57a37f5..82dfcf7a3694 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -558,7 +558,7 @@ class SparkConnectPlanner(
private def transformToSchema(rel: proto.ToSchema): LogicalPlan = {
val schema = transformDataType(rel.getSchema)
- assert(schema.isInstanceOf[StructType])
+ assertPlan(schema.isInstanceOf[StructType])
Dataset
.ofRows(session, transformRelation(rel.getInput))
@@ -876,7 +876,7 @@ class SparkConnectPlanner(
logicalPlan: LogicalPlan,
groupingExprs: java.util.List[proto.Expression],
sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
- assert(groupingExprs.size() >= 1)
+ assertPlan(groupingExprs.size() >= 1)
val dummyFunc = TypedScalaUdf(groupingExprs.get(0), None)
val groupExprs = groupingExprs.asScala.toSeq.drop(1).map(expr =>
transformExpression(expr))
@@ -896,7 +896,7 @@ class SparkConnectPlanner(
logicalPlan: LogicalPlan,
groupingExprs: java.util.List[proto.Expression],
sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
- assert(groupingExprs.size() == 1)
+ assertPlan(groupingExprs.size() == 1)
val groupFunc = TypedScalaUdf(groupingExprs.get(0),
Some(logicalPlan.output))
val vEnc = groupFunc.inEnc
val kEnc = groupFunc.outEnc
@@ -952,7 +952,7 @@ class SparkConnectPlanner(
// Most typed API takes one UDF input.
// For the few that takes more than one inputs, e.g. grouping function
mapping UDFs,
// the first input which is the key of the grouping function.
- assert(udf.inputEncoders.nonEmpty)
+ assertPlan(udf.inputEncoders.nonEmpty)
val inEnc = udf.inputEncoders.head // single input encoder or key encoder
TypedScalaUdf(udf.function, udf.outputEncoder, inEnc, inputAttrs)
}
@@ -1431,7 +1431,7 @@ class SparkConnectPlanner(
}
private def transformFilter(rel: proto.Filter): LogicalPlan = {
- assert(rel.hasInput)
+ assertPlan(rel.hasInput)
val baseRel = transformRelation(rel.getInput)
val cond = rel.getCondition
if (isTypedScalaUdfExpr(cond)) {
@@ -1754,7 +1754,7 @@ class SparkConnectPlanner(
val udf = fun.getScalarScalaUdf
val udfPacket = unpackUdf(fun)
if (udf.getAggregate) {
- assert(udfPacket.inputEncoders.size == 1, "UDAF should have exactly one
input encoder")
+ assertPlan(udfPacket.inputEncoders.size == 1, "UDAF should have exactly
one input encoder")
UserDefinedAggregator(
aggregator = udfPacket.function.asInstanceOf[Aggregator[Any, Any,
Any]],
inputEncoder = ExpressionEncoder(udfPacket.inputEncoders.head),
@@ -2072,7 +2072,7 @@ class SparkConnectPlanner(
}
private def transformJoin(rel: proto.Join): LogicalPlan = {
- assert(rel.hasLeft && rel.hasRight, "Both join sides must be present")
+ assertPlan(rel.hasLeft && rel.hasRight, "Both join sides must be present")
if (rel.hasJoinCondition && rel.getUsingColumnsCount > 0) {
throw InvalidPlanInput(
s"Using columns or join conditions cannot be set at the same time in
Join")
@@ -2144,7 +2144,7 @@ class SparkConnectPlanner(
}
private def transformLateralJoin(rel: proto.LateralJoin): LogicalPlan = {
- assert(rel.hasLeft && rel.hasRight, "Both join sides must be present")
+ assertPlan(rel.hasLeft && rel.hasRight, "Both join sides must be present")
val joinCondition =
if (rel.hasJoinCondition)
Some(transformExpression(rel.getJoinCondition)) else None
val joinType = transformJoinType(
@@ -2157,7 +2157,7 @@ class SparkConnectPlanner(
}
private def transformSort(sort: proto.Sort): LogicalPlan = {
- assert(sort.getOrderCount > 0, "'order' must be present and contain
elements.")
+ assertPlan(sort.getOrderCount > 0, "'order' must be present and contain
elements.")
logical.Sort(
child = transformRelation(sort.getInput),
global = sort.getIsGlobal,
@@ -2287,10 +2287,8 @@ class SparkConnectPlanner(
private def transformTypedReduceExpression(
fun: proto.Expression.UnresolvedFunction,
dataAttributes: Seq[Attribute]): Expression = {
- assert(fun.getFunctionName == "reduce")
- if (fun.getArgumentsCount != 1) {
- throw InvalidPlanInput("reduce requires single child expression")
- }
+ assertPlan(fun.getFunctionName == "reduce")
+ assertPlan(fun.getArgumentsCount == 1, "reduce requires single child
expression")
val udf = fun.getArgumentsList.asScala match {
case collection.Seq(e)
if e.hasCommonInlineUserDefinedFunction &&
@@ -2320,10 +2318,10 @@ class SparkConnectPlanner(
expr: proto.TypedAggregateExpression,
baseRelationOpt: Option[LogicalPlan]): AggregateExpression = {
val udf = expr.getScalarScalaUdf
- assert(udf.getAggregate)
+ assertPlan(udf.getAggregate)
val udfPacket = unpackScalaUDF[UdfPacket](udf)
- assert(udfPacket.inputEncoders.size == 1, "UDAF should have exactly one
input encoder")
+ assertPlan(udfPacket.inputEncoders.size == 1, "UDAF should have exactly
one input encoder")
val aggregator = udfPacket.function.asInstanceOf[Aggregator[Any, Any, Any]]
val tae =
@@ -2356,17 +2354,17 @@ class SparkConnectPlanner(
}.toSeq
action.getActionType match {
case proto.MergeAction.ActionType.ACTION_TYPE_DELETE =>
- assert(assignments.isEmpty, "Delete action should not have
assignment.")
+ assertPlan(assignments.isEmpty, "Delete action should not have
assignment.")
DeleteAction(condition)
case proto.MergeAction.ActionType.ACTION_TYPE_INSERT =>
InsertAction(condition, assignments)
case proto.MergeAction.ActionType.ACTION_TYPE_INSERT_STAR =>
- assert(assignments.isEmpty, "InsertStar action should not have
assignment.")
+ assertPlan(assignments.isEmpty, "InsertStar action should not have
assignment.")
InsertStarAction(condition)
case proto.MergeAction.ActionType.ACTION_TYPE_UPDATE =>
UpdateAction(condition, assignments)
case proto.MergeAction.ActionType.ACTION_TYPE_UPDATE_STAR =>
- assert(assignments.isEmpty, "UpdateStar action should not have
assignment.")
+ assertPlan(assignments.isEmpty, "UpdateStar action should not have
assignment.")
UpdateStarAction(condition)
case _ =>
throw InvalidPlanInput(s"Unsupported merge action type
${action.getActionType}.")
@@ -3578,7 +3576,7 @@ class SparkConnectPlanner(
getCreateExternalTable: proto.CreateExternalTable): LogicalPlan = {
val schema = if (getCreateExternalTable.hasSchema) {
val struct = transformDataType(getCreateExternalTable.getSchema)
- assert(struct.isInstanceOf[StructType])
+ assertPlan(struct.isInstanceOf[StructType])
struct.asInstanceOf[StructType]
} else {
new StructType
@@ -3608,7 +3606,7 @@ class SparkConnectPlanner(
private def transformCreateTable(getCreateTable: proto.CreateTable):
LogicalPlan = {
val schema = if (getCreateTable.hasSchema) {
val struct = transformDataType(getCreateTable.getSchema)
- assert(struct.isInstanceOf[StructType])
+ assertPlan(struct.isInstanceOf[StructType])
struct.asInstanceOf[StructType]
} else {
new StructType
@@ -3724,4 +3722,8 @@ class SparkConnectPlanner(
private def transformLazyExpression(getLazyExpression:
proto.LazyExpression): Expression = {
LazyExpression(transformExpression(getLazyExpression.getChild))
}
+
+ private def assertPlan(assertion: Boolean, message: String = ""): Unit = {
+ if (!assertion) throw InvalidPlanInput(message)
+ }
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index e44d3eacc66d..84d5fd68d4c7 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -274,7 +274,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with
SparkConnectPlanTest {
test("Simple Join") {
val incompleteJoin =
proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build()
- intercept[AssertionError](transform(incompleteJoin))
+ intercept[InvalidPlanInput](transform(incompleteJoin))
// Join type JOIN_TYPE_UNSPECIFIED is not supported.
intercept[InvalidPlanInput] {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]