This is an automated email from the ASF dual-hosted git repository.
parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c3f59a65c chore(deps): bump jni from 0.21.1 to 0.22.4 in /native
(#3753)
c3f59a65c is described below
commit c3f59a65c9562572327115a1fb4de14647b6ad89
Author: Manu Zhang <[email protected]>
AuthorDate: Fri Apr 3 01:10:42 2026 +0800
chore(deps): bump jni from 0.21.1 to 0.22.4 in /native (#3753)
Co-authored-by: Codex <[email protected]>
---
native/Cargo.lock | 90 ++++-
native/core/Cargo.toml | 4 +-
native/core/src/execution/expressions/subquery.rs | 46 +--
native/core/src/execution/jni_api.rs | 150 ++++----
.../core/src/execution/memory_pools/fair_pool.rs | 20 +-
native/core/src/execution/memory_pools/mod.rs | 4 +-
.../src/execution/memory_pools/unified_pool.rs | 20 +-
native/core/src/execution/metrics/utils.rs | 4 +-
native/core/src/execution/operators/projection.rs | 4 +-
native/core/src/execution/operators/scan.rs | 178 +++++----
.../core/src/execution/operators/shuffle_scan.rs | 92 ++---
native/core/src/execution/planner.rs | 6 +-
.../src/execution/planner/operator_registry.rs | 6 +-
native/core/src/lib.rs | 20 +-
native/core/src/parquet/encryption_support.rs | 155 ++++----
native/core/src/parquet/mod.rs | 120 +++---
native/core/src/parquet/util/jni.rs | 17 +-
native/jni-bridge/Cargo.toml | 4 +-
native/jni-bridge/src/batch_iterator.rs | 31 +-
native/jni-bridge/src/comet_exec.rs | 73 +++-
native/jni-bridge/src/comet_metric_node.rs | 28 +-
native/jni-bridge/src/comet_task_memory_manager.rs | 19 +-
native/jni-bridge/src/errors.rs | 406 +++++++++++----------
native/jni-bridge/src/lib.rs | 110 +++---
native/jni-bridge/src/shuffle_block_iterator.rs | 25 +-
25 files changed, 902 insertions(+), 730 deletions(-)
diff --git a/native/Cargo.lock b/native/Cargo.lock
index 0cf1f2031..9ac87ee25 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -1321,7 +1321,7 @@ checksum =
"0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
- "libloading 0.8.9",
+ "libloading",
]
[[package]]
@@ -1853,7 +1853,7 @@ dependencies = [
"iceberg",
"iceberg-storage-opendal",
"itertools 0.14.0",
- "jni",
+ "jni 0.22.4",
"lazy_static",
"log",
"log4rs",
@@ -1914,7 +1914,7 @@ dependencies = [
"assertables",
"datafusion",
"datafusion-comet-common",
- "jni",
+ "jni 0.22.4",
"lazy_static",
"once_cell",
"parquet",
@@ -1962,7 +1962,7 @@ dependencies = [
"datafusion-comet-spark-expr",
"futures",
"itertools 0.14.0",
- "jni",
+ "jni 0.21.1",
"log",
"lz4_flex 0.13.0",
"simd-adler32",
@@ -3715,20 +3715,72 @@ dependencies = [
"cesu8",
"cfg-if",
"combine",
- "java-locator",
- "jni-sys",
- "libloading 0.7.4",
+ "jni-sys 0.3.1",
"log",
"thiserror 1.0.69",
"walkdir",
"windows-sys 0.45.0",
]
+[[package]]
+name = "jni"
+version = "0.22.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5efd9a482cf3a427f00d6b35f14332adc7902ce91efb778580e180ff90fa3498"
+dependencies = [
+ "cfg-if",
+ "combine",
+ "java-locator",
+ "jni-macros",
+ "jni-sys 0.4.1",
+ "libloading",
+ "log",
+ "simd_cesu8",
+ "thiserror 2.0.18",
+ "walkdir",
+ "windows-link",
+]
+
+[[package]]
+name = "jni-macros"
+version = "0.22.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a00109accc170f0bdb141fed3e393c565b6f5e072365c3bd58f5b062591560a3"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "rustc_version",
+ "simd_cesu8",
+ "syn 2.0.117",
+]
+
[[package]]
name = "jni-sys"
-version = "0.3.0"
+version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
+checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258"
+dependencies = [
+ "jni-sys 0.4.1",
+]
+
+[[package]]
+name = "jni-sys"
+version = "0.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2"
+dependencies = [
+ "jni-sys-macros",
+]
+
+[[package]]
+name = "jni-sys-macros"
+version = "0.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264"
+dependencies = [
+ "quote",
+ "syn 2.0.117",
+]
[[package]]
name = "jobserver"
@@ -3864,16 +3916,6 @@ version = "0.2.183"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d"
-[[package]]
-name = "libloading"
-version = "0.7.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f"
-dependencies = [
- "cfg-if",
- "winapi",
-]
-
[[package]]
name = "libloading"
version = "0.8.9"
@@ -5692,6 +5734,16 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214"
+[[package]]
+name = "simd_cesu8"
+version = "1.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "94f90157bb87cddf702797c5dadfa0be7d266cdf49e22da2fcaa32eff75b2c33"
+dependencies = [
+ "rustc_version",
+ "simdutf8",
+]
+
[[package]]
name = "simdutf8"
version = "0.1.5"
diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml
index decc83c50..b87e389e9 100644
--- a/native/core/Cargo.toml
+++ b/native/core/Cargo.toml
@@ -46,7 +46,7 @@ async-trait = { workspace = true }
log = "0.4"
log4rs = "1.4.0"
prost = "0.14.3"
-jni = "0.21"
+jni = "0.22.4"
rand = { workspace = true }
num = { workspace = true }
bytes = { workspace = true }
@@ -87,7 +87,7 @@ hdrs = { version = "0.3.2", features = ["vendored"] }
[dev-dependencies]
pprof = { version = "0.15", features = ["flamegraph"] }
criterion = { version = "0.7", features = ["async", "async_tokio",
"async_std"] }
-jni = { version = "0.21", features = ["invocation"] }
+jni = { version = "0.22.4", features = ["invocation"] }
lazy_static = "1.4"
assertables = "9"
hex = "0.4.3"
diff --git a/native/core/src/execution/expressions/subquery.rs
b/native/core/src/execution/expressions/subquery.rs
index ad4106c25..9272ede60 100644
--- a/native/core/src/execution/expressions/subquery.rs
+++ b/native/core/src/execution/expressions/subquery.rs
@@ -25,7 +25,7 @@ use datafusion::common::{internal_err, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr::PhysicalExpr;
use jni::{
- objects::JByteArray,
+ objects::{JByteArray, JString},
sys::{jboolean, jbyte, jint, jlong, jshort},
};
use std::{
@@ -80,14 +80,12 @@ impl PhysicalExpr for Subquery {
}
fn evaluate(&self, _: &RecordBatch) ->
datafusion::common::Result<ColumnarValue> {
- let mut env = JVMClasses::get_env()?;
-
- unsafe {
- let is_null = jni_static_call!(&mut env,
+ JVMClasses::with_env(|env| unsafe {
+ let is_null = jni_static_call!(env,
comet_exec.is_null(self.exec_context_id, self.id) -> jboolean
)?;
- if is_null > 0 {
+ if is_null {
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(
&self.data_type,
)?));
@@ -95,53 +93,53 @@ impl PhysicalExpr for Subquery {
match &self.data_type {
DataType::Boolean => {
- let r = jni_static_call!(&mut env,
+ let r = jni_static_call!(env,
comet_exec.get_bool(self.exec_context_id, self.id) ->
jboolean
)?;
- Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r >
0))))
+ Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r))))
}
DataType::Int8 => {
- let r = jni_static_call!(&mut env,
+ let r = jni_static_call!(env,
comet_exec.get_byte(self.exec_context_id, self.id) ->
jbyte
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r))))
}
DataType::Int16 => {
- let r = jni_static_call!(&mut env,
+ let r = jni_static_call!(env,
comet_exec.get_short(self.exec_context_id, self.id) ->
jshort
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r))))
}
DataType::Int32 => {
- let r = jni_static_call!(&mut env,
+ let r = jni_static_call!(env,
comet_exec.get_int(self.exec_context_id, self.id) ->
jint
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r))))
}
DataType::Int64 => {
- let r = jni_static_call!(&mut env,
+ let r = jni_static_call!(env,
comet_exec.get_long(self.exec_context_id, self.id) ->
jlong
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r))))
}
DataType::Float32 => {
- let r = jni_static_call!(&mut env,
+ let r = jni_static_call!(env,
comet_exec.get_float(self.exec_context_id, self.id) ->
f32
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r))))
}
DataType::Float64 => {
- let r = jni_static_call!(&mut env,
+ let r = jni_static_call!(env,
comet_exec.get_double(self.exec_context_id, self.id)
-> f64
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r))))
}
DataType::Decimal128(p, s) => {
- let bytes = jni_static_call!(&mut env,
+ let bytes = jni_static_call!(env,
comet_exec.get_decimal(self.exec_context_id, self.id)
-> BinaryWrapper
)?;
- let bytes: &JByteArray = bytes.get().into();
+ let bytes = JByteArray::from_raw(env,
bytes.get().as_raw());
let slice = env.convert_byte_array(bytes).unwrap();
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
@@ -151,14 +149,14 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Date32 => {
- let r = jni_static_call!(&mut env,
+ let r = jni_static_call!(env,
comet_exec.get_int(self.exec_context_id, self.id) ->
jint
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r))))
}
DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
- let r = jni_static_call!(&mut env,
+ let r = jni_static_call!(env,
comet_exec.get_long(self.exec_context_id, self.id) ->
jlong
)?;
@@ -168,25 +166,27 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Utf8 => {
- let string = jni_static_call!(&mut env,
+ let string = jni_static_call!(env,
comet_exec.get_string(self.exec_context_id, self.id)
-> StringWrapper
)?;
- let string = env.get_string(string.get()).unwrap().into();
+ let string = JString::from_raw(env, string.get().as_raw())
+ .try_to_string(env)
+ .unwrap();
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))))
}
DataType::Binary => {
- let bytes = jni_static_call!(&mut env,
+ let bytes = jni_static_call!(env,
comet_exec.get_binary(self.exec_context_id, self.id)
-> BinaryWrapper
)?;
- let bytes: &JByteArray = bytes.get().into();
+ let bytes = JByteArray::from_raw(env,
bytes.get().as_raw());
let slice = env.convert_byte_array(bytes).unwrap();
Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(slice))))
}
_ => internal_err!("Unsupported scalar subquery data type
{:?}", self.data_type),
}
- }
+ })
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
diff --git a/native/core/src/execution/jni_api.rs
b/native/core/src/execution/jni_api.rs
index e0a395ebb..b34493ad6 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -66,11 +66,11 @@ use jni::sys::{jlongArray, JNI_FALSE};
use jni::{
errors::Result as JNIResult,
objects::{
- GlobalRef, JByteArray, JClass, JIntArray, JLongArray, JObject,
JObjectArray, JString,
+ Global, JByteArray, JClass, JIntArray, JLongArray, JObject,
JObjectArray, JString,
ReleaseMode,
},
sys::{jboolean, jdouble, jint, jlong},
- JNIEnv,
+ Env, EnvUnowned,
};
use std::collections::HashMap;
use std::path::PathBuf;
@@ -154,13 +154,13 @@ struct ExecutionContext {
/// The shuffle scan input sources for the DataFusion plan
pub shuffle_scans: Vec<ShuffleScanExec>,
/// The global reference of input sources for the DataFusion plan
- pub input_sources: Vec<Arc<GlobalRef>>,
+ pub input_sources: Vec<Arc<Global<JObject<'static>>>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// Receives batches from a spawned tokio task (async I/O path)
pub batch_receiver: Option<mpsc::Receiver<DataFusionResult<RecordBatch>>>,
/// Native metrics
- pub metrics: Arc<GlobalRef>,
+ pub metrics: Arc<Global<JObject<'static>>>,
// The interval in milliseconds to update metrics
pub metrics_update_interval: Option<Duration>,
// The last update time of metrics
@@ -186,7 +186,7 @@ struct ExecutionContext {
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
id: jlong,
iterators: JObjectArray,
@@ -206,7 +206,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
task_cpus: jlong,
key_unwrapper_obj: JObject,
) -> jlong {
- try_unwrap_or_throw(&e, |mut env| {
+ try_unwrap_or_throw(&e, |env| {
// Deserialize Spark configs
let bytes = env.convert_byte_array(serialized_spark_configs)?;
let spark_configs = serde::deserialize_config(bytes.as_slice())?;
@@ -227,7 +227,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
with_trace("createPlan", tracing_enabled, || {
// Init JVM classes
- JVMClasses::init(&mut env);
+ JVMClasses::init(env);
let start = Instant::now();
@@ -239,9 +239,9 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
// Get the global references of input sources
let mut input_sources = vec![];
- let num_inputs = env.get_array_length(&iterators)?;
+ let num_inputs = iterators.len(env)?;
for i in 0..num_inputs {
- let input_source = env.get_object_array_element(&iterators,
i)?;
+ let input_source = iterators.get_element(env, i)?;
let input_source = Arc::new(jni_new_global_ref!(env,
input_source)?);
input_sources.push(input_source);
}
@@ -250,7 +250,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
let task_memory_manager =
Arc::new(jni_new_global_ref!(env,
comet_task_memory_manager_obj)?);
- let memory_pool_type = env.get_string(&memory_pool_type)?.into();
+ let memory_pool_type = memory_pool_type.try_to_string(env)?;
let memory_pool_config = parse_memory_pool_config(
off_heap_mode != JNI_FALSE,
memory_pool_type,
@@ -267,12 +267,13 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
};
// Get local directories for storing spill files
- let num_local_dirs = env.get_array_length(&local_dirs)?;
+ let num_local_dirs = local_dirs.len(env)?;
let mut local_dirs_vec = vec![];
for i in 0..num_local_dirs {
- let local_dir: JString =
env.get_object_array_element(&local_dirs, i)?.into();
- let local_dir = env.get_string(&local_dir)?;
- local_dirs_vec.push(local_dir.into());
+ let local_dir = local_dirs.get_element(env, i)?;
+ let local_dir = unsafe { JString::from_raw(&*env,
local_dir.into_raw()) };
+ let local_dir = local_dir.try_to_string(env)?;
+ local_dirs_vec.push(local_dir);
}
// We need to keep the session context alive. Some session state
like temporary
@@ -298,7 +299,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
// Handle key unwrapper for encrypted files
if !key_unwrapper_obj.is_null() {
let encryption_factory = CometEncryptionFactory {
- key_unwrapper: jni_new_global_ref!(env,
key_unwrapper_obj)?,
+ key_unwrapper: Arc::new(jni_new_global_ref!(env,
key_unwrapper_obj)?),
};
session.runtime_env().register_parquet_encryption_factory(
ENCRYPTION_FACTORY_ID,
@@ -414,18 +415,18 @@ fn register_datafusion_spark_function(session_ctx:
&SessionContext) {
/// Prepares arrow arrays for output.
fn prepare_output(
- env: &mut JNIEnv,
+ env: &mut Env,
array_addrs: JLongArray,
schema_addrs: JLongArray,
output_batch: RecordBatch,
validate: bool,
) -> CometResult<jlong> {
- let num_cols = env.get_array_length(&array_addrs)? as usize;
+ let num_cols = array_addrs.len(env)?;
- let array_addrs = unsafe { env.get_array_elements(&array_addrs,
ReleaseMode::NoCopyBack)? };
+ let array_addrs = unsafe { array_addrs.get_elements(env,
ReleaseMode::NoCopyBack)? };
let array_addrs = &*array_addrs;
- let schema_addrs = unsafe { env.get_array_elements(&schema_addrs,
ReleaseMode::NoCopyBack)? };
+ let schema_addrs = unsafe { schema_addrs.get_elements(env,
ReleaseMode::NoCopyBack)? };
let schema_addrs = &*schema_addrs;
let results = output_batch.columns();
@@ -507,7 +508,7 @@ fn pull_input_batches(exec_context: &mut ExecutionContext)
-> Result<(), CometEr
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
stage_id: jint,
partition: jint,
@@ -515,7 +516,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
array_addrs: JLongArray,
schema_addrs: JLongArray,
) -> jlong {
- try_unwrap_or_throw(&e, |mut env| {
+ try_unwrap_or_throw(&e, |env| {
// Retrieve the query
let exec_context = get_execution_context(exec_context);
@@ -618,9 +619,9 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
if let Some(rx) = &mut exec_context.batch_receiver {
match rx.blocking_recv() {
Some(Ok(batch)) => {
- update_metrics(&mut env, exec_context)?;
+ update_metrics(env, exec_context)?;
return prepare_output(
- &mut env,
+ env,
array_addrs,
schema_addrs,
batch,
@@ -649,7 +650,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
if exec_context.poll_count_since_metrics_check >= 100 {
let now = Instant::now();
if now - exec_context.metrics_last_update_time >=
interval {
- update_metrics(&mut env, exec_context)?;
+ update_metrics(env, exec_context)?;
exec_context.metrics_last_update_time = now;
}
exec_context.poll_count_since_metrics_check = 0;
@@ -659,7 +660,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
match poll_output {
Poll::Ready(Some(output)) => {
return prepare_output(
- &mut env,
+ env,
array_addrs,
schema_addrs,
output?,
@@ -686,15 +687,15 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
#[no_mangle]
/// Drop the native query plan object and context object.
pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
exec_context: jlong,
) {
- try_unwrap_or_throw(&e, |mut env| unsafe {
+ try_unwrap_or_throw(&e, |env| unsafe {
let execution_context = get_execution_context(exec_context);
// Update metrics
- update_metrics(&mut env, execution_context)?;
+ update_metrics(env, execution_context)?;
handle_task_shared_pool_release(
execution_context.memory_pool_config.pool_type,
@@ -706,8 +707,7 @@ pub extern "system" fn
Java_org_apache_comet_Native_releasePlan(
})
}
-/// Updates the metrics of the query plan.
-fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) ->
CometResult<()> {
+fn update_metrics(env: &mut Env, exec_context: &mut ExecutionContext) ->
CometResult<()> {
if let Some(native_query) = &exec_context.root_op {
let metrics = exec_context.metrics.as_obj();
update_comet_metric(env, metrics, native_query)
@@ -732,15 +732,15 @@ fn log_plan_metrics(exec_context: &ExecutionContext,
stage_id: jint, partition:
}
fn convert_datatype_arrays(
- env: &'_ mut JNIEnv<'_>,
+ env: &mut Env,
serialized_datatypes: JObjectArray,
) -> JNIResult<Vec<ArrowDataType>> {
- let array_len = env.get_array_length(&serialized_datatypes)?;
+ let array_len = serialized_datatypes.len(env)?;
let mut res: Vec<ArrowDataType> = Vec::new();
for i in 0..array_len {
- let inner_array = env.get_object_array_element(&serialized_datatypes,
i)?;
- let inner_array: JByteArray = inner_array.into();
+ let inner_array = serialized_datatypes.get_element(env, i)?;
+ let inner_array = unsafe { JByteArray::from_raw(&*env,
inner_array.into_raw()) };
let bytes = env.convert_byte_array(inner_array)?;
let data_type =
serde::deserialize_data_type(bytes.as_slice()).unwrap();
let arrow_dt = to_arrow_datatype(&data_type);
@@ -763,7 +763,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut
ExecutionContext {
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
row_addresses: JLongArray,
row_sizes: JIntArray,
@@ -778,25 +778,23 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative
compression_level: jint,
tracing_enabled: jboolean,
) -> jlongArray {
- try_unwrap_or_throw(&e, |mut env| unsafe {
+ try_unwrap_or_throw(&e, |env| unsafe {
with_trace(
"writeSortedFileNative",
tracing_enabled != JNI_FALSE,
|| {
- let data_types = convert_datatype_arrays(&mut env,
serialized_datatypes)?;
+ let data_types = convert_datatype_arrays(env,
serialized_datatypes)?;
- let row_num = env.get_array_length(&row_addresses)? as usize;
- let row_addresses =
- env.get_array_elements(&row_addresses,
ReleaseMode::NoCopyBack)?;
+ let row_num = row_addresses.len(env)?;
+ let row_addresses = row_addresses.get_elements(env,
ReleaseMode::NoCopyBack)?;
- let row_sizes = env.get_array_elements(&row_sizes,
ReleaseMode::NoCopyBack)?;
+ let row_sizes = row_sizes.get_elements(env,
ReleaseMode::NoCopyBack)?;
let row_addresses_ptr = row_addresses.as_ptr();
let row_sizes_ptr = row_sizes.as_ptr();
- let output_path: String =
env.get_string(&file_path).unwrap().into();
+ let output_path: String =
file_path.try_to_string(env).unwrap();
- let checksum_enabled = checksum_enabled == 1;
let current_checksum = if current_checksum == i64::MIN {
// Initial checksum is not available.
None
@@ -804,7 +802,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative
Some(current_checksum as u32)
};
- let compression_codec: String =
env.get_string(&compression_codec).unwrap().into();
+ let compression_codec: String =
compression_codec.try_to_string(env).unwrap();
let compression_codec = match compression_codec.as_str() {
"zstd" => CompressionCodec::Zstd(compression_level),
@@ -836,7 +834,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative
};
let long_array = env.new_long_array(2)?;
- env.set_long_array_region(&long_array, 0, &[written_bytes,
checksum])?;
+ long_array.set_region(env, 0, &[written_bytes, checksum])?;
Ok(long_array.into_raw())
},
@@ -847,7 +845,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative
#[no_mangle]
/// Used by Comet shuffle external sorter to sort in-memory row partition ids.
pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
address: jlong,
size: jlong,
@@ -880,7 +878,7 @@ pub extern "system" fn
Java_org_apache_comet_Native_sortRowPartitionsNative(
/// # Safety
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
byte_buffer: JByteBuffer,
length: jint,
@@ -888,13 +886,13 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_decodeShuffleBlock(
schema_addrs: JLongArray,
tracing_enabled: jboolean,
) -> jlong {
- try_unwrap_or_throw(&e, |mut env| {
+ try_unwrap_or_throw(&e, |env| {
with_trace("decodeShuffleBlock", tracing_enabled != JNI_FALSE, || {
let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?;
let length = length as usize;
let slice: &[u8] = unsafe {
std::slice::from_raw_parts(raw_pointer, length) };
let batch = read_ipc_compressed(slice)?;
- prepare_output(&mut env, array_addrs, schema_addrs, batch, false)
+ prepare_output(env, array_addrs, schema_addrs, batch, false)
})
})
}
@@ -903,12 +901,12 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_decodeShuffleBlock(
/// # Safety
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
pub unsafe extern "system" fn Java_org_apache_comet_Native_traceBegin(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
event: JString,
) {
- try_unwrap_or_throw(&e, |mut env| {
- let name: String = env.get_string(&event).unwrap().into();
+ try_unwrap_or_throw(&e, |env| {
+ let name: String = event.try_to_string(env).unwrap();
trace_begin(&name);
Ok(())
})
@@ -918,12 +916,12 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_traceBegin(
/// # Safety
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
pub unsafe extern "system" fn Java_org_apache_comet_Native_traceEnd(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
event: JString,
) {
- try_unwrap_or_throw(&e, |mut env| {
- let name: String = env.get_string(&event).unwrap().into();
+ try_unwrap_or_throw(&e, |env| {
+ let name: String = event.try_to_string(env).unwrap();
trace_end(&name);
Ok(())
})
@@ -933,13 +931,13 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_traceEnd(
/// # Safety
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
pub unsafe extern "system" fn Java_org_apache_comet_Native_logMemoryUsage(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
name: JString,
value: jlong,
) {
- try_unwrap_or_throw(&e, |mut env| {
- let name: String = env.get_string(&name).unwrap().into();
+ try_unwrap_or_throw(&e, |env| {
+ let name: String = name.try_to_string(env).unwrap();
log_memory_usage(&name, value as u64);
Ok(())
})
@@ -958,14 +956,14 @@ use arrow::ffi::{from_ffi, FFI_ArrowArray,
FFI_ArrowSchema};
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_columnarToRowInit(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
serialized_schema: JObjectArray,
batch_size: jint,
) -> jlong {
- try_unwrap_or_throw(&e, |mut env| {
+ try_unwrap_or_throw(&e, |env| {
// Deserialize the schema
- let schema = convert_datatype_arrays(&mut env, serialized_schema)?;
+ let schema = convert_datatype_arrays(env, serialized_schema)?;
// Create the context
let ctx = Box::new(ColumnarToRowContext::new(schema, batch_size as
usize));
@@ -980,26 +978,27 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_columnarToRowInit(
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn
Java_org_apache_comet_Native_columnarToRowConvert(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
c2r_handle: jlong,
array_addrs: JLongArray,
schema_addrs: JLongArray,
num_rows: jint,
) -> jni::sys::jobject {
- try_unwrap_or_throw(&e, |mut env| {
+ try_unwrap_or_throw(&e, |env| {
// Get the context
debug_assert!(c2r_handle != 0, "columnarToRowConvert: c2r_handle is
null");
let ctx = (c2r_handle as *mut ColumnarToRowContext)
.as_mut()
.ok_or_else(|| CometError::Internal("Null columnar to row
context".to_string()))?;
- let num_cols = env.get_array_length(&array_addrs)? as usize;
+ let num_cols = array_addrs.len(env)?;
// Get array and schema addresses
- let array_addrs_elements = env.get_array_elements(&array_addrs,
ReleaseMode::NoCopyBack)?;
+ let array_addrs_elements =
+ unsafe { array_addrs.get_elements(env, ReleaseMode::NoCopyBack)? };
let schema_addrs_elements =
- env.get_array_elements(&schema_addrs, ReleaseMode::NoCopyBack)?;
+ unsafe { schema_addrs.get_elements(env, ReleaseMode::NoCopyBack)?
};
// Import Arrow arrays from FFI
let mut arrays = Vec::with_capacity(num_cols);
@@ -1019,8 +1018,8 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_columnarToRowConvert(
);
// Take ownership of the FFI structures
- let ffi_array = std::ptr::read(array_ptr);
- let ffi_schema = std::ptr::read(schema_ptr);
+ let ffi_array = unsafe { std::ptr::read(array_ptr) };
+ let ffi_schema = unsafe { std::ptr::read(schema_ptr) };
// Convert to Arrow ArrayData
let array_data = from_ffi(ffi_array, &ffi_schema)
@@ -1038,17 +1037,18 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_columnarToRowConvert(
let (buffer_ptr, offsets, lengths) = ctx.convert(&arrays, num_rows as
usize)?;
// Create Java int arrays for offsets and lengths
- let offsets_array = env.new_int_array(offsets.len() as i32)?;
- env.set_int_array_region(&offsets_array, 0, offsets)?;
+ let offsets_array = env.new_int_array(offsets.len())?;
+ offsets_array.set_region(env, 0, offsets)?;
- let lengths_array = env.new_int_array(lengths.len() as i32)?;
- env.set_int_array_region(&lengths_array, 0, lengths)?;
+ let lengths_array = env.new_int_array(lengths.len())?;
+ lengths_array.set_region(env, 0, lengths)?;
// Create the NativeColumnarToRowInfo object
- let info_class =
env.find_class("org/apache/comet/NativeColumnarToRowInfo")?;
+ let info_class =
+
env.find_class(jni::jni_str!("org/apache/comet/NativeColumnarToRowInfo"))?;
let info_obj = env.new_object(
info_class,
- "(J[I[I)V",
+ jni::jni_sig!("(J[I[I)V"),
&[
jni::objects::JValue::Long(buffer_ptr as jlong),
jni::objects::JValue::Object(&offsets_array),
@@ -1066,7 +1066,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_columnarToRowConvert(
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_columnarToRowClose(
- e: JNIEnv,
+ e: EnvUnowned,
_class: JClass,
c2r_handle: jlong,
) {
diff --git a/native/core/src/execution/memory_pools/fair_pool.rs
b/native/core/src/execution/memory_pools/fair_pool.rs
index 2c25fe944..308071181 100644
--- a/native/core/src/execution/memory_pools/fair_pool.rs
+++ b/native/core/src/execution/memory_pools/fair_pool.rs
@@ -20,7 +20,7 @@ use std::{
sync::Arc,
};
-use jni::objects::GlobalRef;
+use jni::objects::{Global, JObject};
use crate::{errors::CometResult, jvm_bridge::JVMClasses};
use datafusion::common::resources_err;
@@ -34,7 +34,7 @@ use parking_lot::Mutex;
/// A DataFusion fair `MemoryPool` implementation for Comet. Internally this is
/// implemented via delegating calls to
[`crate::jvm_bridge::CometTaskMemoryManager`].
pub struct CometFairMemoryPool {
- task_memory_manager_handle: Arc<GlobalRef>,
+ task_memory_manager_handle: Arc<Global<JObject<'static>>>,
pool_size: usize,
state: Mutex<CometFairPoolState>,
}
@@ -57,7 +57,7 @@ impl Debug for CometFairMemoryPool {
impl CometFairMemoryPool {
pub fn new(
- task_memory_manager_handle: Arc<GlobalRef>,
+ task_memory_manager_handle: Arc<Global<JObject<'static>>>,
pool_size: usize,
) -> CometFairMemoryPool {
Self {
@@ -68,20 +68,18 @@ impl CometFairMemoryPool {
}
fn acquire(&self, additional: usize) -> CometResult<i64> {
- let mut env = JVMClasses::get_env()?;
let handle = self.task_memory_manager_handle.as_obj();
- unsafe {
- jni_call!(&mut env,
+ JVMClasses::with_env(|env| unsafe {
+ jni_call!(env,
comet_task_memory_manager(handle).acquire_memory(additional as
i64) -> i64)
- }
+ })
}
fn release(&self, size: usize) -> CometResult<()> {
- let mut env = JVMClasses::get_env()?;
let handle = self.task_memory_manager_handle.as_obj();
- unsafe {
- jni_call!(&mut env,
comet_task_memory_manager(handle).release_memory(size as i64) -> ())
- }
+ JVMClasses::with_env(|env| unsafe {
+ jni_call!(env,
comet_task_memory_manager(handle).release_memory(size as i64) -> ())
+ })
}
}
diff --git a/native/core/src/execution/memory_pools/mod.rs
b/native/core/src/execution/memory_pools/mod.rs
index d8b347335..d44290d05 100644
--- a/native/core/src/execution/memory_pools/mod.rs
+++ b/native/core/src/execution/memory_pools/mod.rs
@@ -25,7 +25,7 @@ use datafusion::execution::memory_pool::{
FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool,
UnboundedMemoryPool,
};
use fair_pool::CometFairMemoryPool;
-use jni::objects::GlobalRef;
+use jni::objects::{Global, JObject};
use once_cell::sync::OnceCell;
use std::num::NonZeroUsize;
use std::sync::Arc;
@@ -36,7 +36,7 @@ pub(crate) use task_shared::*;
pub(crate) fn create_memory_pool(
memory_pool_config: &MemoryPoolConfig,
- comet_task_memory_manager: Arc<GlobalRef>,
+ comet_task_memory_manager: Arc<Global<JObject<'static>>>,
task_attempt_id: i64,
) -> Arc<dyn MemoryPool> {
const NUM_TRACKED_CONSUMERS: usize = 10;
diff --git a/native/core/src/execution/memory_pools/unified_pool.rs
b/native/core/src/execution/memory_pools/unified_pool.rs
index 3233dd6d4..f34418ee9 100644
--- a/native/core/src/execution/memory_pools/unified_pool.rs
+++ b/native/core/src/execution/memory_pools/unified_pool.rs
@@ -28,14 +28,14 @@ use datafusion::{
common::{resources_datafusion_err, DataFusionError},
execution::memory_pool::{MemoryPool, MemoryReservation},
};
-use jni::objects::GlobalRef;
+use jni::objects::{Global, JObject};
use log::warn;
/// A DataFusion `MemoryPool` implementation for Comet that delegates to
/// Spark's off-heap executor memory pool via JNI by calling
/// [`crate::jvm_bridge::CometTaskMemoryManager`].
pub struct CometUnifiedMemoryPool {
- task_memory_manager_handle: Arc<GlobalRef>,
+ task_memory_manager_handle: Arc<Global<JObject<'static>>>,
used: AtomicUsize,
task_attempt_id: i64,
}
@@ -50,7 +50,7 @@ impl Debug for CometUnifiedMemoryPool {
impl CometUnifiedMemoryPool {
pub fn new(
- task_memory_manager_handle: Arc<GlobalRef>,
+ task_memory_manager_handle: Arc<Global<JObject<'static>>>,
task_attempt_id: i64,
) -> CometUnifiedMemoryPool {
Self {
@@ -62,21 +62,19 @@ impl CometUnifiedMemoryPool {
/// Request memory from Spark's off-heap memory pool via JNI
fn acquire_from_spark(&self, additional: usize) -> CometResult<i64> {
- let mut env = JVMClasses::get_env()?;
let handle = self.task_memory_manager_handle.as_obj();
- unsafe {
- jni_call!(&mut env,
+ JVMClasses::with_env(|env| unsafe {
+ jni_call!(env,
comet_task_memory_manager(handle).acquire_memory(additional as
i64) -> i64)
- }
+ })
}
/// Release memory to Spark's off-heap memory pool via JNI
fn release_to_spark(&self, size: usize) -> CometResult<()> {
- let mut env = JVMClasses::get_env()?;
let handle = self.task_memory_manager_handle.as_obj();
- unsafe {
- jni_call!(&mut env,
comet_task_memory_manager(handle).release_memory(size as i64) -> ())
- }
+ JVMClasses::with_env(|env| unsafe {
+ jni_call!(env,
comet_task_memory_manager(handle).release_memory(size as i64) -> ())
+ })
}
}
diff --git a/native/core/src/execution/metrics/utils.rs
b/native/core/src/execution/metrics/utils.rs
index 161c1f1cf..eb7e10bfc 100644
--- a/native/core/src/execution/metrics/utils.rs
+++ b/native/core/src/execution/metrics/utils.rs
@@ -19,7 +19,7 @@ use crate::errors::CometError;
use crate::execution::spark_plan::SparkPlan;
use datafusion::physical_plan::metrics::MetricValue;
use datafusion_comet_proto::spark_metric::NativeMetricNode;
-use jni::{objects::JObject, JNIEnv};
+use jni::{objects::JObject, Env};
use prost::Message;
use std::collections::HashMap;
use std::sync::Arc;
@@ -28,7 +28,7 @@ use std::sync::Arc;
/// update the metrics of all the children nodes. The metrics are pulled from
the
/// native execution plan and pushed to the Java side through JNI.
pub(crate) fn update_comet_metric(
- env: &mut JNIEnv,
+ env: &mut Env,
metric_node: &JObject,
spark_plan: &Arc<SparkPlan>,
) -> Result<(), CometError> {
diff --git a/native/core/src/execution/operators/projection.rs
b/native/core/src/execution/operators/projection.rs
index 194fa6769..355b962e9 100644
--- a/native/core/src/execution/operators/projection.rs
+++ b/native/core/src/execution/operators/projection.rs
@@ -21,7 +21,7 @@ use std::sync::Arc;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion_comet_proto::spark_operator::Operator;
-use jni::objects::GlobalRef;
+use jni::objects::{Global, JObject};
use crate::{
execution::{
@@ -38,7 +38,7 @@ impl OperatorBuilder for ProjectionBuilder {
fn build(
&self,
spark_plan: &Operator,
- inputs: &mut Vec<Arc<GlobalRef>>,
+ inputs: &mut Vec<Arc<Global<JObject<'static>>>>,
partition_count: usize,
planner: &PhysicalPlanner,
) -> PlanCreationResult {
diff --git a/native/core/src/execution/operators/scan.rs
b/native/core/src/execution/operators/scan.rs
index 2394912e4..dfcb50a68 100644
--- a/native/core/src/execution/operators/scan.rs
+++ b/native/core/src/execution/operators/scan.rs
@@ -40,9 +40,7 @@ use datafusion::{
};
use futures::Stream;
use itertools::Itertools;
-use jni::objects::JValueGen;
-use jni::objects::{GlobalRef, JObject};
-use jni::sys::jsize;
+use jni::objects::{Global, JObject, JValue};
use std::rc::Rc;
use std::{
any::Any,
@@ -61,7 +59,7 @@ pub struct ScanExec {
/// environment `JNIEnv` from the execution context.
pub exec_context_id: i64,
/// The input source of scan node. It is a global reference of JVM
`CometBatchIterator` object.
- pub input_source: Option<Arc<GlobalRef>>,
+ pub input_source: Option<Arc<Global<JObject<'static>>>>,
/// A description of the input source for informational purposes
pub input_source_description: String,
/// The data types of columns of the input batch. Converted from Spark
schema.
@@ -84,7 +82,7 @@ pub struct ScanExec {
impl ScanExec {
pub fn new(
exec_context_id: i64,
- input_source: Option<Arc<GlobalRef>>,
+ input_source: Option<Arc<Global<JObject<'static>>>>,
input_source_description: &str,
data_types: Vec<DataType>,
arrow_ffi_safe: bool,
@@ -175,94 +173,94 @@ impl ScanExec {
))));
}
- let mut env = JVMClasses::get_env()?;
+ JVMClasses::with_env(|env| {
+ let num_rows: i32 = unsafe {
+ jni_call!(env,
+ comet_batch_iterator(iter).has_next() -> i32)?
+ };
- let num_rows: i32 = unsafe {
- jni_call!(&mut env,
- comet_batch_iterator(iter).has_next() -> i32)?
- };
+ if num_rows == -1 {
+ return Ok(InputBatch::EOF);
+ }
- if num_rows == -1 {
- return Ok(InputBatch::EOF);
- }
+ // Check for selection vectors and get selection indices if needed
from
+ // JVM via FFI
+ // Selection vectors can be provided by, for instance, Iceberg to
+ // remove rows that have been deleted.
+ let selection_indices_arrays = Self::get_selection_indices(env,
iter, num_cols)?;
- // Check for selection vectors and get selection indices if needed from
- // JVM via FFI
- // Selection vectors can be provided by, for instance, Iceberg to
- // remove rows that have been deleted.
- let selection_indices_arrays = Self::get_selection_indices(&mut env,
iter, num_cols)?;
-
- // fetch batch data from JVM via FFI
- let (num_rows, array_addrs, schema_addrs) =
- Self::allocate_and_fetch_batch(&mut env, iter, num_cols)?;
-
- let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_cols);
-
- // Process each column
- for i in 0..num_cols {
- let array_ptr = array_addrs[i];
- let schema_ptr = schema_addrs[i];
- let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?;
-
- // TODO: validate array input data
- // array_data.validate_full()?;
-
- let array = make_array(array_data);
-
- // Apply selection if selection vectors exist (applies to all
columns)
- let array = if let Some(ref selection_arrays) =
selection_indices_arrays {
- let indices = &selection_arrays[i];
- // Apply the selection using Arrow's take kernel
- match take(&*array, &**indices, None) {
- Ok(selected_array) => selected_array,
- Err(e) => {
- return
Err(CometError::from(ExecutionError::ArrowError(format!(
- "Failed to apply selection for column {i}: {e}",
- ))));
+ // fetch batch data from JVM via FFI
+ let (num_rows, array_addrs, schema_addrs) =
+ Self::allocate_and_fetch_batch(env, iter, num_cols)?;
+
+ let mut inputs: Vec<ArrayRef> = Vec::with_capacity(num_cols);
+
+ // Process each column
+ for i in 0..num_cols {
+ let array_ptr = array_addrs[i];
+ let schema_ptr = schema_addrs[i];
+ let array_data = ArrayData::from_spark((array_ptr,
schema_ptr))?;
+
+ // TODO: validate array input data
+ // array_data.validate_full()?;
+
+ let array = make_array(array_data);
+
+ // Apply selection if selection vectors exist (applies to all
columns)
+ let array = if let Some(ref selection_arrays) =
selection_indices_arrays {
+ let indices = &selection_arrays[i];
+ // Apply the selection using Arrow's take kernel
+ match take(&*array, &**indices, None) {
+ Ok(selected_array) => selected_array,
+ Err(e) => {
+ return
Err(CometError::from(ExecutionError::ArrowError(format!(
+ "Failed to apply selection for column {i}:
{e}",
+ ))));
+ }
}
- }
- } else {
- array
- };
+ } else {
+ array
+ };
- let array = if arrow_ffi_safe {
- // ownership of this array has been transferred to native
- // but we still need to unpack dictionary arrays
- copy_or_unpack_array(&array, &CopyMode::UnpackOrClone)?
- } else {
- // it is necessary to copy the array because the contents may
be
- // overwritten on the JVM side in the future
- copy_array(&array)
- };
+ let array = if arrow_ffi_safe {
+ // ownership of this array has been transferred to native
+ // but we still need to unpack dictionary arrays
+ copy_or_unpack_array(&array, &CopyMode::UnpackOrClone)?
+ } else {
+ // it is necessary to copy the array because the contents
may be
+ // overwritten on the JVM side in the future
+ copy_array(&array)
+ };
- inputs.push(array);
+ inputs.push(array);
- // Drop the Arcs to avoid memory leak
- unsafe {
- Rc::from_raw(array_ptr as *const FFI_ArrowArray);
- Rc::from_raw(schema_ptr as *const FFI_ArrowSchema);
+ // Drop the Arcs to avoid memory leak
+ unsafe {
+ Rc::from_raw(array_ptr as *const FFI_ArrowArray);
+ Rc::from_raw(schema_ptr as *const FFI_ArrowSchema);
+ }
}
- }
- // If selection was applied, determine the actual row count from the
selected arrays
- let actual_num_rows = if let Some(ref selection_arrays) =
selection_indices_arrays {
- if !selection_arrays.is_empty() {
- // Use the length of the first selection array as the actual
row count
- selection_arrays[0].len()
+ // If selection was applied, determine the actual row count from
the selected arrays
+ let actual_num_rows = if let Some(ref selection_arrays) =
selection_indices_arrays {
+ if !selection_arrays.is_empty() {
+ // Use the length of the first selection array as the
actual row count
+ selection_arrays[0].len()
+ } else {
+ num_rows as usize
+ }
} else {
num_rows as usize
- }
- } else {
- num_rows as usize
- };
+ };
- Ok(InputBatch::new(inputs, Some(actual_num_rows)))
+ Ok(InputBatch::new(inputs, Some(actual_num_rows)))
+ })
}
/// Allocates Arrow FFI structures and calls JNI to get the next batch
data.
/// Returns the number of rows and the allocated array/schema addresses.
fn allocate_and_fetch_batch(
- env: &mut jni::JNIEnv,
+ env: &mut jni::Env,
iter: &JObject,
num_cols: usize,
) -> Result<(i32, Vec<i64>, Vec<i64>), CometError> {
@@ -282,17 +280,17 @@ impl ScanExec {
}
// Prepare the java array parameters
- let long_array_addrs = env.new_long_array(num_cols as jsize)?;
- let long_schema_addrs = env.new_long_array(num_cols as jsize)?;
+ let long_array_addrs = env.new_long_array(num_cols)?;
+ let long_schema_addrs = env.new_long_array(num_cols)?;
- env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?;
- env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?;
+ long_array_addrs.set_region(env, 0, &array_addrs)?;
+ long_schema_addrs.set_region(env, 0, &schema_addrs)?;
let array_obj = JObject::from(long_array_addrs);
let schema_obj = JObject::from(long_schema_addrs);
- let array_obj = JValueGen::Object(array_obj.as_ref());
- let schema_obj = JValueGen::Object(schema_obj.as_ref());
+ let array_obj = JValue::Object(array_obj.as_ref());
+ let schema_obj = JValue::Object(schema_obj.as_ref());
let num_rows: i32 = unsafe {
jni_call!(env,
@@ -309,7 +307,7 @@ impl ScanExec {
/// Checks for selection vectors and exports selection indices if needed.
/// Returns selection arrays if they exist (applies to all columns).
fn get_selection_indices(
- env: &mut jni::JNIEnv,
+ env: &mut jni::Env,
iter: &JObject,
num_cols: usize,
) -> Result<Option<Vec<ArrayRef>>, CometError> {
@@ -318,7 +316,7 @@ impl ScanExec {
jni_call!(env,
comet_batch_iterator(iter).has_selection_vectors() ->
jni::sys::jboolean)?
};
- let has_selection_vectors = has_selection_vectors_result != 0;
+ let has_selection_vectors = has_selection_vectors_result;
let selection_indices_arrays = if has_selection_vectors {
// Allocate arrays for selection indices export (one per column)
@@ -333,17 +331,17 @@ impl ScanExec {
}
// Prepare JNI arrays for the export call
- let indices_array_obj = env.new_long_array(num_cols as jsize)?;
- let indices_schema_obj = env.new_long_array(num_cols as jsize)?;
- env.set_long_array_region(&indices_array_obj, 0,
&indices_array_addrs)?;
- env.set_long_array_region(&indices_schema_obj, 0,
&indices_schema_addrs)?;
+ let indices_array_obj = env.new_long_array(num_cols)?;
+ let indices_schema_obj = env.new_long_array(num_cols)?;
+ indices_array_obj.set_region(env, 0, &indices_array_addrs)?;
+ indices_schema_obj.set_region(env, 0, &indices_schema_addrs)?;
// Export selection indices from JVM
let _exported_count: i32 = unsafe {
jni_call!(env,
comet_batch_iterator(iter).export_selection_indices(
-
JValueGen::Object(JObject::from(indices_array_obj).as_ref()),
-
JValueGen::Object(JObject::from(indices_schema_obj).as_ref())
+
JValue::Object(JObject::from(indices_array_obj).as_ref()),
+
JValue::Object(JObject::from(indices_schema_obj).as_ref())
) -> i32)?
};
diff --git a/native/core/src/execution/operators/shuffle_scan.rs
b/native/core/src/execution/operators/shuffle_scan.rs
index a1ad52310..1f3810ee3 100644
--- a/native/core/src/execution/operators/shuffle_scan.rs
+++ b/native/core/src/execution/operators/shuffle_scan.rs
@@ -35,7 +35,7 @@ use datafusion::{
physical_plan::{ExecutionPlan, *},
};
use futures::Stream;
-use jni::objects::{GlobalRef, JByteBuffer, JObject};
+use jni::objects::{Global, JByteBuffer, JObject};
use std::{
any::Any,
pin::Pin,
@@ -53,7 +53,7 @@ pub struct ShuffleScanExec {
/// The ID of the execution context that owns this subquery.
pub exec_context_id: i64,
/// The input source: a global reference to a JVM
CometShuffleBlockIterator object.
- pub input_source: Option<Arc<GlobalRef>>,
+ pub input_source: Option<Arc<Global<JObject<'static>>>>,
/// The data types of columns in the shuffle output.
pub data_types: Vec<DataType>,
/// Schema of the shuffle output.
@@ -73,7 +73,7 @@ pub struct ShuffleScanExec {
impl ShuffleScanExec {
pub fn new(
exec_context_id: i64,
- input_source: Option<Arc<GlobalRef>>,
+ input_source: Option<Arc<Global<JObject<'static>>>>,
data_types: Vec<DataType>,
) -> Result<Self, CometError> {
let metrics_set = ExecutionPlanMetricsSet::default();
@@ -149,55 +149,55 @@ impl ShuffleScanExec {
))));
}
- let mut env = JVMClasses::get_env()?;
+ JVMClasses::with_env(|env| {
+ // has_next() reads the next block and returns its length, or -1
if EOF
+ let block_length: i32 = unsafe {
+ jni_call!(env,
+ comet_shuffle_block_iterator(iter).has_next() -> i32)?
+ };
- // has_next() reads the next block and returns its length, or -1 if EOF
- let block_length: i32 = unsafe {
- jni_call!(&mut env,
- comet_shuffle_block_iterator(iter).has_next() -> i32)?
- };
-
- if block_length == -1 {
- return Ok(InputBatch::EOF);
- }
-
- // Get the DirectByteBuffer containing the compressed shuffle block
- let buffer: JObject = unsafe {
- jni_call!(&mut env,
- comet_shuffle_block_iterator(iter).get_buffer() -> JObject)?
- };
-
- let byte_buffer = JByteBuffer::from(buffer);
- let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?;
- let length = block_length as usize;
- let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer,
length) };
+ if block_length == -1 {
+ return Ok(InputBatch::EOF);
+ }
- // Decode the compressed IPC data
- let mut timer = decode_time.timer();
- let batch = read_ipc_compressed(slice)?;
- timer.stop();
+ // Get the DirectByteBuffer containing the compressed shuffle block
+ let buffer: JObject = unsafe {
+ jni_call!(env,
+ comet_shuffle_block_iterator(iter).get_buffer() ->
JObject)?
+ };
- let num_rows = batch.num_rows();
+ let byte_buffer = unsafe { JByteBuffer::from_raw(env,
buffer.into_raw()) };
+ let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?;
+ let length = block_length as usize;
+ let slice: &[u8] = unsafe {
std::slice::from_raw_parts(raw_pointer, length) };
- // Extract column arrays, unpacking any dictionary-encoded columns.
- // Native shuffle may dictionary-encode string/binary columns for
efficiency,
- // but downstream DataFusion operators expect the value types declared
in the
- // schema (e.g. Utf8, not Dictionary<Int32, Utf8>).
- let columns: Vec<ArrayRef> = batch
- .columns()
- .iter()
- .map(|col| unpack_dictionary(col))
- .collect();
+ // Decode the compressed IPC data
+ let mut timer = decode_time.timer();
+ let batch = read_ipc_compressed(slice)?;
+ timer.stop();
- debug_assert_eq!(
- columns.len(),
- data_types.len(),
- "Shuffle block column count mismatch: got {} but expected {}",
- columns.len(),
- data_types.len()
- );
+ let num_rows = batch.num_rows();
+
+ // Extract column arrays, unpacking any dictionary-encoded columns.
+ // Native shuffle may dictionary-encode string/binary columns for
efficiency,
+ // but downstream DataFusion operators expect the value types
declared in the
+ // schema (e.g. Utf8, not Dictionary<Int32, Utf8>).
+ let columns: Vec<ArrayRef> = batch
+ .columns()
+ .iter()
+ .map(|col| unpack_dictionary(col))
+ .collect();
+
+ debug_assert_eq!(
+ columns.len(),
+ data_types.len(),
+ "Shuffle block column count mismatch: got {} but expected {}",
+ columns.len(),
+ data_types.len()
+ );
- Ok(InputBatch::new(columns, Some(num_rows)))
+ Ok(InputBatch::new(columns, Some(num_rows)))
+ })
}
}
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index 0f96c829e..d1487aaea 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -127,7 +127,7 @@ use datafusion_comet_spark_expr::{
WideDecimalBinaryExpr, WideDecimalOp,
};
use itertools::Itertools;
-use jni::objects::GlobalRef;
+use jni::objects::{Global, JObject};
use num::{BigInt, ToPrimitive};
use object_store::path::Path;
use std::cmp::max;
@@ -910,7 +910,7 @@ impl PhysicalPlanner {
pub(crate) fn create_plan<'a>(
&'a self,
spark_plan: &'a Operator,
- inputs: &mut Vec<Arc<GlobalRef>>,
+ inputs: &mut Vec<Arc<Global<JObject<'static>>>>,
partition_count: usize,
) -> PlanCreationResult {
// Try to use the modular registry first - this automatically handles
any registered operator types
@@ -1817,7 +1817,7 @@ impl PhysicalPlanner {
#[allow(clippy::too_many_arguments)]
fn parse_join_parameters(
&self,
- inputs: &mut Vec<Arc<GlobalRef>>,
+ inputs: &mut Vec<Arc<Global<JObject<'static>>>>,
children: &[Operator],
left_join_keys: &[Expr],
right_join_keys: &[Expr],
diff --git a/native/core/src/execution/planner/operator_registry.rs
b/native/core/src/execution/planner/operator_registry.rs
index cad5df40c..eb3118446 100644
--- a/native/core/src/execution/planner/operator_registry.rs
+++ b/native/core/src/execution/planner/operator_registry.rs
@@ -23,7 +23,7 @@ use std::{
};
use datafusion_comet_proto::spark_operator::Operator;
-use jni::objects::GlobalRef;
+use jni::objects::{Global, JObject};
use super::{PhysicalPlanner, PlanCreationResult};
use crate::execution::operators::ExecutionError;
@@ -34,7 +34,7 @@ pub trait OperatorBuilder: Send + Sync {
fn build(
&self,
spark_plan: &datafusion_comet_proto::spark_operator::Operator,
- inputs: &mut Vec<Arc<GlobalRef>>,
+ inputs: &mut Vec<Arc<Global<JObject<'static>>>>,
partition_count: usize,
planner: &PhysicalPlanner,
) -> PlanCreationResult;
@@ -94,7 +94,7 @@ impl OperatorRegistry {
pub fn create_plan(
&self,
spark_operator: &Operator,
- inputs: &mut Vec<Arc<GlobalRef>>,
+ inputs: &mut Vec<Arc<Global<JObject<'static>>>>,
partition_count: usize,
planner: &PhysicalPlanner,
) -> PlanCreationResult {
diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs
index 1b87dc1db..a9fdf6cdc 100644
--- a/native/core/src/lib.rs
+++ b/native/core/src/lib.rs
@@ -31,7 +31,7 @@ extern crate datafusion_comet_jni_bridge;
use jni::{
objects::{JClass, JString},
- JNIEnv,
+ EnvUnowned,
};
use log::info;
use log4rs::{
@@ -87,7 +87,7 @@ static GLOBAL: MiMalloc = MiMalloc;
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_NativeBase_init(
- e: JNIEnv,
+ e: EnvUnowned,
_: JClass,
log_conf_path: JString,
log_level: JString,
@@ -95,14 +95,14 @@ pub extern "system" fn
Java_org_apache_comet_NativeBase_init(
// Initialize the error handling to capture panic backtraces
errors::init();
- try_unwrap_or_throw(&e, |mut env| {
- let path: String = env.get_string(&log_conf_path)?.into();
+ try_unwrap_or_throw(&e, |env| {
+ let path: String = log_conf_path.try_to_string(env)?;
// empty path means there is no custom log4rs config file provided, so
fallback to use
// the default configuration
let log_config = if path.is_empty() {
- let log_level: String = match env.get_string(&log_level) {
- Ok(level) => level.into(),
+ let log_level: String = match log_level.try_to_string(env) {
+ Ok(level) => level,
Err(_) => "info".parse().unwrap(),
};
default_logger_config(&log_level)
@@ -136,12 +136,12 @@ const LOG_PATTERN: &str = "{d(%y/%m/%d %H:%M:%S)} {l}
{f}: {m}{n}";
/// * `0` (false) if the feature is disabled or unknown
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_NativeBase_isFeatureEnabled(
- env: JNIEnv,
+ env: EnvUnowned,
_: JClass,
feature_name: JString,
) -> jni::sys::jboolean {
- try_unwrap_or_throw(&env, |mut env| {
- let feature: String = env.get_string(&feature_name)?.into();
+ try_unwrap_or_throw(&env, |env| {
+ let feature: String = feature_name.try_to_string(env)?;
let enabled = match feature.as_str() {
"jemalloc" => cfg!(feature = "jemalloc"),
@@ -150,7 +150,7 @@ pub extern "system" fn
Java_org_apache_comet_NativeBase_isFeatureEnabled(
_ => false, // Unknown features return false
};
- Ok(enabled as u8)
+ Ok(enabled)
})
}
diff --git a/native/core/src/parquet/encryption_support.rs
b/native/core/src/parquet/encryption_support.rs
index 4540c217d..afcae086a 100644
--- a/native/core/src/parquet/encryption_support.rs
+++ b/native/core/src/parquet/encryption_support.rs
@@ -23,7 +23,7 @@ use datafusion::common::extensions_options;
use datafusion::config::EncryptionFactoryOptions;
use datafusion::error::DataFusionError;
use datafusion::execution::parquet_encryption::EncryptionFactory;
-use jni::objects::{GlobalRef, JMethodID};
+use jni::objects::{Global, JMethodID, JObject};
use object_store::path::Path;
use parquet::encryption::decrypt::{FileDecryptionProperties, KeyRetriever};
use parquet::encryption::encrypt::FileEncryptionProperties;
@@ -42,7 +42,7 @@ extensions_options! {
#[derive(Debug)]
pub struct CometEncryptionFactory {
- pub(crate) key_unwrapper: GlobalRef,
+ pub(crate) key_unwrapper: Arc<Global<JObject<'static>>>,
}
/// `EncryptionFactory` is a DataFusion trait for types that generate
@@ -73,7 +73,7 @@ impl EncryptionFactory for CometEncryptionFactory {
let config: CometEncryptionConfig = options.to_extension_options()?;
let full_path: String = config.uri_base + file_path.as_ref();
- let key_retriever = CometKeyRetriever::new(&full_path,
self.key_unwrapper.clone())
+ let key_retriever = CometKeyRetriever::new(&full_path,
Arc::clone(&self.key_unwrapper))
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let decryption_properties =
FileDecryptionProperties::with_key_retriever(Arc::new(key_retriever)).build()?;
@@ -83,26 +83,29 @@ impl EncryptionFactory for CometEncryptionFactory {
pub struct CometKeyRetriever {
file_path: String,
- key_unwrapper: GlobalRef,
+ key_unwrapper: Arc<Global<JObject<'static>>>,
get_key_method_id: JMethodID,
}
impl CometKeyRetriever {
- pub fn new(file_path: &str, key_unwrapper: GlobalRef) -> Result<Self,
ExecutionError> {
- let mut env = JVMClasses::get_env()?;
-
- Ok(CometKeyRetriever {
- file_path: file_path.to_string(),
- key_unwrapper,
- get_key_method_id: env
- .get_method_id(
- "org/apache/comet/parquet/CometFileKeyUnwrapper",
- "getKey",
- "(Ljava/lang/String;[B)[B",
- )
- .map_err(|e| {
- ExecutionError::GeneralError(format!("Failed to get JNI
method ID: {}", e))
- })?,
+ pub fn new(
+ file_path: &str,
+ key_unwrapper: Arc<Global<JObject<'static>>>,
+ ) -> Result<Self, ExecutionError> {
+ JVMClasses::with_env(|env| {
+ Ok(CometKeyRetriever {
+ file_path: file_path.to_string(),
+ key_unwrapper,
+ get_key_method_id: env
+ .get_method_id(
+
jni::jni_str!("org/apache/comet/parquet/CometFileKeyUnwrapper"),
+ jni::jni_str!("getKey"),
+ jni::jni_sig!("(Ljava/lang/String;[B)[B"),
+ )
+ .map_err(|e| {
+ ExecutionError::GeneralError(format!("Failed to get
JNI method ID: {}", e))
+ })?,
+ })
})
}
}
@@ -110,63 +113,61 @@ impl CometKeyRetriever {
impl KeyRetriever for CometKeyRetriever {
/// Get a data encryption key using the metadata stored in the Parquet
file.
fn retrieve_key(&self, key_metadata: &[u8]) ->
datafusion::parquet::errors::Result<Vec<u8>> {
- use jni::{objects::JObject, signature::ReturnType};
-
- // Get JNI environment
- let mut env = JVMClasses::get_env()?;
-
- // Get the key unwrapper instance from GlobalRef
- let unwrapper_instance = self.key_unwrapper.as_obj();
-
- let instance: JObject = unsafe {
JObject::from_raw(unwrapper_instance.as_raw()) };
-
- // Convert file path to JString
- let file_path_jstring = env
- .new_string(&self.file_path)
- .map_err(|e| ParquetError::General(format!("Failed to create
JString: {}", e)))?;
-
- // Convert key_metadata to JByteArray
- let key_metadata_array = env
- .byte_array_from_slice(key_metadata)
- .map_err(|e| ParquetError::General(format!("Failed to create byte
array: {}", e)))?;
-
- // Call instance method FileKeyUnwrapper.getKey(String, byte[]) ->
byte[]
- let result = unsafe {
- env.call_method_unchecked(
- instance,
- self.get_key_method_id,
- ReturnType::Array,
- &[
- jni::objects::JValue::from(&file_path_jstring).as_jni(),
- jni::objects::JValue::from(&key_metadata_array).as_jni(),
- ],
- )
- };
-
- // Check for Java exceptions first, before processing the result
- if let Some(exception) = check_exception(&mut env).map_err(|e| {
- ParquetError::General(format!("Failed to check for Java exception:
{}", e))
- })? {
- return Err(ParquetError::General(format!(
- "Java exception during key retrieval: {}",
- exception
- )));
- }
-
- let result =
- result.map_err(|e| ParquetError::General(format!("JNI method call
failed: {}", e)))?;
-
- // Extract the byte array from the result
- let result_array = result
- .l()
- .map_err(|e| ParquetError::General(format!("Failed to extract
result: {}", e)))?;
-
- // Convert JObject to JByteArray and then to Vec<u8>
- let byte_array: jni::objects::JByteArray = result_array.into();
-
- let result_vec = env
- .convert_byte_array(&byte_array)
- .map_err(|e| ParquetError::General(format!("Failed to convert byte
array: {}", e)))?;
- Ok(result_vec)
+ use jni::signature::ReturnType;
+
+ JVMClasses::with_env(|env| {
+ // Get the key unwrapper instance from Global
+ let instance = self.key_unwrapper.as_obj();
+
+ // Convert file path to JString
+ let file_path_jstring = env
+ .new_string(&self.file_path)
+ .map_err(|e| ParquetError::General(format!("Failed to create
JString: {}", e)))?;
+
+ // Convert key_metadata to JByteArray
+ let key_metadata_array =
env.byte_array_from_slice(key_metadata).map_err(|e| {
+ ParquetError::General(format!("Failed to create byte array:
{}", e))
+ })?;
+
+ // Call instance method FileKeyUnwrapper.getKey(String, byte[]) ->
byte[]
+ let result = unsafe {
+ env.call_method_unchecked(
+ instance,
+ self.get_key_method_id,
+ ReturnType::Array,
+ &[
+
jni::objects::JValue::from(&file_path_jstring).as_jni(),
+
jni::objects::JValue::from(&key_metadata_array).as_jni(),
+ ],
+ )
+ };
+
+ // Check for Java exceptions first, before processing the result
+ if let Some(exception) = check_exception(env).map_err(|e| {
+ ParquetError::General(format!("Failed to check for Java
exception: {}", e))
+ })? {
+ return Err(ParquetError::General(format!(
+ "Java exception during key retrieval: {}",
+ exception
+ )));
+ }
+
+ let result = result
+ .map_err(|e| ParquetError::General(format!("JNI method call
failed: {}", e)))?;
+
+ // Extract the byte array from the result
+ let result_array = result
+ .l()
+ .map_err(|e| ParquetError::General(format!("Failed to extract
result: {}", e)))?;
+
+ // Convert JObject to JByteArray and then to Vec<u8>
+ let byte_array =
+ unsafe { jni::objects::JByteArray::from_raw(env,
result_array.into_raw()) };
+
+ let result_vec = env.convert_byte_array(&byte_array).map_err(|e| {
+ ParquetError::General(format!("Failed to convert byte array:
{}", e))
+ })?;
+ Ok(result_vec)
+ })
}
}
diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs
index c1ff725c0..61ff4fc0d 100644
--- a/native/core/src/parquet/mod.rs
+++ b/native/core/src/parquet/mod.rs
@@ -38,10 +38,10 @@ use std::{boxed::Box, sync::Arc};
use crate::errors::{try_unwrap_or_throw, CometError};
/// JNI exposed methods
-use jni::JNIEnv;
use jni::{
- objects::{GlobalRef, JClass},
+ objects::{Global, JClass},
sys::{jboolean, jint, jlong},
+ Env, EnvUnowned,
};
use self::util::jni::TypePromotionInfo;
@@ -77,7 +77,7 @@ struct Context {
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
primitive_type: jint,
logical_type: jint,
@@ -99,9 +99,9 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_initColumnReader(
use_decimal_128: jboolean,
use_legacy_date_timestamp: jboolean,
) -> jlong {
- try_unwrap_or_throw(&e, |mut env| {
+ try_unwrap_or_throw(&e, |env| {
let desc = convert_column_descriptor(
- &mut env,
+ env,
primitive_type,
logical_type,
max_dl,
@@ -126,8 +126,8 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_initColumnReader(
desc,
promotion_info,
batch_size as usize,
- use_decimal_128 != 0,
- use_legacy_date_timestamp != 0,
+ use_decimal_128,
+ use_legacy_date_timestamp,
),
};
let res = Box::new(ctx);
@@ -139,7 +139,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_initColumnReader(
/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setDictionaryPage(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
handle: jlong,
page_value_count: jint,
@@ -153,9 +153,9 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setDictionary
let encoding = convert_encoding(encoding);
// copy the input on-heap buffer to native
- let page_len = env.get_array_length(&page_data)?;
- let mut buffer = MutableBuffer::from_len_zeroed(page_len as usize);
- env.get_byte_array_region(&page_data, 0,
from_u8_slice(buffer.as_slice_mut()))?;
+ let page_len = page_data.len(env)?;
+ let mut buffer = MutableBuffer::from_len_zeroed(page_len);
+ page_data.get_region(env, 0, from_u8_slice(buffer.as_slice_mut()))?;
reader.set_dictionary_page(page_value_count as usize, buffer.into(),
encoding);
Ok(())
@@ -166,7 +166,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setDictionary
/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageV1(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
handle: jlong,
page_value_count: jint,
@@ -180,9 +180,9 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setPageV1(
let encoding = convert_encoding(value_encoding);
// copy the input on-heap buffer to native
- let page_len = env.get_array_length(&page_data)?;
- let mut buffer = MutableBuffer::from_len_zeroed(page_len as usize);
- env.get_byte_array_region(&page_data, 0,
from_u8_slice(buffer.as_slice_mut()))?;
+ let page_len = page_data.len(env)?;
+ let mut buffer = MutableBuffer::from_len_zeroed(page_len);
+ page_data.get_region(env, 0, from_u8_slice(buffer.as_slice_mut()))?;
reader.set_page_v1(page_value_count as usize, buffer.into(), encoding);
Ok(())
@@ -193,7 +193,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setPageV1(
/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageV2(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
handle: jlong,
page_value_count: jint,
@@ -209,17 +209,17 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setPageV2(
let encoding = convert_encoding(value_encoding);
// copy the input on-heap buffer to native
- let dl_len = env.get_array_length(&def_level_data)?;
- let mut dl_buffer = MutableBuffer::from_len_zeroed(dl_len as usize);
- env.get_byte_array_region(&def_level_data, 0,
from_u8_slice(dl_buffer.as_slice_mut()))?;
+ let dl_len = def_level_data.len(env)?;
+ let mut dl_buffer = MutableBuffer::from_len_zeroed(dl_len);
+ def_level_data.get_region(env, 0,
from_u8_slice(dl_buffer.as_slice_mut()))?;
- let rl_len = env.get_array_length(&rep_level_data)?;
- let mut rl_buffer = MutableBuffer::from_len_zeroed(rl_len as usize);
- env.get_byte_array_region(&rep_level_data, 0,
from_u8_slice(rl_buffer.as_slice_mut()))?;
+ let rl_len = rep_level_data.len(env)?;
+ let mut rl_buffer = MutableBuffer::from_len_zeroed(rl_len);
+ rep_level_data.get_region(env, 0,
from_u8_slice(rl_buffer.as_slice_mut()))?;
- let v_len = env.get_array_length(&value_data)?;
- let mut v_buffer = MutableBuffer::from_len_zeroed(v_len as usize);
- env.get_byte_array_region(&value_data, 0,
from_u8_slice(v_buffer.as_slice_mut()))?;
+ let v_len = value_data.len(env)?;
+ let mut v_buffer = MutableBuffer::from_len_zeroed(v_len);
+ value_data.get_region(env, 0, from_u8_slice(v_buffer.as_slice_mut()))?;
reader.set_page_v2(
page_value_count as usize,
@@ -234,7 +234,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setPageV2(
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_resetBatch(
- env: JNIEnv,
+ env: EnvUnowned,
_jclass: JClass,
handle: jlong,
) {
@@ -247,7 +247,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_resetBatch(
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_readBatch(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
handle: jlong,
batch_size: jint,
@@ -259,14 +259,14 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_readBatch(
reader.read_batch(batch_size as usize, null_pad_size as usize);
let res = env.new_int_array(2)?;
let buf: [i32; 2] = [num_values as i32, num_nulls as i32];
- env.set_int_array_region(&res, 0, &buf)?;
+ res.set_region(env, 0, &buf)?;
Ok(res.into_raw())
})
}
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_skipBatch(
- env: JNIEnv,
+ env: EnvUnowned,
_jclass: JClass,
handle: jlong,
batch_size: jint,
@@ -274,13 +274,13 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_skipBatch(
) -> jint {
try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
- Ok(reader.skip_batch(batch_size as usize, discard == 0) as jint)
+ Ok(reader.skip_batch(batch_size as usize, discard) as jint)
})
}
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
handle: jlong,
array_addr: jlong,
@@ -311,7 +311,7 @@ fn get_reader<'a>(handle: jlong) -> Result<&'a mut
ColumnReader, CometError> {
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_closeColumnReader(
- env: JNIEnv,
+ env: EnvUnowned,
_jclass: JClass,
handle: jlong,
) {
@@ -338,7 +338,7 @@ enum ParquetReaderState {
/// Parquet read context maintained across multiple JNI calls.
struct BatchContext {
native_plan: Arc<SparkPlan>,
- metrics_node: Arc<GlobalRef>,
+ metrics_node: Arc<Global<JObject<'static>>>,
batch_stream: Option<SendableRecordBatchStream>,
current_batch: Option<RecordBatch>,
reader_state: ParquetReaderState,
@@ -375,16 +375,20 @@ fn get_file_groups_single_file(
}
pub fn get_object_store_options(
- env: &mut JNIEnv,
+ env: &mut Env,
map_object: JObject,
) -> Result<HashMap<String, String>, CometError> {
- let map = JMap::from_env(env, &map_object)?;
+ let map = env.cast_local::<JMap>(map_object)?;
// Convert to a HashMap
let mut collected_map = HashMap::new();
map.iter(env).and_then(|mut iter| {
- while let Some((key, value)) = iter.next(env)? {
- let key_string: String =
String::from(env.get_string(&JString::from(key))?);
- let value_string: String =
String::from(env.get_string(&JString::from(value))?);
+ while let Some(entry) = iter.next(env)? {
+ let key = entry.key(env)?;
+ let value = entry.value(env)?;
+ let key = unsafe { JString::from_raw(env, key.into_raw()) };
+ let value = unsafe { JString::from_raw(env, value.into_raw()) };
+ let key_string = key.try_to_string(env)?;
+ let value_string = value.try_to_string(env)?;
collected_map.insert(key_string, value_string);
}
Ok(())
@@ -397,18 +401,18 @@ pub fn get_object_store_options(
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_validateObjectStoreConfig(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
file_path: JString,
object_store_options: JObject,
) {
- try_unwrap_or_throw(&e, |mut env| {
+ try_unwrap_or_throw(&e, |env| {
let session_config = SessionConfig::new();
let planner =
PhysicalPlanner::new(Arc::new(SessionContext::new_with_config(session_config)),
0);
let session_ctx = planner.session_ctx();
- let path: String = env.get_string(&file_path).unwrap().into();
- let object_store_config = get_object_store_options(&mut env,
object_store_options)?;
+ let path: String = file_path.try_to_string(env).unwrap();
+ let object_store_config = get_object_store_options(env,
object_store_options)?;
let (_, _) = prepare_object_store_with_configs(
session_ctx.runtime_env(),
path.clone(),
@@ -422,7 +426,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_validateObjec
/// This function is inherently unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_initRecordBatchReader(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
file_path: JString,
file_size: jlong,
@@ -438,16 +442,16 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_initRecordBat
key_unwrapper_obj: JObject,
metrics_node: JObject,
) -> jlong {
- try_unwrap_or_throw(&e, |mut env| unsafe {
- JVMClasses::init(&mut env);
+ try_unwrap_or_throw(&e, |env| unsafe {
+ JVMClasses::init(env);
let session_config = SessionConfig::new().with_batch_size(batch_size
as usize);
let planner =
PhysicalPlanner::new(Arc::new(SessionContext::new_with_config(session_config)),
0);
let session_ctx = planner.session_ctx();
- let path: String = env.get_string(&file_path).unwrap().into();
+ let path: String = file_path.try_to_string(env).unwrap();
- let object_store_config = get_object_store_options(&mut env,
object_store_options)?;
+ let object_store_config = get_object_store_options(env,
object_store_options)?;
let (object_store_url, object_store_path) =
prepare_object_store_with_configs(
session_ctx.runtime_env(),
path.clone(),
@@ -469,21 +473,21 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_initRecordBat
} else {
None
};
- let starts = env.get_array_elements(&starts, ReleaseMode::NoCopyBack)?;
+ let starts = starts.get_elements(env, ReleaseMode::NoCopyBack)?;
let starts = core::slice::from_raw_parts_mut(starts.as_ptr(),
starts.len());
- let lengths = env.get_array_elements(&lengths,
ReleaseMode::NoCopyBack)?;
+ let lengths = lengths.get_elements(env, ReleaseMode::NoCopyBack)?;
let lengths = core::slice::from_raw_parts_mut(lengths.as_ptr(),
lengths.len());
let file_groups =
get_file_groups_single_file(&object_store_path, file_size as u64,
starts, lengths);
- let session_timezone: String =
env.get_string(&session_timezone).unwrap().into();
+ let session_timezone: String =
session_timezone.try_to_string(env).unwrap();
// Handle key unwrapper for encrypted files
let encryption_enabled = if !key_unwrapper_obj.is_null() {
let encryption_factory = CometEncryptionFactory {
- key_unwrapper: jni_new_global_ref!(env, key_unwrapper_obj)?,
+ key_unwrapper: Arc::new(jni_new_global_ref!(env,
key_unwrapper_obj)?),
};
session_ctx
.runtime_env()
@@ -529,11 +533,11 @@ pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_initRecordBat
#[no_mangle]
pub extern "system" fn
Java_org_apache_comet_parquet_Native_readNextRecordBatch(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
handle: jlong,
) -> jint {
- try_unwrap_or_throw(&e, |mut env| {
+ try_unwrap_or_throw(&e, |env| {
let context = get_batch_context(handle)?;
let mut rows_read: i32 = 0;
let batch_stream = context.batch_stream.as_mut().unwrap();
@@ -555,11 +559,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_readNextRecordBatch(
Poll::Ready(None) => {
// EOF
- update_comet_metric(
- &mut env,
- context.metrics_node.as_obj(),
- &context.native_plan,
- )?;
+ update_comet_metric(env, context.metrics_node.as_obj(),
&context.native_plan)?;
context.current_batch = None;
context.reader_state = ParquetReaderState::Complete;
@@ -578,7 +578,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_readNextRecordBatch(
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_currentColumnBatch(
- e: JNIEnv,
+ e: EnvUnowned,
_jclass: JClass,
handle: jlong,
column_idx: jint,
@@ -601,7 +601,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_currentColumnBatch(
#[no_mangle]
pub extern "system" fn
Java_org_apache_comet_parquet_Native_closeRecordBatchReader(
- env: JNIEnv,
+ env: EnvUnowned,
_jclass: JClass,
handle: jlong,
) {
diff --git a/native/core/src/parquet/util/jni.rs
b/native/core/src/parquet/util/jni.rs
index 2223f508f..1fb9519c8 100644
--- a/native/core/src/parquet/util/jni.rs
+++ b/native/core/src/parquet/util/jni.rs
@@ -21,7 +21,7 @@ use jni::{
errors::Result as JNIResult,
objects::{JObjectArray, JString},
sys::{jboolean, jint},
- JNIEnv,
+ Env,
};
use arrow::error::ArrowError;
@@ -37,7 +37,7 @@ use url::{ParseError, Url};
/// Convert primitives from Spark side into a `ColumnDescriptor`.
#[allow(clippy::too_many_arguments)]
pub fn convert_column_descriptor(
- env: &mut JNIEnv,
+ env: &mut Env,
physical_type_id: jint,
logical_type_id: jint,
max_dl: jint,
@@ -131,12 +131,13 @@ impl TypePromotionInfo {
}
}
-fn convert_column_path(env: &mut JNIEnv, path_array: JObjectArray) ->
JNIResult<ColumnPath> {
- let array_len = env.get_array_length(&path_array)?;
+fn convert_column_path(env: &mut Env, path_array: JObjectArray) ->
JNIResult<ColumnPath> {
+ let array_len = path_array.len(env)?;
let mut res: Vec<String> = Vec::new();
for i in 0..array_len {
- let p: JString = env.get_object_array_element(&path_array, i)?.into();
- res.push(env.get_string(&p)?.into());
+ let p = path_array.get_element(env, i)?;
+ let p: JString = unsafe { JString::from_raw(env, p.into_raw()) };
+ res.push(p.try_to_string(env)?);
}
Ok(ColumnPath::new(res))
}
@@ -167,13 +168,13 @@ fn convert_logical_type(
match id {
0 => LogicalType::Integer {
bit_width: bit_width as i8,
- is_signed: is_signed != 0,
+ is_signed,
},
1 => LogicalType::String,
2 => LogicalType::Decimal { scale, precision },
3 => LogicalType::Date,
4 => LogicalType::Timestamp {
- is_adjusted_to_u_t_c: is_adjusted_utc != 0,
+ is_adjusted_to_u_t_c: is_adjusted_utc,
unit: convert_time_unit(time_unit),
},
5 => LogicalType::Enum,
diff --git a/native/jni-bridge/Cargo.toml b/native/jni-bridge/Cargo.toml
index 0c5082566..a0ef4a73c 100644
--- a/native/jni-bridge/Cargo.toml
+++ b/native/jni-bridge/Cargo.toml
@@ -32,7 +32,7 @@ publish = false
arrow = { workspace = true }
parquet = { workspace = true }
datafusion = { workspace = true }
-jni = "0.21"
+jni = "0.22.4"
thiserror = { workspace = true }
regex = { workspace = true }
lazy_static = "1.4.0"
@@ -42,5 +42,5 @@ prost = "0.14.3"
datafusion-comet-common = { workspace = true }
[dev-dependencies]
-jni = { version = "0.21", features = ["invocation"] }
+jni = { version = "0.22.4", features = ["invocation"] }
assertables = "9"
diff --git a/native/jni-bridge/src/batch_iterator.rs
b/native/jni-bridge/src/batch_iterator.rs
index 2824bdbfc..65ca7e7d1 100644
--- a/native/jni-bridge/src/batch_iterator.rs
+++ b/native/jni-bridge/src/batch_iterator.rs
@@ -20,7 +20,8 @@ use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
signature::ReturnType,
- JNIEnv,
+ strings::JNIString,
+ Env,
};
/// A struct that holds all the JNI methods and fields for JVM
`CometBatchIterator` class.
@@ -40,25 +41,33 @@ pub struct CometBatchIterator<'a> {
impl<'a> CometBatchIterator<'a> {
pub const JVM_CLASS: &'static str = "org/apache/comet/CometBatchIterator";
- pub fn new(env: &mut JNIEnv<'a>) -> JniResult<CometBatchIterator<'a>> {
- let class = env.find_class(Self::JVM_CLASS)?;
+ pub fn new(env: &mut Env<'a>) -> JniResult<CometBatchIterator<'a>> {
+ let class = env.find_class(JNIString::new(Self::JVM_CLASS))?;
Ok(CometBatchIterator {
class,
- method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext",
"()I")?,
+ method_has_next: env.get_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("hasNext"),
+ jni::jni_sig!("()I"),
+ )?,
method_has_next_ret: ReturnType::Primitive(Primitive::Int),
- method_next: env.get_method_id(Self::JVM_CLASS, "next",
"([J[J)I")?,
+ method_next: env.get_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("next"),
+ jni::jni_sig!("([J[J)I"),
+ )?,
method_next_ret: ReturnType::Primitive(Primitive::Int),
method_has_selection_vectors: env.get_method_id(
- Self::JVM_CLASS,
- "hasSelectionVectors",
- "()Z",
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("hasSelectionVectors"),
+ jni::jni_sig!("()Z"),
)?,
method_has_selection_vectors_ret:
ReturnType::Primitive(Primitive::Boolean),
method_export_selection_indices: env.get_method_id(
- Self::JVM_CLASS,
- "exportSelectionIndices",
- "([J[J)I",
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("exportSelectionIndices"),
+ jni::jni_sig!("([J[J)I"),
)?,
method_export_selection_indices_ret:
ReturnType::Primitive(Primitive::Int),
})
diff --git a/native/jni-bridge/src/comet_exec.rs
b/native/jni-bridge/src/comet_exec.rs
index 1bcbbc4ad..a0b39d0ea 100644
--- a/native/jni-bridge/src/comet_exec.rs
+++ b/native/jni-bridge/src/comet_exec.rs
@@ -19,7 +19,8 @@ use jni::{
errors::Result as JniResult,
objects::{JClass, JStaticMethodID},
signature::{Primitive, ReturnType},
- JNIEnv,
+ strings::JNIString,
+ Env,
};
/// A struct that holds all the JNI methods and fields for JVM CometExec
object.
@@ -52,39 +53,75 @@ pub struct CometExec<'a> {
impl<'a> CometExec<'a> {
pub const JVM_CLASS: &'static str =
"org/apache/spark/sql/comet/CometScalarSubquery";
- pub fn new(env: &mut JNIEnv<'a>) -> JniResult<CometExec<'a>> {
- let class = env.find_class(Self::JVM_CLASS)?;
+ pub fn new(env: &mut Env<'a>) -> JniResult<CometExec<'a>> {
+ let class = env.find_class(JNIString::new(Self::JVM_CLASS))?;
Ok(CometExec {
- method_get_bool: env.get_static_method_id(Self::JVM_CLASS,
"getBoolean", "(JJ)Z")?,
+ method_get_bool: env.get_static_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getBoolean"),
+ jni::jni_sig!("(JJ)Z"),
+ )?,
method_get_bool_ret: ReturnType::Primitive(Primitive::Boolean),
- method_get_byte: env.get_static_method_id(Self::JVM_CLASS,
"getByte", "(JJ)B")?,
+ method_get_byte: env.get_static_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getByte"),
+ jni::jni_sig!("(JJ)B"),
+ )?,
method_get_byte_ret: ReturnType::Primitive(Primitive::Byte),
- method_get_short: env.get_static_method_id(Self::JVM_CLASS,
"getShort", "(JJ)S")?,
+ method_get_short: env.get_static_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getShort"),
+ jni::jni_sig!("(JJ)S"),
+ )?,
method_get_short_ret: ReturnType::Primitive(Primitive::Short),
- method_get_int: env.get_static_method_id(Self::JVM_CLASS,
"getInt", "(JJ)I")?,
+ method_get_int: env.get_static_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getInt"),
+ jni::jni_sig!("(JJ)I"),
+ )?,
method_get_int_ret: ReturnType::Primitive(Primitive::Int),
- method_get_long: env.get_static_method_id(Self::JVM_CLASS,
"getLong", "(JJ)J")?,
+ method_get_long: env.get_static_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getLong"),
+ jni::jni_sig!("(JJ)J"),
+ )?,
method_get_long_ret: ReturnType::Primitive(Primitive::Long),
- method_get_float: env.get_static_method_id(Self::JVM_CLASS,
"getFloat", "(JJ)F")?,
+ method_get_float: env.get_static_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getFloat"),
+ jni::jni_sig!("(JJ)F"),
+ )?,
method_get_float_ret: ReturnType::Primitive(Primitive::Float),
- method_get_double: env.get_static_method_id(Self::JVM_CLASS,
"getDouble", "(JJ)D")?,
+ method_get_double: env.get_static_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getDouble"),
+ jni::jni_sig!("(JJ)D"),
+ )?,
method_get_double_ret: ReturnType::Primitive(Primitive::Double),
method_get_decimal: env.get_static_method_id(
- Self::JVM_CLASS,
- "getDecimal",
- "(JJ)[B",
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getDecimal"),
+ jni::jni_sig!("(JJ)[B"),
)?,
method_get_decimal_ret: ReturnType::Array,
method_get_string: env.get_static_method_id(
- Self::JVM_CLASS,
- "getString",
- "(JJ)Ljava/lang/String;",
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getString"),
+ jni::jni_sig!("(JJ)Ljava/lang/String;"),
)?,
method_get_string_ret: ReturnType::Object,
- method_get_binary: env.get_static_method_id(Self::JVM_CLASS,
"getBinary", "(JJ)[B")?,
+ method_get_binary: env.get_static_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getBinary"),
+ jni::jni_sig!("(JJ)[B"),
+ )?,
method_get_binary_ret: ReturnType::Array,
- method_is_null: env.get_static_method_id(Self::JVM_CLASS,
"isNull", "(JJ)Z")?,
+ method_is_null: env.get_static_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("isNull"),
+ jni::jni_sig!("(JJ)Z"),
+ )?,
method_is_null_ret: ReturnType::Primitive(Primitive::Boolean),
class,
})
diff --git a/native/jni-bridge/src/comet_metric_node.rs
b/native/jni-bridge/src/comet_metric_node.rs
index f1f025584..4cc8ae163 100644
--- a/native/jni-bridge/src/comet_metric_node.rs
+++ b/native/jni-bridge/src/comet_metric_node.rs
@@ -18,8 +18,10 @@
use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
+ signature::RuntimeMethodSignature,
signature::{Primitive, ReturnType},
- JNIEnv,
+ strings::JNIString,
+ Env,
};
/// A struct that holds all the JNI methods and fields for JVM CometMetricNode
class.
@@ -37,22 +39,28 @@ pub struct CometMetricNode<'a> {
impl<'a> CometMetricNode<'a> {
pub const JVM_CLASS: &'static str =
"org/apache/spark/sql/comet/CometMetricNode";
- pub fn new(env: &mut JNIEnv<'a>) -> JniResult<CometMetricNode<'a>> {
- let class = env.find_class(Self::JVM_CLASS)?;
+ pub fn new(env: &mut Env<'a>) -> JniResult<CometMetricNode<'a>> {
+ let class = env.find_class(JNIString::new(Self::JVM_CLASS))?;
+ let get_child_node_sig =
+ RuntimeMethodSignature::from_str(format!("(I)L{};",
Self::JVM_CLASS))?;
Ok(CometMetricNode {
method_get_child_node: env.get_method_id(
- Self::JVM_CLASS,
- "getChildNode",
- format!("(I)L{:};", Self::JVM_CLASS).as_str(),
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getChildNode"),
+ get_child_node_sig.method_signature(),
)?,
method_get_child_node_ret: ReturnType::Object,
- method_set: env.get_method_id(Self::JVM_CLASS, "set",
"(Ljava/lang/String;J)V")?,
+ method_set: env.get_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("set"),
+ jni::jni_sig!("(Ljava/lang/String;J)V"),
+ )?,
method_set_ret: ReturnType::Primitive(Primitive::Void),
method_set_all_from_bytes: env.get_method_id(
- Self::JVM_CLASS,
- "set_all_from_bytes",
- "([B)V",
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("set_all_from_bytes"),
+ jni::jni_sig!("([B)V"),
)?,
method_set_all_from_bytes_ret:
ReturnType::Primitive(Primitive::Void),
class,
diff --git a/native/jni-bridge/src/comet_task_memory_manager.rs
b/native/jni-bridge/src/comet_task_memory_manager.rs
index 22c3332c6..cec0b7051 100644
--- a/native/jni-bridge/src/comet_task_memory_manager.rs
+++ b/native/jni-bridge/src/comet_task_memory_manager.rs
@@ -19,7 +19,8 @@ use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
signature::{Primitive, ReturnType},
- JNIEnv,
+ strings::JNIString,
+ Env,
};
/// A wrapper which delegate acquire/release memory calls to the
@@ -38,20 +39,20 @@ pub struct CometTaskMemoryManager<'a> {
impl<'a> CometTaskMemoryManager<'a> {
pub const JVM_CLASS: &'static str =
"org/apache/spark/CometTaskMemoryManager";
- pub fn new(env: &mut JNIEnv<'a>) -> JniResult<CometTaskMemoryManager<'a>> {
- let class = env.find_class(Self::JVM_CLASS)?;
+ pub fn new(env: &mut Env<'a>) -> JniResult<CometTaskMemoryManager<'a>> {
+ let class = env.find_class(JNIString::new(Self::JVM_CLASS))?;
let result = CometTaskMemoryManager {
class,
method_acquire_memory: env.get_method_id(
- Self::JVM_CLASS,
- "acquireMemory",
- "(J)J".to_string(),
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("acquireMemory"),
+ jni::jni_sig!("(J)J"),
)?,
method_release_memory: env.get_method_id(
- Self::JVM_CLASS,
- "releaseMemory",
- "(J)V".to_string(),
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("releaseMemory"),
+ jni::jni_sig!("(J)V"),
)?,
method_acquire_memory_ret: ReturnType::Primitive(Primitive::Long),
method_release_memory_ret: ReturnType::Primitive(Primitive::Void),
diff --git a/native/jni-bridge/src/errors.rs b/native/jni-bridge/src/errors.rs
index 640201f6f..4fbcb28e2 100644
--- a/native/jni-bridge/src/errors.rs
+++ b/native/jni-bridge/src/errors.rs
@@ -27,7 +27,7 @@ use std::{
any::Any,
convert,
fmt::Write,
- panic::{catch_unwind, UnwindSafe},
+ panic::UnwindSafe,
result, str,
str::Utf8Error,
sync::{Arc, Mutex},
@@ -38,8 +38,8 @@ use std::{
// lifetime checker won't let us.
use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject,
jshort};
-use jni::objects::{GlobalRef, JThrowable};
-use jni::JNIEnv;
+use jni::objects::{Global, JThrowable};
+use jni::{strings::JNIString, Env, EnvUnowned, Outcome};
use lazy_static::lazy_static;
use parquet::errors::ParquetError;
use thiserror::Error;
@@ -72,7 +72,7 @@ pub enum ExecutionError {
JavaException {
class: String,
msg: String,
- throwable: GlobalRef,
+ throwable: Global<JThrowable<'static>>,
},
}
@@ -167,7 +167,7 @@ pub enum CometError {
JavaException {
class: String,
msg: String,
- throwable: GlobalRef,
+ throwable: Global<JThrowable<'static>>,
},
}
@@ -388,7 +388,7 @@ pub trait JNIDefault {
impl JNIDefault for jboolean {
fn default() -> jboolean {
- 0
+ false
}
}
@@ -449,7 +449,7 @@ impl JNIDefault for () {
// `RuntimeException` back to the calling Java. Since a return result is
required, use `JNIDefault`
// to create a reasonable result. This returned default value will be ignored
due to the exception.
pub fn unwrap_or_throw_default<T: JNIDefault>(
- env: &mut JNIEnv,
+ env: &mut Env,
result: std::result::Result<T, CometError>,
) -> T {
match result {
@@ -465,16 +465,20 @@ pub fn unwrap_or_throw_default<T: JNIDefault>(
}
}
-fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace:
Option<String>) {
+fn throw_exception(env: &mut Env, error: &CometError, backtrace:
Option<String>) {
// If there isn't already an exception?
- if env.exception_check().is_ok() {
+ if !env.exception_check() {
// ... then throw new exception
- match error {
+ // Note: in jni 0.22.x, throw/throw_new return Err(JavaException) on
success
+ // (to signal the pending exception to Rust callers via `?`). We
discard the
+ // result here because we're in an error-handling path and just need
the
+ // exception to be pending in the JVM.
+ let _ = match error {
CometError::JavaException {
class: _,
msg: _,
throwable,
- } => env.throw(<&JThrowable>::from(throwable.as_obj())),
+ } => env.throw(throwable),
CometError::Execution {
source:
ExecutionError::JavaException {
@@ -482,7 +486,7 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError,
backtrace: Option<Strin
msg: _,
throwable,
},
- } => env.throw(<&JThrowable>::from(throwable.as_obj())),
+ } => env.throw(throwable),
// Handle DataFusion errors containing SparkError or
SparkErrorWithContext
CometError::DataFusion {
msg: _,
@@ -491,14 +495,14 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError,
backtrace: Option<Strin
if let Some(spark_error_with_ctx) =
e.downcast_ref::<SparkErrorWithContext>() {
let json_message = spark_error_with_ctx.to_json();
env.throw_new(
-
"org/apache/comet/exceptions/CometQueryExecutionException",
- json_message,
+
jni::jni_str!("org/apache/comet/exceptions/CometQueryExecutionException"),
+ JNIString::new(json_message),
)
} else if let Some(spark_error) =
e.downcast_ref::<SparkError>() {
let json_message = spark_error.to_json();
env.throw_new(
-
"org/apache/comet/exceptions/CometQueryExecutionException",
- json_message,
+
jni::jni_str!("org/apache/comet/exceptions/CometQueryExecutionException"),
+ JNIString::new(json_message),
)
} else {
// Check for file-not-found errors from object store
@@ -513,10 +517,15 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError,
backtrace: Option<Strin
let exception = error.to_exception();
match backtrace {
Some(backtrace_string) => env.throw_new(
- exception.class,
- to_stacktrace_string(exception.msg,
backtrace_string).unwrap(),
+ JNIString::new(exception.class),
+ JNIString::new(
+ to_stacktrace_string(exception.msg,
backtrace_string).unwrap(),
+ ),
+ ),
+ _ => env.throw_new(
+ JNIString::new(exception.class),
+ JNIString::new(exception.msg),
),
- _ => env.throw_new(exception.class, exception.msg),
}
}
}
@@ -537,30 +546,31 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError,
backtrace: Option<Strin
let exception = error.to_exception();
match backtrace {
Some(backtrace_string) => env.throw_new(
- exception.class,
- to_stacktrace_string(exception.msg,
backtrace_string).unwrap(),
+ JNIString::new(exception.class),
+ JNIString::new(
+ to_stacktrace_string(exception.msg,
backtrace_string).unwrap(),
+ ),
+ ),
+ _ => env.throw_new(
+ JNIString::new(exception.class),
+ JNIString::new(exception.msg),
),
- _ => env.throw_new(exception.class, exception.msg),
}
}
}
- }
- .expect("Thrown exception")
+ };
}
}
/// Throws a CometQueryExecutionException with JSON-encoded SparkError
-fn throw_spark_error_as_json(
- env: &mut JNIEnv,
- spark_error: &SparkError,
-) -> jni::errors::Result<()> {
+fn throw_spark_error_as_json(env: &mut Env, spark_error: &SparkError) ->
jni::errors::Result<()> {
// Serialize error to JSON
let json_message = spark_error.to_json();
// Throw CometQueryExecutionException with JSON message
env.throw_new(
- "org/apache/comet/exceptions/CometQueryExecutionException",
- json_message,
+
jni::jni_str!("org/apache/comet/exceptions/CometQueryExecutionException"),
+ JNIString::new(json_message),
)
}
@@ -659,33 +669,26 @@ fn to_stacktrace_string(msg: String, backtrace_string:
String) -> Result<String,
Ok(res)
}
-fn flatten<T, E>(result: Result<Result<T, E>, E>) -> Result<T, E> {
- result.and_then(convert::identity)
-}
-
-// Implements "currying" from `FnOnce(T) -> R` to `FnOnce() -> R`, given
-// an instance of T. Curring is not supported in Rust so we have to use this
-// custom function to achieve something similar here.
-fn curry<'a, T: 'a, F, R>(f: F, t: T) -> impl FnOnce() -> R + 'a
-where
- F: FnOnce(T) -> R + 'a,
-{
- || f(t)
-}
-
// It is currently undefined behavior to unwind from Rust code into foreign
code, so we can wrap
// our JNI functions and turn these panics into a `RuntimeException`.
-pub fn try_unwrap_or_throw<T, F>(env: &JNIEnv, f: F) -> T
+pub fn try_unwrap_or_throw<T, F>(env: &EnvUnowned, f: F) -> T
where
T: JNIDefault,
- F: FnOnce(JNIEnv) -> Result<T, CometError> + UnwindSafe,
+ F: FnOnce(&mut Env) -> Result<T, CometError> + UnwindSafe,
{
- let mut env1 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() };
- let env2 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() };
- unwrap_or_throw_default(
- &mut env1,
- flatten(catch_unwind(curry(f, env2)).map_err(CometError::from)),
- )
+ let raw = env.as_raw();
+ let mut env1 = unsafe { EnvUnowned::from_raw(raw) };
+ match env1.with_env(f).into_outcome() {
+ Outcome::Ok(value) => value,
+ Outcome::Err(err) => {
+ let mut guard = unsafe { jni::AttachGuard::from_unowned(raw) };
+ unwrap_or_throw_default(guard.borrow_env_mut(), Err(err))
+ }
+ Outcome::Panic(payload) => {
+ let mut guard = unsafe { jni::AttachGuard::from_unowned(raw) };
+ unwrap_or_throw_default(guard.borrow_env_mut(),
Err(CometError::from(payload)))
+ }
+ }
}
#[cfg(test)]
@@ -702,7 +705,7 @@ mod tests {
use jni::{
objects::{JClass, JIntArray, JString, JThrowable},
sys::{jintArray, jstring},
- AttachGuard, InitArgsBuilder, JNIEnv, JNIVersion, JavaVM,
+ EnvUnowned, InitArgsBuilder, JNIVersion, JavaVM,
};
use assertables::assert_starts_with;
@@ -728,7 +731,7 @@ mod tests {
// Build the VM properties
let jvm_args = InitArgsBuilder::new()
// Pass the JNI API version (default is 8)
- .version(JNIVersion::V8)
+ .version(JNIVersion::V1_8)
// You can additionally pass any JVM options (standard, like a
system property,
// or VM-specific).
// Here we enable some extra JNI checks useful during
development
@@ -751,25 +754,24 @@ mod tests {
}
}
- fn attach_current_thread() -> AttachGuard<'static> {
- jvm().attach_current_thread().expect("Unable to attach JVM")
- }
-
#[test]
#[cfg_attr(miri, ignore)] // miri can't call foreign function `dlopen`
pub fn error_from_panic() {
- let _guard = attach_current_thread();
- let mut env = jvm().get_env().unwrap();
-
- try_unwrap_or_throw(&env, |_| -> CometResult<()> {
- panic!("oops!");
- });
-
- assert_pending_java_exception_detailed(
- &mut env,
- Some("java/lang/RuntimeException"),
- Some("oops!"),
- );
+ jvm()
+ .attach_current_thread(|env| -> jni::errors::Result<()> {
+ let env_unowned = unsafe { EnvUnowned::from_raw(env.get_raw())
};
+ try_unwrap_or_throw(&env_unowned, |_| -> CometResult<()> {
+ panic!("oops!");
+ });
+
+ assert_pending_java_exception_detailed(
+ env,
+ Some("java/lang/RuntimeException"),
+ Some("oops!"),
+ );
+ Ok(())
+ })
+ .unwrap();
}
// Verify that functions that return an object are handled correctly.
This is basically
@@ -777,17 +779,21 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)] // miri can't call foreign function `dlopen`
pub fn object_result() {
- let _guard = attach_current_thread();
- let mut env = jvm().get_env().unwrap();
-
- let clazz = env.find_class("java/lang/Object").unwrap();
- let input = env.new_string("World".to_string()).unwrap();
-
- let actual = Java_Errors_hello(&env, clazz, input);
- let actual_s = unsafe { JString::from_raw(actual) };
-
- let actual_string =
String::from(env.get_string(&actual_s).unwrap().to_str().unwrap());
- assert_eq!("Hello, World!", actual_string);
+ jvm()
+ .attach_current_thread(|env| -> jni::errors::Result<()> {
+ let clazz =
env.find_class(jni::jni_str!("java/lang/Object")).unwrap();
+ let input = env.new_string("World").unwrap();
+
+ let actual = unsafe {
+ Java_Errors_hello(&EnvUnowned::from_raw(env.get_raw()),
clazz, input)
+ };
+ let actual_s = unsafe { JString::from_raw(env, actual) };
+
+ let actual_string = actual_s.try_to_string(env).unwrap();
+ assert_eq!("Hello, World!", actual_string);
+ Ok(())
+ })
+ .unwrap();
}
// Verify that functions that return an native time are handled correctly.
This is basically
@@ -795,16 +801,19 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)] // miri can't call foreign function `dlopen`
pub fn jlong_result() {
- let _guard = attach_current_thread();
- let mut env = jvm().get_env().unwrap();
-
- // Class java.lang.object is just a stand-in
- let class = env.find_class("java/lang/Object").unwrap();
- let a: jlong = 6;
- let b: jlong = 3;
- let actual = Java_Errors_div(&env, class, a, b);
-
- assert_eq!(2, actual);
+ jvm()
+ .attach_current_thread(|env| -> jni::errors::Result<()> {
+ // Class java.lang.object is just a stand-in
+ let class =
env.find_class(jni::jni_str!("java/lang/Object")).unwrap();
+ let a: jlong = 6;
+ let b: jlong = 3;
+ let actual =
+ unsafe {
Java_Errors_div(&EnvUnowned::from_raw(env.get_raw()), class, a, b) };
+
+ assert_eq!(2, actual);
+ Ok(())
+ })
+ .unwrap();
}
// Verify that functions that return an array can handle throwing
exceptions. The test
@@ -812,20 +821,23 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)] // miri can't call foreign function `dlopen`
pub fn jlong_panic_exception() {
- let _guard = attach_current_thread();
- let mut env = jvm().get_env().unwrap();
-
- // Class java.lang.object is just a stand-in
- let class = env.find_class("java/lang/Object").unwrap();
- let a: jlong = 6;
- let b: jlong = 0;
- let _actual = Java_Errors_div(&env, class, a, b);
-
- assert_pending_java_exception_detailed(
- &mut env,
- Some("java/lang/RuntimeException"),
- Some("attempt to divide by zero"),
- );
+ jvm()
+ .attach_current_thread(|env| -> jni::errors::Result<()> {
+ // Class java.lang.object is just a stand-in
+ let class =
env.find_class(jni::jni_str!("java/lang/Object")).unwrap();
+ let a: jlong = 6;
+ let b: jlong = 0;
+ let _actual =
+ unsafe {
Java_Errors_div(&EnvUnowned::from_raw(env.get_raw()), class, a, b) };
+
+ assert_pending_java_exception_detailed(
+ env,
+ Some("java/lang/RuntimeException"),
+ Some("attempt to divide by zero"),
+ );
+ Ok(())
+ })
+ .unwrap();
}
// Verify that functions that return an native time are handled correctly.
This is basically
@@ -833,16 +845,20 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)] // miri can't call foreign function `dlopen`
pub fn jlong_result_ok() {
- let _guard = attach_current_thread();
- let mut env = jvm().get_env().unwrap();
-
- // Class java.lang.object is just a stand-in
- let class = env.find_class("java/lang/Object").unwrap();
- let a: JString = env.new_string("9".to_string()).unwrap();
- let b: JString = env.new_string("3".to_string()).unwrap();
- let actual = Java_Errors_div_with_parse(&env, class, a, b);
-
- assert_eq!(3, actual);
+ jvm()
+ .attach_current_thread(|env| -> jni::errors::Result<()> {
+ // Class java.lang.object is just a stand-in
+ let class =
env.find_class(jni::jni_str!("java/lang/Object")).unwrap();
+ let a: JString = env.new_string("9").unwrap();
+ let b: JString = env.new_string("3").unwrap();
+ let actual = unsafe {
+
Java_Errors_div_with_parse(&EnvUnowned::from_raw(env.get_raw()), class, a, b)
+ };
+
+ assert_eq!(3, actual);
+ Ok(())
+ })
+ .unwrap();
}
// Verify that functions that return an native time are handled correctly.
This is basically
@@ -850,20 +866,24 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)] // miri can't call foreign function `dlopen`
pub fn jlong_result_err() {
- let _guard = attach_current_thread();
- let mut env = jvm().get_env().unwrap();
-
- // Class java.lang.object is just a stand-in
- let class = env.find_class("java/lang/Object").unwrap();
- let a: JString = env.new_string("NaN".to_string()).unwrap();
- let b: JString = env.new_string("3".to_string()).unwrap();
- let _actual = Java_Errors_div_with_parse(&env, class, a, b);
-
- assert_pending_java_exception_detailed(
- &mut env,
- Some("java/lang/NumberFormatException"),
- Some("invalid digit found in string"),
- );
+ jvm()
+ .attach_current_thread(|env| -> jni::errors::Result<()> {
+ // Class java.lang.object is just a stand-in
+ let class =
env.find_class(jni::jni_str!("java/lang/Object")).unwrap();
+ let a: JString = env.new_string("NaN").unwrap();
+ let b: JString = env.new_string("3").unwrap();
+ let _actual = unsafe {
+
Java_Errors_div_with_parse(&EnvUnowned::from_raw(env.get_raw()), class, a, b)
+ };
+
+ assert_pending_java_exception_detailed(
+ env,
+ Some("java/lang/NumberFormatException"),
+ Some("invalid digit found in string"),
+ );
+ Ok(())
+ })
+ .unwrap();
}
// Verify that functions that return an array are handled correctly. This
is basically
@@ -871,20 +891,24 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)] // miri can't call foreign function `dlopen`
pub fn jint_array_result() {
- let _guard = attach_current_thread();
- let mut env = jvm().get_env().unwrap();
-
- // Class java.lang.object is just a stand-in
- let class = env.find_class("java/lang/Object").unwrap();
- let buf = [2, 4, 6];
- let input = env.new_int_array(3).unwrap();
- env.set_int_array_region(&input, 0, &buf).unwrap();
- let actual = Java_Errors_array_div(&env, class, &input, 2);
- let actual_s = unsafe { JIntArray::from_raw(actual) };
-
- let mut buf: [i32; 3] = [0; 3];
- env.get_int_array_region(&actual_s, 0, &mut buf).unwrap();
- assert_eq!([1, 2, 3], buf);
+ jvm()
+ .attach_current_thread(|env| -> jni::errors::Result<()> {
+ // Class java.lang.object is just a stand-in
+ let class =
env.find_class(jni::jni_str!("java/lang/Object")).unwrap();
+ let buf = [2, 4, 6];
+ let input = env.new_int_array(3).unwrap();
+ input.set_region(env, 0, &buf).unwrap();
+ let actual = unsafe {
+
Java_Errors_array_div(&EnvUnowned::from_raw(env.get_raw()), class, &input, 2)
+ };
+ let actual_s = unsafe { JIntArray::from_raw(env, actual) };
+
+ let mut buf: [i32; 3] = [0; 3];
+ actual_s.get_region(env, 0, &mut buf).unwrap();
+ assert_eq!([1, 2, 3], buf);
+ Ok(())
+ })
+ .unwrap();
}
// Verify that functions that return an array can handle throwing
exceptions. The test
@@ -892,21 +916,25 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)] // miri can't call foreign function `dlopen`
pub fn jint_array_panic_exception() {
- let _guard = attach_current_thread();
- let mut env = jvm().get_env().unwrap();
-
- // Class java.lang.object is just a stand-in
- let class = env.find_class("java/lang/Object").unwrap();
- let buf = [2, 4, 6];
- let input = env.new_int_array(3).unwrap();
- env.set_int_array_region(&input, 0, &buf).unwrap();
- let _actual = Java_Errors_array_div(&env, class, &input, 0);
-
- assert_pending_java_exception_detailed(
- &mut env,
- Some("java/lang/RuntimeException"),
- Some("attempt to divide by zero"),
- );
+ jvm()
+ .attach_current_thread(|env| -> jni::errors::Result<()> {
+ // Class java.lang.object is just a stand-in
+ let class =
env.find_class(jni::jni_str!("java/lang/Object")).unwrap();
+ let buf = [2, 4, 6];
+ let input = env.new_int_array(3).unwrap();
+ input.set_region(env, 0, &buf).unwrap();
+ let _actual = unsafe {
+
Java_Errors_array_div(&EnvUnowned::from_raw(env.get_raw()), class, &input, 0)
+ };
+
+ assert_pending_java_exception_detailed(
+ env,
+ Some("java/lang/RuntimeException"),
+ Some("attempt to divide by zero"),
+ );
+ Ok(())
+ })
+ .unwrap();
}
/// Test that conversion of a serialized backtrace to an equivalent
stacktrace message.
@@ -946,15 +974,12 @@ mod tests {
// * throwing an exception from `.expect()`
#[no_mangle]
pub extern "system" fn Java_Errors_hello(
- e: &JNIEnv,
+ e: &EnvUnowned,
_class: JClass,
input: JString,
) -> jstring {
- try_unwrap_or_throw(e, |mut env| {
- let input: String = env
- .get_string(&input)
- .expect("Couldn't get java string!")
- .into();
+ try_unwrap_or_throw(e, |env| {
+ let input: String = input.try_to_string(env).expect("Couldn't get
java string!");
let output = env
.new_string(format!("Hello, {input}!"))
@@ -969,7 +994,7 @@ mod tests {
// * throwing an exception when dividing by zero
#[no_mangle]
pub extern "system" fn Java_Errors_div(
- env: &JNIEnv,
+ env: &EnvUnowned,
_class: JClass,
a: jlong,
b: jlong,
@@ -979,14 +1004,14 @@ mod tests {
#[no_mangle]
pub extern "system" fn Java_Errors_div_with_parse(
- e: &JNIEnv,
+ e: &EnvUnowned,
_class: JClass,
a: JString,
b: JString,
) -> jlong {
- try_unwrap_or_throw(e, |mut env| {
- let a_value: i64 = env.get_string(&a)?.to_str()?.parse()?;
- let b_value: i64 = env.get_string(&b)?.to_str()?.parse()?;
+ try_unwrap_or_throw(e, |env| {
+ let a_value: i64 = a.try_to_string(env)?.parse()?;
+ let b_value: i64 = b.try_to_string(env)?.parse()?;
Ok(a_value / b_value)
})
}
@@ -996,19 +1021,19 @@ mod tests {
// * throwing an exception when dividing by zero
#[no_mangle]
pub extern "system" fn Java_Errors_array_div(
- e: &JNIEnv,
+ e: &EnvUnowned,
_class: JClass,
input: &JIntArray,
divisor: jint,
) -> jintArray {
try_unwrap_or_throw(e, |env| {
let mut input_buf: [jint; 3] = [0; 3];
- env.get_int_array_region(input, 0, &mut input_buf)?;
+ input.get_region(env, 0, &mut input_buf)?;
let buf = input_buf.map(|v| -> jint { v / divisor });
let result = env.new_int_array(3)?;
- env.set_int_array_region(&result, 0, &buf)?;
+ result.set_region(env, 0, &buf)?;
Ok(result.into_raw())
})
}
@@ -1016,13 +1041,13 @@ mod tests {
// Helper method that asserts there is a pending Java exception which is
an `instance_of`
// `expected_type` with a message matching `expected_message` and clears
it if any.
fn assert_pending_java_exception_detailed(
- env: &mut JNIEnv,
+ env: &mut Env,
expected_type: Option<&str>,
expected_message: Option<&str>,
) {
- assert!(env.exception_check().unwrap());
+ assert!(env.exception_check());
let exception = env.exception_occurred().expect("Unable to get
exception");
- env.exception_clear().unwrap();
+ env.exception_clear();
if let Some(expected_type) = expected_type {
assert_exception_type(env, &exception, expected_type);
@@ -1034,29 +1059,42 @@ mod tests {
}
// Asserts that exception is an `instance_of` `expected_type` type.
- fn assert_exception_type(env: &mut JNIEnv, exception: &JThrowable,
expected_type: &str) {
- if !env.is_instance_of(exception, expected_type).unwrap() {
+ fn assert_exception_type(env: &mut Env, exception: &JThrowable,
expected_type: &str) {
+ if !env
+ .is_instance_of(exception,
jni::strings::JNIString::new(expected_type))
+ .unwrap()
+ {
let class: JClass = env.get_object_class(exception).unwrap();
let name = env
- .call_method(class, "getName", "()Ljava/lang/String;", &[])
+ .call_method(
+ class,
+ jni::jni_str!("getName"),
+ jni::jni_sig!("()Ljava/lang/String;"),
+ &[],
+ )
.unwrap()
.l()
.unwrap();
- let name_string = name.into();
- let class_name: String =
env.get_string(&name_string).unwrap().into();
+ let name_string = unsafe { JString::from_raw(env, name.into_raw())
};
+ let class_name: String = name_string.try_to_string(env).unwrap();
assert_eq!(class_name.replace('.', "/"), expected_type);
};
}
// Asserts that exception's message matches `expected_message`.
- fn assert_exception_message(env: &mut JNIEnv, exception: JThrowable,
expected_message: &str) {
+ fn assert_exception_message(env: &mut Env, exception: JThrowable,
expected_message: &str) {
let message = env
- .call_method(exception, "getMessage", "()Ljava/lang/String;", &[])
+ .call_method(
+ exception,
+ jni::jni_str!("getMessage"),
+ jni::jni_sig!("()Ljava/lang/String;"),
+ &[],
+ )
.unwrap()
.l()
.unwrap();
- let message_string = message.into();
- let msg_rust: String = env.get_string(&message_string).unwrap().into();
+ let message_string = unsafe { JString::from_raw(env,
message.into_raw()) };
+ let msg_rust: String = message_string.try_to_string(env).unwrap();
println!("{msg_rust}");
// Since panics result in multi-line messages which include the
backtrace, just use the
// first line.
diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs
index a2e25c3e2..5b0c0a4a5 100644
--- a/native/jni-bridge/src/lib.rs
+++ b/native/jni-bridge/src/lib.rs
@@ -24,9 +24,9 @@
use jni::objects::JClass;
use jni::{
errors::Error,
- objects::{JMethodID, JObject, JString, JThrowable, JValueGen, JValueOwned},
+ objects::{JMethodID, JObject, JString, JThrowable, JValueOwned},
signature::ReturnType,
- AttachGuard, JNIEnv, JavaVM,
+ Env, JavaVM,
};
use once_cell::sync::OnceCell;
@@ -127,15 +127,15 @@ macro_rules! jni_new_global_ref {
/// Wrapper for JString. Because we cannot implement `TryFrom` trait for
`JString` as they
/// are defined in different crates.
pub struct StringWrapper<'a> {
- value: JString<'a>,
+ value: JObject<'a>,
}
impl<'a> StringWrapper<'a> {
- pub fn new(value: JString<'a>) -> StringWrapper<'a> {
+ pub fn new(value: JObject<'a>) -> StringWrapper<'a> {
Self { value }
}
- pub fn get(&self) -> &JString<'_> {
+ pub fn get(&self) -> &JObject<'_> {
&self.value
}
}
@@ -159,7 +159,7 @@ impl<'a> TryFrom<JValueOwned<'a>> for StringWrapper<'a> {
fn try_from(value: JValueOwned<'a>) -> Result<StringWrapper<'a>, Error> {
match value {
- JValueGen::Object(b) => Ok(StringWrapper::new(JString::from(b))),
+ JValueOwned::Object(b) => Ok(StringWrapper::new(b)),
_ => Err(Error::WrongJValueType("object", value.type_name())),
}
}
@@ -170,7 +170,7 @@ impl<'a> TryFrom<JValueOwned<'a>> for BinaryWrapper<'a> {
fn try_from(value: JValueOwned<'a>) -> Result<BinaryWrapper<'a>, Error> {
match value {
- JValueGen::Object(b) => Ok(BinaryWrapper::new(b)),
+ JValueOwned::Object(b) => Ok(BinaryWrapper::new(b)),
_ => Err(Error::WrongJValueType("object", value.type_name())),
}
}
@@ -228,29 +228,47 @@ static JVM_CLASSES: OnceCell<JVMClasses> =
OnceCell::new();
impl JVMClasses<'_> {
/// Creates a new JVMClasses struct.
- pub fn init(env: &mut JNIEnv) {
+ pub fn init(env: &mut Env) {
JVM_CLASSES.get_or_init(|| {
- // A hack to make the `JNIEnv` static. It is not safe but we don't
really use the
- // `JNIEnv` except for creating the global references of the
classes.
- let env = unsafe { std::mem::transmute::<&mut JNIEnv, &'static mut
JNIEnv>(env) };
+ // A hack to make the `Env` static. It is not safe but we don't
really use the
+ // `Env` except for creating the global references of the classes.
+ let env = unsafe { std::mem::transmute::<&mut Env, &'static mut
Env>(env) };
- let java_lang_object = env.find_class("java/lang/Object").unwrap();
+ let java_lang_object =
env.find_class(jni::jni_str!("java/lang/Object")).unwrap();
let object_get_class_method = env
- .get_method_id(&java_lang_object, "getClass",
"()Ljava/lang/Class;")
+ .get_method_id(
+ &java_lang_object,
+ jni::jni_str!("getClass"),
+ jni::jni_sig!("()Ljava/lang/Class;"),
+ )
.unwrap();
- let java_lang_class = env.find_class("java/lang/Class").unwrap();
+ let java_lang_class =
env.find_class(jni::jni_str!("java/lang/Class")).unwrap();
let class_get_name_method = env
- .get_method_id(&java_lang_class, "getName",
"()Ljava/lang/String;")
+ .get_method_id(
+ &java_lang_class,
+ jni::jni_str!("getName"),
+ jni::jni_sig!("()Ljava/lang/String;"),
+ )
.unwrap();
- let java_lang_throwable =
env.find_class("java/lang/Throwable").unwrap();
+ let java_lang_throwable = env
+ .find_class(jni::jni_str!("java/lang/Throwable"))
+ .unwrap();
let throwable_get_message_method = env
- .get_method_id(&java_lang_throwable, "getMessage",
"()Ljava/lang/String;")
+ .get_method_id(
+ &java_lang_throwable,
+ jni::jni_str!("getMessage"),
+ jni::jni_sig!("()Ljava/lang/String;"),
+ )
.unwrap();
let throwable_get_cause_method = env
- .get_method_id(&java_lang_throwable, "getCause",
"()Ljava/lang/Throwable;")
+ .get_method_id(
+ &java_lang_throwable,
+ jni::jni_str!("getCause"),
+ jni::jni_sig!("()Ljava/lang/Throwable;"),
+ )
.unwrap();
// SAFETY: According to the documentation for `JMethodID`, it is
our
@@ -281,27 +299,34 @@ impl JVMClasses<'_> {
unsafe { JVM_CLASSES.get_unchecked() }
}
- /// Gets the JNIEnv for the current thread.
- pub fn get_env() -> CometResult<AttachGuard<'static>> {
+ /// Runs a closure with an attached JNI environment for the current thread.
+ pub fn with_env<T, E, F>(f: F) -> Result<T, E>
+ where
+ F: FnOnce(&mut Env) -> Result<T, E>,
+ E: From<CometError>,
+ {
debug_assert!(
JAVA_VM.get().is_some(),
- "JVMClasses::get_env: JAVA_VM not initialized"
+ "JVMClasses::with_env: JAVA_VM not initialized"
);
unsafe {
let java_vm = JAVA_VM.get_unchecked();
- java_vm.attach_current_thread().map_err(|e| {
- CometError::Internal(format!(
- "JVMClasses::get_env() failed to attach current thread:
{e}"
- ))
- })
+ let mut scope = jni::ScopeToken::default();
+ let mut guard = java_vm
+ .attach_current_thread_guard(Default::default, &mut scope)
+ .map_err(CometError::from)
+ .map_err(E::from)?;
+ f(guard.borrow_env_mut())
}
}
}
-pub fn check_exception(env: &mut JNIEnv) -> CometResult<Option<CometError>> {
- let result = if env.exception_check()? {
- let exception = env.exception_occurred()?;
- env.exception_clear()?;
+pub fn check_exception(env: &mut Env) -> CometResult<Option<CometError>> {
+ let result = if env.exception_check() {
+ let exception = env
+ .exception_occurred()
+ .expect("exception_check returned true without an exception");
+ env.exception_clear();
let exception_err = convert_exception(env, &exception)?;
Some(exception_err)
} else {
@@ -315,7 +340,7 @@ pub fn check_exception(env: &mut JNIEnv) ->
CometResult<Option<CometError>> {
/// 1. get the `Class` object of the input `throwable` via `Object#getClass`
method
/// 2. get the exception class name via calling `Class#getName` on the above
object
fn get_throwable_class_name(
- env: &mut JNIEnv,
+ env: &mut Env,
jvm_classes: &JVMClasses,
throwable: &JThrowable,
) -> CometResult<String> {
@@ -328,16 +353,17 @@ fn get_throwable_class_name(
&[],
)?
.l()?;
+ let class_obj = JClass::from_raw(env, class_obj.into_raw());
let class_name = env
.call_method_unchecked(
- class_obj,
+ &class_obj,
jvm_classes.class_get_name_method,
ReturnType::Object,
&[],
)?
- .l()?
- .into();
- let class_name_str = env.get_string(&class_name)?.into();
+ .l()?;
+ let class_name = JString::from_raw(env, class_name.into_raw());
+ let class_name_str = class_name.try_to_string(env)?;
Ok(class_name_str)
}
@@ -345,7 +371,7 @@ fn get_throwable_class_name(
/// Get the exception message via calling `Throwable#getMessage` on the
throwable object
fn get_throwable_message(
- env: &mut JNIEnv,
+ env: &mut Env,
jvm_classes: &JVMClasses,
throwable: &JThrowable,
) -> CometResult<String> {
@@ -357,10 +383,10 @@ fn get_throwable_message(
ReturnType::Object,
&[],
)?
- .l()?
- .into();
+ .l()
+ .map(|obj| JString::from_raw(env, obj.into_raw()))?;
let message_str = if !message.is_null() {
- env.get_string(&message)?.into()
+ message.try_to_string(env)?
} else {
String::from("null")
};
@@ -372,8 +398,8 @@ fn get_throwable_message(
ReturnType::Object,
&[],
)?
- .l()?
- .into();
+ .l()
+ .map(|obj| JThrowable::from_raw(env, obj.into_raw()))?;
if !cause.is_null() {
let cause_class_name = get_throwable_class_name(env, jvm_classes,
&cause)?;
@@ -391,7 +417,7 @@ fn get_throwable_message(
/// this converts it into a `CometError::JavaException` with the exception
class name
/// and exception message. This error can then be populated to the JVM side to
let
/// users know the cause of the native side error.
-pub fn convert_exception(env: &mut JNIEnv, throwable: &JThrowable) ->
CometResult<CometError> {
+pub fn convert_exception(env: &mut Env, throwable: &JThrowable) ->
CometResult<CometError> {
let cache = JVMClasses::get();
let exception_class_name_str = get_throwable_class_name(env, cache,
throwable)?;
let message_str = get_throwable_message(env, cache, throwable)?;
diff --git a/native/jni-bridge/src/shuffle_block_iterator.rs
b/native/jni-bridge/src/shuffle_block_iterator.rs
index c3bb5af5f..fb65bf725 100644
--- a/native/jni-bridge/src/shuffle_block_iterator.rs
+++ b/native/jni-bridge/src/shuffle_block_iterator.rs
@@ -20,7 +20,8 @@ use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
signature::ReturnType,
- JNIEnv,
+ strings::JNIString,
+ Env,
};
/// A struct that holds all the JNI methods and fields for JVM
`CometShuffleBlockIterator` class.
@@ -38,23 +39,27 @@ pub struct CometShuffleBlockIterator<'a> {
impl<'a> CometShuffleBlockIterator<'a> {
pub const JVM_CLASS: &'static str =
"org/apache/comet/CometShuffleBlockIterator";
- pub fn new(env: &mut JNIEnv<'a>) ->
JniResult<CometShuffleBlockIterator<'a>> {
- let class = env.find_class(Self::JVM_CLASS)?;
+ pub fn new(env: &mut Env<'a>) -> JniResult<CometShuffleBlockIterator<'a>> {
+ let class = env.find_class(JNIString::new(Self::JVM_CLASS))?;
Ok(CometShuffleBlockIterator {
class,
- method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext",
"()I")?,
+ method_has_next: env.get_method_id(
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("hasNext"),
+ jni::jni_sig!("()I"),
+ )?,
method_has_next_ret: ReturnType::Primitive(Primitive::Int),
method_get_buffer: env.get_method_id(
- Self::JVM_CLASS,
- "getBuffer",
- "()Ljava/nio/ByteBuffer;",
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getBuffer"),
+ jni::jni_sig!("()Ljava/nio/ByteBuffer;"),
)?,
method_get_buffer_ret: ReturnType::Object,
method_get_current_block_length: env.get_method_id(
- Self::JVM_CLASS,
- "getCurrentBlockLength",
- "()I",
+ JNIString::new(Self::JVM_CLASS),
+ jni::jni_str!("getCurrentBlockLength"),
+ jni::jni_sig!("()I"),
)?,
method_get_current_block_length_ret:
ReturnType::Primitive(Primitive::Int),
})
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]