emkornfield commented on a change in pull request #7030:
URL: https://github.com/apache/arrow/pull/7030#discussion_r414990082

File path: cpp/src/jni/dataset/jni_wrapper.cpp
@@ -0,0 +1,577 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//   http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#include <arrow/dataset/file_base.h>
+#include <arrow/dataset/api.h>
+#include <arrow/filesystem/localfs.h>
+#include <arrow/ipc/api.h>
+#include <arrow/util/iterator.h>
+#include <arrow/filesystem/hdfs.h>
+#include <arrow/io/api.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/message.h>
+#include "concurrent_map.h"
+#include "arrow/compute/kernel.h"
+#include "arrow/compute/kernels/cast.h"
+#include "arrow/compute/kernels/compare.h"
+#include "jni/dataset/Types.pb.h"
+#include "org_apache_arrow_dataset_jni_JniWrapper.h"
+#include "org_apache_arrow_dataset_file_JniWrapper.h"
+static jclass illegal_access_exception_class;
+static jclass illegal_argument_exception_class;
+static jclass runtime_exception_class;
+static jclass record_batch_handle_class;
+static jclass record_batch_handle_field_class;
+static jclass record_batch_handle_buffer_class;
+static jmethodID record_batch_handle_constructor;
+static jmethodID record_batch_handle_field_constructor;
+static jmethodID record_batch_handle_buffer_constructor;
+static jint JNI_VERSION = JNI_VERSION_1_6;
+using arrow::jni::ConcurrentMap;
+static ConcurrentMap<std::shared_ptr<arrow::dataset::DatasetFactory>> 
+static ConcurrentMap<std::shared_ptr<arrow::dataset::Dataset>> dataset_holder_;
+static ConcurrentMap<std::shared_ptr<arrow::dataset::ScanTask>> 
+static ConcurrentMap<std::shared_ptr<arrow::dataset::Scanner>> scanner_holder_;
+static ConcurrentMap<std::shared_ptr<arrow::RecordBatchIterator>> 
+static ConcurrentMap<std::shared_ptr<arrow::Buffer>> buffer_holder_;
+#define JNI_ASSIGN_OR_THROW_IMPL(t, lhs, rexpr)                             \
+  auto t = (rexpr);                                                         \
+  if (!t.status().ok()) {                                                   \
+    env->ThrowNew(runtime_exception_class, t.status().message().c_str());   \
+  }                                                                         \
+  lhs = std::move(t).ValueOrDie();
+#define JNI_ASSIGN_OR_THROW(lhs, rexpr)                                     \
lhs, rexpr)
+#define JNI_ASSERT_OK_OR_THROW(expr)                                          \
+  do {                                                                        \
+    auto _res = (expr);                                                       \
+    arrow::Status _st = ::arrow::internal::GenericToStatus(_res);             \
+    if (!_st.ok()) {                                                          \
+       env->ThrowNew(runtime_exception_class, _st.message().c_str());  \
+    }                                                                         \
+  } while (false);
+jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) {
+  jclass local_class = env->FindClass(class_name);
+  jclass global_class = (jclass)env->NewGlobalRef(local_class);
+  env->DeleteLocalRef(local_class);
+  return global_class;
+jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const 
char* sig) {
+  jmethodID ret = env->GetMethodID(this_class, name, sig);
+  if (ret == nullptr) {
+    std::string error_message = "Unable to find method " + std::string(name) +
+        " within signature" + std::string(sig);
+    env->ThrowNew(illegal_access_exception_class, error_message.c_str());
+  }
+  return ret;
+jint JNI_OnLoad(JavaVM* vm, void* reserved) {
+  JNIEnv* env;
+  if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
+    return JNI_ERR;
+  }
+  illegal_access_exception_class =
+      CreateGlobalClassReference(env, "Ljava/lang/IllegalAccessException;");
+  illegal_argument_exception_class =
+      CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;");
+  runtime_exception_class =
+      CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;");
+  record_batch_handle_class =
+      CreateGlobalClassReference(env, 
+  record_batch_handle_field_class =
+      CreateGlobalClassReference(env, 
+  record_batch_handle_buffer_class =
+      CreateGlobalClassReference(env, 
+  record_batch_handle_constructor = GetMethodID(env, 
record_batch_handle_class, "<init>",
+  record_batch_handle_field_constructor = GetMethodID(env, 
record_batch_handle_field_class, "<init>",
+                                                      "(JJ)V");
+  record_batch_handle_buffer_constructor = GetMethodID(env, 
record_batch_handle_buffer_class, "<init>",
+                                                       "(JJJJ)V");
+  env->ExceptionDescribe();
+  return JNI_VERSION;
+void JNI_OnUnload(JavaVM* vm, void* reserved) {
+  JNIEnv* env;
+  vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION);
+  env->DeleteGlobalRef(illegal_access_exception_class);
+  env->DeleteGlobalRef(illegal_argument_exception_class);
+  env->DeleteGlobalRef(runtime_exception_class);
+  env->DeleteGlobalRef(record_batch_handle_class);
+  env->DeleteGlobalRef(record_batch_handle_field_class);
+  env->DeleteGlobalRef(record_batch_handle_buffer_class);
+  dataset_factory_holder_.Clear();
+  dataset_holder_.Clear();
+  scan_task_holder_.Clear();
+  scanner_holder_.Clear();
+  iterator_holder_.Clear();
+  buffer_holder_.Clear();
+std::shared_ptr<arrow::Schema> SchemaFromColumnNames(
+    const std::shared_ptr<arrow::Schema>& input, const 
std::vector<std::string>& column_names) {
+  std::vector<std::shared_ptr<arrow::Field>> columns;
+  for (const auto& name : column_names) {
+    columns.push_back(input->GetFieldByName(name));
+  }
+  return std::make_shared<arrow::Schema>(columns);
+std::shared_ptr<arrow::dataset::FileFormat> GetFileFormat(JNIEnv *env, jint 
id) {
+  switch (id) {
+    case 0:
+      return std::make_shared<arrow::dataset::ParquetFileFormat>();
+    default:
+      std::string error_message = "illegal file format id: " + 
+      env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
+      return nullptr; // unreachable
+  }
+std::shared_ptr<arrow::fs::FileSystem> GetFileSystem(JNIEnv *env, jint id, 
std::string path,
+                                       std::string* out_path) {
+  switch (id) {
+    case 0:
+      *out_path = path;
+      return std::make_shared<arrow::fs::LocalFileSystem>();
+    case 1: {
+      JNI_ASSIGN_OR_THROW(std::shared_ptr<arrow::fs::FileSystem> ret, 
arrow::fs::FileSystemFromUri(path, out_path))
+      return ret;
+    }
+    default:std::string error_message = "illegal filesystem id: " + 
+      env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
+      return nullptr; // unreachable
+  }
+std::string JStringToCString(JNIEnv* env, jstring string) {
+  jboolean copied;
+  int32_t length = env->GetStringUTFLength(string);
+  const char *chars = env->GetStringUTFChars(string, &copied);
+  std::string str = std::string(chars, length);
+  // fixme calling ReleaseStringUTFChars if memory leak faced
+  return str;
+std::vector<std::string> ToStringVector(JNIEnv* env, jobjectArray& str_array) {
+  int length = env->GetArrayLength(str_array);
+  std::vector<std::string> vector;
+  for (int i = 0; i < length; i++) {
+    auto string = (jstring) (env->GetObjectArrayElement(str_array, i));
+    vector.push_back(JStringToCString(env, string));
+  }
+  return vector;
+template <typename T>
+std::vector<T> collect(JNIEnv* env, arrow::Iterator<T> itr) {
+  std::vector<T> vector;
+  while(true) {
+    JNI_ASSIGN_OR_THROW(T t, itr.Next())
+    if (!t) {
+      break;
+    }
+    vector.push_back(t);
+  }
+  return vector;
+// FIXME: COPIED FROM intel/master on which this branch is not rebased yet
+// FIXME: 
+jbyteArray ToSchemaByteArray(JNIEnv* env, std::shared_ptr<arrow::Schema> 
schema) {
+  JNI_ASSIGN_OR_THROW(std::shared_ptr<arrow::Buffer> buffer,
+      arrow::ipc::SerializeSchema(*schema.get(), nullptr, 
+  jbyteArray out = env->NewByteArray(buffer->size());
+  auto src = reinterpret_cast<const jbyte*>(buffer->data());
+  env->SetByteArrayRegion(out, 0, buffer->size(), src);
+  return out;
+// FIXME: COPIED FROM intel/master on which this branch is not rebased yet
+// FIXME: 
+std::shared_ptr<arrow::Schema> FromSchemaByteArray(JNIEnv* env, jbyteArray 
schemaBytes) {
+  arrow::ipc::DictionaryMemo in_memo;
+  int schemaBytes_len = env->GetArrayLength(schemaBytes);
+  jbyte* schemaBytes_data = env->GetByteArrayElements(schemaBytes, 0);
+  auto serialized_schema =
+      std::make_shared<arrow::Buffer>((uint8_t*)schemaBytes_data, 
+  arrow::io::BufferReader buf_reader(serialized_schema);
+  JNI_ASSIGN_OR_THROW(std::shared_ptr<arrow::Schema> schema, 
arrow::ipc::ReadSchema(&buf_reader, &in_memo))
+  env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT);
+  return schema;
+bool ParseProtobuf(uint8_t* buf, int bufLen, google::protobuf::Message* msg) {
+    google::protobuf::io::CodedInputStream cis(buf, bufLen);
+    cis.SetRecursionLimit(1000);
+    return msg->ParseFromCodedStream(&cis);
+void releaseFilterInput(jbyteArray condition_arr, jbyte* condition_bytes, 
JNIEnv* env) {
+  env->ReleaseByteArrayElements(condition_arr, condition_bytes, JNI_ABORT);
+// fixme in development. Not all node types considered.
+std::shared_ptr<arrow::dataset::Expression> translateNode(types::TreeNode 
node, JNIEnv* env) {
+  if (node.has_fieldnode()) {
+    const types::FieldNode& f_node = node.fieldnode();
+    const std::string& name = f_node.name();
+    return std::make_shared<arrow::dataset::FieldExpression>(name);
+  }
+  if (node.has_intnode()) {
+    const types::IntNode& int_node = node.intnode();
+    int32_t val = int_node.value();
+    return 
+  }
+  if (node.has_longnode()) {
+    const types::LongNode& long_node = node.longnode();
+    int64_t val = long_node.value();
+    return 
+  }
+  if (node.has_floatnode()) {
+    const types::FloatNode& float_node = node.floatnode();
+    float_t val = float_node.value();
+    return 
+  }
+  if (node.has_doublenode()) {
+    const types::DoubleNode& double_node = node.doublenode();
+    double_t val = double_node.value();
+    return 
+  }
+  if (node.has_booleannode()) {
+    const types::BooleanNode& boolean_node = node.booleannode();
+    bool val = boolean_node.value();
+    return 
+  }
+  if (node.has_andnode()) {
+    const types::AndNode& and_node = node.andnode();
+    const types::TreeNode& left_arg = and_node.leftarg();
+    const types::TreeNode& right_arg = and_node.rightarg();
+    return 
std::make_shared<arrow::dataset::AndExpression>(translateNode(left_arg, env), 
translateNode(right_arg, env));
+  }
+  if (node.has_ornode()) {
+    const types::OrNode& or_node = node.ornode();
+    const types::TreeNode& left_arg = or_node.leftarg();
+    const types::TreeNode& right_arg = or_node.rightarg();
+    return 
std::make_shared<arrow::dataset::OrExpression>(translateNode(left_arg, env), 
translateNode(right_arg, env));
+  }
+  if (node.has_cpnode()) {
+    const types::ComparisonNode& cp_node = node.cpnode();
+    const std::string& op_name = cp_node.opname();
+    arrow::compute::CompareOperator op;
+    if (op_name == "equal") {
+      op = arrow::compute::CompareOperator::EQUAL;
+    } else if (op_name == "greaterThan") {
+      op = arrow::compute::CompareOperator::GREATER;
+    } else if (op_name == "greaterThanOrEqual") {
+      op = arrow::compute::CompareOperator::GREATER_EQUAL;
+    } else if (op_name == "lessThan") {
+      op = arrow::compute::CompareOperator::LESS;
+    } else if (op_name == "lessThanOrEqual") {
+      op = arrow::compute::CompareOperator::LESS_EQUAL;
+    } else {
+      std::string error_message = "Unknown operation name in comparison node";
+      env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
+      return nullptr; // unreachable
+    }
+    const types::TreeNode& left_arg = cp_node.leftarg();
+    const types::TreeNode& right_arg = cp_node.rightarg();
+    return std::make_shared<arrow::dataset::ComparisonExpression>(op,
+        translateNode(left_arg, env), translateNode(right_arg, env));
+  }
+  if (node.has_notnode()) {
+    const types::NotNode& not_node = node.notnode();
+    const ::types::TreeNode& child = not_node.args();
+    std::shared_ptr<arrow::dataset::Expression> translatedChild = 
translateNode(child, env);
+    return std::make_shared<arrow::dataset::NotExpression>(translatedChild);
+  }
+  if (node.has_isvalidnode()) {
+    const types::IsValidNode& is_valid_node = node.isvalidnode();
+    const ::types::TreeNode& child = is_valid_node.args();
+    std::shared_ptr<arrow::dataset::Expression> translatedChild = 
translateNode(child, env);
+    return 
+  }
+  std::string error_message = "Unknown node type";
+  env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
+  return nullptr; // unreachable
+std::shared_ptr<arrow::dataset::Expression> translateFilter(types::Condition 
condition, JNIEnv* env) {

Review comment:
       style nit:  I think Pure C++ methods should be TranslateFilter

This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:

Reply via email to