zhztheplayer commented on a change in pull request #7030:
URL: https://github.com/apache/arrow/pull/7030#discussion_r567768143



##########
File path: cpp/src/jni/dataset/jni_wrapper.cpp
##########
@@ -0,0 +1,859 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include <arrow/array.h>
+#include <arrow/dataset/api.h>
+#include <arrow/dataset/file_base.h>
+#include <arrow/filesystem/localfs.h>
+#include <arrow/io/api.h>
+#include <arrow/ipc/api.h>
+#include <arrow/util/iterator.h>
+#include <arrow/util/logging.h>
+
+#include "org_apache_arrow_dataset_file_JniWrapper.h"
+#include "org_apache_arrow_dataset_jni_JniWrapper.h"
+#include "org_apache_arrow_dataset_jni_NativeMemoryPool.h"
+
+namespace arrow {
+namespace dataset {
+namespace jni {
+
+static jclass illegal_access_exception_class;
+static jclass illegal_argument_exception_class;
+static jclass runtime_exception_class;
+
+static jclass record_batch_handle_class;
+static jclass record_batch_handle_field_class;
+static jclass record_batch_handle_buffer_class;
+static jclass java_reservation_listener_class;
+
+static jmethodID record_batch_handle_constructor;
+static jmethodID record_batch_handle_field_constructor;
+static jmethodID record_batch_handle_buffer_constructor;
+static jmethodID reserve_memory_method;
+static jmethodID unreserve_memory_method;
+
+static jlong default_memory_pool_id = -1L;
+
+static jint JNI_VERSION = JNI_VERSION_1_6;
+
+class JniPendingException : public std::runtime_error {
+ public:
+  explicit JniPendingException(const std::string& arg) : runtime_error(arg) {}
+};
+
+void ThrowPendingException(const std::string& message) {
+  throw JniPendingException(message);
+}
+
+template <typename T>
+T JniGetOrThrow(arrow::Result<T> result) {
+  if (!result.status().ok()) {
+    ThrowPendingException(result.status().message());
+  }
+  return std::move(result).ValueOrDie();
+}
+
+void JniAssertOkOrThrow(arrow::Status status) {
+  if (!status.ok()) {
+    ThrowPendingException(status.message());
+  }
+}
+
+void JniThrow(std::string message) { ThrowPendingException(message); }
+
+jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) {
+  jclass local_class = env->FindClass(class_name);
+  jclass global_class = (jclass)env->NewGlobalRef(local_class);
+  env->DeleteLocalRef(local_class);
+  return global_class;
+}
+
+arrow::Result<jmethodID> GetMethodID(JNIEnv* env, jclass this_class, const 
char* name,
+                                     const char* sig) {
+  jmethodID ret = env->GetMethodID(this_class, name, sig);
+  if (ret == nullptr) {
+    std::string error_message = "Unable to find method " + std::string(name) +
+                                " within signature" + std::string(sig);
+    return arrow::Status::Invalid(error_message);
+  }
+  return ret;
+}
+
+arrow::Result<jmethodID> GetStaticMethodID(JNIEnv* env, jclass this_class,
+                                           const char* name, const char* sig) {
+  jmethodID ret = env->GetStaticMethodID(this_class, name, sig);
+  if (ret == nullptr) {
+    std::string error_message = "Unable to find static method " + 
std::string(name) +
+                                " within signature" + std::string(sig);
+    return arrow::Status::Invalid(error_message);
+  }
+  return ret;
+}
+
+std::shared_ptr<arrow::Schema> SchemaFromColumnNames(
+    const std::shared_ptr<arrow::Schema>& input,
+    const std::vector<std::string>& column_names) {
+  std::vector<std::shared_ptr<arrow::Field>> columns;
+  for (const auto& name : column_names) {
+    columns.push_back(input->GetFieldByName(name));
+  }
+  return std::make_shared<arrow::Schema>(columns);
+}
+
+arrow::Result<std::shared_ptr<arrow::dataset::FileFormat>> GetFileFormat(jint 
id) {
+  switch (id) {
+    case 0:
+      return std::make_shared<arrow::dataset::ParquetFileFormat>();
+    default:
+      std::string error_message = "illegal file format id: " + 
std::to_string(id);
+      return arrow::Status::Invalid(error_message);
+  }
+}
+
+std::string JStringToCString(JNIEnv* env, jstring string) {
+  if (string == nullptr) {
+    return std::string();
+  }
+  jboolean copied;
+  const char* chars = env->GetStringUTFChars(string, &copied);
+  std::string ret = strdup(chars);
+  env->ReleaseStringUTFChars(string, chars);
+  return ret;
+}
+
+std::vector<std::string> ToStringVector(JNIEnv* env, jobjectArray& str_array) {
+  int length = env->GetArrayLength(str_array);
+  std::vector<std::string> vector;
+  for (int i = 0; i < length; i++) {
+    auto string = (jstring)(env->GetObjectArrayElement(str_array, i));
+    vector.push_back(JStringToCString(env, string));
+  }
+  return vector;
+}
+
+template <typename T>
+jlong CreateNativeRef(std::shared_ptr<T> t) {
+  std::shared_ptr<T>* retained_ptr = new std::shared_ptr<T>(t);
+  return reinterpret_cast<jlong>(retained_ptr);
+}
+
+template <typename T>
+std::shared_ptr<T> RetrieveNativeInstance(jlong ref) {
+  std::shared_ptr<T>* retrieved_ptr = 
reinterpret_cast<std::shared_ptr<T>*>(ref);
+  return *retrieved_ptr;
+}
+
+template <typename T>
+void ReleaseNativeRef(jlong ref) {
+  std::shared_ptr<T>* retrieved_ptr = 
reinterpret_cast<std::shared_ptr<T>*>(ref);
+  delete retrieved_ptr;
+}
+
+arrow::Result<jbyteArray> ToSchemaByteArray(JNIEnv* env,
+                                            std::shared_ptr<arrow::Schema> 
schema) {
+  ARROW_ASSIGN_OR_RAISE(
+      std::shared_ptr<arrow::Buffer> buffer,
+      arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool()))
+
+  jbyteArray out = env->NewByteArray(buffer->size());
+  auto src = reinterpret_cast<const jbyte*>(buffer->data());
+  env->SetByteArrayRegion(out, 0, buffer->size(), src);
+  return out;
+}
+
+arrow::Result<std::shared_ptr<arrow::Schema>> FromSchemaByteArray(
+    JNIEnv* env, jbyteArray schemaBytes) {
+  arrow::ipc::DictionaryMemo in_memo;
+  int schemaBytes_len = env->GetArrayLength(schemaBytes);
+  jbyte* schemaBytes_data = env->GetByteArrayElements(schemaBytes, nullptr);
+  auto serialized_schema = std::make_shared<arrow::Buffer>(
+      reinterpret_cast<uint8_t*>(schemaBytes_data), schemaBytes_len);
+  arrow::io::BufferReader buf_reader(serialized_schema);
+  ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::Schema> schema,
+                        arrow::ipc::ReadSchema(&buf_reader, &in_memo))
+  env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT);
+  return schema;
+}
+
+/// Listener to act on reservations/unreservations from 
ReservationListenableMemoryPool.
+///
+/// Note the memory pool will call this listener only on block-level memory
+/// reservation/unreservation is granted. So the invocation parameter "size" 
is always
+/// multiple of block size (by default, 512k) specified in memory pool.
+class ReservationListener {
+ public:
+  virtual ~ReservationListener() = default;
+
+  virtual arrow::Status OnReservation(int64_t size) = 0;
+  virtual arrow::Status OnRelease(int64_t size) = 0;
+
+ protected:
+  ReservationListener() = default;
+};
+
+class ReservationListenableMemoryPoolImpl {
+ public:
+  explicit ReservationListenableMemoryPoolImpl(
+      arrow::MemoryPool* pool, std::shared_ptr<ReservationListener> listener,
+      int64_t block_size)
+      : pool_(pool),
+        listener_(listener),
+        block_size_(block_size),
+        blocks_reserved_(0),
+        bytes_reserved_(0) {}
+
+  arrow::Status Allocate(int64_t size, uint8_t** out) {
+    RETURN_NOT_OK(UpdateReservation(size));
+    arrow::Status error = pool_->Allocate(size, out);
+    if (!error.ok()) {
+      RETURN_NOT_OK(UpdateReservation(-size));
+      return error;
+    }
+    return arrow::Status::OK();
+  }
+
+  arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) {
+    bool reserved = false;
+    int64_t diff = new_size - old_size;
+    if (new_size >= old_size) {
+      // new_size >= old_size, pre-reserve bytes from listener before 
allocating
+      // from underlying pool
+      RETURN_NOT_OK(UpdateReservation(diff));
+      reserved = true;
+    }
+    arrow::Status error = pool_->Reallocate(old_size, new_size, ptr);
+    if (!error.ok()) {
+      if (reserved) {
+        // roll back reservations on error
+        RETURN_NOT_OK(UpdateReservation(-diff));
+      }
+      return error;
+    }
+    if (!reserved) {
+      // otherwise (e.g. new_size < old_size), make updates after calling 
underlying pool
+      RETURN_NOT_OK(UpdateReservation(diff));
+    }
+    return arrow::Status::OK();
+  }
+
+  void Free(uint8_t* buffer, int64_t size) {
+    pool_->Free(buffer, size);
+    // FIXME: See ARROW-11143, currently method ::Free doesn't allow Status 
return
+    arrow::Status s = UpdateReservation(-size);
+    if (!s.ok()) {
+      ARROW_LOG(FATAL) << "Failed to update reservation while freeing bytes: "
+                       << s.message();
+      return;
+    }
+  }
+
+  arrow::Status UpdateReservation(int64_t diff) {
+    int64_t granted = Reserve(diff);
+    if (granted == 0) {
+      return arrow::Status::OK();
+    }
+    if (granted < 0) {
+      RETURN_NOT_OK(listener_->OnRelease(-granted));
+      return arrow::Status::OK();
+    }
+    RETURN_NOT_OK(listener_->OnReservation(granted));
+    return arrow::Status::OK();
+  }
+
+  int64_t Reserve(int64_t diff) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    bytes_reserved_ += diff;
+    int64_t new_block_count;
+    if (bytes_reserved_ == 0) {
+      new_block_count = 0;
+    } else {
+      // ceil to get the required block number
+      new_block_count = (bytes_reserved_ - 1) / block_size_ + 1;
+    }
+    int64_t bytes_granted = (new_block_count - blocks_reserved_) * block_size_;
+    blocks_reserved_ = new_block_count;
+    return bytes_granted;
+  }
+
+  int64_t bytes_allocated() { return pool_->bytes_allocated(); }
+
+  int64_t max_memory() { return pool_->max_memory(); }
+
+  std::string backend_name() { return pool_->backend_name(); }
+
+  std::shared_ptr<ReservationListener> get_listener() { return listener_; }
+
+ private:
+  arrow::MemoryPool* pool_;
+  std::shared_ptr<ReservationListener> listener_;
+  int64_t block_size_;
+  int64_t blocks_reserved_;
+  int64_t bytes_reserved_;
+  std::mutex mutex_;
+};
+
+/// A memory pool implementation for pre-reserving memory blocks from a
+/// customizable listener. This will typically be used when memory allocations
+/// have to subject to another "virtual" resource manager, which just tracks or
+/// limits number of bytes of application's overall memory usage. The 
underlying
+/// memory pool will still be responsible for actual malloc/free operations.
+class ReservationListenableMemoryPool : public arrow::MemoryPool {
+ public:
+  /// \brief Constructor.
+  ///
+  /// \param[in] pool the underlying memory pool
+  /// \param[in] listener a listener for block-level reservations/releases.
+  /// \param[in] block_size size of each block to reserve from the listener
+  explicit ReservationListenableMemoryPool(MemoryPool* pool,
+                                           
std::shared_ptr<ReservationListener> listener,
+                                           int64_t block_size = 512 * 1024) {
+    impl_.reset(new ReservationListenableMemoryPoolImpl(pool, listener, 
block_size));
+  }
+
+  arrow::Status Allocate(int64_t size, uint8_t** out) override {
+    return impl_->Allocate(size, out);
+  }
+
+  arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) 
override {
+    return impl_->Reallocate(old_size, new_size, ptr);
+  }
+
+  void Free(uint8_t* buffer, int64_t size) override { return 
impl_->Free(buffer, size); }
+
+  int64_t bytes_allocated() const override { return impl_->bytes_allocated(); }
+
+  int64_t max_memory() const override { return impl_->max_memory(); }
+
+  std::string backend_name() const override { return impl_->backend_name(); }
+
+  std::shared_ptr<ReservationListener> get_listener() { return 
impl_->get_listener(); }
+
+ private:
+  std::unique_ptr<ReservationListenableMemoryPoolImpl> impl_;
+};
+
+class ReserveFromJava : public ReservationListener {
+ public:
+  ReserveFromJava(JavaVM* vm, jobject java_reservation_listener)
+      : vm_(vm), java_reservation_listener_(java_reservation_listener) {}
+
+  arrow::Status OnReservation(int64_t size) override {
+    JNIEnv* env;
+    if (vm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
+      return arrow::Status::Invalid("JNIEnv was not attached to current 
thread");
+    }
+    env->CallObjectMethod(java_reservation_listener_, reserve_memory_method, 
size);
+    if (env->ExceptionCheck()) {
+      env->ExceptionDescribe();
+      env->ExceptionClear();
+      return arrow::Status::Invalid("Error calling Java side reservation 
listener");
+    }
+    return arrow::Status::OK();
+  }
+
+  arrow::Status OnRelease(int64_t size) override {
+    JNIEnv* env;
+    if (vm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
+      return arrow::Status::Invalid("JNIEnv was not attached to current 
thread");
+    }
+    env->CallObjectMethod(java_reservation_listener_, unreserve_memory_method, 
size);
+    if (env->ExceptionCheck()) {
+      env->ExceptionDescribe();
+      env->ExceptionClear();
+      return arrow::Status::Invalid("Error calling Java side reservation 
listener");
+    }
+    return arrow::Status::OK();
+  }
+
+  jobject GetJavaReservationListener() { return java_reservation_listener_; }
+
+ private:
+  JavaVM* vm_;
+  jobject java_reservation_listener_;
+};
+
+/// \class DisposableScannerAdaptor
+/// \brief An adaptor that iterates over a Scanner instance then returns 
RecordBatches
+/// directly.
+///
+/// This lessens the complexity of the JNI bridge to make sure it to be easier 
to
+/// maintain. On Java-side, NativeScanner can only produces a single 
NativeScanTask
+/// instance during its whole lifecycle. Each task stands for a 
DisposableScannerAdaptor
+/// instance through JNI bridge.
+///
+class DisposableScannerAdaptor {
+ public:
+  DisposableScannerAdaptor(std::shared_ptr<arrow::dataset::Scanner> scanner,
+                           arrow::dataset::ScanTaskIterator task_itr) {
+    this->scanner_ = std::move(scanner);
+    this->task_itr_ = std::move(task_itr);
+  }
+
+  static arrow::Result<std::shared_ptr<DisposableScannerAdaptor>> Create(
+      std::shared_ptr<arrow::dataset::Scanner> scanner) {
+    ARROW_ASSIGN_OR_RAISE(arrow::dataset::ScanTaskIterator task_itr, 
scanner->Scan())
+    return std::make_shared<DisposableScannerAdaptor>(scanner, 
std::move(task_itr));
+  }
+
+  arrow::Result<std::shared_ptr<arrow::RecordBatch>> Next() {
+    do {
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::RecordBatch> batch, 
NextBatch())
+      if (batch != nullptr) {
+        return batch;
+      }
+      // batch is null, current task is fully consumed
+      ARROW_ASSIGN_OR_RAISE(bool has_next_task, NextTask())
+      if (!has_next_task) {
+        // no more tasks
+        return nullptr;
+      }
+      // new task appended, read again
+    } while (true);
+  }
+
+  const std::shared_ptr<arrow::dataset::Scanner>& GetScanner() const { return 
scanner_; }
+
+ protected:

Review comment:
       done




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to