This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ebfc6bb [SPARK-36960][SQL] Pushdown filters with ANSI interval values to ORC ebfc6bb is described below commit ebfc6bbe0e9200f87ebb52fb71d009b2d71b956d Author: Kousuke Saruta <saru...@oss.nttdata.com> AuthorDate: Sat Oct 9 16:55:59 2021 +0300 [SPARK-36960][SQL] Pushdown filters with ANSI interval values to ORC ### What changes were proposed in this pull request? This PR proposes to push down filters with ANSI intervals to ORC. ### Why are the changes needed? After SPARK-36931 (#34184), V1 and V2 ORC datasources support ANSI intervals. So it's great to be able to push down filters with ANSI interval values for the better performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests. Closes #34224 from sarutak/orc-ansi-interval-pushdown. Lead-authored-by: Kousuke Saruta <saru...@oss.nttdata.com> Co-authored-by: Kousuke Saruta <saru...@oss.nttdata.co.jp> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../apache/spark/sql/catalyst/dsl/package.scala | 4 +- .../sql/execution/datasources/orc/OrcFilters.scala | 10 ++- .../execution/datasources/orc/OrcFilterSuite.scala | 97 ++++++++++++++++++++++ 3 files changed, 108 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 4a97a8d..979c280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import java.time.{Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate, Period} import scala.language.implicitConversions @@ -167,6 +167,8 @@ package object dsl { implicit def timestampToLiteral(t: Timestamp): Literal = Literal(t) implicit def instantToLiteral(i: Instant): Literal = Literal(i) implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a) + implicit def periodToLiteral(p: Period): Literal = Literal(p) + implicit def durationToLiteral(d: Duration): Literal = Literal(d) implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute = analysis.UnresolvedAttribute(s.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 5abfa4c..8e02fc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc -import java.time.{Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate, Period} import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter @@ -140,7 +141,8 @@ private[sql] object OrcFilters extends OrcFiltersBase { */ def getPredicateLeafType(dataType: DataType): PredicateLeaf.Type = dataType match { case BooleanType => PredicateLeaf.Type.BOOLEAN - case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG + case ByteType | ShortType | IntegerType | LongType | + _: AnsiIntervalType => PredicateLeaf.Type.LONG case FloatType | DoubleType => PredicateLeaf.Type.FLOAT case StringType => PredicateLeaf.Type.STRING case DateType => PredicateLeaf.Type.DATE @@ -166,6 +168,10 @@ private[sql] object OrcFilters extends OrcFiltersBase { toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) case _: TimestampType if value.isInstanceOf[Instant] => toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) + case _: YearMonthIntervalType => + IntervalUtils.periodToMonths(value.asInstanceOf[Period]).longValue() + case _: DayTimeIntervalType => + IntervalUtils.durationToMicros(value.asInstanceOf[Duration]) case _ => value } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 681ed91..c53cc10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.math.MathContext import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.time.{Duration, Period} import scala.collection.JavaConverters._ @@ -33,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -383,6 +385,101 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } + test("SPARK-36960: filter pushdown - year-month interval") { + DataTypeTestUtils.yearMonthIntervalTypes.foreach { ymIntervalType => + + def periods(i: Int): Expression = Literal(Period.of(i, i, 0)).cast(ymIntervalType) + + val baseDF = spark.createDataFrame((1 to 4).map { i => + Tuple1.apply(Period.of(i, i, 0)) + }).select(col("_1").cast(ymIntervalType)) + + withNestedOrcDataFrame(baseDF) { + case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF + + val ymIntervalAttr = df(colName).expr + assert(df(colName).expr.dataType === ymIntervalType) + + checkFilterPredicate(ymIntervalAttr.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate(ymIntervalAttr === periods(1), + PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(ymIntervalAttr <=> periods(1), + PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(ymIntervalAttr < periods(2), + PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(ymIntervalAttr > periods(3), + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(ymIntervalAttr <= periods(1), + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(ymIntervalAttr >= periods(4), + PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(periods(1) === ymIntervalAttr, + PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(periods(1) <=> ymIntervalAttr, + PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(periods(2) > ymIntervalAttr, + PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(periods(3) < ymIntervalAttr, + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(periods(1) >= ymIntervalAttr, + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(periods(4) <= ymIntervalAttr, + PredicateLeaf.Operator.LESS_THAN) + } + } + } + + test("SPARK-36960: filter pushdown - day-time interval") { + DataTypeTestUtils.dayTimeIntervalTypes.foreach { dtIntervalType => + + def durations(i: Int): Expression = + Literal(Duration.ofDays(i).plusHours(i).plusMinutes(i).plusSeconds(i)).cast(dtIntervalType) + + val baseDF = spark.createDataFrame((1 to 4).map { i => + Tuple1.apply(Duration.ofDays(i).plusHours(i).plusMinutes(i).plusSeconds(i)) + }).select(col("_1").cast(dtIntervalType)) + + withNestedOrcDataFrame(baseDF) { + case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF + + val ymIntervalAttr = df(colName).expr + assert(df(colName).expr.dataType === dtIntervalType) + + checkFilterPredicate(ymIntervalAttr.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate(ymIntervalAttr === durations(1), + PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(ymIntervalAttr <=> durations(1), + PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(ymIntervalAttr < durations(2), + PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(ymIntervalAttr > durations(3), + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(ymIntervalAttr <= durations(1), + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(ymIntervalAttr >= durations(4), + PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(durations(1) === ymIntervalAttr, + PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(durations(1) <=> ymIntervalAttr, + PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(durations(2) > ymIntervalAttr, + PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(durations(3) < ymIntervalAttr, + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(durations(1) >= ymIntervalAttr, + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(durations(4) <= ymIntervalAttr, + PredicateLeaf.Operator.LESS_THAN) + } + } + } + test("no filter pushdown - non-supported types") { implicit class IntToBinary(int: Int) { def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org