This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 686e37e640d [SPARK-44030][SQL] Implement DataTypeExpression to offer Unapply for expression 686e37e640d is described below commit 686e37e640d078f9727e5457e47ce58033ce8684 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Mon Jun 26 22:47:01 2023 -0400 [SPARK-44030][SQL] Implement DataTypeExpression to offer Unapply for expression ### What changes were proposed in this pull request? Implement DataTypeExpression to offer `Unapply` for expression. By doing so we can drop `Unapply` from DataType. ### Why are the changes needed? Simplify DataType interface. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing tests Closes #41559 from amaliujia/move_datatypes_1. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../sql/catalyst/analysis/AnsiTypeCoercion.scala | 29 +++++----- .../sql/catalyst/analysis/DecimalPrecision.scala | 20 +++---- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 46 +++++++-------- .../apache/spark/sql/types/AbstractDataType.scala | 40 +------------ .../org/apache/spark/sql/types/DataType.scala | 10 ---- .../spark/sql/types/DataTypeExpression.scala | 67 ++++++++++++++++++++++ .../apache/spark/sql/hive/client/HiveShim.scala | 4 +- 7 files changed, 119 insertions(+), 97 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 56dbb2a8590..d3f20f87493 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -244,28 +244,29 @@ object AnsiTypeCoercion extends TypeCoercionBase { val promoteType = findWiderTypeForString(left.dataType, right.dataType).get b.withNewChildren(Seq(castExpr(left, promoteType), castExpr(right, promoteType))) - case Abs(e @ StringType(), failOnError) => Abs(Cast(e, DoubleType), failOnError) - case m @ UnaryMinus(e @ StringType(), _) => m.withNewChildren(Seq(Cast(e, DoubleType))) - case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) + case Abs(e @ StringTypeExpression(), failOnError) => Abs(Cast(e, DoubleType), failOnError) + case m @ UnaryMinus(e @ StringTypeExpression(), _) => + m.withNewChildren(Seq(Cast(e, DoubleType))) + case UnaryPositive(e @ StringTypeExpression()) => UnaryPositive(Cast(e, DoubleType)) - case d @ DateAdd(left @ StringType(), _) => + case d @ DateAdd(left @ StringTypeExpression(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case d @ DateAdd(_, right @ StringType()) => + case d @ DateAdd(_, right @ StringTypeExpression()) => d.copy(days = Cast(right, IntegerType)) - case d @ DateSub(left @ StringType(), _) => + case d @ DateSub(left @ StringTypeExpression(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case d @ DateSub(_, right @ StringType()) => + case d @ DateSub(_, right @ StringTypeExpression()) => d.copy(days = Cast(right, IntegerType)) - case s @ SubtractDates(left @ StringType(), _, _) => + case s @ SubtractDates(left @ StringTypeExpression(), _, _) => s.copy(left = Cast(s.left, DateType)) - case s @ SubtractDates(_, right @ StringType(), _) => + case s @ SubtractDates(_, right @ StringTypeExpression(), _) => s.copy(right = Cast(s.right, DateType)) - case t @ TimeAdd(left @ StringType(), _, _) => + case t @ TimeAdd(left @ StringTypeExpression(), _, _) => t.copy(start = Cast(t.start, TimestampType)) - case t @ SubtractTimestamps(left @ StringType(), _, _, _) => + case t @ SubtractTimestamps(left @ StringTypeExpression(), _, _, _) => t.copy(left = Cast(t.left, t.right.dataType)) - case t @ SubtractTimestamps(_, right @ StringType(), _, _) => + case t @ SubtractTimestamps(_, right @ StringTypeExpression(), _, _) => t.copy(right = Cast(right, t.left.dataType)) } } @@ -296,9 +297,9 @@ object AnsiTypeCoercion extends TypeCoercionBase { case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) => + case s @ SubtractTimestamps(DateTypeExpression(), AnyTimestampType(), _, _) => s.copy(left = Cast(s.left, s.right.dataType)) - case s @ SubtractTimestamps(AnyTimestampType(), DateType(), _, _) => + case s @ SubtractTimestamps(AnyTimestampType(), DateTypeExpression(), _, _) => s.copy(right = Cast(s.right, s.left.dataType)) case s @ SubtractTimestamps(AnyTimestampType(), AnyTimestampType(), _, _) if s.left.dataType != s.right.dataType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 46fbf071f43..90fd13dfb54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -105,7 +105,7 @@ object DecimalPrecision extends TypeCoercionRule { */ private val integralAndDecimalLiteral: PartialFunction[Expression, Expression] = { - case GreaterThan(i @ IntegralType(), DecimalLiteral(value)) => + case GreaterThan(i @ IntegralTypeExpression(), DecimalLiteral(value)) => if (DecimalLiteral.smallerThanSmallestLong(value)) { TrueLiteral } else if (DecimalLiteral.largerThanLargestLong(value)) { @@ -114,7 +114,7 @@ object DecimalPrecision extends TypeCoercionRule { GreaterThan(i, Literal(value.floor.toLong)) } - case GreaterThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + case GreaterThanOrEqual(i @ IntegralTypeExpression(), DecimalLiteral(value)) => if (DecimalLiteral.smallerThanSmallestLong(value)) { TrueLiteral } else if (DecimalLiteral.largerThanLargestLong(value)) { @@ -123,7 +123,7 @@ object DecimalPrecision extends TypeCoercionRule { GreaterThanOrEqual(i, Literal(value.ceil.toLong)) } - case LessThan(i @ IntegralType(), DecimalLiteral(value)) => + case LessThan(i @ IntegralTypeExpression(), DecimalLiteral(value)) => if (DecimalLiteral.smallerThanSmallestLong(value)) { FalseLiteral } else if (DecimalLiteral.largerThanLargestLong(value)) { @@ -132,7 +132,7 @@ object DecimalPrecision extends TypeCoercionRule { LessThan(i, Literal(value.ceil.toLong)) } - case LessThanOrEqual(i @ IntegralType(), DecimalLiteral(value)) => + case LessThanOrEqual(i @ IntegralTypeExpression(), DecimalLiteral(value)) => if (DecimalLiteral.smallerThanSmallestLong(value)) { FalseLiteral } else if (DecimalLiteral.largerThanLargestLong(value)) { @@ -141,7 +141,7 @@ object DecimalPrecision extends TypeCoercionRule { LessThanOrEqual(i, Literal(value.floor.toLong)) } - case GreaterThan(DecimalLiteral(value), i @ IntegralType()) => + case GreaterThan(DecimalLiteral(value), i @ IntegralTypeExpression()) => if (DecimalLiteral.smallerThanSmallestLong(value)) { FalseLiteral } else if (DecimalLiteral.largerThanLargestLong(value)) { @@ -150,7 +150,7 @@ object DecimalPrecision extends TypeCoercionRule { GreaterThan(Literal(value.ceil.toLong), i) } - case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + case GreaterThanOrEqual(DecimalLiteral(value), i @ IntegralTypeExpression()) => if (DecimalLiteral.smallerThanSmallestLong(value)) { FalseLiteral } else if (DecimalLiteral.largerThanLargestLong(value)) { @@ -159,7 +159,7 @@ object DecimalPrecision extends TypeCoercionRule { GreaterThanOrEqual(Literal(value.floor.toLong), i) } - case LessThan(DecimalLiteral(value), i @ IntegralType()) => + case LessThan(DecimalLiteral(value), i @ IntegralTypeExpression()) => if (DecimalLiteral.smallerThanSmallestLong(value)) { TrueLiteral } else if (DecimalLiteral.largerThanLargestLong(value)) { @@ -168,7 +168,7 @@ object DecimalPrecision extends TypeCoercionRule { LessThan(Literal(value.floor.toLong), i) } - case LessThanOrEqual(DecimalLiteral(value), i @ IntegralType()) => + case LessThanOrEqual(DecimalLiteral(value), i @ IntegralTypeExpression()) => if (DecimalLiteral.smallerThanSmallestLong(value)) { TrueLiteral } else if (DecimalLiteral.largerThanLargestLong(value)) { @@ -208,9 +208,9 @@ object DecimalPrecision extends TypeCoercionRule { b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + case (l @ IntegralTypeExpression(), r @ DecimalType.Expression(_, _)) => b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) - case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + case (l @ DecimalType.Expression(_, _), r @ IntegralTypeExpression()) => b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => b.makeCopy(Array(l, Cast(r, DoubleType))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index bd2255134fc..ae4db0575ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -461,8 +461,8 @@ abstract class TypeCoercionBase { m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) // Hive lets you do aggregation of timestamps... for some reason - case Sum(e @ TimestampType(), _) => Sum(Cast(e, DoubleType)) - case Average(e @ TimestampType(), _) => Average(Cast(e, DoubleType)) + case Sum(e @ TimestampTypeExpression(), _) => Sum(Cast(e, DoubleType)) + case Average(e @ TimestampTypeExpression(), _) => Average(Cast(e, DoubleType)) // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and @@ -1105,18 +1105,18 @@ object TypeCoercion extends TypeCoercionBase { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), right) + case a @ BinaryArithmetic(left @ StringTypeExpression(), right) if right.dataType != CalendarIntervalType => a.makeCopy(Array(Cast(left, DoubleType), right)) - case a @ BinaryArithmetic(left, right @ StringType()) + case a @ BinaryArithmetic(left, right @ StringTypeExpression()) if left.dataType != CalendarIntervalType => a.makeCopy(Array(left, Cast(right, DoubleType))) // For equality between string and timestamp we cast the string to a timestamp // so that things like rounding of subsecond precision does not affect the comparison. - case p @ Equality(left @ StringType(), right @ TimestampType()) => + case p @ Equality(left @ StringTypeExpression(), right @ TimestampTypeExpression()) => p.makeCopy(Array(Cast(left, TimestampType), right)) - case p @ Equality(left @ TimestampType(), right @ StringType()) => + case p @ Equality(left @ TimestampTypeExpression(), right @ StringTypeExpression()) => p.makeCopy(Array(left, Cast(right, TimestampType))) case p @ BinaryComparison(left, right) @@ -1142,30 +1142,30 @@ object TypeCoercion extends TypeCoercionBase { // We may simplify the expression if one side is literal numeric values // TODO: Maybe these rules should go into the optimizer. - case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + case EqualTo(bool @ BooleanTypeExpression(), Literal(value, _: NumericType)) if trueValues.contains(value) => bool - case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType)) + case EqualTo(bool @ BooleanTypeExpression(), Literal(value, _: NumericType)) if falseValues.contains(value) => Not(bool) - case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + case EqualTo(Literal(value, _: NumericType), bool @ BooleanTypeExpression()) if trueValues.contains(value) => bool - case EqualTo(Literal(value, _: NumericType), bool @ BooleanType()) + case EqualTo(Literal(value, _: NumericType), bool @ BooleanTypeExpression()) if falseValues.contains(value) => Not(bool) - case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + case EqualNullSafe(bool @ BooleanTypeExpression(), Literal(value, _: NumericType)) if trueValues.contains(value) => And(IsNotNull(bool), bool) - case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType)) + case EqualNullSafe(bool @ BooleanTypeExpression(), Literal(value, _: NumericType)) if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) - case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanTypeExpression()) if trueValues.contains(value) => And(IsNotNull(bool), bool) - case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType()) + case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanTypeExpression()) if falseValues.contains(value) => And(IsNotNull(bool), Not(bool)) - case EqualTo(left @ BooleanType(), right @ NumericType()) => + case EqualTo(left @ BooleanTypeExpression(), right @ NumericTypeExpression()) => EqualTo(Cast(left, right.dataType), right) - case EqualTo(left @ NumericType(), right @ BooleanType()) => + case EqualTo(left @ NumericTypeExpression(), right @ BooleanTypeExpression()) => EqualTo(left, Cast(right, left.dataType)) - case EqualNullSafe(left @ BooleanType(), right @ NumericType()) => + case EqualNullSafe(left @ BooleanTypeExpression(), right @ NumericTypeExpression()) => EqualNullSafe(Cast(left, right.dataType), right) - case EqualNullSafe(left @ NumericType(), right @ BooleanType()) => + case EqualNullSafe(left @ NumericTypeExpression(), right @ BooleanTypeExpression()) => EqualNullSafe(left, Cast(right, left.dataType)) } } @@ -1175,13 +1175,13 @@ object TypeCoercion extends TypeCoercionBase { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case d @ DateAdd(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case d @ DateAdd(StringTypeExpression(), _) => d.copy(startDate = Cast(d.startDate, DateType)) case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case d @ DateSub(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case d @ DateSub(StringTypeExpression(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) => + case s @ SubtractTimestamps(DateTypeExpression(), AnyTimestampType(), _, _) => s.copy(left = Cast(s.left, s.right.dataType)) - case s @ SubtractTimestamps(AnyTimestampType(), DateType(), _, _) => + case s @ SubtractTimestamps(AnyTimestampType(), DateTypeExpression(), _, _) => s.copy(right = Cast(s.right, s.left.dataType)) case s @ SubtractTimestamps(AnyTimestampType(), AnyTimestampType(), _, _) if s.left.dataType != s.right.dataType => @@ -1189,7 +1189,7 @@ object TypeCoercion extends TypeCoercionBase { val newRight = castIfNotSameType(s.right, TimestampNTZType) s.copy(left = newLeft, right = newRight) - case t @ TimeAdd(StringType(), _, _) => t.copy(start = Cast(t.start, TimestampType)) + case t @ TimeAdd(StringTypeExpression(), _, _) => t.copy(start = Cast(t.start, TimestampType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index f498282d4f3..01fa27822b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -121,16 +121,7 @@ protected[sql] object AnyDataType extends AbstractDataType with Serializable { */ protected[sql] abstract class AtomicType extends DataType -object AtomicType { - /** - * Enables matching against AtomicType for expressions: - * {{{ - * case Cast(child @ AtomicType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[AtomicType] -} +object AtomicType /** @@ -143,15 +134,6 @@ abstract class NumericType extends AtomicType private[spark] object NumericType extends AbstractDataType { - /** - * Enables matching against NumericType for expressions: - * {{{ - * case Cast(child @ NumericType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - override private[spark] def defaultConcreteType: DataType = DoubleType override private[spark] def simpleString: String = "numeric" @@ -162,15 +144,6 @@ private[spark] object NumericType extends AbstractDataType { private[sql] object IntegralType extends AbstractDataType { - /** - * Enables matching against IntegralType for expressions: - * {{{ - * case Cast(child @ IntegralType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] - override private[sql] def defaultConcreteType: DataType = IntegerType override private[sql] def simpleString: String = "integral" @@ -182,16 +155,7 @@ private[sql] object IntegralType extends AbstractDataType { private[sql] abstract class IntegralType extends NumericType -private[sql] object FractionalType { - /** - * Enables matching against FractionalType for expressions: - * {{{ - * case Cast(child @ FractionalType(), StringType) => - * ... - * }}} - */ - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType] -} +private[sql] object FractionalType private[sql] abstract class FractionalType extends NumericType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index f78a8de5e6a..893a41f3e39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -30,7 +30,6 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkThrowable import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer} @@ -50,15 +49,6 @@ import org.apache.spark.util.Utils @JsonSerialize(using = classOf[DataTypeJsonSerializer]) @JsonDeserialize(using = classOf[DataTypeJsonDeserializer]) abstract class DataType extends AbstractDataType { - /** - * Enables matching against DataType for expressions: - * {{{ - * case Cast(child @ BinaryType(), StringType) => - * ... - * }}} - */ - private[sql] def unapply(e: Expression): Boolean = e.dataType == this - /** * The default size of a value of this data type, used internally for size estimation. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala new file mode 100644 index 00000000000..f88e266b943 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.expressions.Expression + +abstract class DataTypeExpression(val dataType: DataType) { + /** + * Enables matching against DataType for expressions: + * {{{ + * case Cast(child @ BinaryType(), StringType) => + * ... + * }}} + */ + private[sql] def unapply(e: Expression): Boolean = e.dataType == dataType +} + +case object BooleanTypeExpression extends DataTypeExpression(BooleanType) +case object StringTypeExpression extends DataTypeExpression(StringType) +case object TimestampTypeExpression extends DataTypeExpression(TimestampType) +case object DateTypeExpression extends DataTypeExpression(DateType) +case object ByteTypeExpression extends DataTypeExpression(ByteType) +case object ShortTypeExpression extends DataTypeExpression(ShortType) +case object IntegerTypeExpression extends DataTypeExpression(IntegerType) +case object LongTypeExpression extends DataTypeExpression(LongType) +case object DoubleTypeExpression extends DataTypeExpression(DoubleType) +case object FloatTypeExpression extends DataTypeExpression(FloatType) + +object NumericTypeExpression { + /** + * Enables matching against NumericType for expressions: + * {{{ + * case Cast(child @ NumericType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = { + e.dataType.isInstanceOf[NumericType] + } +} + +object IntegralTypeExpression { + /** + * Enables matching against IntegralType for expressions: + * {{{ + * case Cast(child @ IntegralType(), StringType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = { + e.dataType.isInstanceOf[IntegralType] + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 9defd87aa7d..08615b90d80 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -50,7 +50,7 @@ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateFormatter, Type import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{AtomicType, DateType, IntegralType, StringType} +import org.apache.spark.sql.types.{AtomicType, DateType, IntegralType, IntegralTypeExpression, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -987,7 +987,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def unapply(expr: Expression): Option[Attribute] = { expr match { case attr: Attribute => Some(attr) - case Cast(child @ IntegralType(), dt: IntegralType, _, _) + case Cast(child @ IntegralTypeExpression(), dt: IntegralType, _, _) if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) case _ => None } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org