This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new c4c998cf8 feat: pushdown filter for native_iceberg_compat (#1566)
c4c998cf8 is described below

commit c4c998cf88d39ece09e0d9cc4be39b749bd26072
Author: Zhen Wang <643348...@qq.com>
AuthorDate: Tue Apr 1 05:59:22 2025 +0800

    feat: pushdown filter for native_iceberg_compat (#1566)
    
    * feat: pushdown filter for native_iceberg_compat
    
    * fix style
    
    * add data schema
    
    * fix filter bound
    
    * fix in expr
    
    * add primitive type tests
    
    * enable native_datafusion test
---
 .../main/java/org/apache/comet/parquet/Native.java |   2 +
 .../apache/comet/parquet/NativeBatchReader.java    |  29 +++-
 native/core/src/execution/planner.rs               |   2 +-
 native/core/src/parquet/mod.rs                     |  25 ++-
 .../comet/parquet/CometParquetFileFormat.scala     |  51 +++---
 .../CometParquetPartitionReaderFactory.scala       |   1 +
 .../org/apache/comet/parquet/ParquetFilters.scala  | 156 ++++++++++++++++++
 .../apache/comet/parquet/SourceFilterSerde.scala   | 175 +++++++++++++++++++++
 .../apache/comet/parquet/ParquetReadSuite.scala    |  71 ++++++++-
 .../scala/org/apache/spark/sql/CometTestBase.scala |  11 +-
 10 files changed, 490 insertions(+), 33 deletions(-)

diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java 
b/common/src/main/java/org/apache/comet/parquet/Native.java
index d0b0f951d..4e50190f8 100644
--- a/common/src/main/java/org/apache/comet/parquet/Native.java
+++ b/common/src/main/java/org/apache/comet/parquet/Native.java
@@ -253,7 +253,9 @@ public final class Native extends NativeBase {
       long fileSize,
       long start,
       long length,
+      byte[] filter,
       byte[] requiredSchema,
+      byte[] dataSchema,
       String sessionTimezone);
 
   // arrow native version of read batch
diff --git 
a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java 
b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
index 0051b412c..4f6991c5d 100644
--- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
+++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
@@ -108,6 +108,7 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
   private final Map<String, SQLMetric> metrics;
 
   private StructType sparkSchema;
+  private StructType dataSchema;
   private MessageType requestedSchema;
   private CometVector[] vectors;
   private AbstractColumnReader[] columnReaders;
@@ -117,6 +118,7 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
   private boolean[] missingColumns;
   private boolean isInitialized;
   private ParquetMetadata footer;
+  private byte[] nativeFilter;
 
   /**
    * Whether the native scan should always return decimal represented by 128 
bits, regardless of its
@@ -190,8 +192,10 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
       Configuration conf,
       PartitionedFile inputSplit,
       ParquetMetadata footer,
+      byte[] nativeFilter,
       int capacity,
       StructType sparkSchema,
+      StructType dataSchema,
       boolean isCaseSensitive,
       boolean useFieldId,
       boolean ignoreMissingIds,
@@ -202,6 +206,7 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
     this.conf = conf;
     this.capacity = capacity;
     this.sparkSchema = sparkSchema;
+    this.dataSchema = dataSchema;
     this.isCaseSensitive = isCaseSensitive;
     this.useFieldId = useFieldId;
     this.ignoreMissingIds = ignoreMissingIds;
@@ -210,6 +215,7 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
     this.partitionValues = partitionValues;
     this.file = inputSplit;
     this.footer = footer;
+    this.nativeFilter = nativeFilter;
     this.metrics = metrics;
     this.taskContext = TaskContext$.MODULE$.get();
   }
@@ -262,10 +268,9 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
     String timeZoneId = conf.get("spark.sql.session.timeZone");
     // Native code uses "UTC" always as the timeZoneId when converting from 
spark to arrow schema.
     Schema arrowSchema = Utils$.MODULE$.toArrowSchema(sparkSchema, "UTC");
-    ByteArrayOutputStream out = new ByteArrayOutputStream();
-    WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out));
-    MessageSerializer.serialize(writeChannel, arrowSchema);
-    byte[] serializedRequestedArrowSchema = out.toByteArray();
+    byte[] serializedRequestedArrowSchema = serializeArrowSchema(arrowSchema);
+    Schema dataArrowSchema = Utils$.MODULE$.toArrowSchema(dataSchema, "UTC");
+    byte[] serializedDataArrowSchema = serializeArrowSchema(dataArrowSchema);
 
     //// Create Column readers
     List<ColumnDescriptor> columns = requestedSchema.getColumns();
@@ -350,7 +355,14 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
 
     this.handle =
         Native.initRecordBatchReader(
-            filePath, fileSize, start, length, serializedRequestedArrowSchema, 
timeZoneId);
+            filePath,
+            fileSize,
+            start,
+            length,
+            nativeFilter,
+            serializedRequestedArrowSchema,
+            serializedDataArrowSchema,
+            timeZoneId);
     isInitialized = true;
   }
 
@@ -524,4 +536,11 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
       return Option.apply(null); // None
     }
   }
+
+  private byte[] serializeArrowSchema(Schema schema) throws IOException {
+    ByteArrayOutputStream out = new ByteArrayOutputStream();
+    WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out));
+    MessageSerializer.serialize(writeChannel, schema);
+    return out.toByteArray();
+  }
 }
diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index 60803dfeb..851237a9a 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -230,7 +230,7 @@ impl PhysicalPlanner {
     }
 
     /// Create a DataFusion physical expression from Spark physical expression
-    fn create_expr(
+    pub(crate) fn create_expr(
         &self,
         spark_expr: &Expr,
         input_schema: SchemaRef,
diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs
index bcf73fc54..9289d9d42 100644
--- a/native/core/src/parquet/mod.rs
+++ b/native/core/src/parquet/mod.rs
@@ -45,6 +45,8 @@ use jni::{
 
 use self::util::jni::TypePromotionInfo;
 use crate::execution::operators::ExecutionError;
+use crate::execution::planner::PhysicalPlanner;
+use crate::execution::serde;
 use crate::execution::utils::SparkArrowConvert;
 use crate::parquet::data_type::AsBytes;
 use crate::parquet::parquet_exec::init_datasource_exec;
@@ -644,7 +646,9 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_parquet_Native_initRecordBat
     file_size: jlong,
     start: jlong,
     length: jlong,
+    filter: jbyteArray,
     required_schema: jbyteArray,
+    data_schema: jbyteArray,
     session_timezone: jstring,
 ) -> jlong {
     try_unwrap_or_throw(&e, |mut env| unsafe {
@@ -666,6 +670,23 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_parquet_Native_initRecordBat
         let required_schema_buffer = 
env.convert_byte_array(&required_schema_array)?;
         let required_schema = 
Arc::new(deserialize_schema(required_schema_buffer.as_bytes())?);
 
+        let data_schema_array = JByteArray::from_raw(data_schema);
+        let data_schema_buffer = env.convert_byte_array(&data_schema_array)?;
+        let data_schema = 
Arc::new(deserialize_schema(data_schema_buffer.as_bytes())?);
+
+        let planer = PhysicalPlanner::default();
+
+        let data_filters = if !filter.is_null() {
+            let filter_array = JByteArray::from_raw(filter);
+            let filter_buffer = env.convert_byte_array(&filter_array)?;
+            let filter_expr = 
serde::deserialize_expr(filter_buffer.as_slice())?;
+            Some(vec![
+                planer.create_expr(&filter_expr, Arc::clone(&data_schema))?
+            ])
+        } else {
+            None
+        };
+
         let file_groups =
             get_file_groups_single_file(&object_store_path, file_size as u64, 
start, length);
 
@@ -676,13 +697,13 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_parquet_Native_initRecordBat
 
         let scan = init_datasource_exec(
             required_schema,
-            None,
+            Some(data_schema),
             None,
             None,
             object_store_url,
             file_groups,
             None,
-            None,
+            data_filters,
             session_timezone.as_str(),
         )?;
 
diff --git 
a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala 
b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
index a3e917f79..b67f99ad8 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
@@ -114,36 +114,33 @@ class CometParquetFileFormat extends ParquetFileFormat 
with MetricsSupport with
         footerFileMetaData,
         datetimeRebaseModeInRead)
 
-      val pushed = if (parquetFilterPushDown) {
-        val parquetSchema = footerFileMetaData.getSchema
-        val parquetFilters = new ParquetFilters(
-          parquetSchema,
-          pushDownDate,
-          pushDownTimestamp,
-          pushDownDecimal,
-          pushDownStringPredicate,
-          pushDownInFilterThreshold,
-          isCaseSensitive,
-          datetimeRebaseSpec)
-        filters
-          // Collects all converted Parquet filter predicates. Notice that not 
all predicates can
-          // be converted (`ParquetFilters.createFilter` returns an `Option`). 
That's why a
-          // `flatMap` is used here.
-          .flatMap(parquetFilters.createFilter)
-          .reduceOption(FilterApi.and)
-      } else {
-        None
-      }
-      pushed.foreach(p => ParquetInputFormat.setFilterPredicate(sharedConf, p))
+      val parquetSchema = footerFileMetaData.getSchema
+      val parquetFilters = new ParquetFilters(
+        parquetSchema,
+        dataSchema,
+        pushDownDate,
+        pushDownTimestamp,
+        pushDownDecimal,
+        pushDownStringPredicate,
+        pushDownInFilterThreshold,
+        isCaseSensitive,
+        datetimeRebaseSpec)
 
       val recordBatchReader =
         if (nativeIcebergCompat) {
+          val pushed = if (parquetFilterPushDown) {
+            parquetFilters.createNativeFilters(filters)
+          } else {
+            None
+          }
           val batchReader = new NativeBatchReader(
             sharedConf,
             file,
             footer,
+            pushed.orNull,
             capacity,
             requiredSchema,
+            dataSchema,
             isCaseSensitive,
             useFieldId,
             ignoreMissingIds,
@@ -160,6 +157,18 @@ class CometParquetFileFormat extends ParquetFileFormat 
with MetricsSupport with
           }
           batchReader
         } else {
+          val pushed = if (parquetFilterPushDown) {
+            filters
+              // Collects all converted Parquet filter predicates. Notice that 
not all predicates
+              // can be converted (`ParquetFilters.createFilter` returns an 
`Option`). That's why
+              // a `flatMap` is used here.
+              .flatMap(parquetFilters.createFilter)
+              .reduceOption(FilterApi.and)
+          } else {
+            None
+          }
+          pushed.foreach(p => 
ParquetInputFormat.setFilterPredicate(sharedConf, p))
+
           val batchReader = new BatchReader(
             sharedConf,
             file,
diff --git 
a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
 
b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
index 01ccde94b..4dc099735 100644
--- 
a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
+++ 
b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
@@ -199,6 +199,7 @@ case class CometParquetPartitionReaderFactory(
       val parquetSchema = footerFileMetaData.getSchema
       val parquetFilters = new ParquetFilters(
         parquetSchema,
+        readDataSchema,
         pushDownDate,
         pushDownTimestamp,
         pushDownDecimal,
diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala 
b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
index 30d7804e8..54db9308c 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
@@ -19,12 +19,14 @@
 
 package org.apache.comet.parquet
 
+import java.io.ByteArrayOutputStream
 import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float 
=> JFloat, Long => JLong, Short => JShort}
 import java.math.{BigDecimal => JBigDecimal}
 import java.sql.{Date, Timestamp}
 import java.time.{Duration, Instant, LocalDate, Period}
 import java.util.Locale
 
+import scala.collection.JavaConverters._
 import scala.collection.JavaConverters.asScalaBufferConverter
 
 import org.apache.parquet.column.statistics.{Statistics => ParquetStatistics}
@@ -39,8 +41,11 @@ import org.apache.parquet.schema.Type.Repetition
 import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, 
DateTimeUtils, IntervalUtils}
 import 
org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, 
rebaseGregorianToJulianMicros, RebaseSpec}
 import org.apache.spark.sql.sources
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.unsafe.types.UTF8String
 
+import org.apache.comet.parquet.SourceFilterSerde.{createBinaryExpr, 
createNameExpr, createUnaryExpr, createValueExpr}
+import org.apache.comet.serde.ExprOuterClass
 import org.apache.comet.shims.ShimSQLConf
 
 /**
@@ -51,6 +56,7 @@ import org.apache.comet.shims.ShimSQLConf
  */
 class ParquetFilters(
     schema: MessageType,
+    dataSchema: StructType,
     pushDownDate: Boolean,
     pushDownTimestamp: Boolean,
     pushDownDecimal: Boolean,
@@ -876,4 +882,154 @@ class ParquetFilters(
       case _ => None
     }
   }
+
+  def createNativeFilters(predicates: Seq[sources.Filter]): 
Option[Array[Byte]] = {
+    predicates.reduceOption(sources.And).flatMap(createNativeFilter).map { 
expr =>
+      val outputStream = new ByteArrayOutputStream()
+      expr.writeTo(outputStream)
+      outputStream.close()
+      outputStream.toByteArray
+    }
+  }
+
+  private def createNativeFilter(predicate: sources.Filter): 
Option[ExprOuterClass.Expr] = {
+    def nameUnaryExpr(name: String)(
+        f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => 
ExprOuterClass.Expr.Builder)
+        : Option[ExprOuterClass.Expr] = {
+      createNameExpr(name, dataSchema).map { case (_, childExpr) =>
+        createUnaryExpr(childExpr, f)
+      }
+    }
+
+    def nameValueBinaryExpr(name: String, value: Any)(
+        f: (
+            ExprOuterClass.Expr.Builder,
+            ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder)
+        : Option[ExprOuterClass.Expr] = {
+      createNameExpr(name, dataSchema).flatMap { case (dataType, childExpr) =>
+        createValueExpr(value, dataType).map(createBinaryExpr(childExpr, _, f))
+      }
+    }
+
+    predicate match {
+      case sources.IsNull(name) if canMakeFilterOn(name, null) =>
+        nameUnaryExpr(name) { (builder, unaryExpr) =>
+          builder.setIsNull(unaryExpr)
+        }
+      case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
+        nameUnaryExpr(name) { (builder, unaryExpr) =>
+          builder.setIsNotNull(unaryExpr)
+        }
+
+      case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
+        nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
+          builder.setEq(binaryExpr)
+        }
+
+      case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, 
value) =>
+        nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
+          builder.setNeq(binaryExpr)
+        }
+
+      case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) 
=>
+        nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
+          builder.setEqNullSafe(binaryExpr)
+        }
+
+      case sources.Not(sources.EqualNullSafe(name, value)) if 
canMakeFilterOn(name, value) =>
+        nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
+          builder.setNeqNullSafe(binaryExpr)
+        }
+
+      case sources.LessThan(name, value) if (value != null) && 
canMakeFilterOn(name, value) =>
+        nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
+          builder.setLt(binaryExpr)
+        }
+
+      case sources.LessThanOrEqual(name, value)
+          if (value != null) && canMakeFilterOn(name, value) =>
+        nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
+          builder.setLtEq(binaryExpr)
+        }
+
+      case sources.GreaterThan(name, value) if (value != null) && 
canMakeFilterOn(name, value) =>
+        nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
+          builder.setGt(binaryExpr)
+        }
+
+      case sources.GreaterThanOrEqual(name, value)
+          if (value != null) && canMakeFilterOn(name, value) =>
+        nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
+          builder.setGtEq(binaryExpr)
+        }
+
+      case sources.And(lhs, rhs) =>
+        (createNativeFilter(lhs), createNativeFilter(rhs)) match {
+          case (Some(leftExpr), Some(rightExpr)) =>
+            Some(
+              createBinaryExpr(
+                leftExpr,
+                rightExpr,
+                (builder, binaryExpr) => builder.setAnd(binaryExpr)))
+          case _ => None
+        }
+
+      case sources.Or(lhs, rhs) =>
+        (createNativeFilter(lhs), createNativeFilter(rhs)) match {
+          case (Some(leftExpr), Some(rightExpr)) =>
+            Some(
+              createBinaryExpr(
+                leftExpr,
+                rightExpr,
+                (builder, binaryExpr) => builder.setOr(binaryExpr)))
+          case _ => None
+        }
+
+      case sources.Not(pred) =>
+        val childExpr = createNativeFilter(pred)
+        childExpr.map { expr =>
+          createUnaryExpr(expr, (builder, unaryExpr) => 
builder.setNot(unaryExpr))
+        }
+
+      case sources.In(name, values)
+          if pushDownInFilterThreshold > 0 && values.nonEmpty &&
+            canMakeFilterOn(name, values.head) =>
+        createNameExpr(name, dataSchema).flatMap { case (dataType, nameExpr) =>
+          val valueExprs = values.flatMap(createValueExpr(_, dataType))
+          if (valueExprs.length != values.length) {
+            None
+          } else {
+            val builder = ExprOuterClass.In.newBuilder()
+            builder.setInValue(nameExpr)
+            builder.addAllLists(valueExprs.toSeq.asJava)
+            builder.setNegated(false)
+            Some(
+              ExprOuterClass.Expr
+                .newBuilder()
+                .setIn(builder)
+                .build())
+          }
+        }
+
+      case sources.StringStartsWith(name, prefix)
+          if pushDownStringPredicate && canMakeFilterOn(name, prefix) =>
+        nameValueBinaryExpr(name, prefix) { (builder, binaryExpr) =>
+          builder.setStartsWith(binaryExpr)
+        }
+
+      case sources.StringEndsWith(name, suffix)
+          if pushDownStringPredicate && canMakeFilterOn(name, suffix) =>
+        nameValueBinaryExpr(name, suffix) { (builder, binaryExpr) =>
+          builder.setEndsWith(binaryExpr)
+        }
+
+      case sources.StringContains(name, value)
+          if pushDownStringPredicate && canMakeFilterOn(name, value) =>
+        nameValueBinaryExpr(name, value) { (builder, binaryExpr) =>
+          builder.setContains(binaryExpr)
+        }
+
+      case _ => None
+    }
+  }
 }
diff --git 
a/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala 
b/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala
new file mode 100644
index 000000000..4ad467cd8
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.parquet
+
+import java.math.{BigDecimal => JavaBigDecimal}
+import java.sql.{Date, Timestamp}
+import java.time.{Instant, LocalDate, LocalDateTime}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+
+import org.apache.comet.serde.ExprOuterClass
+import org.apache.comet.serde.ExprOuterClass.Expr
+import org.apache.comet.serde.QueryPlanSerde.serializeDataType
+
+object SourceFilterSerde extends Logging {
+
+  def createNameExpr(
+      name: String,
+      schema: StructType): Option[(DataType, ExprOuterClass.Expr)] = {
+    val filedWithIndex = schema.fields.zipWithIndex.find { case (field, _) =>
+      field.name == name
+    }
+    if (filedWithIndex.isDefined) {
+      val (field, index) = filedWithIndex.get
+      val dataType = serializeDataType(field.dataType)
+      if (dataType.isDefined) {
+        val boundExpr = ExprOuterClass.BoundReference
+          .newBuilder()
+          .setIndex(index)
+          .setDatatype(dataType.get)
+          .build()
+        Some(
+          field.dataType,
+          ExprOuterClass.Expr
+            .newBuilder()
+            .setBound(boundExpr)
+            .build())
+      } else {
+        None
+      }
+    } else {
+      None
+    }
+
+  }
+
+  /**
+   * create a literal value native expression for source filter value, the 
value is a scala value
+   */
+  def createValueExpr(value: Any, dataType: DataType): 
Option[ExprOuterClass.Expr] = {
+    val exprBuilder = ExprOuterClass.Literal.newBuilder()
+    var valueIsSet = true
+    if (value == null) {
+      exprBuilder.setIsNull(true)
+    } else {
+      exprBuilder.setIsNull(false)
+      // value is a scala value, not a catalyst value
+      // refer to 
org.apache.spark.sql.catalyst.CatalystTypeConverters.CatalystTypeConverter#toScala
+      dataType match {
+        case _: BooleanType => 
exprBuilder.setBoolVal(value.asInstanceOf[Boolean])
+        case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte])
+        case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short])
+        case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int])
+        case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long])
+        case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float])
+        case _: DoubleType => 
exprBuilder.setDoubleVal(value.asInstanceOf[Double])
+        case _: StringType => 
exprBuilder.setStringVal(value.asInstanceOf[String])
+        case _: TimestampType =>
+          value match {
+            case v: Timestamp => 
exprBuilder.setLongVal(DateTimeUtils.fromJavaTimestamp(v))
+            case v: Instant => 
exprBuilder.setLongVal(DateTimeUtils.instantToMicros(v))
+            case v: Long => exprBuilder.setLongVal(v)
+            case _ =>
+              valueIsSet = false
+              logWarning(s"Unexpected timestamp type '${value.getClass}' for 
value '$value'")
+          }
+        case _: TimestampNTZType =>
+          value match {
+            case v: LocalDateTime =>
+              exprBuilder.setLongVal(DateTimeUtils.localDateTimeToMicros(v))
+            case v: Long => exprBuilder.setLongVal(v)
+            case _ =>
+              valueIsSet = false
+              logWarning(s"Unexpected timestamp type '${value.getClass}' for 
value' $value'")
+          }
+        case _: DecimalType =>
+          // Pass decimal literal as bytes.
+          val unscaled = value.asInstanceOf[JavaBigDecimal].unscaledValue
+          
exprBuilder.setDecimalVal(com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray))
+        case _: BinaryType =>
+          val byteStr =
+            
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
+          exprBuilder.setBytesVal(byteStr)
+        case _: DateType =>
+          value match {
+            case v: LocalDate => 
exprBuilder.setIntVal(DateTimeUtils.localDateToDays(v))
+            case v: Date => 
exprBuilder.setIntVal(DateTimeUtils.fromJavaDate(v))
+            case v: Int => exprBuilder.setIntVal(v)
+            case _ =>
+              valueIsSet = false
+              logWarning(s"Unexpected date type '${value.getClass}' for value 
'$value'")
+          }
+        case dt =>
+          valueIsSet = false
+          logWarning(s"Unexpected data type '$dt' for literal value '$value'")
+      }
+    }
+
+    val dt = serializeDataType(dataType)
+
+    if (valueIsSet && dt.isDefined) {
+      exprBuilder.setDatatype(dt.get)
+
+      Some(
+        ExprOuterClass.Expr
+          .newBuilder()
+          .setLiteral(exprBuilder)
+          .build())
+    } else {
+      None
+    }
+  }
+
+  def createUnaryExpr(
+      childExpr: Expr,
+      f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => 
ExprOuterClass.Expr.Builder)
+      : ExprOuterClass.Expr = {
+    // create the generic UnaryExpr message
+    val inner = ExprOuterClass.UnaryExpr
+      .newBuilder()
+      .setChild(childExpr)
+      .build()
+    f(
+      ExprOuterClass.Expr
+        .newBuilder(),
+      inner).build()
+  }
+
+  def createBinaryExpr(
+      leftExpr: Expr,
+      rightExpr: Expr,
+      f: (ExprOuterClass.Expr.Builder, ExprOuterClass.BinaryExpr) => 
ExprOuterClass.Expr.Builder)
+      : ExprOuterClass.Expr = {
+    // create the generic BinaryExpr message
+    val inner = ExprOuterClass.BinaryExpr
+      .newBuilder()
+      .setLeft(leftExpr)
+      .setRight(rightExpr)
+      .build()
+    f(
+      ExprOuterClass.Expr
+        .newBuilder(),
+      inner).build()
+  }
+
+}
diff --git 
a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala 
b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
index f5d0cbb42..d810b4fb8 100644
--- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
@@ -20,11 +20,12 @@
 package org.apache.comet.parquet
 
 import java.io.{File, FileFilter}
-import java.math.BigDecimal
+import java.math.{BigDecimal, BigInteger}
 import java.time.{ZoneId, ZoneOffset}
 
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.Breaks.{break, breakable}
 
 import org.scalactic.source.Position
 import org.scalatest.Tag
@@ -1493,6 +1494,74 @@ class ParquetReadV1Suite extends ParquetReadSuite with 
AdaptiveSparkPlanHelper {
         })
     }
   }
+
+  test("test V1 parquet scan filter pushdown of primitive types uses 
native_iceberg_compat") {
+    withTempPath { dir =>
+      val path = new Path(dir.toURI.toString, "test1.parquet")
+      val rows = 1000
+      withSQLConf(
+        CometConf.COMET_NATIVE_SCAN_IMPL.key -> 
CometConf.SCAN_NATIVE_ICEBERG_COMPAT,
+        CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "false") {
+        makeParquetFileAllTypes(path, dictionaryEnabled = false, 0, rows, 
nullEnabled = false)
+      }
+      Seq(
+        (CometConf.SCAN_NATIVE_DATAFUSION, "output_rows"),
+        (CometConf.SCAN_NATIVE_ICEBERG_COMPAT, "numOutputRows")).foreach {
+        case (scanMode, metricKey) =>
+          Seq(true, false).foreach { pushDown =>
+            breakable {
+              withSQLConf(
+                CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanMode,
+                SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> 
pushDown.toString) {
+                if (scanMode == CometConf.SCAN_NATIVE_DATAFUSION && !pushDown) 
{
+                  // FIXME: native_datafusion always pushdown data filters
+                  break()
+                }
+                Seq(
+                  ("_1 = true", Math.ceil(rows.toDouble / 2)), // Boolean
+                  ("_2 = 1", Math.ceil(rows.toDouble / 256)), // Byte
+                  ("_3 = 1", 1), // Short
+                  ("_4 = 1", 1), // Integer
+                  ("_5 = 1", 1), // Long
+                  ("_6 = 1.0", 1), // Float
+                  ("_7 = 1.0", 1), // Double
+                  (s"_8 = '${1.toString * 48}'", 1), // String
+                  ("_21 = to_binary('1', 'utf-8')", 1), // binary
+                  ("_15 = 0.0", 1), // DECIMAL(5, 2)
+                  ("_16 = 0.0", 1), // DECIMAL(18, 10)
+                  (
+                    s"_17 = ${new BigDecimal(new BigInteger(("1" * 
16).getBytes), 37).toString}",
+                    Math.ceil(rows.toDouble / 10)
+                  ), // DECIMAL(38, 37)
+                  (s"_19 = TIMESTAMP '${DateTimeUtils.toJavaTimestamp(1)}'", 
1), // Timestamp
+                  ("_20 = DATE '1970-01-02'", 1) // Date
+                ).foreach { case (whereCause, expectedRows) =>
+                  val df = spark.read
+                    .parquet(path.toString)
+                    .where(whereCause)
+                  val (_, cometPlan) = checkSparkAnswer(df)
+                  val scan = collect(cometPlan) {
+                    case p: CometScanExec =>
+                      assert(p.dataFilters.nonEmpty)
+                      p
+                    case p: CometNativeScanExec =>
+                      assert(p.dataFilters.nonEmpty)
+                      p
+                  }
+                  assert(scan.size == 1)
+
+                  if (pushDown) {
+                    assert(scan.head.metrics(metricKey).value == expectedRows)
+                  } else {
+                    assert(scan.head.metrics(metricKey).value == rows)
+                  }
+                }
+              }
+            }
+          }
+      }
+    }
+  }
 }
 
 class ParquetReadV2Suite extends ParquetReadSuite with AdaptiveSparkPlanHelper 
{
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 7bdec4215..3891b9112 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -461,6 +461,7 @@ abstract class CometTestBase
        |  optional INT64                    _18(TIMESTAMP(MILLIS,true));
        |  optional INT64                    _19(TIMESTAMP(MICROS,true));
        |  optional INT32                    _20(DATE);
+       |  optional binary                   _21;
        |  optional INT32                    _id;
        |}
       """.stripMargin
@@ -487,6 +488,7 @@ abstract class CometTestBase
        |  optional INT64                    _18(TIMESTAMP(MILLIS,true));
        |  optional INT64                    _19(TIMESTAMP(MICROS,true));
        |  optional INT32                    _20(DATE);
+       |  optional binary                   _21;
        |  optional INT32                    _id;
        |}
       """.stripMargin
@@ -498,6 +500,7 @@ abstract class CometTestBase
       dictionaryEnabled: Boolean,
       begin: Int,
       end: Int,
+      nullEnabled: Boolean = true,
       pageSize: Int = 128,
       randomSize: Int = 0): Unit = {
     // alwaysIncludeUnsignedIntTypes means we include unsignedIntTypes in the 
test even if the
@@ -516,7 +519,7 @@ abstract class CometTestBase
 
     val rand = scala.util.Random
     val data = (begin until end).map { i =>
-      if (rand.nextBoolean()) {
+      if (nullEnabled && rand.nextBoolean()) {
         None
       } else {
         if (dictionaryEnabled) Some(i % 4) else Some(i)
@@ -546,7 +549,8 @@ abstract class CometTestBase
           record.add(17, i.toLong)
           record.add(18, i.toLong)
           record.add(19, i)
-          record.add(20, idGenerator.getAndIncrement())
+          record.add(20, i.toString)
+          record.add(21, idGenerator.getAndIncrement())
         case _ =>
       }
       writer.write(record)
@@ -574,7 +578,8 @@ abstract class CometTestBase
       record.add(17, i)
       record.add(18, i)
       record.add(19, i.toInt)
-      record.add(20, idGenerator.getAndIncrement())
+      record.add(20, i.toString)
+      record.add(21, idGenerator.getAndIncrement())
       writer.write(record)
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to