This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 6e75fd85e260 [SPARK-51213][SQL] Keep Expression class info when 
resolving hint parameters
6e75fd85e260 is described below

commit 6e75fd85e2601d0a415cba04ee7b872840d65bb4
Author: Kent Yao <y...@apache.org>
AuthorDate: Fri Feb 14 16:05:36 2025 +0800

    [SPARK-51213][SQL] Keep Expression class info when resolving hint parameters
    
    ### What changes were proposed in this pull request?
    
    Currently, the expression class info is explicitly erased when resolving 
hint parameters, this PR undo this action to keep the class info, so that it 
can be used in error handling for better and consistent representation in error 
messages.
    
    ### Why are the changes needed?
    
    code refactoring and error improvement.
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    
    new tests added
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49950 from yaooqinn/SPARK-51213.
    
    Authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: Kent Yao <y...@apache.org>
    (cherry picked from commit 47edf4b54e2e267d04d038ea116d0bad986d6328)
    Signed-off-by: Kent Yao <y...@apache.org>
---
 .../sql/catalyst/analysis/HintErrorLogger.scala    |  3 ++-
 .../spark/sql/catalyst/analysis/ResolveHints.scala | 22 +++++++++-------------
 .../spark/sql/catalyst/plans/logical/hints.scala   |  2 +-
 .../spark/sql/errors/QueryCompilationErrors.scala  |  8 ++++----
 .../sql/catalyst/analysis/ResolveHintsSuite.scala  |  6 +++++-
 5 files changed, 21 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
index 68c6ae9c03e3..5301a3683c7d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.{QUERY_HINT, RELATION_NAME, 
UNSUPPORTED_HINT_REASON}
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.plans.logical.{HintErrorHandler, HintInfo}
 
 /**
@@ -27,7 +28,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{HintErrorHandler, HintInfo}
 object HintErrorLogger extends HintErrorHandler with Logging {
   import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
 
-  override def hintNotRecognized(name: String, parameters: Seq[Any]): Unit = {
+  override def hintNotRecognized(name: String, parameters: Seq[Expression]): 
Unit = {
     logWarning(log"Unrecognized hint: " +
       log"${MDC(QUERY_HINT, hintToPrettyString(name, parameters))}")
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
index 30f26c1f08d0..0a47448fa049 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
@@ -174,7 +174,7 @@ object ResolveHints {
    * COALESCE Hint accepts names "COALESCE", "REPARTITION", and 
"REPARTITION_BY_RANGE".
    */
   object ResolveCoalesceHints extends Rule[LogicalPlan] {
-    private def getNumOfPartitions(hint: UnresolvedHint): (Option[Int], 
Seq[Any]) = {
+    private def getNumOfPartitions(hint: UnresolvedHint): (Option[Int], 
Seq[Expression]) = {
       hint.parameters match {
         case Seq(ByteLiteral(numPartitions), _*) =>
           (Some(numPartitions.toInt), hint.parameters.tail)
@@ -185,7 +185,7 @@ object ResolveHints {
       }
     }
 
-    private def validateParameters(hint: String, parms: Seq[Any]): Unit = {
+    private def validateParameters(hint: String, parms: Seq[Expression]): Unit 
= {
       val invalidParams = parms.filter(!_.isInstanceOf[UnresolvedAttribute])
       if (invalidParams.nonEmpty) {
         val hintName = hint.toUpperCase(Locale.ROOT)
@@ -198,18 +198,16 @@ object ResolveHints {
      * The "COALESCE" hint only has a partition number as a parameter. The 
"REPARTITION" hint
      * has a partition number, columns, or both of them as parameters.
      */
-    private def createRepartition(
-        shuffle: Boolean, hint: UnresolvedHint): LogicalPlan = {
+    private def createRepartition(shuffle: Boolean, hint: UnresolvedHint): 
LogicalPlan = {
 
       def createRepartitionByExpression(
-          numPartitions: Option[Int], partitionExprs: Seq[Any]): 
RepartitionByExpression = {
+          numPartitions: Option[Int], partitionExprs: Seq[Expression]): 
RepartitionByExpression = {
         val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
         if (sortOrders.nonEmpty) {
           throw 
QueryCompilationErrors.invalidRepartitionExpressionsError(sortOrders)
         }
         validateParameters(hint.name, partitionExprs)
-        RepartitionByExpression(
-          partitionExprs.map(_.asInstanceOf[Expression]), hint.child, 
numPartitions)
+        RepartitionByExpression(partitionExprs, hint.child, numPartitions)
       }
 
       getNumOfPartitions(hint) match {
@@ -232,7 +230,7 @@ object ResolveHints {
      */
     private def createRepartitionByRange(hint: UnresolvedHint): 
RepartitionByExpression = {
       def createRepartitionByExpression(
-          numPartitions: Option[Int], partitionExprs: Seq[Any]): 
RepartitionByExpression = {
+          numPartitions: Option[Int], partitionExprs: Seq[Expression]): 
RepartitionByExpression = {
         validateParameters(hint.name, partitionExprs)
         val sortOrder = partitionExprs.map {
           case expr: SortOrder => expr
@@ -251,12 +249,10 @@ object ResolveHints {
 
     private def createRebalance(hint: UnresolvedHint): LogicalPlan = {
       def createRebalancePartitions(
-          partitionExprs: Seq[Any], initialNumPartitions: Option[Int]): 
RebalancePartitions = {
+          partitionExprs: Seq[Expression],
+          initialNumPartitions: Option[Int]): RebalancePartitions = {
         validateParameters(hint.name, partitionExprs)
-        RebalancePartitions(
-          partitionExprs.map(_.asInstanceOf[Expression]),
-          hint.child,
-          initialNumPartitions)
+        RebalancePartitions(partitionExprs, hint.child, initialNumPartitions)
       }
 
       getNumOfPartitions(hint) match {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index c8d2be298745..8a6182b87b77 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -213,7 +213,7 @@ trait HintErrorHandler {
    * @param name the unrecognized hint name
    * @param parameters the hint parameters
    */
-  def hintNotRecognized(name: String, parameters: Seq[Any]): Unit
+  def hintNotRecognized(name: String, parameters: Seq[Expression]): Unit
 
   /**
    * Callback for relation names specified in a hint that cannot be associated 
with any relation
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 83188af14a67..fa0a90135934 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -949,21 +949,21 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
       messageParameters = Map.empty)
   }
 
-  def joinStrategyHintParameterNotSupportedError(unsupported: Any): Throwable 
= {
+  def joinStrategyHintParameterNotSupportedError(unsupported: Expression): 
Throwable = {
     new AnalysisException(
       errorClass = "_LEGACY_ERROR_TEMP_1046",
       messageParameters = Map(
-        "unsupported" -> unsupported.toString,
+        "unsupported" -> toSQLExpr(unsupported),
         "class" -> unsupported.getClass.toString))
   }
 
   def invalidHintParameterError(
-      hintName: String, invalidParams: Seq[Any]): Throwable = {
+      hintName: String, invalidParams: Seq[Expression]): Throwable = {
     new AnalysisException(
       errorClass = "_LEGACY_ERROR_TEMP_1047",
       messageParameters = Map(
         "hintName" -> hintName,
-        "invalidParams" -> invalidParams.mkString(", ")))
+        "invalidParams" -> invalidParams.map(toSQLExpr).mkString(", ")))
   }
 
   def invalidCoalesceHintParameterError(hintName: String): Throwable = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index 86c87ff866b7..1c36728663f8 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -347,7 +347,7 @@ class ResolveHintsSuite extends AnalysisTest {
     }
 
     val msg = "REBALANCE Hint parameters should include an optional integral 
partitionNum " +
-      "and/or columns, but 1 can not be recognized as either partitionNum or 
columns."
+      "and/or columns, but \"1\" can not be recognized as either partitionNum 
or columns."
     assertAnalysisError(
       UnresolvedHint("REBALANCE", Seq(Literal(1), Literal(1)), table("TaBlE")),
       Seq(msg))
@@ -355,6 +355,10 @@ class ResolveHintsSuite extends AnalysisTest {
     assertAnalysisError(
       UnresolvedHint("REBALANCE", Seq(1, Literal(1)), table("TaBlE")),
       Seq(msg))
+
+    assertAnalysisError(
+      UnresolvedHint("REBALANCE", Seq(1, Literal(Array[Byte](0, 1, 3))), 
table("TaBlE")),
+      Seq("X'000103'"))
   }
 
   test("SPARK-38410: Support specify initial partition number for rebalance") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to