This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 180f962 feat: Handle exception thrown from native side (#61)
180f962 is described below
commit 180f962c17c4a3362e3a4b25e0d3462aba4cab21
Author: Chao Sun <[email protected]>
AuthorDate: Tue Feb 20 12:46:15 2024 -0800
feat: Handle exception thrown from native side (#61)
This PR catches exceptions thrown from native side via calling Java
methods, and convert them into a `CometError::JavaException` which can then be
properly propagated to the JVM.
---
core/src/errors.rs | 3 +
.../execution/datafusion/expressions/subquery.rs | 26 ++--
core/src/jvm_bridge/mod.rs | 157 ++++++++++++++++++---
3 files changed, 153 insertions(+), 33 deletions(-)
diff --git a/core/src/errors.rs b/core/src/errors.rs
index 7188ebd..936d97d 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -122,6 +122,9 @@ pub enum CometError {
#[from]
source: DataFusionError,
},
+
+ #[error("{class}: {msg}")]
+ JavaException { class: String, msg: String },
}
pub fn init() {
diff --git a/core/src/execution/datafusion/expressions/subquery.rs
b/core/src/execution/datafusion/expressions/subquery.rs
index a4b32ba..7cae129 100644
--- a/core/src/execution/datafusion/expressions/subquery.rs
+++ b/core/src/execution/datafusion/expressions/subquery.rs
@@ -93,7 +93,7 @@ impl PhysicalExpr for Subquery {
let mut env = JVMClasses::get_env();
unsafe {
- let is_null = jni_static_call!(env,
+ let is_null = jni_static_call!(&mut env,
comet_exec.is_null(self.exec_context_id, self.id) -> jboolean
)?;
@@ -105,50 +105,50 @@ impl PhysicalExpr for Subquery {
match &self.data_type {
DataType::Boolean => {
- let r = jni_static_call!(env,
+ let r = jni_static_call!(&mut env,
comet_exec.get_bool(self.exec_context_id, self.id) ->
jboolean
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r >
0))))
}
DataType::Int8 => {
- let r = jni_static_call!(env,
+ let r = jni_static_call!(&mut env,
comet_exec.get_byte(self.exec_context_id, self.id) ->
jbyte
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r))))
}
DataType::Int16 => {
- let r = jni_static_call!(env,
+ let r = jni_static_call!(&mut env,
comet_exec.get_short(self.exec_context_id, self.id) ->
jshort
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r))))
}
DataType::Int32 => {
- let r = jni_static_call!(env,
+ let r = jni_static_call!(&mut env,
comet_exec.get_int(self.exec_context_id, self.id) ->
jint
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r))))
}
DataType::Int64 => {
- let r = jni_static_call!(env,
+ let r = jni_static_call!(&mut env,
comet_exec.get_long(self.exec_context_id, self.id) ->
jlong
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r))))
}
DataType::Float32 => {
- let r = jni_static_call!(env,
+ let r = jni_static_call!(&mut env,
comet_exec.get_float(self.exec_context_id, self.id) ->
f32
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r))))
}
DataType::Float64 => {
- let r = jni_static_call!(env,
+ let r = jni_static_call!(&mut env,
comet_exec.get_double(self.exec_context_id, self.id)
-> f64
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r))))
}
DataType::Decimal128(p, s) => {
- let bytes = jni_static_call!(env,
+ let bytes = jni_static_call!(&mut env,
comet_exec.get_decimal(self.exec_context_id, self.id)
-> BinaryWrapper
)?;
let bytes: &JByteArray = bytes.get().into();
@@ -161,14 +161,14 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Date32 => {
- let r = jni_static_call!(env,
+ let r = jni_static_call!(&mut env,
comet_exec.get_int(self.exec_context_id, self.id) ->
jint
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r))))
}
DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
- let r = jni_static_call!(env,
+ let r = jni_static_call!(&mut env,
comet_exec.get_long(self.exec_context_id, self.id) ->
jlong
)?;
@@ -178,7 +178,7 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Utf8 => {
- let string = jni_static_call!(env,
+ let string = jni_static_call!(&mut env,
comet_exec.get_string(self.exec_context_id, self.id)
-> StringWrapper
)?;
@@ -186,7 +186,7 @@ impl PhysicalExpr for Subquery {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))))
}
DataType::Binary => {
- let bytes = jni_static_call!(env,
+ let bytes = jni_static_call!(&mut env,
comet_exec.get_binary(self.exec_context_id, self.id)
-> BinaryWrapper
)?;
let bytes: &JByteArray = bytes.get().into();
diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs
index 331e776..d3db7ba 100644
--- a/core/src/jvm_bridge/mod.rs
+++ b/core/src/jvm_bridge/mod.rs
@@ -19,7 +19,8 @@
use jni::{
errors::{Error, Result as JniResult},
- objects::{JClass, JObject, JString, JValueGen, JValueOwned},
+ objects::{JClass, JMethodID, JObject, JString, JThrowable, JValueGen,
JValueOwned},
+ signature::ReturnType,
AttachGuard, JNIEnv,
};
use once_cell::sync::OnceCell;
@@ -58,29 +59,52 @@ macro_rules! jni_new_string {
/// jname and value are the arguments.
macro_rules! jni_call {
($env:expr, $clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?)
-> $ret:ty) => {{
- $crate::jvm_bridge::jni_map_error!(
- $env,
- $env.call_method_unchecked(
- $obj,
- paste::paste!
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]},
- paste::paste!
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method
_ret>]}.clone(),
- $crate::jvm_bridge::jvalues!($($args,)*)
- )
- ).and_then(|result| $crate::jvm_bridge::jni_map_error!($env,
<$ret>::try_from(result)))
+ let method_id = paste::paste! {
+ $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_
$method>]
+ };
+ let ret_type = paste::paste! {
+ $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_
$method _ret>]
+ }.clone();
+ let args = $crate::jvm_bridge::jvalues!($($args,)*);
+
+ // Call the JVM method and obtain the returned value
+ let ret = $env.call_method_unchecked($obj, method_id, ret_type, args);
+
+ // Check if JVM has thrown any exception, and handle it if so.
+ let result = if let Some(exception) =
$crate::jvm_bridge::check_exception($env).unwrap() {
+ Err(exception.into())
+ } else {
+ $crate::jvm_bridge::jni_map_error!($env, ret)
+ };
+
+ result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env,
<$ret>::try_from(result)))
}}
}
macro_rules! jni_static_call {
($env:expr, $clsname:ident.$method:ident($($args:expr),* $(,)?) ->
$ret:ty) => {{
- $crate::jvm_bridge::jni_map_error!(
- $env,
- $env.call_static_method_unchecked(
- &paste::paste!
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]},
- paste::paste!
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]},
- paste::paste!
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method
_ret>]}.clone(),
- $crate::jvm_bridge::jvalues!($($args,)*)
- )
- ).and_then(|result| $crate::jvm_bridge::jni_map_error!($env,
<$ret>::try_from(result)))
+ let clazz = &paste::paste! {
+ $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]
+ };
+ let method_id = paste::paste! {
+ $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_
$method>]
+ };
+ let ret_type = paste::paste! {
+ $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_
$method _ret>]
+ }.clone();
+ let args = $crate::jvm_bridge::jvalues!($($args,)*);
+
+ // Call the JVM static method and obtain the returned value
+ let ret = $env.call_static_method_unchecked(clazz, method_id,
ret_type, args);
+
+ // Check if JVM has thrown any exception, and handle it if so.
+ let result = if let Some(exception) =
$crate::jvm_bridge::check_exception($env).unwrap() {
+ Err(exception.into())
+ } else {
+ $crate::jvm_bridge::jni_map_error!($env, ret)
+ };
+
+ result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env,
<$ret>::try_from(result)))
}}
}
@@ -167,11 +191,21 @@ pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) ->
JniResult<JClass<'stati
mod comet_exec;
pub use comet_exec::*;
mod comet_metric_node;
-use crate::JAVA_VM;
+use crate::{
+ errors::{CometError, CometResult},
+ JAVA_VM,
+};
pub use comet_metric_node::*;
/// The JVM classes that are used in the JNI calls.
pub struct JVMClasses<'a> {
+ /// Cached method ID for "java.lang.Object#getClass"
+ pub object_get_class_method: JMethodID,
+ /// Cached method ID for "java.lang.Class#getName"
+ pub class_get_name_method: JMethodID,
+ /// Cached method ID for "java.lang.Throwable#getMessage"
+ pub throwable_get_message_method: JMethodID,
+
/// The CometMetricNode class. Used for updating the metrics.
pub comet_metric_node: CometMetricNode<'a>,
/// The static CometExec class. Used for getting the subquery result.
@@ -192,7 +226,25 @@ impl JVMClasses<'_> {
// `JNIEnv` except for creating the global references of the
classes.
let env = unsafe { std::mem::transmute::<_, &'static mut
JNIEnv>(env) };
+ let clazz = env.find_class("java/lang/Object").unwrap();
+ let object_get_class_method = env
+ .get_method_id(clazz, "getClass", "()Ljava/lang/Class;")
+ .unwrap();
+
+ let clazz = env.find_class("java/lang/Class").unwrap();
+ let class_get_name_method = env
+ .get_method_id(clazz, "getName", "()Ljava/lang/String;")
+ .unwrap();
+
+ let clazz = env.find_class("java/lang/Throwable").unwrap();
+ let throwable_get_message_method = env
+ .get_method_id(clazz, "getMessage", "()Ljava/lang/String;")
+ .unwrap();
+
JVMClasses {
+ object_get_class_method,
+ class_get_name_method,
+ throwable_get_message_method,
comet_metric_node: CometMetricNode::new(env).unwrap(),
comet_exec: CometExec::new(env).unwrap(),
}
@@ -211,3 +263,68 @@ impl JVMClasses<'_> {
}
}
}
+
+pub(crate) fn check_exception(env: &mut JNIEnv) ->
CometResult<Option<CometError>> {
+ let result = if env.exception_check()? {
+ let exception = env.exception_occurred()?;
+ env.exception_clear()?;
+ let exception_err = convert_exception(env, &exception)?;
+ Some(exception_err)
+ } else {
+ None
+ };
+
+ Ok(result)
+}
+
+/// Given a `JThrowable` which is thrown from calling a Java method on the
native side,
+/// this converts it into a `CometError::JavaException` with the exception
class name
+/// and exception message. This error can then be populated to the JVM side to
let
+/// users know the cause of the native side error.
+pub(crate) fn convert_exception(
+ env: &mut JNIEnv,
+ throwable: &JThrowable,
+) -> CometResult<CometError> {
+ unsafe {
+ let cache = JVMClasses::get();
+
+ // get the class name of the exception by:
+ // 1. get the `Class` object of the input `throwable` via
`Object#getClass` method
+ // 2. get the exception class name via calling `Class#getName` on the
above object
+ let class_obj = env
+ .call_method_unchecked(
+ throwable,
+ cache.object_get_class_method,
+ ReturnType::Object,
+ &[],
+ )?
+ .l()?;
+ let exception_class_name = env
+ .call_method_unchecked(
+ class_obj,
+ cache.class_get_name_method,
+ ReturnType::Object,
+ &[],
+ )?
+ .l()?
+ .into();
+ let exception_class_name_str =
env.get_string(&exception_class_name)?.into();
+
+ // get the exception message via calling `Throwable#getMessage` on the
throwable object
+ let message = env
+ .call_method_unchecked(
+ throwable,
+ cache.throwable_get_message_method,
+ ReturnType::Object,
+ &[],
+ )?
+ .l()?
+ .into();
+ let message_str = env.get_string(&message)?.into();
+
+ Ok(CometError::JavaException {
+ class: exception_class_name_str,
+ msg: message_str,
+ })
+ }
+}