This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d2332d9c47e [SPARK-39470][SQL] Support cast of ANSI intervals to decimals d2332d9c47e is described below commit d2332d9c47e8f250a015d6dc5edb028b334aa905 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Thu Jun 16 12:07:43 2022 +0300 [SPARK-39470][SQL] Support cast of ANSI intervals to decimals ### What changes were proposed in this pull request? In the PR, I propose to support casts of ANSI intervals to decimals, and follow the SQL standard: <img width="801" alt="Screenshot 2022-06-12 at 13 04 44" src="https://user-images.githubusercontent.com/1580697/173663908-71945980-5638-4b46-9020-4d2e4badef0c.png"> ### Why are the changes needed? To improve user experience with Spark SQL, and to conform to the SQL standard. ### Does this PR introduce _any_ user-facing change? No, it just extends existing behavior of casts. Before: ```sql spark-sql> SELECT CAST(INTERVAL '1.001002' SECOND AS DECIMAL(10, 6)); Error in query: cannot resolve 'CAST(INTERVAL '01.001002' SECOND AS DECIMAL(10,6))' due to data type mismatch: cannot cast interval second to decimal(10,6); line 1 pos 7; 'Project [unresolvedalias(cast(INTERVAL '01.001002' SECOND as decimal(10,6)), None)] +- OneRowRelation ``` After: ``` spark-sql> SELECT CAST(INTERVAL '1.001002' SECOND AS DECIMAL(10, 6)); 1.001002 ``` ### How was this patch tested? By running new tests: ``` $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z cast.sql" $ build/sbt "test:testOnly *CastWithAnsiOnSuite" $ build/sbt "test:testOnly *CastWithAnsiOffSuite" ``` Closes #36857 from MaxGekk/cast-ansi-intervals-to-decimal. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../spark/sql/catalyst/expressions/Cast.scala | 59 ++++++++++++++++--- .../spark/sql/catalyst/util/IntervalUtils.scala | 9 +++ .../sql/catalyst/expressions/CastSuiteBase.scala | 33 +++++++++++ .../src/test/resources/sql-tests/inputs/cast.sql | 10 ++++ .../resources/sql-tests/results/ansi/cast.sql.out | 68 ++++++++++++++++++++++ .../test/resources/sql-tests/results/cast.sql.out | 65 +++++++++++++++++++++ 6 files changed, 237 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0746bc0fcd0..45950607e0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE -import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort} +import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToDecimal, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -110,7 +110,7 @@ object Cast { case (StringType, _: CalendarIntervalType) => true case (StringType, _: AnsiIntervalType) => true - case (_: AnsiIntervalType, _: IntegralType) => true + case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true @@ -194,8 +194,7 @@ object Cast { case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true - case (_: DayTimeIntervalType, _: IntegralType) => true - case (_: YearMonthIntervalType, _: IntegralType) => true + case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true @@ -967,10 +966,17 @@ case class Cast( * NOTE: this modifies `value` in-place, so don't call it on external data. */ private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { + changePrecision(value, decimalType, !ansiEnabled) + } + + private[this] def changePrecision( + value: Decimal, + decimalType: DecimalType, + nullOnOverflow: Boolean): Decimal = { if (value.changePrecision(decimalType.precision, decimalType.scale)) { value } else { - if (!ansiEnabled) { + if (nullOnOverflow) { null } else { throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( @@ -1015,6 +1021,18 @@ case class Cast( } catch { case _: NumberFormatException => null } + case x: DayTimeIntervalType => + buildCast[Long](_, dt => + changePrecision( + value = dayTimeIntervalToDecimal(dt, x.endField), + decimalType = target, + nullOnOverflow = false)) + case x: YearMonthIntervalType => + buildCast[Int](_, ym => + changePrecision( + value = Decimal(yearMonthIntervalToInt(ym, x.startField, x.endField)), + decimalType = target, + nullOnOverflow = false)) } // DoubleConverter @@ -1515,14 +1533,15 @@ case class Cast( evPrim: ExprValue, evNull: ExprValue, canNullSafeCast: Boolean, - ctx: CodegenContext): Block = { + ctx: CodegenContext, + nullOnOverflow: Boolean): Block = { if (canNullSafeCast) { code""" |$d.changePrecision(${decimalType.precision}, ${decimalType.scale}); |$evPrim = $d; """.stripMargin } else { - val overflowCode = if (!ansiEnabled) { + val overflowCode = if (nullOnOverflow) { s"$evNull = true;" } else { s""" @@ -1540,6 +1559,16 @@ case class Cast( } } + private[this] def changePrecision( + d: ExprValue, + decimalType: DecimalType, + evPrim: ExprValue, + evNull: ExprValue, + canNullSafeCast: Boolean, + ctx: CodegenContext): Block = { + changePrecision(d, decimalType, evPrim, evNull, canNullSafeCast, ctx, !ansiEnabled) + } + private[this] def castToDecimalCode( from: DataType, target: DecimalType, @@ -1605,6 +1634,22 @@ case class Cast( $evNull = true; } """ + case x: DayTimeIntervalType => + (c, evPrim, evNull) => + val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") + code""" + Decimal $tmp = $u.dayTimeIntervalToDecimal($c, (byte)${x.endField}); + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)} + """ + case x: YearMonthIntervalType => + (c, evPrim, evNull) => + val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") + val tmpYm = ctx.freshVariable("tmpYm", classOf[Int]) + code""" + int $tmpYm = $u.yearMonthIntervalToInt($c, (byte)${x.startField}, (byte)${x.endField}); + Decimal $tmp = Decimal.apply($tmpYm); + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)} + """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index dad58b7ae45..721f50208ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -1346,6 +1346,15 @@ object IntervalUtils { } } + def dayTimeIntervalToDecimal(v: Long, endField: Byte): Decimal = { + endField match { + case DAY => Decimal(v / MICROS_PER_DAY) + case HOUR => Decimal(v / MICROS_PER_HOUR) + case MINUTE => Decimal(v / MICROS_PER_MINUTE) + case SECOND => Decimal(v, Decimal.MAX_LONG_DIGITS, 6) + } + } + def dayTimeIntervalToInt(v: Long, startField: Byte, endField: Byte): Int = { val vLong = dayTimeIntervalToLong(v, startField, endField) val vInt = vLong.toInt diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index ca492e11226..97cbc781829 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -1272,4 +1272,37 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { "to restore the behavior before Spark 3.0.")) } } + + test("cast ANSI intervals to decimals") { + Seq( + (Duration.ZERO, DayTimeIntervalType(DAY), DecimalType(10, 3)) -> Decimal(0, 10, 3), + (Duration.ofHours(-1), DayTimeIntervalType(HOUR), DecimalType(10, 1)) -> Decimal(-10, 10, 1), + (Duration.ofMinutes(1), DayTimeIntervalType(MINUTE), DecimalType(8, 2)) -> Decimal(100, 8, 2), + (Duration.ofSeconds(59), DayTimeIntervalType(SECOND), DecimalType(6, 0)) -> Decimal(59, 6, 0), + (Duration.ofSeconds(-60).minusMillis(1), DayTimeIntervalType(SECOND), + DecimalType(10, 3)) -> Decimal(-60.001, 10, 3), + (Duration.ZERO, DayTimeIntervalType(DAY, SECOND), DecimalType(10, 6)) -> Decimal(0, 10, 6), + (Duration.ofHours(-23).minusMinutes(59).minusSeconds(59).minusNanos(123456000), + DayTimeIntervalType(HOUR, SECOND), DecimalType(18, 6)) -> Decimal(-86399.123456, 18, 6), + (Period.ZERO, YearMonthIntervalType(YEAR), DecimalType(5, 2)) -> Decimal(0, 5, 2), + (Period.ofMonths(-1), YearMonthIntervalType(MONTH), + DecimalType(8, 0)) -> Decimal(-1, 8, 0), + (Period.ofYears(-1).minusMonths(1), YearMonthIntervalType(YEAR, MONTH), + DecimalType(8, 3)) -> Decimal(-13000, 8, 3) + ).foreach { case ((duration, intervalType, targetType), expected) => + checkEvaluation( + Cast(Literal.create(duration, intervalType), targetType), + expected) + } + + dayTimeIntervalTypes.foreach { it => + checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) => + Cast(child, DecimalType.USER_DEFAULT), it) + } + + yearMonthIntervalTypes.foreach { it => + checkConsistencyBetweenInterpretedAndCodegenAllowingException((child: Expression) => + Cast(child, DecimalType.USER_DEFAULT), it) + } + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql index 5198611a2b3..66a78ec9473 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql @@ -116,3 +116,13 @@ select cast(interval '10' day as bigint); select cast(interval '-1000' month as tinyint); select cast(interval '1000000' second as smallint); + +-- cast ANSI intervals to decimals +select cast(interval '-1' year as decimal(10, 0)); +select cast(interval '1.000001' second as decimal(10, 6)); +select cast(interval '08:11:10.001' hour to second as decimal(10, 4)); +select cast(interval '1 01:02:03.1' day to second as decimal(8, 1)); +select cast(interval '10.123' second as decimal(4, 2)); +select cast(interval '10.005' second as decimal(4, 2)); +select cast(interval '10.123' second as decimal(5, 2)); +select cast(interval '10.123' second as decimal(1, 0)); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out index 11753f2b5ca..470a6081c46 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out @@ -838,3 +838,71 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException [CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. + + +-- !query +select cast(interval '-1' year as decimal(10, 0)) +-- !query schema +struct<CAST(INTERVAL '-1' YEAR AS DECIMAL(10,0)):decimal(10,0)> +-- !query output +-1 + + +-- !query +select cast(interval '1.000001' second as decimal(10, 6)) +-- !query schema +struct<CAST(INTERVAL '01.000001' SECOND AS DECIMAL(10,6)):decimal(10,6)> +-- !query output +1.000001 + + +-- !query +select cast(interval '08:11:10.001' hour to second as decimal(10, 4)) +-- !query schema +struct<CAST(INTERVAL '08:11:10.001' HOUR TO SECOND AS DECIMAL(10,4)):decimal(10,4)> +-- !query output +29470.0010 + + +-- !query +select cast(interval '1 01:02:03.1' day to second as decimal(8, 1)) +-- !query schema +struct<CAST(INTERVAL '1 01:02:03.1' DAY TO SECOND AS DECIMAL(8,1)):decimal(8,1)> +-- !query output +90123.1 + + +-- !query +select cast(interval '10.123' second as decimal(4, 2)) +-- !query schema +struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(4,2)):decimal(4,2)> +-- !query output +10.12 + + +-- !query +select cast(interval '10.005' second as decimal(4, 2)) +-- !query schema +struct<CAST(INTERVAL '10.005' SECOND AS DECIMAL(4,2)):decimal(4,2)> +-- !query output +10.01 + + +-- !query +select cast(interval '10.123' second as decimal(5, 2)) +-- !query schema +struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(5,2)):decimal(5,2)> +-- !query output +10.12 + + +-- !query +select cast(interval '10.123' second as decimal(1, 0)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. +== SQL(line 1, position 8) == +select cast(interval '10.123' second as decimal(1, 0)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out index 9c00e1b985e..911eaff30b9 100644 --- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out @@ -666,3 +666,68 @@ struct<> -- !query output org.apache.spark.SparkArithmeticException [CAST_OVERFLOW] The value INTERVAL '1000000' SECOND of the type "INTERVAL SECOND" cannot be cast to "SMALLINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. + + +-- !query +select cast(interval '-1' year as decimal(10, 0)) +-- !query schema +struct<CAST(INTERVAL '-1' YEAR AS DECIMAL(10,0)):decimal(10,0)> +-- !query output +-1 + + +-- !query +select cast(interval '1.000001' second as decimal(10, 6)) +-- !query schema +struct<CAST(INTERVAL '01.000001' SECOND AS DECIMAL(10,6)):decimal(10,6)> +-- !query output +1.000001 + + +-- !query +select cast(interval '08:11:10.001' hour to second as decimal(10, 4)) +-- !query schema +struct<CAST(INTERVAL '08:11:10.001' HOUR TO SECOND AS DECIMAL(10,4)):decimal(10,4)> +-- !query output +29470.0010 + + +-- !query +select cast(interval '1 01:02:03.1' day to second as decimal(8, 1)) +-- !query schema +struct<CAST(INTERVAL '1 01:02:03.1' DAY TO SECOND AS DECIMAL(8,1)):decimal(8,1)> +-- !query output +90123.1 + + +-- !query +select cast(interval '10.123' second as decimal(4, 2)) +-- !query schema +struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(4,2)):decimal(4,2)> +-- !query output +10.12 + + +-- !query +select cast(interval '10.005' second as decimal(4, 2)) +-- !query schema +struct<CAST(INTERVAL '10.005' SECOND AS DECIMAL(4,2)):decimal(4,2)> +-- !query output +10.01 + + +-- !query +select cast(interval '10.123' second as decimal(5, 2)) +-- !query schema +struct<CAST(INTERVAL '10.123' SECOND AS DECIMAL(5,2)):decimal(5,2)> +-- !query output +10.12 + + +-- !query +select cast(interval '10.123' second as decimal(1, 0)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +[CANNOT_CHANGE_DECIMAL_PRECISION] Decimal(compact, 10, 18, 6) cannot be represented as Decimal(1, 0). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org