This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch comet-parquet-exec
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/comet-parquet-exec by this 
push:
     new 2686a4b6 feat: [comet-parquet-exec] Use Datafusion based record batch 
reader for use in iceberg reads (#1174)
2686a4b6 is described below

commit 2686a4b69e40d7f18c5cf9b488f17b64ba40c5b7
Author: Parth Chandra <[email protected]>
AuthorDate: Tue Dec 17 15:41:39 2024 -0800

    feat: [comet-parquet-exec] Use Datafusion based record batch reader for use 
in iceberg reads (#1174)
    
    * wip. Use DF's ParquetExec for Iceberg API
    
    * wip - await??
    
    * wip
    
    * wip -
    
    * fix shading issue
    
    * fix shading issue
    
    * fixes
    
    * refactor to remove arrow based reader
    
    * rename config
    
    * Fix config defaults
    
    ---------
    
    Co-authored-by: Andy Grove <[email protected]>
---
 .../main/java/org/apache/comet/parquet/Native.java |   7 +-
 .../apache/comet/parquet/NativeBatchReader.java    |  36 ++--
 .../main/scala/org/apache/comet/CometConf.scala    |  12 +-
 .../apache/spark/sql/comet/CometArrowUtils.scala   | 180 +++++++++++++++++
 native/Cargo.lock                                  |   2 +
 native/Cargo.toml                                  |   2 +
 native/core/Cargo.toml                             |   2 +
 native/core/src/execution/datafusion/mod.rs        |   2 +-
 native/core/src/parquet/mod.rs                     | 224 ++++++++++-----------
 native/core/src/parquet/util/jni.rs                |  55 +++++
 native/spark-expr/src/cast.rs                      |  10 +-
 .../scala/org/apache/comet/DataTypeSupport.scala   |   2 +-
 .../comet/parquet/CometParquetFileFormat.scala     |   2 +-
 .../scala/org/apache/spark/sql/CometTestBase.scala |   4 +-
 .../spark/sql/comet/CometPlanStabilitySuite.scala  |  16 +-
 15 files changed, 396 insertions(+), 160 deletions(-)

diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java 
b/common/src/main/java/org/apache/comet/parquet/Native.java
index 1ed01d32..b33ec60d 100644
--- a/common/src/main/java/org/apache/comet/parquet/Native.java
+++ b/common/src/main/java/org/apache/comet/parquet/Native.java
@@ -246,15 +246,10 @@ public final class Native extends NativeBase {
    * @param filePath
    * @param start
    * @param length
-   * @param required_columns array of names of fields to read
    * @return a handle to the record batch reader, used in subsequent calls.
    */
   public static native long initRecordBatchReader(
-      String filePath, long start, long length, Object[] required_columns);
-
-  public static native int numRowGroups(long handle);
-
-  public static native long numTotalRows(long handle);
+      String filePath, long fileSize, long start, long length, byte[] 
requiredSchema);
 
   // arrow native version of read batch
   /**
diff --git 
a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java 
b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
index 3ac55ba4..8461bb50 100644
--- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
+++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
@@ -19,11 +19,13 @@
 
 package org.apache.comet.parquet;
 
+import java.io.ByteArrayOutputStream;
 import java.io.Closeable;
 import java.io.IOException;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.net.URISyntaxException;
+import java.nio.channels.Channels;
 import java.util.*;
 
 import scala.Option;
@@ -36,6 +38,9 @@ import org.slf4j.LoggerFactory;
 import org.apache.arrow.c.CometSchemaImporter;
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.ipc.WriteChannel;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.types.pojo.Schema;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.mapreduce.InputSplit;
@@ -52,6 +57,7 @@ import org.apache.spark.TaskContext;
 import org.apache.spark.TaskContext$;
 import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.comet.CometArrowUtils;
 import org.apache.spark.sql.comet.parquet.CometParquetReadSupport;
 import org.apache.spark.sql.execution.datasources.PartitionedFile;
 import 
org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter;
@@ -99,7 +105,6 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
   private PartitionedFile file;
   private final Map<String, SQLMetric> metrics;
 
-  private long rowsRead;
   private StructType sparkSchema;
   private MessageType requestedSchema;
   private CometVector[] vectors;
@@ -111,9 +116,6 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
   private boolean isInitialized;
   private ParquetMetadata footer;
 
-  /** The total number of rows across all row groups of the input split. */
-  private long totalRowCount;
-
   /**
    * Whether the native scan should always return decimal represented by 128 
bits, regardless of its
    * precision. Normally, this should be true if native execution is enabled, 
since Arrow compute
@@ -224,6 +226,7 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
     long start = file.start();
     long length = file.length();
     String filePath = file.filePath().toString();
+    long fileSize = file.fileSize();
 
     requestedSchema = footer.getFileMetaData().getSchema();
     MessageType fileSchema = requestedSchema;
@@ -254,6 +257,13 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
       }
     } ////// End get requested schema
 
+    String timeZoneId = conf.get("spark.sql.session.timeZone");
+    Schema arrowSchema = CometArrowUtils.toArrowSchema(sparkSchema, 
timeZoneId);
+    ByteArrayOutputStream out = new ByteArrayOutputStream();
+    WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out));
+    MessageSerializer.serialize(writeChannel, arrowSchema);
+    byte[] serializedRequestedArrowSchema = out.toByteArray();
+
     //// Create Column readers
     List<ColumnDescriptor> columns = requestedSchema.getColumns();
     int numColumns = columns.size();
@@ -334,13 +344,9 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
       }
     }
 
-    // TODO: (ARROW NATIVE) Use a ProjectionMask here ?
-    ArrayList<String> requiredColumns = new ArrayList<>();
-    for (Type col : requestedSchema.asGroupType().getFields()) {
-      requiredColumns.add(col.getName());
-    }
-    this.handle = Native.initRecordBatchReader(filePath, start, length, 
requiredColumns.toArray());
-    totalRowCount = Native.numRowGroups(handle);
+    this.handle =
+        Native.initRecordBatchReader(
+            filePath, fileSize, start, length, serializedRequestedArrowSchema);
     isInitialized = true;
   }
 
@@ -375,7 +381,7 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
 
   @Override
   public float getProgress() {
-    return (float) rowsRead / totalRowCount;
+    return 0;
   }
 
   /**
@@ -395,7 +401,7 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
   public boolean nextBatch() throws IOException {
     Preconditions.checkState(isInitialized, "init() should be called first!");
 
-    if (rowsRead >= totalRowCount) return false;
+    //    if (rowsRead >= totalRowCount) return false;
     int batchSize;
 
     try {
@@ -432,7 +438,6 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
     }
 
     currentBatch.setNumRows(batchSize);
-    rowsRead += batchSize;
     return true;
   }
 
@@ -457,6 +462,9 @@ public class NativeBatchReader extends RecordReader<Void, 
ColumnarBatch> impleme
     long startNs = System.nanoTime();
 
     int batchSize = Native.readNextRecordBatch(this.handle);
+    if (batchSize == 0) {
+      return batchSize;
+    }
     if (importer != null) importer.close();
     importer = new CometSchemaImporter(ALLOCATOR);
 
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala 
b/common/src/main/scala/org/apache/comet/CometConf.scala
index 275114a1..fabdd30c 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -75,7 +75,7 @@ object CometConf extends ShimCometConf {
         "that to enable native vectorized execution, both this config and " +
         "'spark.comet.exec.enabled' need to be enabled.")
     .booleanConf
-    .createWithDefault(true)
+    .createWithDefault(false)
 
   val COMET_FULL_NATIVE_SCAN_ENABLED: ConfigEntry[Boolean] = conf(
     "spark.comet.native.scan.enabled")
@@ -85,15 +85,15 @@ object CometConf extends ShimCometConf {
         "read supported data sources (currently only Parquet is supported 
natively)." +
         " By default, this config is true.")
     .booleanConf
-    .createWithDefault(false)
+    .createWithDefault(true)
 
-  val COMET_NATIVE_ARROW_SCAN_ENABLED: ConfigEntry[Boolean] = conf(
+  val COMET_NATIVE_RECORDBATCH_READER_ENABLED: ConfigEntry[Boolean] = conf(
     "spark.comet.native.arrow.scan.enabled")
     .internal()
     .doc(
-      "Whether to enable the fully native arrow based scan. When this is 
turned on, Spark will " +
-        "use Comet to read Parquet files natively via the Arrow based Parquet 
reader." +
-        " By default, this config is false.")
+      "Whether to enable the fully native datafusion based column reader. When 
this is turned on," +
+        " Spark will use Comet to read Parquet files natively via the 
Datafusion based Parquet" +
+        " reader. By default, this config is false.")
     .booleanConf
     .createWithDefault(false)
 
diff --git 
a/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala 
b/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala
new file mode 100644
index 00000000..2f4f55fc
--- /dev/null
+++ b/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.complex.MapVector
+import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, 
IntervalUnit, TimeUnit}
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+object CometArrowUtils {
+
+  val rootAllocator = new RootAllocator(Long.MaxValue)
+
+  // todo: support more types.
+
+  /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for 
TimestampTypes */
+  def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match {
+    case BooleanType => ArrowType.Bool.INSTANCE
+    case ByteType => new ArrowType.Int(8, true)
+    case ShortType => new ArrowType.Int(8 * 2, true)
+    case IntegerType => new ArrowType.Int(8 * 4, true)
+    case LongType => new ArrowType.Int(8 * 8, true)
+    case FloatType => new 
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
+    case DoubleType => new 
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
+    case StringType => ArrowType.Utf8.INSTANCE
+    case BinaryType => ArrowType.Binary.INSTANCE
+    case DecimalType.Fixed(precision, scale) => new 
ArrowType.Decimal(precision, scale)
+    case DateType => new ArrowType.Date(DateUnit.DAY)
+    case TimestampType if timeZoneId == null =>
+      throw new IllegalStateException("Missing timezoneId where it is 
mandatory.")
+    case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, 
timeZoneId)
+    case TimestampNTZType =>
+      new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
+    case NullType => ArrowType.Null.INSTANCE
+    case _: YearMonthIntervalType => new 
ArrowType.Interval(IntervalUnit.YEAR_MONTH)
+    case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
+    case _ =>
+      throw new IllegalArgumentException()
+  }
+
+  def fromArrowType(dt: ArrowType): DataType = dt match {
+    case ArrowType.Bool.INSTANCE => BooleanType
+    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => 
ByteType
+    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => 
ShortType
+    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => 
IntegerType
+    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => 
LongType
+    case float: ArrowType.FloatingPoint
+        if float.getPrecision() == FloatingPointPrecision.SINGLE =>
+      FloatType
+    case float: ArrowType.FloatingPoint
+        if float.getPrecision() == FloatingPointPrecision.DOUBLE =>
+      DoubleType
+    case ArrowType.Utf8.INSTANCE => StringType
+    case ArrowType.Binary.INSTANCE => BinaryType
+    case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
+    case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
+    case ts: ArrowType.Timestamp
+        if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
+      TimestampNTZType
+    case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => 
TimestampType
+    case ArrowType.Null.INSTANCE => NullType
+    case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
+      YearMonthIntervalType()
+    case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => 
DayTimeIntervalType()
+    case _ => throw new IllegalArgumentException()
+    // throw QueryExecutionErrors.unsupportedArrowTypeError(dt)
+  }
+
+  /** Maps field from Spark to Arrow. NOTE: timeZoneId required for 
TimestampType */
+  def toArrowField(name: String, dt: DataType, nullable: Boolean, timeZoneId: 
String): Field = {
+    dt match {
+      case ArrayType(elementType, containsNull) =>
+        val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
+        new Field(
+          name,
+          fieldType,
+          Seq(toArrowField("element", elementType, containsNull, 
timeZoneId)).asJava)
+      case StructType(fields) =>
+        val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, 
null)
+        new Field(
+          name,
+          fieldType,
+          fields
+            .map { field =>
+              toArrowField(field.name, field.dataType, field.nullable, 
timeZoneId)
+            }
+            .toSeq
+            .asJava)
+      case MapType(keyType, valueType, valueContainsNull) =>
+        val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
+        // Note: Map Type struct can not be null, Struct Type key field can 
not be null
+        new Field(
+          name,
+          mapType,
+          Seq(
+            toArrowField(
+              MapVector.DATA_VECTOR_NAME,
+              new StructType()
+                .add(MapVector.KEY_NAME, keyType, nullable = false)
+                .add(MapVector.VALUE_NAME, valueType, nullable = 
valueContainsNull),
+              nullable = false,
+              timeZoneId)).asJava)
+      case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, 
nullable, timeZoneId)
+      case dataType =>
+        val fieldType = new FieldType(nullable, toArrowType(dataType, 
timeZoneId), null)
+        new Field(name, fieldType, Seq.empty[Field].asJava)
+    }
+  }
+
+  def fromArrowField(field: Field): DataType = {
+    field.getType match {
+      case _: ArrowType.Map =>
+        val elementField = field.getChildren.get(0)
+        val keyType = fromArrowField(elementField.getChildren.get(0))
+        val valueType = fromArrowField(elementField.getChildren.get(1))
+        MapType(keyType, valueType, elementField.getChildren.get(1).isNullable)
+      case ArrowType.List.INSTANCE =>
+        val elementField = field.getChildren().get(0)
+        val elementType = fromArrowField(elementField)
+        ArrayType(elementType, containsNull = elementField.isNullable)
+      case ArrowType.Struct.INSTANCE =>
+        val fields = field.getChildren().asScala.map { child =>
+          val dt = fromArrowField(child)
+          StructField(child.getName, dt, child.isNullable)
+        }
+        StructType(fields.toArray)
+      case arrowType => fromArrowType(arrowType)
+    }
+  }
+
+  /**
+   * Maps schema from Spark to Arrow. NOTE: timeZoneId required for 
TimestampType in StructType
+   */
+  def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
+    new Schema(schema.map { field =>
+      toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
+    }.asJava)
+  }
+
+  def fromArrowSchema(schema: Schema): StructType = {
+    StructType(schema.getFields.asScala.map { field =>
+      val dt = fromArrowField(field)
+      StructField(field.getName, dt, field.isNullable)
+    }.toArray)
+  }
+
+  /** Return Map with conf settings to be used in ArrowPythonRunner */
+  def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
+    val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> 
conf.sessionLocalTimeZone)
+    val pandasColsByName = Seq(
+      SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
+        conf.pandasGroupedMapAssignColumnsByName.toString)
+    val arrowSafeTypeCheck = Seq(
+      SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
+        conf.arrowSafeTypeConversion.toString)
+    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
+  }
+
+}
diff --git a/native/Cargo.lock b/native/Cargo.lock
index c3a664ff..27e97268 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -898,6 +898,7 @@ dependencies = [
  "arrow-array",
  "arrow-buffer",
  "arrow-data",
+ "arrow-ipc",
  "arrow-schema",
  "assertables",
  "async-trait",
@@ -913,6 +914,7 @@ dependencies = [
  "datafusion-expr",
  "datafusion-functions-nested",
  "datafusion-physical-expr",
+ "flatbuffers",
  "flate2",
  "futures",
  "half",
diff --git a/native/Cargo.toml b/native/Cargo.toml
index 4b89231c..b78c1d68 100644
--- a/native/Cargo.toml
+++ b/native/Cargo.toml
@@ -37,7 +37,9 @@ arrow = { version = "53.2.0", features = ["prettyprint", 
"ffi", "chrono-tz"] }
 arrow-array = { version = "53.2.0" }
 arrow-buffer = { version = "53.2.0" }
 arrow-data = { version = "53.2.0" }
+arrow-ipc = { version = "53.2.0" }
 arrow-schema = { version = "53.2.0" }
+flatbuffers = { version = "24.3.25" }
 parquet = { version = "53.2.0", default-features = false, features = 
["experimental"] }
 datafusion-common = { version = "43.0.0" }
 datafusion = { version = "43.0.0", default-features = false, features = 
["unicode_expressions", "crypto_expressions", "parquet"] }
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index 8d30b38c..35035ff3 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -40,6 +40,8 @@ arrow-array = { workspace = true }
 arrow-buffer = { workspace = true }
 arrow-data = { workspace = true }
 arrow-schema = { workspace = true }
+arrow-ipc = { workspace = true }
+flatbuffers = { workspace = true }
 parquet = { workspace = true, default-features = false, features = 
["experimental"] }
 half = { version = "2.4.1", default-features = false }
 futures = "0.3.28"
diff --git a/native/core/src/execution/datafusion/mod.rs 
b/native/core/src/execution/datafusion/mod.rs
index fb9c8829..af32b4be 100644
--- a/native/core/src/execution/datafusion/mod.rs
+++ b/native/core/src/execution/datafusion/mod.rs
@@ -20,6 +20,6 @@
 pub mod expressions;
 mod operators;
 pub mod planner;
-mod schema_adapter;
+pub(crate) mod schema_adapter;
 pub mod shuffle_writer;
 mod util;
diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs
index afca6066..c234b6f7 100644
--- a/native/core/src/parquet/mod.rs
+++ b/native/core/src/parquet/mod.rs
@@ -23,7 +23,7 @@ pub use mutable_vector::*;
 pub mod util;
 pub mod read;
 
-use std::fs::File;
+use std::task::Poll;
 use std::{boxed::Box, ptr::NonNull, sync::Arc};
 
 use crate::errors::{try_unwrap_or_throw, CometError};
@@ -42,17 +42,21 @@ use jni::{
 
 use crate::execution::operators::ExecutionError;
 use crate::execution::utils::SparkArrowConvert;
+use crate::parquet::data_type::AsBytes;
 use arrow::buffer::{Buffer, MutableBuffer};
 use arrow_array::{Array, RecordBatch};
-use jni::objects::{
-    JBooleanArray, JLongArray, JObjectArray, JPrimitiveArray, JString, 
ReleaseMode,
-};
+use datafusion::datasource::listing::PartitionedFile;
+use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder;
+use datafusion::datasource::physical_plan::FileScanConfig;
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion_common::config::TableParquetOptions;
+use datafusion_execution::{SendableRecordBatchStream, TaskContext};
+use futures::{poll, StreamExt};
+use jni::objects::{JBooleanArray, JByteArray, JLongArray, JPrimitiveArray, 
JString, ReleaseMode};
 use jni::sys::jstring;
-use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, 
ParquetRecordBatchReaderBuilder};
-use parquet::arrow::ProjectionMask;
+use parquet::arrow::arrow_reader::ParquetRecordBatchReader;
 use read::ColumnReader;
-use url::Url;
-use util::jni::{convert_column_descriptor, convert_encoding};
+use util::jni::{convert_column_descriptor, convert_encoding, 
deserialize_schema, get_file_path};
 
 use self::util::jni::TypePromotionInfo;
 
@@ -600,11 +604,11 @@ enum ParquetReaderState {
 }
 /// Parquet read context maintained across multiple JNI calls.
 struct BatchContext {
-    batch_reader: ParquetRecordBatchReader,
+    runtime: tokio::runtime::Runtime,
+    batch_stream: Option<SendableRecordBatchStream>,
+    batch_reader: Option<ParquetRecordBatchReader>,
     current_batch: Option<RecordBatch>,
     reader_state: ParquetReaderState,
-    num_row_groups: i32,
-    total_rows: i64,
 }
 
 #[inline]
@@ -616,10 +620,12 @@ fn get_batch_context<'a>(handle: jlong) -> Result<&'a mut 
BatchContext, CometErr
     }
 }
 
+/*
 #[inline]
 fn get_batch_reader<'a>(handle: jlong) -> Result<&'a mut 
ParquetRecordBatchReader, CometError> {
-    Ok(&mut get_batch_context(handle)?.batch_reader)
+    Ok(&mut get_batch_context(handle)?.batch_reader.unwrap())
 }
+*/
 
 /// # Safety
 /// This function is inherently unsafe since it deals with raw pointers passed 
from JNI.
@@ -628,118 +634,80 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_parquet_Native_initRecordBat
     e: JNIEnv,
     _jclass: JClass,
     file_path: jstring,
+    file_size: jlong,
     start: jlong,
     length: jlong,
-    required_columns: jobjectArray,
+    required_schema: jbyteArray,
 ) -> jlong {
     try_unwrap_or_throw(&e, |mut env| unsafe {
         let path: String = env
             .get_string(&JString::from_raw(file_path))
             .unwrap()
             .into();
-        //TODO: (ARROW NATIVE) - this works only for 'file://' urls
-        let path = Url::parse(path.as_ref()).unwrap().to_file_path().unwrap();
-        let file = File::open(path).unwrap();
-
-        // Create a async parquet reader builder with batch_size.
-        // batch_size is the number of rows to read up to buffer once from 
pages, defaults to 1024
-        // TODO: (ARROW NATIVE) Use async reader 
ParquetRecordBatchStreamBuilder
-        let mut builder = ParquetRecordBatchReaderBuilder::try_new(file)
-            .unwrap()
-            .with_batch_size(8192); // TODO: (ARROW NATIVE) Use batch size 
configured in JVM
-
-        let num_row_groups;
-        let mut total_rows: i64 = 0;
-        //TODO: (ARROW NATIVE) if we can get the ParquetMetadata serialized, 
we need not do this.
-        {
-            let metadata = builder.metadata();
-
-            let mut columns_to_read: Vec<usize> = Vec::new();
-            let columns_to_read_array = 
JObjectArray::from_raw(required_columns);
-            let array_len = env.get_array_length(&columns_to_read_array)?;
-            let mut required_columns: Vec<String> = Vec::new();
-            for i in 0..array_len {
-                let p: JString = env
-                    .get_object_array_element(&columns_to_read_array, i)?
-                    .into();
-                required_columns.push(env.get_string(&p)?.into());
-            }
-            for (i, col) in metadata
-                .file_metadata()
-                .schema_descr()
-                .columns()
-                .iter()
-                .enumerate()
-            {
-                for required in required_columns.iter() {
-                    if col.name().to_uppercase().eq(&required.to_uppercase()) {
-                        columns_to_read.push(i);
-                        break;
-                    }
-                }
-            }
-            //TODO: (ARROW NATIVE) make this work for complex types 
(especially deeply nested structs)
-            let mask =
-                
ProjectionMask::leaves(metadata.file_metadata().schema_descr(), 
columns_to_read);
-            // Set projection mask to read only root columns 1 and 2.
-
-            let mut row_groups_to_read: Vec<usize> = Vec::new();
-            // get row groups -
-            for (i, rg) in metadata.row_groups().iter().enumerate() {
-                let rg_start = rg.file_offset().unwrap();
-                let rg_end = rg_start + rg.compressed_size();
-                if rg_start >= start && rg_end <= start + length {
-                    row_groups_to_read.push(i);
-                    total_rows += rg.num_rows();
-                }
-            }
-            num_row_groups = row_groups_to_read.len();
-            builder = builder
-                .with_projection(mask)
-                .with_row_groups(row_groups_to_read.clone())
-        }
-
-        // Build a sync parquet reader.
-        let batch_reader = builder.build().unwrap();
+        let batch_stream: Option<SendableRecordBatchStream>;
+        let batch_reader: Option<ParquetRecordBatchReader> = None;
+        // TODO: (ARROW NATIVE) Use the common global runtime
+        let runtime = tokio::runtime::Builder::new_multi_thread()
+            .enable_all()
+            .build()?;
+
+        // EXPERIMENTAL - BEGIN
+        //TODO: Need an execution context and a spark plan equivalent so that 
we can reuse
+        // code from jni_api.rs
+        let (object_store_url, object_store_path) = 
get_file_path(path.clone()).unwrap();
+        // TODO: (ARROW NATIVE) - Remove code duplication between this and POC 
1
+        // copy the input on-heap buffer to native
+        let required_schema_array = JByteArray::from_raw(required_schema);
+        let required_schema_buffer = 
env.convert_byte_array(&required_schema_array)?;
+        let required_schema_arrow = 
deserialize_schema(required_schema_buffer.as_bytes())?;
+        let mut partitioned_file = PartitionedFile::new_with_range(
+            String::new(), // Dummy file path. We will override this with our 
path so that url encoding does not occur
+            file_size as u64,
+            start,
+            start + length,
+        );
+        partitioned_file.object_meta.location = object_store_path;
+        // We build the file scan config with the *required* schema so that 
the reader knows
+        // the output schema we want
+        let file_scan_config = FileScanConfig::new(object_store_url, 
Arc::new(required_schema_arrow))
+                .with_file(partitioned_file)
+                // TODO: (ARROW NATIVE) - do partition columns in native
+                //   - will need partition schema and partition values to do so
+                // .with_table_partition_cols(partition_fields)
+                ;
+        let mut table_parquet_options = TableParquetOptions::new();
+        // TODO: Maybe these are configs?
+        table_parquet_options.global.pushdown_filters = true;
+        table_parquet_options.global.reorder_filters = true;
+
+        let builder2 = ParquetExecBuilder::new(file_scan_config)
+            .with_table_parquet_options(table_parquet_options)
+            .with_schema_adapter_factory(Arc::new(
+                
crate::execution::datafusion::schema_adapter::CometSchemaAdapterFactory::default(),
+            ));
+
+        //TODO: (ARROW NATIVE) - predicate pushdown??
+        // builder = builder.with_predicate(filter);
+
+        let scan = builder2.build();
+        let ctx = TaskContext::default();
+        let partition_index: usize = 0;
+        batch_stream = Some(scan.execute(partition_index, Arc::new(ctx))?);
+
+        // EXPERIMENTAL - END
 
         let ctx = BatchContext {
+            runtime,
+            batch_stream,
             batch_reader,
             current_batch: None,
             reader_state: ParquetReaderState::Init,
-            num_row_groups: num_row_groups as i32,
-            total_rows,
         };
         let res = Box::new(ctx);
         Ok(Box::into_raw(res) as i64)
     })
 }
 
-#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_numRowGroups(
-    e: JNIEnv,
-    _jclass: JClass,
-    handle: jlong,
-) -> jint {
-    try_unwrap_or_throw(&e, |_env| {
-        let context = get_batch_context(handle)?;
-        // Read data
-        Ok(context.num_row_groups)
-    }) as jint
-}
-
-#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_numTotalRows(
-    e: JNIEnv,
-    _jclass: JClass,
-    handle: jlong,
-) -> jlong {
-    try_unwrap_or_throw(&e, |_env| {
-        let context = get_batch_context(handle)?;
-        // Read data
-        Ok(context.total_rows)
-    }) as jlong
-}
-
 #[no_mangle]
 pub extern "system" fn 
Java_org_apache_comet_parquet_Native_readNextRecordBatch(
     e: JNIEnv,
@@ -748,21 +716,39 @@ pub extern "system" fn 
Java_org_apache_comet_parquet_Native_readNextRecordBatch(
 ) -> jint {
     try_unwrap_or_throw(&e, |_env| {
         let context = get_batch_context(handle)?;
-        let batch_reader = &mut context.batch_reader;
-        // Read data
         let mut rows_read: i32 = 0;
-        let batch = batch_reader.next();
-
-        match batch {
-            Some(record_batch) => {
-                let batch = record_batch?;
-                rows_read = batch.num_rows() as i32;
-                context.current_batch = Some(batch);
-                context.reader_state = ParquetReaderState::Reading;
-            }
-            None => {
-                context.current_batch = None;
-                context.reader_state = ParquetReaderState::Complete;
+        let batch_stream = context.batch_stream.as_mut().unwrap();
+        let runtime = &context.runtime;
+
+        // let mut stream = batch_stream.as_mut();
+        loop {
+            let next_item = batch_stream.next();
+            let poll_batch: 
Poll<Option<datafusion_common::Result<RecordBatch>>> =
+                runtime.block_on(async { poll!(next_item) });
+
+            match poll_batch {
+                Poll::Ready(Some(batch)) => {
+                    let batch = batch?;
+                    rows_read = batch.num_rows() as i32;
+                    context.current_batch = Some(batch);
+                    context.reader_state = ParquetReaderState::Reading;
+                    break;
+                }
+                Poll::Ready(None) => {
+                    // EOF
+
+                    // TODO: (ARROW NATIVE) We can update metrics here
+                    // crate::execution::jni_api::update_metrics(&mut env, 
exec_context)?;
+
+                    context.current_batch = None;
+                    context.reader_state = ParquetReaderState::Complete;
+                    break;
+                }
+                Poll::Pending => {
+                    // TODO: (ARROW NATIVE): Just keeping polling??
+                    // Ideally we want to yield to avoid consuming CPU while 
blocked on IO ??
+                    continue;
+                }
             }
         }
         Ok(rows_read)
diff --git a/native/core/src/parquet/util/jni.rs 
b/native/core/src/parquet/util/jni.rs
index b61fbeab..596277b3 100644
--- a/native/core/src/parquet/util/jni.rs
+++ b/native/core/src/parquet/util/jni.rs
@@ -24,11 +24,17 @@ use jni::{
     JNIEnv,
 };
 
+use crate::execution::sort::RdxSort;
+use arrow::error::ArrowError;
+use arrow::ipc::reader::StreamReader;
+use datafusion_execution::object_store::ObjectStoreUrl;
+use object_store::path::Path;
 use parquet::{
     basic::{Encoding, LogicalType, TimeUnit, Type as PhysicalType},
     format::{MicroSeconds, MilliSeconds, NanoSeconds},
     schema::types::{ColumnDescriptor, ColumnPath, PrimitiveTypeBuilder},
 };
+use url::{ParseError, Url};
 
 /// Convert primitives from Spark side into a `ColumnDescriptor`.
 #[allow(clippy::too_many_arguments)]
@@ -198,3 +204,52 @@ fn fix_type_length(t: &PhysicalType, type_length: i32) -> 
i32 {
         _ => type_length,
     }
 }
+
+pub fn deserialize_schema(ipc_bytes: &[u8]) -> 
Result<arrow::datatypes::Schema, ArrowError> {
+    let reader = StreamReader::try_new(std::io::Cursor::new(ipc_bytes), None)?;
+    let schema = reader.schema().as_ref().clone();
+    Ok(schema)
+}
+
+// parses the url and returns a tuple of the scheme and object store path
+pub fn get_file_path(url_: String) -> Result<(ObjectStoreUrl, Path), 
ParseError> {
+    // we define origin of a url as scheme + "://" + authority + ["/" + bucket]
+    let url = Url::parse(url_.as_ref()).unwrap();
+    let mut object_store_origin = url.scheme().to_owned();
+    let mut object_store_path = Path::from_url_path(url.path()).unwrap();
+    if object_store_origin == "s3a" {
+        object_store_origin = "s3".to_string();
+        object_store_origin.push_str("://");
+        object_store_origin.push_str(url.authority());
+        object_store_origin.push('/');
+        let path_splits = url.path_segments().map(|c| 
c.collect::<Vec<_>>()).unwrap();
+        object_store_origin.push_str(path_splits.first().unwrap());
+        let new_path = path_splits[1..path_splits.len() - 1].join("/");
+        //TODO: (ARROW NATIVE) check the use of unwrap here
+        object_store_path = 
Path::from_url_path(new_path.clone().as_str()).unwrap();
+    } else {
+        object_store_origin.push_str("://");
+        object_store_origin.push_str(url.authority());
+        object_store_origin.push('/');
+    }
+    Ok((
+        ObjectStoreUrl::parse(object_store_origin).unwrap(),
+        object_store_path,
+    ))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_get_file_path() {
+        let inp = 
"file:///comet/spark-warehouse/t1/part1=2019-01-01%2011%253A11%253A11/part-00000-84d7ed74-8f28-456c-9270-f45376eea144.c000.snappy.parquet";
+        let expected = "comet/spark-warehouse/t1/part1=2019-01-01 
11%3A11%3A11/part-00000-84d7ed74-8f28-456c-9270-f45376eea144.c000.snappy.parquet";
+
+        if let Ok((_obj_store_url, path)) = get_file_path(inp.to_string()) {
+            let actual = path.to_string();
+            assert_eq!(actual, expected);
+        }
+    }
+}
diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs
index a6d13971..95fde373 100644
--- a/native/spark-expr/src/cast.rs
+++ b/native/spark-expr/src/cast.rs
@@ -571,7 +571,7 @@ impl SparkCastOptions {
             eval_mode,
             timezone: timezone.to_string(),
             allow_incompat,
-            is_adapting_schema: false
+            is_adapting_schema: false,
         }
     }
 
@@ -583,7 +583,6 @@ impl SparkCastOptions {
             is_adapting_schema: false,
         }
     }
-
 }
 
 /// Spark-compatible cast implementation. Defers to DataFusion's cast where 
that is known
@@ -2309,8 +2308,7 @@ mod tests {
     #[test]
     fn test_cast_invalid_timezone() {
         let timestamps: PrimitiveArray<TimestampMicrosecondType> = 
vec![i64::MAX].into();
-        let cast_options =
-            SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", 
false);
+        let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a 
valid timezone", false);
         let result = cast_array(
             Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
             &DataType::Date32,
@@ -2401,9 +2399,7 @@ mod tests {
         let cast_array = spark_cast(
             ColumnarValue::Array(c),
             &DataType::Struct(fields),
-            &SparkCastOptions::new(EvalMode::Legacy,
-            "UTC",
-            false)
+            &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
         )
         .unwrap();
         if let ColumnarValue::Array(cast_array) = cast_array {
diff --git a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala 
b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
index eb524af9..e4235495 100644
--- a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
+++ b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
@@ -41,7 +41,7 @@ trait DataTypeSupport {
     case t: DataType if t.typeName == "timestamp_ntz" => true
     case _: StructType
         if CometConf.COMET_FULL_NATIVE_SCAN_ENABLED
-          .get() || CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.get() =>
+          .get() || CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.get() =>
       true
     case _ => false
   }
diff --git 
a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala 
b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
index 4c96bef4..c142abb5 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
@@ -100,7 +100,7 @@ class CometParquetFileFormat extends ParquetFileFormat with 
MetricsSupport with
 
     // Comet specific configurations
     val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf)
-    val nativeArrowReaderEnabled = 
CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.get(sqlConf)
+    val nativeArrowReaderEnabled = 
CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.get(sqlConf)
 
     (file: PartitionedFile) => {
       val sharedConf = broadcastedHadoopConf.value.value
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 99ed5d3c..e997c5bf 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -80,8 +80,8 @@ abstract class CometTestBase
     conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
     conf.set(CometConf.COMET_SPARK_TO_ARROW_ENABLED.key, "true")
     conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true")
-    conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "true")
-    conf.set(CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.key, "false")
+    conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "false")
+    conf.set(CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.key, "true")
     conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g")
     
conf.set(CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key, 
"true")
     conf
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala 
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
index a553e61c..080655fe 100644
--- 
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
+++ 
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
@@ -89,9 +89,13 @@ trait CometPlanStabilitySuite extends 
DisableAdaptiveExecutionSuite with TPCDSBa
       actualSimplifiedPlan: String,
       actualExplain: String): Boolean = {
     val simplifiedFile = new File(dir, "simplified.txt")
-    val expectedSimplified = FileUtils.readFileToString(simplifiedFile, 
StandardCharsets.UTF_8)
-    lazy val explainFile = new File(dir, "explain.txt")
-    lazy val expectedExplain = FileUtils.readFileToString(explainFile, 
StandardCharsets.UTF_8)
+    var expectedSimplified = FileUtils.readFileToString(simplifiedFile, 
StandardCharsets.UTF_8)
+    val explainFile = new File(dir, "explain.txt")
+    var expectedExplain = FileUtils.readFileToString(explainFile, 
StandardCharsets.UTF_8)
+    if (!CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.get()) {
+      expectedExplain = expectedExplain.replace("CometNativeScan", "CometScan")
+      expectedSimplified = expectedSimplified.replace("CometNativeScan", 
"CometScan")
+    }
     expectedSimplified == actualSimplifiedPlan && expectedExplain == 
actualExplain
   }
 
@@ -259,6 +263,9 @@ trait CometPlanStabilitySuite extends 
DisableAdaptiveExecutionSuite with TPCDSBa
     // Disable char/varchar read-side handling for better performance.
     withSQLConf(
       CometConf.COMET_ENABLED.key -> "true",
+      CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true",
+      CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key -> "false",
+      CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.key -> "true",
       CometConf.COMET_EXEC_ENABLED.key -> "true",
       CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "false",
       CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
@@ -288,6 +295,9 @@ trait CometPlanStabilitySuite extends 
DisableAdaptiveExecutionSuite with TPCDSBa
       "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
     conf.set(CometConf.COMET_ENABLED.key, "true")
     conf.set(CometConf.COMET_EXEC_ENABLED.key, "true")
+    conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true")
+    conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "false")
+    conf.set(CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.key, "true")
     conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "1g")
     conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to