This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new b666d83  [SPARK-31488][SQL] Support `java.time.LocalDate` in Parquet 
filter pushdown
b666d83 is described below

commit b666d8360e39aba583c1b98a6c93826046fccc1b
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Fri Apr 24 02:21:53 2020 +0000

    [SPARK-31488][SQL] Support `java.time.LocalDate` in Parquet filter pushdown
    
    ### What changes were proposed in this pull request?
    1. Modified `ParquetFilters.valueCanMakeFilterOn()` to accept filters with 
`java.time.LocalDate` attributes.
    2. Modified `ParquetFilters.dateToDays()` to support both types 
`java.sql.Date` and `java.time.LocalDate` in conversions to days.
    3. Add implicit conversion from `LocalDate` to `Expression` (`Literal`).
    
    ### Why are the changes needed?
    To support pushed down filters with `java.time.LocalDate` attributes. 
Before the changes, date filters are not pushed down to Parquet datasource when 
`spark.sql.datetime.java8API.enabled` is `true`.
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    Added a test to `ParquetFilterSuite`
    
    Closes #28259 from MaxGekk/parquet-filter-java8-date-time.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 26165427c7ec0b0b64a527edae42b05ba9d47a19)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/catalyst/dsl/package.scala    |  2 +
 .../datasources/parquet/ParquetFilters.scala       | 21 +++--
 .../datasources/parquet/ParquetFilterSuite.scala   | 99 ++++++++++++----------
 3 files changed, 69 insertions(+), 53 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index b4a8baf..cc96d90 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst
 
 import java.sql.{Date, Timestamp}
+import java.time.LocalDate
 
 import scala.language.implicitConversions
 
@@ -146,6 +147,7 @@ package object dsl {
     implicit def doubleToLiteral(d: Double): Literal = Literal(d)
     implicit def stringToLiteral(s: String): Literal = Literal.create(s, 
StringType)
     implicit def dateToLiteral(d: Date): Literal = Literal(d)
+    implicit def localDateToLiteral(d: LocalDate): Literal = Literal(d)
     implicit def bigDecimalToLiteral(d: BigDecimal): Literal = 
Literal(d.underlying())
     implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = 
Literal(d)
     implicit def decimalToLiteral(d: Decimal): Literal = Literal(d)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index f206f59..d89186a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
 import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, 
Long => JLong}
 import java.math.{BigDecimal => JBigDecimal}
 import java.sql.{Date, Timestamp}
+import java.time.LocalDate
 import java.util.Locale
 
 import scala.collection.JavaConverters.asScalaBufferConverter
@@ -123,8 +124,9 @@ class ParquetFilters(
   private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, 
INT64, 0, null)
   private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, 
INT64, 0, null)
 
-  private def dateToDays(date: Date): SQLDate = {
-    DateTimeUtils.fromJavaDate(date)
+  private def dateToDays(date: Any): SQLDate = date match {
+    case d: Date => DateTimeUtils.fromJavaDate(d)
+    case ld: LocalDate => DateTimeUtils.localDateToDays(ld)
   }
 
   private def decimalToInt32(decimal: JBigDecimal): Integer = 
decimal.unscaledValue().intValue()
@@ -173,7 +175,7 @@ class ParquetFilters(
     case ParquetDateType if pushDownDate =>
       (n: Array[String], v: Any) => FilterApi.eq(
         intColumn(n),
-        Option(v).map(date => 
dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
+        Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull)
     case ParquetTimestampMicrosType if pushDownTimestamp =>
       (n: Array[String], v: Any) => FilterApi.eq(
         longColumn(n),
@@ -224,7 +226,7 @@ class ParquetFilters(
     case ParquetDateType if pushDownDate =>
       (n: Array[String], v: Any) => FilterApi.notEq(
         intColumn(n),
-        Option(v).map(date => 
dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
+        Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull)
     case ParquetTimestampMicrosType if pushDownTimestamp =>
       (n: Array[String], v: Any) => FilterApi.notEq(
         longColumn(n),
@@ -269,7 +271,7 @@ class ParquetFilters(
         FilterApi.lt(binaryColumn(n), 
Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
     case ParquetDateType if pushDownDate =>
       (n: Array[String], v: Any) =>
-        FilterApi.lt(intColumn(n), 
dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer])
+        FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer])
     case ParquetTimestampMicrosType if pushDownTimestamp =>
       (n: Array[String], v: Any) => FilterApi.lt(
         longColumn(n),
@@ -310,7 +312,7 @@ class ParquetFilters(
         FilterApi.ltEq(binaryColumn(n), 
Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
     case ParquetDateType if pushDownDate =>
       (n: Array[String], v: Any) =>
-        FilterApi.ltEq(intColumn(n), 
dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer])
+        FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer])
     case ParquetTimestampMicrosType if pushDownTimestamp =>
       (n: Array[String], v: Any) => FilterApi.ltEq(
         longColumn(n),
@@ -351,7 +353,7 @@ class ParquetFilters(
         FilterApi.gt(binaryColumn(n), 
Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
     case ParquetDateType if pushDownDate =>
       (n: Array[String], v: Any) =>
-        FilterApi.gt(intColumn(n), 
dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer])
+        FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer])
     case ParquetTimestampMicrosType if pushDownTimestamp =>
       (n: Array[String], v: Any) => FilterApi.gt(
         longColumn(n),
@@ -392,7 +394,7 @@ class ParquetFilters(
         FilterApi.gtEq(binaryColumn(n), 
Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
     case ParquetDateType if pushDownDate =>
       (n: Array[String], v: Any) =>
-        FilterApi.gtEq(intColumn(n), 
dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer])
+        FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer])
     case ParquetTimestampMicrosType if pushDownTimestamp =>
       (n: Array[String], v: Any) => FilterApi.gtEq(
         longColumn(n),
@@ -471,7 +473,8 @@ class ParquetFilters(
       case ParquetDoubleType => value.isInstanceOf[JDouble]
       case ParquetStringType => value.isInstanceOf[String]
       case ParquetBinaryType => value.isInstanceOf[Array[Byte]]
-      case ParquetDateType => value.isInstanceOf[Date]
+      case ParquetDateType =>
+        value.isInstanceOf[Date] || value.isInstanceOf[LocalDate]
       case ParquetTimestampMicrosType | ParquetTimestampMillisType =>
         value.isInstanceOf[Timestamp]
       case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index d1161e3..20bfb32 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
 import java.math.{BigDecimal => JBigDecimal}
 import java.nio.charset.StandardCharsets
 import java.sql.{Date, Timestamp}
+import java.time.LocalDate
 
 import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, 
Operators}
 import org.apache.parquet.filter2.predicate.FilterApi._
@@ -525,52 +526,62 @@ abstract class ParquetFilterSuite extends QueryTest with 
ParquetTest with Shared
       def date: Date = Date.valueOf(s)
     }
 
-    val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", 
"2018-03-21").map(_.date)
+    val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21")
     import testImplicits._
-    withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF, 
colName, resultFun) =>
-      withParquetDataFrame(inputDF) { implicit df =>
-        val dateAttr: Expression = df(colName).expr
-        assert(df(colName).expr.dataType === DateType)
 
-        checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
-        checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]],
-          data.map(i => Row.apply(resultFun(i))))
-
-        checkFilterPredicate(dateAttr === "2018-03-18".date, classOf[Eq[_]],
-          resultFun("2018-03-18".date))
-        checkFilterPredicate(dateAttr <=> "2018-03-18".date, classOf[Eq[_]],
-          resultFun("2018-03-18".date))
-        checkFilterPredicate(dateAttr =!= "2018-03-18".date, classOf[NotEq[_]],
-          Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => 
Row.apply(resultFun(i.date))))
-
-        checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]],
-          resultFun("2018-03-18".date))
-        checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]],
-          resultFun("2018-03-21".date))
-        checkFilterPredicate(dateAttr <= "2018-03-18".date, classOf[LtEq[_]],
-          resultFun("2018-03-18".date))
-        checkFilterPredicate(dateAttr >= "2018-03-21".date, classOf[GtEq[_]],
-          resultFun("2018-03-21".date))
-
-        checkFilterPredicate(Literal("2018-03-18".date) === dateAttr, 
classOf[Eq[_]],
-          resultFun("2018-03-18".date))
-        checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr, 
classOf[Eq[_]],
-          resultFun("2018-03-18".date))
-        checkFilterPredicate(Literal("2018-03-19".date) > dateAttr, 
classOf[Lt[_]],
-          resultFun("2018-03-18".date))
-        checkFilterPredicate(Literal("2018-03-20".date) < dateAttr, 
classOf[Gt[_]],
-          resultFun("2018-03-21".date))
-        checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr, 
classOf[LtEq[_]],
-          resultFun("2018-03-18".date))
-        checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr, 
classOf[GtEq[_]],
-          resultFun("2018-03-21".date))
-
-        checkFilterPredicate(!(dateAttr < "2018-03-21".date), classOf[GtEq[_]],
-          resultFun("2018-03-21".date))
-        checkFilterPredicate(
-          dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date,
-          classOf[Operators.Or],
-          Seq(Row(resultFun("2018-03-18".date)), 
Row(resultFun("2018-03-21".date))))
+    Seq(false, true).foreach { java8Api =>
+      withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) {
+        val df = data.map(i => Tuple1(Date.valueOf(i))).toDF()
+        withNestedDataFrame(df) { case (inputDF, colName, fun) =>
+          def resultFun(dateStr: String): Any = {
+            val parsed = if (java8Api) LocalDate.parse(dateStr) else 
Date.valueOf(dateStr)
+            fun(parsed)
+          }
+          withParquetDataFrame(inputDF) { implicit df =>
+            val dateAttr: Expression = df(colName).expr
+            assert(df(colName).expr.dataType === DateType)
+
+            checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], 
Seq.empty[Row])
+            checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]],
+              data.map(i => Row.apply(resultFun(i))))
+
+            checkFilterPredicate(dateAttr === "2018-03-18".date, 
classOf[Eq[_]],
+              resultFun("2018-03-18"))
+            checkFilterPredicate(dateAttr <=> "2018-03-18".date, 
classOf[Eq[_]],
+              resultFun("2018-03-18"))
+            checkFilterPredicate(dateAttr =!= "2018-03-18".date, 
classOf[NotEq[_]],
+              Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => 
Row.apply(resultFun(i))))
+
+            checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]],
+              resultFun("2018-03-18"))
+            checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]],
+              resultFun("2018-03-21"))
+            checkFilterPredicate(dateAttr <= "2018-03-18".date, 
classOf[LtEq[_]],
+              resultFun("2018-03-18"))
+            checkFilterPredicate(dateAttr >= "2018-03-21".date, 
classOf[GtEq[_]],
+              resultFun("2018-03-21"))
+
+            checkFilterPredicate(Literal("2018-03-18".date) === dateAttr, 
classOf[Eq[_]],
+              resultFun("2018-03-18"))
+            checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr, 
classOf[Eq[_]],
+              resultFun("2018-03-18"))
+            checkFilterPredicate(Literal("2018-03-19".date) > dateAttr, 
classOf[Lt[_]],
+              resultFun("2018-03-18"))
+            checkFilterPredicate(Literal("2018-03-20".date) < dateAttr, 
classOf[Gt[_]],
+              resultFun("2018-03-21"))
+            checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr, 
classOf[LtEq[_]],
+              resultFun("2018-03-18"))
+            checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr, 
classOf[GtEq[_]],
+              resultFun("2018-03-21"))
+
+            checkFilterPredicate(!(dateAttr < "2018-03-21".date), 
classOf[GtEq[_]],
+              resultFun("2018-03-21"))
+            checkFilterPredicate(
+              dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date,
+              classOf[Operators.Or],
+              Seq(Row(resultFun("2018-03-18")), Row(resultFun("2018-03-21"))))
+          }
+        }
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to