zhztheplayer commented on a change in pull request #7030: URL: https://github.com/apache/arrow/pull/7030#discussion_r567767970
########## File path: cpp/src/jni/dataset/jni_wrapper.cpp ########## @@ -0,0 +1,859 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <mutex> + +#include <arrow/array.h> +#include <arrow/dataset/api.h> +#include <arrow/dataset/file_base.h> +#include <arrow/filesystem/localfs.h> +#include <arrow/io/api.h> +#include <arrow/ipc/api.h> +#include <arrow/util/iterator.h> +#include <arrow/util/logging.h> + +#include "org_apache_arrow_dataset_file_JniWrapper.h" +#include "org_apache_arrow_dataset_jni_JniWrapper.h" +#include "org_apache_arrow_dataset_jni_NativeMemoryPool.h" + +namespace arrow { +namespace dataset { +namespace jni { + +static jclass illegal_access_exception_class; +static jclass illegal_argument_exception_class; +static jclass runtime_exception_class; + +static jclass record_batch_handle_class; +static jclass record_batch_handle_field_class; +static jclass record_batch_handle_buffer_class; +static jclass java_reservation_listener_class; + +static jmethodID record_batch_handle_constructor; +static jmethodID record_batch_handle_field_constructor; +static jmethodID record_batch_handle_buffer_constructor; +static jmethodID reserve_memory_method; +static jmethodID unreserve_memory_method; + +static jlong default_memory_pool_id = -1L; + +static jint JNI_VERSION = JNI_VERSION_1_6; + +class JniPendingException : public std::runtime_error { + public: + explicit JniPendingException(const std::string& arg) : runtime_error(arg) {} +}; + +void ThrowPendingException(const std::string& message) { + throw JniPendingException(message); +} + +template <typename T> +T JniGetOrThrow(arrow::Result<T> result) { + if (!result.status().ok()) { + ThrowPendingException(result.status().message()); + } + return std::move(result).ValueOrDie(); +} + +void JniAssertOkOrThrow(arrow::Status status) { + if (!status.ok()) { + ThrowPendingException(status.message()); + } +} + +void JniThrow(std::string message) { ThrowPendingException(message); } + +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { + jclass local_class = env->FindClass(class_name); + jclass global_class = (jclass)env->NewGlobalRef(local_class); + env->DeleteLocalRef(local_class); + return global_class; +} + +arrow::Result<jmethodID> GetMethodID(JNIEnv* env, jclass this_class, const char* name, + const char* sig) { + jmethodID ret = env->GetMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find method " + std::string(name) + + " within signature" + std::string(sig); + return arrow::Status::Invalid(error_message); + } + return ret; +} + +arrow::Result<jmethodID> GetStaticMethodID(JNIEnv* env, jclass this_class, + const char* name, const char* sig) { + jmethodID ret = env->GetStaticMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find static method " + std::string(name) + + " within signature" + std::string(sig); + return arrow::Status::Invalid(error_message); + } + return ret; +} + +std::shared_ptr<arrow::Schema> SchemaFromColumnNames( + const std::shared_ptr<arrow::Schema>& input, + const std::vector<std::string>& column_names) { + std::vector<std::shared_ptr<arrow::Field>> columns; + for (const auto& name : column_names) { + columns.push_back(input->GetFieldByName(name)); + } + return std::make_shared<arrow::Schema>(columns); +} + +arrow::Result<std::shared_ptr<arrow::dataset::FileFormat>> GetFileFormat(jint id) { + switch (id) { + case 0: + return std::make_shared<arrow::dataset::ParquetFileFormat>(); + default: + std::string error_message = "illegal file format id: " + std::to_string(id); + return arrow::Status::Invalid(error_message); + } +} + +std::string JStringToCString(JNIEnv* env, jstring string) { + if (string == nullptr) { + return std::string(); + } + jboolean copied; + const char* chars = env->GetStringUTFChars(string, &copied); + std::string ret = strdup(chars); + env->ReleaseStringUTFChars(string, chars); + return ret; +} + +std::vector<std::string> ToStringVector(JNIEnv* env, jobjectArray& str_array) { + int length = env->GetArrayLength(str_array); + std::vector<std::string> vector; + for (int i = 0; i < length; i++) { + auto string = (jstring)(env->GetObjectArrayElement(str_array, i)); + vector.push_back(JStringToCString(env, string)); + } + return vector; +} + +template <typename T> +jlong CreateNativeRef(std::shared_ptr<T> t) { + std::shared_ptr<T>* retained_ptr = new std::shared_ptr<T>(t); + return reinterpret_cast<jlong>(retained_ptr); +} + +template <typename T> +std::shared_ptr<T> RetrieveNativeInstance(jlong ref) { + std::shared_ptr<T>* retrieved_ptr = reinterpret_cast<std::shared_ptr<T>*>(ref); + return *retrieved_ptr; +} + +template <typename T> +void ReleaseNativeRef(jlong ref) { + std::shared_ptr<T>* retrieved_ptr = reinterpret_cast<std::shared_ptr<T>*>(ref); + delete retrieved_ptr; +} + +arrow::Result<jbyteArray> ToSchemaByteArray(JNIEnv* env, + std::shared_ptr<arrow::Schema> schema) { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr<arrow::Buffer> buffer, + arrow::ipc::SerializeSchema(*schema, arrow::default_memory_pool())) + + jbyteArray out = env->NewByteArray(buffer->size()); + auto src = reinterpret_cast<const jbyte*>(buffer->data()); + env->SetByteArrayRegion(out, 0, buffer->size(), src); + return out; +} + +arrow::Result<std::shared_ptr<arrow::Schema>> FromSchemaByteArray( + JNIEnv* env, jbyteArray schemaBytes) { + arrow::ipc::DictionaryMemo in_memo; + int schemaBytes_len = env->GetArrayLength(schemaBytes); + jbyte* schemaBytes_data = env->GetByteArrayElements(schemaBytes, nullptr); + auto serialized_schema = std::make_shared<arrow::Buffer>( + reinterpret_cast<uint8_t*>(schemaBytes_data), schemaBytes_len); + arrow::io::BufferReader buf_reader(serialized_schema); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::Schema> schema, + arrow::ipc::ReadSchema(&buf_reader, &in_memo)) + env->ReleaseByteArrayElements(schemaBytes, schemaBytes_data, JNI_ABORT); + return schema; +} + +/// Listener to act on reservations/unreservations from ReservationListenableMemoryPool. +/// +/// Note the memory pool will call this listener only on block-level memory +/// reservation/unreservation is granted. So the invocation parameter "size" is always +/// multiple of block size (by default, 512k) specified in memory pool. +class ReservationListener { + public: + virtual ~ReservationListener() = default; + + virtual arrow::Status OnReservation(int64_t size) = 0; + virtual arrow::Status OnRelease(int64_t size) = 0; + + protected: + ReservationListener() = default; +}; + +class ReservationListenableMemoryPoolImpl { Review comment: See above. Do we still have to do this after the refactors ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org