This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e10bf64  [SPARK-34615][SQL] Support `java.time.Period` as an external 
type of the year-month interval type
e10bf64 is described below

commit e10bf6476969c801808956ba7c5d79464bbebd1a
Author: Max Gekk <[email protected]>
AuthorDate: Mon Mar 8 08:33:09 2021 +0000

    [SPARK-34615][SQL] Support `java.time.Period` as an external type of the 
year-month interval type
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to extend Spark SQL API to accept 
[`java.time.Period`](https://docs.oracle.com/javase/8/docs/api/java/time/Period.html)
 as an external type of recently added new Catalyst type - 
`YearMonthIntervalType` (see #31614). The Java class `java.time.Period` has 
similar semantic to ANSI SQL year-month interval type, and it is the most 
suitable to be an external type for `YearMonthIntervalType`. In more details:
    1. Added `PeriodConverter` which converts `java.time.Period` instances 
to/from internal representation of the Catalyst type `YearMonthIntervalType` 
(to `Int` type). The `PeriodConverter` object uses new methods of 
`IntervalUtils`:
        - `periodToMonths()` converts the input period to the total length in 
months. If this period is too large to fit `Int`, the method throws the 
exception `ArithmeticException`. **Note:** _the input period has "days" 
precision, the method just ignores the days unit._
        - `monthToPeriod()` obtains a `java.time.Period` representing a number 
of months.
    2. Support new type `YearMonthIntervalType` in `RowEncoder` via the methods 
`createDeserializerForPeriod()` and `createSerializerForJavaPeriod()`.
    3. Extended the Literal API to construct literals from `java.time.Period` 
instances.
    
    ### Why are the changes needed?
    1. To allow users parallelization of `java.time.Period` collections, and 
construct year-month interval columns. Also to collect such columns back to the 
driver side.
    2. This will allow to write tests in other sub-tasks of SPARK-27790.
    
    ### Does this PR introduce _any_ user-facing change?
    The PR extends existing functionality. So, users can parallelize instances 
of the `java.time.Duration` class and collect them back:
    
    ```scala
    scala> val ds = Seq(java.time.Period.ofYears(10).withMonths(2)).toDS
    ds: org.apache.spark.sql.Dataset[java.time.Period] = [value: 
yearmonthinterval]
    
    scala> ds.collect
    res0: Array[java.time.Period] = Array(P10Y2M)
    ```
    
    ### How was this patch tested?
    - Added a few tests to `CatalystTypeConvertersSuite` to check conversion 
from/to `java.time.Period`.
    - Checking row encoding by new tests in `RowEncoderSuite`.
    - Making literals of `YearMonthIntervalType` are tested in 
`LiteralExpressionSuite`.
    - Check collecting by `DatasetSuite` and `JavaDatasetSuite`.
    - New tests in `IntervalUtilsSuites` to check conversions 
`java.time.Period` <-> months.
    
    Closes #31765 from MaxGekk/java-time-period.
    
    Authored-by: Max Gekk <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../expressions/SpecializedGettersReader.java      |  3 ++
 .../main/scala/org/apache/spark/sql/Encoders.scala |  8 +++++
 .../sql/catalyst/CatalystTypeConverters.scala      | 16 ++++++++-
 .../sql/catalyst/DeserializerBuildHelper.scala     |  9 +++++
 .../apache/spark/sql/catalyst/InternalRow.scala    |  6 ++--
 .../spark/sql/catalyst/JavaTypeInference.scala     |  6 ++++
 .../spark/sql/catalyst/ScalaReflection.scala       | 14 ++++++--
 .../spark/sql/catalyst/SerializerBuildHelper.scala |  9 +++++
 .../apache/spark/sql/catalyst/dsl/package.scala    |  5 +++
 .../spark/sql/catalyst/encoders/RowEncoder.scala   |  6 ++++
 .../expressions/InterpretedUnsafeProjection.scala  |  2 +-
 .../catalyst/expressions/SpecificInternalRow.scala |  4 +--
 .../expressions/codegen/CodeGenerator.scala        |  4 +--
 .../spark/sql/catalyst/expressions/literals.scala  | 11 ++++---
 .../spark/sql/catalyst/util/IntervalUtils.scala    | 33 ++++++++++++++++++-
 .../org/apache/spark/sql/types/DataType.scala      |  2 +-
 .../sql/catalyst/CatalystTypeConvertersSuite.scala | 38 +++++++++++++++++++++-
 .../sql/catalyst/encoders/RowEncoderSuite.scala    | 10 ++++++
 .../expressions/LiteralExpressionSuite.scala       | 20 +++++++++++-
 .../sql/catalyst/util/IntervalUtilsSuite.scala     | 26 ++++++++++++++-
 .../scala/org/apache/spark/sql/SQLImplicits.scala  |  3 ++
 .../org/apache/spark/sql/JavaDatasetSuite.java     |  9 +++++
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |  5 +++
 23 files changed, 230 insertions(+), 19 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
index d1bb719..90f340b 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
@@ -86,6 +86,9 @@ public final class SpecializedGettersReader {
     if (dataType instanceof DayTimeIntervalType) {
       return obj.getLong(ordinal);
     }
+    if (dataType instanceof YearMonthIntervalType) {
+      return obj.getInt(ordinal);
+    }
 
     throw new UnsupportedOperationException("Unsupported data type " + 
dataType.simpleString());
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
index 5e72b19..d508295 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -144,6 +144,14 @@ object Encoders {
   def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()
 
   /**
+   * Creates an encoder that serializes instances of the `java.time.Period` 
class
+   * to the internal representation of nullable Catalyst's 
YearMonthIntervalType.
+   *
+   * @since 3.2.0
+   */
+  def PERIOD: Encoder[java.time.Period] = ExpressionEncoder()
+
+  /**
    * Creates an encoder for Java Bean of type T.
    *
    * T must be publicly accessible.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 8201fd7d..b55d1b7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
 import java.math.{BigDecimal => JavaBigDecimal}
 import java.math.{BigInteger => JavaBigInteger}
 import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate}
+import java.time.{Duration, Instant, LocalDate, Period}
 import java.util.{Map => JavaMap}
 import javax.annotation.Nullable
 
@@ -75,6 +75,7 @@ object CatalystTypeConverters {
       case FloatType => FloatConverter
       case DoubleType => DoubleConverter
       case DayTimeIntervalType => DurationConverter
+      case YearMonthIntervalType => PeriodConverter
       case dataType: DataType => IdentityConverter(dataType)
     }
     converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
@@ -413,6 +414,18 @@ object CatalystTypeConverters {
       IntervalUtils.microsToDuration(row.getLong(column))
   }
 
+  private object PeriodConverter extends CatalystTypeConverter[Period, Period, 
Any] {
+    override def toCatalystImpl(scalaValue: Period): Int = {
+      IntervalUtils.periodToMonths(scalaValue)
+    }
+    override def toScala(catalystValue: Any): Period = {
+      if (catalystValue == null) null
+      else IntervalUtils.monthsToPeriod(catalystValue.asInstanceOf[Int])
+    }
+    override def toScalaImpl(row: InternalRow, column: Int): Period =
+      IntervalUtils.monthsToPeriod(row.getInt(column))
+  }
+
   /**
    * Creates a converter function that will convert Scala objects to the 
specified Catalyst type.
    * Typical use case would be converting a collection of rows that have the 
same schema. You will
@@ -479,6 +492,7 @@ object CatalystTypeConverters {
         (key: Any) => convertToCatalyst(key),
         (value: Any) => convertToCatalyst(value))
     case d: Duration => DurationConverter.toCatalyst(d)
+    case p: Period => PeriodConverter.toCatalyst(p)
     case other => other
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index 03243b4..eaa7c17 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -152,6 +152,15 @@ object DeserializerBuildHelper {
       returnNullable = false)
   }
 
+  def createDeserializerForPeriod(path: Expression): Expression = {
+    StaticInvoke(
+      IntervalUtils.getClass,
+      ObjectType(classOf[java.time.Period]),
+      "monthsToPeriod",
+      path :: Nil,
+      returnNullable = false)
+  }
+
   /**
    * When we build the `deserializer` for an encoder, we set up a lot of 
"unresolved" stuff
    * and lost the required data type, which may lead to runtime error if the 
real type doesn't
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 00b2d16..fd74f60 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -132,7 +132,8 @@ object InternalRow {
       case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
       case ByteType => (input, ordinal) => input.getByte(ordinal)
       case ShortType => (input, ordinal) => input.getShort(ordinal)
-      case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
+      case IntegerType | DateType | YearMonthIntervalType =>
+        (input, ordinal) => input.getInt(ordinal)
       case LongType | TimestampType | DayTimeIntervalType =>
         (input, ordinal) => input.getLong(ordinal)
       case FloatType => (input, ordinal) => input.getFloat(ordinal)
@@ -168,7 +169,8 @@ object InternalRow {
     case BooleanType => (input, v) => input.setBoolean(ordinal, 
v.asInstanceOf[Boolean])
     case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
     case ShortType => (input, v) => input.setShort(ordinal, 
v.asInstanceOf[Short])
-    case IntegerType | DateType => (input, v) => input.setInt(ordinal, 
v.asInstanceOf[Int])
+    case IntegerType | DateType | YearMonthIntervalType =>
+      (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
     case LongType | TimestampType | DayTimeIntervalType =>
       (input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
     case FloatType => (input, v) => input.setFloat(ordinal, 
v.asInstanceOf[Float])
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 7f055a1..541b783 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -119,6 +119,7 @@ object JavaTypeInference {
       case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, 
true)
       case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, 
true)
       case c: Class[_] if c == classOf[java.time.Duration] => 
(DayTimeIntervalType, true)
+      case c: Class[_] if c == classOf[java.time.Period] => 
(YearMonthIntervalType, true)
 
       case _ if typeToken.isArray =>
         val (dataType, nullable) = inferDataType(typeToken.getComponentType, 
seenTypeSet)
@@ -253,6 +254,9 @@ object JavaTypeInference {
       case c if c == classOf[java.time.Duration] =>
         createDeserializerForDuration(path)
 
+      case c if c == classOf[java.time.Period] =>
+        createDeserializerForPeriod(path)
+
       case c if c == classOf[java.lang.String] =>
         createDeserializerForString(path, returnNullable = true)
 
@@ -412,6 +416,8 @@ object JavaTypeInference {
 
         case c if c == classOf[java.time.Duration] => 
createSerializerForJavaDuration(inputObject)
 
+        case c if c == classOf[java.time.Period] => 
createSerializerForJavaPeriod(inputObject)
+
         case c if c == classOf[java.math.BigDecimal] =>
           createSerializerForJavaBigDecimal(inputObject)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index bdb2a8e..c258cdf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -243,6 +243,9 @@ object ScalaReflection extends ScalaReflection {
       case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
         createDeserializerForDuration(path)
 
+      case t if isSubtype(t, localTypeOf[java.time.Period]) =>
+        createDeserializerForPeriod(path)
+
       case t if isSubtype(t, localTypeOf[java.lang.String]) =>
         createDeserializerForString(path, returnNullable = false)
 
@@ -528,6 +531,9 @@ object ScalaReflection extends ScalaReflection {
       case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
         createSerializerForJavaDuration(inputObject)
 
+      case t if isSubtype(t, localTypeOf[java.time.Period]) =>
+        createSerializerForJavaPeriod(inputObject)
+
       case t if isSubtype(t, localTypeOf[BigDecimal]) =>
         createSerializerForScalaBigDecimal(inputObject)
 
@@ -748,6 +754,8 @@ object ScalaReflection extends ScalaReflection {
         Schema(CalendarIntervalType, nullable = true)
       case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
         Schema(DayTimeIntervalType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.time.Period]) =>
+        Schema(YearMonthIntervalType, nullable = true)
       case t if isSubtype(t, localTypeOf[BigDecimal]) =>
         Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
       case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
@@ -846,7 +854,8 @@ object ScalaReflection extends ScalaReflection {
     TimestampType -> classOf[TimestampType.InternalType],
     BinaryType -> classOf[BinaryType.InternalType],
     CalendarIntervalType -> classOf[CalendarInterval],
-    DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType]
+    DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType],
+    YearMonthIntervalType -> classOf[YearMonthIntervalType.InternalType]
   )
 
   val typeBoxedJavaMapping = Map[DataType, Class[_]](
@@ -859,7 +868,8 @@ object ScalaReflection extends ScalaReflection {
     DoubleType -> classOf[java.lang.Double],
     DateType -> classOf[java.lang.Integer],
     TimestampType -> classOf[java.lang.Long],
-    DayTimeIntervalType -> classOf[java.lang.Long]
+    DayTimeIntervalType -> classOf[java.lang.Long],
+    YearMonthIntervalType -> classOf[java.lang.Integer]
   )
 
   def dataTypeJavaClass(dt: DataType): Class[_] = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index fcecfbe..f80fab5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -113,6 +113,15 @@ object SerializerBuildHelper {
       returnNullable = false)
   }
 
+  def createSerializerForJavaPeriod(inputObject: Expression): Expression = {
+    StaticInvoke(
+      IntervalUtils.getClass,
+      YearMonthIntervalType,
+      "periodToMonths",
+      inputObject :: Nil,
+      returnNullable = false)
+  }
+
   def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = 
{
     CheckOverflow(StaticInvoke(
       Decimal.getClass,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 5d55973..626ece3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -302,6 +302,11 @@ package object dsl {
         AttributeReference(s, DayTimeIntervalType, nullable = true)()
       }
 
+      /** Creates a new AttributeReference of the year-month interval type */
+      def yearMonthInterval: AttributeReference = {
+        AttributeReference(s, YearMonthIntervalType, nullable = true)()
+      }
+
       /** Creates a new AttributeReference of type binary */
       def binary: AttributeReference = AttributeReference(s, BinaryType, 
nullable = true)()
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index ebda55b..b67f707 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -54,6 +54,7 @@ import org.apache.spark.sql.types._
  *   TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled 
is true
  *
  *   DayTimeIntervalType -> java.time.Duration
+ *   YearMonthIntervalType -> java.time.Period
  *
  *   BinaryType -> byte array
  *   ArrayType -> scala.collection.Seq or Array
@@ -112,6 +113,8 @@ object RowEncoder {
 
     case DayTimeIntervalType => createSerializerForJavaDuration(inputObject)
 
+    case YearMonthIntervalType => createSerializerForJavaPeriod(inputObject)
+
     case d: DecimalType =>
       CheckOverflow(StaticInvoke(
         Decimal.getClass,
@@ -231,6 +234,7 @@ object RowEncoder {
         ObjectType(classOf[java.sql.Date])
       }
     case DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
+    case YearMonthIntervalType => ObjectType(classOf[java.time.Period])
     case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
     case StringType => ObjectType(classOf[java.lang.String])
     case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
@@ -288,6 +292,8 @@ object RowEncoder {
 
     case DayTimeIntervalType => createDeserializerForDuration(input)
 
+    case YearMonthIntervalType => createDeserializerForPeriod(input)
+
     case _: DecimalType => createDeserializerForJavaBigDecimal(input, 
returnNullable = false)
 
     case StringType => createDeserializerForString(input, returnNullable = 
false)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index 00ac3d6..908b73a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -157,7 +157,7 @@ object InterpretedUnsafeProjection {
       case ShortType =>
         (v, i) => writer.write(i, v.getShort(i))
 
-      case IntegerType | DateType =>
+      case IntegerType | DateType | YearMonthIntervalType =>
         (v, i) => writer.write(i, v.getInt(i))
 
       case LongType | TimestampType | DayTimeIntervalType =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
index fd22978..0f26192 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala
@@ -193,8 +193,8 @@ final class MutableAny extends MutableValue {
 final class SpecificInternalRow(val values: Array[MutableValue]) extends 
BaseGenericInternalRow {
 
   private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = 
dataType match {
-    // We use INT for DATE internally
-    case IntegerType | DateType => new MutableInt
+    // We use INT for DATE and YearMonthIntervalType internally
+    case IntegerType | DateType | YearMonthIntervalType => new MutableInt
     // We use Long for Timestamp and DayTimeInterval internally
     case LongType | TimestampType | DayTimeIntervalType => new MutableLong
     case FloatType => new MutableFloat
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 67c4adf..45ee193 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1812,7 +1812,7 @@ object CodeGenerator extends Logging {
     case BooleanType => JAVA_BOOLEAN
     case ByteType => JAVA_BYTE
     case ShortType => JAVA_SHORT
-    case IntegerType | DateType => JAVA_INT
+    case IntegerType | DateType | YearMonthIntervalType => JAVA_INT
     case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG
     case FloatType => JAVA_FLOAT
     case DoubleType => JAVA_DOUBLE
@@ -1833,7 +1833,7 @@ object CodeGenerator extends Logging {
     case BooleanType => java.lang.Boolean.TYPE
     case ByteType => java.lang.Byte.TYPE
     case ShortType => java.lang.Short.TYPE
-    case IntegerType | DateType => java.lang.Integer.TYPE
+    case IntegerType | DateType | YearMonthIntervalType => 
java.lang.Integer.TYPE
     case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE
     case FloatType => java.lang.Float.TYPE
     case DoubleType => java.lang.Double.TYPE
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 203e98c..2ea73e8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort}
 import java.math.{BigDecimal => JavaBigDecimal}
 import java.nio.charset.StandardCharsets
 import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate}
+import java.time.{Duration, Instant, LocalDate, Period}
 import java.util
 import java.util.Objects
 import javax.xml.bind.DatatypeConverter
@@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, 
InternalRow, Scala
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
-import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros
+import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, 
periodToMonths}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -78,6 +78,7 @@ object Literal {
     case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
     case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
     case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType)
+    case p: Period => Literal(periodToMonths(p), YearMonthIntervalType)
     case a: Array[Byte] => Literal(a, BinaryType)
     case a: collection.mutable.WrappedArray[_] => apply(a.array)
     case a: Array[_] =>
@@ -114,6 +115,7 @@ object Literal {
     case _ if clz == classOf[Instant] => TimestampType
     case _ if clz == classOf[Timestamp] => TimestampType
     case _ if clz == classOf[Duration] => DayTimeIntervalType
+    case _ if clz == classOf[Period] => YearMonthIntervalType
     case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
     case _ if clz == classOf[Array[Byte]] => BinaryType
     case _ if clz == classOf[Array[Char]] => StringType
@@ -171,6 +173,7 @@ object Literal {
     case DateType => create(0, DateType)
     case TimestampType => create(0L, TimestampType)
     case DayTimeIntervalType => create(0L, DayTimeIntervalType)
+    case YearMonthIntervalType => create(0, YearMonthIntervalType)
     case StringType => Literal("")
     case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8))
     case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0))
@@ -189,7 +192,7 @@ object Literal {
       case BooleanType => v.isInstanceOf[Boolean]
       case ByteType => v.isInstanceOf[Byte]
       case ShortType => v.isInstanceOf[Short]
-      case IntegerType | DateType => v.isInstanceOf[Int]
+      case IntegerType | DateType | YearMonthIntervalType => 
v.isInstanceOf[Int]
       case LongType | TimestampType | DayTimeIntervalType => 
v.isInstanceOf[Long]
       case FloatType => v.isInstanceOf[Float]
       case DoubleType => v.isInstanceOf[Double]
@@ -366,7 +369,7 @@ case class Literal (value: Any, dataType: DataType) extends 
LeafExpression {
         ExprCode.forNonNullValue(JavaCode.literal(code, dataType))
       }
       dataType match {
-        case BooleanType | IntegerType | DateType =>
+        case BooleanType | IntegerType | DateType | YearMonthIntervalType =>
           toExprCode(value.toString)
         case FloatType =>
           value.asInstanceOf[Float] match {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index 6be4e9f..06ab4b6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.util
 
-import java.time.Duration
+import java.time.{Duration, Period}
 import java.time.temporal.ChronoUnit
 import java.util.concurrent.TimeUnit
 
@@ -791,4 +791,35 @@ object IntervalUtils {
    * @return A [[Duration]], not null
    */
   def microsToDuration(micros: Long): Duration = Duration.of(micros, 
ChronoUnit.MICROS)
+
+  /**
+   * Gets the total number of months in this period.
+   * <p>
+   * This returns the total number of months in the period by multiplying the
+   * number of years by 12 and adding the number of months.
+   * <p>
+   *
+   * @return The total number of months in the period, may be negative
+   * @throws ArithmeticException If numeric overflow occurs
+   */
+  def periodToMonths(period: Period): Int = {
+    val monthsInYears = Math.multiplyExact(period.getYears, MONTHS_PER_YEAR)
+    Math.addExact(monthsInYears, period.getMonths)
+  }
+
+  /**
+   * Obtains a [[Period]] representing a number of months. The days unit will 
be zero, and the years
+   * and months units will be normalized.
+   *
+   * <p>
+   * The months unit is adjusted to have an absolute value < 12, with the 
years unit being adjusted
+   * to compensate. For example, the method returns "2 years and 3 months" for 
the 27 input months.
+   * <p>
+   * The sign of the years and months units will be the same after 
normalization.
+   * For example, -13 months will be converted to "-1 year and -1 month".
+   *
+   * @param months The number of months, positive or negative
+   * @return The period of months, not null
+   */
+  def monthsToPeriod(months: Int): Period = 
Period.ofMonths(months).normalized()
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 0527327..5c5742c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -171,7 +171,7 @@ object DataType {
   private val otherTypes = {
     Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, 
BooleanType, LongType,
       DoubleType, FloatType, ShortType, ByteType, StringType, 
CalendarIntervalType,
-      DayTimeIntervalType)
+      DayTimeIntervalType, YearMonthIntervalType)
       .map(t => t.typeName -> t).toMap
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
index 6b66af8..0dbae70 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst
 
-import java.time.{Duration, Instant, LocalDate}
+import java.time.{Duration, Instant, LocalDate, Period}
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
@@ -258,4 +258,40 @@ class CatalystTypeConvertersSuite extends SparkFunSuite 
with SQLHelper {
       }
     }
   }
+
+  test("SPARK-34615: converting java.time.Period to YearMonthIntervalType") {
+    Seq(
+      Period.ZERO,
+      Period.ofMonths(1),
+      Period.ofMonths(-1),
+      Period.ofMonths(Int.MaxValue).normalized(),
+      Period.ofMonths(Int.MinValue).normalized(),
+      Period.ofYears(106751991),
+      Period.ofYears(-106751991)).foreach { input =>
+      val result = CatalystTypeConverters.convertToCatalyst(input)
+      val expected = IntervalUtils.periodToMonths(input)
+      assert(result === expected)
+    }
+
+    val errMsg = intercept[ArithmeticException] {
+      IntervalUtils.periodToMonths(Period.of(Int.MaxValue, Int.MaxValue, 
Int.MaxValue))
+    }.getMessage
+    assert(errMsg.contains("integer overflow"))
+  }
+
+  test("SPARK-34615: converting YearMonthIntervalType to java.time.Period") {
+    Seq(
+      0,
+      1,
+      999999,
+      1000000,
+      Int.MaxValue).foreach { input =>
+      Seq(1, -1).foreach { sign =>
+        val months = sign * input
+        val period = IntervalUtils.monthsToPeriod(months)
+        assert(
+          
CatalystTypeConverters.createToScalaConverter(YearMonthIntervalType)(months) 
=== period)
+      }
+    }
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 9ab3361..6c22c14 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -352,6 +352,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
     assert(readback.get(0).equals(duration))
   }
 
+  test("SPARK-34615: encoding/decoding YearMonthIntervalType to/from 
java.time.Period") {
+    val schema = new StructType().add("p", YearMonthIntervalType)
+    val encoder = RowEncoder(schema).resolveAndBind()
+    val period = java.time.Period.ofMonths(1)
+    val row = toRow(encoder, Row(period))
+    assert(row.getInt(0) === IntervalUtils.periodToMonths(period))
+    val readback = fromRow(encoder, row)
+    assert(readback.get(0).equals(period))
+  }
+
   for {
     elementType <- Seq(IntegerType, StringType)
     containsNull <- Seq(true, false)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 8cba46c..f8766f3 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.nio.charset.StandardCharsets
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, ZoneOffset}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, 
ZoneOffset}
 import java.util.TimeZone
 
 import scala.reflect.runtime.universe.TypeTag
@@ -367,4 +367,22 @@ class LiteralExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     val duration1 = Duration.ofHours(-1024)
     checkEvaluation(Literal(Array(duration0, duration1)), Array(duration0, 
duration1))
   }
+
+  test("SPARK-34615: construct literals from java.time.Period") {
+    Seq(
+      Period.ofYears(0),
+      Period.of(-1, 11, 0),
+      Period.of(1, -11, 0),
+      Period.ofMonths(Int.MaxValue),
+      Period.ofMonths(Int.MinValue)).foreach { period =>
+      checkEvaluation(Literal(period), period)
+    }
+  }
+
+  test("SPARK-34615: construct literals from arrays of java.time.Period") {
+    val period0 = Period.ofYears(123).withMonths(456)
+    checkEvaluation(Literal(Array(period0)), Array(period0))
+    val period1 = Period.ofMonths(-1024)
+    checkEvaluation(Literal(Array(period0, period1)), Array(period0, period1))
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
index 51fd291..df2656f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.util
 
-import java.time.Duration
+import java.time.{Duration, Period}
 import java.util.concurrent.TimeUnit
 
 import org.apache.spark.SparkFunSuite
@@ -401,4 +401,28 @@ class IntervalUtilsSuite extends SparkFunSuite with 
SQLHelper {
     }.getMessage
     assert(errMsg.contains("long overflow"))
   }
+
+  test("SPARK-34615: period to months") {
+    assert(periodToMonths(Period.ZERO) === 0)
+    assert(periodToMonths(Period.of(0, -1, 0)) === -1)
+    assert(periodToMonths(Period.of(-1, 0, 10)) === -12) // ignore days
+    assert(periodToMonths(Period.of(178956970, 7, 0)) === Int.MaxValue)
+    assert(periodToMonths(Period.of(-178956970, -8, 123)) === Int.MinValue)
+    assert(periodToMonths(Period.of(0, Int.MaxValue, Int.MaxValue)) === 
Int.MaxValue)
+
+    val errMsg = intercept[ArithmeticException] {
+      periodToMonths(Period.of(Int.MaxValue, 0, 0))
+    }.getMessage
+    assert(errMsg.contains("integer overflow"))
+  }
+
+  test("SPARK-34615: months to period") {
+    assert(monthsToPeriod(0) === Period.ZERO)
+    assert(monthsToPeriod(-11) === Period.of(0, -11, 0))
+    assert(monthsToPeriod(11) === Period.of(0, 11, 0))
+    assert(monthsToPeriod(27) === Period.of(2, 3, 0))
+    assert(monthsToPeriod(-13) === Period.of(-1, -1, 0))
+    assert(monthsToPeriod(Int.MaxValue) === 
Period.ofYears(178956970).withMonths(7))
+    assert(monthsToPeriod(Int.MinValue) === 
Period.ofYears(-178956970).withMonths(-8))
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index bcc4871..90188ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -92,6 +92,9 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
   implicit def newDurationEncoder: Encoder[java.time.Duration] = 
Encoders.DURATION
 
   /** @since 3.2.0 */
+  implicit def newPeriodEncoder: Encoder[java.time.Period] = Encoders.PERIOD
+
+  /** @since 3.2.0 */
   implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]: 
Encoder[A] =
     ExpressionEncoder()
 
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 85ad80e..93566e0 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -24,6 +24,7 @@ import java.sql.Timestamp;
 import java.time.Duration;
 import java.time.Instant;
 import java.time.LocalDate;
+import java.time.Period;
 import java.util.*;
 import javax.annotation.Nonnull;
 
@@ -421,6 +422,14 @@ public class JavaDatasetSuite implements Serializable {
     Assert.assertEquals(data, ds.collectAsList());
   }
 
+  @Test
+  public void testPeriodEncoder() {
+    Encoder<Period> encoder = Encoders.PERIOD();
+    List<Period> data = Arrays.asList(Period.ofYears(10));
+    Dataset<Period> ds = spark.createDataset(data, encoder);
+    Assert.assertEquals(data, ds.collectAsList());
+  }
+
   public static class KryoSerializable {
     String value;
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 843696e..a98bb06 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -2012,6 +2012,11 @@ class DatasetSuite extends QueryTest
     val duration = java.time.Duration.ofMinutes(10)
     assert(spark.range(1).map { _ => duration }.head === duration)
   }
+
+  test("SPARK-34615: implicit encoder for java.time.Period") {
+    val period = java.time.Period.ofYears(9999).withMonths(11)
+    assert(spark.range(1).map { _ => period }.head === period)
+  }
 }
 
 case class Bar(a: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to