Github user wangyum commented on a diff in the pull request:
https://github.com/apache/spark/pull/21603#discussion_r202500542
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
---
@@ -747,6 +748,66 @@ class ParquetFilterSuite extends QueryTest with
ParquetTest with SharedSQLContex
// Test inverseCanDrop() has taken effect
testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value
not like '10%'")
}
+
+ test("SPARK-17091: Convert IN predicate to Parquet filter push-down") {
+ val schema = StructType(Seq(
+ StructField("a", IntegerType, nullable = false)
+ ))
+
+ val parquetSchema = new
SparkToParquetSchemaConverter(conf).convert(schema)
+
+ assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) {
+ parquetFilters.createFilter(parquetSchema, sources.In("a",
Array(null)))
+ }
+
+ assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) {
+ parquetFilters.createFilter(parquetSchema, sources.In("a",
Array(10)))
+ }
+
+ // Remove duplicates
+ assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) {
+ parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10,
10)))
+ }
+
+ assertResult(Some(or(
+ FilterApi.eq(intColumn("a"), 10: Integer),
+ FilterApi.eq(intColumn("a"), 20: Integer)))
+ ) {
+ parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10,
20)))
+ }
+
+ assertResult(Some(or(or(
+ FilterApi.eq(intColumn("a"), 10: Integer),
+ FilterApi.eq(intColumn("a"), 20: Integer)),
+ FilterApi.eq(intColumn("a"), 30: Integer)))
+ ) {
+ parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10,
20, 30)))
+ }
+
+ assert(parquetFilters.createFilter(parquetSchema, sources.In("a",
+ Range(0,
conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined)
+ assert(parquetFilters.createFilter(parquetSchema, sources.In("a",
+ Range(0, conf.parquetFilterPushDownInFilterThreshold +
1).toArray)).isEmpty)
+
+ import testImplicits._
+ withTempPath { path =>
+ (0 to 1024).toDF("a").selectExpr("if (a = 1024, null, a) AS a") //
convert 1024 to null
+ .coalesce(1).write.option("parquet.block.size", 512)
+ .parquet(path.getAbsolutePath)
+ val df = spark.read.parquet(path.getAbsolutePath)
+ Seq(true, false).foreach { pushEnabled =>
--- End diff --
Updated to:
```scala
val actual = stripSparkFilter(df.where(filter)).collect().length
if (pushEnabled && count <= conf.parquetFilterPushDownInFilterThreshold) {
assert(actual > 1 && actual < data.length)
} else {
assert(actual === data.length)
}
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]