This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 7018225 feat: Upgrade to `jni-rs` 0.21 (#50)
7018225 is described below
commit 701822599e6e14dd49d3dda481691ea5a38e498f
Author: Chao Sun <[email protected]>
AuthorDate: Tue Feb 20 10:30:42 2024 -0800
feat: Upgrade to `jni-rs` 0.21 (#50)
---
core/Cargo.lock | 104 +++++++++-
core/Cargo.toml | 4 +-
core/src/errors.rs | 169 ++++++++--------
.../execution/datafusion/expressions/subquery.rs | 204 +++++++++----------
core/src/execution/jni_api.rs | 219 +++++++++++----------
core/src/execution/metrics/utils.rs | 32 +--
core/src/jvm_bridge/comet_exec.rs | 102 +++++-----
core/src/jvm_bridge/comet_metric_node.rs | 22 +--
core/src/jvm_bridge/mod.rs | 24 +--
core/src/lib.rs | 8 +-
core/src/parquet/mod.rs | 168 +++++++++-------
core/src/parquet/util/jni.rs | 28 +--
core/src/parquet/util/mod.rs | 2 -
13 files changed, 602 insertions(+), 484 deletions(-)
diff --git a/core/Cargo.lock b/core/Cargo.lock
index 0585d7e..9c40b91 100644
--- a/core/Cargo.lock
+++ b/core/Cargo.lock
@@ -1087,7 +1087,7 @@ source =
"registry+https://github.com/rust-lang/crates.io-index"
checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245"
dependencies = [
"libc",
- "windows-sys",
+ "windows-sys 0.52.0",
]
[[package]]
@@ -1336,7 +1336,7 @@ version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
dependencies = [
- "windows-sys",
+ "windows-sys 0.52.0",
]
[[package]]
@@ -1436,7 +1436,7 @@ checksum =
"0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455"
dependencies = [
"hermit-abi",
"rustix",
- "windows-sys",
+ "windows-sys 0.52.0",
]
[[package]]
@@ -1472,18 +1472,32 @@ version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
+[[package]]
+name = "java-locator"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "90003f2fd9c52f212c21d8520f1128da0080bad6fff16b68fe6e7f2f0c3780c2"
+dependencies = [
+ "glob",
+ "lazy_static",
+]
+
[[package]]
name = "jni"
-version = "0.19.0"
+version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec"
+checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97"
dependencies = [
"cesu8",
+ "cfg-if",
"combine",
+ "java-locator",
"jni-sys",
+ "libloading",
"log",
"thiserror",
"walkdir",
+ "windows-sys 0.45.0",
]
[[package]]
@@ -1586,6 +1600,16 @@ version = "0.2.151"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4"
+[[package]]
+name = "libloading"
+version = "0.7.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f"
+dependencies = [
+ "cfg-if",
+ "winapi",
+]
+
[[package]]
name = "libm"
version = "0.2.8"
@@ -2319,7 +2343,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
- "windows-sys",
+ "windows-sys 0.52.0",
]
[[package]]
@@ -2602,7 +2626,7 @@ dependencies = [
"fastrand",
"redox_syscall",
"rustix",
- "windows-sys",
+ "windows-sys 0.52.0",
]
[[package]]
@@ -3009,6 +3033,15 @@ dependencies = [
"windows-targets 0.52.0",
]
+[[package]]
+name = "windows-sys"
+version = "0.45.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
+dependencies = [
+ "windows-targets 0.42.2",
+]
+
[[package]]
name = "windows-sys"
version = "0.52.0"
@@ -3018,6 +3051,21 @@ dependencies = [
"windows-targets 0.52.0",
]
+[[package]]
+name = "windows-targets"
+version = "0.42.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
+dependencies = [
+ "windows_aarch64_gnullvm 0.42.2",
+ "windows_aarch64_msvc 0.42.2",
+ "windows_i686_gnu 0.42.2",
+ "windows_i686_msvc 0.42.2",
+ "windows_x86_64_gnu 0.42.2",
+ "windows_x86_64_gnullvm 0.42.2",
+ "windows_x86_64_msvc 0.42.2",
+]
+
[[package]]
name = "windows-targets"
version = "0.48.5"
@@ -3048,6 +3096,12 @@ dependencies = [
"windows_x86_64_msvc 0.52.0",
]
+[[package]]
+name = "windows_aarch64_gnullvm"
+version = "0.42.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
+
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
@@ -3060,6 +3114,12 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea"
+[[package]]
+name = "windows_aarch64_msvc"
+version = "0.42.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
+
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
@@ -3072,6 +3132,12 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef"
+[[package]]
+name = "windows_i686_gnu"
+version = "0.42.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
+
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
@@ -3084,6 +3150,12 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313"
+[[package]]
+name = "windows_i686_msvc"
+version = "0.42.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
+
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
@@ -3096,6 +3168,12 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a"
+[[package]]
+name = "windows_x86_64_gnu"
+version = "0.42.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
+
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
@@ -3108,6 +3186,12 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd"
+[[package]]
+name = "windows_x86_64_gnullvm"
+version = "0.42.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
+
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
@@ -3120,6 +3204,12 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e"
+[[package]]
+name = "windows_x86_64_msvc"
+version = "0.42.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
+
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
diff --git a/core/Cargo.toml b/core/Cargo.toml
index d27b833..b4df34d 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -48,7 +48,7 @@ serde = { version = "1", features = ["derive"] }
lazy_static = "1.4.0"
prost = "0.12.1"
thrift = "0.17"
-jni = "0.19"
+jni = "0.21"
byteorder = "1.4.3"
snap = "1.1"
brotli = "3.3"
@@ -81,7 +81,7 @@ prost-build = "0.9.0"
[dev-dependencies]
pprof = { version = "0.13.0", features = ["flamegraph"] }
criterion = "0.5.1"
-jni = { version = "0.19", features = ["invocation"] }
+jni = { version = "0.21", features = ["invocation"] }
lazy_static = "1.4"
assertables = "7"
diff --git a/core/src/errors.rs b/core/src/errors.rs
index a5f52d3..7188ebd 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -298,7 +298,7 @@ impl JNIDefault for () {
// `RuntimeException` back to the calling Java. Since a return result is
required, use `JNIDefault`
// to create a reasonable result. This returned default value will be ignored
due to the exception.
pub fn unwrap_or_throw_default<T: JNIDefault>(
- env: &JNIEnv,
+ env: &mut JNIEnv,
result: std::result::Result<T, CometError>,
) -> T {
match result {
@@ -314,7 +314,7 @@ pub fn unwrap_or_throw_default<T: JNIDefault>(
}
}
-fn throw_exception<E: ToException>(env: &JNIEnv, error: &E, backtrace:
Option<String>) {
+fn throw_exception<E: ToException>(env: &mut JNIEnv, error: &E, backtrace:
Option<String>) {
// If there isn't already an exception?
if env.exception_check().is_ok() {
// ... then throw new exception
@@ -380,37 +380,46 @@ fn flatten<T, E>(result: Result<Result<T, E>, E>) ->
Result<T, E> {
result.and_then(convert::identity)
}
-// It is currently undefined behavior to unwind from Rust code into foreign
code, so we can wrap
-// our JNI functions and turn these panics into a `RuntimeException`.
-pub fn try_or_throw<T, F>(env: JNIEnv, f: F) -> T
+// Implements "currying" from `FnOnce(T) -> R` to `FnOnce() -> R`, given
+// an instance of T. Curring is not supported in Rust so we have to use this
+// custom function to achieve something similar here.
+fn curry<'a, T: 'a, F, R>(f: F, t: T) -> impl FnOnce() -> R + 'a
where
- T: JNIDefault,
- F: FnOnce() -> T + UnwindSafe,
+ F: FnOnce(T) -> R + 'a,
{
- unwrap_or_throw_default(&env, catch_unwind(f).map_err(CometError::from))
+ || f(t)
}
// This is a duplicate of `try_unwrap_or_throw`, which is used to work around
Arrow's lack of
// `UnwindSafe` handling.
-pub fn try_assert_unwind_safe_or_throw<T, F>(env: JNIEnv, f: F) -> T
+pub fn try_assert_unwind_safe_or_throw<T, F>(env: &JNIEnv, f: F) -> T
where
T: JNIDefault,
- F: FnOnce() -> Result<T, CometError>,
+ F: FnOnce(JNIEnv) -> Result<T, CometError>,
{
+ let mut env1 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() };
+ let env2 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() };
unwrap_or_throw_default(
- &env,
-
flatten(catch_unwind(std::panic::AssertUnwindSafe(f)).map_err(CometError::from)),
+ &mut env1,
+ flatten(
+ catch_unwind(std::panic::AssertUnwindSafe(curry(f,
env2))).map_err(CometError::from),
+ ),
)
}
// It is currently undefined behavior to unwind from Rust code into foreign
code, so we can wrap
// our JNI functions and turn these panics into a `RuntimeException`.
-pub fn try_unwrap_or_throw<T, F>(env: JNIEnv, f: F) -> T
+pub fn try_unwrap_or_throw<T, F>(env: &JNIEnv, f: F) -> T
where
T: JNIDefault,
- F: FnOnce() -> Result<T, CometError> + UnwindSafe,
+ F: FnOnce(JNIEnv) -> Result<T, CometError> + UnwindSafe,
{
- unwrap_or_throw_default(&env,
flatten(catch_unwind(f).map_err(CometError::from)))
+ let mut env1 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() };
+ let env2 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() };
+ unwrap_or_throw_default(
+ &mut env1,
+ flatten(catch_unwind(curry(f, env2)).map_err(CometError::from)),
+ )
}
#[cfg(test)]
@@ -425,7 +434,7 @@ mod tests {
};
use jni::{
- objects::{JClass, JObject, JString, JThrowable},
+ objects::{JClass, JIntArray, JString, JThrowable},
sys::{jintArray, jstring},
AttachGuard, InitArgsBuilder, JNIEnv, JNIVersion, JavaVM,
};
@@ -482,14 +491,14 @@ mod tests {
#[test]
pub fn error_from_panic() {
let _guard = attach_current_thread();
- let env = jvm().get_env().unwrap();
+ let mut env = jvm().get_env().unwrap();
- try_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| -> CometResult<()> {
panic!("oops!");
});
assert_pending_java_exception_detailed(
- &env,
+ &mut env,
Some("java/lang/RuntimeException"),
Some("oops!"),
);
@@ -500,38 +509,16 @@ mod tests {
#[test]
pub fn object_result() {
let _guard = attach_current_thread();
- let env = jvm().get_env().unwrap();
+ let mut env = jvm().get_env().unwrap();
let clazz = env.find_class("java/lang/Object").unwrap();
let input = env.new_string("World".to_string()).unwrap();
- let actual = Java_Errors_hello(env, clazz, input);
-
- let actual_string =
String::from(env.get_string(actual.into()).unwrap().to_str().unwrap());
- assert_eq!("Hello, World!", actual_string);
- }
-
- // Verify that functions that return an object can handle throwing
exceptions. The test
- // causes an exception by passing a `null` where a string value is
expected.
- #[test]
- pub fn object_panic_exception() {
- let _guard = attach_current_thread();
- let env = jvm().get_env().unwrap();
- // Class java.lang.object is just a stand-in
- let class = env.find_class("java/lang/Object").unwrap();
- let input = JString::from(JObject::null());
- let _actual = Java_Errors_hello(env, class, input);
-
- assert!(env.exception_check().unwrap());
- let exception = env.exception_occurred().expect("Unable to get
exception");
- env.exception_clear().unwrap();
+ let actual = Java_Errors_hello(&env, clazz, input);
+ let actual_s = unsafe { JString::from_raw(actual) };
- assert_exception_message_with_stacktrace(
- &env,
- exception,
- "Couldn't get java string!: NullPtr(\"get_string obj argument\")",
- "at Java_Errors_hello(",
- );
+ let actual_string =
String::from(env.get_string(&actual_s).unwrap().to_str().unwrap());
+ assert_eq!("Hello, World!", actual_string);
}
// Verify that functions that return an native time are handled correctly.
This is basically
@@ -539,13 +526,13 @@ mod tests {
#[test]
pub fn jlong_result() {
let _guard = attach_current_thread();
- let env = jvm().get_env().unwrap();
+ let mut env = jvm().get_env().unwrap();
// Class java.lang.object is just a stand-in
let class = env.find_class("java/lang/Object").unwrap();
let a: jlong = 6;
let b: jlong = 3;
- let actual = Java_Errors_div(env, class, a, b);
+ let actual = Java_Errors_div(&env, class, a, b);
assert_eq!(2, actual);
}
@@ -555,16 +542,16 @@ mod tests {
#[test]
pub fn jlong_panic_exception() {
let _guard = attach_current_thread();
- let env = jvm().get_env().unwrap();
+ let mut env = jvm().get_env().unwrap();
// Class java.lang.object is just a stand-in
let class = env.find_class("java/lang/Object").unwrap();
let a: jlong = 6;
let b: jlong = 0;
- let _actual = Java_Errors_div(env, class, a, b);
+ let _actual = Java_Errors_div(&env, class, a, b);
assert_pending_java_exception_detailed(
- &env,
+ &mut env,
Some("java/lang/RuntimeException"),
Some("attempt to divide by zero"),
);
@@ -575,13 +562,13 @@ mod tests {
#[test]
pub fn jlong_result_ok() {
let _guard = attach_current_thread();
- let env = jvm().get_env().unwrap();
+ let mut env = jvm().get_env().unwrap();
// Class java.lang.object is just a stand-in
let class = env.find_class("java/lang/Object").unwrap();
let a: JString = env.new_string("9".to_string()).unwrap();
let b: JString = env.new_string("3".to_string()).unwrap();
- let actual = Java_Errors_div_with_parse(env, class, a, b);
+ let actual = Java_Errors_div_with_parse(&env, class, a, b);
assert_eq!(3, actual);
}
@@ -591,16 +578,16 @@ mod tests {
#[test]
pub fn jlong_result_err() {
let _guard = attach_current_thread();
- let env = jvm().get_env().unwrap();
+ let mut env = jvm().get_env().unwrap();
// Class java.lang.object is just a stand-in
let class = env.find_class("java/lang/Object").unwrap();
let a: JString = env.new_string("NaN".to_string()).unwrap();
let b: JString = env.new_string("3".to_string()).unwrap();
- let _actual = Java_Errors_div_with_parse(env, class, a, b);
+ let _actual = Java_Errors_div_with_parse(&env, class, a, b);
assert_pending_java_exception_detailed(
- &env,
+ &mut env,
Some("java/lang/NumberFormatException"),
Some("invalid digit found in string"),
);
@@ -611,17 +598,18 @@ mod tests {
#[test]
pub fn jint_array_result() {
let _guard = attach_current_thread();
- let env = jvm().get_env().unwrap();
+ let mut env = jvm().get_env().unwrap();
// Class java.lang.object is just a stand-in
let class = env.find_class("java/lang/Object").unwrap();
let buf = [2, 4, 6];
let input = env.new_int_array(3).unwrap();
- env.set_int_array_region(input, 0, &buf).unwrap();
- let actual = Java_Errors_array_div(env, class, input, 2);
+ env.set_int_array_region(&input, 0, &buf).unwrap();
+ let actual = Java_Errors_array_div(&env, class, &input, 2);
+ let actual_s = unsafe { JIntArray::from_raw(actual) };
let mut buf: [i32; 3] = [0; 3];
- env.get_int_array_region(actual, 0, &mut buf).unwrap();
+ env.get_int_array_region(&actual_s, 0, &mut buf).unwrap();
assert_eq!([1, 2, 3], buf);
}
@@ -630,17 +618,17 @@ mod tests {
#[test]
pub fn jint_array_panic_exception() {
let _guard = attach_current_thread();
- let env = jvm().get_env().unwrap();
+ let mut env = jvm().get_env().unwrap();
// Class java.lang.object is just a stand-in
let class = env.find_class("java/lang/Object").unwrap();
let buf = [2, 4, 6];
let input = env.new_int_array(3).unwrap();
- env.set_int_array_region(input, 0, &buf).unwrap();
- let _actual = Java_Errors_array_div(env, class, input, 0);
+ env.set_int_array_region(&input, 0, &buf).unwrap();
+ let _actual = Java_Errors_array_div(&env, class, &input, 0);
assert_pending_java_exception_detailed(
- &env,
+ &mut env,
Some("java/lang/RuntimeException"),
Some("attempt to divide by zero"),
);
@@ -683,13 +671,13 @@ mod tests {
// * throwing an exception from `.expect()`
#[no_mangle]
pub extern "system" fn Java_Errors_hello(
- env: JNIEnv,
+ e: &JNIEnv,
_class: JClass,
input: JString,
) -> jstring {
- try_or_throw(env, || {
+ try_unwrap_or_throw(&e, |mut env| {
let input: String = env
- .get_string(input)
+ .get_string(&input)
.expect("Couldn't get java string!")
.into();
@@ -697,7 +685,7 @@ mod tests {
.new_string(format!("Hello, {}!", input))
.expect("Couldn't create java string!");
- output.into_inner()
+ Ok(output.into_raw())
})
}
@@ -706,24 +694,24 @@ mod tests {
// * throwing an exception when dividing by zero
#[no_mangle]
pub extern "system" fn Java_Errors_div(
- env: JNIEnv,
+ env: &JNIEnv,
_class: JClass,
a: jlong,
b: jlong,
) -> jlong {
- try_or_throw(env, || a / b)
+ try_unwrap_or_throw(env, |_| Ok(a / b))
}
#[no_mangle]
pub extern "system" fn Java_Errors_div_with_parse(
- env: JNIEnv,
+ e: &JNIEnv,
_class: JClass,
a: JString,
b: JString,
) -> jlong {
- try_unwrap_or_throw(env, || {
- let a_value: i64 = env.get_string(a)?.to_str()?.parse()?;
- let b_value: i64 = env.get_string(b)?.to_str()?.parse()?;
+ try_unwrap_or_throw(e, |mut env| {
+ let a_value: i64 = env.get_string(&a)?.to_str()?.parse()?;
+ let b_value: i64 = env.get_string(&b)?.to_str()?.parse()?;
Ok(a_value / b_value)
})
}
@@ -733,27 +721,27 @@ mod tests {
// * throwing an exception when dividing by zero
#[no_mangle]
pub extern "system" fn Java_Errors_array_div(
- env: JNIEnv,
+ e: &JNIEnv,
_class: JClass,
- input: jintArray,
+ input: &JIntArray,
divisor: jint,
) -> jintArray {
- try_or_throw(env, || {
+ try_unwrap_or_throw(e, |env| {
let mut input_buf: [jint; 3] = [0; 3];
- env.get_int_array_region(input, 0, &mut input_buf).unwrap();
+ env.get_int_array_region(input, 0, &mut input_buf)?;
let buf = input_buf.map(|v| -> jint { v / divisor });
- let result = env.new_int_array(3).unwrap();
- env.set_int_array_region(result, 0, &buf).unwrap();
- result
+ let result = env.new_int_array(3)?;
+ env.set_int_array_region(&result, 0, &buf)?;
+ Ok(result.into_raw())
})
}
// Helper method that asserts there is a pending Java exception which is
an `instance_of`
// `expected_type` with a message matching `expected_message` and clears
it if any.
fn assert_pending_java_exception_detailed(
- env: &JNIEnv,
+ env: &mut JNIEnv,
expected_type: Option<&str>,
expected_message: Option<&str>,
) {
@@ -762,7 +750,7 @@ mod tests {
env.exception_clear().unwrap();
if let Some(expected_type) = expected_type {
- assert_exception_type(env, exception, expected_type);
+ assert_exception_type(env, &exception, expected_type);
}
if let Some(expected_message) = expected_message {
@@ -771,7 +759,7 @@ mod tests {
}
// Asserts that exception is an `instance_of` `expected_type` type.
- fn assert_exception_type(env: &JNIEnv, exception: JThrowable,
expected_type: &str) {
+ fn assert_exception_type(env: &mut JNIEnv, exception: &JThrowable,
expected_type: &str) {
if !env.is_instance_of(exception, expected_type).unwrap() {
let class: JClass = env.get_object_class(exception).unwrap();
let name = env
@@ -779,19 +767,21 @@ mod tests {
.unwrap()
.l()
.unwrap();
- let class_name: String =
env.get_string(name.into()).unwrap().into();
+ let name_string = name.into();
+ let class_name: String =
env.get_string(&name_string).unwrap().into();
assert_eq!(class_name.replace('.', "/"), expected_type);
};
}
// Asserts that exception's message matches `expected_message`.
- fn assert_exception_message(env: &JNIEnv, exception: JThrowable,
expected_message: &str) {
+ fn assert_exception_message(env: &mut JNIEnv, exception: JThrowable,
expected_message: &str) {
let message = env
.call_method(exception, "getMessage", "()Ljava/lang/String;", &[])
.unwrap()
.l()
.unwrap();
- let msg_rust: String = env.get_string(message.into()).unwrap().into();
+ let message_string = message.into();
+ let msg_rust: String = env.get_string(&message_string).unwrap().into();
println!("{}", msg_rust);
// Since panics result in multi-line messages which include the
backtrace, just use the
// first line.
@@ -800,7 +790,7 @@ mod tests {
// Asserts that exception's message matches `expected_message`.
fn assert_exception_message_with_stacktrace(
- env: &JNIEnv,
+ env: &mut JNIEnv,
exception: JThrowable,
expected_message: &str,
stacktrace_contains: &str,
@@ -810,7 +800,8 @@ mod tests {
.unwrap()
.l()
.unwrap();
- let msg_rust: String = env.get_string(message.into()).unwrap().into();
+ let message_string = message.into();
+ let msg_rust: String = env.get_string(&message_string).unwrap().into();
// Since panics result in multi-line messages which include the
backtrace, just use the
// first line.
assert_starts_with!(msg_rust, expected_message);
diff --git a/core/src/execution/datafusion/expressions/subquery.rs
b/core/src/execution/datafusion/expressions/subquery.rs
index a82fb35..a4b32ba 100644
--- a/core/src/execution/datafusion/expressions/subquery.rs
+++ b/core/src/execution/datafusion/expressions/subquery.rs
@@ -20,7 +20,10 @@ use arrow_schema::{DataType, Schema, TimeUnit};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, DataFusionError, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
-use jni::sys::{jboolean, jbyte, jint, jlong, jshort};
+use jni::{
+ objects::JByteArray,
+ sys::{jboolean, jbyte, jint, jlong, jshort},
+};
use std::{
any::Any,
fmt::{Display, Formatter},
@@ -87,109 +90,112 @@ impl PhysicalExpr for Subquery {
}
fn evaluate(&self, _: &RecordBatch) ->
datafusion_common::Result<ColumnarValue> {
- let env = JVMClasses::get_env();
-
- let is_null =
- jni_static_call!(env, comet_exec.is_null(self.exec_context_id,
self.id) -> jboolean)?;
+ let mut env = JVMClasses::get_env();
- if is_null > 0 {
- return Ok(ColumnarValue::Scalar(ScalarValue::try_from(
- &self.data_type,
- )?));
- }
+ unsafe {
+ let is_null = jni_static_call!(env,
+ comet_exec.is_null(self.exec_context_id, self.id) -> jboolean
+ )?;
- match &self.data_type {
- DataType::Boolean => {
- let r = jni_static_call!(env,
- comet_exec.get_bool(self.exec_context_id, self.id) ->
jboolean
- )?;
- Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0))))
- }
- DataType::Int8 => {
- let r = jni_static_call!(env,
- comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte
- )?;
- Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r))))
- }
- DataType::Int16 => {
- let r = jni_static_call!(env,
- comet_exec.get_short(self.exec_context_id, self.id) ->
jshort
- )?;
- Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r))))
+ if is_null > 0 {
+ return Ok(ColumnarValue::Scalar(ScalarValue::try_from(
+ &self.data_type,
+ )?));
}
- DataType::Int32 => {
- let r = jni_static_call!(env,
- comet_exec.get_int(self.exec_context_id, self.id) -> jint
- )?;
- Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r))))
- }
- DataType::Int64 => {
- let r = jni_static_call!(env,
- comet_exec.get_long(self.exec_context_id, self.id) -> jlong
- )?;
- Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r))))
- }
- DataType::Float32 => {
- let r = jni_static_call!(env,
- comet_exec.get_float(self.exec_context_id, self.id) -> f32
- )?;
- Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r))))
- }
- DataType::Float64 => {
- let r = jni_static_call!(env,
- comet_exec.get_double(self.exec_context_id, self.id) -> f64
- )?;
-
- Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r))))
- }
- DataType::Decimal128(p, s) => {
- let bytes = jni_static_call!(env,
- comet_exec.get_decimal(self.exec_context_id, self.id) ->
BinaryWrapper
- )?;
-
- let slice =
env.convert_byte_array((*bytes.get()).into_inner()).unwrap();
-
- Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
- Some(bytes_to_i128(&slice)),
- *p,
- *s,
- )))
- }
- DataType::Date32 => {
- let r = jni_static_call!(env,
- comet_exec.get_int(self.exec_context_id, self.id) -> jint
- )?;
-
- Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r))))
- }
- DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
- let r = jni_static_call!(env,
- comet_exec.get_long(self.exec_context_id, self.id) -> jlong
- )?;
-
- Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
- Some(r),
- timezone.clone(),
- )))
- }
- DataType::Utf8 => {
- let string = jni_static_call!(env,
- comet_exec.get_string(self.exec_context_id, self.id) ->
StringWrapper
- )?;
-
- let string = env.get_string(*string.get()).unwrap().into();
- Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))))
- }
- DataType::Binary => {
- let bytes = jni_static_call!(env,
- comet_exec.get_binary(self.exec_context_id, self.id) ->
BinaryWrapper
- )?;
-
- let slice =
env.convert_byte_array((*bytes.get()).into_inner()).unwrap();
- Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(slice))))
+ match &self.data_type {
+ DataType::Boolean => {
+ let r = jni_static_call!(env,
+ comet_exec.get_bool(self.exec_context_id, self.id) ->
jboolean
+ )?;
+ Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r >
0))))
+ }
+ DataType::Int8 => {
+ let r = jni_static_call!(env,
+ comet_exec.get_byte(self.exec_context_id, self.id) ->
jbyte
+ )?;
+ Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r))))
+ }
+ DataType::Int16 => {
+ let r = jni_static_call!(env,
+ comet_exec.get_short(self.exec_context_id, self.id) ->
jshort
+ )?;
+ Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r))))
+ }
+ DataType::Int32 => {
+ let r = jni_static_call!(env,
+ comet_exec.get_int(self.exec_context_id, self.id) ->
jint
+ )?;
+ Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r))))
+ }
+ DataType::Int64 => {
+ let r = jni_static_call!(env,
+ comet_exec.get_long(self.exec_context_id, self.id) ->
jlong
+ )?;
+ Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r))))
+ }
+ DataType::Float32 => {
+ let r = jni_static_call!(env,
+ comet_exec.get_float(self.exec_context_id, self.id) ->
f32
+ )?;
+ Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r))))
+ }
+ DataType::Float64 => {
+ let r = jni_static_call!(env,
+ comet_exec.get_double(self.exec_context_id, self.id)
-> f64
+ )?;
+
+ Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r))))
+ }
+ DataType::Decimal128(p, s) => {
+ let bytes = jni_static_call!(env,
+ comet_exec.get_decimal(self.exec_context_id, self.id)
-> BinaryWrapper
+ )?;
+ let bytes: &JByteArray = bytes.get().into();
+ let slice = env.convert_byte_array(bytes).unwrap();
+
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(bytes_to_i128(&slice)),
+ *p,
+ *s,
+ )))
+ }
+ DataType::Date32 => {
+ let r = jni_static_call!(env,
+ comet_exec.get_int(self.exec_context_id, self.id) ->
jint
+ )?;
+
+ Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r))))
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
+ let r = jni_static_call!(env,
+ comet_exec.get_long(self.exec_context_id, self.id) ->
jlong
+ )?;
+
+ Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
+ Some(r),
+ timezone.clone(),
+ )))
+ }
+ DataType::Utf8 => {
+ let string = jni_static_call!(env,
+ comet_exec.get_string(self.exec_context_id, self.id)
-> StringWrapper
+ )?;
+
+ let string = env.get_string(string.get()).unwrap().into();
+ Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))))
+ }
+ DataType::Binary => {
+ let bytes = jni_static_call!(env,
+ comet_exec.get_binary(self.exec_context_id, self.id)
-> BinaryWrapper
+ )?;
+ let bytes: &JByteArray = bytes.get().into();
+ let slice = env.convert_byte_array(bytes).unwrap();
+
+ Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(slice))))
+ }
+ _ => internal_err!("Unsupported scalar subquery data type
{:?}", self.data_type),
}
- _ => internal_err!("Unsupported scalar subquery data type {:?}",
self.data_type),
}
}
diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs
index 9981cec..831f788 100644
--- a/core/src/execution/jni_api.rs
+++ b/core/src/execution/jni_api.rs
@@ -36,7 +36,10 @@ use datafusion_common::DataFusionError;
use futures::poll;
use jni::{
errors::Result as JNIResult,
- objects::{JClass, JMap, JObject, JString, ReleaseMode},
+ objects::{
+ AutoElements, JBooleanArray, JByteArray, JClass, JIntArray,
JLongArray, JMap, JObject,
+ JObjectArray, JPrimitiveArray, JString, ReleaseMode,
+ },
sys::{jbyteArray, jint, jlong, jlongArray},
JNIEnv,
};
@@ -45,7 +48,7 @@ use std::{collections::HashMap, sync::Arc, task::Poll};
use super::{serde, utils::SparkArrowConvert};
use crate::{
- errors::{try_unwrap_or_throw, CometError},
+ errors::{try_unwrap_or_throw, CometError, CometResult},
execution::{
datafusion::planner::PhysicalPlanner,
metrics::utils::update_comet_metric,
serde::to_arrow_datatype, shuffle::row::process_sorted_row_partition,
sort::RdxSort,
@@ -55,7 +58,7 @@ use crate::{
};
use futures::stream::StreamExt;
use jni::{
- objects::{AutoArray, GlobalRef},
+ objects::GlobalRef,
sys::{jboolean, jbooleanArray, jdouble, jintArray, jobjectArray, jstring},
};
use tokio::runtime::Runtime;
@@ -88,21 +91,24 @@ struct ExecutionContext {
pub debug_native: bool,
}
-#[no_mangle]
/// Accept serialized query plan and return the address of the native query
plan.
-pub extern "system" fn Java_org_apache_comet_Native_createPlan(
- env: JNIEnv,
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
+#[no_mangle]
+pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
+ e: JNIEnv,
_class: JClass,
id: jlong,
config_object: JObject,
serialized_query: jbyteArray,
metrics_node: JObject,
) -> jlong {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |mut env| {
// Init JVM classes
- JVMClasses::init(&env);
+ JVMClasses::init(&mut env);
- let bytes = env.convert_byte_array(serialized_query)?;
+ let array = unsafe { JPrimitiveArray::from_raw(serialized_query) };
+ let bytes = env.convert_byte_array(array)?;
// Deserialize query plan
let spark_plan = serde::deserialize_op(bytes.as_slice())?;
@@ -110,13 +116,13 @@ pub extern "system" fn
Java_org_apache_comet_Native_createPlan(
// Sets up context
let mut configs = HashMap::new();
- let config_map = JMap::from_env(&env, config_object)?;
- config_map.iter()?.for_each(|config| {
- let key: String =
env.get_string(JString::from(config.0)).unwrap().into();
- let value: String =
env.get_string(JString::from(config.1)).unwrap().into();
-
+ let config_map = JMap::from_env(&mut env, &config_object)?;
+ let mut map_iter = config_map.iter(&mut env)?;
+ while let Some((key, value)) = map_iter.next(&mut env)? {
+ let key: String =
env.get_string(&JString::from(key)).unwrap().into();
+ let value: String =
env.get_string(&JString::from(value)).unwrap().into();
configs.insert(key, value);
- });
+ }
// Whether we've enabled additional debugging on the native side
let debug_native = configs
@@ -157,8 +163,8 @@ pub extern "system" fn
Java_org_apache_comet_Native_createPlan(
/// Parse Comet configs and configure DataFusion session context.
fn prepare_datafusion_session_context(
conf: &HashMap<String, String>,
-) -> Result<SessionContext, CometError> {
- // Get the batch size from Comet JVM side
+) -> CometResult<SessionContext> {
+ // Get the batch size from Boson JVM side
let batch_size = conf
.get("batch_size")
.ok_or(CometError::Internal(
@@ -205,10 +211,10 @@ fn prepare_datafusion_session_context(
/// Prepares arrow arrays for output.
fn prepare_output(
+ env: &mut JNIEnv,
output: Result<RecordBatch, DataFusionError>,
- env: JNIEnv,
exec_context: &mut ExecutionContext,
-) -> Result<jlongArray, CometError> {
+) -> CometResult<jlongArray> {
let output_batch = output?;
let results = output_batch.columns();
let num_rows = output_batch.num_rows();
@@ -226,7 +232,7 @@ fn prepare_output(
let return_flag = 1;
let long_array = env.new_long_array((results.len() * 2) as i32 + 2)?;
- env.set_long_array_region(long_array, 0, &[return_flag, num_rows as
jlong])?;
+ env.set_long_array_region(&long_array, 0, &[return_flag, num_rows as
jlong])?;
let mut arrays = vec![];
@@ -241,48 +247,61 @@ fn prepare_output(
arrays.push((arrow_array, arrow_schema));
}
- env.set_long_array_region(long_array, (i * 2) as i32 + 2, &[array,
schema])?;
+ env.set_long_array_region(&long_array, (i * 2) as i32 + 2, &[array,
schema])?;
i += 1;
}
// Update metrics
- update_metrics(&env, exec_context)?;
+ update_metrics(env, exec_context)?;
// Record the pointer to allocated Arrow Arrays
exec_context.ffi_arrays = arrays;
- Ok(long_array)
+ Ok(long_array.into_raw())
}
-#[no_mangle]
/// Accept serialized query plan and the addresses of Arrow Arrays from Spark,
/// then execute the query. Return addresses of arrow vector.
-pub extern "system" fn Java_org_apache_comet_Native_executePlan(
- env: JNIEnv,
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
+#[no_mangle]
+pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
+ e: JNIEnv,
_class: JClass,
exec_context: jlong,
addresses_array: jobjectArray,
finishes: jbooleanArray,
batch_rows: jint,
) -> jlongArray {
- try_unwrap_or_throw(env, || {
- let addresses_vec = convert_addresses_arrays(&env, addresses_array)?;
- let mut all_inputs: Vec<Vec<ArrayRef>> =
Vec::with_capacity(addresses_vec.len());
-
+ try_unwrap_or_throw(&e, |mut env| unsafe {
let exec_context = get_execution_context(exec_context);
- for addresses in addresses_vec.iter() {
+
+ let addresses = JObjectArray::from_raw(addresses_array);
+ let num_addresses = env.get_array_length(&addresses)? as usize;
+
+ let mut all_inputs: Vec<Vec<ArrayRef>> =
Vec::with_capacity(num_addresses);
+
+ for i in 0..num_addresses {
let mut inputs: Vec<ArrayRef> = vec![];
- let array_num = addresses.size()? as usize;
- assert_eq!(array_num % 2, 0, "Arrow Array addresses are invalid!");
+ let inner_addresses = env.get_object_array_element(&addresses, i
as i32)?.into();
+ let inner_address_array: AutoElements<jlong> =
+ env.get_array_elements(&inner_addresses,
ReleaseMode::NoCopyBack)?;
- let num_arrays = array_num / 2;
- let array_elements = addresses.as_ptr();
+ let num_inner_address = inner_address_array.len();
+ assert_eq!(
+ num_inner_address % 2,
+ 0,
+ "Arrow Array addresses are invalid!"
+ );
+
+ let num_arrays = num_inner_address / 2;
+ let array_elements = inner_address_array.as_ptr();
let mut i: usize = 0;
while i < num_arrays {
- let array_ptr = unsafe { *(array_elements.add(i * 2)) };
- let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) };
+ let array_ptr = *(array_elements.add(i * 2));
+ let schema_ptr = *(array_elements.add(i * 2 + 1));
let array_data = ArrayData::from_spark((array_ptr,
schema_ptr))?;
if exec_context.debug_native {
@@ -298,7 +317,8 @@ pub extern "system" fn
Java_org_apache_comet_Native_executePlan(
}
// Prepares the input batches.
- let eofs = env.get_boolean_array_elements(finishes,
ReleaseMode::NoCopyBack)?;
+ let array = JBooleanArray::from_raw(finishes);
+ let eofs = env.get_array_elements(&array, ReleaseMode::NoCopyBack)?;
let eof_flags = eofs.as_ptr();
// Whether reaching the end of input batches.
@@ -306,7 +326,7 @@ pub extern "system" fn
Java_org_apache_comet_Native_executePlan(
let mut input_batches = all_inputs
.into_iter()
.enumerate()
- .map(|(idx, inputs)| unsafe {
+ .map(|(idx, inputs)| {
let eof = eof_flags.add(idx);
if *eof == 1 {
@@ -364,25 +384,25 @@ pub extern "system" fn
Java_org_apache_comet_Native_executePlan(
match poll_output {
Poll::Ready(Some(output)) => {
- return prepare_output(output, env, exec_context);
+ return prepare_output(&mut env, output, exec_context);
}
Poll::Ready(None) => {
// Reaches EOF of output.
// Update metrics
- update_metrics(&env, exec_context)?;
+ update_metrics(&mut env, exec_context)?;
let long_array = env.new_long_array(1)?;
- env.set_long_array_region(long_array, 0, &[-1])?;
+ env.set_long_array_region(&long_array, 0, &[-1])?;
- return Ok(long_array);
+ return Ok(long_array.into_raw());
}
- // After reaching the end of any input, a poll pending means
there are more than one
- // blocking operators, we don't need go back-forth between
JVM/Native. Just
- // keeping polling.
+ // After reaching the end of any input, a poll pending means
there are more than
+ // one blocking operators, we don't need go back-forth
+ // between JVM/Native. Just keeping polling.
Poll::Pending if finished => {
// Update metrics
- update_metrics(&env, exec_context)?;
+ update_metrics(&mut env, exec_context)?;
// Output not ready yet
continue;
@@ -391,7 +411,7 @@ pub extern "system" fn
Java_org_apache_comet_Native_executePlan(
// operators. Just returning to keep reading next input.
Poll::Pending => {
// Update metrics
- update_metrics(&env, exec_context)?;
+ update_metrics(&mut env, exec_context)?;
return return_pending(env);
}
}
@@ -401,19 +421,18 @@ pub extern "system" fn
Java_org_apache_comet_Native_executePlan(
fn return_pending(env: JNIEnv) -> Result<jlongArray, CometError> {
let long_array = env.new_long_array(1)?;
- env.set_long_array_region(long_array, 0, &[0])?;
-
- Ok(long_array)
+ env.set_long_array_region(&long_array, 0, &[0])?;
+ Ok(long_array.into_raw())
}
#[no_mangle]
/// Peeks into next output if any.
pub extern "system" fn Java_org_apache_comet_Native_peekNext(
- env: JNIEnv,
+ e: JNIEnv,
_class: JClass,
exec_context: jlong,
) -> jlongArray {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |mut env| {
// Retrieve the query
let exec_context = get_execution_context(exec_context);
@@ -427,10 +446,10 @@ pub extern "system" fn
Java_org_apache_comet_Native_peekNext(
let poll_output = exec_context.runtime.block_on(async {
poll!(next_item) });
match poll_output {
- Poll::Ready(Some(output)) => prepare_output(output, env,
exec_context),
+ Poll::Ready(Some(output)) => prepare_output(&mut env, output,
exec_context),
_ => {
// Update metrics
- update_metrics(&env, exec_context)?;
+ update_metrics(&mut env, exec_context)?;
return_pending(env)
}
}
@@ -440,11 +459,11 @@ pub extern "system" fn
Java_org_apache_comet_Native_peekNext(
#[no_mangle]
/// Drop the native query plan object and context object.
pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
- env: JNIEnv,
+ e: JNIEnv,
_class: JClass,
exec_context: jlong,
) {
- try_unwrap_or_throw(env, || unsafe {
+ try_unwrap_or_throw(&e, |_| unsafe {
let execution_context = get_execution_context(exec_context);
let _: Box<ExecutionContext> = Box::from_raw(execution_context);
Ok(())
@@ -452,51 +471,32 @@ pub extern "system" fn
Java_org_apache_comet_Native_releasePlan(
}
/// Updates the metrics of the query plan.
-fn update_metrics(env: &JNIEnv, exec_context: &ExecutionContext) -> Result<(),
CometError> {
+fn update_metrics(env: &mut JNIEnv, exec_context: &ExecutionContext) ->
CometResult<()> {
let native_query = exec_context.root_op.as_ref().unwrap();
let metrics = exec_context.metrics.as_obj();
update_comet_metric(env, metrics, native_query)
}
-/// Converts a Java array of address arrays to a Rust vector of address arrays.
-fn convert_addresses_arrays<'a>(
- env: &'a JNIEnv<'a>,
- addresses_array: jobjectArray,
-) -> JNIResult<Vec<AutoArray<'a, 'a, jlong>>> {
- let array_len = env.get_array_length(addresses_array)?;
- let mut res: Vec<AutoArray<jlong>> = Vec::new();
-
- for i in 0..array_len {
- let array: AutoArray<jlong> = env.get_array_elements(
- env.get_object_array_element(addresses_array, i)?
- .into_inner() as jlongArray,
- ReleaseMode::NoCopyBack,
- )?;
- res.push(array);
- }
-
- Ok(res)
-}
-
fn convert_datatype_arrays(
- env: &'_ JNIEnv<'_>,
+ env: &'_ mut JNIEnv<'_>,
serialized_datatypes: jobjectArray,
) -> JNIResult<Vec<ArrowDataType>> {
- let array_len = env.get_array_length(serialized_datatypes)?;
- let mut res: Vec<ArrowDataType> = Vec::new();
-
- for i in 0..array_len {
- let array = env
- .get_object_array_element(serialized_datatypes, i)?
- .into_inner() as jbyteArray;
+ unsafe {
+ let obj_array = JObjectArray::from_raw(serialized_datatypes);
+ let array_len = env.get_array_length(&obj_array)?;
+ let mut res: Vec<ArrowDataType> = Vec::new();
+
+ for i in 0..array_len {
+ let inner_array = env.get_object_array_element(&obj_array, i)?;
+ let inner_array: JByteArray = inner_array.into();
+ let bytes = env.convert_byte_array(inner_array)?;
+ let data_type =
serde::deserialize_data_type(bytes.as_slice()).unwrap();
+ let arrow_dt = to_arrow_datatype(&data_type);
+ res.push(arrow_dt);
+ }
- let bytes = env.convert_byte_array(array)?;
- let data_type =
serde::deserialize_data_type(bytes.as_slice()).unwrap();
- let arrow_dt = to_arrow_datatype(&data_type);
- res.push(arrow_dt);
+ Ok(res)
}
-
- Ok(res)
}
fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext {
@@ -507,10 +507,12 @@ fn get_execution_context<'a>(id: i64) -> &'a mut
ExecutionContext {
}
}
+/// Used by Boson shuffle external sorter to write sorted records to disk.
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
-/// Used by Comet shuffle external sorter to write sorted records to disk.
-pub extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative(
- env: JNIEnv,
+pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative(
+ e: JNIEnv,
_class: JClass,
row_addresses: jlongArray,
row_sizes: jintArray,
@@ -521,18 +523,23 @@ pub extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative(
checksum_algo: jint,
current_checksum: jlong,
) -> jlongArray {
- try_unwrap_or_throw(env, || {
- let row_num = env.get_array_length(row_addresses)? as usize;
+ try_unwrap_or_throw(&e, |mut env| unsafe {
+ let data_types = convert_datatype_arrays(&mut env,
serialized_datatypes)?;
- let data_types = convert_datatype_arrays(&env, serialized_datatypes)?;
+ let row_address_array = JLongArray::from_raw(row_addresses);
+ let row_num = env.get_array_length(&row_address_array)? as usize;
+ let row_addresses = env.get_array_elements(&row_address_array,
ReleaseMode::NoCopyBack)?;
- let row_addresses = env.get_long_array_elements(row_addresses,
ReleaseMode::NoCopyBack)?;
- let row_sizes = env.get_int_array_elements(row_sizes,
ReleaseMode::NoCopyBack)?;
+ let row_size_array = JIntArray::from_raw(row_sizes);
+ let row_sizes = env.get_array_elements(&row_size_array,
ReleaseMode::NoCopyBack)?;
let row_addresses_ptr = row_addresses.as_ptr();
let row_sizes_ptr = row_sizes.as_ptr();
- let output_path: String =
env.get_string(JString::from(file_path)).unwrap().into();
+ let output_path: String = env
+ .get_string(&JString::from_raw(file_path))
+ .unwrap()
+ .into();
let checksum_enabled = checksum_enabled == 1;
let current_checksum = if current_checksum == i64::MIN {
@@ -563,21 +570,21 @@ pub extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative(
};
let long_array = env.new_long_array(2)?;
- env.set_long_array_region(long_array, 0, &[written_bytes, checksum])?;
+ env.set_long_array_region(&long_array, 0, &[written_bytes, checksum])?;
- Ok(long_array)
+ Ok(long_array.into_raw())
})
}
#[no_mangle]
-/// Used by Comet shuffle external sorter to sort in-memory row partition ids.
+/// Used by Boson shuffle external sorter to sort in-memory row partition ids.
pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative(
- env: JNIEnv,
+ e: JNIEnv,
_class: JClass,
address: jlong,
size: jlong,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |_| {
// SAFETY: JVM unsafe memory allocation is aligned with long.
let array = unsafe { std::slice::from_raw_parts_mut(address as *mut
i64, size as usize) };
array.rdxsort();
diff --git a/core/src/execution/metrics/utils.rs
b/core/src/execution/metrics/utils.rs
index eb36a55..6990aa5 100644
--- a/core/src/execution/metrics/utils.rs
+++ b/core/src/execution/metrics/utils.rs
@@ -27,8 +27,8 @@ use std::sync::Arc;
/// update the metrics of all the children nodes. The metrics are pulled from
the
/// DataFusion execution plan and pushed to the Java side through JNI.
pub fn update_comet_metric(
- env: &JNIEnv,
- metric_node: JObject,
+ env: &mut JNIEnv,
+ metric_node: &JObject,
execution_plan: &Arc<dyn ExecutionPlan>,
) -> Result<(), CometError> {
update_metrics(
@@ -43,27 +43,31 @@ pub fn update_comet_metric(
.collect::<Vec<_>>(),
)?;
- for (i, child_plan) in execution_plan.children().iter().enumerate() {
- let child_metric_node: JObject = jni_call!(env,
- comet_metric_node(metric_node).get_child_node(i as i32) -> JObject
- )?;
- if child_metric_node.is_null() {
- continue;
+ unsafe {
+ for (i, child_plan) in execution_plan.children().iter().enumerate() {
+ let child_metric_node: JObject = jni_call!(env,
+ comet_metric_node(metric_node).get_child_node(i as i32) ->
JObject
+ )?;
+ if child_metric_node.is_null() {
+ continue;
+ }
+ update_comet_metric(env, &child_metric_node, child_plan)?;
}
- update_comet_metric(env, child_metric_node, child_plan)?;
}
Ok(())
}
#[inline]
fn update_metrics(
- env: &JNIEnv,
- metric_node: JObject,
+ env: &mut JNIEnv,
+ metric_node: &JObject,
metric_values: &[(&str, i64)],
) -> Result<(), CometError> {
- for &(name, value) in metric_values {
- let jname = jni_new_string!(env, &name)?;
- jni_call!(env, comet_metric_node(metric_node).add(jname, value) ->
())?;
+ unsafe {
+ for &(name, value) in metric_values {
+ let jname = jni_new_string!(env, &name)?;
+ jni_call!(env, comet_metric_node(metric_node).add(&jname, value)
-> ())?;
+ }
}
Ok(())
}
diff --git a/core/src/jvm_bridge/comet_exec.rs
b/core/src/jvm_bridge/comet_exec.rs
index e28fc08..6b6652e 100644
--- a/core/src/jvm_bridge/comet_exec.rs
+++ b/core/src/jvm_bridge/comet_exec.rs
@@ -18,7 +18,7 @@
use jni::{
errors::Result as JniResult,
objects::{JClass, JStaticMethodID},
- signature::{JavaType, Primitive},
+ signature::{Primitive, ReturnType},
JNIEnv,
};
@@ -27,75 +27,83 @@ use super::get_global_jclass;
/// A struct that holds all the JNI methods and fields for JVM CometExec
object.
pub struct CometExec<'a> {
pub class: JClass<'a>,
- pub method_get_bool: JStaticMethodID<'a>,
- pub method_get_bool_ret: JavaType,
- pub method_get_byte: JStaticMethodID<'a>,
- pub method_get_byte_ret: JavaType,
- pub method_get_short: JStaticMethodID<'a>,
- pub method_get_short_ret: JavaType,
- pub method_get_int: JStaticMethodID<'a>,
- pub method_get_int_ret: JavaType,
- pub method_get_long: JStaticMethodID<'a>,
- pub method_get_long_ret: JavaType,
- pub method_get_float: JStaticMethodID<'a>,
- pub method_get_float_ret: JavaType,
- pub method_get_double: JStaticMethodID<'a>,
- pub method_get_double_ret: JavaType,
- pub method_get_decimal: JStaticMethodID<'a>,
- pub method_get_decimal_ret: JavaType,
- pub method_get_string: JStaticMethodID<'a>,
- pub method_get_string_ret: JavaType,
- pub method_get_binary: JStaticMethodID<'a>,
- pub method_get_binary_ret: JavaType,
- pub method_is_null: JStaticMethodID<'a>,
- pub method_is_null_ret: JavaType,
+ pub method_get_bool: JStaticMethodID,
+ pub method_get_bool_ret: ReturnType,
+ pub method_get_byte: JStaticMethodID,
+ pub method_get_byte_ret: ReturnType,
+ pub method_get_short: JStaticMethodID,
+ pub method_get_short_ret: ReturnType,
+ pub method_get_int: JStaticMethodID,
+ pub method_get_int_ret: ReturnType,
+ pub method_get_long: JStaticMethodID,
+ pub method_get_long_ret: ReturnType,
+ pub method_get_float: JStaticMethodID,
+ pub method_get_float_ret: ReturnType,
+ pub method_get_double: JStaticMethodID,
+ pub method_get_double_ret: ReturnType,
+ pub method_get_decimal: JStaticMethodID,
+ pub method_get_decimal_ret: ReturnType,
+ pub method_get_string: JStaticMethodID,
+ pub method_get_string_ret: ReturnType,
+ pub method_get_binary: JStaticMethodID,
+ pub method_get_binary_ret: ReturnType,
+ pub method_is_null: JStaticMethodID,
+ pub method_is_null_ret: ReturnType,
}
impl<'a> CometExec<'a> {
pub const JVM_CLASS: &'static str =
"org/apache/spark/sql/comet/CometScalarSubquery";
- pub fn new(env: &JNIEnv<'a>) -> JniResult<CometExec<'a>> {
+ pub fn new(env: &mut JNIEnv<'a>) -> JniResult<CometExec<'a>> {
// Get the global class reference
let class = get_global_jclass(env, Self::JVM_CLASS)?;
Ok(CometExec {
- class,
method_get_bool: env
- .get_static_method_id(class, "getBoolean", "(JJ)Z")
+ .get_static_method_id(Self::JVM_CLASS, "getBoolean", "(JJ)Z")
+ .unwrap(),
+ method_get_bool_ret: ReturnType::Primitive(Primitive::Boolean),
+ method_get_byte: env
+ .get_static_method_id(Self::JVM_CLASS, "getByte", "(JJ)B")
.unwrap(),
- method_get_bool_ret: JavaType::Primitive(Primitive::Boolean),
- method_get_byte: env.get_static_method_id(class, "getByte",
"(JJ)B").unwrap(),
- method_get_byte_ret: JavaType::Primitive(Primitive::Byte),
+ method_get_byte_ret: ReturnType::Primitive(Primitive::Byte),
method_get_short: env
- .get_static_method_id(class, "getShort", "(JJ)S")
+ .get_static_method_id(Self::JVM_CLASS, "getShort", "(JJ)S")
+ .unwrap(),
+ method_get_short_ret: ReturnType::Primitive(Primitive::Short),
+ method_get_int: env
+ .get_static_method_id(Self::JVM_CLASS, "getInt", "(JJ)I")
.unwrap(),
- method_get_short_ret: JavaType::Primitive(Primitive::Short),
- method_get_int: env.get_static_method_id(class, "getInt",
"(JJ)I").unwrap(),
- method_get_int_ret: JavaType::Primitive(Primitive::Int),
- method_get_long: env.get_static_method_id(class, "getLong",
"(JJ)J").unwrap(),
- method_get_long_ret: JavaType::Primitive(Primitive::Long),
+ method_get_int_ret: ReturnType::Primitive(Primitive::Int),
+ method_get_long: env
+ .get_static_method_id(Self::JVM_CLASS, "getLong", "(JJ)J")
+ .unwrap(),
+ method_get_long_ret: ReturnType::Primitive(Primitive::Long),
method_get_float: env
- .get_static_method_id(class, "getFloat", "(JJ)F")
+ .get_static_method_id(Self::JVM_CLASS, "getFloat", "(JJ)F")
.unwrap(),
- method_get_float_ret: JavaType::Primitive(Primitive::Float),
+ method_get_float_ret: ReturnType::Primitive(Primitive::Float),
method_get_double: env
- .get_static_method_id(class, "getDouble", "(JJ)D")
+ .get_static_method_id(Self::JVM_CLASS, "getDouble", "(JJ)D")
.unwrap(),
- method_get_double_ret: JavaType::Primitive(Primitive::Double),
+ method_get_double_ret: ReturnType::Primitive(Primitive::Double),
method_get_decimal: env
- .get_static_method_id(class, "getDecimal", "(JJ)[B")
+ .get_static_method_id(Self::JVM_CLASS, "getDecimal", "(JJ)[B")
.unwrap(),
- method_get_decimal_ret:
JavaType::Array(Box::new(JavaType::Primitive(Primitive::Byte))),
+ method_get_decimal_ret: ReturnType::Array,
method_get_string: env
- .get_static_method_id(class, "getString",
"(JJ)Ljava/lang/String;")
+ .get_static_method_id(Self::JVM_CLASS, "getString",
"(JJ)Ljava/lang/String;")
.unwrap(),
- method_get_string_ret:
JavaType::Object("java/lang/String".to_owned()),
+ method_get_string_ret: ReturnType::Object,
method_get_binary: env
- .get_static_method_id(class, "getBinary", "(JJ)[B")
+ .get_static_method_id(Self::JVM_CLASS, "getBinary", "(JJ)[B")
+ .unwrap(),
+ method_get_binary_ret: ReturnType::Array,
+ method_is_null: env
+ .get_static_method_id(Self::JVM_CLASS, "isNull", "(JJ)Z")
.unwrap(),
- method_get_binary_ret:
JavaType::Array(Box::new(JavaType::Primitive(Primitive::Byte))),
- method_is_null: env.get_static_method_id(class, "isNull",
"(JJ)Z").unwrap(),
- method_is_null_ret: JavaType::Primitive(Primitive::Boolean),
+ method_is_null_ret: ReturnType::Primitive(Primitive::Boolean),
+ class,
})
}
}
diff --git a/core/src/jvm_bridge/comet_metric_node.rs
b/core/src/jvm_bridge/comet_metric_node.rs
index 1d4928a..d0176f4 100644
--- a/core/src/jvm_bridge/comet_metric_node.rs
+++ b/core/src/jvm_bridge/comet_metric_node.rs
@@ -18,7 +18,7 @@
use jni::{
errors::Result as JniResult,
objects::{JClass, JMethodID},
- signature::{JavaType, Primitive},
+ signature::{Primitive, ReturnType},
JNIEnv,
};
@@ -27,33 +27,33 @@ use super::get_global_jclass;
/// A struct that holds all the JNI methods and fields for JVM CometMetricNode
class.
pub struct CometMetricNode<'a> {
pub class: JClass<'a>,
- pub method_get_child_node: JMethodID<'a>,
- pub method_get_child_node_ret: JavaType,
- pub method_add: JMethodID<'a>,
- pub method_add_ret: JavaType,
+ pub method_get_child_node: JMethodID,
+ pub method_get_child_node_ret: ReturnType,
+ pub method_add: JMethodID,
+ pub method_add_ret: ReturnType,
}
impl<'a> CometMetricNode<'a> {
pub const JVM_CLASS: &'static str =
"org/apache/spark/sql/comet/CometMetricNode";
- pub fn new(env: &JNIEnv<'a>) -> JniResult<CometMetricNode<'a>> {
+ pub fn new(env: &mut JNIEnv<'a>) -> JniResult<CometMetricNode<'a>> {
// Get the global class reference
let class = get_global_jclass(env, Self::JVM_CLASS)?;
Ok(CometMetricNode {
- class,
method_get_child_node: env
.get_method_id(
- class,
+ Self::JVM_CLASS,
"getChildNode",
format!("(I)L{:};", Self::JVM_CLASS).as_str(),
)
.unwrap(),
- method_get_child_node_ret:
JavaType::Object(Self::JVM_CLASS.to_owned()),
+ method_get_child_node_ret: ReturnType::Object,
method_add: env
- .get_method_id(class, "add", "(Ljava/lang/String;J)V")
+ .get_method_id(Self::JVM_CLASS, "add",
"(Ljava/lang/String;J)V")
.unwrap(),
- method_add_ret: JavaType::Primitive(Primitive::Void),
+ method_add_ret: ReturnType::Primitive(Primitive::Void),
+ class,
})
}
}
diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs
index 6f162a0..331e776 100644
--- a/core/src/jvm_bridge/mod.rs
+++ b/core/src/jvm_bridge/mod.rs
@@ -19,7 +19,7 @@
use jni::{
errors::{Error, Result as JniResult},
- objects::{JClass, JObject, JString, JValue},
+ objects::{JClass, JObject, JString, JValueGen, JValueOwned},
AttachGuard, JNIEnv,
};
use once_cell::sync::OnceCell;
@@ -38,7 +38,7 @@ macro_rules! jni_map_error {
/// Macro for converting Rust types to JNI types.
macro_rules! jvalues {
($($args:expr,)* $(,)?) => {{
- &[$(jni::objects::JValue::from($args)),*] as &[jni::objects::JValue]
+ &[$(jni::objects::JValue::from($args).as_jni()),*] as
&[jni::sys::jvalue]
}}
}
@@ -75,7 +75,7 @@ macro_rules! jni_static_call {
$crate::jvm_bridge::jni_map_error!(
$env,
$env.call_static_method_unchecked(
- paste::paste!
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]},
+ &paste::paste!
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]},
paste::paste!
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]},
paste::paste!
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method
_ret>]}.clone(),
$crate::jvm_bridge::jvalues!($($args,)*)
@@ -114,23 +114,23 @@ impl<'a> BinaryWrapper<'a> {
}
}
-impl<'a> TryFrom<JValue<'a>> for StringWrapper<'a> {
+impl<'a> TryFrom<JValueOwned<'a>> for StringWrapper<'a> {
type Error = Error;
- fn try_from(value: JValue<'a>) -> Result<StringWrapper<'a>, Error> {
+ fn try_from(value: JValueOwned<'a>) -> Result<StringWrapper<'a>, Error> {
match value {
- JValue::Object(b) => Ok(StringWrapper::new(JString::from(b))),
+ JValueGen::Object(b) => Ok(StringWrapper::new(JString::from(b))),
_ => Err(Error::WrongJValueType("object", value.type_name())),
}
}
}
-impl<'a> TryFrom<JValue<'a>> for BinaryWrapper<'a> {
+impl<'a> TryFrom<JValueOwned<'a>> for BinaryWrapper<'a> {
type Error = Error;
- fn try_from(value: JValue<'a>) -> Result<BinaryWrapper<'a>, Error> {
+ fn try_from(value: JValueOwned<'a>) -> Result<BinaryWrapper<'a>, Error> {
match value {
- JValue::Object(b) => Ok(BinaryWrapper::new(b)),
+ JValueGen::Object(b) => Ok(BinaryWrapper::new(b)),
_ => Err(Error::WrongJValueType("object", value.type_name())),
}
}
@@ -151,7 +151,7 @@ pub(crate) use jni_static_call;
pub(crate) use jvalues;
/// Gets a global reference to a Java class.
-pub fn get_global_jclass(env: &JNIEnv<'_>, cls: &str) ->
JniResult<JClass<'static>> {
+pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) ->
JniResult<JClass<'static>> {
let local_jclass = env.find_class(cls)?;
let global = env.new_global_ref::<JObject>(local_jclass.into())?;
@@ -186,11 +186,11 @@ static JVM_CLASSES: OnceCell<JVMClasses> =
OnceCell::new();
impl JVMClasses<'_> {
/// Creates a new JVMClasses struct.
- pub fn init(env: &JNIEnv) {
+ pub fn init(env: &mut JNIEnv) {
JVM_CLASSES.get_or_init(|| {
// A hack to make the `JNIEnv` static. It is not safe but we don't
really use the
// `JNIEnv` except for creating the global references of the
classes.
- let env = unsafe { std::mem::transmute::<_, &'static JNIEnv>(env)
};
+ let env = unsafe { std::mem::transmute::<_, &'static mut
JNIEnv>(env) };
JVMClasses {
comet_metric_node: CometMetricNode::new(env).unwrap(),
diff --git a/core/src/lib.rs b/core/src/lib.rs
index c85263f..d104788 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -45,7 +45,7 @@ use once_cell::sync::OnceCell;
pub use data_type::*;
-use crate::errors::{try_unwrap_or_throw, CometError, CometResult};
+use errors::{try_unwrap_or_throw, CometError, CometResult};
#[macro_use]
mod errors;
@@ -64,15 +64,15 @@ static JAVA_VM: OnceCell<JavaVM> = OnceCell::new();
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_NativeBase_init(
- env: JNIEnv,
+ e: JNIEnv,
_: JClass,
log_conf_path: JString,
) {
// Initialize the error handling to capture panic backtraces
errors::init();
- try_unwrap_or_throw(env, || {
- let path: String = env.get_string(log_conf_path)?.into();
+ try_unwrap_or_throw(&e, |mut env| {
+ let path: String = env.get_string(&log_conf_path)?.into();
// empty path means there is no custom log4rs config file provided, so
fallback to use
// the default configuration
diff --git a/core/src/parquet/mod.rs b/core/src/parquet/mod.rs
index b1a7b93..4f87d15 100644
--- a/core/src/parquet/mod.rs
+++ b/core/src/parquet/mod.rs
@@ -41,7 +41,7 @@ use jni::{
use crate::execution::utils::SparkArrowConvert;
use arrow::buffer::{Buffer, MutableBuffer};
-use jni::objects::ReleaseMode;
+use jni::objects::{JBooleanArray, JLongArray, JPrimitiveArray, ReleaseMode};
use read::ColumnReader;
use util::jni::{convert_column_descriptor, convert_encoding};
@@ -58,7 +58,7 @@ struct Context {
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader(
- env: JNIEnv,
+ e: JNIEnv,
_jclass: JClass,
primitive_type: jint,
logical_type: jint,
@@ -78,9 +78,9 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_initColumnReader(
use_decimal_128: jboolean,
use_legacy_date_timestamp: jboolean,
) -> jlong {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |mut env| {
let desc = convert_column_descriptor(
- &env,
+ &mut env,
primitive_type,
logical_type,
max_dl,
@@ -111,66 +111,74 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_initColumnReader(
})
}
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_setDictionaryPage(
- env: JNIEnv,
+pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setDictionaryPage(
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
page_value_count: jint,
page_data: jbyteArray,
encoding: jint,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |env| {
let reader = get_reader(handle)?;
// convert value encoding ordinal to the native encoding definition
let encoding = convert_encoding(encoding);
// copy the input on-heap buffer to native
- let page_len = env.get_array_length(page_data)?;
+ let page_data_array = unsafe { JPrimitiveArray::from_raw(page_data) };
+ let page_len = env.get_array_length(&page_data_array)?;
let mut buffer = MutableBuffer::from_len_zeroed(page_len as usize);
- env.get_byte_array_region(page_data, 0,
from_u8_slice(buffer.as_slice_mut()))?;
+ env.get_byte_array_region(&page_data_array, 0,
from_u8_slice(buffer.as_slice_mut()))?;
reader.set_dictionary_page(page_value_count as usize, buffer.into(),
encoding);
Ok(())
})
}
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageV1(
- env: JNIEnv,
+pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageV1(
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
page_value_count: jint,
page_data: jbyteArray,
value_encoding: jint,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |env| {
let reader = get_reader(handle)?;
// convert value encoding ordinal to the native encoding definition
let encoding = convert_encoding(value_encoding);
// copy the input on-heap buffer to native
- let page_len = env.get_array_length(page_data)?;
+ let page_data_array = unsafe { JPrimitiveArray::from_raw(page_data) };
+ let page_len = env.get_array_length(&page_data_array)?;
let mut buffer = MutableBuffer::from_len_zeroed(page_len as usize);
- env.get_byte_array_region(page_data, 0,
from_u8_slice(buffer.as_slice_mut()))?;
+ env.get_byte_array_region(&page_data_array, 0,
from_u8_slice(buffer.as_slice_mut()))?;
reader.set_page_v1(page_value_count as usize, buffer.into(), encoding);
Ok(())
})
}
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageBufferV1(
- env: JNIEnv,
+pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setPageBufferV1(
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
page_value_count: jint,
buffer: jobject,
value_encoding: jint,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |env| {
let ctx = get_context(handle)?;
let reader = &mut ctx.column_reader;
@@ -178,19 +186,20 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setPageBufferV1(
let encoding = convert_encoding(value_encoding);
// Get slices from Java DirectByteBuffer
- let jbuffer = JByteBuffer::from(buffer);
+ let jbuffer = unsafe { JByteBuffer::from_raw(buffer) };
// Convert the page to global reference so it won't get GC'd by Java.
Also free the last
// page if there is any.
- ctx.last_data_page = Some(env.new_global_ref(jbuffer)?);
+ ctx.last_data_page = Some(env.new_global_ref(&jbuffer)?);
- let buf_slice = env.get_direct_buffer_address(jbuffer)?;
+ let buf_slice = env.get_direct_buffer_address(&jbuffer)?;
+ let buf_capacity = env.get_direct_buffer_capacity(&jbuffer)?;
unsafe {
- let page_ptr = NonNull::new_unchecked(buf_slice.as_ptr() as *mut
u8);
+ let page_ptr = NonNull::new_unchecked(buf_slice);
let buffer = Buffer::from_custom_allocation(
page_ptr,
- buf_slice.len(),
+ buf_capacity,
Arc::new(FFI_ArrowArray::empty()),
);
reader.set_page_v1(page_value_count as usize, buffer, encoding);
@@ -199,9 +208,11 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setPageBufferV1(
})
}
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageV2(
- env: JNIEnv,
+pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageV2(
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
page_value_count: jint,
@@ -210,24 +221,27 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setPageV2(
value_data: jbyteArray,
value_encoding: jint,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |env| {
let reader = get_reader(handle)?;
// convert value encoding ordinal to the native encoding definition
let encoding = convert_encoding(value_encoding);
// copy the input on-heap buffer to native
- let dl_len = env.get_array_length(def_level_data)?;
+ let def_level_array = unsafe {
JPrimitiveArray::from_raw(def_level_data) };
+ let dl_len = env.get_array_length(&def_level_array)?;
let mut dl_buffer = MutableBuffer::from_len_zeroed(dl_len as usize);
- env.get_byte_array_region(def_level_data, 0,
from_u8_slice(dl_buffer.as_slice_mut()))?;
+ env.get_byte_array_region(&def_level_array, 0,
from_u8_slice(dl_buffer.as_slice_mut()))?;
- let rl_len = env.get_array_length(rep_level_data)?;
+ let rep_level_array = unsafe {
JPrimitiveArray::from_raw(rep_level_data) };
+ let rl_len = env.get_array_length(&rep_level_array)?;
let mut rl_buffer = MutableBuffer::from_len_zeroed(rl_len as usize);
- env.get_byte_array_region(rep_level_data, 0,
from_u8_slice(rl_buffer.as_slice_mut()))?;
+ env.get_byte_array_region(&rep_level_array, 0,
from_u8_slice(rl_buffer.as_slice_mut()))?;
- let v_len = env.get_array_length(value_data)?;
+ let value_array = unsafe { JPrimitiveArray::from_raw(value_data) };
+ let v_len = env.get_array_length(&value_array)?;
let mut v_buffer = MutableBuffer::from_len_zeroed(v_len as usize);
- env.get_byte_array_region(value_data, 0,
from_u8_slice(v_buffer.as_slice_mut()))?;
+ env.get_byte_array_region(&value_array, 0,
from_u8_slice(v_buffer.as_slice_mut()))?;
reader.set_page_v2(
page_value_count as usize,
@@ -246,7 +260,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setNull(
_jclass: JClass,
handle: jlong,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.set_null();
Ok(())
@@ -260,7 +274,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setBoolean(
handle: jlong,
value: jboolean,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.set_boolean(value != 0);
Ok(())
@@ -274,7 +288,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setByte(
handle: jlong,
value: jbyte,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.set_fixed::<i8>(value);
Ok(())
@@ -288,7 +302,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setShort(
handle: jlong,
value: jshort,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.set_fixed::<i16>(value);
Ok(())
@@ -302,7 +316,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setInt(
handle: jlong,
value: jint,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.set_fixed::<i32>(value);
Ok(())
@@ -316,7 +330,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setLong(
handle: jlong,
value: jlong,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.set_fixed::<i64>(value);
Ok(())
@@ -330,7 +344,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setFloat(
handle: jlong,
value: jfloat,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.set_fixed::<f32>(value);
Ok(())
@@ -344,44 +358,50 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setDouble(
handle: jlong,
value: jdouble,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.set_fixed::<f64>(value);
Ok(())
})
}
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_setBinary(
- env: JNIEnv,
+pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setBinary(
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
value: jbyteArray,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |env| {
let reader = get_reader(handle)?;
- let len = env.get_array_length(value)?;
+ let value_array = unsafe { JPrimitiveArray::from_raw(value) };
+ let len = env.get_array_length(&value_array)?;
let mut buffer = MutableBuffer::from_len_zeroed(len as usize);
- env.get_byte_array_region(value, 0,
from_u8_slice(buffer.as_slice_mut()))?;
+ env.get_byte_array_region(&value_array, 0,
from_u8_slice(buffer.as_slice_mut()))?;
reader.set_binary(buffer);
Ok(())
})
}
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_setDecimal(
- env: JNIEnv,
+pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setDecimal(
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
value: jbyteArray,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |env| {
let reader = get_reader(handle)?;
- let len = env.get_array_length(value)?;
+ let value_array = unsafe { JPrimitiveArray::from_raw(value) };
+ let len = env.get_array_length(&value_array)?;
let mut buffer = MutableBuffer::from_len_zeroed(len as usize);
- env.get_byte_array_region(value, 0,
from_u8_slice(buffer.as_slice_mut()))?;
+ env.get_byte_array_region(&value_array, 0,
from_u8_slice(buffer.as_slice_mut()))?;
reader.set_decimal_flba(buffer);
Ok(())
})
@@ -395,26 +415,29 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setPosition(
value: jlong,
size: jint,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.set_position(value, size as usize);
Ok(())
})
}
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_setIndices(
- env: JNIEnv,
+pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setIndices(
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
offset: jlong,
batch_size: jint,
indices: jlongArray,
) -> jlong {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |mut env| {
let reader = get_reader(handle)?;
- let indices = env.get_long_array_elements(indices,
ReleaseMode::NoCopyBack)?;
- let len = indices.size()? as usize;
+ let indice_array = unsafe { JLongArray::from_raw(indices) };
+ let indices = unsafe { env.get_array_elements(&indice_array,
ReleaseMode::NoCopyBack)? };
+ let len = indices.len();
// paris alternately contains start index and length of continuous
indices
let pairs = unsafe { core::slice::from_raw_parts_mut(indices.as_ptr(),
len) };
let mut skipped = 0;
@@ -437,19 +460,22 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_setIndices(
})
}
+/// # Safety
+/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
-pub extern "system" fn Java_org_apache_comet_parquet_Native_setIsDeleted(
- env: JNIEnv,
+pub unsafe extern "system" fn
Java_org_apache_comet_parquet_Native_setIsDeleted(
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
is_deleted: jbooleanArray,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |env| {
let reader = get_reader(handle)?;
- let len = env.get_array_length(is_deleted)?;
+ let is_deleted_array = unsafe { JBooleanArray::from_raw(is_deleted) };
+ let len = env.get_array_length(&is_deleted_array)?;
let mut buffer = MutableBuffer::from_len_zeroed(len as usize);
- env.get_boolean_array_region(is_deleted, 0, buffer.as_slice_mut())?;
+ env.get_boolean_array_region(&is_deleted_array, 0,
buffer.as_slice_mut())?;
reader.set_is_deleted(buffer);
Ok(())
})
@@ -461,7 +487,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_resetBatch(
_jclass: JClass,
handle: jlong,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
reader.reset_batch();
Ok(())
@@ -470,20 +496,20 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_resetBatch(
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_readBatch(
- env: JNIEnv,
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
batch_size: jint,
null_pad_size: jint,
) -> jintArray {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |env| {
let reader = get_reader(handle)?;
let (num_values, num_nulls) =
reader.read_batch(batch_size as usize, null_pad_size as usize);
let res = env.new_int_array(2)?;
let buf: [i32; 2] = [num_values as i32, num_nulls as i32];
- env.set_int_array_region(res, 0, &buf)?;
- Ok(res)
+ env.set_int_array_region(&res, 0, &buf)?;
+ Ok(res.into_raw())
})
}
@@ -495,7 +521,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_skipBatch(
batch_size: jint,
discard: jboolean,
) -> jint {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
let reader = get_reader(handle)?;
Ok(reader.skip_batch(batch_size as usize, discard == 0) as jint)
})
@@ -503,11 +529,11 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_skipBatch(
#[no_mangle]
pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch(
- env: JNIEnv,
+ e: JNIEnv,
_jclass: JClass,
handle: jlong,
) -> jlongArray {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&e, |env| {
let ctx = get_context(handle)?;
let reader = &mut ctx.column_reader;
let data = reader.current_batch();
@@ -520,9 +546,9 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_currentBatch(
let res = env.new_long_array(2)?;
let buf: [i64; 2] = [array, schema];
- env.set_long_array_region(res, 0, &buf)
+ env.set_long_array_region(&res, 0, &buf)
.expect("set long array region failed");
- Ok(res)
+ Ok(res.into_raw())
}
})
}
@@ -547,7 +573,7 @@ pub extern "system" fn
Java_org_apache_comet_parquet_Native_closeColumnReader(
_jclass: JClass,
handle: jlong,
) {
- try_unwrap_or_throw(env, || {
+ try_unwrap_or_throw(&env, |_| {
unsafe {
let ctx = handle as *mut Context;
let _ = Box::from_raw(ctx);
diff --git a/core/src/parquet/util/jni.rs b/core/src/parquet/util/jni.rs
index 000eeee..225abfc 100644
--- a/core/src/parquet/util/jni.rs
+++ b/core/src/parquet/util/jni.rs
@@ -19,8 +19,8 @@ use std::sync::Arc;
use jni::{
errors::Result as JNIResult,
- objects::{JMethodID, JString},
- sys::{jboolean, jint, jobjectArray, jstring},
+ objects::{JObjectArray, JString},
+ sys::{jboolean, jint, jobjectArray},
JNIEnv,
};
@@ -33,7 +33,7 @@ use parquet::{
/// Convert primitives from Spark side into a `ColumnDescriptor`.
#[allow(clippy::too_many_arguments)]
pub fn convert_column_descriptor(
- env: &JNIEnv,
+ env: &mut JNIEnv,
physical_type_id: jint,
logical_type_id: jint,
max_dl: jint,
@@ -114,12 +114,13 @@ impl TypePromotionInfo {
}
}
-fn convert_column_path(env: &JNIEnv, path: jobjectArray) ->
JNIResult<ColumnPath> {
- let array_len = env.get_array_length(path)?;
+fn convert_column_path(env: &mut JNIEnv, path: jobjectArray) ->
JNIResult<ColumnPath> {
+ let path_array = unsafe { JObjectArray::from_raw(path) };
+ let array_len = env.get_array_length(&path_array)?;
let mut res: Vec<String> = Vec::new();
for i in 0..array_len {
- let p: JString = (env.get_object_array_element(path, i)?.into_inner()
as jstring).into();
- res.push(env.get_string(p)?.into());
+ let p: JString = env.get_object_array_element(&path_array, i)?.into();
+ res.push(env.get_string(&p)?.into());
}
Ok(ColumnPath::new(res))
}
@@ -184,16 +185,3 @@ fn fix_type_length(t: &PhysicalType, type_length: i32) ->
i32 {
_ => type_length,
}
}
-
-fn get_method_id<'a>(env: &'a JNIEnv, class: &'a str, method: &str, sig: &str)
-> JMethodID<'a> {
- // first verify the class exists
- let _ = env
- .find_class(class)
- .unwrap_or_else(|_| panic!("Class '{}' not found", class));
- env.get_method_id(class, method, sig).unwrap_or_else(|_| {
- panic!(
- "Method '{}' with signature '{}' of class '{}' not found",
- method, sig, class
- )
- })
-}
diff --git a/core/src/parquet/util/mod.rs b/core/src/parquet/util/mod.rs
index 6a8c731..7a37b78 100644
--- a/core/src/parquet/util/mod.rs
+++ b/core/src/parquet/util/mod.rs
@@ -22,7 +22,5 @@ pub mod memory;
mod buffer;
pub use buffer::*;
-mod jni_buffer;
-pub use jni_buffer::*;
pub mod test_common;