This is an automated email from the ASF dual-hosted git repository.
panxiaolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new c07e2ada43 [imporve](udaf) refactor java-udaf executor by using for
loop (#21713)
c07e2ada43 is described below
commit c07e2ada43e870c17e4308477eae3979fff9a3e1
Author: zhangstar333 <[email protected]>
AuthorDate: Fri Jul 14 11:37:19 2023 +0800
[imporve](udaf) refactor java-udaf executor by using for loop (#21713)
refactor java-udaf executor by using for loop
---
.../aggregate_function_java_udaf.h | 119 ++++++----
fe/be-java-extensions/java-udf/pom.xml | 6 +
.../java/org/apache/doris/udf/BaseExecutor.java | 181 ++++++++++++++
.../java/org/apache/doris/udf/UdafExecutor.java | 89 ++++++-
.../main/java/org/apache/doris/udf/UdfConvert.java | 262 +++++++++++----------
.../java/org/apache/doris/udf/UdfExecutor.java | 165 +------------
6 files changed, 485 insertions(+), 337 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
index fa0c4efd9d..6fe4742064 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
@@ -128,63 +128,80 @@ public:
return Status::OK();
}
- Status add(const int64_t places_address[], bool is_single_place, const
IColumn** columns,
- size_t row_num_start, size_t row_num_end, const DataTypes&
argument_types) {
+ Status add(int64_t places_address, bool is_single_place, const IColumn**
columns,
+ int row_num_start, int row_num_end, const DataTypes&
argument_types,
+ int place_offset) {
JNIEnv* env = nullptr;
RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf
add function");
+ jclass obj_class = env->FindClass("[Ljava/lang/Object;");
+ jobjectArray arg_objects = env->NewObjectArray(argument_size,
obj_class, nullptr);
+ int64_t nullmap_address = 0;
+
for (int arg_idx = 0; arg_idx < argument_size; ++arg_idx) {
+ bool arg_column_nullable = false;
auto data_col = columns[arg_idx];
if (auto* nullable = check_and_get_column<const
ColumnNullable>(*columns[arg_idx])) {
+ arg_column_nullable = true;
+ auto null_col = nullable->get_null_map_column_ptr();
data_col = nullable->get_nested_column_ptr();
- auto null_col = check_and_get_column<ColumnVector<UInt8>>(
- nullable->get_null_map_column_ptr());
- input_nulls_buffer_ptr.get()[arg_idx] =
- reinterpret_cast<int64_t>(null_col->get_data().data());
- } else {
- input_nulls_buffer_ptr.get()[arg_idx] = -1;
+ nullmap_address = reinterpret_cast<int64_t>(
+
check_and_get_column<ColumnVector<UInt8>>(null_col)->get_data().data());
}
- if (data_col->is_column_string()) {
- const ColumnString* str_col =
check_and_get_column<ColumnString>(data_col);
- input_values_buffer_ptr.get()[arg_idx] =
- reinterpret_cast<int64_t>(str_col->get_chars().data());
- input_offsets_ptrs.get()[arg_idx] =
-
reinterpret_cast<int64_t>(str_col->get_offsets().data());
- } else if (data_col->is_numeric() ||
data_col->is_column_decimal()) {
- input_values_buffer_ptr.get()[arg_idx] =
-
reinterpret_cast<int64_t>(data_col->get_raw_data().data);
+ // convert argument column data into java type
+ jobjectArray arr_obj = nullptr;
+ if (data_col->is_numeric() || data_col->is_column_decimal()) {
+ arr_obj = (jobjectArray)env->CallObjectMethod(
+ executor_obj, executor_convert_basic_argument_id,
arg_idx,
+ arg_column_nullable, row_num_start, row_num_end,
nullmap_address,
+
reinterpret_cast<int64_t>(data_col->get_raw_data().data), 0);
+ } else if (data_col->is_column_string()) {
+ const ColumnString* str_col = assert_cast<const
ColumnString*>(data_col);
+ arr_obj = (jobjectArray)env->CallObjectMethod(
+ executor_obj, executor_convert_basic_argument_id,
arg_idx,
+ arg_column_nullable, row_num_start, row_num_end,
nullmap_address,
+ reinterpret_cast<int64_t>(str_col->get_chars().data()),
+
reinterpret_cast<int64_t>(str_col->get_offsets().data()));
} else if (data_col->is_column_array()) {
const ColumnArray* array_col = assert_cast<const
ColumnArray*>(data_col);
- input_offsets_ptrs.get()[arg_idx] = reinterpret_cast<int64_t>(
- array_col->get_offsets_column().get_raw_data().data);
const ColumnNullable& array_nested_nullable =
assert_cast<const
ColumnNullable&>(array_col->get_data());
auto data_column_null_map =
array_nested_nullable.get_null_map_column_ptr();
auto data_column =
array_nested_nullable.get_nested_column_ptr();
- input_array_nulls_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(
+ auto offset_address = reinterpret_cast<int64_t>(
+ array_col->get_offsets_column().get_raw_data().data);
+ auto nested_nullmap_address = reinterpret_cast<int64_t>(
check_and_get_column<ColumnVector<UInt8>>(data_column_null_map)
->get_data()
.data());
-
- //need pass FE, nullamp and offset, chars
+ int64_t nested_data_address = 0, nested_offset_address = 0;
+ // array type need pass address: [nullmap_address],
offset_address, nested_nullmap_address,
nested_data_address/nested_char_address,nested_offset_address
if (data_column->is_column_string()) {
const ColumnString* col = assert_cast<const
ColumnString*>(data_column.get());
- input_values_buffer_ptr.get()[arg_idx] =
- reinterpret_cast<int64_t>(col->get_chars().data());
- input_array_string_offsets_ptrs.get()[arg_idx] =
-
reinterpret_cast<int64_t>(col->get_offsets().data());
+ nested_data_address =
reinterpret_cast<int64_t>(col->get_chars().data());
+ nested_offset_address =
reinterpret_cast<int64_t>(col->get_offsets().data());
} else {
- input_values_buffer_ptr.get()[arg_idx] =
+ nested_data_address =
reinterpret_cast<int64_t>(data_column->get_raw_data().data);
}
+ arr_obj = (jobjectArray)env->CallObjectMethod(
+ executor_obj, executor_convert_array_argument_id,
arg_idx,
+ arg_column_nullable, row_num_start, row_num_end,
nullmap_address,
+ offset_address, nested_nullmap_address,
nested_data_address,
+ nested_offset_address);
} else {
return Status::InvalidArgument(
strings::Substitute("Java UDAF doesn't support type is
$0 now !",
argument_types[arg_idx]->get_name()));
}
+ env->SetObjectArrayElement(arg_objects, arg_idx, arr_obj);
+ env->DeleteLocalRef(arr_obj);
}
- *input_place_ptrs = reinterpret_cast<int64_t>(places_address);
- env->CallNonvirtualVoidMethod(executor_obj, executor_cl,
executor_add_id, is_single_place,
- row_num_start, row_num_end);
+ RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
+ // invoke add batch
+ env->CallObjectMethod(executor_obj, executor_add_batch_id,
is_single_place, row_num_start,
+ row_num_end, places_address, place_offset,
arg_objects);
+ env->DeleteLocalRef(arg_objects);
+ env->DeleteLocalRef(obj_class);
return JniUtil::GetJniExceptionMsg(env);
}
@@ -392,6 +409,12 @@ private:
register_id("getValue", UDAF_EXECUTOR_RESULT_SIGNATURE,
executor_result_id));
RETURN_IF_ERROR(
register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE,
executor_destroy_id));
+ RETURN_IF_ERROR(register_id("convertBasicArguments",
"(IZIIJJJ)[Ljava/lang/Object;",
+ executor_convert_basic_argument_id));
+ RETURN_IF_ERROR(register_id("convertArrayArguments",
"(IZIIJJJJJ)[Ljava/lang/Object;",
+ executor_convert_array_argument_id));
+ RETURN_IF_ERROR(
+ register_id("addBatch", "(ZIIJI[Ljava/lang/Object;)V",
executor_add_batch_id));
return Status::OK();
}
@@ -403,12 +426,15 @@ private:
jmethodID executor_ctor_id;
jmethodID executor_add_id;
+ jmethodID executor_add_batch_id;
jmethodID executor_merge_id;
jmethodID executor_serialize_id;
jmethodID executor_result_id;
jmethodID executor_reset_id;
jmethodID executor_close_id;
jmethodID executor_destroy_id;
+ jmethodID executor_convert_basic_argument_id;
+ jmethodID executor_convert_array_argument_id;
std::unique_ptr<int64_t[]> input_values_buffer_ptr;
std::unique_ptr<int64_t[]> input_nulls_buffer_ptr;
@@ -481,11 +507,10 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
Arena*) const override {
- int64_t places_address[1];
- places_address[0] = reinterpret_cast<int64_t>(place);
- Status st =
- this->data(_exec_place)
- .add(places_address, true, columns, row_num, row_num +
1, argument_types);
+ int64_t places_address = reinterpret_cast<int64_t>(place);
+ Status st = this->data(_exec_place)
+ .add(places_address, true, columns, row_num,
row_num + 1,
+ argument_types, 0);
if (UNLIKELY(st != Status::OK())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
@@ -493,25 +518,20 @@ public:
void add_batch(size_t batch_size, AggregateDataPtr* places, size_t
place_offset,
const IColumn** columns, Arena* /*arena*/, bool
/*agg_many*/) const override {
- int64_t places_address[batch_size];
- for (size_t i = 0; i < batch_size; ++i) {
- places_address[i] = reinterpret_cast<int64_t>(places[i] +
place_offset);
- }
+ int64_t places_address = reinterpret_cast<int64_t>(places);
Status st = this->data(_exec_place)
- .add(places_address, false, columns, 0,
batch_size, argument_types);
+ .add(places_address, false, columns, 0,
batch_size, argument_types,
+ place_offset);
if (UNLIKELY(st != Status::OK())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
}
- // TODO: Here we calling method by jni, And if we get a thrown from FE,
- // But can't let user known the error, only return directly and output
error to log file.
void add_batch_single_place(size_t batch_size, AggregateDataPtr place,
const IColumn** columns,
Arena* /*arena*/) const override {
- int64_t places_address[1];
- places_address[0] = reinterpret_cast<int64_t>(place);
+ int64_t places_address = reinterpret_cast<int64_t>(place);
Status st = this->data(_exec_place)
- .add(places_address, true, columns, 0, batch_size,
argument_types);
+ .add(places_address, true, columns, 0, batch_size,
argument_types, 0);
if (UNLIKELY(st != Status::OK())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
@@ -522,11 +542,10 @@ public:
Arena* arena) const override {
frame_start = std::max<int64_t>(frame_start, partition_start);
frame_end = std::min<int64_t>(frame_end, partition_end);
- int64_t places_address[1];
- places_address[0] = reinterpret_cast<int64_t>(place);
- Status st =
- this->data(_exec_place)
- .add(places_address, true, columns, frame_start,
frame_end, argument_types);
+ int64_t places_address = reinterpret_cast<int64_t>(place);
+ Status st = this->data(_exec_place)
+ .add(places_address, true, columns, frame_start,
frame_end,
+ argument_types, 0);
if (UNLIKELY(st != Status::OK())) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
}
diff --git a/fe/be-java-extensions/java-udf/pom.xml
b/fe/be-java-extensions/java-udf/pom.xml
index bf05aeafc2..67921aa2cf 100644
--- a/fe/be-java-extensions/java-udf/pom.xml
+++ b/fe/be-java-extensions/java-udf/pom.xml
@@ -41,6 +41,12 @@ under the License.
<artifactId>java-common</artifactId>
<version>${project.version}</version>
</dependency>
+ <!--
https://mvnrepository.com/artifact/com.esotericsoftware/reflectasm -->
+ <dependency>
+ <groupId>com.esotericsoftware</groupId>
+ <artifactId>reflectasm</artifactId>
+ <version>1.11.9</version>
+ </dependency>
</dependencies>
<build>
diff --git
a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
index 3dbe10ca27..ef405197d6 100644
---
a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
+++
b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
@@ -25,12 +25,14 @@ import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
+import com.google.common.base.Preconditions;
import org.apache.log4j.Logger;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import java.io.IOException;
+import java.lang.reflect.Array;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
@@ -1021,4 +1023,183 @@ public abstract class BaseExecutor {
protected void updateOutputOffset(long offset) {
}
+
+ public Object[] convertBasicArg(boolean isUdf, int argIdx, boolean
isNullable, int rowStart, int rowEnd,
+ long nullMapAddr, long columnAddr, long strOffsetAddr) {
+ switch (argTypes[argIdx]) {
+ case BOOLEAN:
+ return UdfConvert.convertBooleanArg(isNullable, rowStart,
rowEnd, nullMapAddr, columnAddr);
+ case TINYINT:
+ return UdfConvert.convertTinyIntArg(isNullable, rowStart,
rowEnd, nullMapAddr, columnAddr);
+ case SMALLINT:
+ return UdfConvert.convertSmallIntArg(isNullable, rowStart,
rowEnd, nullMapAddr, columnAddr);
+ case INT:
+ return UdfConvert.convertIntArg(isNullable, rowStart, rowEnd,
nullMapAddr, columnAddr);
+ case BIGINT:
+ return UdfConvert.convertBigIntArg(isNullable, rowStart,
rowEnd, nullMapAddr, columnAddr);
+ case LARGEINT:
+ return UdfConvert.convertLargeIntArg(isNullable, rowStart,
rowEnd, nullMapAddr, columnAddr);
+ case FLOAT:
+ return UdfConvert.convertFloatArg(isNullable, rowStart,
rowEnd, nullMapAddr, columnAddr);
+ case DOUBLE:
+ return UdfConvert.convertDoubleArg(isNullable, rowStart,
rowEnd, nullMapAddr, columnAddr);
+ case CHAR:
+ case VARCHAR:
+ case STRING:
+ return UdfConvert
+ .convertStringArg(isNullable, rowStart, rowEnd,
nullMapAddr, columnAddr, strOffsetAddr);
+ case DATE: // udaf maybe argClass[i + argClassOffset] need add +1
+ return UdfConvert
+ .convertDateArg(isUdf ? argClass[argIdx] :
argClass[argIdx + 1], isNullable, rowStart, rowEnd,
+ nullMapAddr, columnAddr);
+ case DATETIME:
+ return UdfConvert
+ .convertDateTimeArg(isUdf ? argClass[argIdx] :
argClass[argIdx + 1], isNullable, rowStart,
+ rowEnd, nullMapAddr, columnAddr);
+ case DATEV2:
+ return UdfConvert
+ .convertDateV2Arg(isUdf ? argClass[argIdx] :
argClass[argIdx + 1], isNullable, rowStart, rowEnd,
+ nullMapAddr, columnAddr);
+ case DATETIMEV2:
+ return UdfConvert
+ .convertDateTimeV2Arg(isUdf ? argClass[argIdx] :
argClass[argIdx + 1], isNullable, rowStart,
+ rowEnd, nullMapAddr, columnAddr);
+ case DECIMALV2:
+ case DECIMAL128:
+ return UdfConvert
+ .convertDecimalArg(argTypes[argIdx].getScale(), 16L,
isNullable, rowStart, rowEnd, nullMapAddr,
+ columnAddr);
+ case DECIMAL32:
+ return UdfConvert
+ .convertDecimalArg(argTypes[argIdx].getScale(), 4L,
isNullable, rowStart, rowEnd, nullMapAddr,
+ columnAddr);
+ case DECIMAL64:
+ return UdfConvert
+ .convertDecimalArg(argTypes[argIdx].getScale(), 8L,
isNullable, rowStart, rowEnd, nullMapAddr,
+ columnAddr);
+ default: {
+ LOG.info("Not support type: " + argTypes[argIdx].toString());
+ Preconditions.checkState(false, "Not support type: " +
argTypes[argIdx].toString());
+ break;
+ }
+ }
+ return null;
+ }
+
+ public Object[] convertArrayArg(int argIdx, boolean isNullable, int
rowStart, int rowEnd, long nullMapAddr,
+ long offsetsAddr, long nestedNullMapAddr, long dataAddr, long
strOffsetAddr) {
+ Object[] argument = (Object[]) Array.newInstance(ArrayList.class,
rowEnd - rowStart);
+ for (int row = rowStart; row < rowEnd; ++row) {
+ long offsetStart = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L
* (row - 1));
+ long offsetEnd = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L *
(row));
+ int currentRowNum = (int) (offsetEnd - offsetStart);
+ switch (argTypes[argIdx].getItemType().getPrimitiveType()) {
+ case BOOLEAN: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayBooleanArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case TINYINT: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayTinyIntArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case SMALLINT: {
+ argument[row - rowStart] = UdfConvert
+ .convertArraySmallIntArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case INT: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayIntArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case BIGINT: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayBigIntArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case LARGEINT: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayLargeIntArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case FLOAT: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayFloatArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case DOUBLE: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayDoubleArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case CHAR:
+ case VARCHAR:
+ case STRING: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayStringArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr,
strOffsetAddr);
+ break;
+ }
+ case DATE: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayDateArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case DATETIME: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayDateTimeArg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case DATEV2: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayDateV2Arg(row, currentRowNum,
offsetStart, isNullable, nullMapAddr,
+ nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case DATETIMEV2: {
+ argument[row - rowStart] = UdfConvert
+ .convertArrayDateTimeV2Arg(row, currentRowNum,
offsetStart, isNullable,
+ nullMapAddr, nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case DECIMALV2:
+ case DECIMAL128: {
+ argument[row - rowStart] = UdfConvert
+
.convertArrayDecimalArg(argTypes[argIdx].getScale(), 16L, row, currentRowNum,
+ offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case DECIMAL32: {
+ argument[row - rowStart] = UdfConvert
+
.convertArrayDecimalArg(argTypes[argIdx].getScale(), 4L, row, currentRowNum,
+ offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
+ break;
+ }
+ case DECIMAL64: {
+ argument[row - rowStart] = UdfConvert
+
.convertArrayDecimalArg(argTypes[argIdx].getScale(), 8L, row, currentRowNum,
+ offsetStart, isNullable, nullMapAddr,
nestedNullMapAddr, dataAddr);
+ break;
+ }
+ default: {
+ LOG.info("Not support: " + argTypes[argIdx]);
+ Preconditions.checkState(false, "Not support type " +
argTypes[argIdx].toString());
+ break;
+ }
+ }
+ }
+ return argument;
+ }
}
diff --git
a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
index e2d8ab1b75..a0736b5a72 100644
---
a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
+++
b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
@@ -24,6 +24,7 @@ import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
+import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
@@ -36,6 +37,7 @@ import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
/**
@@ -49,6 +51,8 @@ public class UdafExecutor extends BaseExecutor {
private HashMap<String, Method> allMethods;
private HashMap<Long, Object> stateObjMap;
private Class retClass;
+ private int addIndex;
+ private MethodAccess methodAccess;
/**
* Constructor to create an object.
@@ -66,6 +70,84 @@ public class UdafExecutor extends BaseExecutor {
super.close();
}
+ public Object[] convertBasicArguments(int argIdx, boolean isNullable, int
rowStart, int rowEnd, long nullMapAddr,
+ long columnAddr, long strOffsetAddr) {
+ return convertBasicArg(false, argIdx, isNullable, rowStart, rowEnd,
nullMapAddr, columnAddr, strOffsetAddr);
+ }
+
+ public Object[] convertArrayArguments(int argIdx, boolean isNullable, int
rowStart, int rowEnd, long nullMapAddr,
+ long offsetsAddr, long nestedNullMapAddr, long dataAddr, long
strOffsetAddr) {
+ return convertArrayArg(argIdx, isNullable, rowStart, rowEnd,
nullMapAddr, offsetsAddr, nestedNullMapAddr,
+ dataAddr, strOffsetAddr);
+ }
+
+ public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long
placeAddr, int offset, Object[] column)
+ throws UdfRuntimeException {
+ if (isSinglePlace) {
+ addBatchSingle(rowStart, rowEnd, placeAddr, column);
+ } else {
+ addBatchPlaces(rowStart, rowEnd, placeAddr, offset, column);
+ }
+ }
+
+ public void addBatchSingle(int rowStart, int rowEnd, long placeAddr,
Object[] column) throws UdfRuntimeException {
+ try {
+ Long curPlace = placeAddr;
+ Object[] inputArgs = new Object[argTypes.length + 1];
+ Object state = stateObjMap.get(curPlace);
+ if (state != null) {
+ inputArgs[0] = state;
+ } else {
+ Object newState = createAggState();
+ stateObjMap.put(curPlace, newState);
+ inputArgs[0] = newState;
+ }
+
+ Object[][] inputs = (Object[][]) column;
+ for (int i = 0; i < (rowEnd - rowStart); ++i) {
+ for (int j = 0; j < column.length; ++j) {
+ inputArgs[j + 1] = inputs[j][i];
+ }
+ methodAccess.invoke(udf, addIndex, inputArgs);
+ }
+ } catch (Exception e) {
+ LOG.warn("invoke add function meet some error: " +
e.getCause().toString());
+ throw new UdfRuntimeException("UDAF failed to addBatchSingle: ",
e);
+ }
+ }
+
+ public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int
offset, Object[] column)
+ throws UdfRuntimeException {
+ try {
+ Object[][] inputs = (Object[][]) column;
+ ArrayList<Object> placeState = new ArrayList<>(rowEnd - rowStart);
+ for (int row = rowStart; row < rowEnd; ++row) {
+ Long curPlace = UdfUtils.UNSAFE.getLong(null, placeAddr + (8L
* row)) + offset;
+ Object state = stateObjMap.get(curPlace);
+ if (state != null) {
+ placeState.add(state);
+ } else {
+ Object newState = createAggState();
+ stateObjMap.put(curPlace, newState);
+ placeState.add(newState);
+ }
+ }
+ //spilt into two for loop
+
+ Object[] inputArgs = new Object[argTypes.length + 1];
+ for (int row = 0; row < (rowEnd - rowStart); ++row) {
+ inputArgs[0] = placeState.get(row);
+ for (int j = 0; j < column.length; ++j) {
+ inputArgs[j + 1] = inputs[j][row];
+ }
+ methodAccess.invoke(udf, addIndex, inputArgs);
+ }
+ } catch (Exception e) {
+ LOG.warn("invoke add function meet some error: " +
Arrays.toString(e.getStackTrace()));
+ throw new UdfRuntimeException("UDAF failed to addBatchPlaces: ",
e);
+ }
+ }
+
/**
* invoke add function, add row in loop [rowStart, rowEnd).
*/
@@ -224,10 +306,10 @@ public class UdafExecutor extends BaseExecutor {
protected long getCurrentOutputOffset(long row, boolean isArrayType) {
if (isArrayType) {
return Integer.toUnsignedLong(
- UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 8L * (row - 1)));
+ UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 8L * (row - 1)));
} else {
return Integer.toUnsignedLong(
- UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 4L * (row - 1)));
+ UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 4L * (row - 1)));
}
}
@@ -251,6 +333,7 @@ public class UdafExecutor extends BaseExecutor {
loader = ClassLoader.getSystemClassLoader();
}
Class<?> c = Class.forName(className, true, loader);
+ methodAccess = MethodAccess.get(c);
Constructor<?> ctor = c.getConstructor();
udf = ctor.newInstance();
Method[] methods = c.getDeclaredMethods();
@@ -281,7 +364,7 @@ public class UdafExecutor extends BaseExecutor {
}
case UDAF_ADD_FUNCTION: {
allMethods.put(methods[idx].getName(), methods[idx]);
-
+ addIndex = methodAccess.getIndex(UDAF_ADD_FUNCTION);
argClass = methods[idx].getParameterTypes();
if (argClass.length != parameterTypes.length + 1) {
LOG.debug("add function parameterTypes length not
equal " + argClass.length + " "
diff --git
a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java
b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java
index 4519b23a54..fb2ead5a3f 100644
---
a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java
+++
b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java
@@ -37,263 +37,269 @@ import java.util.Arrays;
public class UdfConvert {
private static final Logger LOG = Logger.getLogger(UdfConvert.class);
- public static Object[] convertBooleanArg(boolean isNullable, int numRows,
long nullMapAddr, long columnAddr) {
- Boolean[] argument = new Boolean[numRows];
+ public static Object[] convertBooleanArg(boolean isNullable, int
rowsStart, int rowsEnd, long nullMapAddr,
+ long columnAddr) {
+ Boolean[] argument = new Boolean[rowsEnd - rowsStart];
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
- argument[i] = UdfUtils.UNSAFE.getBoolean(null, columnAddr
+ i);
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getBoolean(null,
columnAddr + i);
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
- argument[i] = UdfUtils.UNSAFE.getBoolean(null, columnAddr + i);
+ for (int i = rowsStart; i < rowsEnd; ++i) {
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getBoolean(null,
columnAddr + i);
}
}
return argument;
}
- public static Object[] convertTinyIntArg(boolean isNullable, int numRows,
long nullMapAddr, long columnAddr) {
- Byte[] argument = new Byte[numRows];
+ public static Object[] convertTinyIntArg(boolean isNullable, int
rowsStart, int rowsEnd, long nullMapAddr,
+ long columnAddr) {
+ Byte[] argument = new Byte[rowsEnd - rowsStart];
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
- argument[i] = UdfUtils.UNSAFE.getByte(null, columnAddr +
i);
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getByte(null,
columnAddr + i);
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
- argument[i] = UdfUtils.UNSAFE.getByte(null, columnAddr + i);
+ for (int i = rowsStart; i < rowsEnd; ++i) {
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getByte(null,
columnAddr + i);
}
}
return argument;
}
- public static Object[] convertSmallIntArg(boolean isNullable, int numRows,
long nullMapAddr, long columnAddr) {
- Short[] argument = new Short[numRows];
+ public static Object[] convertSmallIntArg(boolean isNullable, int
rowsStart, int rowsEnd, long nullMapAddr,
+ long columnAddr) {
+ Short[] argument = new Short[rowsEnd - rowsStart];
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
- argument[i] = UdfUtils.UNSAFE.getShort(null, columnAddr +
(i * 2L));
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getShort(null,
columnAddr + (i * 2L));
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
- argument[i] = UdfUtils.UNSAFE.getShort(null, columnAddr + (i *
2L));
+ for (int i = rowsStart; i < rowsEnd; ++i) {
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getShort(null,
columnAddr + (i * 2L));
}
}
return argument;
}
- public static Object[] convertIntArg(boolean isNullable, int numRows, long
nullMapAddr, long columnAddr) {
- Integer[] argument = new Integer[numRows];
+ public static Object[] convertIntArg(boolean isNullable, int rowsStart,
int rowsEnd, long nullMapAddr,
+ long columnAddr) {
+ Integer[] argument = new Integer[rowsEnd - rowsStart];
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
- argument[i] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i
* 4L));
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getInt(null,
columnAddr + (i * 4L));
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
- argument[i] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i *
4L));
+ for (int i = rowsStart; i < rowsEnd; ++i) {
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getInt(null,
columnAddr + (i * 4L));
}
}
return argument;
}
- public static Object[] convertBigIntArg(boolean isNullable, int numRows,
long nullMapAddr, long columnAddr) {
- Long[] argument = new Long[numRows];
+ public static Object[] convertBigIntArg(boolean isNullable, int rowsStart,
int rowsEnd, long nullMapAddr,
+ long columnAddr) {
+ Long[] argument = new Long[rowsEnd - rowsStart];
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
- argument[i] = UdfUtils.UNSAFE.getLong(null, columnAddr +
(i * 8L));
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getLong(null,
columnAddr + (i * 8L));
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
- argument[i] = UdfUtils.UNSAFE.getLong(null, columnAddr + (i *
8L));
+ for (int i = rowsStart; i < rowsEnd; ++i) {
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getLong(null,
columnAddr + (i * 8L));
}
}
return argument;
}
- public static Object[] convertFloatArg(boolean isNullable, int numRows,
long nullMapAddr, long columnAddr) {
- Float[] argument = new Float[numRows];
+ public static Object[] convertFloatArg(boolean isNullable, int rowsStart,
int rowsEnd, long nullMapAddr,
+ long columnAddr) {
+ Float[] argument = new Float[rowsEnd - rowsStart];
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
- argument[i] = UdfUtils.UNSAFE.getFloat(null, columnAddr +
(i * 4L));
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getFloat(null,
columnAddr + (i * 4L));
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
- argument[i] = UdfUtils.UNSAFE.getFloat(null, columnAddr + (i *
4L));
+ for (int i = rowsStart; i < rowsEnd; ++i) {
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getFloat(null,
columnAddr + (i * 4L));
}
}
return argument;
}
- public static Object[] convertDoubleArg(boolean isNullable, int numRows,
long nullMapAddr, long columnAddr) {
- Double[] argument = new Double[numRows];
+ public static Object[] convertDoubleArg(boolean isNullable, int rowsStart,
int rowsEnd, long nullMapAddr,
+ long columnAddr) {
+ Double[] argument = new Double[rowsEnd - rowsStart];
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
- argument[i] = UdfUtils.UNSAFE.getDouble(null, columnAddr +
(i * 8L));
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getDouble(null,
columnAddr + (i * 8L));
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
- argument[i] = UdfUtils.UNSAFE.getDouble(null, columnAddr + (i
* 8L));
+ for (int i = rowsStart; i < rowsEnd; ++i) {
+ argument[i - rowsStart] = UdfUtils.UNSAFE.getDouble(null,
columnAddr + (i * 8L));
}
}
return argument;
}
- public static Object[] convertDateArg(Class argTypeClass, boolean
isNullable, int numRows, long nullMapAddr,
- long columnAddr) {
- Object[] argument = (Object[]) Array.newInstance(argTypeClass,
numRows);
+ public static Object[] convertDateArg(Class argTypeClass, boolean
isNullable, int rowsStart, int rowsEnd,
+ long nullMapAddr, long columnAddr) {
+ Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd
- rowsStart);
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i
* 8L));
- argument[i] = UdfUtils.convertDateToJavaDate(value,
argTypeClass);
+ argument[i - rowsStart] =
UdfUtils.convertDateToJavaDate(value, argTypeClass);
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i *
8L));
- argument[i] = UdfUtils.convertDateToJavaDate(value,
argTypeClass);
+ argument[i - rowsStart] =
UdfUtils.convertDateToJavaDate(value, argTypeClass);
}
}
return argument;
}
- public static Object[] convertDateTimeArg(Class argTypeClass, boolean
isNullable, int numRows, long nullMapAddr,
- long columnAddr) {
- Object[] argument = (Object[]) Array.newInstance(argTypeClass,
numRows);
+ public static Object[] convertDateTimeArg(Class argTypeClass, boolean
isNullable, int rowsStart, int rowsEnd,
+ long nullMapAddr, long columnAddr) {
+ Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd
- rowsStart);
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i
* 8L));
- argument[i] = UdfUtils
+ argument[i - rowsStart] = UdfUtils
.convertDateTimeToJavaDateTime(value,
argTypeClass);
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i *
8L));
- argument[i] = UdfUtils.convertDateTimeToJavaDateTime(value,
argTypeClass);
+ argument[i - rowsStart] =
UdfUtils.convertDateTimeToJavaDateTime(value, argTypeClass);
}
}
return argument;
}
- public static Object[] convertDateV2Arg(Class argTypeClass, boolean
isNullable, int numRows, long nullMapAddr,
- long columnAddr) {
- Object[] argument = (Object[]) Array.newInstance(argTypeClass,
numRows);
+ public static Object[] convertDateV2Arg(Class argTypeClass, boolean
isNullable, int rowsStart, int rowsEnd,
+ long nullMapAddr, long columnAddr) {
+ Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd
- rowsStart);
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
int value = UdfUtils.UNSAFE.getInt(null, columnAddr + (i *
4L));
- argument[i] = UdfUtils.convertDateV2ToJavaDate(value,
argTypeClass);
+ argument[i - rowsStart] =
UdfUtils.convertDateV2ToJavaDate(value, argTypeClass);
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
int value = UdfUtils.UNSAFE.getInt(null, columnAddr + (i *
4L));
- argument[i] = UdfUtils.convertDateV2ToJavaDate(value,
argTypeClass);
+ argument[i - rowsStart] =
UdfUtils.convertDateV2ToJavaDate(value, argTypeClass);
}
}
return argument;
}
- public static Object[] convertDateTimeV2Arg(Class argTypeClass, boolean
isNullable, int numRows, long nullMapAddr,
- long columnAddr) {
- Object[] argument = (Object[]) Array.newInstance(argTypeClass,
numRows);
+ public static Object[] convertDateTimeV2Arg(Class argTypeClass, boolean
isNullable, int rowsStart, int rowsEnd,
+ long nullMapAddr, long columnAddr) {
+ Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd
- rowsStart);
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(null, nullMapAddr + i) == 0) {
long value = UdfUtils.UNSAFE.getLong(columnAddr + (i *
8L));
- argument[i] = UdfUtils
+ argument[i - rowsStart] = UdfUtils
.convertDateTimeV2ToJavaDateTime(value,
argTypeClass);
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i *
8L));
- argument[i] = UdfUtils
+ argument[i - rowsStart] = UdfUtils
.convertDateTimeV2ToJavaDateTime(value, argTypeClass);
}
}
return argument;
}
- public static Object[] convertLargeIntArg(boolean isNullable, int numRows,
long nullMapAddr, long columnAddr) {
- BigInteger[] argument = new BigInteger[numRows];
+ public static Object[] convertLargeIntArg(boolean isNullable, int
rowsStart, int rowsEnd, long nullMapAddr,
+ long columnAddr) {
+ BigInteger[] argument = new BigInteger[rowsEnd - rowsStart];
byte[] bytes = new byte[16];
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
UdfUtils.copyMemory(null, columnAddr + (i * 16L), bytes,
UdfUtils.BYTE_ARRAY_OFFSET, 16);
- argument[i] = new
BigInteger(UdfUtils.convertByteOrder(bytes));
+ argument[i - rowsStart] = new
BigInteger(UdfUtils.convertByteOrder(bytes));
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
UdfUtils.copyMemory(null, columnAddr + (i * 16L), bytes,
UdfUtils.BYTE_ARRAY_OFFSET, 16);
- argument[i] = new BigInteger(UdfUtils.convertByteOrder(bytes));
+ argument[i - rowsStart] = new
BigInteger(UdfUtils.convertByteOrder(bytes));
}
}
return argument;
}
- public static Object[] convertDecimalArg(int scale, long typeLen, boolean
isNullable, int numRows, long nullMapAddr,
- long columnAddr) {
- BigDecimal[] argument = new BigDecimal[numRows];
+ public static Object[] convertDecimalArg(int scale, long typeLen, boolean
isNullable, int rowsStart, int rowsEnd,
+ long nullMapAddr, long columnAddr) {
+ BigDecimal[] argument = new BigDecimal[rowsEnd - rowsStart];
byte[] bytes = new byte[(int) typeLen];
if (isNullable) {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) {
UdfUtils.copyMemory(null, columnAddr + (i * typeLen),
bytes, UdfUtils.BYTE_ARRAY_OFFSET, typeLen);
BigInteger bigInteger = new
BigInteger(UdfUtils.convertByteOrder(bytes));
- argument[i] = new BigDecimal(bigInteger, scale); //show to
pass scale info
+ argument[i - rowsStart] = new BigDecimal(bigInteger,
scale); //show to pass scale info
} // else is the current row is null
}
} else {
- for (int i = 0; i < numRows; ++i) {
+ for (int i = rowsStart; i < rowsEnd; ++i) {
UdfUtils.copyMemory(null, columnAddr + (i * typeLen), bytes,
UdfUtils.BYTE_ARRAY_OFFSET, typeLen);
BigInteger bigInteger = new
BigInteger(UdfUtils.convertByteOrder(bytes));
- argument[i] = new BigDecimal(bigInteger, scale);
+ argument[i - rowsStart] = new BigDecimal(bigInteger, scale);
}
}
return argument;
}
- public static Object[] convertStringArg(boolean isNullable, int numRows,
long nullMapAddr,
+ public static Object[] convertStringArg(boolean isNullable, int rowsStart,
int rowsEnd, long nullMapAddr,
long charsAddr, long offsetsAddr) {
- String[] argument = new String[numRows];
+ String[] argument = new String[rowsEnd - rowsStart];
Preconditions.checkState(UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L
* (0 - 1)) == 0,
"offsetsAddr[-1] should be 0;");
-
+ final int totalLen = UdfUtils.UNSAFE.getInt(null, offsetsAddr +
(rowsEnd - 1) * 4L);
+ byte[] bytes = new byte[totalLen];
+ UdfUtils.copyMemory(null, charsAddr, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, totalLen);
if (isNullable) {
- for (int row = 0; row < numRows; ++row) {
+ for (int row = rowsStart; row < rowsEnd; ++row) {
if (UdfUtils.UNSAFE.getByte(nullMapAddr + row) == 0) {
- int offset = UdfUtils.UNSAFE.getInt(null, offsetsAddr +
row * 4L);
- int numBytes = offset - UdfUtils.UNSAFE.getInt(null,
offsetsAddr + 4L * (row - 1));
- long base = charsAddr + offset - numBytes;
- byte[] bytes = new byte[numBytes];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
- argument[row] = new String(bytes, StandardCharsets.UTF_8);
+ int prevOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr
+ 4L * (row - 1));
+ int currOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr
+ row * 4L);
+ argument[row - rowsStart] = new String(bytes, prevOffset,
currOffset - prevOffset,
+ StandardCharsets.UTF_8);
} // else is the current row is null
}
} else {
- for (int row = 0; row < numRows; ++row) {
- int offset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + row *
4L);
- int numBytes = offset - UdfUtils.UNSAFE.getInt(null,
offsetsAddr + 4L * (row - 1));
- long base = charsAddr + offset - numBytes;
- byte[] bytes = new byte[numBytes];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
- argument[row] = new String(bytes, StandardCharsets.UTF_8);
+ for (int row = rowsStart; row < rowsEnd; ++row) {
+ int prevOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L
* (row - 1));
+ int currOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L
* row);
+ argument[row - rowsStart] = new String(bytes, prevOffset,
currOffset - prevOffset,
+ StandardCharsets.UTF_8);
}
}
return argument;
@@ -1314,7 +1320,7 @@ public class UdfConvert {
}
//////////////////////////////////////////convertArray///////////////////////////////////////////////////////////
- public static void convertArrayBooleanArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<Boolean> convertArrayBooleanArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<Boolean> data = null;
if (isNullable) {
@@ -1340,10 +1346,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayTinyIntArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<Byte> convertArrayTinyIntArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<Byte> data = null;
if (isNullable) {
@@ -1369,10 +1375,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArraySmallIntArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<Short> convertArraySmallIntArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<Short> data = null;
if (isNullable) {
@@ -1398,10 +1404,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayIntArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<Integer> convertArrayIntArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<Integer> data = null;
if (isNullable) {
@@ -1427,10 +1433,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayBigIntArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<Long> convertArrayBigIntArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<Long> data = null;
if (isNullable) {
@@ -1456,10 +1462,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayFloatArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<Float> convertArrayFloatArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<Float> data = null;
if (isNullable) {
@@ -1485,10 +1491,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayDoubleArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<Double> convertArrayDoubleArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<Double> data = null;
if (isNullable) {
@@ -1514,10 +1520,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayDateArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<LocalDate> convertArrayDateArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<LocalDate> data = null;
if (isNullable) {
@@ -1549,10 +1555,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayDateTimeArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<LocalDateTime> convertArrayDateTimeArg(int row,
int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<LocalDateTime> data = null;
if (isNullable) {
@@ -1582,10 +1588,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayDateV2Arg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<LocalDate> convertArrayDateV2Arg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<LocalDate> data = null;
if (isNullable) {
@@ -1613,10 +1619,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayDateTimeV2Arg(Object[] argument, int row,
int currentRowNum, long offsetStart,
+ public static ArrayList<LocalDateTime> convertArrayDateTimeV2Arg(int row,
int currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<LocalDateTime> data = null;
if (isNullable) {
@@ -1646,10 +1652,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayLargeIntArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<BigInteger> convertArrayLargeIntArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<BigInteger> data = null;
byte[] bytes = new byte[16];
@@ -1678,10 +1684,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayDecimalArg(int scale, long typeLen,
Object[] argument, int row, int currentRowNum,
+ public static ArrayList<BigDecimal> convertArrayDecimalArg(int scale, long
typeLen, int row, int currentRowNum,
long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr) {
ArrayList<BigDecimal> data = null;
@@ -1713,10 +1719,10 @@ public class UdfConvert {
}
} // for loop
} // end for all current row
- argument[row] = data;
+ return data;
}
- public static void convertArrayStringArg(Object[] argument, int row, int
currentRowNum, long offsetStart,
+ public static ArrayList<String> convertArrayStringArg(int row, int
currentRowNum, long offsetStart,
boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long
dataAddr, long strOffsetAddr) {
ArrayList<String> data = null;
if (isNullable) {
@@ -1755,6 +1761,6 @@ public class UdfConvert {
}
}
}
- argument[row] = data;
+ return data;
}
}
diff --git
a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
index f1993ec488..50528d007b 100644
---
a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
+++
b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
@@ -24,6 +24,7 @@ import org.apache.doris.common.jni.utils.UdfUtils;
import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
+import com.esotericsoftware.reflectasm.MethodAccess;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@@ -50,6 +51,8 @@ public class UdfExecutor extends BaseExecutor {
private long rowIdx;
private long batchSizePtr;
+ private int evaluateIndex;
+ private MethodAccess methodAccess;
/**
* Create a UdfExecutor, using parameters from a serialized thrift object.
Used by
@@ -113,166 +116,14 @@ public class UdfExecutor extends BaseExecutor {
public Object[] convertBasicArguments(int argIdx, boolean isNullable, int
numRows, long nullMapAddr,
long columnAddr, long strOffsetAddr) {
- switch (argTypes[argIdx]) {
- case BOOLEAN:
- return UdfConvert.convertBooleanArg(isNullable, numRows,
nullMapAddr, columnAddr);
- case TINYINT:
- return UdfConvert.convertTinyIntArg(isNullable, numRows,
nullMapAddr, columnAddr);
- case SMALLINT:
- return UdfConvert.convertSmallIntArg(isNullable, numRows,
nullMapAddr, columnAddr);
- case INT:
- return UdfConvert.convertIntArg(isNullable, numRows,
nullMapAddr, columnAddr);
- case BIGINT:
- return UdfConvert.convertBigIntArg(isNullable, numRows,
nullMapAddr, columnAddr);
- case LARGEINT:
- return UdfConvert.convertLargeIntArg(isNullable, numRows,
nullMapAddr, columnAddr);
- case FLOAT:
- return UdfConvert.convertFloatArg(isNullable, numRows,
nullMapAddr, columnAddr);
- case DOUBLE:
- return UdfConvert.convertDoubleArg(isNullable, numRows,
nullMapAddr, columnAddr);
- case CHAR:
- case VARCHAR:
- case STRING:
- return UdfConvert.convertStringArg(isNullable, numRows,
nullMapAddr, columnAddr, strOffsetAddr);
- case DATE: // udaf maybe argClass[i + argClassOffset] need add +1
- return UdfConvert.convertDateArg(argClass[argIdx], isNullable,
numRows, nullMapAddr, columnAddr);
- case DATETIME:
- return UdfConvert.convertDateTimeArg(argClass[argIdx],
isNullable, numRows, nullMapAddr, columnAddr);
- case DATEV2:
- return UdfConvert.convertDateV2Arg(argClass[argIdx],
isNullable, numRows, nullMapAddr, columnAddr);
- case DATETIMEV2:
- return UdfConvert.convertDateTimeV2Arg(argClass[argIdx],
isNullable, numRows, nullMapAddr, columnAddr);
- case DECIMALV2:
- case DECIMAL128:
- return
UdfConvert.convertDecimalArg(argTypes[argIdx].getScale(), 16L, isNullable,
numRows, nullMapAddr,
- columnAddr);
- case DECIMAL32:
- return
UdfConvert.convertDecimalArg(argTypes[argIdx].getScale(), 4L, isNullable,
numRows, nullMapAddr,
- columnAddr);
- case DECIMAL64:
- return
UdfConvert.convertDecimalArg(argTypes[argIdx].getScale(), 8L, isNullable,
numRows, nullMapAddr,
- columnAddr);
- default: {
- LOG.info("Not support type: " + argTypes[argIdx].toString());
- Preconditions.checkState(false, "Not support type: " +
argTypes[argIdx].toString());
- break;
- }
- }
- return null;
+ return convertBasicArg(true, argIdx, isNullable, 0, numRows,
nullMapAddr, columnAddr, strOffsetAddr);
}
public Object[] convertArrayArguments(int argIdx, boolean isNullable, int
numRows, long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long
strOffsetAddr) {
- Object[] argument = (Object[]) Array.newInstance(ArrayList.class,
numRows);
- for (int row = 0; row < numRows; ++row) {
- long offsetStart = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L
* (row - 1));
- long offsetEnd = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L *
(row));
- int currentRowNum = (int) (offsetEnd - offsetStart);
- switch (argTypes[argIdx].getItemType().getPrimitiveType()) {
- case BOOLEAN: {
- UdfConvert
- .convertArrayBooleanArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case TINYINT: {
- UdfConvert
- .convertArrayTinyIntArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case SMALLINT: {
- UdfConvert
- .convertArraySmallIntArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case INT: {
- UdfConvert.convertArrayIntArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case BIGINT: {
- UdfConvert.convertArrayBigIntArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case LARGEINT: {
- UdfConvert
- .convertArrayLargeIntArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case FLOAT: {
- UdfConvert.convertArrayFloatArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case DOUBLE: {
- UdfConvert.convertArrayDoubleArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- UdfConvert.convertArrayStringArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr, strOffsetAddr);
- break;
- }
- case DATE: {
- UdfConvert.convertArrayDateArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case DATETIME: {
- UdfConvert
- .convertArrayDateTimeArg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case DATEV2: {
- UdfConvert.convertArrayDateV2Arg(argument, row,
currentRowNum, offsetStart, isNullable, nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case DATETIMEV2: {
- UdfConvert.convertArrayDateTimeV2Arg(argument, row,
currentRowNum, offsetStart, isNullable,
- nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case DECIMALV2:
- case DECIMAL128: {
-
UdfConvert.convertArrayDecimalArg(argTypes[argIdx].getScale(), 16L, argument,
row, currentRowNum,
- offsetStart, isNullable,
- nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case DECIMAL32: {
-
UdfConvert.convertArrayDecimalArg(argTypes[argIdx].getScale(), 4L, argument,
row, currentRowNum,
- offsetStart, isNullable,
- nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- case DECIMAL64: {
-
UdfConvert.convertArrayDecimalArg(argTypes[argIdx].getScale(), 8L, argument,
row, currentRowNum,
- offsetStart, isNullable,
- nullMapAddr,
- nestedNullMapAddr, dataAddr);
- break;
- }
- default: {
- LOG.info("Not support: " + argTypes[argIdx]);
- Preconditions.checkState(false, "Not support type " +
argTypes[argIdx].toString());
- break;
- }
- }
- }
- return argument;
+ return convertArrayArg(argIdx, isNullable, 0, numRows, nullMapAddr,
offsetsAddr, nestedNullMapAddr, dataAddr,
+ strOffsetAddr);
}
/**
@@ -287,7 +138,7 @@ public class UdfExecutor extends BaseExecutor {
for (int j = 0; j < column.length; ++j) {
parameters[j] = inputs[j][i];
}
- result[i] = method.invoke(udf, parameters);
+ result[i] = methodAccess.invoke(udf, evaluateIndex,
parameters);
}
return result;
} catch (Exception e) {
@@ -581,6 +432,7 @@ public class UdfExecutor extends BaseExecutor {
loader = ClassLoader.getSystemClassLoader();
}
Class<?> c = Class.forName(className, true, loader);
+ methodAccess = MethodAccess.get(c);
Constructor<?> ctor = c.getConstructor();
udf = ctor.newInstance();
Method[] methods = c.getMethods();
@@ -597,6 +449,7 @@ public class UdfExecutor extends BaseExecutor {
continue;
}
method = m;
+ evaluateIndex = methodAccess.getIndex(UDF_FUNCTION_NAME);
Pair<Boolean, JavaUdfDataType> returnType;
if (argClass.length == 0 && parameterTypes.length == 0) {
// Special case where the UDF doesn't take any input args
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]