This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 61ce8f7 [SPARK-35680][SQL] Add fields to `YearMonthIntervalType` 61ce8f7 is described below commit 61ce8f764982306f2c7a8b2b3dfe22963b49f2d5 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Tue Jun 15 23:08:12 2021 +0300 [SPARK-35680][SQL] Add fields to `YearMonthIntervalType` ### What changes were proposed in this pull request? Extend `YearMonthIntervalType` to support interval fields. Valid interval field values: - 0 (YEAR) - 1 (MONTH) After the changes, the following year-month interval types are supported: 1. `YearMonthIntervalType(0, 0)` or `YearMonthIntervalType(YEAR, YEAR)` 2. `YearMonthIntervalType(0, 1)` or `YearMonthIntervalType(YEAR, MONTH)`. **It is the default one**. 3. `YearMonthIntervalType(1, 1)` or `YearMonthIntervalType(MONTH, MONTH)` Closes #32825 ### Why are the changes needed? In the current implementation, Spark supports only `interval year to month` but the SQL standard allows to specify the start and end fields. The changes will allow to follow ANSI SQL standard more precisely. ### Does this PR introduce _any_ user-facing change? Yes but `YearMonthIntervalType` has not been released yet. ### How was this patch tested? By existing test suites. Closes #32909 from MaxGekk/add-fields-to-YearMonthIntervalType. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 11 ++--- .../java/org/apache/spark/sql/types/DataTypes.java | 19 +++++--- .../sql/catalyst/CatalystTypeConverters.scala | 3 +- .../apache/spark/sql/catalyst/InternalRow.scala | 4 +- .../spark/sql/catalyst/JavaTypeInference.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 10 ++--- .../spark/sql/catalyst/SerializerBuildHelper.scala | 2 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 18 ++++---- .../apache/spark/sql/catalyst/dsl/package.scala | 7 ++- .../spark/sql/catalyst/encoders/RowEncoder.scala | 6 +-- .../spark/sql/catalyst/expressions/Cast.scala | 35 +++++++++------ .../expressions/InterpretedUnsafeProjection.scala | 2 +- .../catalyst/expressions/SpecificInternalRow.scala | 2 +- .../catalyst/expressions/aggregate/Average.scala | 6 +-- .../sql/catalyst/expressions/aggregate/Sum.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 10 ++--- .../expressions/codegen/CodeGenerator.scala | 4 +- .../expressions/collectionOperations.scala | 8 ++-- .../catalyst/expressions/datetimeExpressions.scala | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 2 +- .../catalyst/expressions/intervalExpressions.scala | 10 ++--- .../spark/sql/catalyst/expressions/literals.scala | 16 ++++--- .../catalyst/expressions/windowExpressions.scala | 4 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 6 ++- .../spark/sql/catalyst/util/IntervalUtils.scala | 17 ++++++- .../apache/spark/sql/catalyst/util/TypeUtils.scala | 4 +- .../spark/sql/errors/QueryCompilationErrors.scala | 11 +++++ .../org/apache/spark/sql/types/DataType.scala | 4 +- .../spark/sql/types/YearMonthIntervalType.scala | 52 ++++++++++++++++++---- .../org/apache/spark/sql/util/ArrowUtils.scala | 4 +- .../org/apache/spark/sql/RandomDataGenerator.scala | 2 +- .../spark/sql/RandomDataGeneratorSuite.scala | 4 +- .../sql/catalyst/CatalystTypeConvertersSuite.scala | 2 +- .../sql/catalyst/encoders/RowEncoderSuite.scala | 18 ++++---- .../expressions/ArithmeticExpressionSuite.scala | 10 ++--- .../spark/sql/catalyst/expressions/CastSuite.scala | 30 ++++++------- .../sql/catalyst/expressions/CastSuiteBase.scala | 8 ++-- .../expressions/DateExpressionsSuite.scala | 19 ++++---- .../expressions/HashExpressionsSuite.scala | 2 +- .../expressions/IntervalExpressionsSuite.scala | 24 ++++++---- .../expressions/LiteralExpressionSuite.scala | 6 +-- .../catalyst/expressions/LiteralGenerator.scala | 4 +- .../expressions/MutableProjectionSuite.scala | 14 +++--- .../optimizer/PushFoldableIntoBranchesSuite.scala | 16 +++---- .../sql/catalyst/parser/DataTypeParserSuite.scala | 2 +- .../sql/catalyst/util/IntervalUtilsSuite.scala | 5 ++- .../org/apache/spark/sql/types/DataTypeSuite.scala | 6 +-- .../apache/spark/sql/types/DataTypeTestUtils.scala | 18 ++++---- .../apache/spark/sql/util/ArrowUtilsSuite.scala | 2 +- .../apache/spark/sql/execution/HiveResult.scala | 4 +- .../sql/execution/aggregate/HashMapGenerator.scala | 3 +- .../spark/sql/execution/aggregate/udaf.scala | 4 +- .../spark/sql/execution/arrow/ArrowWriter.scala | 2 +- .../sql/execution/columnar/ColumnAccessor.scala | 2 +- .../sql/execution/columnar/ColumnBuilder.scala | 2 +- .../spark/sql/execution/columnar/ColumnType.scala | 4 +- .../columnar/GenerateColumnAccessor.scala | 2 +- .../sql/execution/window/WindowExecBase.scala | 4 +- .../apache/spark/sql/DataFrameAggregateSuite.scala | 14 ++++-- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 9 ++-- .../sql/execution/arrow/ArrowWriterSuite.scala | 9 ++-- .../SparkExecuteStatementOperation.scala | 4 +- .../thriftserver/SparkGetColumnsOperation.scala | 5 ++- .../thriftserver/SparkMetadataOperationSuite.scala | 4 +- .../org/apache/spark/sql/hive/HiveInspectors.scala | 10 ++--- .../execution/HiveScriptTransformationSuite.scala | 3 +- 66 files changed, 340 insertions(+), 220 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 572e901..5088d06 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -76,7 +76,7 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo */ public static final Set<DataType> mutableFieldTypes; - // DecimalType and DayTimeIntervalType are also mutable + // DecimalType, DayTimeIntervalType and YearMonthIntervalType are also mutable static { mutableFieldTypes = Collections.unmodifiableSet( new HashSet<>( @@ -90,8 +90,7 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo FloatType, DoubleType, DateType, - TimestampType, - YearMonthIntervalType + TimestampType }))); } @@ -103,7 +102,8 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo if (dt instanceof DecimalType) { return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS(); } else { - return dt instanceof DayTimeIntervalType || mutableFieldTypes.contains(dt); + return dt instanceof DayTimeIntervalType || dt instanceof YearMonthIntervalType || + mutableFieldTypes.contains(dt); } } @@ -113,7 +113,8 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo } return mutableFieldTypes.contains(dt) || dt instanceof DecimalType || - dt instanceof CalendarIntervalType || dt instanceof DayTimeIntervalType; + dt instanceof CalendarIntervalType || dt instanceof DayTimeIntervalType || + dt instanceof YearMonthIntervalType; } ////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 6604837..25d0a00 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -100,11 +100,6 @@ public class DataTypes { public static final DataType NullType = NullType$.MODULE$; /** - * Gets the YearMonthIntervalType object. - */ - public static final DataType YearMonthIntervalType = YearMonthIntervalType$.MODULE$; - - /** * Creates an ArrayType by specifying the data type of elements ({@code elementType}). * The field of {@code containsNull} is set to {@code true}. */ @@ -155,6 +150,20 @@ public class DataTypes { } /** + * Creates a YearMonthIntervalType by specifying the start and end fields. + */ + public static YearMonthIntervalType createYearMonthIntervalType(byte startField, byte endField) { + return YearMonthIntervalType$.MODULE$.apply(startField, endField); + } + + /** + * Creates a YearMonthIntervalType with default start and end fields: interval year to month. + */ + public static YearMonthIntervalType createYearMonthIntervalType() { + return YearMonthIntervalType$.MODULE$.DEFAULT(); + } + + /** * Creates a MapType by specifying the data type of keys ({@code keyType}) and values * ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 2efdf37..38790e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -77,7 +77,8 @@ object CatalystTypeConverters { case DoubleType => DoubleConverter // TODO(SPARK-35726): Truncate java.time.Duration by fields of day-time interval type case _: DayTimeIntervalType => DurationConverter - case YearMonthIntervalType => PeriodConverter + // TODO(SPARK-35769): Truncate java.time.Period by fields of year-month interval type + case _: YearMonthIntervalType => PeriodConverter case dataType: DataType => IdentityConverter(dataType) } converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index ab668d7..b431d67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -132,7 +132,7 @@ object InternalRow { case BooleanType => (input, ordinal) => input.getBoolean(ordinal) case ByteType => (input, ordinal) => input.getByte(ordinal) case ShortType => (input, ordinal) => input.getShort(ordinal) - case IntegerType | DateType | YearMonthIntervalType => + case IntegerType | DateType | _: YearMonthIntervalType => (input, ordinal) => input.getInt(ordinal) case LongType | TimestampType | TimestampWithoutTZType | _: DayTimeIntervalType => (input, ordinal) => input.getLong(ordinal) @@ -169,7 +169,7 @@ object InternalRow { case BooleanType => (input, v) => input.setBoolean(ordinal, v.asInstanceOf[Boolean]) case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte]) case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short]) - case IntegerType | DateType | YearMonthIntervalType => + case IntegerType | DateType | _: YearMonthIntervalType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) case LongType | TimestampType | TimestampWithoutTZType | _: DayTimeIntervalType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 6a31d4a..3a51957d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -121,7 +121,7 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) case c: Class[_] if c == classOf[java.time.LocalDateTime] => (TimestampWithoutTZType, true) case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType(), true) - case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType, true) + case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType(), true) case _ if typeToken.isArray => val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c6854e9..680b61f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -761,7 +761,7 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.time.Duration]) => Schema(DayTimeIntervalType(), nullable = true) case t if isSubtype(t, localTypeOf[java.time.Period]) => - Schema(YearMonthIntervalType, nullable = true) + Schema(YearMonthIntervalType(), nullable = true) case t if isSubtype(t, localTypeOf[BigDecimal]) => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => @@ -860,8 +860,7 @@ object ScalaReflection extends ScalaReflection { TimestampType -> classOf[TimestampType.InternalType], TimestampWithoutTZType -> classOf[TimestampWithoutTZType.InternalType], BinaryType -> classOf[BinaryType.InternalType], - CalendarIntervalType -> classOf[CalendarInterval], - YearMonthIntervalType -> classOf[YearMonthIntervalType.InternalType] + CalendarIntervalType -> classOf[CalendarInterval] ) val typeBoxedJavaMapping = Map[DataType, Class[_]]( @@ -874,14 +873,14 @@ object ScalaReflection extends ScalaReflection { DoubleType -> classOf[java.lang.Double], DateType -> classOf[java.lang.Integer], TimestampType -> classOf[java.lang.Long], - TimestampWithoutTZType -> classOf[java.lang.Long], - YearMonthIntervalType -> classOf[java.lang.Integer] + TimestampWithoutTZType -> classOf[java.lang.Long] ) def dataTypeJavaClass(dt: DataType): Class[_] = { dt match { case _: DecimalType => classOf[Decimal] case it: DayTimeIntervalType => classOf[it.InternalType] + case it: YearMonthIntervalType => classOf[it.InternalType] case _: StructType => classOf[InternalRow] case _: ArrayType => classOf[ArrayData] case _: MapType => classOf[MapData] @@ -893,6 +892,7 @@ object ScalaReflection extends ScalaReflection { def javaBoxedType(dt: DataType): Class[_] = dt match { case _: DecimalType => classOf[Decimal] case _: DayTimeIntervalType => classOf[java.lang.Long] + case _: YearMonthIntervalType => classOf[java.lang.Integer] case BinaryType => classOf[Array[Byte]] case StringType => classOf[UTF8String] case CalendarIntervalType => classOf[CalendarInterval] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 9045d53..9b81f07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -125,7 +125,7 @@ object SerializerBuildHelper { def createSerializerForJavaPeriod(inputObject: Expression): Expression = { StaticInvoke( IntervalUtils.getClass, - YearMonthIntervalType, + YearMonthIntervalType(), "periodToMonths", inputObject :: Nil, returnNullable = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 668a661..e22c012 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -351,10 +351,10 @@ class Analyzer(override val catalogManager: CatalogManager) case a @ Add(l, r, f) if a.childrenResolved => (l.dataType, r.dataType) match { case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, TimestampType), r) case (_: DayTimeIntervalType, DateType) => TimeAdd(Cast(r, TimestampType), l) - case (DateType, YearMonthIntervalType) => DateAddYMInterval(l, r) - case (YearMonthIntervalType, DateType) => DateAddYMInterval(r, l) - case (TimestampType, YearMonthIntervalType) => TimestampAddYMInterval(l, r) - case (YearMonthIntervalType, TimestampType) => TimestampAddYMInterval(r, l) + case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(l, r) + case (_: YearMonthIntervalType, DateType) => DateAddYMInterval(r, l) + case (TimestampType, _: YearMonthIntervalType) => TimestampAddYMInterval(l, r) + case (_: YearMonthIntervalType, TimestampType) => TimestampAddYMInterval(r, l) case (CalendarIntervalType, CalendarIntervalType) | (_: DayTimeIntervalType, _: DayTimeIntervalType) => a case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f) @@ -368,9 +368,9 @@ class Analyzer(override val catalogManager: CatalogManager) case s @ Subtract(l, r, f) if s.childrenResolved => (l.dataType, r.dataType) match { case (DateType, _: DayTimeIntervalType) => DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, f))) - case (DateType, YearMonthIntervalType) => + case (DateType, _: YearMonthIntervalType) => DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, f))) - case (TimestampType, YearMonthIntervalType) => + case (TimestampType, _: YearMonthIntervalType) => DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f))) case (CalendarIntervalType, CalendarIntervalType) | (_: DayTimeIntervalType, _: DayTimeIntervalType) => s @@ -387,15 +387,15 @@ class Analyzer(override val catalogManager: CatalogManager) case m @ Multiply(l, r, f) if m.childrenResolved => (l.dataType, r.dataType) match { case (CalendarIntervalType, _) => MultiplyInterval(l, r, f) case (_, CalendarIntervalType) => MultiplyInterval(r, l, f) - case (YearMonthIntervalType, _) => MultiplyYMInterval(l, r) - case (_, YearMonthIntervalType) => MultiplyYMInterval(r, l) + case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r) + case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l) case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r) case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l) case _ => m } case d @ Divide(l, r, f) if d.childrenResolved => (l.dataType, r.dataType) match { case (CalendarIntervalType, _) => DivideInterval(l, r, f) - case (YearMonthIntervalType, _) => DivideYMInterval(l, r) + case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r) case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r) case _ => d } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 211af78..b8f74ee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -310,8 +310,11 @@ package object dsl { } /** Creates a new AttributeReference of the year-month interval type */ - def yearMonthInterval: AttributeReference = { - AttributeReference(s, YearMonthIntervalType, nullable = true)() + def yearMonthInterval(startField: Byte, endField: Byte): AttributeReference = { + AttributeReference(s, YearMonthIntervalType(startField, endField), nullable = true)() + } + def yearMonthInterval(): AttributeReference = { + AttributeReference(s, YearMonthIntervalType(), nullable = true)() } /** Creates a new AttributeReference of type binary */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index cf22b35..5373a2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -116,7 +116,7 @@ object RowEncoder { case _: DayTimeIntervalType => createSerializerForJavaDuration(inputObject) - case YearMonthIntervalType => createSerializerForJavaPeriod(inputObject) + case _: YearMonthIntervalType => createSerializerForJavaPeriod(inputObject) case d: DecimalType => CheckOverflow(StaticInvoke( @@ -239,7 +239,7 @@ object RowEncoder { ObjectType(classOf[java.sql.Date]) } case _: DayTimeIntervalType => ObjectType(classOf[java.time.Duration]) - case YearMonthIntervalType => ObjectType(classOf[java.time.Period]) + case _: YearMonthIntervalType => ObjectType(classOf[java.time.Period]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) case StringType => ObjectType(classOf[java.lang.String]) case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) @@ -299,7 +299,7 @@ object RowEncoder { case _: DayTimeIntervalType => createDeserializerForDuration(input) - case YearMonthIntervalType => createDeserializerForPeriod(input) + case _: YearMonthIntervalType => createDeserializerForPeriod(input) case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5598085..1bf2851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -79,7 +79,7 @@ object Cast { case (StringType, CalendarIntervalType) => true case (StringType, _: DayTimeIntervalType) => true - case (StringType, YearMonthIntervalType) => true + case (StringType, _: YearMonthIntervalType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true @@ -422,9 +422,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case pudt: PythonUserDefinedType => castToString(pudt.sqlType) case udt: UserDefinedType[_] => buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) - case YearMonthIntervalType => + case YearMonthIntervalType(startField, endField) => buildCast[Int](_, i => UTF8String.fromString( - IntervalUtils.toYearMonthIntervalString(i, ANSI_STYLE))) + IntervalUtils.toYearMonthIntervalString(i, ANSI_STYLE, startField, endField))) case DayTimeIntervalType(startField, endField) => buildCast[Long](_, i => UTF8String.fromString( IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField, endField))) @@ -566,8 +566,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit IntervalUtils.castStringToDTInterval(s, it.startField, it.endField)) } - private[this] def castToYearMonthInterval(from: DataType): Any => Any = from match { - case StringType => buildCast[UTF8String](_, s => IntervalUtils.castStringToYMInterval(s)) + private[this] def castToYearMonthInterval( + from: DataType, + it: YearMonthIntervalType): Any => Any = from match { + case StringType => buildCast[UTF8String](_, s => + IntervalUtils.castStringToYMInterval(s, it.startField, it.endField)) } // LongConverter @@ -876,7 +879,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case TimestampWithoutTZType => castToTimestampWithoutTZ(from) case CalendarIntervalType => castToInterval(from) case it: DayTimeIntervalType => castToDayTimeInterval(from, it) - case YearMonthIntervalType => castToYearMonthInterval(from) + case it: YearMonthIntervalType => castToYearMonthInterval(from, it) case BooleanType => castToBoolean(from) case ByteType => castToByte(from) case ShortType => castToShort(from) @@ -937,7 +940,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case TimestampWithoutTZType => castToTimestampWithoutTZCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) case it: DayTimeIntervalType => castToDayTimeIntervalCode(from, it) - case YearMonthIntervalType => castToYearMonthIntervalCode(from) + case it: YearMonthIntervalType => castToYearMonthIntervalCode(from, it) case BooleanType => castToBooleanCode(from) case ByteType => castToByteCode(from, ctx) case ShortType => castToShortCode(from, ctx) @@ -1176,15 +1179,16 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit (c, evPrim, evNull) => { code"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" } - case YearMonthIntervalType => + case i: YearMonthIntervalType => val iu = IntervalUtils.getClass.getName.stripSuffix("$") val iss = IntervalStringStyles.getClass.getName.stripSuffix("$") val style = s"$iss$$.MODULE$$.ANSI_STYLE()" (c, evPrim, _) => code""" - $evPrim = UTF8String.fromString($iu.toYearMonthIntervalString($c, $style)); + $evPrim = UTF8String.fromString($iu.toYearMonthIntervalString($c, $style, + (byte)${i.startField}, (byte)${i.endField})); """ - case i : DayTimeIntervalType => + case i: DayTimeIntervalType => val iu = IntervalUtils.getClass.getName.stripSuffix("$") val iss = IntervalStringStyles.getClass.getName.stripSuffix("$") val style = s"$iss$$.MODULE$$.ANSI_STYLE()" @@ -1441,10 +1445,15 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit """ } - private[this] def castToYearMonthIntervalCode(from: DataType): CastFunction = from match { + private[this] def castToYearMonthIntervalCode( + from: DataType, + it: YearMonthIntervalType): CastFunction = from match { case StringType => val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") - (c, evPrim, _) => code"$evPrim = $util.castStringToYMInterval($c);" + (c, evPrim, _) => + code""" + $evPrim = $util.castStringToYMInterval($c, (byte)${it.startField}, (byte)${it.endField}); + """ } private[this] def decimalToTimestampCode(d: ExprValue): Block = { @@ -2012,7 +2021,7 @@ object AnsiCast { case (StringType, _: CalendarIntervalType) => true case (StringType, _: DayTimeIntervalType) => true - case (StringType, YearMonthIntervalType) => true + case (StringType, _: YearMonthIntervalType) => true case (StringType, DateType) => true case (TimestampType, DateType) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 79dbce0..2058af1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -157,7 +157,7 @@ object InterpretedUnsafeProjection { case ShortType => (v, i) => writer.write(i, v.getShort(i)) - case IntegerType | DateType | YearMonthIntervalType => + case IntegerType | DateType | _: YearMonthIntervalType => (v, i) => writer.write(i, v.getInt(i)) case LongType | TimestampType | TimestampWithoutTZType | _: DayTimeIntervalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index 891ac82..80438e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -194,7 +194,7 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match { // We use INT for DATE and YearMonthIntervalType internally - case IntegerType | DateType | YearMonthIntervalType => new MutableInt + case IntegerType | DateType | _: YearMonthIntervalType => new MutableInt // We use Long for Timestamp, Timestamp without time zone and DayTimeInterval internally case LongType | TimestampType | TimestampWithoutTZType | _: DayTimeIntervalType => new MutableLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index a64ca57..77a6cf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -57,14 +57,14 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) - case _: YearMonthIntervalType => YearMonthIntervalType + case _: YearMonthIntervalType => YearMonthIntervalType() case _: DayTimeIntervalType => DayTimeIntervalType() case _ => DoubleType } private lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) - case _: YearMonthIntervalType => YearMonthIntervalType + case _: YearMonthIntervalType => YearMonthIntervalType() case _: DayTimeIntervalType => DayTimeIntervalType() case _ => DoubleType } @@ -92,7 +92,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) case _: YearMonthIntervalType => If(EqualTo(count, Literal(0L)), - Literal(null, YearMonthIntervalType), DivideYMInterval(sum, count)) + Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count)) case _: DayTimeIntervalType => If(EqualTo(count, Literal(0L)), Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 09d88c4..80dda69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -59,7 +59,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) case _: IntegralType => LongType - case _: YearMonthIntervalType => YearMonthIntervalType + case it: YearMonthIntervalType => it case it: DayTimeIntervalType => it case _ => DoubleType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9527df9..a1fdbb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -85,7 +85,7 @@ case class UnaryMinus( val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") val method = if (failOnError) "negateExact" else "negate" defineCodeGen(ctx, ev, c => s"$iu.$method($c)") - case _: DayTimeIntervalType | YearMonthIntervalType => + case _: DayTimeIntervalType | _: YearMonthIntervalType => nullSafeCodeGen(ctx, ev, eval => { val mathClass = classOf[Math].getName s"${ev.value} = $mathClass.negateExact($eval);" @@ -97,7 +97,7 @@ case class UnaryMinus( IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval]) case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) case _: DayTimeIntervalType => Math.negateExact(input.asInstanceOf[Long]) - case YearMonthIntervalType => Math.negateExact(input.asInstanceOf[Int]) + case _: YearMonthIntervalType => Math.negateExact(input.asInstanceOf[Int]) case _ => numeric.negate(input) } @@ -229,7 +229,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { case CalendarIntervalType => val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)") - case _: DayTimeIntervalType | YearMonthIntervalType => + case _: DayTimeIntervalType | _: YearMonthIntervalType => assert(exactMathMethod.isDefined, s"The expression '$nodeName' must override the exactMathMethod() method " + "if it is supposed to operate over interval types.") @@ -319,7 +319,7 @@ case class Add( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case _: DayTimeIntervalType => Math.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long]) - case YearMonthIntervalType => + case _: YearMonthIntervalType => Math.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int]) case _ => numeric.plus(input1, input2) } @@ -365,7 +365,7 @@ case class Subtract( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) case _: DayTimeIntervalType => Math.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long]) - case YearMonthIntervalType => + case _: YearMonthIntervalType => Math.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int]) case _ => numeric.minus(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index db7a349..9831b13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1816,7 +1816,7 @@ object CodeGenerator extends Logging { case BooleanType => JAVA_BOOLEAN case ByteType => JAVA_BYTE case ShortType => JAVA_SHORT - case IntegerType | DateType | YearMonthIntervalType => JAVA_INT + case IntegerType | DateType | _: YearMonthIntervalType => JAVA_INT case LongType | TimestampType | TimestampWithoutTZType | _: DayTimeIntervalType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE @@ -1837,7 +1837,7 @@ object CodeGenerator extends Logging { case BooleanType => java.lang.Boolean.TYPE case ByteType => java.lang.Byte.TYPE case ShortType => java.lang.Short.TYPE - case IntegerType | DateType | YearMonthIntervalType => java.lang.Integer.TYPE + case IntegerType | DateType | _: YearMonthIntervalType => java.lang.Integer.TYPE case LongType | TimestampType | TimestampWithoutTZType | _: DayTimeIntervalType => java.lang.Long.TYPE case FloatType => java.lang.Float.TYPE diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 72041e0..41e7de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2564,14 +2564,14 @@ case class Sequence( |1. The start and stop expressions must resolve to the same type. |2. If start and stop expressions resolve to the 'date' or 'timestamp' type |then the step expression must resolve to the 'interval' or - |'${YearMonthIntervalType.typeName}' or '${DayTimeIntervalType.simpleString}' type, + |'${YearMonthIntervalType.simpleString}' or '${DayTimeIntervalType.simpleString}' type, |otherwise to the same type as the start and stop expressions. """.stripMargin) } } private def isNotIntervalType(expr: Expression) = expr.dataType match { - case CalendarIntervalType | YearMonthIntervalType | _: DayTimeIntervalType => false + case CalendarIntervalType | _: YearMonthIntervalType | _: DayTimeIntervalType => false case _ => true } @@ -2749,10 +2749,10 @@ object Sequence { override val defaultStep: DefaultStep = new DefaultStep( (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], - YearMonthIntervalType, + YearMonthIntervalType(), Period.of(0, 1, 0)) - val intervalType: DataType = YearMonthIntervalType + val intervalType: DataType = YearMonthIntervalType() def splitStep(input: Any): (Int, Int, Long) = { (input.asInstanceOf[Int], 0, 0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index b43bc07..7a0b555 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -2396,7 +2396,7 @@ object DatePart { throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(fieldStr, source) source.dataType match { - case YearMonthIntervalType | _: DayTimeIntervalType | CalendarIntervalType => + case _: YearMonthIntervalType | _: DayTimeIntervalType | CalendarIntervalType => ExtractIntervalPart.parseExtractField(fieldStr, source, analysisException) case _ => DatePart.parseExtractField(fieldStr, source, analysisException) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 9b8b2b9..d730586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -496,7 +496,7 @@ abstract class HashExpression[E] extends Expression { case d: DecimalType => genHashDecimal(ctx, d, input, result) case CalendarIntervalType => genHashCalendarInterval(input, result) case _: DayTimeIntervalType => genHashLong(input, result) - case YearMonthIntervalType => genHashInt(input, result) + case _: YearMonthIntervalType => genHashInt(input, result) case BinaryType => genHashBytes(input, result) case StringType => genHashString(input, result) case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index adc1f4c..92a8ce0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -125,11 +125,11 @@ object ExtractIntervalPart { source: Expression, errorHandleFunc: => Nothing): Expression = { (extractField.toUpperCase(Locale.ROOT), source.dataType) match { - case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType) => + case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", _: YearMonthIntervalType) => ExtractANSIIntervalYears(source) case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) => ExtractIntervalYears(source) - case ("MONTH" | "MON" | "MONS" | "MONTHS", YearMonthIntervalType) => + case ("MONTH" | "MON" | "MONS" | "MONTHS", _: YearMonthIntervalType) => ExtractANSIIntervalMonths(source) case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) => ExtractIntervalMonths(source) @@ -374,7 +374,7 @@ case class MakeYMInterval(years: Expression, months: Expression) override def left: Expression = years override def right: Expression = months override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType) - override def dataType: DataType = YearMonthIntervalType + override def dataType: DataType = YearMonthIntervalType() override def nullSafeEval(year: Any, month: Any): Any = { Math.toIntExact(Math.addExact(month.asInstanceOf[Number].longValue(), @@ -407,7 +407,7 @@ case class MultiplyYMInterval( override def right: Expression = num override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType) - override def dataType: DataType = YearMonthIntervalType + override def dataType: DataType = YearMonthIntervalType() @transient private lazy val evalFunc: (Int, Any) => Any = right.dataType match { @@ -517,7 +517,7 @@ case class DivideYMInterval( override def right: Expression = num override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, NumericType) - override def dataType: DataType = YearMonthIntervalType + override def dataType: DataType = YearMonthIntervalType() @transient private lazy val evalFunc: (Int, Any) => Any = right.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 92a6e83..d31634c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -84,7 +84,7 @@ object Literal { case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType()) - case p: Period => Literal(periodToMonths(p), YearMonthIntervalType) + case p: Period => Literal(periodToMonths(p), YearMonthIntervalType()) case a: Array[Byte] => Literal(a, BinaryType) case a: collection.mutable.WrappedArray[_] => apply(a.array) case a: Array[_] => @@ -122,7 +122,7 @@ object Literal { case _ if clz == classOf[Timestamp] => TimestampType case _ if clz == classOf[LocalDateTime] => TimestampWithoutTZType case _ if clz == classOf[Duration] => DayTimeIntervalType() - case _ if clz == classOf[Period] => YearMonthIntervalType + case _ if clz == classOf[Period] => YearMonthIntervalType() case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[Array[Byte]] => BinaryType case _ if clz == classOf[Array[Char]] => StringType @@ -181,7 +181,7 @@ object Literal { case TimestampType => create(0L, TimestampType) case TimestampWithoutTZType => create(0L, TimestampWithoutTZType) case it: DayTimeIntervalType => create(0L, it) - case YearMonthIntervalType => create(0, YearMonthIntervalType) + case it: YearMonthIntervalType => create(0, it) case StringType => Literal("") case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8)) case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0)) @@ -200,7 +200,7 @@ object Literal { case BooleanType => v.isInstanceOf[Boolean] case ByteType => v.isInstanceOf[Byte] case ShortType => v.isInstanceOf[Short] - case IntegerType | DateType | YearMonthIntervalType => v.isInstanceOf[Int] + case IntegerType | DateType | _: YearMonthIntervalType => v.isInstanceOf[Int] case LongType | TimestampType | TimestampWithoutTZType | _: DayTimeIntervalType => v.isInstanceOf[Long] case FloatType => v.isInstanceOf[Float] @@ -348,7 +348,8 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { TimestampFormatter.getFractionFormatter(timeZoneId).format(value.asInstanceOf[Long]) case DayTimeIntervalType(startField, endField) => toDayTimeIntervalString(value.asInstanceOf[Long], ANSI_STYLE, startField, endField) - case YearMonthIntervalType => toYearMonthIntervalString(value.asInstanceOf[Int], ANSI_STYLE) + case YearMonthIntervalType(startField, endField) => + toYearMonthIntervalString(value.asInstanceOf[Int], ANSI_STYLE, startField, endField) case _ => other.toString } @@ -401,7 +402,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) } dataType match { - case BooleanType | IntegerType | DateType | YearMonthIntervalType => + case BooleanType | IntegerType | DateType | _: YearMonthIntervalType => toExprCode(value.toString) case FloatType => value.asInstanceOf[Float] match { @@ -471,7 +472,8 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { case (v: Array[Byte], BinaryType) => s"X'${DatatypeConverter.printHexBinary(v)}'" case (i: Long, DayTimeIntervalType(startField, endField)) => toDayTimeIntervalString(i, ANSI_STYLE, startField, endField) - case (i: Int, YearMonthIntervalType) => toYearMonthIntervalString(i, ANSI_STYLE) + case (i: Int, YearMonthIntervalType(startField, endField)) => + toYearMonthIntervalString(i, ANSI_STYLE, startField, endField) case _ => value.toString } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 2960ec2..2555c6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -101,9 +101,9 @@ case class WindowSpecDefinition( private def isValidFrameType(ft: DataType): Boolean = (orderSpec.head.dataType, ft) match { case (DateType, IntegerType) => true - case (DateType, YearMonthIntervalType) => true + case (DateType, _: YearMonthIntervalType) => true case (TimestampType, CalendarIntervalType) => true - case (TimestampType, YearMonthIntervalType) => true + case (TimestampType, _: YearMonthIntervalType) => true case (TimestampType, _: DayTimeIntervalType) => true case (a, b) => a == b } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 165cf13..4ff7691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2354,7 +2354,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg val toUnit = ctx.errorCapturingUnitToUnitInterval.body.to.getText.toLowerCase(Locale.ROOT) if (toUnit == "month") { assert(calendarInterval.days == 0 && calendarInterval.microseconds == 0) - Literal(calendarInterval.months, YearMonthIntervalType) + // TODO(SPARK-35773): Parse year-month interval literals to tightest types + Literal(calendarInterval.months, YearMonthIntervalType()) } else { assert(calendarInterval.months == 0) val fromUnit = @@ -2513,7 +2514,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } override def visitYearMonthIntervalDataType(ctx: YearMonthIntervalDataTypeContext): DataType = { - YearMonthIntervalType + // TODO(SPARK-35774): Parse any year-month interval types in SQL + YearMonthIntervalType() } override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index dda5581..aca3d15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -106,7 +106,11 @@ object IntervalUtils { private val yearMonthLiteralRegex = (s"(?i)^INTERVAL\\s+([+|-])?'$yearMonthPatternString'\\s+YEAR\\s+TO\\s+MONTH$$").r - def castStringToYMInterval(input: UTF8String): Int = { + def castStringToYMInterval( + input: UTF8String, + // TODO(SPARK-35768): Take into account year-month interval fields in cast + startField: Byte, + endField: Byte): Int = { input.trimAll().toString match { case yearMonthRegex("-", year, month) => toYMInterval(year, month, -1) case yearMonthRegex(_, year, month) => toYMInterval(year, month, 1) @@ -934,9 +938,16 @@ object IntervalUtils { * * @param months The number of months, positive or negative * @param style The style of textual representation of the interval + * @param startField The start field (YEAR or MONTH) which the interval comprises of. + * @param endField The end field (YEAR or MONTH) which the interval comprises of. * @return Year-month interval string */ - def toYearMonthIntervalString(months: Int, style: IntervalStyle): String = { + def toYearMonthIntervalString( + months: Int, + style: IntervalStyle, + // TODO(SPARK-35771): Format year-month intervals using type fields + startField: Byte, + endField: Byte): String = { var sign = "" var absMonths: Long = months if (months < 0) { @@ -956,6 +967,8 @@ object IntervalUtils { * * @param micros The number of microseconds, positive or negative * @param style The style of textual representation of the interval + * @param startField The start field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. + * @param endField The end field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. * @return Day-time interval string */ def toDayTimeIntervalString( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index ea1e227..015dca8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -63,7 +63,7 @@ object TypeUtils { def checkForAnsiIntervalOrNumericType( dt: DataType, funcName: String): TypeCheckResult = dt match { - case YearMonthIntervalType | _: DayTimeIntervalType | NullType => + case _: YearMonthIntervalType | _: DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess case other => TypeCheckResult.TypeCheckFailure( @@ -117,7 +117,7 @@ object TypeUtils { def invokeOnceForInterval(dataType: DataType)(f: => Unit): Unit = { def isInterval(dataType: DataType): Boolean = dataType match { - case CalendarIntervalType | _: DayTimeIntervalType | YearMonthIntervalType => true + case CalendarIntervalType | _: DayTimeIntervalType | _: YearMonthIntervalType => true case _ => false } if (dataType.existsRecursively(isInterval)) f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e2822f7..6559716 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1636,4 +1636,15 @@ private[spark] object QueryCompilationErrors { def invalidDayTimeIntervalType(startFieldName: String, endFieldName: String): Throwable = { new AnalysisException(s"'interval $startFieldName to $endFieldName' is invalid.") } + + def invalidYearMonthField(field: Byte): Throwable = { + val supportedIds = YearMonthIntervalType.yearMonthFields + .map(i => s"$i (${YearMonthIntervalType.fieldToString(i)})") + new AnalysisException(s"Invalid field id '$field' in year-month interval. " + + s"Supported interval fields: ${supportedIds.mkString(", ")}.") + } + + def invalidYearMonthIntervalType(startFieldName: String, endFieldName: String): Throwable = { + new AnalysisException(s"'interval $startFieldName to $endFieldName' is invalid.") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 11d33f0..35fbe5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -172,7 +172,9 @@ object DataType { Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType, // TODO(SPARK-35732): Parse DayTimeIntervalType from JSON - DayTimeIntervalType(), YearMonthIntervalType, TimestampWithoutTZType) + DayTimeIntervalType(), + // TODO(SPARK-35770): Parse YearMonthIntervalType from JSON + YearMonthIntervalType(), TimestampWithoutTZType) .map(t => t.typeName -> t).toMap } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala index 8ee4bef..e6e2643 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala @@ -21,6 +21,8 @@ import scala.math.Ordering import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.Unstable +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.YearMonthIntervalType.fieldToString /** * The type represents year-month intervals of the SQL standard. A year-month interval is made up @@ -30,12 +32,15 @@ import org.apache.spark.annotation.Unstable * * `YearMonthIntervalType` represents positive as well as negative year-month intervals. * - * Please use the singleton `DataTypes.YearMonthIntervalType` to refer the type. + * @param startField The leftmost field which the type comprises of. Valid values: + * 0 (YEAR), 1 (MONTH). + * @param endField The rightmost field which the type comprises of. Valid values: + * 0 (YEAR), 1 (MONTH). * * @since 3.2.0 */ @Unstable -class YearMonthIntervalType private() extends AtomicType { +case class YearMonthIntervalType(startField: Byte, endField: Byte) extends AtomicType { /** * Internally, values of year-month intervals are stored in `Int` values as amount of months * that are calculated by the formula: @@ -55,16 +60,47 @@ class YearMonthIntervalType private() extends AtomicType { private[spark] override def asNullable: YearMonthIntervalType = this - override def typeName: String = "interval year to month" + override val typeName: String = { + val startFieldName = fieldToString(startField) + val endFieldName = fieldToString(endField) + if (startFieldName == endFieldName) { + s"interval $startFieldName" + } else if (startField < endField) { + s"interval $startFieldName to $endFieldName" + } else { + throw QueryCompilationErrors.invalidDayTimeIntervalType(startFieldName, endFieldName) + } + } } /** - * The companion case object and its class is separated so the companion object also subclasses - * the YearMonthIntervalType class. Otherwise, the companion object would be of type - * "YearMonthIntervalType$" in byte code. Defined with a private constructor so the companion object - * is the only possible instantiation. + * Extra factory methods and pattern matchers for YearMonthIntervalType. * * @since 3.2.0 */ @Unstable -case object YearMonthIntervalType extends YearMonthIntervalType +case object YearMonthIntervalType extends AbstractDataType { + val YEAR: Byte = 0 + val MONTH: Byte = 1 + val yearMonthFields = Seq(YEAR, MONTH) + + def fieldToString(field: Byte): String = field match { + case YEAR => "year" + case MONTH => "month" + case invalid => throw QueryCompilationErrors.invalidYearMonthField(invalid) + } + + val stringToField: Map[String, Byte] = yearMonthFields.map(i => fieldToString(i) -> i).toMap + + val DEFAULT = YearMonthIntervalType(YEAR, MONTH) + + def apply(): YearMonthIntervalType = DEFAULT + + override private[sql] def defaultConcreteType: DataType = DEFAULT + + override private[sql] def acceptsType(other: DataType): Boolean = { + other.isInstanceOf[YearMonthIntervalType] + } + + override private[sql] def simpleString: String = defaultConcreteType.simpleString +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 67d9dfd..48a5491 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -54,7 +54,7 @@ private[sql] object ArrowUtils { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } case NullType => ArrowType.Null.INSTANCE - case YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) case _: DayTimeIntervalType => new ArrowType.Interval(IntervalUnit.DAY_TIME) case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") @@ -76,7 +76,7 @@ private[sql] object ArrowUtils { case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case ArrowType.Null.INSTANCE => NullType - case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType + case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType() case di: ArrowType.Interval if di.getUnit == IntervalUnit.DAY_TIME => DayTimeIntervalType() case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 3b4978e..603c88d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -284,7 +284,7 @@ object RandomDataGenerator { new CalendarInterval(months, days, ns) }) case _: DayTimeIntervalType => Some(() => Duration.of(rand.nextLong(), ChronoUnit.MICROS)) - case YearMonthIntervalType => Some(() => Period.ofMonths(rand.nextInt()).normalized()) + case _: YearMonthIntervalType => Some(() => Period.ofMonths(rand.nextInt()).normalized()) case DecimalType.Fixed(precision, scale) => Some( () => BigDecimal.apply( rand.nextLong() % math.pow(10, precision).toLong, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index cd8fae5..e217730 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes +import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} /** * Tests of [[RandomDataGenerator]]. @@ -146,7 +146,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite with SQLHelper { } test("SPARK-35116: The generated data fits the precision of DayTimeIntervalType in spark") { - (dayTimeIntervalTypes :+ YearMonthIntervalType).foreach { dt => + (dayTimeIntervalTypes ++ yearMonthIntervalTypes).foreach { dt => for (seed <- 1 to 1000) { val generator = RandomDataGenerator.forType(dt, false, new Random(seed)).get val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dt) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 3a90e94..1c2359c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -320,7 +320,7 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { val months = sign * input val period = IntervalUtils.monthsToPeriod(months) assert( - CatalystTypeConverters.createToScalaConverter(YearMonthIntervalType)(months) === period) + CatalystTypeConverters.createToScalaConverter(YearMonthIntervalType())(months) === period) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 316845d..ad1c11e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes +import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) class ExamplePoint(val x: Double, val y: Double) extends Serializable { @@ -366,13 +366,15 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { } test("SPARK-34615: encoding/decoding YearMonthIntervalType to/from java.time.Period") { - val schema = new StructType().add("p", YearMonthIntervalType) - val encoder = RowEncoder(schema).resolveAndBind() - val period = java.time.Period.ofMonths(1) - val row = toRow(encoder, Row(period)) - assert(row.getInt(0) === IntervalUtils.periodToMonths(period)) - val readback = fromRow(encoder, row) - assert(readback.get(0).equals(period)) + yearMonthIntervalTypes.foreach { yearMonthIntervalType => + val schema = new StructType().add("p", yearMonthIntervalType) + val encoder = RowEncoder(schema).resolveAndBind() + val period = java.time.Period.ofMonths(1) + val row = toRow(encoder, Row(period)) + assert(row.getInt(0) === IntervalUtils.periodToMonths(period)) + val readback = fromRow(encoder, row) + assert(readback.get(0).equals(period)) + } } for { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index ec8ae7f..d1bb3e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -614,20 +614,20 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Seq(true, false).foreach { failOnError => checkExceptionInExpression[ArithmeticException]( UnaryMinus( - Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType), + Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType()), failOnError), "overflow") checkExceptionInExpression[ArithmeticException]( Subtract( - Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType), - Literal.create(Period.ofMonths(10), YearMonthIntervalType), + Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType()), + Literal.create(Period.ofMonths(10), YearMonthIntervalType()), failOnError ), "overflow") checkExceptionInExpression[ArithmeticException]( Add( - Literal.create(Period.ofMonths(Int.MaxValue), YearMonthIntervalType), - Literal.create(Period.ofMonths(10), YearMonthIntervalType), + Literal.create(Period.ofMonths(Int.MaxValue), YearMonthIntervalType()), + Literal.create(Period.ofMonths(10), YearMonthIntervalType()), failOnError ), "overflow") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1e5ff53..c121741 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -612,43 +612,43 @@ class CastSuite extends CastSuiteBase { test("SPARK-35111: Cast string to year-month interval") { checkEvaluation(cast(Literal.create("INTERVAL '1-0' YEAR TO MONTH"), - YearMonthIntervalType), 12) + YearMonthIntervalType()), 12) checkEvaluation(cast(Literal.create("INTERVAL '-1-0' YEAR TO MONTH"), - YearMonthIntervalType), -12) + YearMonthIntervalType()), -12) checkEvaluation(cast(Literal.create("INTERVAL -'-1-0' YEAR TO MONTH"), - YearMonthIntervalType), 12) + YearMonthIntervalType()), 12) checkEvaluation(cast(Literal.create("INTERVAL +'-1-0' YEAR TO MONTH"), - YearMonthIntervalType), -12) + YearMonthIntervalType()), -12) checkEvaluation(cast(Literal.create("INTERVAL +'+1-0' YEAR TO MONTH"), - YearMonthIntervalType), 12) + YearMonthIntervalType()), 12) checkEvaluation(cast(Literal.create("INTERVAL +'1-0' YEAR TO MONTH"), - YearMonthIntervalType), 12) + YearMonthIntervalType()), 12) checkEvaluation(cast(Literal.create(" interval +'1-0' YEAR TO MONTH "), - YearMonthIntervalType), 12) - checkEvaluation(cast(Literal.create(" -1-0 "), YearMonthIntervalType), -12) - checkEvaluation(cast(Literal.create("-1-0"), YearMonthIntervalType), -12) - checkEvaluation(cast(Literal.create(null, StringType), YearMonthIntervalType), null) + YearMonthIntervalType()), 12) + checkEvaluation(cast(Literal.create(" -1-0 "), YearMonthIntervalType()), -12) + checkEvaluation(cast(Literal.create("-1-0"), YearMonthIntervalType()), -12) + checkEvaluation(cast(Literal.create(null, StringType), YearMonthIntervalType()), null) Seq("0-0", "10-1", "-178956970-7", "178956970-7", "-178956970-8").foreach { interval => val ansiInterval = s"INTERVAL '$interval' YEAR TO MONTH" checkEvaluation( - cast(cast(Literal.create(interval), YearMonthIntervalType), StringType), ansiInterval) + cast(cast(Literal.create(interval), YearMonthIntervalType()), StringType), ansiInterval) checkEvaluation(cast(cast(Literal.create(ansiInterval), - YearMonthIntervalType), StringType), ansiInterval) + YearMonthIntervalType()), StringType), ansiInterval) } Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH") .foreach { interval => val e = intercept[IllegalArgumentException] { - cast(Literal.create(interval), YearMonthIntervalType).eval() + cast(Literal.create(interval), YearMonthIntervalType()).eval() }.getMessage assert(e.contains("Error parsing interval year-month string: integer overflow")) } Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue) .foreach { period => - val interval = Literal.create(Period.ofMonths(period), YearMonthIntervalType) - checkEvaluation(cast(cast(interval, StringType), YearMonthIntervalType), period) + val interval = Literal.create(Period.ofMonths(period), YearMonthIntervalType()) + checkEvaluation(cast(cast(interval, StringType), YearMonthIntervalType()), period) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 46522e6..fdea6d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.IntervalUtils.microsToDuration import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes +import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} import org.apache.spark.unsafe.types.UTF8String abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { @@ -831,8 +831,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { s"INTERVAL '$intervalPayload' YEAR TO MONTH") } - checkConsistencyBetweenInterpretedAndCodegen( - (child: Expression) => Cast(child, StringType), YearMonthIntervalType) + yearMonthIntervalTypes.foreach { it => + checkConsistencyBetweenInterpretedAndCodegen( + (child: Expression) => Cast(child, StringType), it) + } } test("SPARK-34668: cast day-time interval to string") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 77fdc3f..6e584a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, TimeZoneUTC} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes +import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -526,7 +526,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { private def testAddMonths(dataType: DataType): Unit = { def addMonths(date: Literal, months: Any): AddMonthsBase = dataType match { case IntegerType => AddMonths(date, Literal.create(months, dataType)) - case YearMonthIntervalType => + case _: YearMonthIntervalType => val period = if (months == null) null else Period.ofMonths(months.asInstanceOf[Int]) DateAddYMInterval(date, Literal.create(period, dataType)) } @@ -561,7 +561,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-34721: add a year-month interval to a date") { - testAddMonths(YearMonthIntervalType) + testAddMonths(YearMonthIntervalType()) // Test evaluation results between Interpreted mode and Codegen mode forAll ( LiteralGenerator.randomGen(DateType), @@ -1596,18 +1596,21 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( TimestampAddYMInterval( Literal(new Timestamp(sdf.parse("2016-01-29 10:00:00.000").getTime)), - Literal.create(null, YearMonthIntervalType), + Literal.create(null, YearMonthIntervalType()), timeZoneId), null) checkEvaluation( TimestampAddYMInterval( Literal.create(null, TimestampType), - Literal.create(null, YearMonthIntervalType), + Literal.create(null, YearMonthIntervalType()), timeZoneId), null) - checkConsistencyBetweenInterpretedAndCodegen( - (ts: Expression, interval: Expression) => TimestampAddYMInterval(ts, interval, timeZoneId), - TimestampType, YearMonthIntervalType) + yearMonthIntervalTypes.foreach { it => + checkConsistencyBetweenInterpretedAndCodegen( + (ts: Expression, interval: Expression) => + TimestampAddYMInterval(ts, interval, timeZoneId), + TimestampType, it) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 585f03e..aa010ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -699,7 +699,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-35113: HashExpression support DayTimeIntervalType/YearMonthIntervalType") { val dayTime = Literal.create(Duration.ofSeconds(1237123123), DayTimeIntervalType()) - val yearMonth = Literal.create(Period.ofMonths(1234), YearMonthIntervalType) + val yearMonth = Literal.create(Period.ofMonths(1234), YearMonthIntervalType()) checkEvaluation(Murmur3Hash(Seq(dayTime), 10), -428664612) checkEvaluation(Murmur3Hash(Seq(yearMonth), 10), -686520021) checkEvaluation(XxHash64(Seq(dayTime), 10), 8228802290839366895L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index afc6223..3ae952f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.IntervalUtils.{safeStringToInterval, stringToInterval} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DayTimeIntervalType, Decimal, DecimalType, YearMonthIntervalType} -import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, numericTypes} +import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, numericTypes, yearMonthIntervalTypes} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -280,6 +280,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + // TODO(SPARK-35778): Check multiply/divide of year-month intervals of any fields by numeric test("SPARK-34824: multiply year-month interval by numeric") { Seq( (Period.ofYears(-123), Literal(null, DecimalType.USER_DEFAULT)) -> null, @@ -307,9 +308,11 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } numericTypes.foreach { numType => - checkConsistencyBetweenInterpretedAndCodegenAllowingException( - (interval: Expression, num: Expression) => MultiplyYMInterval(interval, num), - YearMonthIntervalType, numType) + yearMonthIntervalTypes.foreach { it => + checkConsistencyBetweenInterpretedAndCodegenAllowingException( + (interval: Expression, num: Expression) => MultiplyYMInterval(interval, num), + it, numType) + } } } @@ -349,6 +352,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + // TODO(SPARK-35778): Check multiply/divide of year-month intervals of any fields by numeric test("SPARK-34868: divide year-month interval by numeric") { Seq( (Period.ofYears(-123), Literal(null, DecimalType.USER_DEFAULT)) -> null, @@ -376,9 +380,11 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } numericTypes.foreach { numType => - checkConsistencyBetweenInterpretedAndCodegenAllowingException( - (interval: Expression, num: Expression) => DivideYMInterval(interval, num), - YearMonthIntervalType, numType) + yearMonthIntervalTypes.foreach { it => + checkConsistencyBetweenInterpretedAndCodegenAllowingException( + (interval: Expression, num: Expression) => DivideYMInterval(interval, num), + it, numType) + } } } @@ -429,8 +435,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ExtractANSIIntervalMonths(Literal(p)), IntervalUtils.getMonths(p.toTotalMonths.toInt)) } - checkEvaluation(ExtractANSIIntervalYears(Literal(null, YearMonthIntervalType)), null) - checkEvaluation(ExtractANSIIntervalMonths(Literal(null, YearMonthIntervalType)), null) + checkEvaluation(ExtractANSIIntervalYears(Literal(null, YearMonthIntervalType())), null) + checkEvaluation(ExtractANSIIntervalMonths(Literal(null, YearMonthIntervalType())), null) } test("ANSI: extract days, hours, minutes and seconds") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 7baffbf..6410651 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -50,7 +50,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, DateType), null) checkEvaluation(Literal.create(null, TimestampType), null) checkEvaluation(Literal.create(null, CalendarIntervalType), null) - checkEvaluation(Literal.create(null, YearMonthIntervalType), null) + checkEvaluation(Literal.create(null, YearMonthIntervalType()), null) checkEvaluation(Literal.create(null, DayTimeIntervalType()), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) checkEvaluation(Literal.create(null, ArrayType(StringType, true)), null) @@ -79,7 +79,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.default(TimestampType), Instant.ofEpochSecond(0)) } checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0, 0L)) - checkEvaluation(Literal.default(YearMonthIntervalType), 0) + checkEvaluation(Literal.default(YearMonthIntervalType()), 0) checkEvaluation(Literal.default(DayTimeIntervalType()), 0L) checkEvaluation(Literal.default(ArrayType(StringType)), Array()) checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map()) @@ -345,7 +345,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { assert(Literal(Array("1", "2", "3")) == Literal.create(Array("1", "2", "3"), ArrayType(StringType))) assert(Literal(Array(Period.ofMonths(1))) == - Literal.create(Array(Period.ofMonths(1)), ArrayType(YearMonthIntervalType))) + Literal.create(Array(Period.ofMonths(1)), ArrayType(YearMonthIntervalType()))) } test("SPARK-34342: Date/Timestamp toString") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala index ac04897..b6f0c2b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -184,7 +184,7 @@ object LiteralGenerator { lazy val yearMonthIntervalLiteralGen: Gen[Literal] = { for { months <- Gen.choose(-1 * maxIntervalInMonths, maxIntervalInMonths) } - yield Literal.create(Period.ofMonths(months), YearMonthIntervalType) + yield Literal.create(Period.ofMonths(months), YearMonthIntervalType()) } def randomGen(dt: DataType): Gen[Literal] = { @@ -204,7 +204,7 @@ object LiteralGenerator { case CalendarIntervalType => calendarIntervalLiterGen case DecimalType.Fixed(precision, scale) => decimalLiteralGen(precision, scale) case _: DayTimeIntervalType => dayTimeIntervalLiteralGen - case YearMonthIntervalType => yearMonthIntervalLiteralGen + case _: YearMonthIntervalType => yearMonthIntervalLiteralGen case dt => throw new IllegalArgumentException(s"not supported type $dt") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 8804000..0f01bfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes +import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.UTF8String @@ -31,7 +31,7 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { val fixedLengthTypes = Array[DataType]( BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - DateType, TimestampType, YearMonthIntervalType) ++ dayTimeIntervalTypes + DateType, TimestampType) ++ dayTimeIntervalTypes ++ yearMonthIntervalTypes val variableLengthTypes = Array( StringType, DecimalType.defaultConcreteType, CalendarIntervalType, BinaryType, @@ -44,16 +44,18 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { testBothCodegenAndInterpreted("fixed-length types") { val inputRow = InternalRow.fromSeq(Seq( - true, 3.toByte, 15.toShort, -83, 129L, 1.0f, 5.0, 1, 2L, Int.MaxValue) ++ - Seq.tabulate(dayTimeIntervalTypes.length)(_ => Long.MaxValue)) + true, 3.toByte, 15.toShort, -83, 129L, 1.0f, 5.0, 1, 2L) ++ + Seq.tabulate(dayTimeIntervalTypes.length)(_ => Long.MaxValue) ++ + Seq.tabulate(yearMonthIntervalTypes.length)(_ => Int.MaxValue)) val proj = createMutableProjection(fixedLengthTypes) assert(proj(inputRow) === inputRow) } testBothCodegenAndInterpreted("unsafe buffer") { val inputRow = InternalRow.fromSeq(Seq( - false, 1.toByte, 9.toShort, -18, 53L, 3.2f, 7.8, 4, 9L, Int.MinValue) ++ - Seq.tabulate(dayTimeIntervalTypes.length)(_ => Long.MaxValue)) + false, 1.toByte, 9.toShort, -18, 53L, 3.2f, 7.8, 4, 9L) ++ + Seq.tabulate(dayTimeIntervalTypes.length)(_ => Long.MaxValue) ++ + Seq.tabulate(yearMonthIntervalTypes.length)(_ => Int.MaxValue)) val numFields = fixedLengthTypes.length val numBytes = Platform.BYTE_ARRAY_OFFSET + UnsafeRow.calculateBitSetWidthInBytes(numFields) + UnsafeRow.WORD_SIZE * numFields diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index f669b91..2f6cff3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -275,8 +275,8 @@ class PushFoldableIntoBranchesSuite Literal(new CalendarInterval(1, 2, 0)))), If(a, Literal(Date.valueOf("2021-02-02")), Literal(Date.valueOf("2021-02-03")))) assertEquivalent(DateAddYMInterval(Literal(d), - If(a, Literal.create(Period.ofMonths(1), YearMonthIntervalType), - Literal.create(Period.ofMonths(2), YearMonthIntervalType))), + If(a, Literal.create(Period.ofMonths(1), YearMonthIntervalType()), + Literal.create(Period.ofMonths(2), YearMonthIntervalType()))), If(a, Literal(Date.valueOf("2021-02-01")), Literal(Date.valueOf("2021-03-01")))) assertEquivalent(DateDiff(Literal(d), If(a, Literal(Date.valueOf("2021-02-01")), Literal(Date.valueOf("2021-03-01")))), @@ -286,8 +286,8 @@ class PushFoldableIntoBranchesSuite If(a, Literal(Date.valueOf("2020-12-31")), Literal(Date.valueOf("2020-12-30")))) assertEquivalent(TimestampAddYMInterval( Literal.create(Timestamp.valueOf("2021-01-01 00:00:00.000"), TimestampType), - If(a, Literal.create(Period.ofMonths(1), YearMonthIntervalType), - Literal.create(Period.ofMonths(2), YearMonthIntervalType))), + If(a, Literal.create(Period.ofMonths(1), YearMonthIntervalType()), + Literal.create(Period.ofMonths(2), YearMonthIntervalType()))), If(a, Literal.create(Timestamp.valueOf("2021-02-01 00:00:00"), TimestampType), Literal.create(Timestamp.valueOf("2021-03-01 00:00:00"), TimestampType))) assertEquivalent(TimeAdd( @@ -312,8 +312,8 @@ class PushFoldableIntoBranchesSuite CaseWhen(Seq((a, Literal(Date.valueOf("2021-02-02"))), (c, Literal(Date.valueOf("2021-02-03")))), None)) assertEquivalent(DateAddYMInterval(Literal(d), - CaseWhen(Seq((a, Literal.create(Period.ofMonths(1), YearMonthIntervalType)), - (c, Literal.create(Period.ofMonths(2), YearMonthIntervalType))), None)), + CaseWhen(Seq((a, Literal.create(Period.ofMonths(1), YearMonthIntervalType())), + (c, Literal.create(Period.ofMonths(2), YearMonthIntervalType()))), None)), CaseWhen(Seq((a, Literal(Date.valueOf("2021-02-01"))), (c, Literal(Date.valueOf("2021-03-01")))), None)) assertEquivalent(DateDiff(Literal(d), @@ -326,8 +326,8 @@ class PushFoldableIntoBranchesSuite (c, Literal(Date.valueOf("2020-12-30")))), None)) assertEquivalent(TimestampAddYMInterval( Literal.create(Timestamp.valueOf("2021-01-01 00:00:00.000"), TimestampType), - CaseWhen(Seq((a, Literal.create(Period.ofMonths(1), YearMonthIntervalType)), - (c, Literal.create(Period.ofMonths(2), YearMonthIntervalType))), None)), + CaseWhen(Seq((a, Literal.create(Period.ofMonths(1), YearMonthIntervalType())), + (c, Literal.create(Period.ofMonths(2), YearMonthIntervalType()))), None)), CaseWhen(Seq((a, Literal.create(Timestamp.valueOf("2021-02-01 00:00:00"), TimestampType)), (c, Literal.create(Timestamp.valueOf("2021-03-01 00:00:00"), TimestampType))), None)) assertEquivalent(TimeAdd( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 8e0a01d..bd8b4cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -63,7 +63,7 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("BINARY", BinaryType) checkDataType("void", NullType) checkDataType("interval", CalendarIntervalType) - checkDataType("INTERVAL YEAR TO MONTH", YearMonthIntervalType) + checkDataType("INTERVAL YEAR TO MONTH", YearMonthIntervalType()) checkDataType("interval day to second", DayTimeIntervalType()) checkDataType("array<doublE>", ArrayType(DoubleType, true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 0d496c9..c2ece95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -503,6 +503,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { } test("SPARK-35016: format year-month intervals") { + import org.apache.spark.sql.types.YearMonthIntervalType._ Seq( 0 -> ("0-0", "INTERVAL '0-0' YEAR TO MONTH"), -11 -> ("-0-11", "INTERVAL '-0-11' YEAR TO MONTH"), @@ -514,8 +515,8 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { Int.MinValue -> ("-178956970-8", "INTERVAL '-178956970-8' YEAR TO MONTH"), Int.MaxValue -> ("178956970-7", "INTERVAL '178956970-7' YEAR TO MONTH") ).foreach { case (months, (hiveIntervalStr, ansiIntervalStr)) => - assert(toYearMonthIntervalString(months, ANSI_STYLE) === ansiIntervalStr) - assert(toYearMonthIntervalString(months, HIVE_STYLE) === hiveIntervalStr) + assert(toYearMonthIntervalString(months, ANSI_STYLE, YEAR, MONTH) === ansiIntervalStr) + assert(toYearMonthIntervalString(months, HIVE_STYLE, YEAR, MONTH) === hiveIntervalStr) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index d2f13f8..d543af4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -22,7 +22,7 @@ import com.fasterxml.jackson.core.JsonParseException import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat -import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes +import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} class DataTypeSuite extends SparkFunSuite { @@ -256,7 +256,7 @@ class DataTypeSuite extends SparkFunSuite { checkDataTypeFromJson(VarcharType(10)) checkDataTypeFromDDL(VarcharType(11)) - checkDataTypeFromDDL(YearMonthIntervalType) + checkDataTypeFromDDL(YearMonthIntervalType()) dayTimeIntervalTypes.foreach(checkDataTypeFromDDL) val metadata = new MetadataBuilder() @@ -325,7 +325,7 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(CharType(100), 100) checkDefaultSize(VarcharType(5), 5) checkDefaultSize(VarcharType(10), 10) - checkDefaultSize(YearMonthIntervalType, 4) + yearMonthIntervalTypes.foreach(checkDefaultSize(_, 4)) dayTimeIntervalTypes.foreach(checkDefaultSize(_, 8)) def checkEqualsIgnoreCompatibleNullability( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 12f3783..d358ab6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} +import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} /** * Utility functions for working with DataTypes in tests. @@ -63,12 +64,16 @@ object DataTypeTestUtils { DayTimeIntervalType(MINUTE, SECOND), DayTimeIntervalType(SECOND, SECOND)) + val yearMonthIntervalTypes: Seq[YearMonthIntervalType] = Seq( + YearMonthIntervalType(YEAR, MONTH), + YearMonthIntervalType(YEAR, YEAR), + YearMonthIntervalType(MONTH, MONTH)) + /** * Instances of all [[NumericType]]s and [[CalendarIntervalType]] */ - val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal ++ Set( - CalendarIntervalType, - YearMonthIntervalType) ++ dayTimeIntervalTypes + val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal ++ + Set(CalendarIntervalType) ++ dayTimeIntervalTypes ++ yearMonthIntervalTypes /** * All the types that support ordering @@ -79,8 +84,7 @@ object DataTypeTestUtils { TimestampWithoutTZType, DateType, StringType, - BinaryType, - YearMonthIntervalType) ++ dayTimeIntervalTypes + BinaryType) ++ dayTimeIntervalTypes ++ yearMonthIntervalTypes /** * All the types that we can use in a property check @@ -96,9 +100,7 @@ object DataTypeTestUtils { DateType, StringType, TimestampType, - TimestampWithoutTZType, - YearMonthIntervalType - ) ++ dayTimeIntervalTypes + TimestampWithoutTZType) ++ dayTimeIntervalTypes ++ yearMonthIntervalTypes /** * Instances of [[ArrayType]] for all [[AtomicType]]s. Arrays of these types may contain null. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala index 7955d6b..642b387 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala @@ -48,7 +48,7 @@ class ArrowUtilsSuite extends SparkFunSuite { roundtrip(BinaryType) roundtrip(DecimalType.SYSTEM_DEFAULT) roundtrip(DateType) - roundtrip(YearMonthIntervalType) + roundtrip(YearMonthIntervalType()) roundtrip(DayTimeIntervalType()) val tsExMsg = intercept[UnsupportedOperationException] { roundtrip(TimestampType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index b49757f..e38b722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -117,8 +117,8 @@ object HiveResult { struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveString((v, t.dataType), true, formatters)}""" }.mkString("{", ",", "}") - case (period: Period, YearMonthIntervalType) => - toYearMonthIntervalString(periodToMonths(period), HIVE_STYLE) + case (period: Period, YearMonthIntervalType(startField, endField)) => + toYearMonthIntervalString(periodToMonths(period), HIVE_STYLE, startField, endField) case (duration: Duration, DayTimeIntervalType(startField, endField)) => toDayTimeIntervalString(durationToMicros(duration), HIVE_STYLE, startField, endField) case (other, _: UserDefinedType[_]) => other.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 2487312..b3f5e34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -158,7 +158,8 @@ abstract class HashMapGenerator( dataType match { case BooleanType => hashInt(s"$input ? 1 : 0") - case ByteType | ShortType | IntegerType | DateType | YearMonthIntervalType => hashInt(input) + case ByteType | ShortType | IntegerType | DateType | _: YearMonthIntervalType => + hashInt(input) case LongType | TimestampType | _: DayTimeIntervalType => hashLong(input) case FloatType => hashInt(s"Float.floatToIntBits($input)") case DoubleType => hashLong(s"Double.doubleToLongBits($input)") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 83caca2..8879e14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -87,7 +87,7 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getLong(ordinal) - case YearMonthIntervalType => + case _: YearMonthIntervalType => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getInt(ordinal) @@ -195,7 +195,7 @@ sealed trait BufferSetterGetterUtils { row.setNullAt(ordinal) } - case YearMonthIntervalType => + case _: YearMonthIntervalType => (row: InternalRow, ordinal: Int, value: Any) => if (value != null) { row.setInt(ordinal, value.asInstanceOf[Int]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 6786c2c..0afacf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -75,7 +75,7 @@ object ArrowWriter { } new StructWriter(vector, children.toArray) case (NullType, vector: NullVector) => new NullWriter(vector) - case (YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) + case (_: YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) case (_: DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector) case (dt, _) => throw QueryExecutionErrors.unsupportedDataTypeError(dt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 16c3191..2f68e89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -139,7 +139,7 @@ private[sql] object ColumnAccessor { case BooleanType => new BooleanColumnAccessor(buf) case ByteType => new ByteColumnAccessor(buf) case ShortType => new ShortColumnAccessor(buf) - case IntegerType | DateType | YearMonthIntervalType => new IntColumnAccessor(buf) + case IntegerType | DateType | _: YearMonthIntervalType => new IntColumnAccessor(buf) case LongType | TimestampType | _: DayTimeIntervalType => new LongColumnAccessor(buf) case FloatType => new FloatColumnAccessor(buf) case DoubleType => new DoubleColumnAccessor(buf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index e2a9f90..e9251e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -174,7 +174,7 @@ private[columnar] object ColumnBuilder { case BooleanType => new BooleanColumnBuilder case ByteType => new ByteColumnBuilder case ShortType => new ShortColumnBuilder - case IntegerType | DateType | YearMonthIntervalType => new IntColumnBuilder + case IntegerType | DateType | _: YearMonthIntervalType => new IntColumnBuilder case LongType | TimestampType | _: DayTimeIntervalType => new LongColumnBuilder case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index cd6b74a..8e99368 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -257,7 +257,7 @@ private[columnar] object LONG extends NativeColumnType(LongType, 8) { } } -private[columnar] object YEAR_MONTH_INTERVAL extends NativeColumnType(YearMonthIntervalType, 4) { +private[columnar] object YEAR_MONTH_INTERVAL extends NativeColumnType(YearMonthIntervalType(), 4) { override def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } @@ -817,7 +817,7 @@ private[columnar] object ColumnType { case BooleanType => BOOLEAN case ByteType => BYTE case ShortType => SHORT - case IntegerType | DateType | YearMonthIntervalType => INT + case IntegerType | DateType | _: YearMonthIntervalType => INT case LongType | TimestampType | _: DayTimeIntervalType => LONG case FloatType => FLOAT case DoubleType => DOUBLE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index f3ac428..190c2c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -80,7 +80,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case BooleanType => classOf[BooleanColumnAccessor].getName case ByteType => classOf[ByteColumnAccessor].getName case ShortType => classOf[ShortColumnAccessor].getName - case IntegerType | DateType | YearMonthIntervalType => classOf[IntColumnAccessor].getName + case IntegerType | DateType | _: YearMonthIntervalType => classOf[IntColumnAccessor].getName case LongType | TimestampType | _: DayTimeIntervalType => classOf[LongColumnAccessor].getName case FloatType => classOf[FloatColumnAccessor].getName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index f9b2c92..2aa0b02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -95,9 +95,9 @@ trait WindowExecBase extends UnaryExecNode { // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = (expr.dataType, boundOffset.dataType) match { case (DateType, IntegerType) => DateAdd(expr, boundOffset) - case (DateType, YearMonthIntervalType) => DateAddYMInterval(expr, boundOffset) + case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(expr, boundOffset) case (TimestampType, CalendarIntervalType) => TimeAdd(expr, boundOffset, Some(timeZone)) - case (TimestampType, YearMonthIntervalType) => + case (TimestampType, _: YearMonthIntervalType) => TimestampAddYMInterval(expr, boundOffset, Some(timeZone)) case (TimestampType, _: DayTimeIntervalType) => TimeAdd(expr, boundOffset, Some(timeZone)) case (a, b) if a == b => Add(expr, boundOffset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ff8361a..2cfa298 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1129,7 +1129,9 @@ class DataFrameAggregateSuite extends QueryTest val sumDF = df.select(sum($"year-month"), sum($"day-time")) checkAnswer(sumDF, Row(Period.of(2, 5, 0), Duration.ofDays(0))) assert(find(sumDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) - assert(sumDF.schema == StructType(Seq(StructField("sum(year-month)", YearMonthIntervalType), + assert(sumDF.schema == StructType(Seq( + // TODO(SPARK-35775): Check all year-month interval types in aggregate expressions + StructField("sum(year-month)", YearMonthIntervalType()), // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions StructField("sum(day-time)", DayTimeIntervalType())))) @@ -1139,7 +1141,8 @@ class DataFrameAggregateSuite extends QueryTest Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) :: Nil) assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), - StructField("sum(year-month)", YearMonthIntervalType), + // TODO(SPARK-35775): Check all year-month interval types in aggregate expressions + StructField("sum(year-month)", YearMonthIntervalType()), // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions StructField("sum(day-time)", DayTimeIntervalType())))) @@ -1169,7 +1172,9 @@ class DataFrameAggregateSuite extends QueryTest val avgDF = df.select(avg($"year-month"), avg($"day-time")) checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0))) assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) - assert(avgDF.schema == StructType(Seq(StructField("avg(year-month)", YearMonthIntervalType), + assert(avgDF.schema == StructType(Seq( + // TODO(SPARK-35775): Check all year-month interval types in aggregate expressions + StructField("avg(year-month)", YearMonthIntervalType()), // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions StructField("avg(day-time)", DayTimeIntervalType())))) @@ -1179,7 +1184,8 @@ class DataFrameAggregateSuite extends QueryTest Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) :: Nil) assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false), - StructField("avg(year-month)", YearMonthIntervalType), + // TODO(SPARK-35775): Check all year-month interval types in aggregate expressions + StructField("avg(year-month)", YearMonthIntervalType()), // TODO(SPARK-35729): Check all day-time interval types in aggregate expressions StructField("avg(day-time)", DayTimeIntervalType())))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 6327744..65cbaf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -908,18 +908,21 @@ class UDFSuite extends QueryTest with SharedSparkSession { val incMonth = udf((p: java.time.Period) => p.plusMonths(1)) val result = input.select(incMonth($"p").as("new_p")) checkAnswer(result, Row(java.time.Period.ofYears(1)) :: Nil) - assert(result.schema === new StructType().add("new_p", YearMonthIntervalType)) + // TODO(SPARK-35777): Check all year-month interval types in UDF + assert(result.schema === new StructType().add("new_p", YearMonthIntervalType())) // UDF produces `null` val nullFunc = udf((_: java.time.Period) => null.asInstanceOf[java.time.Period]) val nullResult = input.select(nullFunc($"p").as("null_p")) checkAnswer(nullResult, Row(null) :: Nil) - assert(nullResult.schema === new StructType().add("null_p", YearMonthIntervalType)) + // TODO(SPARK-35777): Check all year-month interval types in UDF + assert(nullResult.schema === new StructType().add("null_p", YearMonthIntervalType())) // Input parameter of UDF is null val nullInput = Seq(null.asInstanceOf[java.time.Period]).toDF("null_p") val constPeriod = udf((_: java.time.Period) => java.time.Period.ofYears(10)) val constResult = nullInput.select(constPeriod($"null_p").as("10_years")) checkAnswer(constResult, Row(java.time.Period.ofYears(10)) :: Nil) - assert(constResult.schema === new StructType().add("10_years", YearMonthIntervalType)) + // TODO(SPARK-35777): Check all year-month interval types in UDF + assert(constResult.schema === new StructType().add("10_years", YearMonthIntervalType())) // Error in the conversion of UDF result to the internal representation of year-month interval val overflowFunc = udf((p: java.time.Period) => p.plusYears(Long.MaxValue)) val e = intercept[SparkException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 52fc021..c56a210 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -56,7 +56,7 @@ class ArrowWriterSuite extends SparkFunSuite { case BinaryType => reader.getBinary(rowId) case DateType => reader.getInt(rowId) case TimestampType => reader.getLong(rowId) - case YearMonthIntervalType => reader.getInt(rowId) + case _: YearMonthIntervalType => reader.getInt(rowId) case _: DayTimeIntervalType => reader.getLong(rowId) } assert(value === datum) @@ -77,7 +77,7 @@ class ArrowWriterSuite extends SparkFunSuite { check(DateType, Seq(0, 1, 2, null, 4)) check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") check(NullType, Seq(null, null, null)) - check(YearMonthIntervalType, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue)) + check(YearMonthIntervalType(), Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue)) check(DayTimeIntervalType(), Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L), (Long.MinValue + 808L))) } @@ -128,7 +128,7 @@ class ArrowWriterSuite extends SparkFunSuite { case DoubleType => reader.getDoubles(0, data.size) case DateType => reader.getInts(0, data.size) case TimestampType => reader.getLongs(0, data.size) - case YearMonthIntervalType => reader.getInts(0, data.size) + case _: YearMonthIntervalType => reader.getInts(0, data.size) case _: DayTimeIntervalType => reader.getLongs(0, data.size) } assert(values === data) @@ -144,7 +144,8 @@ class ArrowWriterSuite extends SparkFunSuite { check(DoubleType, (0 until 10).map(_.toDouble)) check(DateType, (0 until 10)) check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") - check(YearMonthIntervalType, (0 until 10)) + // TODO(SPARK-35776): Check all year-month interval types in arrow + check(YearMonthIntervalType(), (0 until 10)) // TODO(SPARK-35731): Check all day-time interval types in arrow check(DayTimeIntervalType(), (-10 until 10).map(_ * 1000.toLong)) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 6f0c32b..4747d20 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -121,7 +121,7 @@ private[hive] class SparkExecuteStatementOperation( false, timeFormatters) case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] | - YearMonthIntervalType | _: DayTimeIntervalType => + _: YearMonthIntervalType | _: DayTimeIntervalType => to += toHiveString((from.get(ordinal), dataTypes(ordinal)), false, timeFormatters) } } @@ -377,7 +377,7 @@ object SparkExecuteStatementOperation { val attrTypeString = field.dataType match { case NullType => "void" case CalendarIntervalType => StringType.catalogString - case YearMonthIntervalType => "interval_year_month" + case _: YearMonthIntervalType => "interval_year_month" case _: DayTimeIntervalType => "interval_day_time" case other => other.catalogString } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala index 2d8d103..d80ae09 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala @@ -131,7 +131,8 @@ private[hive] class SparkGetColumnsOperation( */ private def getColumnSize(typ: DataType): Option[Int] = typ match { case dt @ (BooleanType | _: NumericType | DateType | TimestampType | - CalendarIntervalType | NullType | YearMonthIntervalType | _: DayTimeIntervalType) => + CalendarIntervalType | NullType | + _: YearMonthIntervalType | _: DayTimeIntervalType) => Some(dt.defaultSize) case CharType(n) => Some(n) case StructType(fields) => @@ -186,7 +187,7 @@ private[hive] class SparkGetColumnsOperation( case _: MapType => java.sql.Types.JAVA_OBJECT case _: StructType => java.sql.Types.STRUCT // Hive's year-month and day-time intervals are mapping to java.sql.Types.OTHER - case _: CalendarIntervalType | YearMonthIntervalType | _: DayTimeIntervalType => + case _: CalendarIntervalType | _: YearMonthIntervalType | _: DayTimeIntervalType => java.sql.Types.OTHER case _ => throw new IllegalArgumentException(s"Unrecognized type name: ${typ.sql}") } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala index bac8757..9acae1b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala @@ -392,8 +392,8 @@ class SparkMetadataOperationSuite extends HiveThriftServer2TestBase { assert(rowSet.getString("TABLE_NAME") === viewName1) assert(rowSet.getString("COLUMN_NAME") === "i") assert(rowSet.getInt("DATA_TYPE") === java.sql.Types.OTHER) - assert(rowSet.getString("TYPE_NAME").equalsIgnoreCase(YearMonthIntervalType.sql)) - assert(rowSet.getInt("COLUMN_SIZE") === YearMonthIntervalType.defaultSize) + assert(rowSet.getString("TYPE_NAME").equalsIgnoreCase(YearMonthIntervalType().sql)) + assert(rowSet.getInt("COLUMN_SIZE") === YearMonthIntervalType().defaultSize) assert(rowSet.getInt("DECIMAL_DIGITS") === 0) assert(rowSet.getInt("NUM_PREC_RADIX") === 0) assert(rowSet.getInt("NULLABLE") === 0) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 3f83c2f..f49018b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -842,7 +842,7 @@ private[hive] trait HiveInspectors { case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector case _: DayTimeIntervalType => PrimitiveObjectInspectorFactory.javaHiveIntervalDayTimeObjectInspector - case YearMonthIntervalType => + case _: YearMonthIntervalType => PrimitiveObjectInspectorFactory.javaHiveIntervalYearMonthObjectInspector // TODO decimal precision? case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector @@ -891,7 +891,7 @@ private[hive] trait HiveInspectors { getPrimitiveNullWritableConstantObjectInspector case Literal(_, _: DayTimeIntervalType) => getHiveIntervalDayTimeWritableConstantObjectInspector - case Literal(_, YearMonthIntervalType) => + case Literal(_, _: YearMonthIntervalType) => getHiveIntervalYearMonthWritableConstantObjectInspector case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) @@ -971,8 +971,8 @@ private[hive] trait HiveInspectors { case _: JavaTimestampObjectInspector => TimestampType case _: WritableHiveIntervalDayTimeObjectInspector => DayTimeIntervalType() case _: JavaHiveIntervalDayTimeObjectInspector => DayTimeIntervalType() - case _: WritableHiveIntervalYearMonthObjectInspector => YearMonthIntervalType - case _: JavaHiveIntervalYearMonthObjectInspector => YearMonthIntervalType + case _: WritableHiveIntervalYearMonthObjectInspector => YearMonthIntervalType() + case _: JavaHiveIntervalYearMonthObjectInspector => YearMonthIntervalType() case _: WritableVoidObjectInspector => NullType case _: JavaVoidObjectInspector => NullType } @@ -1156,7 +1156,7 @@ private[hive] trait HiveInspectors { case TimestampType => timestampTypeInfo case NullType => voidTypeInfo case _: DayTimeIntervalType => intervalDayTimeTypeInfo - case YearMonthIntervalType => intervalYearMonthTypeInfo + case _: YearMonthIntervalType => intervalYearMonthTypeInfo case dt => throw new AnalysisException( s"${dt.catalogString} cannot be converted to Hive TypeInfo") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index b396ddc..8cea781 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -546,7 +546,8 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T AttributeReference("a", DayTimeIntervalType())(), AttributeReference("b", DayTimeIntervalType())(), AttributeReference("c", DayTimeIntervalType())(), - AttributeReference("d", YearMonthIntervalType)()), + // TODO(SPARK-35772): Check all year-month interval types in HiveInspectors tests + AttributeReference("d", YearMonthIntervalType())()), child = child, ioschema = hiveIOSchema), df.select($"a", $"b", $"c", $"d").collect()) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org