This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 180f962  feat: Handle exception thrown from native side (#61)
180f962 is described below

commit 180f962c17c4a3362e3a4b25e0d3462aba4cab21
Author: Chao Sun <[email protected]>
AuthorDate: Tue Feb 20 12:46:15 2024 -0800

    feat: Handle exception thrown from native side (#61)
    
    This PR catches exceptions thrown from native side via calling Java 
methods, and convert them into a `CometError::JavaException` which can then be 
properly propagated to the JVM.
---
 core/src/errors.rs                                 |   3 +
 .../execution/datafusion/expressions/subquery.rs   |  26 ++--
 core/src/jvm_bridge/mod.rs                         | 157 ++++++++++++++++++---
 3 files changed, 153 insertions(+), 33 deletions(-)

diff --git a/core/src/errors.rs b/core/src/errors.rs
index 7188ebd..936d97d 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -122,6 +122,9 @@ pub enum CometError {
         #[from]
         source: DataFusionError,
     },
+
+    #[error("{class}: {msg}")]
+    JavaException { class: String, msg: String },
 }
 
 pub fn init() {
diff --git a/core/src/execution/datafusion/expressions/subquery.rs 
b/core/src/execution/datafusion/expressions/subquery.rs
index a4b32ba..7cae129 100644
--- a/core/src/execution/datafusion/expressions/subquery.rs
+++ b/core/src/execution/datafusion/expressions/subquery.rs
@@ -93,7 +93,7 @@ impl PhysicalExpr for Subquery {
         let mut env = JVMClasses::get_env();
 
         unsafe {
-            let is_null = jni_static_call!(env,
+            let is_null = jni_static_call!(&mut env,
                 comet_exec.is_null(self.exec_context_id, self.id) -> jboolean
             )?;
 
@@ -105,50 +105,50 @@ impl PhysicalExpr for Subquery {
 
             match &self.data_type {
                 DataType::Boolean => {
-                    let r = jni_static_call!(env,
+                    let r = jni_static_call!(&mut env,
                         comet_exec.get_bool(self.exec_context_id, self.id) -> 
jboolean
                     )?;
                     Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 
0))))
                 }
                 DataType::Int8 => {
-                    let r = jni_static_call!(env,
+                    let r = jni_static_call!(&mut env,
                         comet_exec.get_byte(self.exec_context_id, self.id) -> 
jbyte
                     )?;
                     Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r))))
                 }
                 DataType::Int16 => {
-                    let r = jni_static_call!(env,
+                    let r = jni_static_call!(&mut env,
                         comet_exec.get_short(self.exec_context_id, self.id) -> 
jshort
                     )?;
                     Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r))))
                 }
                 DataType::Int32 => {
-                    let r = jni_static_call!(env,
+                    let r = jni_static_call!(&mut env,
                         comet_exec.get_int(self.exec_context_id, self.id) -> 
jint
                     )?;
                     Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r))))
                 }
                 DataType::Int64 => {
-                    let r = jni_static_call!(env,
+                    let r = jni_static_call!(&mut env,
                         comet_exec.get_long(self.exec_context_id, self.id) -> 
jlong
                     )?;
                     Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r))))
                 }
                 DataType::Float32 => {
-                    let r = jni_static_call!(env,
+                    let r = jni_static_call!(&mut env,
                         comet_exec.get_float(self.exec_context_id, self.id) -> 
f32
                     )?;
                     Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r))))
                 }
                 DataType::Float64 => {
-                    let r = jni_static_call!(env,
+                    let r = jni_static_call!(&mut env,
                         comet_exec.get_double(self.exec_context_id, self.id) 
-> f64
                     )?;
 
                     Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r))))
                 }
                 DataType::Decimal128(p, s) => {
-                    let bytes = jni_static_call!(env,
+                    let bytes = jni_static_call!(&mut env,
                         comet_exec.get_decimal(self.exec_context_id, self.id) 
-> BinaryWrapper
                     )?;
                     let bytes: &JByteArray = bytes.get().into();
@@ -161,14 +161,14 @@ impl PhysicalExpr for Subquery {
                     )))
                 }
                 DataType::Date32 => {
-                    let r = jni_static_call!(env,
+                    let r = jni_static_call!(&mut env,
                         comet_exec.get_int(self.exec_context_id, self.id) -> 
jint
                     )?;
 
                     Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r))))
                 }
                 DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
-                    let r = jni_static_call!(env,
+                    let r = jni_static_call!(&mut env,
                         comet_exec.get_long(self.exec_context_id, self.id) -> 
jlong
                     )?;
 
@@ -178,7 +178,7 @@ impl PhysicalExpr for Subquery {
                     )))
                 }
                 DataType::Utf8 => {
-                    let string = jni_static_call!(env,
+                    let string = jni_static_call!(&mut env,
                         comet_exec.get_string(self.exec_context_id, self.id) 
-> StringWrapper
                     )?;
 
@@ -186,7 +186,7 @@ impl PhysicalExpr for Subquery {
                     Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))))
                 }
                 DataType::Binary => {
-                    let bytes = jni_static_call!(env,
+                    let bytes = jni_static_call!(&mut env,
                         comet_exec.get_binary(self.exec_context_id, self.id) 
-> BinaryWrapper
                     )?;
                     let bytes: &JByteArray = bytes.get().into();
diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs
index 331e776..d3db7ba 100644
--- a/core/src/jvm_bridge/mod.rs
+++ b/core/src/jvm_bridge/mod.rs
@@ -19,7 +19,8 @@
 
 use jni::{
     errors::{Error, Result as JniResult},
-    objects::{JClass, JObject, JString, JValueGen, JValueOwned},
+    objects::{JClass, JMethodID, JObject, JString, JThrowable, JValueGen, 
JValueOwned},
+    signature::ReturnType,
     AttachGuard, JNIEnv,
 };
 use once_cell::sync::OnceCell;
@@ -58,29 +59,52 @@ macro_rules! jni_new_string {
 /// jname and value are the arguments.
 macro_rules! jni_call {
     ($env:expr, $clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) 
-> $ret:ty) => {{
-        $crate::jvm_bridge::jni_map_error!(
-            $env,
-            $env.call_method_unchecked(
-                $obj,
-                paste::paste! 
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]},
-                paste::paste! 
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method 
_ret>]}.clone(),
-                $crate::jvm_bridge::jvalues!($($args,)*)
-            )
-        ).and_then(|result| $crate::jvm_bridge::jni_map_error!($env, 
<$ret>::try_from(result)))
+        let method_id = paste::paste! {
+            $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ 
$method>]
+        };
+        let ret_type = paste::paste! {
+            $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ 
$method _ret>]
+        }.clone();
+        let args = $crate::jvm_bridge::jvalues!($($args,)*);
+
+        // Call the JVM method and obtain the returned value
+        let ret = $env.call_method_unchecked($obj, method_id, ret_type, args);
+
+        // Check if JVM has thrown any exception, and handle it if so.
+        let result = if let Some(exception) = 
$crate::jvm_bridge::check_exception($env).unwrap() {
+            Err(exception.into())
+        } else {
+            $crate::jvm_bridge::jni_map_error!($env, ret)
+        };
+
+        result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, 
<$ret>::try_from(result)))
     }}
 }
 
 macro_rules! jni_static_call {
     ($env:expr, $clsname:ident.$method:ident($($args:expr),* $(,)?) -> 
$ret:ty) => {{
-        $crate::jvm_bridge::jni_map_error!(
-            $env,
-            $env.call_static_method_unchecked(
-                &paste::paste! 
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]},
-                paste::paste! 
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]},
-                paste::paste! 
{$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method 
_ret>]}.clone(),
-                $crate::jvm_bridge::jvalues!($($args,)*)
-            )
-        ).and_then(|result| $crate::jvm_bridge::jni_map_error!($env, 
<$ret>::try_from(result)))
+        let clazz = &paste::paste! {
+            $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]
+        };
+        let method_id = paste::paste! {
+            $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ 
$method>]
+        };
+        let ret_type = paste::paste! {
+            $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ 
$method _ret>]
+        }.clone();
+        let args = $crate::jvm_bridge::jvalues!($($args,)*);
+
+        // Call the JVM static method and obtain the returned value
+        let ret = $env.call_static_method_unchecked(clazz, method_id, 
ret_type, args);
+
+        // Check if JVM has thrown any exception, and handle it if so.
+        let result = if let Some(exception) = 
$crate::jvm_bridge::check_exception($env).unwrap() {
+            Err(exception.into())
+        } else {
+            $crate::jvm_bridge::jni_map_error!($env, ret)
+        };
+
+        result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, 
<$ret>::try_from(result)))
     }}
 }
 
@@ -167,11 +191,21 @@ pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) -> 
JniResult<JClass<'stati
 mod comet_exec;
 pub use comet_exec::*;
 mod comet_metric_node;
-use crate::JAVA_VM;
+use crate::{
+    errors::{CometError, CometResult},
+    JAVA_VM,
+};
 pub use comet_metric_node::*;
 
 /// The JVM classes that are used in the JNI calls.
 pub struct JVMClasses<'a> {
+    /// Cached method ID for "java.lang.Object#getClass"
+    pub object_get_class_method: JMethodID,
+    /// Cached method ID for "java.lang.Class#getName"
+    pub class_get_name_method: JMethodID,
+    /// Cached method ID for "java.lang.Throwable#getMessage"
+    pub throwable_get_message_method: JMethodID,
+
     /// The CometMetricNode class. Used for updating the metrics.
     pub comet_metric_node: CometMetricNode<'a>,
     /// The static CometExec class. Used for getting the subquery result.
@@ -192,7 +226,25 @@ impl JVMClasses<'_> {
             // `JNIEnv` except for creating the global references of the 
classes.
             let env = unsafe { std::mem::transmute::<_, &'static mut 
JNIEnv>(env) };
 
+            let clazz = env.find_class("java/lang/Object").unwrap();
+            let object_get_class_method = env
+                .get_method_id(clazz, "getClass", "()Ljava/lang/Class;")
+                .unwrap();
+
+            let clazz = env.find_class("java/lang/Class").unwrap();
+            let class_get_name_method = env
+                .get_method_id(clazz, "getName", "()Ljava/lang/String;")
+                .unwrap();
+
+            let clazz = env.find_class("java/lang/Throwable").unwrap();
+            let throwable_get_message_method = env
+                .get_method_id(clazz, "getMessage", "()Ljava/lang/String;")
+                .unwrap();
+
             JVMClasses {
+                object_get_class_method,
+                class_get_name_method,
+                throwable_get_message_method,
                 comet_metric_node: CometMetricNode::new(env).unwrap(),
                 comet_exec: CometExec::new(env).unwrap(),
             }
@@ -211,3 +263,68 @@ impl JVMClasses<'_> {
         }
     }
 }
+
+pub(crate) fn check_exception(env: &mut JNIEnv) -> 
CometResult<Option<CometError>> {
+    let result = if env.exception_check()? {
+        let exception = env.exception_occurred()?;
+        env.exception_clear()?;
+        let exception_err = convert_exception(env, &exception)?;
+        Some(exception_err)
+    } else {
+        None
+    };
+
+    Ok(result)
+}
+
+/// Given a `JThrowable` which is thrown from calling a Java method on the 
native side,
+/// this converts it into a `CometError::JavaException` with the exception 
class name
+/// and exception message. This error can then be populated to the JVM side to 
let
+/// users know the cause of the native side error.
+pub(crate) fn convert_exception(
+    env: &mut JNIEnv,
+    throwable: &JThrowable,
+) -> CometResult<CometError> {
+    unsafe {
+        let cache = JVMClasses::get();
+
+        // get the class name of the exception by:
+        //  1. get the `Class` object of the input `throwable` via 
`Object#getClass` method
+        //  2. get the exception class name via calling `Class#getName` on the 
above object
+        let class_obj = env
+            .call_method_unchecked(
+                throwable,
+                cache.object_get_class_method,
+                ReturnType::Object,
+                &[],
+            )?
+            .l()?;
+        let exception_class_name = env
+            .call_method_unchecked(
+                class_obj,
+                cache.class_get_name_method,
+                ReturnType::Object,
+                &[],
+            )?
+            .l()?
+            .into();
+        let exception_class_name_str = 
env.get_string(&exception_class_name)?.into();
+
+        // get the exception message via calling `Throwable#getMessage` on the 
throwable object
+        let message = env
+            .call_method_unchecked(
+                throwable,
+                cache.throwable_get_message_method,
+                ReturnType::Object,
+                &[],
+            )?
+            .l()?
+            .into();
+        let message_str = env.get_string(&message)?.into();
+
+        Ok(CometError::JavaException {
+            class: exception_class_name_str,
+            msg: message_str,
+        })
+    }
+}

Reply via email to