Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/21603#discussion_r202506240
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
---
@@ -803,6 +804,67 @@ class ParquetFilterSuite extends QueryTest with
ParquetTest with SharedSQLContex
// Test inverseCanDrop() has taken effect
testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value
not like '10%'")
}
+
+ test("SPARK-17091: Convert IN predicate to Parquet filter push-down") {
+ val schema = StructType(Seq(
+ StructField("a", IntegerType, nullable = false)
+ ))
+
+ val parquetSchema = new
SparkToParquetSchemaConverter(conf).convert(schema)
+
+ assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) {
+ parquetFilters.createFilter(parquetSchema, sources.In("a",
Array(null)))
+ }
+
+ assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) {
+ parquetFilters.createFilter(parquetSchema, sources.In("a",
Array(10)))
+ }
+
+ // Remove duplicates
+ assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) {
+ parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10,
10)))
+ }
+
+ assertResult(Some(or(or(
+ FilterApi.eq(intColumn("a"), 10: Integer),
+ FilterApi.eq(intColumn("a"), 20: Integer)),
+ FilterApi.eq(intColumn("a"), 30: Integer)))
+ ) {
+ parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10,
20, 30)))
+ }
+
+ assert(parquetFilters.createFilter(parquetSchema, sources.In("a",
+ Range(0,
conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined)
+ assert(parquetFilters.createFilter(parquetSchema, sources.In("a",
+ Range(0, conf.parquetFilterPushDownInFilterThreshold +
1).toArray)).isEmpty)
+
+ import testImplicits._
+ withTempPath { path =>
+ val data = 0 to 1024
+ data.toDF("a").selectExpr("if (a = 1024, null, a) AS a") // convert
1024 to null
+ .coalesce(1).write.option("parquet.block.size", 512)
+ .parquet(path.getAbsolutePath)
+ val df = spark.read.parquet(path.getAbsolutePath)
+ Seq(true, false).foreach { pushEnabled =>
+ withSQLConf(
+ SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key ->
pushEnabled.toString) {
+ Seq(1, 5, 10, 11).foreach { count =>
+ val filter = s"a in(${Range(0, count).mkString(",")})"
+ assert(df.where(filter).count() === count)
+ val actual =
stripSparkFilter(df.where(filter)).collect().length
+ if (pushEnabled && count <=
conf.parquetFilterPushDownInFilterThreshold) {
+ assert(actual > 1 && actual < data.length)
--- End diff --
ah okay this tests block level filtering. lgtm
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]