This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3d5a7d90712 [SPARK-40178][SQL][COONECT] Support coalesce hints with ease for PySpark and R 3d5a7d90712 is described below commit 3d5a7d9071228affd5e06ff03091fc7525031ccc Author: Xianjin <xian...@apache.org> AuthorDate: Mon Aug 21 09:28:12 2023 +0900 [SPARK-40178][SQL][COONECT] Support coalesce hints with ease for PySpark and R ### What changes were proposed in this pull request? 1. Refactor `UnresolvedHint` to accept Expressions only as parameters 2. ResolveHints now parses StringLiteral as UnresolvedAttribute, which would allow users to specify string in parameters directly 3. `hint` method in Dataset now treats all its parameters as `Column`s or `Literal`s, all other values would be rejected. The method signature is kept for better compatibility and ease of use. It also matches how hint method is handled in the Connect module. 4. Connect: PySpark Connect now accepts `Column` as hint's parameters. 5. PySpark: allows `Column` as hint's parameters and tighten the input parameters type check: for list input, only list of primitive values is now allowed 6. SparkR: allows `Column` as hint's parameters and corresponding test. ### Why are the changes needed? This is a rework of #37616. Before this commit, there's no way for users to directly specify hint info that include column info in PySpark's hint method. In other ways, `rebalance` hint that requires column refs is not possible before this PR. ### Does this PR introduce _any_ user-facing change? Yes. PySpark and Spark for R uses may specify rebalance and repartition hint with ease. ### How was this patch tested? Added UTs. Closes #42255 from advancedxy/SPARK-40178. Lead-authored-by: Xianjin <xian...@apache.org> Co-authored-by: Xianjin YE <xian...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- R/pkg/R/DataFrame.R | 11 ++++- R/pkg/tests/fulltests/test_sparkSQL.R | 6 +++ .../sql/connect/planner/SparkConnectPlanner.scala | 16 +------- python/pyspark/sql/connect/dataframe.py | 21 ++++++++-- python/pyspark/sql/connect/plan.py | 2 +- python/pyspark/sql/dataframe.py | 48 ++++++++++++++++++---- .../sql/tests/connect/test_connect_basic.py | 2 +- python/pyspark/sql/tests/test_dataframe.py | 41 +++++++++++++++++- .../spark/sql/catalyst/analysis/ResolveHints.scala | 30 +++++++------- .../apache/spark/sql/catalyst/dsl/package.scala | 3 +- .../spark/sql/catalyst/plans/logical/hints.scala | 4 +- .../spark/sql/catalyst/analysis/DSLHintSuite.scala | 5 --- .../sql/catalyst/analysis/ResolveHintsSuite.scala | 6 +-- .../main/scala/org/apache/spark/sql/Dataset.scala | 19 ++++++++- .../org/apache/spark/sql/DataFrameHintSuite.scala | 30 ++++++++++---- 15 files changed, 178 insertions(+), 66 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 3f9bc9cb6d0..8650df64fe5 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -4146,7 +4146,7 @@ setMethod("hint", function(x, name, ...) { parameters <- list(...) if (!all(sapply(parameters, function(y) { - if (is.character(y) || is.numeric(y)) { + if (is.character(y) || is.numeric(y) || is(class(y), "characterOrColumn")) { TRUE } else if (is.list(y)) { all(sapply(y, function(z) { is.character(z) || is.numeric(z) })) @@ -4156,7 +4156,14 @@ setMethod("hint", }))) { stop("sql hint should be character, numeric, or list with character or numeric.") } - jdf <- callJMethod(x@sdf, "hint", name, parameters) + jparams <- lapply(parameters, function(c) { + if (is.character(c) || is.numeric(c) || is.list(c)) { + c + } else { + c@jc + } + }) + jdf <- callJMethod(x@sdf, "hint", name, jparams) dataFrame(jdf) }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 47688d7560c..1730c7be7f5 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2808,6 +2808,12 @@ test_that("test hint", { explain(hint(df, "hint1", 1.23456, "aaaaaaaaaa", hintList), TRUE) ) expect_true(any(grepl("1.23456, aaaaaaaaaa", execution_plan_hint))) + + a <- column("id") + rebalance_plan_hint <- capture.output( + explain(hint(df, "rebalance", as.integer(2), a), TRUE) + ) + expect_true(any(grepl("RebalancePartitions", rebalance_plan_hint))) }) test_that("toJSON() on DataFrame", { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 69990af8187..c23fd35acc7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -994,21 +994,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private def transformHint(rel: proto.Hint): LogicalPlan = { - def extractValue(expr: Expression): Any = { - expr match { - case Literal(s, StringType) if s != null => - UnresolvedAttribute.quotedString(s.toString) - case literal: Literal => literal.value - case UnresolvedFunction(Seq("array"), arguments, _, _, _) => - arguments.map(extractValue).toArray - case other => - throw InvalidPlanInput( - s"Expression should be a Literal or CreateMap or CreateArray, " + - s"but got ${other.getClass} $other") - } - } - - val params = rel.getParametersList.asScala.toSeq.map(transformExpression).map(extractValue) + val params = rel.getParametersList.asScala.toSeq.map(transformExpression) UnresolvedHint(rel.getName, params, transformRelation(rel.getInput)) } diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 7b326538a8e..16c42637fbd 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -879,7 +879,7 @@ class DataFrame: withWatermark.__doc__ = PySparkDataFrame.withWatermark.__doc__ def hint( - self, name: str, *parameters: Union["PrimitiveType", List["PrimitiveType"]] + self, name: str, *parameters: Union["PrimitiveType", "Column", List["PrimitiveType"]] ) -> "DataFrame": if len(parameters) == 1 and isinstance(parameters[0], list): parameters = parameters[0] # type: ignore[assignment] @@ -890,17 +890,32 @@ class DataFrame: message_parameters={"arg_name": "name", "arg_type": type(name).__name__}, ) - allowed_types = (str, list, float, int) + allowed_types = (str, float, int, Column, list) + allowed_primitive_types = (str, float, int) + allowed_types_repr = ", ".join( + [t.__name__ for t in allowed_types[:-1]] + + ["list[" + t.__name__ + "]" for t in allowed_primitive_types] + ) for p in parameters: if not isinstance(p, allowed_types): raise PySparkTypeError( error_class="INVALID_ITEM_FOR_CONTAINER", message_parameters={ "arg_name": "parameters", - "allowed_types": ", ".join([t.__name__ for t in allowed_types]), + "allowed_types": allowed_types_repr, "item_type": type(p).__name__, }, ) + if isinstance(p, list): + if not all(isinstance(e, allowed_primitive_types) for e in p): + raise PySparkTypeError( + error_class="INVALID_ITEM_FOR_CONTAINER", + message_parameters={ + "arg_name": "parameters", + "allowed_types": allowed_types_repr, + "item_type": type(p).__name__ + "[" + type(p[0]).__name__ + "]", + }, + ) return DataFrame.withPlan( plan.Hint(self._plan, name, list(parameters)), diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 7da93ef413c..e21a0fab000 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -579,7 +579,7 @@ class Hint(LogicalPlan): self._name = name for param in parameters: - assert isinstance(param, (list, str, float, int)) + assert isinstance(param, (list, str, float, int, Column)) if isinstance(param, list): assert all(isinstance(p, (str, float, int)) for p in param) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 03aaee8f2ec..0999d365cf0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -60,7 +60,7 @@ from pyspark.sql.types import ( Row, _parse_datatype_json_string, ) -from pyspark.sql.utils import get_active_spark_context +from pyspark.sql.utils import get_active_spark_context, toJArray from pyspark.sql.pandas.conversion import PandasConversionMixin from pyspark.sql.pandas.map_ops import PandasMapOpsMixin @@ -1117,7 +1117,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(jdf, self.sparkSession) def hint( - self, name: str, *parameters: Union["PrimitiveType", List["PrimitiveType"]] + self, name: str, *parameters: Union["PrimitiveType", "Column", List["PrimitiveType"]] ) -> "DataFrame": """Specifies some hint on the current :class:`DataFrame`. @@ -1165,20 +1165,54 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): message_parameters={"arg_name": "name", "arg_type": type(name).__name__}, ) - allowed_types = (str, list, float, int) + allowed_types = (str, float, int, Column, list) + allowed_primitive_types = (str, float, int) + allowed_types_repr = ", ".join( + [t.__name__ for t in allowed_types[:-1]] + + ["list[" + t.__name__ + "]" for t in allowed_primitive_types] + ) for p in parameters: if not isinstance(p, allowed_types): raise PySparkTypeError( error_class="DISALLOWED_TYPE_FOR_CONTAINER", message_parameters={ "arg_name": "parameters", - "arg_type": type(parameters).__name__, - "allowed_types": ", ".join(map(lambda x: x.__name__, allowed_types)), - "return_type": type(p).__name__, + "allowed_types": allowed_types_repr, + "item_type": type(p).__name__, }, ) + if isinstance(p, list): + if not all(isinstance(e, allowed_primitive_types) for e in p): + raise PySparkTypeError( + error_class="DISALLOWED_TYPE_FOR_CONTAINER", + message_parameters={ + "arg_name": "parameters", + "allowed_types": allowed_types_repr, + "item_type": type(p).__name__ + "[" + type(p[0]).__name__ + "]", + }, + ) + + def _converter(parameter: Union[str, list, float, int, Column]) -> Any: + if isinstance(parameter, Column): + return _to_java_column(parameter) + elif isinstance(parameter, list): + # for list input, we are assuming only one element type exist in the list. + # for empty list, we are converting it into an empty long[] in the JVM side. + gateway = self._sc._gateway + assert gateway is not None + jclass = gateway.jvm.long + if len(parameter) >= 1: + mapping = { + str: gateway.jvm.java.lang.String, + float: gateway.jvm.double, + int: gateway.jvm.long, + } + jclass = mapping[type(parameter[0])] + return toJArray(gateway, jclass, parameter) + else: + return parameter - jdf = self._jdf.hint(name, self._jseq(parameters)) + jdf = self._jdf.hint(name, self._jseq(parameters, _converter)) return DataFrame(jdf, self.sparkSession) def count(self) -> int: diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 9e8f5623971..aa27217be03 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1900,7 +1900,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): error_class="INVALID_ITEM_FOR_CONTAINER", message_parameters={ "arg_name": "parameters", - "allowed_types": "str, list, float, int", + "allowed_types": "str, float, int, Column, list[str], list[float], list[int]", "item_type": "dict", }, ) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 33049233dee..868d3f6f0aa 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -27,7 +27,7 @@ import io from contextlib import redirect_stdout from pyspark import StorageLevel -from pyspark.sql import SparkSession, Row +from pyspark.sql import SparkSession, Row, functions from pyspark.sql.functions import col, lit, count, sum, mean, struct from pyspark.sql.pandas.utils import pyarrow_version_less_than_minimum from pyspark.sql.types import ( @@ -645,11 +645,48 @@ class DataFrameTestsMixin: df1.join(df2.hint("broadcast"), "id").explain(True) self.assertEqual(1, buf.getvalue().count("BroadcastHashJoin")) + def test_coalesce_hints_with_string_parameter(self): + with self.sql_conf({"spark.sql.adaptive.coalescePartitions.enabled": False}): + df = self.spark.createDataFrame( + zip(["A", "B"] * 2**9, range(2**10)), + StructType([StructField("a", StringType()), StructField("n", IntegerType())]), + ) + with io.StringIO() as buf, redirect_stdout(buf): + # COALESCE + coalesce = df.hint("coalesce", 2) + coalesce.explain(True) + output = buf.getvalue() + self.assertGreaterEqual(output.count("Coalesce 2"), 1) + buf.truncate(0) + buf.seek(0) + + # REPARTITION_BY_RANGE + range_partitioned = df.hint("REPARTITION_BY_RANGE", 2, "a") + range_partitioned.explain(True) + output = buf.getvalue() + self.assertGreaterEqual(output.count("REPARTITION_BY_NUM"), 1) + buf.truncate(0) + buf.seek(0) + + # REBALANCE + rebalanced1 = df.hint("REBALANCE", "a") # just check this doesn't error + rebalanced1.explain(True) + rebalanced2 = df.hint("REBALANCE", 2) + rebalanced2.explain(True) + rebalanced3 = df.hint("REBALANCE", 2, "a") + rebalanced3.explain(True) + rebalanced4 = df.hint("REBALANCE", functions.col("a")) + rebalanced4.explain(True) + output = buf.getvalue() + self.assertGreaterEqual(output.count("REBALANCE_PARTITIONS_BY_NONE"), 1) + self.assertGreaterEqual(output.count("REBALANCE_PARTITIONS_BY_COL"), 3) + # add tests for SPARK-23647 (test more types for hint) def test_extended_hint_types(self): df = self.spark.range(10e10).toDF("id") such_a_nice_list = ["itworks1", "itworks2", "itworks3"] - hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list) + int_list = [1, 2, 3] + hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list, int_list) self.assertIsInstance(df.hint("broadcast", []), type(df)) self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 46ebffea1ae..bb8678ebe25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, IntegerLiteral, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, IntegerLiteral, SortOrder, StringLiteral} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -153,7 +153,7 @@ object ResolveHints { } else { // Otherwise, find within the subtree query plans to apply the hint. val relationNamesInHint = h.parameters.map { - case tableName: String => UnresolvedAttribute.parseAttributeName(tableName) + case StringLiteral(tableName) => UnresolvedAttribute.parseAttributeName(tableName) case tableId: UnresolvedAttribute => tableId.nameParts case unsupported => throw QueryCompilationErrors.joinStrategyHintParameterNotSupportedError(unsupported) @@ -204,16 +204,12 @@ object ResolveHints { hint.parameters match { case Seq(IntegerLiteral(numPartitions)) => Repartition(numPartitions, shuffle, hint.child) - case Seq(numPartitions: Int) => - Repartition(numPartitions, shuffle, hint.child) // The "COALESCE" hint (shuffle = false) must have a partition number only case _ if !shuffle => throw QueryCompilationErrors.invalidCoalesceHintParameterError(hintName) case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle => createRepartitionByExpression(Some(numPartitions), param.tail) - case param @ Seq(numPartitions: Int, _*) if shuffle => - createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(_*) if shuffle => createRepartitionByExpression(None, param) } @@ -242,8 +238,6 @@ object ResolveHints { hint.parameters match { case param @ Seq(IntegerLiteral(numPartitions), _*) => createRepartitionByExpression(Some(numPartitions), param.tail) - case param @ Seq(numPartitions: Int, _*) => - createRepartitionByExpression(Some(numPartitions), param.tail) case param @ Seq(_*) => createRepartitionByExpression(None, param) } @@ -266,24 +260,32 @@ object ResolveHints { hint.parameters match { case param @ Seq(IntegerLiteral(numPartitions), _*) => createRebalancePartitions(param.tail, Some(numPartitions)) - case param @ Seq(numPartitions: Int, _*) => - createRebalancePartitions(param.tail, Some(numPartitions)) case partitionExprs @ Seq(_*) => createRebalancePartitions(partitionExprs, None) } } + private def transformStringToAttribute(hint: UnresolvedHint): UnresolvedHint = { + // for all the coalesce hints, it's safe to transform the string literal to an attribute as + // all the parameters should be column names. + val parameters = hint.parameters.map { + case StringLiteral(name) => UnresolvedAttribute(name) + case e => e + } + hint.copy(parameters = parameters) + } + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( _.containsPattern(UNRESOLVED_HINT), ruleId) { case hint @ UnresolvedHint(hintName, _, _) => hintName.toUpperCase(Locale.ROOT) match { case "REPARTITION" => - createRepartition(shuffle = true, hint) + createRepartition(shuffle = true, transformStringToAttribute(hint)) case "COALESCE" => - createRepartition(shuffle = false, hint) + createRepartition(shuffle = false, transformStringToAttribute(hint)) case "REPARTITION_BY_RANGE" => - createRepartitionByRange(hint) + createRepartitionByRange(transformStringToAttribute(hint)) case "REBALANCE" if conf.adaptiveExecutionEnabled => - createRebalance(hint) + createRebalance(transformStringToAttribute(hint)) case _ => hint } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 08fda363905..559822369d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -521,8 +521,9 @@ package object dsl { EliminateSubqueryAliases(analyzed) } - def hint(name: String, parameters: Any*): LogicalPlan = + def hint(name: String, parameters: Expression*): LogicalPlan = { UnresolvedHint(name, parameters, logicalPlan) + } def sample( lowerBound: Double, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index b17bab7849b..ff7c79fbe89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_HINT} /** @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_ * @param parameters the parameters of the hint * @param child the [[LogicalPlan]] on which this hint applies */ -case class UnresolvedHint(name: String, parameters: Seq[Any], child: LogicalPlan) +case class UnresolvedHint(name: String, parameters: Seq[Expression], child: LogicalPlan) extends UnaryNode { // we need it to be resolved so that the analyzer can continue to analyze the rest of the query diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala index 358346bfa2b..ab5f35c61b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala @@ -42,10 +42,5 @@ class DSLHintSuite extends AnalysisTest { r1.hint("hint1", 1, $"a"), UnresolvedHint("hint1", Seq(1, $"a"), r1) ) - - comparePlans( - r1.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")), - UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")), r1) - ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 50bbe8be916..fa02e77975b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -352,11 +352,11 @@ class ResolveHintsSuite extends AnalysisTest { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3") { Seq( Nil -> 3, - Seq(1) -> 1, + Seq(Literal(1)) -> 1, Seq(UnresolvedAttribute("a")) -> 3, - Seq(1, UnresolvedAttribute("a")) -> 1).foreach { case (param, initialNumPartitions) => + Seq(Literal(1), UnresolvedAttribute("a")) -> 1).foreach { case (param, numberPartitions) => assert(UnresolvedHint("REBALANCE", param, testRelation).analyze - .asInstanceOf[RebalancePartitions].partitioning.numPartitions == initialNumPartitions) + .asInstanceOf[RebalancePartitions].partitioning.numPartitions == numberPartitions) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fd8421fa096..2d9e0e231d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1401,12 +1401,29 @@ class Dataset[T] private[sql]( * df1.join(df2.hint("broadcast")) * }}} * + * the following code specifies that this dataset could be rebalanced with given number of + * partitions: + * + * {{{ + * df1.hint("rebalance", 10) + * }}} + * + * @param name the name of the hint + * @param parameters the parameters of the hint, all the parameters should be a `Column` or + * `Expression` or `Symbol` or could be converted into a `Literal` + * * @group basic * @since 2.2.0 */ @scala.annotation.varargs def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan { - UnresolvedHint(name, parameters, logicalPlan) + val exprs = parameters.map { + case c: Column => c.expr + case s: Symbol => Column(s.name).expr + case e: Expression => e + case literal => Literal(literal) + }.toSeq + UnresolvedHint(name, exprs, logicalPlan) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala index 37dc8f1bcc7..6fdee8613d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.functions.array import org.apache.spark.sql.test.SharedSparkSession class DataFrameHintSuite extends AnalysisTest with SharedSparkSession { @@ -42,36 +44,36 @@ class DataFrameHintSuite extends AnalysisTest with SharedSparkSession { check( df.hint("hint1", 1, "a"), - UnresolvedHint("hint1", Seq(1, "a"), df.logicalPlan) + UnresolvedHint("hint1", Seq(Literal(1), Literal("a")), df.logicalPlan) ) check( df.hint("hint1", 1, $"a"), - UnresolvedHint("hint1", Seq(1, $"a"), + UnresolvedHint("hint1", Seq(Literal(1), $"a".expr), df.logicalPlan ) ) check( - df.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")), - UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")), + df.hint("hint1", Array(1, 2, 3), array($"a", $"b", $"c")), + UnresolvedHint("hint1", Seq(Literal(Array(1, 2, 3)), array($"a", $"b", $"c").expr), df.logicalPlan ) ) } - test("coalesce and repartition hint") { + test("coalesce, repartition and rebalance hint") { check( df.hint("COALESCE", 10), - UnresolvedHint("COALESCE", Seq(10), df.logicalPlan)) + UnresolvedHint("COALESCE", Seq(Literal(10)), df.logicalPlan)) check( df.hint("REPARTITION", 100), - UnresolvedHint("REPARTITION", Seq(100), df.logicalPlan)) + UnresolvedHint("REPARTITION", Seq(Literal(100)), df.logicalPlan)) check( df.hint("REPARTITION", 10, $"id".expr), - UnresolvedHint("REPARTITION", Seq(10, $"id".expr), df.logicalPlan)) + UnresolvedHint("REPARTITION", Seq(Literal(10), $"id".expr), df.logicalPlan)) check( df.hint("REPARTITION_BY_RANGE", $"id".expr), @@ -79,6 +81,16 @@ class DataFrameHintSuite extends AnalysisTest with SharedSparkSession { check( df.hint("REPARTITION_BY_RANGE", 10, $"id".expr), - UnresolvedHint("REPARTITION_BY_RANGE", Seq(10, $"id".expr), df.logicalPlan)) + UnresolvedHint("REPARTITION_BY_RANGE", Seq(Literal(10), $"id".expr), df.logicalPlan)) + + // simple column name should be accepted + check( + df.hint("REBALANCE", 10, "id"), + UnresolvedHint("REBALANCE", Seq(Literal(10), Literal("id")), df.logicalPlan)) + + check( + df.hint("REBALANCE", 10, $"id".expr), + UnresolvedHint("REBALANCE", Seq(Literal(10), $"id".expr), df.logicalPlan)) } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org