This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new c4c998cf8 feat: pushdown filter for native_iceberg_compat (#1566) c4c998cf8 is described below commit c4c998cf88d39ece09e0d9cc4be39b749bd26072 Author: Zhen Wang <643348...@qq.com> AuthorDate: Tue Apr 1 05:59:22 2025 +0800 feat: pushdown filter for native_iceberg_compat (#1566) * feat: pushdown filter for native_iceberg_compat * fix style * add data schema * fix filter bound * fix in expr * add primitive type tests * enable native_datafusion test --- .../main/java/org/apache/comet/parquet/Native.java | 2 + .../apache/comet/parquet/NativeBatchReader.java | 29 +++- native/core/src/execution/planner.rs | 2 +- native/core/src/parquet/mod.rs | 25 ++- .../comet/parquet/CometParquetFileFormat.scala | 51 +++--- .../CometParquetPartitionReaderFactory.scala | 1 + .../org/apache/comet/parquet/ParquetFilters.scala | 156 ++++++++++++++++++ .../apache/comet/parquet/SourceFilterSerde.scala | 175 +++++++++++++++++++++ .../apache/comet/parquet/ParquetReadSuite.scala | 71 ++++++++- .../scala/org/apache/spark/sql/CometTestBase.scala | 11 +- 10 files changed, 490 insertions(+), 33 deletions(-) diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java b/common/src/main/java/org/apache/comet/parquet/Native.java index d0b0f951d..4e50190f8 100644 --- a/common/src/main/java/org/apache/comet/parquet/Native.java +++ b/common/src/main/java/org/apache/comet/parquet/Native.java @@ -253,7 +253,9 @@ public final class Native extends NativeBase { long fileSize, long start, long length, + byte[] filter, byte[] requiredSchema, + byte[] dataSchema, String sessionTimezone); // arrow native version of read batch diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java index 0051b412c..4f6991c5d 100644 --- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -108,6 +108,7 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme private final Map<String, SQLMetric> metrics; private StructType sparkSchema; + private StructType dataSchema; private MessageType requestedSchema; private CometVector[] vectors; private AbstractColumnReader[] columnReaders; @@ -117,6 +118,7 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme private boolean[] missingColumns; private boolean isInitialized; private ParquetMetadata footer; + private byte[] nativeFilter; /** * Whether the native scan should always return decimal represented by 128 bits, regardless of its @@ -190,8 +192,10 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme Configuration conf, PartitionedFile inputSplit, ParquetMetadata footer, + byte[] nativeFilter, int capacity, StructType sparkSchema, + StructType dataSchema, boolean isCaseSensitive, boolean useFieldId, boolean ignoreMissingIds, @@ -202,6 +206,7 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme this.conf = conf; this.capacity = capacity; this.sparkSchema = sparkSchema; + this.dataSchema = dataSchema; this.isCaseSensitive = isCaseSensitive; this.useFieldId = useFieldId; this.ignoreMissingIds = ignoreMissingIds; @@ -210,6 +215,7 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme this.partitionValues = partitionValues; this.file = inputSplit; this.footer = footer; + this.nativeFilter = nativeFilter; this.metrics = metrics; this.taskContext = TaskContext$.MODULE$.get(); } @@ -262,10 +268,9 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme String timeZoneId = conf.get("spark.sql.session.timeZone"); // Native code uses "UTC" always as the timeZoneId when converting from spark to arrow schema. Schema arrowSchema = Utils$.MODULE$.toArrowSchema(sparkSchema, "UTC"); - ByteArrayOutputStream out = new ByteArrayOutputStream(); - WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out)); - MessageSerializer.serialize(writeChannel, arrowSchema); - byte[] serializedRequestedArrowSchema = out.toByteArray(); + byte[] serializedRequestedArrowSchema = serializeArrowSchema(arrowSchema); + Schema dataArrowSchema = Utils$.MODULE$.toArrowSchema(dataSchema, "UTC"); + byte[] serializedDataArrowSchema = serializeArrowSchema(dataArrowSchema); //// Create Column readers List<ColumnDescriptor> columns = requestedSchema.getColumns(); @@ -350,7 +355,14 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme this.handle = Native.initRecordBatchReader( - filePath, fileSize, start, length, serializedRequestedArrowSchema, timeZoneId); + filePath, + fileSize, + start, + length, + nativeFilter, + serializedRequestedArrowSchema, + serializedDataArrowSchema, + timeZoneId); isInitialized = true; } @@ -524,4 +536,11 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme return Option.apply(null); // None } } + + private byte[] serializeArrowSchema(Schema schema) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out)); + MessageSerializer.serialize(writeChannel, schema); + return out.toByteArray(); + } } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 60803dfeb..851237a9a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -230,7 +230,7 @@ impl PhysicalPlanner { } /// Create a DataFusion physical expression from Spark physical expression - fn create_expr( + pub(crate) fn create_expr( &self, spark_expr: &Expr, input_schema: SchemaRef, diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index bcf73fc54..9289d9d42 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -45,6 +45,8 @@ use jni::{ use self::util::jni::TypePromotionInfo; use crate::execution::operators::ExecutionError; +use crate::execution::planner::PhysicalPlanner; +use crate::execution::serde; use crate::execution::utils::SparkArrowConvert; use crate::parquet::data_type::AsBytes; use crate::parquet::parquet_exec::init_datasource_exec; @@ -644,7 +646,9 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat file_size: jlong, start: jlong, length: jlong, + filter: jbyteArray, required_schema: jbyteArray, + data_schema: jbyteArray, session_timezone: jstring, ) -> jlong { try_unwrap_or_throw(&e, |mut env| unsafe { @@ -666,6 +670,23 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat let required_schema_buffer = env.convert_byte_array(&required_schema_array)?; let required_schema = Arc::new(deserialize_schema(required_schema_buffer.as_bytes())?); + let data_schema_array = JByteArray::from_raw(data_schema); + let data_schema_buffer = env.convert_byte_array(&data_schema_array)?; + let data_schema = Arc::new(deserialize_schema(data_schema_buffer.as_bytes())?); + + let planer = PhysicalPlanner::default(); + + let data_filters = if !filter.is_null() { + let filter_array = JByteArray::from_raw(filter); + let filter_buffer = env.convert_byte_array(&filter_array)?; + let filter_expr = serde::deserialize_expr(filter_buffer.as_slice())?; + Some(vec![ + planer.create_expr(&filter_expr, Arc::clone(&data_schema))? + ]) + } else { + None + }; + let file_groups = get_file_groups_single_file(&object_store_path, file_size as u64, start, length); @@ -676,13 +697,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat let scan = init_datasource_exec( required_schema, - None, + Some(data_schema), None, None, object_store_url, file_groups, None, - None, + data_filters, session_timezone.as_str(), )?; diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala index a3e917f79..b67f99ad8 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala @@ -114,36 +114,33 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with footerFileMetaData, datetimeRebaseModeInRead) - val pushed = if (parquetFilterPushDown) { - val parquetSchema = footerFileMetaData.getSchema - val parquetFilters = new ParquetFilters( - parquetSchema, - pushDownDate, - pushDownTimestamp, - pushDownDecimal, - pushDownStringPredicate, - pushDownInFilterThreshold, - isCaseSensitive, - datetimeRebaseSpec) - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can - // be converted (`ParquetFilters.createFilter` returns an `Option`). That's why a - // `flatMap` is used here. - .flatMap(parquetFilters.createFilter) - .reduceOption(FilterApi.and) - } else { - None - } - pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = new ParquetFilters( + parquetSchema, + dataSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringPredicate, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) val recordBatchReader = if (nativeIcebergCompat) { + val pushed = if (parquetFilterPushDown) { + parquetFilters.createNativeFilters(filters) + } else { + None + } val batchReader = new NativeBatchReader( sharedConf, file, footer, + pushed.orNull, capacity, requiredSchema, + dataSchema, isCaseSensitive, useFieldId, ignoreMissingIds, @@ -160,6 +157,18 @@ class CometParquetFileFormat extends ParquetFileFormat with MetricsSupport with } batchReader } else { + val pushed = if (parquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates + // can be converted (`ParquetFilters.createFilter` returns an `Option`). That's why + // a `flatMap` is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p)) + val batchReader = new BatchReader( sharedConf, file, diff --git a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala index 01ccde94b..4dc099735 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala @@ -199,6 +199,7 @@ case class CometParquetPartitionReaderFactory( val parquetSchema = footerFileMetaData.getSchema val parquetFilters = new ParquetFilters( parquetSchema, + readDataSchema, pushDownDate, pushDownTimestamp, pushDownDecimal, diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala index 30d7804e8..54db9308c 100644 --- a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala +++ b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala @@ -19,12 +19,14 @@ package org.apache.comet.parquet +import java.io.ByteArrayOutputStream import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} import java.time.{Duration, Instant, LocalDate, Period} import java.util.Locale +import scala.collection.JavaConverters._ import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics} @@ -39,8 +41,11 @@ import org.apache.parquet.schema.Type.Repetition import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} import org.apache.spark.sql.sources +import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String +import org.apache.comet.parquet.SourceFilterSerde.{createBinaryExpr, createNameExpr, createUnaryExpr, createValueExpr} +import org.apache.comet.serde.ExprOuterClass import org.apache.comet.shims.ShimSQLConf /** @@ -51,6 +56,7 @@ import org.apache.comet.shims.ShimSQLConf */ class ParquetFilters( schema: MessageType, + dataSchema: StructType, pushDownDate: Boolean, pushDownTimestamp: Boolean, pushDownDecimal: Boolean, @@ -876,4 +882,154 @@ class ParquetFilters( case _ => None } } + + def createNativeFilters(predicates: Seq[sources.Filter]): Option[Array[Byte]] = { + predicates.reduceOption(sources.And).flatMap(createNativeFilter).map { expr => + val outputStream = new ByteArrayOutputStream() + expr.writeTo(outputStream) + outputStream.close() + outputStream.toByteArray + } + } + + private def createNativeFilter(predicate: sources.Filter): Option[ExprOuterClass.Expr] = { + def nameUnaryExpr(name: String)( + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + createNameExpr(name, dataSchema).map { case (_, childExpr) => + createUnaryExpr(childExpr, f) + } + } + + def nameValueBinaryExpr(name: String, value: Any)( + f: ( + ExprOuterClass.Expr.Builder, + ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + createNameExpr(name, dataSchema).flatMap { case (dataType, childExpr) => + createValueExpr(value, dataType).map(createBinaryExpr(childExpr, _, f)) + } + } + + predicate match { + case sources.IsNull(name) if canMakeFilterOn(name, null) => + nameUnaryExpr(name) { (builder, unaryExpr) => + builder.setIsNull(unaryExpr) + } + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => + nameUnaryExpr(name) { (builder, unaryExpr) => + builder.setIsNotNull(unaryExpr) + } + + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => + nameValueBinaryExpr(name, value) { (builder, binaryExpr) => + builder.setEq(binaryExpr) + } + + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => + nameValueBinaryExpr(name, value) { (builder, binaryExpr) => + builder.setNeq(binaryExpr) + } + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => + nameValueBinaryExpr(name, value) { (builder, binaryExpr) => + builder.setEqNullSafe(binaryExpr) + } + + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => + nameValueBinaryExpr(name, value) { (builder, binaryExpr) => + builder.setNeqNullSafe(binaryExpr) + } + + case sources.LessThan(name, value) if (value != null) && canMakeFilterOn(name, value) => + nameValueBinaryExpr(name, value) { (builder, binaryExpr) => + builder.setLt(binaryExpr) + } + + case sources.LessThanOrEqual(name, value) + if (value != null) && canMakeFilterOn(name, value) => + nameValueBinaryExpr(name, value) { (builder, binaryExpr) => + builder.setLtEq(binaryExpr) + } + + case sources.GreaterThan(name, value) if (value != null) && canMakeFilterOn(name, value) => + nameValueBinaryExpr(name, value) { (builder, binaryExpr) => + builder.setGt(binaryExpr) + } + + case sources.GreaterThanOrEqual(name, value) + if (value != null) && canMakeFilterOn(name, value) => + nameValueBinaryExpr(name, value) { (builder, binaryExpr) => + builder.setGtEq(binaryExpr) + } + + case sources.And(lhs, rhs) => + (createNativeFilter(lhs), createNativeFilter(rhs)) match { + case (Some(leftExpr), Some(rightExpr)) => + Some( + createBinaryExpr( + leftExpr, + rightExpr, + (builder, binaryExpr) => builder.setAnd(binaryExpr))) + case _ => None + } + + case sources.Or(lhs, rhs) => + (createNativeFilter(lhs), createNativeFilter(rhs)) match { + case (Some(leftExpr), Some(rightExpr)) => + Some( + createBinaryExpr( + leftExpr, + rightExpr, + (builder, binaryExpr) => builder.setOr(binaryExpr))) + case _ => None + } + + case sources.Not(pred) => + val childExpr = createNativeFilter(pred) + childExpr.map { expr => + createUnaryExpr(expr, (builder, unaryExpr) => builder.setNot(unaryExpr)) + } + + case sources.In(name, values) + if pushDownInFilterThreshold > 0 && values.nonEmpty && + canMakeFilterOn(name, values.head) => + createNameExpr(name, dataSchema).flatMap { case (dataType, nameExpr) => + val valueExprs = values.flatMap(createValueExpr(_, dataType)) + if (valueExprs.length != values.length) { + None + } else { + val builder = ExprOuterClass.In.newBuilder() + builder.setInValue(nameExpr) + builder.addAllLists(valueExprs.toSeq.asJava) + builder.setNegated(false) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIn(builder) + .build()) + } + } + + case sources.StringStartsWith(name, prefix) + if pushDownStringPredicate && canMakeFilterOn(name, prefix) => + nameValueBinaryExpr(name, prefix) { (builder, binaryExpr) => + builder.setStartsWith(binaryExpr) + } + + case sources.StringEndsWith(name, suffix) + if pushDownStringPredicate && canMakeFilterOn(name, suffix) => + nameValueBinaryExpr(name, suffix) { (builder, binaryExpr) => + builder.setEndsWith(binaryExpr) + } + + case sources.StringContains(name, value) + if pushDownStringPredicate && canMakeFilterOn(name, value) => + nameValueBinaryExpr(name, value) { (builder, binaryExpr) => + builder.setContains(binaryExpr) + } + + case _ => None + } + } } diff --git a/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala b/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala new file mode 100644 index 000000000..4ad467cd8 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate, LocalDateTime} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +import org.apache.comet.serde.ExprOuterClass +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.serializeDataType + +object SourceFilterSerde extends Logging { + + def createNameExpr( + name: String, + schema: StructType): Option[(DataType, ExprOuterClass.Expr)] = { + val filedWithIndex = schema.fields.zipWithIndex.find { case (field, _) => + field.name == name + } + if (filedWithIndex.isDefined) { + val (field, index) = filedWithIndex.get + val dataType = serializeDataType(field.dataType) + if (dataType.isDefined) { + val boundExpr = ExprOuterClass.BoundReference + .newBuilder() + .setIndex(index) + .setDatatype(dataType.get) + .build() + Some( + field.dataType, + ExprOuterClass.Expr + .newBuilder() + .setBound(boundExpr) + .build()) + } else { + None + } + } else { + None + } + + } + + /** + * create a literal value native expression for source filter value, the value is a scala value + */ + def createValueExpr(value: Any, dataType: DataType): Option[ExprOuterClass.Expr] = { + val exprBuilder = ExprOuterClass.Literal.newBuilder() + var valueIsSet = true + if (value == null) { + exprBuilder.setIsNull(true) + } else { + exprBuilder.setIsNull(false) + // value is a scala value, not a catalyst value + // refer to org.apache.spark.sql.catalyst.CatalystTypeConverters.CatalystTypeConverter#toScala + dataType match { + case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) + case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) + case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) + case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) + case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) + case _: StringType => exprBuilder.setStringVal(value.asInstanceOf[String]) + case _: TimestampType => + value match { + case v: Timestamp => exprBuilder.setLongVal(DateTimeUtils.fromJavaTimestamp(v)) + case v: Instant => exprBuilder.setLongVal(DateTimeUtils.instantToMicros(v)) + case v: Long => exprBuilder.setLongVal(v) + case _ => + valueIsSet = false + logWarning(s"Unexpected timestamp type '${value.getClass}' for value '$value'") + } + case _: TimestampNTZType => + value match { + case v: LocalDateTime => + exprBuilder.setLongVal(DateTimeUtils.localDateTimeToMicros(v)) + case v: Long => exprBuilder.setLongVal(v) + case _ => + valueIsSet = false + logWarning(s"Unexpected timestamp type '${value.getClass}' for value' $value'") + } + case _: DecimalType => + // Pass decimal literal as bytes. + val unscaled = value.asInstanceOf[JavaBigDecimal].unscaledValue + exprBuilder.setDecimalVal(com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray)) + case _: BinaryType => + val byteStr = + com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) + exprBuilder.setBytesVal(byteStr) + case _: DateType => + value match { + case v: LocalDate => exprBuilder.setIntVal(DateTimeUtils.localDateToDays(v)) + case v: Date => exprBuilder.setIntVal(DateTimeUtils.fromJavaDate(v)) + case v: Int => exprBuilder.setIntVal(v) + case _ => + valueIsSet = false + logWarning(s"Unexpected date type '${value.getClass}' for value '$value'") + } + case dt => + valueIsSet = false + logWarning(s"Unexpected data type '$dt' for literal value '$value'") + } + } + + val dt = serializeDataType(dataType) + + if (valueIsSet && dt.isDefined) { + exprBuilder.setDatatype(dt.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setLiteral(exprBuilder) + .build()) + } else { + None + } + } + + def createUnaryExpr( + childExpr: Expr, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => ExprOuterClass.Expr.Builder) + : ExprOuterClass.Expr = { + // create the generic UnaryExpr message + val inner = ExprOuterClass.UnaryExpr + .newBuilder() + .setChild(childExpr) + .build() + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build() + } + + def createBinaryExpr( + leftExpr: Expr, + rightExpr: Expr, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder) + : ExprOuterClass.Expr = { + // create the generic BinaryExpr message + val inner = ExprOuterClass.BinaryExpr + .newBuilder() + .setLeft(leftExpr) + .setRight(rightExpr) + .build() + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build() + } + +} diff --git a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala index f5d0cbb42..d810b4fb8 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala @@ -20,11 +20,12 @@ package org.apache.comet.parquet import java.io.{File, FileFilter} -import java.math.BigDecimal +import java.math.{BigDecimal, BigInteger} import java.time.{ZoneId, ZoneOffset} import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import scala.util.control.Breaks.{break, breakable} import org.scalactic.source.Position import org.scalatest.Tag @@ -1493,6 +1494,74 @@ class ParquetReadV1Suite extends ParquetReadSuite with AdaptiveSparkPlanHelper { }) } } + + test("test V1 parquet scan filter pushdown of primitive types uses native_iceberg_compat") { + withTempPath { dir => + val path = new Path(dir.toURI.toString, "test1.parquet") + val rows = 1000 + withSQLConf( + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_ICEBERG_COMPAT, + CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "false") { + makeParquetFileAllTypes(path, dictionaryEnabled = false, 0, rows, nullEnabled = false) + } + Seq( + (CometConf.SCAN_NATIVE_DATAFUSION, "output_rows"), + (CometConf.SCAN_NATIVE_ICEBERG_COMPAT, "numOutputRows")).foreach { + case (scanMode, metricKey) => + Seq(true, false).foreach { pushDown => + breakable { + withSQLConf( + CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanMode, + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> pushDown.toString) { + if (scanMode == CometConf.SCAN_NATIVE_DATAFUSION && !pushDown) { + // FIXME: native_datafusion always pushdown data filters + break() + } + Seq( + ("_1 = true", Math.ceil(rows.toDouble / 2)), // Boolean + ("_2 = 1", Math.ceil(rows.toDouble / 256)), // Byte + ("_3 = 1", 1), // Short + ("_4 = 1", 1), // Integer + ("_5 = 1", 1), // Long + ("_6 = 1.0", 1), // Float + ("_7 = 1.0", 1), // Double + (s"_8 = '${1.toString * 48}'", 1), // String + ("_21 = to_binary('1', 'utf-8')", 1), // binary + ("_15 = 0.0", 1), // DECIMAL(5, 2) + ("_16 = 0.0", 1), // DECIMAL(18, 10) + ( + s"_17 = ${new BigDecimal(new BigInteger(("1" * 16).getBytes), 37).toString}", + Math.ceil(rows.toDouble / 10) + ), // DECIMAL(38, 37) + (s"_19 = TIMESTAMP '${DateTimeUtils.toJavaTimestamp(1)}'", 1), // Timestamp + ("_20 = DATE '1970-01-02'", 1) // Date + ).foreach { case (whereCause, expectedRows) => + val df = spark.read + .parquet(path.toString) + .where(whereCause) + val (_, cometPlan) = checkSparkAnswer(df) + val scan = collect(cometPlan) { + case p: CometScanExec => + assert(p.dataFilters.nonEmpty) + p + case p: CometNativeScanExec => + assert(p.dataFilters.nonEmpty) + p + } + assert(scan.size == 1) + + if (pushDown) { + assert(scan.head.metrics(metricKey).value == expectedRows) + } else { + assert(scan.head.metrics(metricKey).value == rows) + } + } + } + } + } + } + } + } } class ParquetReadV2Suite extends ParquetReadSuite with AdaptiveSparkPlanHelper { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 7bdec4215..3891b9112 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -461,6 +461,7 @@ abstract class CometTestBase | optional INT64 _18(TIMESTAMP(MILLIS,true)); | optional INT64 _19(TIMESTAMP(MICROS,true)); | optional INT32 _20(DATE); + | optional binary _21; | optional INT32 _id; |} """.stripMargin @@ -487,6 +488,7 @@ abstract class CometTestBase | optional INT64 _18(TIMESTAMP(MILLIS,true)); | optional INT64 _19(TIMESTAMP(MICROS,true)); | optional INT32 _20(DATE); + | optional binary _21; | optional INT32 _id; |} """.stripMargin @@ -498,6 +500,7 @@ abstract class CometTestBase dictionaryEnabled: Boolean, begin: Int, end: Int, + nullEnabled: Boolean = true, pageSize: Int = 128, randomSize: Int = 0): Unit = { // alwaysIncludeUnsignedIntTypes means we include unsignedIntTypes in the test even if the @@ -516,7 +519,7 @@ abstract class CometTestBase val rand = scala.util.Random val data = (begin until end).map { i => - if (rand.nextBoolean()) { + if (nullEnabled && rand.nextBoolean()) { None } else { if (dictionaryEnabled) Some(i % 4) else Some(i) @@ -546,7 +549,8 @@ abstract class CometTestBase record.add(17, i.toLong) record.add(18, i.toLong) record.add(19, i) - record.add(20, idGenerator.getAndIncrement()) + record.add(20, i.toString) + record.add(21, idGenerator.getAndIncrement()) case _ => } writer.write(record) @@ -574,7 +578,8 @@ abstract class CometTestBase record.add(17, i) record.add(18, i) record.add(19, i.toInt) - record.add(20, idGenerator.getAndIncrement()) + record.add(20, i.toString) + record.add(21, idGenerator.getAndIncrement()) writer.write(record) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org