zhztheplayer commented on a change in pull request #7030: URL: https://github.com/apache/arrow/pull/7030#discussion_r569376564
########## File path: cpp/src/jni/dataset/jni_wrapper.cpp ########## @@ -0,0 +1,859 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include <arrow/array.h> +#include <arrow/dataset/api.h> +#include <arrow/dataset/file_base.h> +#include <arrow/filesystem/localfs.h> +#include <arrow/io/api.h> +#include <arrow/ipc/api.h> +#include <arrow/util/iterator.h> +#include <arrow/util/logging.h> + +#include "org_apache_arrow_dataset_file_JniWrapper.h" +#include "org_apache_arrow_dataset_jni_JniWrapper.h" +#include "org_apache_arrow_dataset_jni_NativeMemoryPool.h" + +namespace arrow { +namespace dataset { +namespace jni { + +static jclass illegal_access_exception_class; +static jclass illegal_argument_exception_class; +static jclass runtime_exception_class; + +static jclass record_batch_handle_class; +static jclass record_batch_handle_field_class; +static jclass record_batch_handle_buffer_class; +static jclass java_reservation_listener_class; + +static jmethodID record_batch_handle_constructor; +static jmethodID record_batch_handle_field_constructor; +static jmethodID record_batch_handle_buffer_constructor; +static jmethodID reserve_memory_method; +static jmethodID unreserve_memory_method; + +static jlong default_memory_pool_id = -1L; + +static jint JNI_VERSION = JNI_VERSION_1_6; + +class JniPendingException : public std::runtime_error { + public: + explicit JniPendingException(const std::string& arg) : runtime_error(arg) {} +}; + +void ThrowPendingException(const std::string& message) { + throw JniPendingException(message); +} + +template <typename T> +T JniGetOrThrow(arrow::Result<T> result) { + if (!result.status().ok()) { + ThrowPendingException(result.status().message()); + } + return std::move(result).ValueOrDie(); +} + +void JniAssertOkOrThrow(arrow::Status status) { + if (!status.ok()) { + ThrowPendingException(status.message()); + } +} + +void JniThrow(std::string message) { ThrowPendingException(message); } + +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { + jclass local_class = env->FindClass(class_name); + jclass global_class = (jclass)env->NewGlobalRef(local_class); + env->DeleteLocalRef(local_class); + return global_class; +} + +arrow::Result<jmethodID> GetMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig) { + jmethodID ret = env->GetMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find method " + std::string(name) + + " within signature" + std::string(sig); + return arrow::Status::Invalid(error_message); + } + return ret; +} + +arrow::Result<jmethodID> GetStaticMethodID(JNIEnv* env, jclass this_class, + const char* name, const char* sig) { + jmethodID ret = env->GetStaticMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find static method " + std::string(name) + + " within signature" + std::string(sig); + return arrow::Status::Invalid(error_message); + } + return ret; +} + +std::shared_ptr<arrow::Schema> SchemaFromColumnNames( + const std::shared_ptr<arrow::Schema>& input, + const std::vector<std::string>& column_names) { + std::vector<std::shared_ptr<arrow::Field>> columns; + for (const auto& name : column_names) { + columns.push_back(input->GetFieldByName(name)); + } + return std::make_shared<arrow::Schema>(columns); +} + +arrow::Result<std::shared_ptr<arrow::dataset::FileFormat>> GetFileFormat(jint id) { + switch (id) { + case 0: + return std::make_shared<arrow::dataset::ParquetFileFormat>(); + default: + std::string error_message = "illegal file format id: " + std::to_string(id); + return arrow::Status::Invalid(error_message); + } +} + +std::string JStringToCString(JNIEnv* env, jstring string) { + if (string == nullptr) { + return std::string(); + } + jboolean copied; + const char* chars = env->GetStringUTFChars(string, &copied); + std::string ret = strdup(chars); + env->ReleaseStringUTFChars(string, chars); + return ret; +} + +std::vector<std::string> ToStringVector(JNIEnv* env, jobjectArray& str_array) { + int length = env->GetArrayLength(str_array); + std::vector<std::string> vector; + for (int i = 0; i < length; i++) { + auto string = (jstring)(env->GetObjectArrayElement(str_array, i)); + vector.push_back(JStringToCString(env, string)); + } + return vector; +} + +template <typename T> +jlong CreateNativeRef(std::shared_ptr<T> t) { + std::shared_ptr<T>* retained_ptr = new std::shared_ptr<T>(t); + return reinterpret_cast<jlong>(retained_ptr); +} + +template <typename T> +std::shared_ptr<T> RetrieveNativeInstance(jlong ref) { + std::shared_ptr<T>* retrieved_ptr = reinterpret_cast<std::shared_ptr<T>*>(ref); + return *retrieved_ptr; +} + +template <typename T> +void ReleaseNativeRef(jlong ref) { + std::shared_ptr<T>* retrieved_ptr = reinterpret_cast<std::shared_ptr<T>*>(ref); + delete retrieved_ptr; +} + +arrow::Result<jbyteArray> ToSchemaByteArray(JNIEnv* env, + std::shared_ptr<arrow::Schema> schema) { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr<arrow::Buffer> buffer, + arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool())) + + jbyteArray out = env->NewByteArray(buffer->size()); + auto src = reinterpret_cast<const jbyte*>(buffer->data()); + env->SetByteArrayRegion(out, 0, buffer->size(), src); + return out; +} + +arrow::Result<std::shared_ptr<arrow::Schema>> FromSchemaByteArray( + JNIEnv* env, jbyteArray schemaBytes) { + arrow::ipc::DictionaryMemo in_memo; + int schemaBytes_len = env->GetArrayLength(schemaBytes); + jbyte* schemaBytes_data = env->GetByteArrayElements(schemaBytes, nullptr); + auto serialized_schema = std::make_shared<arrow::Buffer>( + reinterpret_cast<uint8_t*>(schemaBytes_data), schemaBytes_len); + arrow::io::BufferReader buf_reader(serialized_schema); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::Schema> schema, + arrow::ipc::ReadSchema(&buf_reader, &in_memo)) + env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT); + return schema; +} + +/// Listener to act on reservations/unreservations from ReservationListenableMemoryPool. +/// +/// Note the memory pool will call this listener only on block-level memory +/// reservation/unreservation is granted. So the invocation parameter "size" is always +/// multiple of block size (by default, 512k) specified in memory pool. +class ReservationListener { + public: + virtual ~ReservationListener() = default; + + virtual arrow::Status OnReservation(int64_t size) = 0; + virtual arrow::Status OnRelease(int64_t size) = 0; + + protected: + ReservationListener() = default; +}; + +class ReservationListenableMemoryPoolImpl { + public: + explicit ReservationListenableMemoryPoolImpl( + arrow::MemoryPool* pool, std::shared_ptr<ReservationListener> listener, + int64_t block_size) + : pool_(pool), + listener_(listener), + block_size_(block_size), + blocks_reserved_(0), + bytes_reserved_(0) {} + + arrow::Status Allocate(int64_t size, uint8_t** out) { + RETURN_NOT_OK(UpdateReservation(size)); + arrow::Status error = pool_->Allocate(size, out); + if (!error.ok()) { + RETURN_NOT_OK(UpdateReservation(-size)); + return error; + } + return arrow::Status::OK(); + } + + arrow::Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) { + bool reserved = false; + int64_t diff = new_size - old_size; + if (new_size >= old_size) { + // new_size >= old_size, pre-reserve bytes from listener before allocating + // from underlying pool + RETURN_NOT_OK(UpdateReservation(diff)); + reserved = true; + } + arrow::Status error = pool_->Reallocate(old_size, new_size, ptr); + if (!error.ok()) { + if (reserved) { + // roll back reservations on error + RETURN_NOT_OK(UpdateReservation(-diff)); + } + return error; + } + if (!reserved) { + // otherwise (e.g. new_size < old_size), make updates after calling underlying pool + RETURN_NOT_OK(UpdateReservation(diff)); + } + return arrow::Status::OK(); + } + + void Free(uint8_t* buffer, int64_t size) { + pool_->Free(buffer, size); + // FIXME: See ARROW-11143, currently method ::Free doesn't allow Status return + arrow::Status s = UpdateReservation(-size); + if (!s.ok()) { + ARROW_LOG(FATAL) << "Failed to update reservation while freeing bytes: " + << s.message(); + return; + } + } + + arrow::Status UpdateReservation(int64_t diff) { + int64_t granted = Reserve(diff); + if (granted == 0) { + return arrow::Status::OK(); + } + if (granted < 0) { + RETURN_NOT_OK(listener_->OnRelease(-granted)); + return arrow::Status::OK(); + } + RETURN_NOT_OK(listener_->OnReservation(granted)); + return arrow::Status::OK(); + } + + int64_t Reserve(int64_t diff) { + std::lock_guard<std::mutex> lock(mutex_); + bytes_reserved_ += diff; + int64_t new_block_count; + if (bytes_reserved_ == 0) { + new_block_count = 0; + } else { + // ceil to get the required block number + new_block_count = (bytes_reserved_ - 1) / block_size_ + 1; + } + int64_t bytes_granted = (new_block_count - blocks_reserved_) * block_size_; + blocks_reserved_ = new_block_count; + return bytes_granted; + } + + int64_t bytes_allocated() { return pool_->bytes_allocated(); } + + int64_t max_memory() { return pool_->max_memory(); } + + std::string backend_name() { return pool_->backend_name(); } + + std::shared_ptr<ReservationListener> get_listener() { return listener_; } + + private: + arrow::MemoryPool* pool_; + std::shared_ptr<ReservationListener> listener_; + int64_t block_size_; + int64_t blocks_reserved_; + int64_t bytes_reserved_; + std::mutex mutex_; +}; + +/// A memory pool implementation for pre-reserving memory blocks from a Review comment: Added some unit tests ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org