This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new c9412307394f [SPARK-46092][SQL][3.3] Don't push down Parquet row group 
filters that overflow
c9412307394f is described below

commit c9412307394fd1a277dd7fd5b173ec34e4b123d6
Author: Johan Lasperas <johan.laspe...@databricks.com>
AuthorDate: Mon Dec 4 12:50:57 2023 -0800

    [SPARK-46092][SQL][3.3] Don't push down Parquet row group filters that 
overflow
    
    This is a cherry-pick from https://github.com/apache/spark/pull/44006 to 
spark 3.3
    
    ### What changes were proposed in this pull request?
    This change adds a check for overflows when creating Parquet row group 
filters on an INT32 (byte/short/int) parquet type to avoid incorrectly skipping 
row groups if the predicate value doesn't fit in an INT. This can happen if the 
read schema is specified as LONG, e.g via `.schema("col LONG")`
    While the Parquet readers don't support reading INT32 into a LONG, the 
overflow can lead to row groups being incorrectly skipped, bypassing the reader 
altogether and producing incorrect results instead of failing.
    
    ### Why are the changes needed?
    Reading a parquet file containing INT32 values with a read schema specified 
as LONG can produce incorrect results today:
    ```
    Seq(0).toDF("a").write.parquet(path)
    spark.read.schema("a LONG").parquet(path).where(s"a < 
${Long.MaxValue}").collect()
    ```
    will return an empty result. The correct result is either:
    - Failing the query if the parquet reader doesn't support upcasting 
integers to longs (all parquet readers in Spark today)
    - Return result `[0]` if the parquet reader supports that upcast (no 
readers in Spark as of now, but I'm looking into adding this capability).
    
    ### Does this PR introduce _any_ user-facing change?
    The following:
    ```
    Seq(0).toDF("a").write.parquet(path)
    spark.read.schema("a LONG").parquet(path).where(s"a < 
${Long.MaxValue}").collect()
    ```
    produces an (incorrect) empty result before this change. After this change, 
the read will fail, raising an error about the unsupported conversion from INT 
to LONG in the parquet reader.
    
    ### How was this patch tested?
    - Added tests to `ParquetFilterSuite` to ensure that no row group filter is 
created when the predicate value overflows or when the value type isn't 
compatible with the parquet type
    - Added test to `ParquetQuerySuite` covering the correctness issue 
described above.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #44156 from johanl-db/SPARK-46092-row-group-skipping-overflow-3.3.
    
    Authored-by: Johan Lasperas <johan.laspe...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../datasources/parquet/ParquetFilters.scala       | 10 ++-
 .../datasources/parquet/ParquetFilterSuite.scala   | 71 ++++++++++++++++++++++
 .../datasources/parquet/ParquetQuerySuite.scala    | 20 ++++++
 3 files changed, 99 insertions(+), 2 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index 210f37d473ad..969fbab746ad 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
-import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, 
Long => JLong}
+import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float 
=> JFloat, Long => JLong, Short => JShort}
 import java.math.{BigDecimal => JBigDecimal}
 import java.nio.charset.StandardCharsets.UTF_8
 import java.sql.{Date, Timestamp}
@@ -600,7 +600,13 @@ class ParquetFilters(
     value == null || (nameToParquetField(name).fieldType match {
       case ParquetBooleanType => value.isInstanceOf[JBoolean]
       case ParquetIntegerType if value.isInstanceOf[Period] => true
-      case ParquetByteType | ParquetShortType | ParquetIntegerType => 
value.isInstanceOf[Number]
+      case ParquetByteType | ParquetShortType | ParquetIntegerType => value 
match {
+        // Byte/Short/Int are all stored as INT32 in Parquet so filters are 
built using type Int.
+        // We don't create a filter if the value would overflow.
+        case _: JByte | _: JShort | _: Integer => true
+        case v: JLong => v.longValue() >= Int.MinValue && v.longValue() <= 
Int.MaxValue
+        case _ => false
+      }
       case ParquetLongType => value.isInstanceOf[JLong] || 
value.isInstanceOf[Duration]
       case ParquetFloatType => value.isInstanceOf[JFloat]
       case ParquetDoubleType => value.isInstanceOf[JDouble]
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index f291e1e71f6c..6fe47fc4132e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.datasources.parquet
 
 import java.io.File
+import java.lang.{Double => JDouble, Float => JFloat, Long => JLong}
 import java.math.{BigDecimal => JBigDecimal}
 import java.nio.charset.StandardCharsets
 import java.sql.{Date, Timestamp}
@@ -897,6 +898,76 @@ abstract class ParquetFilterSuite extends QueryTest with 
ParquetTest with Shared
     }
   }
 
+  test("don't push down filters that would result in overflows") {
+    val schema = StructType(Seq(
+      StructField("cbyte", ByteType),
+      StructField("cshort", ShortType),
+      StructField("cint", IntegerType)
+    ))
+
+    val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema)
+    val parquetFilters = createParquetFilters(parquetSchema)
+
+    for {
+      column <- Seq("cbyte", "cshort", "cint")
+      value <- Seq(JLong.MAX_VALUE, JLong.MIN_VALUE).map(JLong.valueOf)
+    } {
+      val filters = Seq(
+        sources.LessThan(column, value),
+        sources.LessThanOrEqual(column, value),
+        sources.GreaterThan(column, value),
+        sources.GreaterThanOrEqual(column, value),
+        sources.EqualTo(column, value),
+        sources.EqualNullSafe(column, value),
+        sources.Not(sources.EqualTo(column, value)),
+        sources.In(column, Array(value))
+      )
+      for (filter <- filters) {
+        assert(parquetFilters.createFilter(filter).isEmpty,
+          s"Row group filter $filter shouldn't be pushed down.")
+      }
+    }
+  }
+
+  test("don't push down filters when value type doesn't match column type") {
+    val schema = StructType(Seq(
+      StructField("cbyte", ByteType),
+      StructField("cshort", ShortType),
+      StructField("cint", IntegerType),
+      StructField("clong", LongType),
+      StructField("cfloat", FloatType),
+      StructField("cdouble", DoubleType),
+      StructField("cboolean", BooleanType),
+      StructField("cstring", StringType),
+      StructField("cdate", DateType),
+      StructField("ctimestamp", TimestampType),
+      StructField("cbinary", BinaryType),
+      StructField("cdecimal", DecimalType(10, 0))
+    ))
+
+    val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema)
+    val parquetFilters = createParquetFilters(parquetSchema)
+
+    val filters = Seq(
+      sources.LessThan("cbyte", String.valueOf("1")),
+      sources.LessThan("cshort", JBigDecimal.valueOf(1)),
+      sources.LessThan("cint", JFloat.valueOf(JFloat.NaN)),
+      sources.LessThan("clong", String.valueOf("1")),
+      sources.LessThan("cfloat", JDouble.valueOf(1.0D)),
+      sources.LessThan("cdouble", JFloat.valueOf(1.0F)),
+      sources.LessThan("cboolean", String.valueOf("true")),
+      sources.LessThan("cstring", Integer.valueOf(1)),
+      sources.LessThan("cdate", Timestamp.valueOf("2018-01-01 00:00:00")),
+      sources.LessThan("ctimestamp", Date.valueOf("2018-01-01")),
+      sources.LessThan("cbinary", Integer.valueOf(1)),
+      sources.LessThan("cdecimal", Integer.valueOf(1234))
+    )
+    for (filter <- filters) {
+      assert(parquetFilters.createFilter(filter).isEmpty,
+        s"Row group filter $filter shouldn't be pushed down.")
+    }
+  }
+
   test("SPARK-6554: don't push down predicates which reference partition 
columns") {
     import testImplicits._
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 51de8fa04c6d..8ccbec829eed 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -1001,6 +1001,26 @@ abstract class ParquetQuerySuite extends QueryTest with 
ParquetTest with SharedS
     }
   }
 
+  test("row group skipping doesn't overflow when reading into larger type") {
+    withTempPath { path =>
+      Seq(0).toDF("a").write.parquet(path.toString)
+      // The vectorized and non-vectorized readers will produce different 
exceptions, we don't need
+      // to test both as this covers row group skipping.
+      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+        // Reading integer 'a' as a long isn't supported. Check that an 
exception is raised instead
+        // of incorrectly skipping the single row group and producing 
incorrect results.
+        val exception = intercept[SparkException] {
+          spark.read
+            .schema("a LONG")
+            .parquet(path.toString)
+            .where(s"a < ${Long.MaxValue}")
+            .collect()
+        }
+        
assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException])
+      }
+    }
+  }
+
   test("SPARK-36825, SPARK-36852: create table with ANSI intervals") {
     withTable("tbl") {
       sql("create table tbl (c1 interval day, c2 interval year to month) using 
parquet")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to