This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d5a7d90712 [SPARK-40178][SQL][COONECT] Support coalesce hints with 
ease for PySpark and R
3d5a7d90712 is described below

commit 3d5a7d9071228affd5e06ff03091fc7525031ccc
Author: Xianjin <xian...@apache.org>
AuthorDate: Mon Aug 21 09:28:12 2023 +0900

    [SPARK-40178][SQL][COONECT] Support coalesce hints with ease for PySpark 
and R
    
    ### What changes were proposed in this pull request?
    1. Refactor `UnresolvedHint` to accept Expressions only as parameters
    2. ResolveHints now parses StringLiteral as UnresolvedAttribute, which 
would allow users to specify string in parameters directly
    3. `hint` method in Dataset now treats all its parameters as `Column`s or 
`Literal`s, all other values would be rejected. The method signature is kept 
for better compatibility and ease of use. It also matches how hint method is 
handled in the Connect module.
    4. Connect: PySpark Connect now accepts `Column` as hint's parameters.
    5. PySpark: allows `Column` as hint's parameters and tighten the input 
parameters type check: for list input, only list of primitive values is now 
allowed
    6. SparkR: allows `Column` as hint's parameters and corresponding test.
    
    ### Why are the changes needed?
    This is a rework of #37616. Before this commit, there's no way for users to 
directly specify hint info that include column info in PySpark's hint method. 
In other ways, `rebalance` hint that requires column refs is not possible 
before this PR.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. PySpark and Spark for R uses may specify rebalance and repartition 
hint with ease.
    
    ### How was this patch tested?
    Added UTs.
    
    Closes #42255 from advancedxy/SPARK-40178.
    
    Lead-authored-by: Xianjin <xian...@apache.org>
    Co-authored-by: Xianjin YE <xian...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 R/pkg/R/DataFrame.R                                | 11 ++++-
 R/pkg/tests/fulltests/test_sparkSQL.R              |  6 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 16 +-------
 python/pyspark/sql/connect/dataframe.py            | 21 ++++++++--
 python/pyspark/sql/connect/plan.py                 |  2 +-
 python/pyspark/sql/dataframe.py                    | 48 ++++++++++++++++++----
 .../sql/tests/connect/test_connect_basic.py        |  2 +-
 python/pyspark/sql/tests/test_dataframe.py         | 41 +++++++++++++++++-
 .../spark/sql/catalyst/analysis/ResolveHints.scala | 30 +++++++-------
 .../apache/spark/sql/catalyst/dsl/package.scala    |  3 +-
 .../spark/sql/catalyst/plans/logical/hints.scala   |  4 +-
 .../spark/sql/catalyst/analysis/DSLHintSuite.scala |  5 ---
 .../sql/catalyst/analysis/ResolveHintsSuite.scala  |  6 +--
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 19 ++++++++-
 .../org/apache/spark/sql/DataFrameHintSuite.scala  | 30 ++++++++++----
 15 files changed, 178 insertions(+), 66 deletions(-)

diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 3f9bc9cb6d0..8650df64fe5 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -4146,7 +4146,7 @@ setMethod("hint",
           function(x, name, ...) {
             parameters <- list(...)
             if (!all(sapply(parameters, function(y) {
-              if (is.character(y) || is.numeric(y)) {
+              if (is.character(y) || is.numeric(y) || is(class(y), 
"characterOrColumn")) {
                 TRUE
               } else if (is.list(y)) {
                 all(sapply(y, function(z) { is.character(z) || is.numeric(z) 
}))
@@ -4156,7 +4156,14 @@ setMethod("hint",
             }))) {
               stop("sql hint should be character, numeric, or list with 
character or numeric.")
             }
-            jdf <- callJMethod(x@sdf, "hint", name, parameters)
+            jparams <- lapply(parameters, function(c) {
+              if (is.character(c) || is.numeric(c) || is.list(c)) {
+                c
+              } else {
+                c@jc
+              }
+            })
+            jdf <- callJMethod(x@sdf, "hint", name, jparams)
             dataFrame(jdf)
           })
 
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R 
b/R/pkg/tests/fulltests/test_sparkSQL.R
index 47688d7560c..1730c7be7f5 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -2808,6 +2808,12 @@ test_that("test hint", {
     explain(hint(df, "hint1", 1.23456, "aaaaaaaaaa", hintList), TRUE)
   )
   expect_true(any(grepl("1.23456, aaaaaaaaaa", execution_plan_hint)))
+
+  a <- column("id")
+  rebalance_plan_hint <- capture.output(
+      explain(hint(df, "rebalance", as.integer(2), a), TRUE)
+  )
+  expect_true(any(grepl("RebalancePartitions", rebalance_plan_hint)))
 })
 
 test_that("toJSON() on DataFrame", {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 69990af8187..c23fd35acc7 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -994,21 +994,7 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
 
   private def transformHint(rel: proto.Hint): LogicalPlan = {
 
-    def extractValue(expr: Expression): Any = {
-      expr match {
-        case Literal(s, StringType) if s != null =>
-          UnresolvedAttribute.quotedString(s.toString)
-        case literal: Literal => literal.value
-        case UnresolvedFunction(Seq("array"), arguments, _, _, _) =>
-          arguments.map(extractValue).toArray
-        case other =>
-          throw InvalidPlanInput(
-            s"Expression should be a Literal or CreateMap or CreateArray, " +
-              s"but got ${other.getClass} $other")
-      }
-    }
-
-    val params = 
rel.getParametersList.asScala.toSeq.map(transformExpression).map(extractValue)
+    val params = rel.getParametersList.asScala.toSeq.map(transformExpression)
     UnresolvedHint(rel.getName, params, transformRelation(rel.getInput))
   }
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 7b326538a8e..16c42637fbd 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -879,7 +879,7 @@ class DataFrame:
     withWatermark.__doc__ = PySparkDataFrame.withWatermark.__doc__
 
     def hint(
-        self, name: str, *parameters: Union["PrimitiveType", 
List["PrimitiveType"]]
+        self, name: str, *parameters: Union["PrimitiveType", "Column", 
List["PrimitiveType"]]
     ) -> "DataFrame":
         if len(parameters) == 1 and isinstance(parameters[0], list):
             parameters = parameters[0]  # type: ignore[assignment]
@@ -890,17 +890,32 @@ class DataFrame:
                 message_parameters={"arg_name": "name", "arg_type": 
type(name).__name__},
             )
 
-        allowed_types = (str, list, float, int)
+        allowed_types = (str, float, int, Column, list)
+        allowed_primitive_types = (str, float, int)
+        allowed_types_repr = ", ".join(
+            [t.__name__ for t in allowed_types[:-1]]
+            + ["list[" + t.__name__ + "]" for t in allowed_primitive_types]
+        )
         for p in parameters:
             if not isinstance(p, allowed_types):
                 raise PySparkTypeError(
                     error_class="INVALID_ITEM_FOR_CONTAINER",
                     message_parameters={
                         "arg_name": "parameters",
-                        "allowed_types": ", ".join([t.__name__ for t in 
allowed_types]),
+                        "allowed_types": allowed_types_repr,
                         "item_type": type(p).__name__,
                     },
                 )
+            if isinstance(p, list):
+                if not all(isinstance(e, allowed_primitive_types) for e in p):
+                    raise PySparkTypeError(
+                        error_class="INVALID_ITEM_FOR_CONTAINER",
+                        message_parameters={
+                            "arg_name": "parameters",
+                            "allowed_types": allowed_types_repr,
+                            "item_type": type(p).__name__ + "[" + 
type(p[0]).__name__ + "]",
+                        },
+                    )
 
         return DataFrame.withPlan(
             plan.Hint(self._plan, name, list(parameters)),
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 7da93ef413c..e21a0fab000 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -579,7 +579,7 @@ class Hint(LogicalPlan):
         self._name = name
 
         for param in parameters:
-            assert isinstance(param, (list, str, float, int))
+            assert isinstance(param, (list, str, float, int, Column))
             if isinstance(param, list):
                 assert all(isinstance(p, (str, float, int)) for p in param)
 
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 03aaee8f2ec..0999d365cf0 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -60,7 +60,7 @@ from pyspark.sql.types import (
     Row,
     _parse_datatype_json_string,
 )
-from pyspark.sql.utils import get_active_spark_context
+from pyspark.sql.utils import get_active_spark_context, toJArray
 from pyspark.sql.pandas.conversion import PandasConversionMixin
 from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
 
@@ -1117,7 +1117,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return DataFrame(jdf, self.sparkSession)
 
     def hint(
-        self, name: str, *parameters: Union["PrimitiveType", 
List["PrimitiveType"]]
+        self, name: str, *parameters: Union["PrimitiveType", "Column", 
List["PrimitiveType"]]
     ) -> "DataFrame":
         """Specifies some hint on the current :class:`DataFrame`.
 
@@ -1165,20 +1165,54 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
                 message_parameters={"arg_name": "name", "arg_type": 
type(name).__name__},
             )
 
-        allowed_types = (str, list, float, int)
+        allowed_types = (str, float, int, Column, list)
+        allowed_primitive_types = (str, float, int)
+        allowed_types_repr = ", ".join(
+            [t.__name__ for t in allowed_types[:-1]]
+            + ["list[" + t.__name__ + "]" for t in allowed_primitive_types]
+        )
         for p in parameters:
             if not isinstance(p, allowed_types):
                 raise PySparkTypeError(
                     error_class="DISALLOWED_TYPE_FOR_CONTAINER",
                     message_parameters={
                         "arg_name": "parameters",
-                        "arg_type": type(parameters).__name__,
-                        "allowed_types": ", ".join(map(lambda x: x.__name__, 
allowed_types)),
-                        "return_type": type(p).__name__,
+                        "allowed_types": allowed_types_repr,
+                        "item_type": type(p).__name__,
                     },
                 )
+            if isinstance(p, list):
+                if not all(isinstance(e, allowed_primitive_types) for e in p):
+                    raise PySparkTypeError(
+                        error_class="DISALLOWED_TYPE_FOR_CONTAINER",
+                        message_parameters={
+                            "arg_name": "parameters",
+                            "allowed_types": allowed_types_repr,
+                            "item_type": type(p).__name__ + "[" + 
type(p[0]).__name__ + "]",
+                        },
+                    )
+
+        def _converter(parameter: Union[str, list, float, int, Column]) -> Any:
+            if isinstance(parameter, Column):
+                return _to_java_column(parameter)
+            elif isinstance(parameter, list):
+                # for list input, we are assuming only one element type exist 
in the list.
+                # for empty list, we are converting it into an empty long[] in 
the JVM side.
+                gateway = self._sc._gateway
+                assert gateway is not None
+                jclass = gateway.jvm.long
+                if len(parameter) >= 1:
+                    mapping = {
+                        str: gateway.jvm.java.lang.String,
+                        float: gateway.jvm.double,
+                        int: gateway.jvm.long,
+                    }
+                    jclass = mapping[type(parameter[0])]
+                return toJArray(gateway, jclass, parameter)
+            else:
+                return parameter
 
-        jdf = self._jdf.hint(name, self._jseq(parameters))
+        jdf = self._jdf.hint(name, self._jseq(parameters, _converter))
         return DataFrame(jdf, self.sparkSession)
 
     def count(self) -> int:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9e8f5623971..aa27217be03 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1900,7 +1900,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             error_class="INVALID_ITEM_FOR_CONTAINER",
             message_parameters={
                 "arg_name": "parameters",
-                "allowed_types": "str, list, float, int",
+                "allowed_types": "str, float, int, Column, list[str], 
list[float], list[int]",
                 "item_type": "dict",
             },
         )
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 33049233dee..868d3f6f0aa 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -27,7 +27,7 @@ import io
 from contextlib import redirect_stdout
 
 from pyspark import StorageLevel
-from pyspark.sql import SparkSession, Row
+from pyspark.sql import SparkSession, Row, functions
 from pyspark.sql.functions import col, lit, count, sum, mean, struct
 from pyspark.sql.pandas.utils import pyarrow_version_less_than_minimum
 from pyspark.sql.types import (
@@ -645,11 +645,48 @@ class DataFrameTestsMixin:
             df1.join(df2.hint("broadcast"), "id").explain(True)
             self.assertEqual(1, buf.getvalue().count("BroadcastHashJoin"))
 
+    def test_coalesce_hints_with_string_parameter(self):
+        with self.sql_conf({"spark.sql.adaptive.coalescePartitions.enabled": 
False}):
+            df = self.spark.createDataFrame(
+                zip(["A", "B"] * 2**9, range(2**10)),
+                StructType([StructField("a", StringType()), StructField("n", 
IntegerType())]),
+            )
+            with io.StringIO() as buf, redirect_stdout(buf):
+                # COALESCE
+                coalesce = df.hint("coalesce", 2)
+                coalesce.explain(True)
+                output = buf.getvalue()
+                self.assertGreaterEqual(output.count("Coalesce 2"), 1)
+                buf.truncate(0)
+                buf.seek(0)
+
+                # REPARTITION_BY_RANGE
+                range_partitioned = df.hint("REPARTITION_BY_RANGE", 2, "a")
+                range_partitioned.explain(True)
+                output = buf.getvalue()
+                self.assertGreaterEqual(output.count("REPARTITION_BY_NUM"), 1)
+                buf.truncate(0)
+                buf.seek(0)
+
+                # REBALANCE
+                rebalanced1 = df.hint("REBALANCE", "a")  # just check this 
doesn't error
+                rebalanced1.explain(True)
+                rebalanced2 = df.hint("REBALANCE", 2)
+                rebalanced2.explain(True)
+                rebalanced3 = df.hint("REBALANCE", 2, "a")
+                rebalanced3.explain(True)
+                rebalanced4 = df.hint("REBALANCE", functions.col("a"))
+                rebalanced4.explain(True)
+                output = buf.getvalue()
+                
self.assertGreaterEqual(output.count("REBALANCE_PARTITIONS_BY_NONE"), 1)
+                
self.assertGreaterEqual(output.count("REBALANCE_PARTITIONS_BY_COL"), 3)
+
     # add tests for SPARK-23647 (test more types for hint)
     def test_extended_hint_types(self):
         df = self.spark.range(10e10).toDF("id")
         such_a_nice_list = ["itworks1", "itworks2", "itworks3"]
-        hinted_df = df.hint("my awesome hint", 1.2345, "what", 
such_a_nice_list)
+        int_list = [1, 2, 3]
+        hinted_df = df.hint("my awesome hint", 1.2345, "what", 
such_a_nice_list, int_list)
 
         self.assertIsInstance(df.hint("broadcast", []), type(df))
         self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
index 46ebffea1ae..bb8678ebe25 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
@@ -21,7 +21,7 @@ import java.util.Locale
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, 
IntegerLiteral, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, 
IntegerLiteral, SortOrder, StringLiteral}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.CurrentOrigin
@@ -153,7 +153,7 @@ object ResolveHints {
         } else {
           // Otherwise, find within the subtree query plans to apply the hint.
           val relationNamesInHint = h.parameters.map {
-            case tableName: String => 
UnresolvedAttribute.parseAttributeName(tableName)
+            case StringLiteral(tableName) => 
UnresolvedAttribute.parseAttributeName(tableName)
             case tableId: UnresolvedAttribute => tableId.nameParts
             case unsupported =>
               throw 
QueryCompilationErrors.joinStrategyHintParameterNotSupportedError(unsupported)
@@ -204,16 +204,12 @@ object ResolveHints {
       hint.parameters match {
         case Seq(IntegerLiteral(numPartitions)) =>
           Repartition(numPartitions, shuffle, hint.child)
-        case Seq(numPartitions: Int) =>
-          Repartition(numPartitions, shuffle, hint.child)
         // The "COALESCE" hint (shuffle = false) must have a partition number 
only
         case _ if !shuffle =>
           throw 
QueryCompilationErrors.invalidCoalesceHintParameterError(hintName)
 
         case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
           createRepartitionByExpression(Some(numPartitions), param.tail)
-        case param @ Seq(numPartitions: Int, _*) if shuffle =>
-          createRepartitionByExpression(Some(numPartitions), param.tail)
         case param @ Seq(_*) if shuffle =>
           createRepartitionByExpression(None, param)
       }
@@ -242,8 +238,6 @@ object ResolveHints {
       hint.parameters match {
         case param @ Seq(IntegerLiteral(numPartitions), _*) =>
           createRepartitionByExpression(Some(numPartitions), param.tail)
-        case param @ Seq(numPartitions: Int, _*) =>
-          createRepartitionByExpression(Some(numPartitions), param.tail)
         case param @ Seq(_*) =>
           createRepartitionByExpression(None, param)
       }
@@ -266,24 +260,32 @@ object ResolveHints {
       hint.parameters match {
         case param @ Seq(IntegerLiteral(numPartitions), _*) =>
           createRebalancePartitions(param.tail, Some(numPartitions))
-        case param @ Seq(numPartitions: Int, _*) =>
-          createRebalancePartitions(param.tail, Some(numPartitions))
         case partitionExprs @ Seq(_*) =>
           createRebalancePartitions(partitionExprs, None)
       }
     }
 
+    private def transformStringToAttribute(hint: UnresolvedHint): 
UnresolvedHint = {
+      // for all the coalesce hints, it's safe to transform the string literal 
to an attribute as
+      // all the parameters should be column names.
+      val parameters = hint.parameters.map {
+        case StringLiteral(name) => UnresolvedAttribute(name)
+        case e => e
+      }
+      hint.copy(parameters = parameters)
+    }
+
     def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsWithPruning(
       _.containsPattern(UNRESOLVED_HINT), ruleId) {
       case hint @ UnresolvedHint(hintName, _, _) => 
hintName.toUpperCase(Locale.ROOT) match {
           case "REPARTITION" =>
-            createRepartition(shuffle = true, hint)
+            createRepartition(shuffle = true, transformStringToAttribute(hint))
           case "COALESCE" =>
-            createRepartition(shuffle = false, hint)
+            createRepartition(shuffle = false, 
transformStringToAttribute(hint))
           case "REPARTITION_BY_RANGE" =>
-            createRepartitionByRange(hint)
+            createRepartitionByRange(transformStringToAttribute(hint))
           case "REBALANCE" if conf.adaptiveExecutionEnabled =>
-            createRebalance(hint)
+            createRebalance(transformStringToAttribute(hint))
           case _ => hint
         }
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 08fda363905..559822369d6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -521,8 +521,9 @@ package object dsl {
         EliminateSubqueryAliases(analyzed)
       }
 
-      def hint(name: String, parameters: Any*): LogicalPlan =
+      def hint(name: String, parameters: Expression*): LogicalPlan = {
         UnresolvedHint(name, parameters, logicalPlan)
+      }
 
       def sample(
           lowerBound: Double,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index b17bab7849b..ff7c79fbe89 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, 
UNRESOLVED_HINT}
 
 /**
@@ -27,7 +27,7 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_
  * @param parameters the parameters of the hint
  * @param child the [[LogicalPlan]] on which this hint applies
  */
-case class UnresolvedHint(name: String, parameters: Seq[Any], child: 
LogicalPlan)
+case class UnresolvedHint(name: String, parameters: Seq[Expression], child: 
LogicalPlan)
   extends UnaryNode {
 
   // we need it to be resolved so that the analyzer can continue to analyze 
the rest of the query
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala
index 358346bfa2b..ab5f35c61b1 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala
@@ -42,10 +42,5 @@ class DSLHintSuite extends AnalysisTest {
       r1.hint("hint1", 1, $"a"),
       UnresolvedHint("hint1", Seq(1, $"a"), r1)
     )
-
-    comparePlans(
-      r1.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")),
-      UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")), r1)
-    )
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index 50bbe8be916..fa02e77975b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -352,11 +352,11 @@ class ResolveHintsSuite extends AnalysisTest {
     withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3") {
       Seq(
         Nil -> 3,
-        Seq(1) -> 1,
+        Seq(Literal(1)) -> 1,
         Seq(UnresolvedAttribute("a")) -> 3,
-        Seq(1, UnresolvedAttribute("a")) -> 1).foreach { case (param, 
initialNumPartitions) =>
+        Seq(Literal(1), UnresolvedAttribute("a")) -> 1).foreach { case (param, 
numberPartitions) =>
         assert(UnresolvedHint("REBALANCE", param, testRelation).analyze
-          .asInstanceOf[RebalancePartitions].partitioning.numPartitions == 
initialNumPartitions)
+          .asInstanceOf[RebalancePartitions].partitioning.numPartitions == 
numberPartitions)
       }
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index fd8421fa096..2d9e0e231d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1401,12 +1401,29 @@ class Dataset[T] private[sql](
    *   df1.join(df2.hint("broadcast"))
    * }}}
    *
+   * the following code specifies that this dataset could be rebalanced with 
given number of
+   * partitions:
+   *
+   * {{{
+   *    df1.hint("rebalance", 10)
+   * }}}
+   *
+   * @param name the name of the hint
+   * @param parameters the parameters of the hint, all the parameters should 
be a `Column` or
+   *                   `Expression` or `Symbol` or could be converted into a 
`Literal`
+   *
    * @group basic
    * @since 2.2.0
    */
   @scala.annotation.varargs
   def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan {
-    UnresolvedHint(name, parameters, logicalPlan)
+    val exprs = parameters.map {
+      case c: Column => c.expr
+      case s: Symbol => Column(s.name).expr
+      case e: Expression => e
+      case literal => Literal(literal)
+    }.toSeq
+    UnresolvedHint(name, exprs, logicalPlan)
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala
index 37dc8f1bcc7..6fdee8613d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala
@@ -18,7 +18,9 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.analysis.AnalysisTest
+import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.functions.array
 import org.apache.spark.sql.test.SharedSparkSession
 
 class DataFrameHintSuite extends AnalysisTest with SharedSparkSession {
@@ -42,36 +44,36 @@ class DataFrameHintSuite extends AnalysisTest with 
SharedSparkSession {
 
     check(
       df.hint("hint1", 1, "a"),
-      UnresolvedHint("hint1", Seq(1, "a"), df.logicalPlan)
+      UnresolvedHint("hint1", Seq(Literal(1), Literal("a")), df.logicalPlan)
     )
 
     check(
       df.hint("hint1", 1, $"a"),
-      UnresolvedHint("hint1", Seq(1, $"a"),
+      UnresolvedHint("hint1", Seq(Literal(1), $"a".expr),
         df.logicalPlan
       )
     )
 
     check(
-      df.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")),
-      UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")),
+      df.hint("hint1", Array(1, 2, 3), array($"a", $"b", $"c")),
+      UnresolvedHint("hint1", Seq(Literal(Array(1, 2, 3)), array($"a", $"b", 
$"c").expr),
         df.logicalPlan
       )
     )
   }
 
-  test("coalesce and repartition hint") {
+  test("coalesce, repartition and rebalance hint") {
     check(
       df.hint("COALESCE", 10),
-      UnresolvedHint("COALESCE", Seq(10), df.logicalPlan))
+      UnresolvedHint("COALESCE", Seq(Literal(10)), df.logicalPlan))
 
     check(
       df.hint("REPARTITION", 100),
-      UnresolvedHint("REPARTITION", Seq(100), df.logicalPlan))
+      UnresolvedHint("REPARTITION", Seq(Literal(100)), df.logicalPlan))
 
     check(
       df.hint("REPARTITION", 10, $"id".expr),
-      UnresolvedHint("REPARTITION", Seq(10, $"id".expr), df.logicalPlan))
+      UnresolvedHint("REPARTITION", Seq(Literal(10), $"id".expr), 
df.logicalPlan))
 
     check(
       df.hint("REPARTITION_BY_RANGE", $"id".expr),
@@ -79,6 +81,16 @@ class DataFrameHintSuite extends AnalysisTest with 
SharedSparkSession {
 
     check(
       df.hint("REPARTITION_BY_RANGE", 10, $"id".expr),
-      UnresolvedHint("REPARTITION_BY_RANGE", Seq(10, $"id".expr), 
df.logicalPlan))
+      UnresolvedHint("REPARTITION_BY_RANGE", Seq(Literal(10), $"id".expr), 
df.logicalPlan))
+
+    // simple column name should be accepted
+    check(
+      df.hint("REBALANCE", 10, "id"),
+      UnresolvedHint("REBALANCE", Seq(Literal(10), Literal("id")), 
df.logicalPlan))
+
+    check(
+      df.hint("REBALANCE", 10, $"id".expr),
+      UnresolvedHint("REBALANCE", Seq(Literal(10), $"id".expr), 
df.logicalPlan))
   }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to