This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch comet-parquet-exec
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/comet-parquet-exec by this
push:
new 2686a4b6 feat: [comet-parquet-exec] Use Datafusion based record batch
reader for use in iceberg reads (#1174)
2686a4b6 is described below
commit 2686a4b69e40d7f18c5cf9b488f17b64ba40c5b7
Author: Parth Chandra <[email protected]>
AuthorDate: Tue Dec 17 15:41:39 2024 -0800
feat: [comet-parquet-exec] Use Datafusion based record batch reader for use
in iceberg reads (#1174)
* wip. Use DF's ParquetExec for Iceberg API
* wip - await??
* wip
* wip -
* fix shading issue
* fix shading issue
* fixes
* refactor to remove arrow based reader
* rename config
* Fix config defaults
---------
Co-authored-by: Andy Grove <[email protected]>
---
.../main/java/org/apache/comet/parquet/Native.java | 7 +-
.../apache/comet/parquet/NativeBatchReader.java | 36 ++--
.../main/scala/org/apache/comet/CometConf.scala | 12 +-
.../apache/spark/sql/comet/CometArrowUtils.scala | 180 +++++++++++++++++
native/Cargo.lock | 2 +
native/Cargo.toml | 2 +
native/core/Cargo.toml | 2 +
native/core/src/execution/datafusion/mod.rs | 2 +-
native/core/src/parquet/mod.rs | 224 ++++++++++-----------
native/core/src/parquet/util/jni.rs | 55 +++++
native/spark-expr/src/cast.rs | 10 +-
.../scala/org/apache/comet/DataTypeSupport.scala | 2 +-
.../comet/parquet/CometParquetFileFormat.scala | 2 +-
.../scala/org/apache/spark/sql/CometTestBase.scala | 4 +-
.../spark/sql/comet/CometPlanStabilitySuite.scala | 16 +-
15 files changed, 396 insertions(+), 160 deletions(-)
diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java
b/common/src/main/java/org/apache/comet/parquet/Native.java
index 1ed01d32..b33ec60d 100644
--- a/common/src/main/java/org/apache/comet/parquet/Native.java
+++ b/common/src/main/java/org/apache/comet/parquet/Native.java
@@ -246,15 +246,10 @@ public final class Native extends NativeBase {
* @param filePath
* @param start
* @param length
- * @param required_columns array of names of fields to read
* @return a handle to the record batch reader, used in subsequent calls.
*/
public static native long initRecordBatchReader(
- String filePath, long start, long length, Object[] required_columns);
-
- public static native int numRowGroups(long handle);
-
- public static native long numTotalRows(long handle);
+ String filePath, long fileSize, long start, long length, byte[]
requiredSchema);
// arrow native version of read batch
/**
diff --git
a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
index 3ac55ba4..8461bb50 100644
--- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
+++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java
@@ -19,11 +19,13 @@
package org.apache.comet.parquet;
+import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URISyntaxException;
+import java.nio.channels.Channels;
import java.util.*;
import scala.Option;
@@ -36,6 +38,9 @@ import org.slf4j.LoggerFactory;
import org.apache.arrow.c.CometSchemaImporter;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.ipc.WriteChannel;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.InputSplit;
@@ -52,6 +57,7 @@ import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.comet.CometArrowUtils;
import org.apache.spark.sql.comet.parquet.CometParquetReadSupport;
import org.apache.spark.sql.execution.datasources.PartitionedFile;
import
org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter;
@@ -99,7 +105,6 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
private PartitionedFile file;
private final Map<String, SQLMetric> metrics;
- private long rowsRead;
private StructType sparkSchema;
private MessageType requestedSchema;
private CometVector[] vectors;
@@ -111,9 +116,6 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
private boolean isInitialized;
private ParquetMetadata footer;
- /** The total number of rows across all row groups of the input split. */
- private long totalRowCount;
-
/**
* Whether the native scan should always return decimal represented by 128
bits, regardless of its
* precision. Normally, this should be true if native execution is enabled,
since Arrow compute
@@ -224,6 +226,7 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
long start = file.start();
long length = file.length();
String filePath = file.filePath().toString();
+ long fileSize = file.fileSize();
requestedSchema = footer.getFileMetaData().getSchema();
MessageType fileSchema = requestedSchema;
@@ -254,6 +257,13 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
}
} ////// End get requested schema
+ String timeZoneId = conf.get("spark.sql.session.timeZone");
+ Schema arrowSchema = CometArrowUtils.toArrowSchema(sparkSchema,
timeZoneId);
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out));
+ MessageSerializer.serialize(writeChannel, arrowSchema);
+ byte[] serializedRequestedArrowSchema = out.toByteArray();
+
//// Create Column readers
List<ColumnDescriptor> columns = requestedSchema.getColumns();
int numColumns = columns.size();
@@ -334,13 +344,9 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
}
}
- // TODO: (ARROW NATIVE) Use a ProjectionMask here ?
- ArrayList<String> requiredColumns = new ArrayList<>();
- for (Type col : requestedSchema.asGroupType().getFields()) {
- requiredColumns.add(col.getName());
- }
- this.handle = Native.initRecordBatchReader(filePath, start, length,
requiredColumns.toArray());
- totalRowCount = Native.numRowGroups(handle);
+ this.handle =
+ Native.initRecordBatchReader(
+ filePath, fileSize, start, length, serializedRequestedArrowSchema);
isInitialized = true;
}
@@ -375,7 +381,7 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
@Override
public float getProgress() {
- return (float) rowsRead / totalRowCount;
+ return 0;
}
/**
@@ -395,7 +401,7 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
public boolean nextBatch() throws IOException {
Preconditions.checkState(isInitialized, "init() should be called first!");
- if (rowsRead >= totalRowCount) return false;
+ // if (rowsRead >= totalRowCount) return false;
int batchSize;
try {
@@ -432,7 +438,6 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
}
currentBatch.setNumRows(batchSize);
- rowsRead += batchSize;
return true;
}
@@ -457,6 +462,9 @@ public class NativeBatchReader extends RecordReader<Void,
ColumnarBatch> impleme
long startNs = System.nanoTime();
int batchSize = Native.readNextRecordBatch(this.handle);
+ if (batchSize == 0) {
+ return batchSize;
+ }
if (importer != null) importer.close();
importer = new CometSchemaImporter(ALLOCATOR);
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala
b/common/src/main/scala/org/apache/comet/CometConf.scala
index 275114a1..fabdd30c 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -75,7 +75,7 @@ object CometConf extends ShimCometConf {
"that to enable native vectorized execution, both this config and " +
"'spark.comet.exec.enabled' need to be enabled.")
.booleanConf
- .createWithDefault(true)
+ .createWithDefault(false)
val COMET_FULL_NATIVE_SCAN_ENABLED: ConfigEntry[Boolean] = conf(
"spark.comet.native.scan.enabled")
@@ -85,15 +85,15 @@ object CometConf extends ShimCometConf {
"read supported data sources (currently only Parquet is supported
natively)." +
" By default, this config is true.")
.booleanConf
- .createWithDefault(false)
+ .createWithDefault(true)
- val COMET_NATIVE_ARROW_SCAN_ENABLED: ConfigEntry[Boolean] = conf(
+ val COMET_NATIVE_RECORDBATCH_READER_ENABLED: ConfigEntry[Boolean] = conf(
"spark.comet.native.arrow.scan.enabled")
.internal()
.doc(
- "Whether to enable the fully native arrow based scan. When this is
turned on, Spark will " +
- "use Comet to read Parquet files natively via the Arrow based Parquet
reader." +
- " By default, this config is false.")
+ "Whether to enable the fully native datafusion based column reader. When
this is turned on," +
+ " Spark will use Comet to read Parquet files natively via the
Datafusion based Parquet" +
+ " reader. By default, this config is false.")
.booleanConf
.createWithDefault(false)
diff --git
a/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala
b/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala
new file mode 100644
index 00000000..2f4f55fc
--- /dev/null
+++ b/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.complex.MapVector
+import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision,
IntervalUnit, TimeUnit}
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+object CometArrowUtils {
+
+ val rootAllocator = new RootAllocator(Long.MaxValue)
+
+ // todo: support more types.
+
+ /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for
TimestampTypes */
+ def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match {
+ case BooleanType => ArrowType.Bool.INSTANCE
+ case ByteType => new ArrowType.Int(8, true)
+ case ShortType => new ArrowType.Int(8 * 2, true)
+ case IntegerType => new ArrowType.Int(8 * 4, true)
+ case LongType => new ArrowType.Int(8 * 8, true)
+ case FloatType => new
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
+ case DoubleType => new
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
+ case StringType => ArrowType.Utf8.INSTANCE
+ case BinaryType => ArrowType.Binary.INSTANCE
+ case DecimalType.Fixed(precision, scale) => new
ArrowType.Decimal(precision, scale)
+ case DateType => new ArrowType.Date(DateUnit.DAY)
+ case TimestampType if timeZoneId == null =>
+ throw new IllegalStateException("Missing timezoneId where it is
mandatory.")
+ case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND,
timeZoneId)
+ case TimestampNTZType =>
+ new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
+ case NullType => ArrowType.Null.INSTANCE
+ case _: YearMonthIntervalType => new
ArrowType.Interval(IntervalUnit.YEAR_MONTH)
+ case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
+ case _ =>
+ throw new IllegalArgumentException()
+ }
+
+ def fromArrowType(dt: ArrowType): DataType = dt match {
+ case ArrowType.Bool.INSTANCE => BooleanType
+ case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 =>
ByteType
+ case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 =>
ShortType
+ case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 =>
IntegerType
+ case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 =>
LongType
+ case float: ArrowType.FloatingPoint
+ if float.getPrecision() == FloatingPointPrecision.SINGLE =>
+ FloatType
+ case float: ArrowType.FloatingPoint
+ if float.getPrecision() == FloatingPointPrecision.DOUBLE =>
+ DoubleType
+ case ArrowType.Utf8.INSTANCE => StringType
+ case ArrowType.Binary.INSTANCE => BinaryType
+ case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
+ case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
+ case ts: ArrowType.Timestamp
+ if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
+ TimestampNTZType
+ case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND =>
TimestampType
+ case ArrowType.Null.INSTANCE => NullType
+ case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
+ YearMonthIntervalType()
+ case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND =>
DayTimeIntervalType()
+ case _ => throw new IllegalArgumentException()
+ // throw QueryExecutionErrors.unsupportedArrowTypeError(dt)
+ }
+
+ /** Maps field from Spark to Arrow. NOTE: timeZoneId required for
TimestampType */
+ def toArrowField(name: String, dt: DataType, nullable: Boolean, timeZoneId:
String): Field = {
+ dt match {
+ case ArrayType(elementType, containsNull) =>
+ val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
+ new Field(
+ name,
+ fieldType,
+ Seq(toArrowField("element", elementType, containsNull,
timeZoneId)).asJava)
+ case StructType(fields) =>
+ val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE,
null)
+ new Field(
+ name,
+ fieldType,
+ fields
+ .map { field =>
+ toArrowField(field.name, field.dataType, field.nullable,
timeZoneId)
+ }
+ .toSeq
+ .asJava)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
+ // Note: Map Type struct can not be null, Struct Type key field can
not be null
+ new Field(
+ name,
+ mapType,
+ Seq(
+ toArrowField(
+ MapVector.DATA_VECTOR_NAME,
+ new StructType()
+ .add(MapVector.KEY_NAME, keyType, nullable = false)
+ .add(MapVector.VALUE_NAME, valueType, nullable =
valueContainsNull),
+ nullable = false,
+ timeZoneId)).asJava)
+ case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType,
nullable, timeZoneId)
+ case dataType =>
+ val fieldType = new FieldType(nullable, toArrowType(dataType,
timeZoneId), null)
+ new Field(name, fieldType, Seq.empty[Field].asJava)
+ }
+ }
+
+ def fromArrowField(field: Field): DataType = {
+ field.getType match {
+ case _: ArrowType.Map =>
+ val elementField = field.getChildren.get(0)
+ val keyType = fromArrowField(elementField.getChildren.get(0))
+ val valueType = fromArrowField(elementField.getChildren.get(1))
+ MapType(keyType, valueType, elementField.getChildren.get(1).isNullable)
+ case ArrowType.List.INSTANCE =>
+ val elementField = field.getChildren().get(0)
+ val elementType = fromArrowField(elementField)
+ ArrayType(elementType, containsNull = elementField.isNullable)
+ case ArrowType.Struct.INSTANCE =>
+ val fields = field.getChildren().asScala.map { child =>
+ val dt = fromArrowField(child)
+ StructField(child.getName, dt, child.isNullable)
+ }
+ StructType(fields.toArray)
+ case arrowType => fromArrowType(arrowType)
+ }
+ }
+
+ /**
+ * Maps schema from Spark to Arrow. NOTE: timeZoneId required for
TimestampType in StructType
+ */
+ def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
+ new Schema(schema.map { field =>
+ toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
+ }.asJava)
+ }
+
+ def fromArrowSchema(schema: Schema): StructType = {
+ StructType(schema.getFields.asScala.map { field =>
+ val dt = fromArrowField(field)
+ StructField(field.getName, dt, field.isNullable)
+ }.toArray)
+ }
+
+ /** Return Map with conf settings to be used in ArrowPythonRunner */
+ def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
+ val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key ->
conf.sessionLocalTimeZone)
+ val pandasColsByName = Seq(
+ SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
+ conf.pandasGroupedMapAssignColumnsByName.toString)
+ val arrowSafeTypeCheck = Seq(
+ SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
+ conf.arrowSafeTypeConversion.toString)
+ Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
+ }
+
+}
diff --git a/native/Cargo.lock b/native/Cargo.lock
index c3a664ff..27e97268 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -898,6 +898,7 @@ dependencies = [
"arrow-array",
"arrow-buffer",
"arrow-data",
+ "arrow-ipc",
"arrow-schema",
"assertables",
"async-trait",
@@ -913,6 +914,7 @@ dependencies = [
"datafusion-expr",
"datafusion-functions-nested",
"datafusion-physical-expr",
+ "flatbuffers",
"flate2",
"futures",
"half",
diff --git a/native/Cargo.toml b/native/Cargo.toml
index 4b89231c..b78c1d68 100644
--- a/native/Cargo.toml
+++ b/native/Cargo.toml
@@ -37,7 +37,9 @@ arrow = { version = "53.2.0", features = ["prettyprint",
"ffi", "chrono-tz"] }
arrow-array = { version = "53.2.0" }
arrow-buffer = { version = "53.2.0" }
arrow-data = { version = "53.2.0" }
+arrow-ipc = { version = "53.2.0" }
arrow-schema = { version = "53.2.0" }
+flatbuffers = { version = "24.3.25" }
parquet = { version = "53.2.0", default-features = false, features =
["experimental"] }
datafusion-common = { version = "43.0.0" }
datafusion = { version = "43.0.0", default-features = false, features =
["unicode_expressions", "crypto_expressions", "parquet"] }
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index 8d30b38c..35035ff3 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -40,6 +40,8 @@ arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-data = { workspace = true }
arrow-schema = { workspace = true }
+arrow-ipc = { workspace = true }
+flatbuffers = { workspace = true }
parquet = { workspace = true, default-features = false, features =
["experimental"] }
half = { version = "2.4.1", default-features = false }
futures = "0.3.28"
diff --git a/native/core/src/execution/datafusion/mod.rs
b/native/core/src/execution/datafusion/mod.rs
index fb9c8829..af32b4be 100644
--- a/native/core/src/execution/datafusion/mod.rs
+++ b/native/core/src/execution/datafusion/mod.rs
@@ -20,6 +20,6 @@
pub mod expressions;
mod operators;
pub mod planner;
-mod schema_adapter;
+pub(crate) mod schema_adapter;
pub mod shuffle_writer;
mod util;
diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs
index afca6066..c234b6f7 100644
--- a/native/core/src/parquet/mod.rs
+++ b/native/core/src/parquet/mod.rs
@@ -23,7 +23,7 @@ pub use mutable_vector::*;
pub mod util;
pub mod read;
-use std::fs::File;
+use std::task::Poll;
use std::{boxed::Box, ptr::NonNull, sync::Arc};
use crate::errors::{try_unwrap_or_throw, CometError};
@@ -42,17 +42,21 @@ use jni::{
use crate::execution::operators::ExecutionError;
use crate::execution::utils::SparkArrowConvert;
+use crate::parquet::data_type::AsBytes;
use arrow::buffer::{Buffer, MutableBuffer};
use arrow_array::{Array, RecordBatch};
-use jni::objects::{
- JBooleanArray, JLongArray, JObjectArray, JPrimitiveArray, JString,
ReleaseMode,
-};
+use datafusion::datasource::listing::PartitionedFile;
+use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder;
+use datafusion::datasource::physical_plan::FileScanConfig;
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion_common::config::TableParquetOptions;
+use datafusion_execution::{SendableRecordBatchStream, TaskContext};
+use futures::{poll, StreamExt};
+use jni::objects::{JBooleanArray, JByteArray, JLongArray, JPrimitiveArray,
JString, ReleaseMode};
use jni::sys::jstring;
-use parquet::arrow::arrow_reader::{ParquetRecordBatchReader,
ParquetRecordBatchReaderBuilder};
-use parquet::arrow::ProjectionMask;
+use parquet::arrow::arrow_reader::ParquetRecordBatchReader;
use read::ColumnReader;
-use url::Url;
-use util::jni::{convert_column_descriptor, convert_encoding};
+use util::jni::{convert_column_descriptor, convert_encoding,
deserialize_schema, get_file_path};
use self::util::jni::TypePromotionInfo;
@@ -600,11 +604,11 @@ enum ParquetReaderState {
}
/// Parquet read context maintained across multiple JNI calls.
struct BatchContext {
- batch_reader: ParquetRecordBatchReader,
+ runtime: tokio::runtime::Runtime,
+ batch_stream: Option<SendableRecordBatchStream>,
+ batch_reader: Option<ParquetRecordBatchReader>,
current_batch: Option<RecordBatch>,
reader_state: ParquetReaderState,
- num_row_groups: i32,
- total_rows: i64,
}
#[inline]
@@ -616,10 +620,12 @@ fn get_batch_context<'a>(handle: jlong) -> Result<&'a mut
BatchContext, CometErr
}
}
+/*
#[inline]
fn get_batch_reader<'a>(handle: jlong) -> Result<&'a mut
ParquetRecordBatchReader, CometError> {
- Ok(&mut get_batch_context(handle)?.batch_reader)
+ Ok(&mut get_batch_context(handle)?.batch_reader.unwrap())
}
+*/
/// # Safety
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
@@ -628,118 +634,80 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_initRecordBat
e: JNIEnv,
_jclass: JClass,
file_path: jstring,
+ file_size: jlong,
start: jlong,
length: jlong,
- required_columns: jobjectArray,
+ required_schema: jbyteArray,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| unsafe {
let path: String = env
.get_string(&JString::from_raw(file_path))
.unwrap()
.into();
- //TODO: (ARROW NATIVE) - this works only for 'file://' urls
- let path = Url::parse(path.as_ref()).unwrap().to_file_path().unwrap();
- let file = File::open(path).unwrap();
-
- // Create a async parquet reader builder with batch_size.
- // batch_size is the number of rows to read up to buffer once from
pages, defaults to 1024
- // TODO: (ARROW NATIVE) Use async reader
ParquetRecordBatchStreamBuilder
- let mut builder = ParquetRecordBatchReaderBuilder::try_new(file)
- .unwrap()
- .with_batch_size(8192); // TODO: (ARROW NATIVE) Use batch size
configured in JVM
-
- let num_row_groups;
- let mut total_rows: i64 = 0;
- //TODO: (ARROW NATIVE) if we can get the ParquetMetadata serialized,
we need not do this.
- {
- let metadata = builder.metadata();
-
- let mut columns_to_read: Vec<usize> = Vec::new();
- let columns_to_read_array =
JObjectArray::from_raw(required_columns);
- let array_len = env.get_array_length(&columns_to_read_array)?;
- let mut required_columns: Vec<String> = Vec::new();
- for i in 0..array_len {
- let p: JString = env
- .get_object_array_element(&columns_to_read_array, i)?
- .into();
- required_columns.push(env.get_string(&p)?.into());
- }
- for (i, col) in metadata
- .file_metadata()
- .schema_descr()
- .columns()
- .iter()
- .enumerate()
- {
- for required in required_columns.iter() {
- if col.name().to_uppercase().eq(&required.to_uppercase()) {
- columns_to_read.push(i);
- break;
- }
- }
- }
- //TODO: (ARROW NATIVE) make this work for complex types
(especially deeply nested structs)
- let mask =
-
ProjectionMask::leaves(metadata.file_metadata().schema_descr(),
columns_to_read);
- // Set projection mask to read only root columns 1 and 2.
-
- let mut row_groups_to_read: Vec<usize> = Vec::new();
- // get row groups -
- for (i, rg) in metadata.row_groups().iter().enumerate() {
- let rg_start = rg.file_offset().unwrap();
- let rg_end = rg_start + rg.compressed_size();
- if rg_start >= start && rg_end <= start + length {
- row_groups_to_read.push(i);
- total_rows += rg.num_rows();
- }
- }
- num_row_groups = row_groups_to_read.len();
- builder = builder
- .with_projection(mask)
- .with_row_groups(row_groups_to_read.clone())
- }
-
- // Build a sync parquet reader.
- let batch_reader = builder.build().unwrap();
+ let batch_stream: Option<SendableRecordBatchStream>;
+ let batch_reader: Option<ParquetRecordBatchReader> = None;
+ // TODO: (ARROW NATIVE) Use the common global runtime
+ let runtime = tokio::runtime::Builder::new_multi_thread()
+ .enable_all()
+ .build()?;
+
+ // EXPERIMENTAL - BEGIN
+ //TODO: Need an execution context and a spark plan equivalent so that
we can reuse
+ // code from jni_api.rs
+ let (object_store_url, object_store_path) =
get_file_path(path.clone()).unwrap();
+ // TODO: (ARROW NATIVE) - Remove code duplication between this and POC
1
+ // copy the input on-heap buffer to native
+ let required_schema_array = JByteArray::from_raw(required_schema);
+ let required_schema_buffer =
env.convert_byte_array(&required_schema_array)?;
+ let required_schema_arrow =
deserialize_schema(required_schema_buffer.as_bytes())?;
+ let mut partitioned_file = PartitionedFile::new_with_range(
+ String::new(), // Dummy file path. We will override this with our
path so that url encoding does not occur
+ file_size as u64,
+ start,
+ start + length,
+ );
+ partitioned_file.object_meta.location = object_store_path;
+ // We build the file scan config with the *required* schema so that
the reader knows
+ // the output schema we want
+ let file_scan_config = FileScanConfig::new(object_store_url,
Arc::new(required_schema_arrow))
+ .with_file(partitioned_file)
+ // TODO: (ARROW NATIVE) - do partition columns in native
+ // - will need partition schema and partition values to do so
+ // .with_table_partition_cols(partition_fields)
+ ;
+ let mut table_parquet_options = TableParquetOptions::new();
+ // TODO: Maybe these are configs?
+ table_parquet_options.global.pushdown_filters = true;
+ table_parquet_options.global.reorder_filters = true;
+
+ let builder2 = ParquetExecBuilder::new(file_scan_config)
+ .with_table_parquet_options(table_parquet_options)
+ .with_schema_adapter_factory(Arc::new(
+
crate::execution::datafusion::schema_adapter::CometSchemaAdapterFactory::default(),
+ ));
+
+ //TODO: (ARROW NATIVE) - predicate pushdown??
+ // builder = builder.with_predicate(filter);
+
+ let scan = builder2.build();
+ let ctx = TaskContext::default();
+ let partition_index: usize = 0;
+ batch_stream = Some(scan.execute(partition_index, Arc::new(ctx))?);
+
+ // EXPERIMENTAL - END
let ctx = BatchContext {
+ runtime,
+ batch_stream,
batch_reader,
current_batch: None,
reader_state: ParquetReaderState::Init,
- num_row_groups: num_row_groups as i32,
- total_rows,
};
let res = Box::new(ctx);
Ok(Box::into_raw(res) as i64)
})
}
-#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_numRowGroups(
- e: JNIEnv,
- _jclass: JClass,
- handle: jlong,
-) -> jint {
- try_unwrap_or_throw(&e, |_env| {
- let context = get_batch_context(handle)?;
- // Read data
- Ok(context.num_row_groups)
- }) as jint
-}
-
-#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_numTotalRows(
- e: JNIEnv,
- _jclass: JClass,
- handle: jlong,
-) -> jlong {
- try_unwrap_or_throw(&e, |_env| {
- let context = get_batch_context(handle)?;
- // Read data
- Ok(context.total_rows)
- }) as jlong
-}
-
#[no_mangle]
pub extern "system" fn
Java_org_apache_comet_parquet_Native_readNextRecordBatch(
e: JNIEnv,
@@ -748,21 +716,39 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_readNextRecordBatch(
) -> jint {
try_unwrap_or_throw(&e, |_env| {
let context = get_batch_context(handle)?;
- let batch_reader = &mut context.batch_reader;
- // Read data
let mut rows_read: i32 = 0;
- let batch = batch_reader.next();
-
- match batch {
- Some(record_batch) => {
- let batch = record_batch?;
- rows_read = batch.num_rows() as i32;
- context.current_batch = Some(batch);
- context.reader_state = ParquetReaderState::Reading;
- }
- None => {
- context.current_batch = None;
- context.reader_state = ParquetReaderState::Complete;
+ let batch_stream = context.batch_stream.as_mut().unwrap();
+ let runtime = &context.runtime;
+
+ // let mut stream = batch_stream.as_mut();
+ loop {
+ let next_item = batch_stream.next();
+ let poll_batch:
Poll<Option<datafusion_common::Result<RecordBatch>>> =
+ runtime.block_on(async { poll!(next_item) });
+
+ match poll_batch {
+ Poll::Ready(Some(batch)) => {
+ let batch = batch?;
+ rows_read = batch.num_rows() as i32;
+ context.current_batch = Some(batch);
+ context.reader_state = ParquetReaderState::Reading;
+ break;
+ }
+ Poll::Ready(None) => {
+ // EOF
+
+ // TODO: (ARROW NATIVE) We can update metrics here
+ // crate::execution::jni_api::update_metrics(&mut env,
exec_context)?;
+
+ context.current_batch = None;
+ context.reader_state = ParquetReaderState::Complete;
+ break;
+ }
+ Poll::Pending => {
+ // TODO: (ARROW NATIVE): Just keeping polling??
+ // Ideally we want to yield to avoid consuming CPU while
blocked on IO ??
+ continue;
+ }
}
}
Ok(rows_read)
diff --git a/native/core/src/parquet/util/jni.rs
b/native/core/src/parquet/util/jni.rs
index b61fbeab..596277b3 100644
--- a/native/core/src/parquet/util/jni.rs
+++ b/native/core/src/parquet/util/jni.rs
@@ -24,11 +24,17 @@ use jni::{
JNIEnv,
};
+use crate::execution::sort::RdxSort;
+use arrow::error::ArrowError;
+use arrow::ipc::reader::StreamReader;
+use datafusion_execution::object_store::ObjectStoreUrl;
+use object_store::path::Path;
use parquet::{
basic::{Encoding, LogicalType, TimeUnit, Type as PhysicalType},
format::{MicroSeconds, MilliSeconds, NanoSeconds},
schema::types::{ColumnDescriptor, ColumnPath, PrimitiveTypeBuilder},
};
+use url::{ParseError, Url};
/// Convert primitives from Spark side into a `ColumnDescriptor`.
#[allow(clippy::too_many_arguments)]
@@ -198,3 +204,52 @@ fn fix_type_length(t: &PhysicalType, type_length: i32) ->
i32 {
_ => type_length,
}
}
+
+pub fn deserialize_schema(ipc_bytes: &[u8]) ->
Result<arrow::datatypes::Schema, ArrowError> {
+ let reader = StreamReader::try_new(std::io::Cursor::new(ipc_bytes), None)?;
+ let schema = reader.schema().as_ref().clone();
+ Ok(schema)
+}
+
+// parses the url and returns a tuple of the scheme and object store path
+pub fn get_file_path(url_: String) -> Result<(ObjectStoreUrl, Path),
ParseError> {
+ // we define origin of a url as scheme + "://" + authority + ["/" + bucket]
+ let url = Url::parse(url_.as_ref()).unwrap();
+ let mut object_store_origin = url.scheme().to_owned();
+ let mut object_store_path = Path::from_url_path(url.path()).unwrap();
+ if object_store_origin == "s3a" {
+ object_store_origin = "s3".to_string();
+ object_store_origin.push_str("://");
+ object_store_origin.push_str(url.authority());
+ object_store_origin.push('/');
+ let path_splits = url.path_segments().map(|c|
c.collect::<Vec<_>>()).unwrap();
+ object_store_origin.push_str(path_splits.first().unwrap());
+ let new_path = path_splits[1..path_splits.len() - 1].join("/");
+ //TODO: (ARROW NATIVE) check the use of unwrap here
+ object_store_path =
Path::from_url_path(new_path.clone().as_str()).unwrap();
+ } else {
+ object_store_origin.push_str("://");
+ object_store_origin.push_str(url.authority());
+ object_store_origin.push('/');
+ }
+ Ok((
+ ObjectStoreUrl::parse(object_store_origin).unwrap(),
+ object_store_path,
+ ))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_get_file_path() {
+ let inp =
"file:///comet/spark-warehouse/t1/part1=2019-01-01%2011%253A11%253A11/part-00000-84d7ed74-8f28-456c-9270-f45376eea144.c000.snappy.parquet";
+ let expected = "comet/spark-warehouse/t1/part1=2019-01-01
11%3A11%3A11/part-00000-84d7ed74-8f28-456c-9270-f45376eea144.c000.snappy.parquet";
+
+ if let Ok((_obj_store_url, path)) = get_file_path(inp.to_string()) {
+ let actual = path.to_string();
+ assert_eq!(actual, expected);
+ }
+ }
+}
diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs
index a6d13971..95fde373 100644
--- a/native/spark-expr/src/cast.rs
+++ b/native/spark-expr/src/cast.rs
@@ -571,7 +571,7 @@ impl SparkCastOptions {
eval_mode,
timezone: timezone.to_string(),
allow_incompat,
- is_adapting_schema: false
+ is_adapting_schema: false,
}
}
@@ -583,7 +583,6 @@ impl SparkCastOptions {
is_adapting_schema: false,
}
}
-
}
/// Spark-compatible cast implementation. Defers to DataFusion's cast where
that is known
@@ -2309,8 +2308,7 @@ mod tests {
#[test]
fn test_cast_invalid_timezone() {
let timestamps: PrimitiveArray<TimestampMicrosecondType> =
vec![i64::MAX].into();
- let cast_options =
- SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone",
false);
+ let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a
valid timezone", false);
let result = cast_array(
Arc::new(timestamps.with_timezone("Europe/Copenhagen")),
&DataType::Date32,
@@ -2401,9 +2399,7 @@ mod tests {
let cast_array = spark_cast(
ColumnarValue::Array(c),
&DataType::Struct(fields),
- &SparkCastOptions::new(EvalMode::Legacy,
- "UTC",
- false)
+ &SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
)
.unwrap();
if let ColumnarValue::Array(cast_array) = cast_array {
diff --git a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
index eb524af9..e4235495 100644
--- a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
+++ b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
@@ -41,7 +41,7 @@ trait DataTypeSupport {
case t: DataType if t.typeName == "timestamp_ntz" => true
case _: StructType
if CometConf.COMET_FULL_NATIVE_SCAN_ENABLED
- .get() || CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.get() =>
+ .get() || CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.get() =>
true
case _ => false
}
diff --git
a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
index 4c96bef4..c142abb5 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
@@ -100,7 +100,7 @@ class CometParquetFileFormat extends ParquetFileFormat with
MetricsSupport with
// Comet specific configurations
val capacity = CometConf.COMET_BATCH_SIZE.get(sqlConf)
- val nativeArrowReaderEnabled =
CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.get(sqlConf)
+ val nativeArrowReaderEnabled =
CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.get(sqlConf)
(file: PartitionedFile) => {
val sharedConf = broadcastedHadoopConf.value.value
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 99ed5d3c..e997c5bf 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -80,8 +80,8 @@ abstract class CometTestBase
conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
conf.set(CometConf.COMET_SPARK_TO_ARROW_ENABLED.key, "true")
conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true")
- conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "true")
- conf.set(CometConf.COMET_NATIVE_ARROW_SCAN_ENABLED.key, "false")
+ conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "false")
+ conf.set(CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.key, "true")
conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g")
conf.set(CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key,
"true")
conf
diff --git
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
index a553e61c..080655fe 100644
---
a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
+++
b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
@@ -89,9 +89,13 @@ trait CometPlanStabilitySuite extends
DisableAdaptiveExecutionSuite with TPCDSBa
actualSimplifiedPlan: String,
actualExplain: String): Boolean = {
val simplifiedFile = new File(dir, "simplified.txt")
- val expectedSimplified = FileUtils.readFileToString(simplifiedFile,
StandardCharsets.UTF_8)
- lazy val explainFile = new File(dir, "explain.txt")
- lazy val expectedExplain = FileUtils.readFileToString(explainFile,
StandardCharsets.UTF_8)
+ var expectedSimplified = FileUtils.readFileToString(simplifiedFile,
StandardCharsets.UTF_8)
+ val explainFile = new File(dir, "explain.txt")
+ var expectedExplain = FileUtils.readFileToString(explainFile,
StandardCharsets.UTF_8)
+ if (!CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.get()) {
+ expectedExplain = expectedExplain.replace("CometNativeScan", "CometScan")
+ expectedSimplified = expectedSimplified.replace("CometNativeScan",
"CometScan")
+ }
expectedSimplified == actualSimplifiedPlan && expectedExplain ==
actualExplain
}
@@ -259,6 +263,9 @@ trait CometPlanStabilitySuite extends
DisableAdaptiveExecutionSuite with TPCDSBa
// Disable char/varchar read-side handling for better performance.
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true",
+ CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key -> "false",
+ CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "false",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
@@ -288,6 +295,9 @@ trait CometPlanStabilitySuite extends
DisableAdaptiveExecutionSuite with TPCDSBa
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
conf.set(CometConf.COMET_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_ENABLED.key, "true")
+ conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true")
+ conf.set(CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.key, "false")
+ conf.set(CometConf.COMET_NATIVE_RECORDBATCH_READER_ENABLED.key, "true")
conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "1g")
conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]