This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fab4ceb [SPARK-38240][SQL] Improve RuntimeReplaceable and add a
guideline for adding new functions
fab4ceb is described below
commit fab4ceb157baac870f6d50b942084bb9b2cd4ad2
Author: Wenchen Fan <[email protected]>
AuthorDate: Wed Feb 23 15:32:00 2022 +0800
[SPARK-38240][SQL] Improve RuntimeReplaceable and add a guideline for
adding new functions
### What changes were proposed in this pull request?
This PR improves `RuntimeReplaceable` so that it can
1. Customize the type coercion behavior instead of always inheriting from
the replacement expression. This is useful for expressions like `ToBinary`,
where its replacement expression can be `Cast` that does not have type coercion.
2. Support aggregate functions.
This PR also adds a guideline for adding new SQL functions, with
`RuntimeReplaceable` and `ExpressionBuilder`. See
https://github.com/apache/spark/pull/35534/files#diff-6c6ba3e220b9d155160e4e25305fdd3a4835b7ce9eba230a7ae70bdd97047313R330
### Why are the changes needed?
Since we are keep adding new functions, it's better to make
`RuntimeReplaceable` more useful and set up a standard for adding functions.
### Does this PR introduce _any_ user-facing change?
Improves error messages of some functions.
### How was this patch tested?
existing tests
Closes #35534 from cloud-fan/refactor.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/examples/extensions/AgeExample.scala | 13 +-
.../sql/catalyst/analysis/CheckAnalysis.scala | 4 +
.../sql/catalyst/analysis/FunctionRegistry.scala | 63 +++-
.../sql/catalyst/analysis/TimeTravelSpec.scala | 2 +-
.../sql/catalyst/expressions/Expression.scala | 81 +++--
.../spark/sql/catalyst/expressions/TryEval.scala | 51 ++-
.../catalyst/expressions/aggregate/CountIf.scala | 35 +--
.../catalyst/expressions/aggregate/RegrCount.scala | 19 +-
...{UnevaluableAggs.scala => boolAggregates.scala} | 41 +--
.../expressions/collectionOperations.scala | 53 +++-
.../catalyst/expressions/datetimeExpressions.scala | 343 ++++++++++-----------
.../catalyst/expressions/intervalExpressions.scala | 10 +-
.../sql/catalyst/expressions/mathExpressions.scala | 97 +++---
.../spark/sql/catalyst/expressions/misc.scala | 91 +++---
.../sql/catalyst/expressions/nullExpressions.scala | 54 +---
.../catalyst/expressions/regexpExpressions.scala | 19 +-
.../catalyst/expressions/stringExpressions.scala | 207 ++++++-------
.../sql/catalyst/optimizer/finishAnalysis.scala | 21 +-
.../spark/sql/catalyst/parser/AstBuilder.scala | 2 +-
.../spark/sql/catalyst/trees/TreePatterns.scala | 3 -
.../apache/spark/sql/catalyst/util/package.scala | 4 +-
.../spark/sql/errors/QueryCompilationErrors.scala | 24 +-
.../spark/sql/errors/QueryExecutionErrors.scala | 8 +-
.../expressions/DateExpressionsSuite.scala | 8 +-
.../scala/org/apache/spark/sql/functions.scala | 4 +-
.../sql-functions/sql-expression-schema.md | 20 +-
.../sql-tests/inputs/string-functions.sql | 9 +-
.../resources/sql-tests/results/ansi/map.sql.out | 4 +-
.../results/ansi/string-functions.sql.out | 28 +-
.../results/ceil-floor-with-scale-param.sql.out | 14 +-
.../resources/sql-tests/results/extract.sql.out | 4 +-
.../resources/sql-tests/results/group-by.sql.out | 12 +-
.../test/resources/sql-tests/results/map.sql.out | 4 +-
.../sql-tests/results/string-functions.sql.out | 28 +-
.../sql-tests/results/timestamp-ltz.sql.out | 2 +-
.../sql-tests/results/udf/udf-group-by.sql.out | 8 +-
.../apache/spark/sql/DataFrameAggregateSuite.scala | 3 +-
37 files changed, 657 insertions(+), 736 deletions(-)
diff --git
a/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala
b/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala
index d25f220..e484024 100644
---
a/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala
+++
b/examples/src/main/scala/org/apache/spark/examples/extensions/AgeExample.scala
@@ -18,14 +18,15 @@
package org.apache.spark.examples.extensions
import org.apache.spark.sql.catalyst.expressions.{CurrentDate, Expression,
RuntimeReplaceable, SubtractDates}
+import org.apache.spark.sql.catalyst.trees.UnaryLike
/**
* How old are you in days?
*/
-case class AgeExample(birthday: Expression, child: Expression) extends
RuntimeReplaceable {
-
- def this(birthday: Expression) = this(birthday, SubtractDates(CurrentDate(),
birthday))
- override def exprsReplaced: Seq[Expression] = Seq(birthday)
-
- override protected def withNewChildInternal(newChild: Expression):
Expression = copy(newChild)
+case class AgeExample(birthday: Expression) extends RuntimeReplaceable with
UnaryLike[Expression] {
+ override lazy val replacement: Expression = SubtractDates(CurrentDate(),
birthday)
+ override def child: Expression = birthday
+ override protected def withNewChildInternal(newChild: Expression):
Expression = {
+ copy(birthday = newChild)
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index eacb5b2..0bf748c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -199,6 +199,10 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog {
failAnalysis(s"invalid cast from ${c.child.dataType.catalogString}
to " +
c.dataType.catalogString)
+ case e: RuntimeReplaceable if !e.replacement.resolved =>
+ throw new IllegalStateException("Illegal RuntimeReplaceable: " + e
+
+ "\nReplacement is unresolved: " + e.replacement)
+
case g: Grouping =>
failAnalysis("grouping() can only be used with
GroupingSets/Cube/Rollup")
case g: GroupingID =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 12fa723..6cf0fd1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -111,9 +111,11 @@ object FunctionRegistryBase {
name: String,
since: Option[String]): (ExpressionInfo, Seq[Expression] => T) = {
val runtimeClass = scala.reflect.classTag[T].runtimeClass
- // For `RuntimeReplaceable`, skip the constructor with most arguments,
which is the main
- // constructor and contains non-parameter `child` and should not be used
as function builder.
- val constructors = if
(classOf[RuntimeReplaceable].isAssignableFrom(runtimeClass)) {
+ // For `InheritAnalysisRules`, skip the constructor with most arguments,
which is the main
+ // constructor and contains non-parameter `replacement` and should not be
used as
+ // function builder.
+ val isRuntime =
classOf[InheritAnalysisRules].isAssignableFrom(runtimeClass)
+ val constructors = if (isRuntime) {
val all = runtimeClass.getConstructors
val maxNumArgs = all.map(_.getParameterCount).max
all.filterNot(_.getParameterCount == maxNumArgs)
@@ -324,7 +326,36 @@ object FunctionRegistry {
val FUNC_ALIAS = TreeNodeTag[String]("functionAliasName")
- // Note: Whenever we add a new entry here, make sure we also update
ExpressionToSQLSuite
+ //
==============================================================================================
+ // The guideline for adding SQL functions
+ //
==============================================================================================
+ // To add a SQL function, we usually need to create a new `Expression` for
the function, and
+ // implement the function logic in both the interpretation code path and
codegen code path of the
+ // `Expression`. We also need to define the type coercion behavior for the
function inputs, by
+ // extending `ImplicitCastInputTypes` or updating type coercion rules
directly.
+ //
+ // It's much simpler if the SQL function can be implemented with existing
expression(s). There are
+ // a few cases:
+ // - The function is simply an alias of another function. We can just
register the same
+ // expression with a different function name, e.g.
`expression[Rand]("random", true)`.
+ // - The function is mostly the same with another function, but has a
different parameter list.
+ // We can use `RuntimeReplaceable` to create a new expression, which can
customize the
+ // parameter list and analysis behavior (type coercion). The
`RuntimeReplaceable` expression
+ // will be replaced by the actual expression at the end of analysis. See
`Left` as an example.
+ // - The function can be implemented by combining some existing
expressions. We can use
+ // `RuntimeReplaceable` to define the combination. See `ParseToDate` as
an example.
+ // We can also inherit the analysis behavior from the replacement
expression, by
+ // mixing `InheritAnalysisRules`. See `TryAdd` as an example.
+ // - Similarly, we can use `RuntimeReplaceableAggregate` to implement new
aggregate functions.
+ //
+ // Sometimes, multiple functions share the same/similar expression
replacement logic and it's
+ // tedious to create many similar `RuntimeReplaceable` expressions. We can
use `ExpressionBuilder`
+ // to share the replacement logic. See
`ParseToTimestampLTZExpressionBuilder` as an example.
+ //
+ // With these tools, we can even implement a new SQL function with a Java
(static) method, and
+ // then create a `RuntimeReplaceable` expression to call the Java method
with `Invoke` or
+ // `StaticInvoke` expression. By doing so we don't need to implement codegen
for new functions
+ // anymore. See `AesEncrypt`/`AesDecrypt` as an example.
val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
// misc non-aggregate functions
expression[Abs]("abs"),
@@ -336,7 +367,7 @@ object FunctionRegistry {
expression[Inline]("inline"),
expressionGeneratorOuter[Inline]("inline_outer"),
expression[IsNaN]("isnan"),
- expression[IfNull]("ifnull"),
+ expression[Nvl]("ifnull", setAlias = true),
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
@@ -565,8 +596,9 @@ object FunctionRegistry {
expression[ToBinary]("to_binary"),
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
- expression[ParseToTimestampNTZ]("to_timestamp_ntz"),
- expression[ParseToTimestampLTZ]("to_timestamp_ltz"),
+ // We keep the 2 expression builders below to have different function docs.
+ expressionBuilder("to_timestamp_ntz",
ParseToTimestampNTZExpressionBuilder, setAlias = true),
+ expressionBuilder("to_timestamp_ltz",
ParseToTimestampLTZExpressionBuilder, setAlias = true),
expression[TruncDate]("trunc"),
expression[TruncTimestamp]("date_trunc"),
expression[UnixTimestamp]("unix_timestamp"),
@@ -578,13 +610,15 @@ object FunctionRegistry {
expression[SessionWindow]("session_window"),
expression[MakeDate]("make_date"),
expression[MakeTimestamp]("make_timestamp"),
- expression[MakeTimestampNTZ]("make_timestamp_ntz"),
- expression[MakeTimestampLTZ]("make_timestamp_ltz"),
+ // We keep the 2 expression builders below to have different function docs.
+ expressionBuilder("make_timestamp_ntz", MakeTimestampNTZExpressionBuilder,
setAlias = true),
+ expressionBuilder("make_timestamp_ltz", MakeTimestampLTZExpressionBuilder,
setAlias = true),
expression[MakeInterval]("make_interval"),
expression[MakeDTInterval]("make_dt_interval"),
expression[MakeYMInterval]("make_ym_interval"),
- expression[DatePart]("date_part"),
expression[Extract]("extract"),
+ // We keep the `DatePartExpressionBuilder` to have different function docs.
+ expressionBuilder("date_part", DatePartExpressionBuilder, setAlias = true),
expression[DateFromUnixDate]("date_from_unix_date"),
expression[UnixDate]("unix_date"),
expression[SecondsToTimestamp]("timestamp_seconds"),
@@ -806,12 +840,13 @@ object FunctionRegistry {
}
private def expressionBuilder[T <: ExpressionBuilder : ClassTag](
- name: String, builder: T, setAlias: Boolean = false)
- : (String, (ExpressionInfo, FunctionBuilder)) = {
+ name: String,
+ builder: T,
+ setAlias: Boolean = false): (String, (ExpressionInfo, FunctionBuilder))
= {
val info = FunctionRegistryBase.expressionInfo[T](name, None)
val funcBuilder = (expressions: Seq[Expression]) => {
assert(expressions.forall(_.resolved), "function arguments must be
resolved.")
- val expr = builder.build(expressions)
+ val expr = builder.build(name, expressions)
if (setAlias) expr.setTagValue(FUNC_ALIAS, name)
expr
}
@@ -915,5 +950,5 @@ object TableFunctionRegistry {
}
trait ExpressionBuilder {
- def build(expressions: Seq[Expression]): Expression
+ def build(funcName: String, expressions: Seq[Expression]): Expression
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala
index cbb6e8b..7e79c03 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala
@@ -41,7 +41,7 @@ object TimeTravelSpec {
throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts)
}
val tsToEval = ts.transform {
- case r: RuntimeReplaceable => r.child
+ case r: RuntimeReplaceable => r.replacement
case _: Unevaluable =>
throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts)
case e if !e.deterministic =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 32b25f5..4ff5267 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -21,7 +21,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
TypeCheckResult, TypeCoercion}
-import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike,
QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
@@ -352,34 +352,41 @@ trait Unevaluable extends Expression {
* An expression that gets replaced at runtime (currently by the optimizer)
into a different
* expression for evaluation. This is mainly used to provide compatibility
with other databases.
* For example, we use this to support "nvl" by replacing it with "coalesce".
- *
- * A RuntimeReplaceable should have the original parameters along with a
"child" expression in the
- * case class constructor, and define a normal constructor that accepts only
the original
- * parameters. For an example, see [[Nvl]]. To make sure the explain plan and
expression SQL
- * works correctly, the implementation should also override flatArguments
method and sql method.
*/
-trait RuntimeReplaceable extends UnaryExpression with Unevaluable {
- override def nullable: Boolean = child.nullable
- override def dataType: DataType = child.dataType
+trait RuntimeReplaceable extends Expression {
+ def replacement: Expression
+
+ override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)
+ override def nullable: Boolean = replacement.nullable
+ override def dataType: DataType = replacement.dataType
// As this expression gets replaced at optimization with its `child"
expression,
// two `RuntimeReplaceable` are considered to be semantically equal if their
"child" expressions
// are semantically equal.
- override lazy val preCanonicalized: Expression = child.preCanonicalized
+ override lazy val preCanonicalized: Expression = replacement.preCanonicalized
- /**
- * Only used to generate SQL representation of this expression.
- *
- * Implementations should override this with original parameters
- */
- def exprsReplaced: Seq[Expression]
-
- override def sql: String = mkString(exprsReplaced.map(_.sql))
-
- final override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)
+ final override def eval(input: InternalRow = null): Any =
+ throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
+ final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode =
+ throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
+}
- def mkString(childrenString: Seq[String]): String = {
- prettyName + childrenString.mkString("(", ", ", ")")
+/**
+ * An add-on of [[RuntimeReplaceable]]. It makes `replacement` the child of
the expression, to
+ * inherit the analysis rules for it, such as type coercion. The
implementation should put
+ * `replacement` in the case class constructor, and define a normal
constructor that accepts only
+ * the original parameters. For an example, see [[TryAdd]]. To make sure the
explain plan and
+ * expression SQL works correctly, the implementation should also implement
the `parameters` method.
+ */
+trait InheritAnalysisRules extends UnaryLike[Expression] { self:
RuntimeReplaceable =>
+ override def child: Expression = replacement
+ def parameters: Seq[Expression]
+ override def flatArguments: Iterator[Any] = parameters.iterator
+ // This method is used to generate a SQL string with transformed inputs.
This is necessary as
+ // the actual inputs are not the children of this expression.
+ def makeSQLString(childrenSQL: Seq[String]): String = {
+ prettyName + childrenSQL.mkString("(", ", ", ")")
}
+ final override def sql: String = makeSQLString(parameters.map(_.sql))
}
/**
@@ -388,29 +395,13 @@ trait RuntimeReplaceable extends UnaryExpression with
Unevaluable {
* with other databases. For example, we use this to support every, any/some
aggregates by rewriting
* them with Min and Max respectively.
*/
-trait UnevaluableAggregate extends DeclarativeAggregate {
-
- override def nullable: Boolean = true
-
- override lazy val aggBufferAttributes =
- throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError(
- "aggBufferAttributes", this)
-
- override lazy val initialValues: Seq[Expression] =
- throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError(
- "initialValues", this)
-
- override lazy val updateExpressions: Seq[Expression] =
- throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError(
- "updateExpressions", this)
-
- override lazy val mergeExpressions: Seq[Expression] =
- throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError(
- "mergeExpressions", this)
-
- override lazy val evaluateExpression: Expression =
- throw QueryExecutionErrors.evaluateUnevaluableAggregateUnsupportedError(
- "evaluateExpression", this)
+abstract class RuntimeReplaceableAggregate extends AggregateFunction with
RuntimeReplaceable {
+ def aggBufferSchema: StructType = throw new IllegalStateException(
+ "RuntimeReplaceableAggregate.aggBufferSchema should not be called")
+ def aggBufferAttributes: Seq[AttributeReference] = throw new
IllegalStateException(
+ "RuntimeReplaceableAggregate.aggBufferAttributes should not be called")
+ def inputAggBufferAttributes: Seq[AttributeReference] = throw new
IllegalStateException(
+ "RuntimeReplaceableAggregate.inputAggBufferAttributes should not be
called")
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
index 4663d48..7a8a689 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala
@@ -75,19 +75,17 @@ case class TryEval(child: Expression) extends
UnaryExpression with NullIntoleran
since = "3.2.0",
group = "math_funcs")
// scalastyle:on line.size.limit
-case class TryAdd(left: Expression, right: Expression, child: Expression)
- extends RuntimeReplaceable {
+case class TryAdd(left: Expression, right: Expression, replacement: Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) =
this(left, right, TryEval(Add(left, right, failOnError = true)))
- override def flatArguments: Iterator[Any] = Iterator(left, right)
-
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
-
override def prettyName: String = "try_add"
+ override def parameters: Seq[Expression] = Seq(left, right)
+
override protected def withNewChildInternal(newChild: Expression):
Expression =
- this.copy(child = newChild)
+ this.copy(replacement = newChild)
}
// scalastyle:off line.size.limit
@@ -110,19 +108,18 @@ case class TryAdd(left: Expression, right: Expression,
child: Expression)
since = "3.2.0",
group = "math_funcs")
// scalastyle:on line.size.limit
-case class TryDivide(left: Expression, right: Expression, child: Expression)
- extends RuntimeReplaceable {
+case class TryDivide(left: Expression, right: Expression, replacement:
Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) =
this(left, right, TryEval(Divide(left, right, failOnError = true)))
- override def flatArguments: Iterator[Any] = Iterator(left, right)
-
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
-
override def prettyName: String = "try_divide"
- override protected def withNewChildInternal(newChild: Expression):
Expression =
- this.copy(child = newChild)
+ override def parameters: Seq[Expression] = Seq(left, right)
+
+ override protected def withNewChildInternal(newChild: Expression):
Expression = {
+ copy(replacement = newChild)
+ }
}
@ExpressionDescription(
@@ -145,19 +142,17 @@ case class TryDivide(left: Expression, right: Expression,
child: Expression)
""",
since = "3.3.0",
group = "math_funcs")
-case class TrySubtract(left: Expression, right: Expression, child: Expression)
- extends RuntimeReplaceable {
+case class TrySubtract(left: Expression, right: Expression, replacement:
Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) =
this(left, right, TryEval(Subtract(left, right, failOnError = true)))
- override def flatArguments: Iterator[Any] = Iterator(left, right)
-
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
-
override def prettyName: String = "try_subtract"
+ override def parameters: Seq[Expression] = Seq(left, right)
+
override protected def withNewChildInternal(newChild: Expression):
Expression =
- this.copy(child = newChild)
+ this.copy(replacement = newChild)
}
@ExpressionDescription(
@@ -174,17 +169,15 @@ case class TrySubtract(left: Expression, right:
Expression, child: Expression)
""",
since = "3.3.0",
group = "math_funcs")
-case class TryMultiply(left: Expression, right: Expression, child: Expression)
- extends RuntimeReplaceable {
+case class TryMultiply(left: Expression, right: Expression, replacement:
Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) =
this(left, right, TryEval(Multiply(left, right, failOnError = true)))
- override def flatArguments: Iterator[Any] = Iterator(left, right)
-
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
-
override def prettyName: String = "try_multiply"
+ override def parameters: Seq[Expression] = Seq(left, right)
+
override protected def withNewChildInternal(newChild: Expression):
Expression =
- this.copy(child = newChild)
+ this.copy(replacement = newChild)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala
index 66800b2..6973641 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountIf.scala
@@ -17,11 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.{Expression,
ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT_IF, TreePattern}
+import org.apache.spark.sql.catalyst.expressions.{Expression,
ExpressionDescription, ImplicitCastInputTypes, Literal, NullIf,
RuntimeReplaceableAggregate}
import org.apache.spark.sql.catalyst.trees.UnaryLike
-import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType,
LongType}
+import org.apache.spark.sql.types.{AbstractDataType, BooleanType}
@ExpressionDescription(
usage = """
@@ -36,30 +34,11 @@ import org.apache.spark.sql.types.{AbstractDataType,
BooleanType, DataType, Long
""",
group = "agg_funcs",
since = "3.0.0")
-case class CountIf(predicate: Expression) extends UnevaluableAggregate with
ImplicitCastInputTypes
- with UnaryLike[Expression] {
-
- override def prettyName: String = "count_if"
-
- override def child: Expression = predicate
-
- override def nullable: Boolean = false
-
- override def dataType: DataType = LongType
-
+case class CountIf(child: Expression) extends RuntimeReplaceableAggregate
+ with ImplicitCastInputTypes with UnaryLike[Expression] {
+ override lazy val replacement: Expression = Count(new NullIf(child,
Literal.FalseLiteral))
+ override def nodeName: String = "count_if"
override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)
-
- final override val nodePatterns: Seq[TreePattern] = Seq(COUNT_IF)
-
- override def checkInputDataTypes(): TypeCheckResult = predicate.dataType
match {
- case BooleanType =>
- TypeCheckResult.TypeCheckSuccess
- case _ =>
- TypeCheckResult.TypeCheckFailure(
- s"function $prettyName requires boolean type, not
${predicate.dataType.catalogString}"
- )
- }
-
override protected def withNewChildInternal(newChild: Expression): CountIf =
- copy(predicate = newChild)
+ copy(child = newChild)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/RegrCount.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/RegrCount.scala
index 57dbc14..80df012 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/RegrCount.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/RegrCount.scala
@@ -17,10 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.expressions.{Expression,
ExpressionDescription, ImplicitCastInputTypes, UnevaluableAggregate}
+import org.apache.spark.sql.catalyst.expressions.{Expression,
ExpressionDescription, ImplicitCastInputTypes, RuntimeReplaceableAggregate}
import org.apache.spark.sql.catalyst.trees.BinaryLike
-import org.apache.spark.sql.catalyst.trees.TreePattern.{REGR_COUNT,
TreePattern}
-import org.apache.spark.sql.types.{AbstractDataType, DataType, LongType,
NumericType}
+import org.apache.spark.sql.types.{AbstractDataType, NumericType}
@ExpressionDescription(
usage = """
@@ -38,18 +37,10 @@ import org.apache.spark.sql.types.{AbstractDataType,
DataType, LongType, Numeric
group = "agg_funcs",
since = "3.3.0")
case class RegrCount(left: Expression, right: Expression)
- extends UnevaluableAggregate with ImplicitCastInputTypes with
BinaryLike[Expression] {
-
- override def prettyName: String = "regr_count"
-
- override def nullable: Boolean = false
-
- override def dataType: DataType = LongType
-
+ extends RuntimeReplaceableAggregate with ImplicitCastInputTypes with
BinaryLike[Expression] {
+ override lazy val replacement: Expression = Count(Seq(left, right))
+ override def nodeName: String = "regr_count"
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType,
NumericType)
-
- final override val nodePatterns: Seq[TreePattern] = Seq(REGR_COUNT)
-
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): RegrCount =
this.copy(left = newLeft, right = newRight)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/boolAggregates.scala
similarity index 63%
rename from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala
rename to
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/boolAggregates.scala
index 244e9d9..59c75f2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/boolAggregates.scala
@@ -17,33 +17,10 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.trees.TreePattern.{BOOL_AGG, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types._
-abstract class UnevaluableBooleanAggBase(arg: Expression)
- extends UnevaluableAggregate with ImplicitCastInputTypes with
UnaryLike[Expression] {
-
- override def child: Expression = arg
-
- override def dataType: DataType = BooleanType
-
- override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)
-
- final override val nodePatterns: Seq[TreePattern] = Seq(BOOL_AGG)
-
- override def checkInputDataTypes(): TypeCheckResult = {
- arg.dataType match {
- case dt if dt != BooleanType =>
- TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName'
should have been " +
- s"${BooleanType.simpleString}, but it's
[${arg.dataType.catalogString}].")
- case _ => TypeCheckResult.TypeCheckSuccess
- }
- }
-}
-
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.",
examples = """
@@ -57,10 +34,13 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
""",
group = "agg_funcs",
since = "3.0.0")
-case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
- override def nodeName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_and")
+case class BoolAnd(child: Expression) extends RuntimeReplaceableAggregate
+ with ImplicitCastInputTypes with UnaryLike[Expression] {
+ override lazy val replacement: Expression = Min(child)
+ override def nodeName: String = "bool_and"
+ override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)
override protected def withNewChildInternal(newChild: Expression):
Expression =
- copy(arg = newChild)
+ copy(child = newChild)
}
@ExpressionDescription(
@@ -76,8 +56,11 @@ case class BoolAnd(arg: Expression) extends
UnevaluableBooleanAggBase(arg) {
""",
group = "agg_funcs",
since = "3.0.0")
-case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
- override def nodeName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_or")
+case class BoolOr(child: Expression) extends RuntimeReplaceableAggregate
+ with ImplicitCastInputTypes with UnaryLike[Expression] {
+ override lazy val replacement: Expression = Max(child)
+ override def nodeName: String = "bool_or"
+ override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)
override protected def withNewChildInternal(newChild: Expression):
Expression =
- copy(arg = newChild)
+ copy(child = newChild)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 65b6a05..0cd8593 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -27,6 +27,7 @@ import
org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT,
TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -182,19 +183,41 @@ case class MapKeys(child: Expression)
""",
group = "map_funcs",
since = "3.3.0")
-case class MapContainsKey(
- left: Expression,
- right: Expression,
- child: Expression) extends RuntimeReplaceable {
- def this(left: Expression, right: Expression) =
- this(left, right, ArrayContains(MapKeys(left), right))
+case class MapContainsKey(left: Expression, right: Expression)
+ extends RuntimeReplaceable with BinaryLike[Expression] with
ImplicitCastInputTypes {
+
+ override lazy val replacement: Expression = ArrayContains(MapKeys(left),
right)
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
+ override def inputTypes: Seq[AbstractDataType] = {
+ (left.dataType, right.dataType) match {
+ case (_, NullType) => Seq.empty
+ case (MapType(kt, vt, valueContainsNull), dt) =>
+ TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(kt, dt) match {
+ case Some(widerType) => Seq(MapType(widerType, vt,
valueContainsNull), widerType)
+ case _ => Seq.empty
+ }
+ case _ => Seq.empty
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (_, NullType) =>
+ TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as
arguments")
+ case (MapType(kt, _, _), dt) if kt.sameType(dt) =>
+ TypeUtils.checkForOrderingExpr(kt, s"function $prettyName")
+ case _ => TypeCheckResult.TypeCheckFailure(s"Input to function
$prettyName should have " +
+ s"been ${MapType.simpleString} followed by a value with same key type,
but it's " +
+ s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
+ }
+ }
override def prettyName: String = "map_contains_key"
- override protected def withNewChildInternal(newChild: Expression):
MapContainsKey =
- copy(child = newChild)
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): Expression = {
+ copy(newLeft, newRight)
+ }
}
@ExpressionDescription(
@@ -2229,19 +2252,17 @@ case class ElementAt(
""",
since = "3.3.0",
group = "map_funcs")
-case class TryElementAt(left: Expression, right: Expression, child: Expression)
- extends RuntimeReplaceable {
+case class TryElementAt(left: Expression, right: Expression, replacement:
Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) =
this(left, right, ElementAt(left, right, failOnError = false))
- override def flatArguments: Iterator[Any] = Iterator(left, right)
-
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
-
override def prettyName: String = "try_element_at"
+ override def parameters: Seq[Expression] = Seq(left, right)
+
override protected def withNewChildInternal(newChild: Expression):
Expression =
- this.copy(child = newChild)
+ this.copy(replacement = newChild)
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index e73e989..9780b9d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -26,6 +26,7 @@ import org.apache.commons.text.StringEscapeUtils
import org.apache.spark.SparkDateTimeException
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder,
FunctionRegistry}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -1112,25 +1113,15 @@ case class GetTimestamp(
group = "datetime_funcs",
since = "3.3.0")
// scalastyle:on line.size.limit
-case class ParseToTimestampNTZ(
- left: Expression,
- format: Option[Expression],
- child: Expression) extends RuntimeReplaceable {
-
- def this(left: Expression, format: Expression) = {
- this(left, Option(format), GetTimestamp(left, format, TimestampNTZType))
+object ParseToTimestampNTZExpressionBuilder extends ExpressionBuilder {
+ override def build(funcName: String, expressions: Seq[Expression]):
Expression = {
+ val numArgs = expressions.length
+ if (numArgs == 1 || numArgs == 2) {
+ ParseToTimestamp(expressions(0), expressions.drop(1).lastOption,
TimestampNTZType)
+ } else {
+ throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(1,
2), funcName, numArgs)
+ }
}
-
- def this(left: Expression) = this(left, None, Cast(left, TimestampNTZType))
-
- override def flatArguments: Iterator[Any] = Iterator(left, format)
- override def exprsReplaced: Seq[Expression] = left +: format.toSeq
-
- override def prettyName: String = "to_timestamp_ntz"
- override def dataType: DataType = TimestampNTZType
-
- override protected def withNewChildInternal(newChild: Expression):
ParseToTimestampNTZ =
- copy(child = newChild)
}
/**
@@ -1159,25 +1150,15 @@ case class ParseToTimestampNTZ(
group = "datetime_funcs",
since = "3.3.0")
// scalastyle:on line.size.limit
-case class ParseToTimestampLTZ(
- left: Expression,
- format: Option[Expression],
- child: Expression) extends RuntimeReplaceable {
-
- def this(left: Expression, format: Expression) = {
- this(left, Option(format), GetTimestamp(left, format, TimestampType))
+object ParseToTimestampLTZExpressionBuilder extends ExpressionBuilder {
+ override def build(funcName: String, expressions: Seq[Expression]):
Expression = {
+ val numArgs = expressions.length
+ if (numArgs == 1 || numArgs == 2) {
+ ParseToTimestamp(expressions(0), expressions.drop(1).lastOption,
TimestampType)
+ } else {
+ throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(1,
2), funcName, numArgs)
+ }
}
-
- def this(left: Expression) = this(left, None, Cast(left, TimestampType))
-
- override def flatArguments: Iterator[Any] = Iterator(left, format)
- override def exprsReplaced: Seq[Expression] = left +: format.toSeq
-
- override def prettyName: String = "to_timestamp_ltz"
- override def dataType: DataType = TimestampType
-
- override protected def withNewChildInternal(newChild: Expression):
ParseToTimestampLTZ =
- copy(child = newChild)
}
abstract class ToTimestamp
@@ -1606,12 +1587,19 @@ case class TimeAdd(start: Expression, interval:
Expression, timeZoneId: Option[S
case class DatetimeSub(
start: Expression,
interval: Expression,
- child: Expression) extends RuntimeReplaceable {
- override def exprsReplaced: Seq[Expression] = Seq(start, interval)
+ replacement: Expression) extends RuntimeReplaceable with
InheritAnalysisRules {
+
+ override def parameters: Seq[Expression] = Seq(start, interval)
+
+ override def makeSQLString(childrenSQL: Seq[String]): String = {
+ childrenSQL.mkString(" - ")
+ }
+
override def toString: String = s"$start - $interval"
- override def mkString(childrenString: Seq[String]): String =
childrenString.mkString(" - ")
- override protected def withNewChildInternal(newChild: Expression):
DatetimeSub =
- copy(child = newChild)
+
+ override protected def withNewChildInternal(newChild: Expression):
Expression = {
+ copy(replacement = newChild)
+ }
}
/**
@@ -1991,25 +1979,48 @@ case class MonthsBetween(
group = "datetime_funcs",
since = "1.5.0")
// scalastyle:on line.size.limit
-case class ParseToDate(left: Expression, format: Option[Expression], child:
Expression)
- extends RuntimeReplaceable {
+case class ParseToDate(
+ left: Expression,
+ format: Option[Expression],
+ timeZoneId: Option[String] = None)
+ extends RuntimeReplaceable with ImplicitCastInputTypes with
TimeZoneAwareExpression {
+
+ override lazy val replacement: Expression = format.map { f =>
+ Cast(GetTimestamp(left, f, TimestampType, timeZoneId), DateType,
timeZoneId)
+ }.getOrElse(Cast(left, DateType, timeZoneId)) // backwards compatibility
def this(left: Expression, format: Expression) = {
- this(left, Option(format), Cast(GetTimestamp(left, format, TimestampType),
DateType))
+ this(left, Option(format))
}
def this(left: Expression) = {
- // backwards compatibility
- this(left, None, Cast(left, DateType))
+ this(left, None)
}
- override def exprsReplaced: Seq[Expression] = left +: format.toSeq
- override def flatArguments: Iterator[Any] = Iterator(left, format)
-
override def prettyName: String = "to_date"
- override protected def withNewChildInternal(newChild: Expression):
ParseToDate =
- copy(child = newChild)
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Some(timeZoneId))
+
+ override def nodePatternsInternal(): Seq[TreePattern] =
Seq(RUNTIME_REPLACEABLE)
+
+ override def children: Seq[Expression] = left +: format.toSeq
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ // Note: ideally this function should only take string input, but we allow
more types here to
+ // be backward compatible.
+ TypeCollection(StringType, DateType, TimestampType, TimestampNTZType) +:
+ format.map(_ => StringType).toSeq
+ }
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ if (format.isDefined) {
+ copy(left = newChildren.head, format = Some(newChildren.last))
+ } else {
+ copy(left = newChildren.head)
+ }
+ }
}
/**
@@ -2043,23 +2054,44 @@ case class ParseToTimestamp(
left: Expression,
format: Option[Expression],
override val dataType: DataType,
- child: Expression) extends RuntimeReplaceable {
+ timeZoneId: Option[String] = None)
+ extends RuntimeReplaceable with ImplicitCastInputTypes with
TimeZoneAwareExpression {
+
+ override lazy val replacement: Expression = format.map { f =>
+ GetTimestamp(left, f, dataType, timeZoneId)
+ }.getOrElse(Cast(left, dataType, timeZoneId))
def this(left: Expression, format: Expression) = {
- this(left, Option(format), SQLConf.get.timestampType,
- GetTimestamp(left, format, SQLConf.get.timestampType))
+ this(left, Option(format), SQLConf.get.timestampType)
}
def this(left: Expression) =
- this(left, None, SQLConf.get.timestampType, Cast(left,
SQLConf.get.timestampType))
+ this(left, None, SQLConf.get.timestampType)
- override def flatArguments: Iterator[Any] = Iterator(left, format)
- override def exprsReplaced: Seq[Expression] = left +: format.toSeq
+ override def nodeName: String = "to_timestamp"
- override def prettyName: String = "to_timestamp"
+ override def nodePatternsInternal(): Seq[TreePattern] =
Seq(RUNTIME_REPLACEABLE)
- override protected def withNewChildInternal(newChild: Expression):
ParseToTimestamp =
- copy(child = newChild)
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Some(timeZoneId))
+
+ override def children: Seq[Expression] = left +: format.toSeq
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ // Note: ideally this function should only take string input, but we allow
more types here to
+ // be backward compatible.
+ TypeCollection(StringType, DateType, TimestampType, TimestampNTZType) +:
+ format.map(_ => StringType).toSeq
+ }
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ if (format.isDefined) {
+ copy(left = newChildren.head, format = Some(newChildren.last))
+ } else {
+ copy(left = newChildren.head)
+ }
+ }
}
trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
@@ -2410,32 +2442,22 @@ case class MakeDate(
group = "datetime_funcs",
since = "3.3.0")
// scalastyle:on line.size.limit
-case class MakeTimestampNTZ(
- year: Expression,
- month: Expression,
- day: Expression,
- hour: Expression,
- min: Expression,
- sec: Expression,
- failOnError: Boolean = SQLConf.get.ansiEnabled,
- child: Expression) extends RuntimeReplaceable {
- def this(
- year: Expression,
- month: Expression,
- day: Expression,
- hour: Expression,
- min: Expression,
- sec: Expression) = {
- this(year, month, day, hour, min, sec, failOnError =
SQLConf.get.ansiEnabled,
- MakeTimestamp(year, month, day, hour, min, sec, dataType =
TimestampNTZType))
+object MakeTimestampNTZExpressionBuilder extends ExpressionBuilder {
+ override def build(funcName: String, expressions: Seq[Expression]):
Expression = {
+ val numArgs = expressions.length
+ if (numArgs == 6) {
+ MakeTimestamp(
+ expressions(0),
+ expressions(1),
+ expressions(2),
+ expressions(3),
+ expressions(4),
+ expressions(5),
+ dataType = TimestampNTZType)
+ } else {
+ throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(6),
funcName, numArgs)
+ }
}
-
- override def prettyName: String = "make_timestamp_ntz"
-
- override def exprsReplaced: Seq[Expression] = Seq(year, month, day, hour,
min, sec)
-
- override protected def withNewChildInternal(newChild: Expression):
Expression =
- copy(child = newChild)
}
// scalastyle:off line.size.limit
@@ -2469,45 +2491,23 @@ case class MakeTimestampNTZ(
group = "datetime_funcs",
since = "3.3.0")
// scalastyle:on line.size.limit
-case class MakeTimestampLTZ(
- year: Expression,
- month: Expression,
- day: Expression,
- hour: Expression,
- min: Expression,
- sec: Expression,
- timezone: Option[Expression],
- failOnError: Boolean = SQLConf.get.ansiEnabled,
- child: Expression) extends RuntimeReplaceable {
- def this(
- year: Expression,
- month: Expression,
- day: Expression,
- hour: Expression,
- min: Expression,
- sec: Expression) = {
- this(year, month, day, hour, min, sec, None, failOnError =
SQLConf.get.ansiEnabled,
- MakeTimestamp(year, month, day, hour, min, sec, dataType =
TimestampType))
- }
-
- def this(
- year: Expression,
- month: Expression,
- day: Expression,
- hour: Expression,
- min: Expression,
- sec: Expression,
- timezone: Expression) = {
- this(year, month, day, hour, min, sec, Some(timezone), failOnError =
SQLConf.get.ansiEnabled,
- MakeTimestamp(year, month, day, hour, min, sec, Some(timezone), dataType
= TimestampType))
+object MakeTimestampLTZExpressionBuilder extends ExpressionBuilder {
+ override def build(funcName: String, expressions: Seq[Expression]):
Expression = {
+ val numArgs = expressions.length
+ if (numArgs == 6 || numArgs == 7) {
+ MakeTimestamp(
+ expressions(0),
+ expressions(1),
+ expressions(2),
+ expressions(3),
+ expressions(4),
+ expressions(5),
+ expressions.drop(6).lastOption,
+ dataType = TimestampType)
+ } else {
+ throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(6),
funcName, numArgs)
+ }
}
-
- override def prettyName: String = "make_timestamp_ltz"
-
- override def exprsReplaced: Seq[Expression] = Seq(year, month, day, hour,
min, sec)
-
- override protected def withNewChildInternal(newChild: Expression):
Expression =
- copy(child = newChild)
}
// scalastyle:off line.size.limit
@@ -2699,7 +2699,7 @@ case class MakeTimestamp(
})
}
- override def prettyName: String = "make_timestamp"
+ override def nodeName: String = "make_timestamp"
// override def children: Seq[Expression] = Seq(year, month, day, hour, min,
sec) ++ timezone
override protected def withNewChildrenInternal(
@@ -2720,8 +2720,7 @@ object DatePart {
def parseExtractField(
extractField: String,
- source: Expression,
- errorHandleFunc: => Nothing): Expression =
extractField.toUpperCase(Locale.ROOT) match {
+ source: Expression): Expression = extractField.toUpperCase(Locale.ROOT)
match {
case "YEAR" | "Y" | "YEARS" | "YR" | "YRS" => Year(source)
case "YEAROFWEEK" => YearOfWeek(source)
case "QUARTER" | "QTR" => Quarter(source)
@@ -2734,29 +2733,8 @@ object DatePart {
case "HOUR" | "H" | "HOURS" | "HR" | "HRS" => Hour(source)
case "MINUTE" | "M" | "MIN" | "MINS" | "MINUTES" => Minute(source)
case "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" =>
SecondWithFraction(source)
- case _ => errorHandleFunc
- }
-
- def toEquivalentExpr(field: Expression, source: Expression): Expression = {
- if (!field.foldable) {
- throw QueryCompilationErrors.unfoldableFieldUnsupportedError
- }
- val fieldEval = field.eval()
- if (fieldEval == null) {
- Literal(null, DoubleType)
- } else {
- val fieldStr = fieldEval.asInstanceOf[UTF8String].toString
-
- def analysisException =
- throw
QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(fieldStr,
source)
-
- source.dataType match {
- case _: AnsiIntervalType | CalendarIntervalType =>
- ExtractIntervalPart.parseExtractField(fieldStr, source,
analysisException)
- case _ =>
- DatePart.parseExtractField(fieldStr, source, analysisException)
- }
- }
+ case _ =>
+ throw
QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(extractField,
source)
}
}
@@ -2793,20 +2771,17 @@ object DatePart {
group = "datetime_funcs",
since = "3.0.0")
// scalastyle:on line.size.limit
-case class DatePart(field: Expression, source: Expression, child: Expression)
- extends RuntimeReplaceable {
-
- def this(field: Expression, source: Expression) = {
- this(field, source, DatePart.toEquivalentExpr(field, source))
+object DatePartExpressionBuilder extends ExpressionBuilder {
+ override def build(funcName: String, expressions: Seq[Expression]):
Expression = {
+ val numArgs = expressions.length
+ if (numArgs == 2) {
+ val field = expressions(0)
+ val source = expressions(1)
+ Extract(field, source, Extract.createExpr(funcName, field, source))
+ } else {
+ throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2),
funcName, numArgs)
+ }
}
-
- override def flatArguments: Iterator[Any] = Iterator(field, source)
- override def exprsReplaced: Seq[Expression] = Seq(field, source)
-
- override def prettyName: String = "date_part"
-
- override protected def withNewChildInternal(newChild: Expression): DatePart =
- copy(child = newChild)
}
// scalastyle:off line.size.limit
@@ -2862,23 +2837,45 @@ case class DatePart(field: Expression, source:
Expression, child: Expression)
group = "datetime_funcs",
since = "3.0.0")
// scalastyle:on line.size.limit
-case class Extract(field: Expression, source: Expression, child: Expression)
- extends RuntimeReplaceable {
+case class Extract(field: Expression, source: Expression, replacement:
Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
- def this(field: Expression, source: Expression) = {
- this(field, source, DatePart.toEquivalentExpr(field, source))
- }
+ def this(field: Expression, source: Expression) =
+ this(field, source, Extract.createExpr("extract", field, source))
- override def flatArguments: Iterator[Any] = Iterator(field, source)
+ override def parameters: Seq[Expression] = Seq(field, source)
- override def exprsReplaced: Seq[Expression] = Seq(field, source)
+ override def makeSQLString(childrenSQL: Seq[String]): String = {
+ getTagValue(FunctionRegistry.FUNC_ALIAS) match {
+ case Some("date_part") => s"$prettyName(${childrenSQL.mkString(", ")})"
+ case _ => s"$prettyName(${childrenSQL.mkString(" FROM ")})"
+ }
+ }
- override def mkString(childrenString: Seq[String]): String = {
- prettyName + childrenString.mkString("(", " FROM ", ")")
+ override protected def withNewChildInternal(newChild: Expression):
Expression = {
+ copy(replacement = newChild)
}
+}
- override protected def withNewChildInternal(newChild: Expression): Extract =
- copy(child = newChild)
+object Extract {
+ def createExpr(funcName: String, field: Expression, source: Expression):
Expression = {
+ // both string and null literals are allowed.
+ if ((field.dataType == StringType || field.dataType == NullType) &&
field.foldable) {
+ val fieldStr = field.eval().asInstanceOf[UTF8String]
+ if (fieldStr == null) {
+ Literal(null, DoubleType)
+ } else {
+ source.dataType match {
+ case _: AnsiIntervalType | CalendarIntervalType =>
+ ExtractIntervalPart.parseExtractField(fieldStr.toString, source)
+ case _ =>
+ DatePart.parseExtractField(fieldStr.toString, source)
+ }
+ }
+ } else {
+ throw QueryCompilationErrors.requireLiteralParameter(funcName, "field",
"string")
+ }
+ }
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 5568d7c..c461b8f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -26,7 +26,7 @@ import
org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MONTHS_PER_YEAR
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils._
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE,
SECOND}
@@ -122,10 +122,7 @@ case class ExtractANSIIntervalSeconds(child: Expression)
object ExtractIntervalPart {
- def parseExtractField(
- extractField: String,
- source: Expression,
- errorHandleFunc: => Nothing): Expression = {
+ def parseExtractField(extractField: String, source: Expression): Expression
= {
(extractField.toUpperCase(Locale.ROOT), source.dataType) match {
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS",
YearMonthIntervalType(start, end))
if isUnitInIntervalRange(YEAR, start, end) =>
@@ -157,7 +154,8 @@ object ExtractIntervalPart {
ExtractANSIIntervalSeconds(source)
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType)
=>
ExtractIntervalSeconds(source)
- case _ => errorHandleFunc
+ case _ =>
+ throw
QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(extractField,
source)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index d34b837..f64b6ea 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -269,28 +269,32 @@ case class Ceil(child: Expression) extends
UnaryMathExpression(math.ceil, "CEIL"
override protected def withNewChildInternal(newChild: Expression): Ceil =
copy(child = newChild)
}
-trait CeilFloorExpressionBuilder extends ExpressionBuilder {
- val functionName: String
- def build(expressions: Seq[Expression]): Expression
-
- def extractChildAndScaleParam(expressions: Seq[Expression]): (Expression,
Expression) = {
- val child = expressions(0)
- val scale = expressions(1)
- if (! (scale.foldable && scale.dataType == DataTypes.IntegerType)) {
- throw QueryCompilationErrors.invalidScaleParameterRoundBase(functionName)
- }
- val scaleV = scale.eval(EmptyRow)
- if (scaleV == null) {
- throw QueryCompilationErrors.invalidScaleParameterRoundBase(functionName)
+trait CeilFloorExpressionBuilderBase extends ExpressionBuilder {
+ protected def buildWithOneParam(param: Expression): Expression
+ protected def buildWithTwoParams(param1: Expression, param2: Expression):
Expression
+
+ override def build(funcName: String, expressions: Seq[Expression]):
Expression = {
+ val numArgs = expressions.length
+ if (numArgs == 1) {
+ buildWithOneParam(expressions.head)
+ } else if (numArgs == 2) {
+ val scale = expressions(1)
+ if (!(scale.foldable && scale.dataType == IntegerType)) {
+ throw QueryCompilationErrors.requireLiteralParameter(funcName,
"scale", "int")
+ }
+ if (scale.eval() == null) {
+ throw QueryCompilationErrors.requireLiteralParameter(funcName,
"scale", "int")
+ }
+ buildWithTwoParams(expressions(0), scale)
+ } else {
+ throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2),
funcName, numArgs)
}
- (child, scale)
}
}
+// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = """
- _FUNC_(expr[, scale]) - Returns the smallest number after rounding up that
is not smaller
- than `expr`. A optional `scale` parameter can be specified to control the
rounding behavior.""",
+ usage = "_FUNC_(expr[, scale]) - Returns the smallest number after rounding
up that is not smaller than `expr`. An optional `scale` parameter can be
specified to control the rounding behavior.",
examples = """
Examples:
> SELECT _FUNC_(-0.1);
@@ -304,24 +308,17 @@ trait CeilFloorExpressionBuilder extends
ExpressionBuilder {
""",
since = "3.3.0",
group = "math_funcs")
-object CeilExpressionBuilder extends CeilFloorExpressionBuilder {
- val functionName: String = "ceil"
-
- def build(expressions: Seq[Expression]): Expression = {
- if (expressions.length == 1) {
- Ceil(expressions.head)
- } else if (expressions.length == 2) {
- val (child, scale) = extractChildAndScaleParam(expressions)
- RoundCeil(child, scale)
- } else {
- throw
QueryCompilationErrors.invalidNumberOfFunctionParameters(functionName)
- }
- }
+// scalastyle:on line.size.limit
+object CeilExpressionBuilder extends CeilFloorExpressionBuilderBase {
+ override protected def buildWithOneParam(param: Expression): Expression =
Ceil(param)
+
+ override protected def buildWithTwoParams(param1: Expression, param2:
Expression): Expression =
+ RoundCeil(param1, param2)
}
case class RoundCeil(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.CEILING,
"ROUND_CEILING")
- with Serializable with ImplicitCastInputTypes {
+ with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType,
IntegerType)
@@ -335,9 +332,11 @@ case class RoundCeil(child: Expression, scale: Expression)
case t => t
}
- override protected def withNewChildrenInternal(newLeft: Expression,
newRight: Expression)
- : RoundCeil = copy(child = newLeft, scale = newRight)
override def nodeName: String = "ceil"
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): RoundCeil =
+ copy(child = newLeft, scale = newRight)
}
@ExpressionDescription(
@@ -539,10 +538,9 @@ case class Floor(child: Expression) extends
UnaryMathExpression(math.floor, "FLO
copy(child = newChild)
}
+// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = """
- _FUNC_(expr[, scale]) - Returns the largest number after rounding down that
is not greater
- than `expr`. An optional `scale` parameter can be specified to control the
rounding behavior.""",
+ usage = " _FUNC_(expr[, scale]) - Returns the largest number after rounding
down that is not greater than `expr`. An optional `scale` parameter can be
specified to control the rounding behavior.",
examples = """
Examples:
> SELECT _FUNC_(-0.1);
@@ -556,24 +554,17 @@ case class Floor(child: Expression) extends
UnaryMathExpression(math.floor, "FLO
""",
since = "3.3.0",
group = "math_funcs")
-object FloorExpressionBuilder extends CeilFloorExpressionBuilder {
- val functionName: String = "floor"
-
- def build(expressions: Seq[Expression]): Expression = {
- if (expressions.length == 1) {
- Floor(expressions.head)
- } else if (expressions.length == 2) {
- val(child, scale) = extractChildAndScaleParam(expressions)
- RoundFloor(child, scale)
- } else {
- throw
QueryCompilationErrors.invalidNumberOfFunctionParameters(functionName)
- }
- }
+// scalastyle:on line.size.limit
+object FloorExpressionBuilder extends CeilFloorExpressionBuilderBase {
+ override protected def buildWithOneParam(param: Expression): Expression =
Floor(param)
+
+ override protected def buildWithTwoParams(param1: Expression, param2:
Expression): Expression =
+ RoundFloor(param1, param2)
}
case class RoundFloor(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.FLOOR, "ROUND_FLOOR")
- with Serializable with ImplicitCastInputTypes {
+ with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType,
IntegerType)
@@ -587,9 +578,11 @@ case class RoundFloor(child: Expression, scale: Expression)
case t => t
}
- override protected def withNewChildrenInternal(newLeft: Expression,
newRight: Expression)
- : RoundFloor = copy(child = newLeft, scale = newRight)
override def nodeName: String = "floor"
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): RoundFloor =
+ copy(child = newLeft, scale = newRight)
}
object Factorial {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 941ccb7..eb21bd5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -126,8 +126,8 @@ object RaiseError {
""",
since = "2.0.0",
group = "misc_funcs")
-case class AssertTrue(left: Expression, right: Expression, child: Expression)
- extends RuntimeReplaceable {
+case class AssertTrue(left: Expression, right: Expression, replacement:
Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
override def prettyName: String = "assert_true"
@@ -139,11 +139,10 @@ case class AssertTrue(left: Expression, right:
Expression, child: Expression)
this(left, Literal(s"'${left.simpleString(SQLConf.get.maxToStringFields)}'
is not true!"))
}
- override def flatArguments: Iterator[Any] = Iterator(left, right)
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
+ override def parameters: Seq[Expression] = Seq(left, right)
override protected def withNewChildInternal(newChild: Expression):
AssertTrue =
- copy(child = newChild)
+ copy(replacement = newChild)
}
object AssertTrue {
@@ -341,31 +340,31 @@ case class AesEncrypt(
input: Expression,
key: Expression,
mode: Expression,
- padding: Expression,
- child: Expression)
- extends RuntimeReplaceable {
-
- def this(input: Expression, key: Expression, mode: Expression, padding:
Expression) = {
- this(
- input,
- key,
- mode,
- padding,
- StaticInvoke(
- classOf[ExpressionImplUtils],
- BinaryType,
- "aesEncrypt",
- Seq(input, key, mode, padding),
- Seq(BinaryType, BinaryType, StringType, StringType)))
- }
+ padding: Expression)
+ extends RuntimeReplaceable with ImplicitCastInputTypes {
+
+ override lazy val replacement: Expression = StaticInvoke(
+ classOf[ExpressionImplUtils],
+ BinaryType,
+ "aesEncrypt",
+ Seq(input, key, mode, padding),
+ inputTypes)
+
def this(input: Expression, key: Expression, mode: Expression) =
this(input, key, mode, Literal("DEFAULT"))
def this(input: Expression, key: Expression) =
this(input, key, Literal("GCM"))
- def exprsReplaced: Seq[Expression] = Seq(input, key, mode, padding)
- protected def withNewChildInternal(newChild: Expression): AesEncrypt =
- copy(child = newChild)
+ override def prettyName: String = "aes_encrypt"
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType,
StringType, StringType)
+
+ override def children: Seq[Expression] = Seq(input, key, mode, padding)
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ copy(newChildren(0), newChildren(1), newChildren(2), newChildren(3))
+ }
}
/**
@@ -405,30 +404,32 @@ case class AesDecrypt(
input: Expression,
key: Expression,
mode: Expression,
- padding: Expression,
- child: Expression)
- extends RuntimeReplaceable {
-
- def this(input: Expression, key: Expression, mode: Expression, padding:
Expression) = {
- this(
- input,
- key,
- mode,
- padding,
- StaticInvoke(
- classOf[ExpressionImplUtils],
- BinaryType,
- "aesDecrypt",
- Seq(input, key, mode, padding),
- Seq(BinaryType, BinaryType, StringType, StringType)))
- }
+ padding: Expression)
+ extends RuntimeReplaceable with ImplicitCastInputTypes {
+
+ override lazy val replacement: Expression = StaticInvoke(
+ classOf[ExpressionImplUtils],
+ BinaryType,
+ "aesDecrypt",
+ Seq(input, key, mode, padding),
+ inputTypes)
+
def this(input: Expression, key: Expression, mode: Expression) =
this(input, key, mode, Literal("DEFAULT"))
def this(input: Expression, key: Expression) =
this(input, key, Literal("GCM"))
- def exprsReplaced: Seq[Expression] = Seq(input, key)
- protected def withNewChildInternal(newChild: Expression): AesDecrypt =
- copy(child = newChild)
+ override def inputTypes: Seq[AbstractDataType] = {
+ Seq(BinaryType, BinaryType, StringType, StringType)
+ }
+
+ override def prettyName: String = "aes_decrypt"
+
+ override def children: Seq[Expression] = Seq(input, key, mode, padding)
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ copy(newChildren(0), newChildren(1), newChildren(2), newChildren(3))
+ }
}
// scalastyle:on line.size.limit
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index a15126a..3c6a9b8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -130,29 +130,6 @@ case class Coalesce(children: Seq[Expression]) extends
ComplexTypeMergingExpress
@ExpressionDescription(
- usage = "_FUNC_(expr1, expr2) - Returns `expr2` if `expr1` is null, or
`expr1` otherwise.",
- examples = """
- Examples:
- > SELECT _FUNC_(NULL, array('2'));
- ["2"]
- """,
- since = "2.0.0",
- group = "conditional_funcs")
-case class IfNull(left: Expression, right: Expression, child: Expression)
- extends RuntimeReplaceable {
-
- def this(left: Expression, right: Expression) = {
- this(left, right, Coalesce(Seq(left, right)))
- }
-
- override def flatArguments: Iterator[Any] = Iterator(left, right)
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
-
- override protected def withNewChildInternal(newChild: Expression): IfNull =
copy(child = newChild)
-}
-
-
-@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns null if `expr1` equals to `expr2`,
or `expr1` otherwise.",
examples = """
Examples:
@@ -161,17 +138,18 @@ case class IfNull(left: Expression, right: Expression,
child: Expression)
""",
since = "2.0.0",
group = "conditional_funcs")
-case class NullIf(left: Expression, right: Expression, child: Expression)
- extends RuntimeReplaceable {
+case class NullIf(left: Expression, right: Expression, replacement: Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) = {
this(left, right, If(EqualTo(left, right), Literal.create(null,
left.dataType), left))
}
- override def flatArguments: Iterator[Any] = Iterator(left, right)
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
+ override def parameters: Seq[Expression] = Seq(left, right)
- override protected def withNewChildInternal(newChild: Expression): NullIf =
copy(child = newChild)
+ override protected def withNewChildInternal(newChild: Expression): NullIf = {
+ copy(replacement = newChild)
+ }
}
@@ -184,16 +162,17 @@ case class NullIf(left: Expression, right: Expression,
child: Expression)
""",
since = "2.0.0",
group = "conditional_funcs")
-case class Nvl(left: Expression, right: Expression, child: Expression) extends
RuntimeReplaceable {
+case class Nvl(left: Expression, right: Expression, replacement: Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) = {
this(left, right, Coalesce(Seq(left, right)))
}
- override def flatArguments: Iterator[Any] = Iterator(left, right)
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
+ override def parameters: Seq[Expression] = Seq(left, right)
- override protected def withNewChildInternal(newChild: Expression): Nvl =
copy(child = newChild)
+ override protected def withNewChildInternal(newChild: Expression): Nvl =
+ copy(replacement = newChild)
}
@@ -208,17 +187,18 @@ case class Nvl(left: Expression, right: Expression,
child: Expression) extends R
since = "2.0.0",
group = "conditional_funcs")
// scalastyle:on line.size.limit
-case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression,
child: Expression)
- extends RuntimeReplaceable {
+case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression,
replacement: Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
def this(expr1: Expression, expr2: Expression, expr3: Expression) = {
this(expr1, expr2, expr3, If(IsNotNull(expr1), expr2, expr3))
}
- override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3)
- override def exprsReplaced: Seq[Expression] = Seq(expr1, expr2, expr3)
+ override def parameters: Seq[Expression] = Seq(expr1, expr2, expr3)
- override protected def withNewChildInternal(newChild: Expression): Nvl2 =
copy(child = newChild)
+ override protected def withNewChildInternal(newChild: Expression): Nvl2 = {
+ copy(replacement = newChild)
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 889c53b..368cbfd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure,
TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY,
TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -240,18 +241,20 @@ case class Like(left: Expression, right: Expression,
escapeChar: Char)
case class ILike(
left: Expression,
right: Expression,
- escapeChar: Char,
- child: Expression) extends RuntimeReplaceable {
- def this(left: Expression, right: Expression, escapeChar: Char) =
- this(left, right, escapeChar, Like(Lower(left), Lower(right), escapeChar))
+ escapeChar: Char) extends RuntimeReplaceable
+ with ImplicitCastInputTypes with BinaryLike[Expression] {
+
+ override lazy val replacement: Expression = Like(Lower(left), Lower(right),
escapeChar)
+
def this(left: Expression, right: Expression) =
this(left, right, '\\')
- override def exprsReplaced: Seq[Expression] = Seq(left, right)
- override def flatArguments: Iterator[Any] = Iterator(left, right, escapeChar)
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
- override protected def withNewChildInternal(newChild: Expression): ILike =
- copy(child = newChild)
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): Expression = {
+ copy(left = newLeft, right = newRight)
+ }
}
sealed abstract class MultiLikeBase
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 56cd224..021ddbe 100755
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -30,6 +30,7 @@ import
org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegist
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
+import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern,
UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData,
TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
@@ -1047,8 +1048,8 @@ case class StringTrim(srcStr: Expression, trimStr:
Option[Expression] = None)
""",
since = "3.2.0",
group = "string_funcs")
-case class StringTrimBoth(srcStr: Expression, trimStr: Option[Expression],
child: Expression)
- extends RuntimeReplaceable {
+case class StringTrimBoth(srcStr: Expression, trimStr: Option[Expression],
replacement: Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
def this(srcStr: Expression, trimStr: Expression) = {
this(srcStr, Option(trimStr), StringTrim(srcStr, trimStr))
@@ -1058,13 +1059,12 @@ case class StringTrimBoth(srcStr: Expression, trimStr:
Option[Expression], child
this(srcStr, None, StringTrim(srcStr))
}
- override def exprsReplaced: Seq[Expression] = srcStr +: trimStr.toSeq
- override def flatArguments: Iterator[Any] = Iterator(srcStr, trimStr)
-
override def prettyName: String = "btrim"
+ override def parameters: Seq[Expression] = srcStr +: trimStr.toSeq
+
override protected def withNewChildInternal(newChild: Expression):
StringTrimBoth =
- copy(child = newChild)
+ copy(replacement = newChild)
}
object StringTrimLeft {
@@ -1376,17 +1376,17 @@ case class StringLocate(substr: Expression, str:
Expression, start: Expression)
}
trait PadExpressionBuilderBase extends ExpressionBuilder {
- override def build(expressions: Seq[Expression]): Expression = {
+ override def build(funcName: String, expressions: Seq[Expression]):
Expression = {
val numArgs = expressions.length
if (numArgs == 2) {
if (expressions(0).dataType == BinaryType) {
- createBinaryPad(expressions(0), expressions(1),
Literal(Array[Byte](0)))
+ BinaryPad(funcName, expressions(0), expressions(1),
Literal(Array[Byte](0)))
} else {
createStringPad(expressions(0), expressions(1), Literal(" "))
}
} else if (numArgs == 3) {
if (expressions(0).dataType == BinaryType && expressions(2).dataType ==
BinaryType) {
- createBinaryPad(expressions(0), expressions(1), expressions(2))
+ BinaryPad(funcName, expressions(0), expressions(1), expressions(2))
} else {
createStringPad(expressions(0), expressions(1), expressions(2))
}
@@ -1395,8 +1395,6 @@ trait PadExpressionBuilderBase extends ExpressionBuilder {
}
}
- protected def funcName: String
- protected def createBinaryPad(str: Expression, len: Expression, pad:
Expression): Expression
protected def createStringPad(str: Expression, len: Expression, pad:
Expression): Expression
}
@@ -1423,10 +1421,6 @@ trait PadExpressionBuilderBase extends ExpressionBuilder
{
since = "1.5.0",
group = "string_funcs")
object LPadExpressionBuilder extends PadExpressionBuilderBase {
- override def funcName: String = "lpad"
- override def createBinaryPad(str: Expression, len: Expression, pad:
Expression): Expression = {
- new BinaryLPad(str, len, pad)
- }
override def createStringPad(str: Expression, len: Expression, pad:
Expression): Expression = {
StringLPad(str, len, pad)
}
@@ -1459,21 +1453,28 @@ case class StringLPad(str: Expression, len: Expression,
pad: Expression)
copy(str = newFirst, len = newSecond, pad = newThird)
}
-case class BinaryLPad(str: Expression, len: Expression, pad: Expression,
child: Expression)
- extends RuntimeReplaceable {
+case class BinaryPad(funcName: String, str: Expression, len: Expression, pad:
Expression)
+ extends RuntimeReplaceable with ImplicitCastInputTypes {
+ assert(funcName == "lpad" || funcName == "rpad")
- def this(str: Expression, len: Expression, pad: Expression) = this(str, len,
pad, StaticInvoke(
+ override lazy val replacement: Expression = StaticInvoke(
classOf[ByteArray],
BinaryType,
- "lpad",
+ funcName,
Seq(str, len, pad),
- Seq(BinaryType, IntegerType, BinaryType),
+ inputTypes,
returnNullable = false)
- )
- override def prettyName: String = "lpad"
- def exprsReplaced: Seq[Expression] = Seq(str, len, pad)
- protected def withNewChildInternal(newChild: Expression): BinaryLPad =
copy(child = newChild)
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType,
IntegerType, BinaryType)
+
+ override def nodeName: String = funcName
+
+ override def children: Seq[Expression] = Seq(str, len, pad)
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ copy(str = newChildren(0), len = newChildren(1), pad = newChildren(2))
+ }
}
@ExpressionDescription(
@@ -1499,10 +1500,6 @@ case class BinaryLPad(str: Expression, len: Expression,
pad: Expression, child:
since = "1.5.0",
group = "string_funcs")
object RPadExpressionBuilder extends PadExpressionBuilderBase {
- override def funcName: String = "rpad"
- override def createBinaryPad(str: Expression, len: Expression, pad:
Expression): Expression = {
- new BinaryRPad(str, len, pad)
- }
override def createStringPad(str: Expression, len: Expression, pad:
Expression): Expression = {
StringRPad(str, len, pad)
}
@@ -1535,23 +1532,6 @@ case class StringRPad(str: Expression, len: Expression,
pad: Expression = Litera
copy(str = newFirst, len = newSecond, pad = newThird)
}
-case class BinaryRPad(str: Expression, len: Expression, pad: Expression,
child: Expression)
- extends RuntimeReplaceable {
-
- def this(str: Expression, len: Expression, pad: Expression) = this(str, len,
pad, StaticInvoke(
- classOf[ByteArray],
- BinaryType,
- "rpad",
- Seq(str, len, pad),
- Seq(BinaryType, IntegerType, BinaryType),
- returnNullable = false)
- )
-
- override def prettyName: String = "rpad"
- def exprsReplaced: Seq[Expression] = Seq(str, len, pad)
- protected def withNewChildInternal(newChild: Expression): BinaryRPad =
copy(child = newChild)
-}
-
object ParseUrl {
private val HOST = UTF8String.fromString("HOST")
private val PATH = UTF8String.fromString("PATH")
@@ -2025,16 +2005,26 @@ case class Substring(str: Expression, pos: Expression,
len: Expression)
since = "2.3.0",
group = "string_funcs")
// scalastyle:on line.size.limit
-case class Right(str: Expression, len: Expression, child: Expression) extends
RuntimeReplaceable {
- def this(str: Expression, len: Expression) = {
- this(str, len, If(IsNull(str), Literal(null, StringType),
If(LessThanOrEqual(len, Literal(0)),
- Literal(UTF8String.EMPTY_UTF8, StringType), new Substring(str,
UnaryMinus(len)))))
- }
-
- override def flatArguments: Iterator[Any] = Iterator(str, len)
- override def exprsReplaced: Seq[Expression] = Seq(str, len)
+case class Right(str: Expression, len: Expression) extends RuntimeReplaceable
+ with ImplicitCastInputTypes with BinaryLike[Expression] {
+
+ override lazy val replacement: Expression = If(
+ IsNull(str),
+ Literal(null, StringType),
+ If(
+ LessThanOrEqual(len, Literal(0)),
+ Literal(UTF8String.EMPTY_UTF8, StringType),
+ new Substring(str, UnaryMinus(len))
+ )
+ )
- override protected def withNewChildInternal(newChild: Expression): Right =
copy(child = newChild)
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType)
+ override def left: Expression = str
+ override def right: Expression = len
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): Expression = {
+ copy(str = newLeft, len = newRight)
+ }
}
/**
@@ -2051,14 +2041,21 @@ case class Right(str: Expression, len: Expression,
child: Expression) extends Ru
since = "2.3.0",
group = "string_funcs")
// scalastyle:on line.size.limit
-case class Left(str: Expression, len: Expression, child: Expression) extends
RuntimeReplaceable {
- def this(str: Expression, len: Expression) = {
- this(str, len, Substring(str, Literal(1), len))
+case class Left(str: Expression, len: Expression) extends RuntimeReplaceable
+ with ImplicitCastInputTypes with BinaryLike[Expression] {
+
+ override lazy val replacement: Expression = Substring(str, Literal(1), len)
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ Seq(TypeCollection(StringType, BinaryType), IntegerType)
}
- override def flatArguments: Iterator[Any] = Iterator(str, len)
- override def exprsReplaced: Seq[Expression] = Seq(str, len)
- override protected def withNewChildInternal(newChild: Expression): Left =
copy(child = newChild)
+ override def left: Expression = str
+ override def right: Expression = len
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): Expression = {
+ copy(str = newLeft, len = newRight)
+ }
}
/**
@@ -2438,16 +2435,16 @@ object Decode {
since = "3.2.0",
group = "string_funcs")
// scalastyle:on line.size.limit
-case class Decode(params: Seq[Expression], child: Expression) extends
RuntimeReplaceable {
+case class Decode(params: Seq[Expression], replacement: Expression)
+ extends RuntimeReplaceable with InheritAnalysisRules {
- def this(params: Seq[Expression]) = {
- this(params, Decode.createExpr(params))
- }
+ def this(params: Seq[Expression]) = this(params, Decode.createExpr(params))
- override def flatArguments: Iterator[Any] = Iterator(params)
- override def exprsReplaced: Seq[Expression] = params
+ override def parameters: Seq[Expression] = params
- override protected def withNewChildInternal(newChild: Expression): Decode =
copy(child = newChild)
+ override protected def withNewChildInternal(newChild: Expression):
Expression = {
+ copy(replacement = newChild)
+ }
}
/**
@@ -2557,56 +2554,52 @@ case class Encode(value: Expression, charset:
Expression)
since = "3.3.0",
group = "string_funcs")
// scalastyle:on line.size.limit
-case class ToBinary(expr: Expression, format: Option[Expression], child:
Expression)
- extends RuntimeReplaceable {
-
- def this(expr: Expression, format: Expression) = this(expr, Option(format),
- format match {
- case lit if (lit.foldable && Seq(StringType,
NullType).contains(lit.dataType)) =>
- val value = lit.eval()
- if (value == null) Literal(null, BinaryType)
- else {
- value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT)
match {
- case "hex" => Unhex(expr)
- case "utf-8" => Encode(expr, Literal("UTF-8"))
- case "base64" => UnBase64(expr)
- case _ => lit
- }
- }
-
- case other => other
+case class ToBinary(expr: Expression, format: Option[Expression]) extends
RuntimeReplaceable
+ with ImplicitCastInputTypes {
+
+ override lazy val replacement: Expression = format.map { f =>
+ assert(f.foldable && (f.dataType == StringType || f.dataType == NullType))
+ val value = f.eval()
+ if (value == null) {
+ Literal(null, BinaryType)
+ } else {
+ value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT) match {
+ case "hex" => Unhex(expr)
+ case "utf-8" => Encode(expr, Literal("UTF-8"))
+ case "base64" => UnBase64(expr)
+ case other => throw
QueryCompilationErrors.invalidStringLiteralParameter(
+ "to_binary", "format", other,
+ Some("The value has to be a case-insensitive string literal of " +
+ "'hex', 'utf-8', or 'base64'."))
+ }
}
- )
+ }.getOrElse(Unhex(expr))
- def this(expr: Expression) = this(expr, None, Unhex(expr))
+ def this(expr: Expression) = this(expr, None)
- override def flatArguments: Iterator[Any] = Iterator(expr, format)
- override def exprsReplaced: Seq[Expression] = expr +: format.toSeq
+ def this(expr: Expression, format: Expression) = this(expr, Some({
+ // We perform this check in the constructor to make it eager and not go
through type coercion.
+ if (format.foldable && (format.dataType == StringType || format.dataType
== NullType)) {
+ format
+ } else {
+ throw QueryCompilationErrors.requireLiteralParameter("to_binary",
"format", "string")
+ }
+ }))
override def prettyName: String = "to_binary"
- override def dataType: DataType = BinaryType
- override def checkInputDataTypes(): TypeCheckResult = {
- def checkFormat(lit: Expression) = {
- if (lit.foldable && Seq(StringType, NullType).contains(lit.dataType)) {
- val value = lit.eval()
- value == null ||
- Seq("hex", "utf-8", "base64").contains(
- value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT))
- } else false
- }
+ override def children: Seq[Expression] = expr +: format.toSeq
+
+ override def inputTypes: Seq[AbstractDataType] = children.map(_ =>
StringType)
- if (format.forall(checkFormat)) {
- super.checkInputDataTypes()
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ if (format.isDefined) {
+ copy(expr = newChildren.head, format = Some(newChildren.last))
} else {
- TypeCheckResult.TypeCheckFailure(
- s"Unsupported encoding format: $format. The format has to be " +
- s"a case-insensitive string literal of 'hex', 'utf-8', or 'base64'")
+ copy(expr = newChildren.head)
}
}
-
- override protected def withNewChildInternal(newChild: Expression): ToBinary =
- copy(child = newChild)
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
index 645ff6b..7b896e2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -21,7 +21,6 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -32,26 +31,18 @@ import org.apache.spark.util.Utils
/**
- * Finds all the expressions that are unevaluable and replace/rewrite them
with semantically
- * equivalent expressions that can be evaluated. Currently we replace two
kinds of expressions:
- * 1) [[RuntimeReplaceable]] expressions
- * 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any, CountIf
+ * Finds all the [[RuntimeReplaceable]] expressions that are unevaluable and
replace them
+ * with semantically equivalent expressions that can be evaluated.
+ *
* This is mainly used to provide compatibility with other databases.
* Few examples are:
- * we use this to support "nvl" by replacing it with "coalesce".
+ * we use this to support "left" by replacing it with "substring".
* we use this to replace Every and Any with Min and Max respectively.
- *
- * TODO: In future, explore an option to replace aggregate functions similar to
- * how RuntimeReplaceable does.
*/
object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan =
plan.transformAllExpressionsWithPruning(
- _.containsAnyPattern(RUNTIME_REPLACEABLE, COUNT_IF, BOOL_AGG, REGR_COUNT))
{
- case e: RuntimeReplaceable => e.child
- case CountIf(predicate) => Count(new NullIf(predicate,
Literal.FalseLiteral))
- case BoolOr(arg) => Max(arg)
- case BoolAnd(arg) => Min(arg)
- case RegrCount(left, right) => Count(Seq(left, right))
+ _.containsAnyPattern(RUNTIME_REPLACEABLE)) {
+ case e: RuntimeReplaceable => e.replacement
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 08f2cb9..257df58 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1670,7 +1670,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with
SQLConfHelper with Logg
str.charAt(0)
}.getOrElse('\\')
val likeExpr = ctx.kind.getType match {
- case SqlBaseParser.ILIKE => new ILike(e,
expression(ctx.pattern), escapeChar)
+ case SqlBaseParser.ILIKE => ILike(e, expression(ctx.pattern),
escapeChar)
case _ => Like(e, expression(ctx.pattern), escapeChar)
}
invertIfNotDefined(likeExpr)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 8db2f55..b595966 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -33,13 +33,11 @@ object TreePattern extends Enumeration {
val GROUPING_ANALYTICS: Value = Value
val BINARY_ARITHMETIC: Value = Value
val BINARY_COMPARISON: Value = Value
- val BOOL_AGG: Value = Value
val CASE_WHEN: Value = Value
val CAST: Value = Value
val COALESCE: Value = Value
val CONCAT: Value = Value
val COUNT: Value = Value
- val COUNT_IF: Value = Value
val CREATE_NAMED_STRUCT: Value = Value
val CURRENT_LIKE: Value = Value
val DESERIALIZE_TO_OBJECT: Value = Value
@@ -74,7 +72,6 @@ object TreePattern extends Enumeration {
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
val PYTHON_UDF: Value = Value
- val REGR_COUNT: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
val SCALAR_SUBQUERY: Value = Value
val SCALA_UDF: Value = Value
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index e26f397..ed9dc03 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -135,8 +135,8 @@ package object util extends Logging {
PrettyAttribute(usePrettyExpression(e.child).sql + "." + name,
e.dataType)
case e: GetArrayStructFields =>
PrettyAttribute(usePrettyExpression(e.child) + "." + e.field.name,
e.dataType)
- case r: RuntimeReplaceable =>
- PrettyAttribute(r.mkString(r.exprsReplaced.map(toPrettySQL)), r.dataType)
+ case r: InheritAnalysisRules =>
+ PrettyAttribute(r.makeSQLString(r.parameters.map(toPrettySQL)),
r.dataType)
case c: CastBase if
!c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) =>
PrettyAttribute(usePrettyExpression(c.child).sql, c.dataType)
case p: PythonUDF => PrettyPythonUDF(p.name, p.dataType, p.children)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 28be81d..880c28d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -721,8 +721,20 @@ object QueryCompilationErrors {
s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.")
}
- def unfoldableFieldUnsupportedError(): Throwable = {
- new AnalysisException("The field parameter needs to be a foldable string
value.")
+ def requireLiteralParameter(
+ funcName: String, argName: String, requiredType: String): Throwable = {
+ new AnalysisException(
+ s"The '$argName' parameter of function '$funcName' needs to be a
$requiredType literal.")
+ }
+
+ def invalidStringLiteralParameter(
+ funcName: String,
+ argName: String,
+ invalidValue: String,
+ allowedValues: Option[String] = None): Throwable = {
+ val endingMsg = allowedValues.map(" " + _).getOrElse("")
+ new AnalysisException(s"Invalid value for the '$argName' parameter of
function '$funcName': " +
+ s"$invalidValue.$endingMsg")
}
def literalTypeUnsupportedForSourceTypeError(field: String, source:
Expression): Throwable = {
@@ -2375,12 +2387,4 @@ object QueryCompilationErrors {
new AnalysisException(
"Sinks cannot request distribution and ordering in continuous execution
mode")
}
-
- def invalidScaleParameterRoundBase(function: String): Throwable = {
- new AnalysisException(s"The 'scale' parameter of function '$function' must
be an int constant.")
- }
-
- def invalidNumberOfFunctionParameters(function: String): Throwable = {
- new AnalysisException(s"Invalid number of parameters to the function
'$function'.")
- }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index d1db017..fcf9a6b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.Schema
import org.apache.spark.sql.catalyst.WalkedTypePath
import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression, UnevaluableAggregate}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LogicalPlan}
@@ -126,12 +126,6 @@ object QueryExecutionErrors {
messageParameters = Array.empty)
}
- def evaluateUnevaluableAggregateUnsupportedError(
- methodName: String, unEvaluable: UnevaluableAggregate): Throwable = {
- new SparkUnsupportedOperationException(errorClass = "INTERNAL_ERROR",
- messageParameters = Array(s"Cannot evaluate expression: $methodName:
$unEvaluable"))
- }
-
def dataTypeUnsupportedError(dataType: String, failure: String): Throwable =
{
new SparkIllegalArgumentException(errorClass = "UNSUPPORTED_DATATYPE",
messageParameters = Array(dataType + failure))
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 31d7da3..84603ee 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -1481,8 +1481,8 @@ class DateExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
test("Consistent error handling for datetime formatting and parsing
functions") {
def checkException[T <: Exception : ClassTag](c: String): Unit = {
- checkExceptionInExpression[T](new ParseToTimestamp(Literal("1"),
Literal(c)).child, c)
- checkExceptionInExpression[T](new ParseToDate(Literal("1"),
Literal(c)).child, c)
+ checkExceptionInExpression[T](new ParseToTimestamp(Literal("1"),
Literal(c)).replacement, c)
+ checkExceptionInExpression[T](new ParseToDate(Literal("1"),
Literal(c)).replacement, c)
checkExceptionInExpression[T](ToUnixTimestamp(Literal("1"), Literal(c)),
c)
checkExceptionInExpression[T](UnixTimestamp(Literal("1"), Literal(c)), c)
if (!Set("E", "F", "q", "Q").contains(c)) {
@@ -1502,10 +1502,10 @@ class DateExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
test("SPARK-31896: Handle am-pm timestamp parsing when hour is missing") {
checkEvaluation(
- new ParseToTimestamp(Literal("PM"), Literal("a")).child,
+ new ParseToTimestamp(Literal("PM"), Literal("a")).replacement,
Timestamp.valueOf("1970-01-01 12:00:00.0"))
checkEvaluation(
- new ParseToTimestamp(Literal("11:11 PM"), Literal("mm:ss a")).child,
+ new ParseToTimestamp(Literal("11:11 PM"), Literal("mm:ss
a")).replacement,
Timestamp.valueOf("1970-01-01 12:11:11.0"))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index ea410a6..58e855e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2772,7 +2772,7 @@ object functions {
* @since 3.3.0
*/
def lpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr {
- new BinaryLPad(str.expr, lit(len).expr, lit(pad).expr)
+ BinaryPad("lpad", str.expr, lit(len).expr, lit(pad).expr)
}
/**
@@ -2861,7 +2861,7 @@ object functions {
* @since 3.3.0
*/
def rpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr {
- new BinaryRPad(str.expr, lit(len).expr, lit(pad).expr)
+ BinaryPad("rpad", str.expr, lit(len).expr, lit(pad).expr)
}
/**
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 8b1a12f..a817440 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -11,8 +11,8 @@
| org.apache.spark.sql.catalyst.expressions.Acosh | acosh | SELECT acosh(1) |
struct<ACOSH(1):double> |
| org.apache.spark.sql.catalyst.expressions.Add | + | SELECT 1 + 2 | struct<(1
+ 2):int> |
| org.apache.spark.sql.catalyst.expressions.AddMonths | add_months | SELECT
add_months('2016-08-31', 1) | struct<add_months(2016-08-31, 1):date> |
-| org.apache.spark.sql.catalyst.expressions.AesDecrypt | aes_decrypt | SELECT
aes_decrypt(unhex('83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94'),
'0000111122223333') |
struct<aesdecrypt(unhex(83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94),
0000111122223333):binary> |
-| org.apache.spark.sql.catalyst.expressions.AesEncrypt | aes_encrypt | SELECT
hex(aes_encrypt('Spark', '0000111122223333')) | struct<hex(aesencrypt(Spark,
0000111122223333, GCM, DEFAULT)):string> |
+| org.apache.spark.sql.catalyst.expressions.AesDecrypt | aes_decrypt | SELECT
aes_decrypt(unhex('83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94'),
'0000111122223333') |
struct<aes_decrypt(unhex(83F16B2AA704794132802D248E6BFD4E380078182D1544813898AC97E709B28A94),
0000111122223333, GCM, DEFAULT):binary> |
+| org.apache.spark.sql.catalyst.expressions.AesEncrypt | aes_encrypt | SELECT
hex(aes_encrypt('Spark', '0000111122223333')) | struct<hex(aes_encrypt(Spark,
0000111122223333, GCM, DEFAULT)):string> |
| org.apache.spark.sql.catalyst.expressions.And | and | SELECT true and true |
struct<(true AND true):boolean> |
| org.apache.spark.sql.catalyst.expressions.ArrayAggregate | aggregate |
SELECT aggregate(array(1, 2, 3), 0, (acc, x) -> acc + x) |
struct<aggregate(array(1, 2, 3), 0, lambdafunction((namedlambdavariable() +
namedlambdavariable()), namedlambdavariable(), namedlambdavariable()),
lambdafunction(namedlambdavariable(), namedlambdavariable())):int> |
| org.apache.spark.sql.catalyst.expressions.ArrayContains | array_contains |
SELECT array_contains(array(1, 2, 3), 2) | struct<array_contains(array(1, 2,
3), 2):boolean> |
@@ -99,7 +99,7 @@
| org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT
datediff('2009-07-31', '2009-07-30') | struct<datediff(2009-07-31,
2009-07-30):int> |
| org.apache.spark.sql.catalyst.expressions.DateFormatClass | date_format |
SELECT date_format('2016-04-08', 'y') | struct<date_format(2016-04-08,
y):string> |
| org.apache.spark.sql.catalyst.expressions.DateFromUnixDate |
date_from_unix_date | SELECT date_from_unix_date(1) |
struct<date_from_unix_date(1):date> |
-| org.apache.spark.sql.catalyst.expressions.DatePart | date_part | SELECT
date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456') |
struct<date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456'):int> |
+| org.apache.spark.sql.catalyst.expressions.DatePartExpressionBuilder$ |
date_part | SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456') |
struct<date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456'):int> |
| org.apache.spark.sql.catalyst.expressions.DateSub | date_sub | SELECT
date_sub('2016-07-30', 1) | struct<date_sub(2016-07-30, 1):date> |
| org.apache.spark.sql.catalyst.expressions.DayOfMonth | day | SELECT
day('2009-07-30') | struct<day(2009-07-30):int> |
| org.apache.spark.sql.catalyst.expressions.DayOfMonth | dayofmonth | SELECT
dayofmonth('2009-07-30') | struct<dayofmonth(2009-07-30):int> |
@@ -120,7 +120,7 @@
| org.apache.spark.sql.catalyst.expressions.Explode | explode | SELECT
explode(array(10, 20)) | struct<col:int> |
| org.apache.spark.sql.catalyst.expressions.Explode | explode_outer | SELECT
explode_outer(array(10, 20)) | struct<col:int> |
| org.apache.spark.sql.catalyst.expressions.Expm1 | expm1 | SELECT expm1(0) |
struct<EXPM1(0):double> |
-| org.apache.spark.sql.catalyst.expressions.Extract | extract | SELECT
extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456') | struct<extract(YEAR
FROM TIMESTAMP '2019-08-12 01:00:00.123456'):int> |
+| org.apache.spark.sql.catalyst.expressions.ExtractExpressionBuilder$ |
extract | SELECT extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456') |
struct<extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456'):int> |
| org.apache.spark.sql.catalyst.expressions.Factorial | factorial | SELECT
factorial(5) | struct<factorial(5):bigint> |
| org.apache.spark.sql.catalyst.expressions.FindInSet | find_in_set | SELECT
find_in_set('ab','abc,b,ab,c,def') | struct<find_in_set(ab,
abc,b,ab,c,def):int> |
| org.apache.spark.sql.catalyst.expressions.Flatten | flatten | SELECT
flatten(array(array(1, 2), array(3, 4))) | struct<flatten(array(array(1, 2),
array(3, 4))):array<int>> |
@@ -141,7 +141,6 @@
| org.apache.spark.sql.catalyst.expressions.Hypot | hypot | SELECT hypot(3, 4)
| struct<HYPOT(3, 4):double> |
| org.apache.spark.sql.catalyst.expressions.ILike | ilike | SELECT
ilike('Spark', '_Park') | struct<ilike(Spark, _Park):boolean> |
| org.apache.spark.sql.catalyst.expressions.If | if | SELECT if(1 < 2, 'a',
'b') | struct<(IF((1 < 2), a, b)):string> |
-| org.apache.spark.sql.catalyst.expressions.IfNull | ifnull | SELECT
ifnull(NULL, array('2')) | struct<ifnull(NULL, array(2)):array<string>> |
| org.apache.spark.sql.catalyst.expressions.In | in | SELECT 1 in(1, 2, 3) |
struct<(1 IN (1, 2, 3)):boolean> |
| org.apache.spark.sql.catalyst.expressions.InitCap | initcap | SELECT
initcap('sPark sql') | struct<initcap(sPark sql):string> |
| org.apache.spark.sql.catalyst.expressions.Inline | inline | SELECT
inline(array(struct(1, 'a'), struct(2, 'b'))) | struct<col1:int,col2:string> |
@@ -182,8 +181,8 @@
| org.apache.spark.sql.catalyst.expressions.MakeDate | make_date | SELECT
make_date(2013, 7, 15) | struct<make_date(2013, 7, 15):date> |
| org.apache.spark.sql.catalyst.expressions.MakeInterval | make_interval |
SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001) |
struct<make_interval(100, 11, 1, 1, 12, 30, 1.001001):interval> |
| org.apache.spark.sql.catalyst.expressions.MakeTimestamp | make_timestamp |
SELECT make_timestamp(2014, 12, 28, 6, 30, 45.887) |
struct<make_timestamp(2014, 12, 28, 6, 30, 45.887):timestamp> |
-| org.apache.spark.sql.catalyst.expressions.MakeTimestampLTZ |
make_timestamp_ltz | SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887) |
struct<make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887):timestamp> |
-| org.apache.spark.sql.catalyst.expressions.MakeTimestampNTZ |
make_timestamp_ntz | SELECT make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887) |
struct<make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887):timestamp_ntz> |
+| org.apache.spark.sql.catalyst.expressions.MakeTimestampLTZExpressionBuilder$
| make_timestamp_ltz | SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887) |
struct<make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887):timestamp> |
+| org.apache.spark.sql.catalyst.expressions.MakeTimestampNTZExpressionBuilder$
| make_timestamp_ntz | SELECT make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887) |
struct<make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887):timestamp_ntz> |
| org.apache.spark.sql.catalyst.expressions.MakeYMInterval | make_ym_interval
| SELECT make_ym_interval(1, 2) | struct<make_ym_interval(1, 2):interval year
to month> |
| org.apache.spark.sql.catalyst.expressions.MapConcat | map_concat | SELECT
map_concat(map(1, 'a', 2, 'b'), map(3, 'c')) | struct<map_concat(map(1, a, 2,
b), map(3, c)):map<int,string>> |
| org.apache.spark.sql.catalyst.expressions.MapContainsKey | map_contains_key
| SELECT map_contains_key(map(1, 'a', 2, 'b'), 1) |
struct<map_contains_key(map(1, a, 2, b), 1):boolean> |
@@ -211,15 +210,16 @@
| org.apache.spark.sql.catalyst.expressions.Now | now | SELECT now() |
struct<now():timestamp> |
| org.apache.spark.sql.catalyst.expressions.NthValue | nth_value | SELECT a,
b, nth_value(b, 2) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2),
('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct<a:string,b:int,nth_value(b,
2) OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST RANGE BETWEEN UNBOUNDED
PRECEDING AND CURRENT ROW):int> |
| org.apache.spark.sql.catalyst.expressions.NullIf | nullif | SELECT nullif(2,
2) | struct<nullif(2, 2):int> |
+| org.apache.spark.sql.catalyst.expressions.Nvl | ifnull | SELECT ifnull(NULL,
array('2')) | struct<ifnull(NULL, array(2)):array<string>> |
| org.apache.spark.sql.catalyst.expressions.Nvl | nvl | SELECT nvl(NULL,
array('2')) | struct<nvl(NULL, array(2)):array<string>> |
| org.apache.spark.sql.catalyst.expressions.Nvl2 | nvl2 | SELECT nvl2(NULL, 2,
1) | struct<nvl2(NULL, 2, 1):int> |
| org.apache.spark.sql.catalyst.expressions.OctetLength | octet_length |
SELECT octet_length('Spark SQL') | struct<octet_length(Spark SQL):int> |
| org.apache.spark.sql.catalyst.expressions.Or | or | SELECT true or false |
struct<(true OR false):boolean> |
| org.apache.spark.sql.catalyst.expressions.Overlay | overlay | SELECT
overlay('Spark SQL' PLACING '_' FROM 6) | struct<overlay(Spark SQL, _, 6,
-1):string> |
| org.apache.spark.sql.catalyst.expressions.ParseToDate | to_date | SELECT
to_date('2009-07-30 04:17:52') | struct<to_date(2009-07-30 04:17:52):date> |
-| org.apache.spark.sql.catalyst.expressions.ParseToTimestamp | to_timestamp |
SELECT to_timestamp('2016-12-31 00:12:00') | struct<to_timestamp(2016-12-31
00:12:00):timestamp> |
-| org.apache.spark.sql.catalyst.expressions.ParseToTimestampLTZ |
to_timestamp_ltz | SELECT to_timestamp_ltz('2016-12-31 00:12:00') |
struct<to_timestamp_ltz(2016-12-31 00:12:00):timestamp> |
-| org.apache.spark.sql.catalyst.expressions.ParseToTimestampNTZ |
to_timestamp_ntz | SELECT to_timestamp_ntz('2016-12-31 00:12:00') |
struct<to_timestamp_ntz(2016-12-31 00:12:00):timestamp_ntz> |
+| org.apache.spark.sql.catalyst.expressions.ParseToTimestampExpressionBuilder$
| to_timestamp | SELECT to_timestamp('2016-12-31 00:12:00') |
struct<to_timestamp(2016-12-31 00:12:00):timestamp> |
+|
org.apache.spark.sql.catalyst.expressions.ParseToTimestampLTZExpressionBuilder$
| to_timestamp_ltz | SELECT to_timestamp_ltz('2016-12-31 00:12:00') |
struct<to_timestamp_ltz(2016-12-31 00:12:00):timestamp> |
+|
org.apache.spark.sql.catalyst.expressions.ParseToTimestampNTZExpressionBuilder$
| to_timestamp_ntz | SELECT to_timestamp_ntz('2016-12-31 00:12:00') |
struct<to_timestamp_ntz(2016-12-31 00:12:00):timestamp_ntz> |
| org.apache.spark.sql.catalyst.expressions.ParseUrl | parse_url | SELECT
parse_url('http://spark.apache.org/path?query=1', 'HOST') |
struct<parse_url(http://spark.apache.org/path?query=1, HOST):string> |
| org.apache.spark.sql.catalyst.expressions.PercentRank | percent_rank |
SELECT a, b, percent_rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES
('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) |
struct<a:string,b:int,PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS
FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):double> |
| org.apache.spark.sql.catalyst.expressions.Pi | pi | SELECT pi() |
struct<PI():double> |
diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
index 94eb96f..fef16b7 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
@@ -141,14 +141,17 @@ select to_binary('abc');
select to_binary('abc', 'utf-8');
select to_binary('abc', 'base64');
select to_binary('abc', 'hex');
+-- 'format' parameter can be any foldable string value, not just literal.
select to_binary('abc', concat('utf', '-8'));
-select to_binary('abc', concat('base', '64'));
+-- 'format' parameter is case insensitive.
select to_binary('abc', 'Hex');
-select to_binary('abc', 'UTF-8');
+-- null inputs lead to null result.
select to_binary('abc', null);
select to_binary(null, 'utf-8');
select to_binary(null, null);
select to_binary(null, cast(null as string));
+-- 'format' parameter must be string type or void type.
select to_binary(null, cast(null as int));
-select to_binary('abc', 'invalidFormat');
select to_binary('abc', 1);
+-- invalid inputs.
+select to_binary('abc', 'invalidFormat');
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out
b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out
index 7a27a89..5f7bd9f 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out
@@ -74,7 +74,7 @@ select map_contains_key(map('1', 'a', '2', 'b'), 1)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'array_contains(map_keys(map('1', 'a', '2', 'b')), 1)' due to
data type mismatch: Input to function array_contains should have been array
followed by a value with same element type, but it's [array<string>, int].;
line 1 pos 7
+cannot resolve 'map_contains_key(map('1', 'a', '2', 'b'), 1)' due to data type
mismatch: Input to function map_contains_key should have been map followed by a
value with same key type, but it's [map<string,string>, int].; line 1 pos 7
-- !query
@@ -83,7 +83,7 @@ select map_contains_key(map(1, 'a', 2, 'b'), '1')
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'array_contains(map_keys(map(1, 'a', 2, 'b')), '1')' due to
data type mismatch: Input to function array_contains should have been array
followed by a value with same element type, but it's [array<int>, string].;
line 1 pos 7
+cannot resolve 'map_contains_key(map(1, 'a', 2, 'b'), '1')' due to data type
mismatch: Input to function map_contains_key should have been map followed by a
value with same key type, but it's [map<int,string>, string].; line 1 pos 7
-- !query
diff --git
a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
index ec7f41d..913f1cf 100644
---
a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
+++
b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 117
+-- Number of queries: 115
-- !query
@@ -867,14 +867,6 @@ abc
-- !query
-select to_binary('abc', concat('base', '64'))
--- !query schema
-struct<to_binary(abc, concat(base, 64)):binary>
--- !query output
-i�
-
-
--- !query
select to_binary('abc', 'Hex')
-- !query schema
struct<to_binary(abc, Hex):binary>
@@ -883,14 +875,6 @@ struct<to_binary(abc, Hex):binary>
-- !query
-select to_binary('abc', 'UTF-8')
--- !query schema
-struct<to_binary(abc, UTF-8):binary>
--- !query output
-abc
-
-
--- !query
select to_binary('abc', null)
-- !query schema
struct<to_binary(abc, NULL):binary>
@@ -928,22 +912,22 @@ select to_binary(null, cast(null as int))
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'to_binary(NULL, CAST(NULL AS INT))' due to data type mismatch:
Unsupported encoding format: Some(cast(null as int)). The format has to be a
case-insensitive string literal of 'hex', 'utf-8', or 'base64'; line 1 pos 7
+The 'format' parameter of function 'to_binary' needs to be a string literal.;
line 1 pos 7
-- !query
-select to_binary('abc', 'invalidFormat')
+select to_binary('abc', 1)
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'to_binary('abc', 'invalidFormat')' due to data type mismatch:
Unsupported encoding format: Some(invalidFormat). The format has to be a
case-insensitive string literal of 'hex', 'utf-8', or 'base64'; line 1 pos 7
+The 'format' parameter of function 'to_binary' needs to be a string literal.;
line 1 pos 7
-- !query
-select to_binary('abc', 1)
+select to_binary('abc', 'invalidFormat')
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'to_binary('abc', 1)' due to data type mismatch: Unsupported
encoding format: Some(1). The format has to be a case-insensitive string
literal of 'hex', 'utf-8', or 'base64'; line 1 pos 7
+Invalid value for the 'format' parameter of function 'to_binary':
invalidformat. The value has to be a case-insensitive string literal of 'hex',
'utf-8', or 'base64'.
diff --git
a/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out
b/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out
index 1ec00af..132bd96 100644
---
a/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out
+++
b/sql/core/src/test/resources/sql-tests/results/ceil-floor-with-scale-param.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 18
+-- Number of queries: 24
-- !query
@@ -80,7 +80,7 @@ SELECT CEIL(2.5, null)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-The 'scale' parameter of function 'ceil' must be an int constant.; line 1 pos 7
+The 'scale' parameter of function 'ceil' needs to be a int literal.; line 1
pos 7
-- !query
@@ -89,7 +89,7 @@ SELECT CEIL(2.5, 'a')
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-The 'scale' parameter of function 'ceil' must be an int constant.; line 1 pos 7
+The 'scale' parameter of function 'ceil' needs to be a int literal.; line 1
pos 7
-- !query
@@ -98,7 +98,7 @@ SELECT CEIL(2.5, 0, 0)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-Invalid number of parameters to the function 'ceil'.; line 1 pos 7
+Invalid number of arguments for function ceil. Expected: 2; Found: 3; line 1
pos 7
-- !query
@@ -179,7 +179,7 @@ SELECT FLOOR(2.5, null)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-The 'scale' parameter of function 'floor' must be an int constant.; line 1 pos
7
+The 'scale' parameter of function 'floor' needs to be a int literal.; line 1
pos 7
-- !query
@@ -188,7 +188,7 @@ SELECT FLOOR(2.5, 'a')
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-The 'scale' parameter of function 'floor' must be an int constant.; line 1 pos
7
+The 'scale' parameter of function 'floor' needs to be a int literal.; line 1
pos 7
-- !query
@@ -197,4 +197,4 @@ SELECT FLOOR(2.5, 0, 0)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-Invalid number of parameters to the function 'floor'.; line 1 pos 7
+Invalid number of arguments for function floor. Expected: 2; Found: 3; line 1
pos 7
diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out
b/sql/core/src/test/resources/sql-tests/results/extract.sql.out
index e3f676d..55776d3 100644
--- a/sql/core/src/test/resources/sql-tests/results/extract.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out
@@ -660,7 +660,7 @@ select date_part(c, c) from t
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-The field parameter needs to be a foldable string value.; line 1 pos 7
+The 'field' parameter of function 'date_part' needs to be a string literal.;
line 1 pos 7
-- !query
@@ -677,7 +677,7 @@ select date_part(i, i) from t
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-The field parameter needs to be a foldable string value.; line 1 pos 7
+The 'field' parameter of function 'date_part' needs to be a string literal.;
line 1 pos 7
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index cd0fa48..400d6c9 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -470,7 +470,7 @@ SELECT every(1)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'every(1)' due to data type mismatch: Input to function 'every'
should have been boolean, but it's [int].; line 1 pos 7
+cannot resolve 'every(1)' due to data type mismatch: argument 1 requires
boolean type, however, '1' is of int type.; line 1 pos 7
-- !query
@@ -479,7 +479,7 @@ SELECT some(1S)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'some(1S)' due to data type mismatch: Input to function 'some'
should have been boolean, but it's [smallint].; line 1 pos 7
+cannot resolve 'some(1S)' due to data type mismatch: argument 1 requires
boolean type, however, '1S' is of smallint type.; line 1 pos 7
-- !query
@@ -488,7 +488,7 @@ SELECT any(1L)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'any(1L)' due to data type mismatch: Input to function 'any'
should have been boolean, but it's [bigint].; line 1 pos 7
+cannot resolve 'any(1L)' due to data type mismatch: argument 1 requires
boolean type, however, '1L' is of bigint type.; line 1 pos 7
-- !query
@@ -497,7 +497,7 @@ SELECT every("true")
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'every('true')' due to data type mismatch: Input to function
'every' should have been boolean, but it's [string].; line 1 pos 7
+cannot resolve 'every('true')' due to data type mismatch: argument 1 requires
boolean type, however, ''true'' is of string type.; line 1 pos 7
-- !query
@@ -506,7 +506,7 @@ SELECT bool_and(1.0)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'bool_and(1.0BD)' due to data type mismatch: Input to function
'bool_and' should have been boolean, but it's [decimal(2,1)].; line 1 pos 7
+cannot resolve 'bool_and(1.0BD)' due to data type mismatch: argument 1
requires boolean type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7
-- !query
@@ -515,7 +515,7 @@ SELECT bool_or(1.0D)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'bool_or(1.0D)' due to data type mismatch: Input to function
'bool_or' should have been boolean, but it's [double].; line 1 pos 7
+cannot resolve 'bool_or(1.0D)' due to data type mismatch: argument 1 requires
boolean type, however, '1.0D' is of double type.; line 1 pos 7
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/results/map.sql.out
b/sql/core/src/test/resources/sql-tests/results/map.sql.out
index aa13fee..b615a62 100644
--- a/sql/core/src/test/resources/sql-tests/results/map.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/map.sql.out
@@ -72,7 +72,7 @@ select map_contains_key(map('1', 'a', '2', 'b'), 1)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'array_contains(map_keys(map('1', 'a', '2', 'b')), 1)' due to
data type mismatch: Input to function array_contains should have been array
followed by a value with same element type, but it's [array<string>, int].;
line 1 pos 7
+cannot resolve 'map_contains_key(map('1', 'a', '2', 'b'), 1)' due to data type
mismatch: Input to function map_contains_key should have been map followed by a
value with same key type, but it's [map<string,string>, int].; line 1 pos 7
-- !query
@@ -81,4 +81,4 @@ select map_contains_key(map(1, 'a', 2, 'b'), '1')
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'array_contains(map_keys(map(1, 'a', 2, 'b')), '1')' due to
data type mismatch: Input to function array_contains should have been array
followed by a value with same element type, but it's [array<int>, string].;
line 1 pos 7
+cannot resolve 'map_contains_key(map(1, 'a', 2, 'b'), '1')' due to data type
mismatch: Input to function map_contains_key should have been map followed by a
value with same key type, but it's [map<int,string>, string].; line 1 pos 7
diff --git
a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index bb2974d..bf4348d 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 117
+-- Number of queries: 115
-- !query
@@ -863,14 +863,6 @@ abc
-- !query
-select to_binary('abc', concat('base', '64'))
--- !query schema
-struct<to_binary(abc, concat(base, 64)):binary>
--- !query output
-i�
-
-
--- !query
select to_binary('abc', 'Hex')
-- !query schema
struct<to_binary(abc, Hex):binary>
@@ -879,14 +871,6 @@ struct<to_binary(abc, Hex):binary>
-- !query
-select to_binary('abc', 'UTF-8')
--- !query schema
-struct<to_binary(abc, UTF-8):binary>
--- !query output
-abc
-
-
--- !query
select to_binary('abc', null)
-- !query schema
struct<to_binary(abc, NULL):binary>
@@ -924,22 +908,22 @@ select to_binary(null, cast(null as int))
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'to_binary(NULL, CAST(NULL AS INT))' due to data type mismatch:
Unsupported encoding format: Some(cast(null as int)). The format has to be a
case-insensitive string literal of 'hex', 'utf-8', or 'base64'; line 1 pos 7
+The 'format' parameter of function 'to_binary' needs to be a string literal.;
line 1 pos 7
-- !query
-select to_binary('abc', 'invalidFormat')
+select to_binary('abc', 1)
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'to_binary('abc', 'invalidFormat')' due to data type mismatch:
Unsupported encoding format: Some(invalidFormat). The format has to be a
case-insensitive string literal of 'hex', 'utf-8', or 'base64'; line 1 pos 7
+The 'format' parameter of function 'to_binary' needs to be a string literal.;
line 1 pos 7
-- !query
-select to_binary('abc', 1)
+select to_binary('abc', 'invalidFormat')
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'to_binary('abc', 1)' due to data type mismatch: Unsupported
encoding format: Some(1). The format has to be a case-insensitive string
literal of 'hex', 'utf-8', or 'base64'; line 1 pos 7
+Invalid value for the 'format' parameter of function 'to_binary':
invalidformat. The value has to be a case-insensitive string literal of 'hex',
'utf-8', or 'base64'.
diff --git
a/sql/core/src/test/resources/sql-tests/results/timestamp-ltz.sql.out
b/sql/core/src/test/resources/sql-tests/results/timestamp-ltz.sql.out
index 48036c6..057cdf1 100644
--- a/sql/core/src/test/resources/sql-tests/results/timestamp-ltz.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/timestamp-ltz.sql.out
@@ -45,7 +45,7 @@ struct<make_timestamp_ltz(2021, 7, 11, 6, 30,
45.678):timestamp>
-- !query
SELECT make_timestamp_ltz(2021, 07, 11, 6, 30, 45.678, 'CET')
-- !query schema
-struct<make_timestamp_ltz(2021, 7, 11, 6, 30, 45.678):timestamp>
+struct<make_timestamp_ltz(2021, 7, 11, 6, 30, 45.678, CET):timestamp>
-- !query output
2021-07-10 21:30:45.678
diff --git
a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out
b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out
index 5db0f4d..d543c6a 100644
--- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out
@@ -380,7 +380,7 @@ SELECT every(udf(1))
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'every(CAST(udf(cast(1 as string)) AS INT))' due to data type
mismatch: Input to function 'every' should have been boolean, but it's [int].;
line 1 pos 7
+cannot resolve 'every(CAST(udf(cast(1 as string)) AS INT))' due to data type
mismatch: argument 1 requires boolean type, however, 'CAST(udf(cast(1 as
string)) AS INT)' is of int type.; line 1 pos 7
-- !query
@@ -389,7 +389,7 @@ SELECT some(udf(1S))
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'some(CAST(udf(cast(1 as string)) AS SMALLINT))' due to data
type mismatch: Input to function 'some' should have been boolean, but it's
[smallint].; line 1 pos 7
+cannot resolve 'some(CAST(udf(cast(1 as string)) AS SMALLINT))' due to data
type mismatch: argument 1 requires boolean type, however, 'CAST(udf(cast(1 as
string)) AS SMALLINT)' is of smallint type.; line 1 pos 7
-- !query
@@ -398,7 +398,7 @@ SELECT any(udf(1L))
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'any(CAST(udf(cast(1 as string)) AS BIGINT))' due to data type
mismatch: Input to function 'any' should have been boolean, but it's [bigint].;
line 1 pos 7
+cannot resolve 'any(CAST(udf(cast(1 as string)) AS BIGINT))' due to data type
mismatch: argument 1 requires boolean type, however, 'CAST(udf(cast(1 as
string)) AS BIGINT)' is of bigint type.; line 1 pos 7
-- !query
@@ -407,7 +407,7 @@ SELECT udf(every("true"))
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
-cannot resolve 'every('true')' due to data type mismatch: Input to function
'every' should have been boolean, but it's [string].; line 1 pos 11
+cannot resolve 'every('true')' due to data type mismatch: argument 1 requires
boolean type, however, ''true'' is of string type.; line 1 pos 11
-- !query
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 215d38d..42293bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1028,7 +1028,8 @@ class DataFrameAggregateSuite extends QueryTest
val error = intercept[AnalysisException] {
sql("SELECT COUNT_IF(x) FROM tempView")
}
- assert(error.message.contains("function count_if requires boolean type"))
+ assert(error.message.contains("cannot resolve 'count_if(tempview.x)' due
to data type " +
+ "mismatch: argument 1 requires boolean type, however, 'tempview.x' is
of string type"))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]