This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new 9268392 [SPARK-33945][SQL][3.1] Handles a random seed consisting of an expr tree 9268392 is described below commit 9268392b957b263692e13fecaf9adec2136e1865 Author: Takeshi Yamamuro <yamam...@apache.org> AuthorDate: Sun Jan 3 21:36:25 2021 -0800 [SPARK-33945][SQL][3.1] Handles a random seed consisting of an expr tree ### What changes were proposed in this pull request? This PR intends to fix the minor bug that throws an analysis exception when a seed param in `rand`/`randn` having a expr tree (e.g., `rand(1 + 1)`) with constant folding (`ConstantFolding` and `ReorderAssociativeOperator`) disabled. A query to reproduce this issue is as follows; ``` // v3.1.0, v3.0.2, and v2.4.8 $./bin/spark-shell scala> sql("select rand(1 + 2)").show() +-------------------+ | rand((1 + 2))| +-------------------+ |0.25738143505962285| +-------------------+ $./bin/spark-shell --conf spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConstantFolding,org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator scala> sql("select rand(1 + 2)").show() org.apache.spark.sql.AnalysisException: Input argument to rand must be an integer, long or null literal.; at org.apache.spark.sql.catalyst.expressions.RDG.seed$lzycompute(randomExpressions.scala:49) at org.apache.spark.sql.catalyst.expressions.RDG.seed(randomExpressions.scala:46) at org.apache.spark.sql.catalyst.expressions.Rand.doGenCode(randomExpressions.scala:98) at org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:146) at scala.Option.getOrElse(Option.scala:189) ... ``` A root cause is that the match-case code below cannot handle the case described above: https://github.com/apache/spark/blob/42f5e62403469cec6da680b9fbedd0aa508dcbe5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala#L46-L51 ### Why are the changes needed? Bugfix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Checking if GA/Jenkins can pass Closes #30977 from maropu/FixRandSeedIssue. Authored-by: Takeshi Yamamuro <yamam...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../sql/catalyst/expressions/randomExpressions.scala | 6 +++--- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6a94517..a14b1fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -44,10 +44,10 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful } @transient protected lazy val seed: Long = child match { - case Literal(s, IntegerType) => s.asInstanceOf[Int] - case Literal(s, LongType) => s.asInstanceOf[Long] + case e if child.foldable && e.dataType == IntegerType => e.eval().asInstanceOf[Int] + case e if child.foldable && e.dataType == LongType => e.eval().asInstanceOf[Long] case _ => throw new AnalysisException( - s"Input argument to $prettyName must be an integer, long or null literal.") + s"Input argument to $prettyName must be an integer, long, or null constant.") } override def nullable: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 237d2c3..a003275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} -import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} +import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, ConvertToLocalRelation, NestedColumnAliasingSuite, ReorderAssociativeOperator} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -3758,6 +3758,21 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } }) } + + test("SPARK-33945: handles a random seed consisting of an expr tree") { + val excludedRules = Seq(ConstantFolding, ReorderAssociativeOperator).map(_.ruleName) + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.mkString(",")) { + Seq("rand", "randn").foreach { f => + // Just checks if a query works correctly + sql(s"SELECT $f(1 + 1)").collect() + + val msg = intercept[AnalysisException] { + sql(s"SELECT $f(id + 1) FROM range(0, 3)").collect() + }.getMessage + assert(msg.contains("must be an integer, long, or null constant")) + } + } + } } case class Foo(bar: Option[String]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org