Github user wangyum commented on a diff in the pull request:
https://github.com/apache/spark/pull/21556#discussion_r199442189
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
---
@@ -359,6 +369,70 @@ class ParquetFilterSuite extends QueryTest with
ParquetTest with SharedSQLContex
}
}
+ test("filter pushdown - decimal") {
+ Seq(true, false).foreach { legacyFormat =>
+ withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key ->
legacyFormat.toString) {
+ Seq(s"_1 decimal(${Decimal.MAX_INT_DIGITS}, 2)", //
32BitDecimalType
+ s"_1 decimal(${Decimal.MAX_LONG_DIGITS}, 2)", //
64BitDecimalType
+ "_1 decimal(38, 18)" //
ByteArrayDecimalType
+ ).foreach { schemaDDL =>
+ val schema = StructType.fromDDL(schemaDDL)
+ val rdd =
+ spark.sparkContext.parallelize((1 to 4).map(i => Row(new
java.math.BigDecimal(i))))
+ val dataFrame = spark.createDataFrame(rdd, schema)
+ testDecimalPushDown(dataFrame) { implicit df =>
+ assert(df.schema === schema)
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]],
Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to
4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
+ checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1)
+ checkFilterPredicate('_1 =!= 1, classOf[NotEq[_]], (2 to
4).map(Row.apply(_)))
+
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
+ checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 < 2 || '_1 > 3,
classOf[Operators.Or], Seq(Row(1), Row(4)))
+ }
+ }
+ }
+ }
+ }
+
+ test("incompatible parquet file format will throw exeception") {
--- End diff --
Have create a PR: https://github.com/apache/spark/pull/21696
After this PR. Support decimal should be like this:
https://github.com/wangyum/spark/blob/refactor-decimal-pushdown/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala#L118-L146
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]