Github user wangyum commented on a diff in the pull request: https://github.com/apache/spark/pull/21603#discussion_r202500542 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala --- @@ -747,6 +748,66 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // Test inverseCanDrop() has taken effect testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value not like '10%'") } + + test("SPARK-17091: Convert IN predicate to Parquet filter push-down") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(null))) + } + + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10))) + } + + // Remove duplicates + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 10))) + } + + assertResult(Some(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer))) + ) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 20))) + } + + assertResult(Some(or(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer)), + FilterApi.eq(intColumn("a"), 30: Integer))) + ) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 20, 30))) + } + + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined) + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold + 1).toArray)).isEmpty) + + import testImplicits._ + withTempPath { path => + (0 to 1024).toDF("a").selectExpr("if (a = 1024, null, a) AS a") // convert 1024 to null + .coalesce(1).write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + Seq(true, false).foreach { pushEnabled => --- End diff -- Updated to: ```scala val actual = stripSparkFilter(df.where(filter)).collect().length if (pushEnabled && count <= conf.parquetFilterPushDownInFilterThreshold) { assert(actual > 1 && actual < data.length) } else { assert(actual === data.length) } ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org