This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e10bf64 [SPARK-34615][SQL] Support `java.time.Period` as an external
type of the year-month interval type
e10bf64 is described below
commit e10bf6476969c801808956ba7c5d79464bbebd1a
Author: Max Gekk <[email protected]>
AuthorDate: Mon Mar 8 08:33:09 2021 +0000
[SPARK-34615][SQL] Support `java.time.Period` as an external type of the
year-month interval type
### What changes were proposed in this pull request?
In the PR, I propose to extend Spark SQL API to accept
[`java.time.Period`](https://docs.oracle.com/javase/8/docs/api/java/time/Period.html)
as an external type of recently added new Catalyst type -
`YearMonthIntervalType` (see #31614). The Java class `java.time.Period` has
similar semantic to ANSI SQL year-month interval type, and it is the most
suitable to be an external type for `YearMonthIntervalType`. In more details:
1. Added `PeriodConverter` which converts `java.time.Period` instances
to/from internal representation of the Catalyst type `YearMonthIntervalType`
(to `Int` type). The `PeriodConverter` object uses new methods of
`IntervalUtils`:
- `periodToMonths()` converts the input period to the total length in
months. If this period is too large to fit `Int`, the method throws the
exception `ArithmeticException`. **Note:** _the input period has "days"
precision, the method just ignores the days unit._
- `monthToPeriod()` obtains a `java.time.Period` representing a number
of months.
2. Support new type `YearMonthIntervalType` in `RowEncoder` via the methods
`createDeserializerForPeriod()` and `createSerializerForJavaPeriod()`.
3. Extended the Literal API to construct literals from `java.time.Period`
instances.
### Why are the changes needed?
1. To allow users parallelization of `java.time.Period` collections, and
construct year-month interval columns. Also to collect such columns back to the
driver side.
2. This will allow to write tests in other sub-tasks of SPARK-27790.
### Does this PR introduce _any_ user-facing change?
The PR extends existing functionality. So, users can parallelize instances
of the `java.time.Duration` class and collect them back:
```scala
scala> val ds = Seq(java.time.Period.ofYears(10).withMonths(2)).toDS
ds: org.apache.spark.sql.Dataset[java.time.Period] = [value:
yearmonthinterval]
scala> ds.collect
res0: Array[java.time.Period] = Array(P10Y2M)
```
### How was this patch tested?
- Added a few tests to `CatalystTypeConvertersSuite` to check conversion
from/to `java.time.Period`.
- Checking row encoding by new tests in `RowEncoderSuite`.
- Making literals of `YearMonthIntervalType` are tested in
`LiteralExpressionSuite`.
- Check collecting by `DatasetSuite` and `JavaDatasetSuite`.
- New tests in `IntervalUtilsSuites` to check conversions
`java.time.Period` <-> months.
Closes #31765 from MaxGekk/java-time-period.
Authored-by: Max Gekk <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/SpecializedGettersReader.java | 3 ++
.../main/scala/org/apache/spark/sql/Encoders.scala | 8 +++++
.../sql/catalyst/CatalystTypeConverters.scala | 16 ++++++++-
.../sql/catalyst/DeserializerBuildHelper.scala | 9 +++++
.../apache/spark/sql/catalyst/InternalRow.scala | 6 ++--
.../spark/sql/catalyst/JavaTypeInference.scala | 6 ++++
.../spark/sql/catalyst/ScalaReflection.scala | 14 ++++++--
.../spark/sql/catalyst/SerializerBuildHelper.scala | 9 +++++
.../apache/spark/sql/catalyst/dsl/package.scala | 5 +++
.../spark/sql/catalyst/encoders/RowEncoder.scala | 6 ++++
.../expressions/InterpretedUnsafeProjection.scala | 2 +-
.../catalyst/expressions/SpecificInternalRow.scala | 4 +--
.../expressions/codegen/CodeGenerator.scala | 4 +--
.../spark/sql/catalyst/expressions/literals.scala | 11 ++++---
.../spark/sql/catalyst/util/IntervalUtils.scala | 33 ++++++++++++++++++-
.../org/apache/spark/sql/types/DataType.scala | 2 +-
.../sql/catalyst/CatalystTypeConvertersSuite.scala | 38 +++++++++++++++++++++-
.../sql/catalyst/encoders/RowEncoderSuite.scala | 10 ++++++
.../expressions/LiteralExpressionSuite.scala | 20 +++++++++++-
.../sql/catalyst/util/IntervalUtilsSuite.scala | 26 ++++++++++++++-
.../scala/org/apache/spark/sql/SQLImplicits.scala | 3 ++
.../org/apache/spark/sql/JavaDatasetSuite.java | 9 +++++
.../scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++
23 files changed, 230 insertions(+), 19 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
index d1bb719..90f340b 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
@@ -86,6 +86,9 @@ public final class SpecializedGettersReader {
if (dataType instanceof DayTimeIntervalType) {
return obj.getLong(ordinal);
}
+ if (dataType instanceof YearMonthIntervalType) {
+ return obj.getInt(ordinal);
+ }
throw new UnsupportedOperationException("Unsupported data type " +
dataType.simpleString());
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
index 5e72b19..d508295 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -144,6 +144,14 @@ object Encoders {
def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()
/**
+ * Creates an encoder that serializes instances of the `java.time.Period`
class
+ * to the internal representation of nullable Catalyst's
YearMonthIntervalType.
+ *
+ * @since 3.2.0
+ */
+ def PERIOD: Encoder[java.time.Period] = ExpressionEncoder()
+
+ /**
* Creates an encoder for Java Bean of type T.
*
* T must be publicly accessible.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 8201fd7d..b55d1b7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate}
+import java.time.{Duration, Instant, LocalDate, Period}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable
@@ -75,6 +75,7 @@ object CatalystTypeConverters {
case FloatType => FloatConverter
case DoubleType => DoubleConverter
case DayTimeIntervalType => DurationConverter
+ case YearMonthIntervalType => PeriodConverter
case dataType: DataType => IdentityConverter(dataType)
}
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
@@ -413,6 +414,18 @@ object CatalystTypeConverters {
IntervalUtils.microsToDuration(row.getLong(column))
}
+ private object PeriodConverter extends CatalystTypeConverter[Period, Period,
Any] {
+ override def toCatalystImpl(scalaValue: Period): Int = {
+ IntervalUtils.periodToMonths(scalaValue)
+ }
+ override def toScala(catalystValue: Any): Period = {
+ if (catalystValue == null) null
+ else IntervalUtils.monthsToPeriod(catalystValue.asInstanceOf[Int])
+ }
+ override def toScalaImpl(row: InternalRow, column: Int): Period =
+ IntervalUtils.monthsToPeriod(row.getInt(column))
+ }
+
/**
* Creates a converter function that will convert Scala objects to the
specified Catalyst type.
* Typical use case would be converting a collection of rows that have the
same schema. You will
@@ -479,6 +492,7 @@ object CatalystTypeConverters {
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
case d: Duration => DurationConverter.toCatalyst(d)
+ case p: Period => PeriodConverter.toCatalyst(p)
case other => other
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index 03243b4..eaa7c17 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -152,6 +152,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}
+ def createDeserializerForPeriod(path: Expression): Expression = {
+ StaticInvoke(
+ IntervalUtils.getClass,
+ ObjectType(classOf[java.time.Period]),
+ "monthsToPeriod",
+ path :: Nil,
+ returnNullable = false)
+ }
+
/**
* When we build the `deserializer` for an encoder, we set up a lot of
"unresolved" stuff
* and lost the required data type, which may lead to runtime error if the
real type doesn't
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 00b2d16..fd74f60 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -132,7 +132,8 @@ object InternalRow {
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
- case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
+ case IntegerType | DateType | YearMonthIntervalType =>
+ (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType | DayTimeIntervalType =>
(input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
@@ -168,7 +169,8 @@ object InternalRow {
case BooleanType => (input, v) => input.setBoolean(ordinal,
v.asInstanceOf[Boolean])
case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
case ShortType => (input, v) => input.setShort(ordinal,
v.asInstanceOf[Short])
- case IntegerType | DateType => (input, v) => input.setInt(ordinal,
v.asInstanceOf[Int])
+ case IntegerType | DateType | YearMonthIntervalType =>
+ (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
case LongType | TimestampType | DayTimeIntervalType =>
(input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case FloatType => (input, v) => input.setFloat(ordinal,
v.asInstanceOf[Float])
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 7f055a1..541b783 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -119,6 +119,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType,
true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType,
true)
case c: Class[_] if c == classOf[java.time.Duration] =>
(DayTimeIntervalType, true)
+ case c: Class[_] if c == classOf[java.time.Period] =>
(YearMonthIntervalType, true)
case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType,
seenTypeSet)
@@ -253,6 +254,9 @@ object JavaTypeInference {
case c if c == classOf[java.time.Duration] =>
createDeserializerForDuration(path)
+ case c if c == classOf[java.time.Period] =>
+ createDeserializerForPeriod(path)
+
case c if c == classOf[java.lang.String] =>
createDeserializerForString(path, returnNullable = true)
@@ -412,6 +416,8 @@ object JavaTypeInference {
case c if c == classOf[java.time.Duration] =>
createSerializerForJavaDuration(inputObject)
+ case c if c == classOf[java.time.Period] =>
createSerializerForJavaPeriod(inputObject)
+
case c if c == classOf[java.math.BigDecimal] =>
createSerializerForJavaBigDecimal(inputObject)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index bdb2a8e..c258cdf 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -243,6 +243,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createDeserializerForDuration(path)
+ case t if isSubtype(t, localTypeOf[java.time.Period]) =>
+ createDeserializerForPeriod(path)
+
case t if isSubtype(t, localTypeOf[java.lang.String]) =>
createDeserializerForString(path, returnNullable = false)
@@ -528,6 +531,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createSerializerForJavaDuration(inputObject)
+ case t if isSubtype(t, localTypeOf[java.time.Period]) =>
+ createSerializerForJavaPeriod(inputObject)
+
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
createSerializerForScalaBigDecimal(inputObject)
@@ -748,6 +754,8 @@ object ScalaReflection extends ScalaReflection {
Schema(CalendarIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
Schema(DayTimeIntervalType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.time.Period]) =>
+ Schema(YearMonthIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
@@ -846,7 +854,8 @@ object ScalaReflection extends ScalaReflection {
TimestampType -> classOf[TimestampType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval],
- DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType]
+ DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType],
+ YearMonthIntervalType -> classOf[YearMonthIntervalType.InternalType]
)
val typeBoxedJavaMapping = Map[DataType, Class[_]](
@@ -859,7 +868,8 @@ object ScalaReflection extends ScalaReflection {
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long],
- DayTimeIntervalType -> classOf[java.lang.Long]
+ DayTimeIntervalType -> classOf[java.lang.Long],
+ YearMonthIntervalType -> classOf[java.lang.Integer]
)
def dataTypeJavaClass(dt: DataType): Class[_] = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index fcecfbe..f80fab5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -113,6 +113,15 @@ object SerializerBuildHelper {
returnNullable = false)
}
+ def createSerializerForJavaPeriod(inputObject: Expression): Expression = {
+ StaticInvoke(
+ IntervalUtils.getClass,
+ YearMonthIntervalType,
+ "periodToMonths",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
def createSerializerForJavaBigDecimal(inputObject: Expression): Expression =
{
CheckOverflow(StaticInvoke(
Decimal.getClass,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 5d55973..626ece3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -302,6 +302,11 @@ package object dsl {
AttributeReference(s, DayTimeIntervalType, nullable = true)()
}
+ /** Creates a new AttributeReference of the year-month interval type */
+ def yearMonthInterval: AttributeReference = {
+ AttributeReference(s, YearMonthIntervalType, nullable = true)()
+ }
+
/** Creates a new AttributeReference of type binary */
def binary: AttributeReference = AttributeReference(s, BinaryType,
nullable = true)()
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index ebda55b..b67f707 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -54,6 +54,7 @@ import org.apache.spark.sql.types._
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled
is true
*
* DayTimeIntervalType -> java.time.Duration
+ * YearMonthIntervalType -> java.time.Period
*
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
@@ -112,6 +113,8 @@ object RowEncoder {
case DayTimeIntervalType => createSerializerForJavaDuration(inputObject)
+ case YearMonthIntervalType => createSerializerForJavaPeriod(inputObject)
+
case d: DecimalType =>
CheckOverflow(StaticInvoke(
Decimal.getClass,
@@ -231,6 +234,7 @@ object RowEncoder {
ObjectType(classOf[java.sql.Date])
}
case DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
+ case YearMonthIntervalType => ObjectType(classOf[java.time.Period])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
@@ -288,6 +292,8 @@ object RowEncoder {
case DayTimeIntervalType => createDeserializerForDuration(input)
+ case YearMonthIntervalType => createDeserializerForPeriod(input)
+
case _: DecimalType => createDeserializerForJavaBigDecimal(input,
returnNullable = false)
case StringType => createDeserializerForString(input, returnNullable =
false)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index 00ac3d6..908b73a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -157,7 +157,7 @@ object InterpretedUnsafeProjection {
case ShortType =>
(v, i) => writer.write(i, v.getShort(i))
- case IntegerType | DateType =>
+ case IntegerType | DateType | YearMonthIntervalType =>
(v, i) => writer.write(i, v.getInt(i))
case LongType | TimestampType | DayTimeIntervalType =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
index fd22978..0f26192 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
@@ -193,8 +193,8 @@ final class MutableAny extends MutableValue {
final class SpecificInternalRow(val values: Array[MutableValue]) extends
BaseGenericInternalRow {
private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue =
dataType match {
- // We use INT for DATE internally
- case IntegerType | DateType => new MutableInt
+ // We use INT for DATE and YearMonthIntervalType internally
+ case IntegerType | DateType | YearMonthIntervalType => new MutableInt
// We use Long for Timestamp and DayTimeInterval internally
case LongType | TimestampType | DayTimeIntervalType => new MutableLong
case FloatType => new MutableFloat
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 67c4adf..45ee193 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1812,7 +1812,7 @@ object CodeGenerator extends Logging {
case BooleanType => JAVA_BOOLEAN
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
- case IntegerType | DateType => JAVA_INT
+ case IntegerType | DateType | YearMonthIntervalType => JAVA_INT
case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
@@ -1833,7 +1833,7 @@ object CodeGenerator extends Logging {
case BooleanType => java.lang.Boolean.TYPE
case ByteType => java.lang.Byte.TYPE
case ShortType => java.lang.Short.TYPE
- case IntegerType | DateType => java.lang.Integer.TYPE
+ case IntegerType | DateType | YearMonthIntervalType =>
java.lang.Integer.TYPE
case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE
case FloatType => java.lang.Float.TYPE
case DoubleType => java.lang.Double.TYPE
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 203e98c..2ea73e8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort}
import java.math.{BigDecimal => JavaBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate}
+import java.time.{Duration, Instant, LocalDate, Period}
import java.util
import java.util.Objects
import javax.xml.bind.DatatypeConverter
@@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters,
InternalRow, Scala
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
-import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros
+import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros,
periodToMonths}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -78,6 +78,7 @@ object Literal {
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType)
+ case p: Period => Literal(periodToMonths(p), YearMonthIntervalType)
case a: Array[Byte] => Literal(a, BinaryType)
case a: collection.mutable.WrappedArray[_] => apply(a.array)
case a: Array[_] =>
@@ -114,6 +115,7 @@ object Literal {
case _ if clz == classOf[Instant] => TimestampType
case _ if clz == classOf[Timestamp] => TimestampType
case _ if clz == classOf[Duration] => DayTimeIntervalType
+ case _ if clz == classOf[Period] => YearMonthIntervalType
case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[Array[Byte]] => BinaryType
case _ if clz == classOf[Array[Char]] => StringType
@@ -171,6 +173,7 @@ object Literal {
case DateType => create(0, DateType)
case TimestampType => create(0L, TimestampType)
case DayTimeIntervalType => create(0L, DayTimeIntervalType)
+ case YearMonthIntervalType => create(0, YearMonthIntervalType)
case StringType => Literal("")
case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8))
case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0))
@@ -189,7 +192,7 @@ object Literal {
case BooleanType => v.isInstanceOf[Boolean]
case ByteType => v.isInstanceOf[Byte]
case ShortType => v.isInstanceOf[Short]
- case IntegerType | DateType => v.isInstanceOf[Int]
+ case IntegerType | DateType | YearMonthIntervalType =>
v.isInstanceOf[Int]
case LongType | TimestampType | DayTimeIntervalType =>
v.isInstanceOf[Long]
case FloatType => v.isInstanceOf[Float]
case DoubleType => v.isInstanceOf[Double]
@@ -366,7 +369,7 @@ case class Literal (value: Any, dataType: DataType) extends
LeafExpression {
ExprCode.forNonNullValue(JavaCode.literal(code, dataType))
}
dataType match {
- case BooleanType | IntegerType | DateType =>
+ case BooleanType | IntegerType | DateType | YearMonthIntervalType =>
toExprCode(value.toString)
case FloatType =>
value.asInstanceOf[Float] match {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index 6be4e9f..06ab4b6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.util
-import java.time.Duration
+import java.time.{Duration, Period}
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit
@@ -791,4 +791,35 @@ object IntervalUtils {
* @return A [[Duration]], not null
*/
def microsToDuration(micros: Long): Duration = Duration.of(micros,
ChronoUnit.MICROS)
+
+ /**
+ * Gets the total number of months in this period.
+ * <p>
+ * This returns the total number of months in the period by multiplying the
+ * number of years by 12 and adding the number of months.
+ * <p>
+ *
+ * @return The total number of months in the period, may be negative
+ * @throws ArithmeticException If numeric overflow occurs
+ */
+ def periodToMonths(period: Period): Int = {
+ val monthsInYears = Math.multiplyExact(period.getYears, MONTHS_PER_YEAR)
+ Math.addExact(monthsInYears, period.getMonths)
+ }
+
+ /**
+ * Obtains a [[Period]] representing a number of months. The days unit will
be zero, and the years
+ * and months units will be normalized.
+ *
+ * <p>
+ * The months unit is adjusted to have an absolute value < 12, with the
years unit being adjusted
+ * to compensate. For example, the method returns "2 years and 3 months" for
the 27 input months.
+ * <p>
+ * The sign of the years and months units will be the same after
normalization.
+ * For example, -13 months will be converted to "-1 year and -1 month".
+ *
+ * @param months The number of months, positive or negative
+ * @return The period of months, not null
+ */
+ def monthsToPeriod(months: Int): Period =
Period.ofMonths(months).normalized()
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 0527327..5c5742c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -171,7 +171,7 @@ object DataType {
private val otherTypes = {
Seq(NullType, DateType, TimestampType, BinaryType, IntegerType,
BooleanType, LongType,
DoubleType, FloatType, ShortType, ByteType, StringType,
CalendarIntervalType,
- DayTimeIntervalType)
+ DayTimeIntervalType, YearMonthIntervalType)
.map(t => t.typeName -> t).toMap
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
index 6b66af8..0dbae70 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst
-import java.time.{Duration, Instant, LocalDate}
+import java.time.{Duration, Instant, LocalDate, Period}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
@@ -258,4 +258,40 @@ class CatalystTypeConvertersSuite extends SparkFunSuite
with SQLHelper {
}
}
}
+
+ test("SPARK-34615: converting java.time.Period to YearMonthIntervalType") {
+ Seq(
+ Period.ZERO,
+ Period.ofMonths(1),
+ Period.ofMonths(-1),
+ Period.ofMonths(Int.MaxValue).normalized(),
+ Period.ofMonths(Int.MinValue).normalized(),
+ Period.ofYears(106751991),
+ Period.ofYears(-106751991)).foreach { input =>
+ val result = CatalystTypeConverters.convertToCatalyst(input)
+ val expected = IntervalUtils.periodToMonths(input)
+ assert(result === expected)
+ }
+
+ val errMsg = intercept[ArithmeticException] {
+ IntervalUtils.periodToMonths(Period.of(Int.MaxValue, Int.MaxValue,
Int.MaxValue))
+ }.getMessage
+ assert(errMsg.contains("integer overflow"))
+ }
+
+ test("SPARK-34615: converting YearMonthIntervalType to java.time.Period") {
+ Seq(
+ 0,
+ 1,
+ 999999,
+ 1000000,
+ Int.MaxValue).foreach { input =>
+ Seq(1, -1).foreach { sign =>
+ val months = sign * input
+ val period = IntervalUtils.monthsToPeriod(months)
+ assert(
+
CatalystTypeConverters.createToScalaConverter(YearMonthIntervalType)(months)
=== period)
+ }
+ }
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 9ab3361..6c22c14 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -352,6 +352,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
assert(readback.get(0).equals(duration))
}
+ test("SPARK-34615: encoding/decoding YearMonthIntervalType to/from
java.time.Period") {
+ val schema = new StructType().add("p", YearMonthIntervalType)
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val period = java.time.Period.ofMonths(1)
+ val row = toRow(encoder, Row(period))
+ assert(row.getInt(0) === IntervalUtils.periodToMonths(period))
+ val readback = fromRow(encoder, row)
+ assert(readback.get(0).equals(period))
+ }
+
for {
elementType <- Seq(IntegerType, StringType)
containsNull <- Seq(true, false)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 8cba46c..f8766f3 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, ZoneOffset}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period,
ZoneOffset}
import java.util.TimeZone
import scala.reflect.runtime.universe.TypeTag
@@ -367,4 +367,22 @@ class LiteralExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
val duration1 = Duration.ofHours(-1024)
checkEvaluation(Literal(Array(duration0, duration1)), Array(duration0,
duration1))
}
+
+ test("SPARK-34615: construct literals from java.time.Period") {
+ Seq(
+ Period.ofYears(0),
+ Period.of(-1, 11, 0),
+ Period.of(1, -11, 0),
+ Period.ofMonths(Int.MaxValue),
+ Period.ofMonths(Int.MinValue)).foreach { period =>
+ checkEvaluation(Literal(period), period)
+ }
+ }
+
+ test("SPARK-34615: construct literals from arrays of java.time.Period") {
+ val period0 = Period.ofYears(123).withMonths(456)
+ checkEvaluation(Literal(Array(period0)), Array(period0))
+ val period1 = Period.ofMonths(-1024)
+ checkEvaluation(Literal(Array(period0, period1)), Array(period0, period1))
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
index 51fd291..df2656f 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.util
-import java.time.Duration
+import java.time.{Duration, Period}
import java.util.concurrent.TimeUnit
import org.apache.spark.SparkFunSuite
@@ -401,4 +401,28 @@ class IntervalUtilsSuite extends SparkFunSuite with
SQLHelper {
}.getMessage
assert(errMsg.contains("long overflow"))
}
+
+ test("SPARK-34615: period to months") {
+ assert(periodToMonths(Period.ZERO) === 0)
+ assert(periodToMonths(Period.of(0, -1, 0)) === -1)
+ assert(periodToMonths(Period.of(-1, 0, 10)) === -12) // ignore days
+ assert(periodToMonths(Period.of(178956970, 7, 0)) === Int.MaxValue)
+ assert(periodToMonths(Period.of(-178956970, -8, 123)) === Int.MinValue)
+ assert(periodToMonths(Period.of(0, Int.MaxValue, Int.MaxValue)) ===
Int.MaxValue)
+
+ val errMsg = intercept[ArithmeticException] {
+ periodToMonths(Period.of(Int.MaxValue, 0, 0))
+ }.getMessage
+ assert(errMsg.contains("integer overflow"))
+ }
+
+ test("SPARK-34615: months to period") {
+ assert(monthsToPeriod(0) === Period.ZERO)
+ assert(monthsToPeriod(-11) === Period.of(0, -11, 0))
+ assert(monthsToPeriod(11) === Period.of(0, 11, 0))
+ assert(monthsToPeriod(27) === Period.of(2, 3, 0))
+ assert(monthsToPeriod(-13) === Period.of(-1, -1, 0))
+ assert(monthsToPeriod(Int.MaxValue) ===
Period.ofYears(178956970).withMonths(7))
+ assert(monthsToPeriod(Int.MinValue) ===
Period.ofYears(-178956970).withMonths(-8))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index bcc4871..90188ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -92,6 +92,9 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
implicit def newDurationEncoder: Encoder[java.time.Duration] =
Encoders.DURATION
/** @since 3.2.0 */
+ implicit def newPeriodEncoder: Encoder[java.time.Period] = Encoders.PERIOD
+
+ /** @since 3.2.0 */
implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]:
Encoder[A] =
ExpressionEncoder()
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 85ad80e..93566e0 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -24,6 +24,7 @@ import java.sql.Timestamp;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDate;
+import java.time.Period;
import java.util.*;
import javax.annotation.Nonnull;
@@ -421,6 +422,14 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(data, ds.collectAsList());
}
+ @Test
+ public void testPeriodEncoder() {
+ Encoder<Period> encoder = Encoders.PERIOD();
+ List<Period> data = Arrays.asList(Period.ofYears(10));
+ Dataset<Period> ds = spark.createDataset(data, encoder);
+ Assert.assertEquals(data, ds.collectAsList());
+ }
+
public static class KryoSerializable {
String value;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 843696e..a98bb06 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -2012,6 +2012,11 @@ class DatasetSuite extends QueryTest
val duration = java.time.Duration.ofMinutes(10)
assert(spark.range(1).map { _ => duration }.head === duration)
}
+
+ test("SPARK-34615: implicit encoder for java.time.Period") {
+ val period = java.time.Period.ofYears(9999).withMonths(11)
+ assert(spark.range(1).map { _ => period }.head === period)
+ }
}
case class Bar(a: Int)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]