This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push:
new 0d761f9909 [feature-wip][UDF][DIP-1] Support variable-size input and
output for Java UDF (#8678)
0d761f9909 is described below
commit 0d761f99096be50b0fd687f231921041d685e901
Author: Gabriel <[email protected]>
AuthorDate: Mon Apr 11 09:36:16 2022 +0800
[feature-wip][UDF][DIP-1] Support variable-size input and output for Java
UDF (#8678)
This feature is proposed in DSIP-1. This PR support variable-length input
and output Java UDF.
---
be/src/vec/functions/function_java_udf.cpp | 106 ++++++---
be/src/vec/functions/function_java_udf.h | 47 ++--
bin/start_be.sh | 1 +
.../apache/doris/analysis/CreateFunctionStmt.java | 1 +
.../java/org/apache/doris/udf/UdfExecutor.java | 180 +++++++++++-----
.../main/java/org/apache/doris/udf/UdfUtils.java | 53 ++++-
.../java/org/apache/doris/udf/StringConcatUdf.java | 24 +++
.../java/org/apache/doris/udf/UdfExecutorTest.java | 240 ++++++++++++++++-----
gensrc/thrift/Types.thrift | 6 +-
run-be-ut.sh | 44 ++++
10 files changed, 546 insertions(+), 156 deletions(-)
diff --git a/be/src/vec/functions/function_java_udf.cpp
b/be/src/vec/functions/function_java_udf.cpp
index 5a47a5c0c2..cddf95661b 100644
--- a/be/src/vec/functions/function_java_udf.cpp
+++ b/be/src/vec/functions/function_java_udf.cpp
@@ -77,12 +77,15 @@ Status JavaFunctionCall::prepare(FunctionContext* context,
FunctionContext::Func
TJavaUdfExecutorCtorParams ctor_params;
ctor_params.__set_fn(fn_);
ctor_params.__set_location(local_location);
- ctor_params.__set_input_byte_offsets(jni_ctx->input_byte_offsets_ptr);
- ctor_params.__set_input_buffer_ptrs(jni_ctx->input_values_buffer_ptr);
- ctor_params.__set_input_nulls_ptrs(jni_ctx->input_nulls_buffer_ptr);
- ctor_params.__set_output_buffer_ptr(jni_ctx->output_value_buffer);
- ctor_params.__set_output_null_ptr(jni_ctx->output_null_value);
- ctor_params.__set_batch_size_ptr(jni_ctx->batch_size_ptr);
+ ctor_params.__set_input_offsets_ptrs((int64_t)
jni_ctx->input_offsets_ptrs.get());
+ ctor_params.__set_input_buffer_ptrs((int64_t)
jni_ctx->input_values_buffer_ptr.get());
+ ctor_params.__set_input_nulls_ptrs((int64_t)
jni_ctx->input_nulls_buffer_ptr.get());
+ ctor_params.__set_output_buffer_ptr((int64_t)
jni_ctx->output_value_buffer.get());
+ ctor_params.__set_output_null_ptr((int64_t)
jni_ctx->output_null_value.get());
+ ctor_params.__set_output_offsets_ptr((int64_t)
jni_ctx->output_offsets_ptr.get());
+ ctor_params.__set_output_intermediate_state_ptr(
+ (int64_t) jni_ctx->output_intermediate_state_ptr.get());
+ ctor_params.__set_batch_size_ptr((int64_t)
jni_ctx->batch_size_ptr.get());
jbyteArray ctor_params_bytes;
@@ -100,11 +103,6 @@ Status JavaFunctionCall::prepare(FunctionContext* context,
FunctionContext::Func
Status JavaFunctionCall::execute(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
size_t result, size_t num_rows, bool dry_run) {
- auto return_type = block.get_data_type(result);
- if (!return_type->have_maximum_size_of_value()) {
- return Status::InvalidArgument(strings::Substitute(
- "Java UDF doesn't support return type $0 now !",
return_type->get_name()));
- }
JNIEnv* env;
RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env));
JniContext* jni_ctx = reinterpret_cast<JniContext*>(
@@ -119,50 +117,94 @@ Status JavaFunctionCall::execute(FunctionContext*
context, Block& block, const C
arg_idx, column.type->get_name(),
_argument_types[arg_idx]->get_name()));
}
- if (!column.type->have_maximum_size_of_value()) {
- return Status::InvalidArgument(strings::Substitute(
- "Java UDF doesn't support input type $0 now !",
return_type->get_name()));
- }
auto data_col = col;
if (auto* nullable = check_and_get_column<const ColumnNullable>(*col))
{
data_col = nullable->get_nested_column_ptr();
auto null_col =
check_and_get_column<ColumnVector<UInt8>>(nullable->get_null_map_column_ptr());
- ((int64_t*) jni_ctx->input_nulls_buffer_ptr)[arg_idx] =
+ jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(null_col->get_data().data());
+ } else {
+ jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = -1;
+ }
+ if (const ColumnString* str_col =
check_and_get_column<ColumnString>(data_col.get())) {
+ jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
+ reinterpret_cast<int64_t>(str_col->get_chars().data());
+ jni_ctx->input_offsets_ptrs.get()[arg_idx] =
+ reinterpret_cast<int64_t>(str_col->get_offsets().data());
+ } else if (data_col->is_numeric()) {
+ jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
+ reinterpret_cast<int64_t>(data_col->get_raw_data().data);
+ } else {
+ return Status::InvalidArgument(strings::Substitute(
+ "Java UDF doesn't support type $0 now !",
_argument_types[arg_idx]->get_name()));
}
- ((int64_t*) jni_ctx->input_values_buffer_ptr)[arg_idx] =
- reinterpret_cast<int64_t>(data_col->get_raw_data().data);
arg_idx++;
}
-
+ *(jni_ctx->batch_size_ptr) = num_rows;
+ auto return_type = block.get_data_type(result);
if (return_type->is_nullable()) {
auto null_type = std::reinterpret_pointer_cast<const
DataTypeNullable>(return_type);
auto data_col = null_type->get_nested_type()->create_column();
auto null_col = ColumnUInt8::create(data_col->size(), 0);
null_col->reserve(num_rows);
null_col->resize(num_rows);
- data_col->reserve(num_rows);
- data_col->resize(num_rows);
- *((int64_t*) jni_ctx->output_null_value) =
+ *(jni_ctx->output_null_value) =
reinterpret_cast<int64_t>(null_col->get_data().data());
- *((int64_t*) jni_ctx->output_value_buffer) =
reinterpret_cast<int64_t>(data_col->get_raw_data().data);
+#ifndef EVALUATE_JAVA_UDF
+#define EVALUATE_JAVA_UDF
\
+ if (const ColumnString* str_col =
check_and_get_column<ColumnString>(data_col.get())) { \
+ ColumnString::Chars& chars =
const_cast<ColumnString::Chars&>(str_col->get_chars()); \
+ ColumnString::Offsets& offsets =
\
+ const_cast<ColumnString::Offsets&>(str_col->get_offsets());
\
+ int increase_buffer_size = 0;
\
+ int32_t buffer_size =
\
+
JavaFunctionCall::IncreaseReservedBufferSize(increase_buffer_size);
\
+ chars.reserve(buffer_size);
\
+ chars.resize(buffer_size);
\
+ offsets.reserve(num_rows);
\
+ offsets.resize(num_rows);
\
+ *(jni_ctx->output_value_buffer) =
\
+ reinterpret_cast<int64_t>(chars.data());
\
+ *(jni_ctx->output_offsets_ptr) =
\
+ reinterpret_cast<int64_t>(offsets.data());
\
+ jni_ctx->output_intermediate_state_ptr->row_idx = 0;
\
+ jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size;
\
+ env->CallNonvirtualVoidMethodA(
\
+ jni_ctx->executor, executor_cl_, executor_evaluate_id_,
nullptr); \
+ while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows)
{ \
+ increase_buffer_size++;
\
+ int32_t buffer_size =
\
+
JavaFunctionCall::IncreaseReservedBufferSize(increase_buffer_size); \
+ chars.resize(buffer_size);
\
+ *(jni_ctx->output_value_buffer) =
\
+ reinterpret_cast<int64_t>(chars.data());
\
+ jni_ctx->output_intermediate_state_ptr->buffer_size =
buffer_size; \
+ env->CallNonvirtualVoidMethodA(
\
+ jni_ctx->executor, executor_cl_,
executor_evaluate_id_, nullptr); \
+ }
\
+ } else if (data_col->is_numeric()) {
\
+ data_col->reserve(num_rows);
\
+ data_col->resize(num_rows);
\
+ *(jni_ctx->output_value_buffer) =
\
+ reinterpret_cast<int64_t>(data_col->get_raw_data().data);
\
+ env->CallNonvirtualVoidMethodA(
\
+ jni_ctx->executor, executor_cl_, executor_evaluate_id_,
nullptr); \
+ } else {
\
+ return Status::InvalidArgument(strings::Substitute(
\
+ "Java UDF doesn't support return type $0 now !",
return_type->get_name())); \
+ }
+#endif
+ EVALUATE_JAVA_UDF;
block.replace_by_position(result,
ColumnNullable::create(std::move(data_col),
std::move(null_col)));
} else {
+ *(jni_ctx->output_null_value) = -1;
auto data_col = return_type->create_column();
- data_col->reserve(num_rows);
- data_col->resize(num_rows);
-
- *((int64_t*) jni_ctx->output_value_buffer) =
reinterpret_cast<int64_t>(data_col->get_raw_data().data);
+ EVALUATE_JAVA_UDF;
block.replace_by_position(result, std::move(data_col));
}
- *((int32_t*) jni_ctx->batch_size_ptr) = num_rows;
- // Using this version of Call has the lowest overhead. This eliminates the
- // vtable lookup and setting up return stacks.
- env->CallNonvirtualVoidMethodA(
- jni_ctx->executor, executor_cl_, executor_evaluate_id_, nullptr);
return JniUtil::GetJniExceptionMsg(env);
}
diff --git a/be/src/vec/functions/function_java_udf.h
b/be/src/vec/functions/function_java_udf.h
index 2bc8ce88d8..8c0fcbe8e7 100644
--- a/be/src/vec/functions/function_java_udf.h
+++ b/be/src/vec/functions/function_java_udf.h
@@ -76,27 +76,36 @@ private:
jmethodID executor_evaluate_id_;
jmethodID executor_close_id_;
+ struct IntermediateState {
+ size_t buffer_size;
+ size_t row_idx;
+ };
+
struct JniContext {
JavaFunctionCall* parent = nullptr;
jobject executor = nullptr;
- int64_t input_values_buffer_ptr;
- int64_t input_nulls_buffer_ptr;
- int64_t input_byte_offsets_ptr;
- int64_t output_value_buffer;
- int64_t output_null_value;
- int64_t batch_size_ptr;
+ std::unique_ptr<int64_t[]> input_values_buffer_ptr;
+ std::unique_ptr<int64_t[]> input_nulls_buffer_ptr;
+ std::unique_ptr<int64_t[]> input_offsets_ptrs;
+ std::unique_ptr<int64_t> output_value_buffer;
+ std::unique_ptr<int64_t> output_null_value;
+ std::unique_ptr<int64_t> output_offsets_ptr;
+ std::unique_ptr<int32_t> batch_size_ptr;
+ // intermediate_state includes two parts: reserved / used buffer size
and rows
+ std::unique_ptr<IntermediateState> output_intermediate_state_ptr;
JniContext(int64_t num_args, JavaFunctionCall* parent):
parent(parent) {
- input_values_buffer_ptr = (int64_t) new int64_t[num_args];
- input_nulls_buffer_ptr = (int64_t) new int64_t[num_args];
- input_byte_offsets_ptr = (int64_t) new int64_t[num_args];
-
- output_value_buffer = (int64_t) malloc(sizeof(int64_t));
- output_null_value = (int64_t) malloc(sizeof(int64_t));
- batch_size_ptr = (int64_t) malloc(sizeof(int32_t));
+ input_values_buffer_ptr.reset(new int64_t[num_args]);
+ input_nulls_buffer_ptr.reset(new int64_t[num_args]);
+ input_offsets_ptrs.reset(new int64_t[num_args]);
+ output_value_buffer.reset((int64_t*) malloc(sizeof(int64_t)));
+ output_null_value.reset((int64_t*) malloc(sizeof(int64_t)));
+ batch_size_ptr.reset((int32_t*) malloc(sizeof(int32_t)));
+ output_offsets_ptr.reset((int64_t*) malloc(sizeof(int64_t)));
+ output_intermediate_state_ptr.reset((IntermediateState*)
malloc(sizeof(IntermediateState)));
}
~JniContext() {
@@ -109,12 +118,6 @@ private:
Status s = JniUtil::GetJniExceptionMsg(env);
if (!s.ok()) LOG(WARNING) << s.get_error_msg();
env->DeleteGlobalRef(executor);
- delete[] ((int64*) input_values_buffer_ptr);
- delete[] ((int64*) input_nulls_buffer_ptr);
- delete[] ((int64*) input_byte_offsets_ptr);
- free((int64*) output_value_buffer);
- free((int64*) output_null_value);
- free((int32*) batch_size_ptr);
}
/// These functions are cross-compiled to IR and used by codegen.
@@ -122,6 +125,12 @@ private:
JniContext* jni_ctx, int index, uint8_t value);
static uint8_t* GetInputValuesBufferAtOffset(JniContext* jni_ctx, int
offset);
};
+
+ static const int32_t INITIAL_RESERVED_BUFFER_SIZE = 1024;
+ // TODO: we need a heuristic strategy to increase buffer size for
variable-size output.
+ static inline int32_t IncreaseReservedBufferSize(int n) {
+ return INITIAL_RESERVED_BUFFER_SIZE << n;
+ }
};
} // namespace vectorized
diff --git a/bin/start_be.sh b/bin/start_be.sh
index c5b8da5c61..8bb8a6ff30 100755
--- a/bin/start_be.sh
+++ b/bin/start_be.sh
@@ -92,6 +92,7 @@ jdk_version() {
}
jvm_arch="amd64"
+MACHINE_TYPE=$(uname -m)
if [[ "${MACHINE_TYPE}" == "aarch64" ]]; then
jvm_arch="aarch64"
fi
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
index 0ecd3b1ffb..d2516954f7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
@@ -375,6 +375,7 @@ public class CreateFunctionStmt extends DdlStmt {
.put(PrimitiveType.BIGINT, Sets.newHashSet(Long.class,
long.class))
.put(PrimitiveType.CHAR, Sets.newHashSet(String.class))
.put(PrimitiveType.VARCHAR, Sets.newHashSet(String.class))
+ .put(PrimitiveType.STRING, Sets.newHashSet(String.class))
.build();
private void checkUdfType(Class clazz, Method method, Type expType, Class
pType, String pname)
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
index 89b6ea79a2..cca787400a 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
@@ -29,36 +29,18 @@ import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
-import sun.misc.Unsafe;
-
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
-import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
-import java.security.AccessController;
-import java.security.PrivilegedAction;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
public class UdfExecutor {
private static final Logger LOG = Logger.getLogger(UdfExecutor.class);
- public static final Unsafe UNSAFE;
-
- static {
- UNSAFE = (Unsafe) AccessController.doPrivileged(
- (PrivilegedAction<Object>) () -> {
- try {
- Field f = Unsafe.class.getDeclaredField("theUnsafe");
- f.setAccessible(true);
- return f.get(null);
- } catch (NoSuchFieldException | IllegalAccessException e) {
- throw new Error();
- }
- });
- }
// By convention, the function in the class must be called evaluate()
public static final String UDF_FUNCTION_NAME = "evaluate";
@@ -82,10 +64,13 @@ public class UdfExecutor {
// These buffers are allocated in the BE.
private final long inputBufferPtrs_;
private final long inputNullsPtrs_;
+ private final long inputOffsetsPtrs_;
// Output buffer to return non-string values. These buffers are allocated
in the BE.
private final long outputBufferPtr_;
private final long outputNullPtr_;
+ private final long outputOffsetsPtr_;
+ private final long outputIntermediateStatePtr_;
// Pre-constructed input objects for the UDF. This minimizes object
creation overhead
// as these objects are reused across calls to evaluate().
@@ -93,6 +78,9 @@ public class UdfExecutor {
// inputArgs_[i] is either inputObjects_[i] or null
private Object[] inputArgs_;
+ private long outputOffset_;
+ private long row_idx_;
+
private final long batch_size_ptr_;
// Data types that are supported as return or argument types in Java UDFs.
@@ -104,7 +92,10 @@ public class UdfExecutor {
INT("INT", TPrimitiveType.INT, 4),
BIGINT("BIGINT", TPrimitiveType.BIGINT, 8),
FLOAT("FLOAT", TPrimitiveType.FLOAT, 4),
- DOUBLE("DOUBLE", TPrimitiveType.DOUBLE, 8);
+ DOUBLE("DOUBLE", TPrimitiveType.DOUBLE, 8),
+ CHAR("CHAR", TPrimitiveType.CHAR, 0),
+ VARCHAR("VARCHAR", TPrimitiveType.VARCHAR, 0),
+ STRING("STRING", TPrimitiveType.STRING, 0);
private final String description_;
private final TPrimitiveType thriftType_;
@@ -144,6 +135,10 @@ public class UdfExecutor {
return JavaUdfDataType.FLOAT;
} else if (c == double.class || c == Double.class) {
return JavaUdfDataType.DOUBLE;
+ } else if (c == char.class || c == Character.class) {
+ return JavaUdfDataType.CHAR;
+ } else if (c == String.class) {
+ return JavaUdfDataType.STRING;
}
return JavaUdfDataType.INVALID_TYPE;
}
@@ -183,8 +178,15 @@ public class UdfExecutor {
batch_size_ptr_ = request.batch_size_ptr;
inputBufferPtrs_ = request.input_buffer_ptrs;
inputNullsPtrs_ = request.input_nulls_ptrs;
+ inputOffsetsPtrs_ = request.input_offsets_ptrs;
+
outputBufferPtr_ = request.output_buffer_ptr;
outputNullPtr_ = request.output_null_ptr;
+ outputOffsetsPtr_ = request.output_offsets_ptr;
+ outputIntermediateStatePtr_ = request.output_intermediate_state_ptr;
+
+ outputOffset_ = 0L;
+ row_idx_ = 0L;
init(jarFile, className, retType, parameterTypes);
}
@@ -218,22 +220,52 @@ public class UdfExecutor {
* been serialized to 'input'
*/
public void evaluate() throws UdfRuntimeException {
+ int batch_size = UdfUtils.UNSAFE.getInt(null, batch_size_ptr_);
try {
- int batch_size = UNSAFE.getInt(null, batch_size_ptr_);
- for (int row = 0; row < batch_size; row++) {
- allocateInputObjects(row);
+ if (retType_.equals(JavaUdfDataType.STRING) ||
retType_.equals(JavaUdfDataType.VARCHAR)
+ || retType_.equals(JavaUdfDataType.CHAR)) {
+ // If this udf return variable-size type (e.g.) String, we
have to allocate output
+ // buffer multiple times until buffer size is enough to store
output column. So we
+ // always begin with the last evaluated row instead of
beginning of this batch.
+ row_idx_ = UdfUtils.UNSAFE.getLong(null,
outputIntermediateStatePtr_ + 8);
+ if (row_idx_ == 0) {
+ outputOffset_ = 0L;
+ }
+ } else {
+ row_idx_ = 0;
+ }
+ for (; row_idx_ < batch_size; row_idx_++) {
+ allocateInputObjects(row_idx_);
for (int i = 0; i < argTypes_.length; ++i) {
- if (UNSAFE.getByte(null, UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputNullsPtrs_, i)) + row * 1L) == 0) {
+ // Currently, -1 indicates this column is not nullable. So
input argument is
+ // null iff inputNullsPtrs_ != -1 and nullCol[row_idx] !=
0.
+ if (UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.getAddressAtOffset(inputNullsPtrs_, i))
== -1 ||
+ UdfUtils.UNSAFE.getByte(null,
UdfUtils.UNSAFE.getLong(null,
+
UdfUtils.getAddressAtOffset(inputNullsPtrs_, i)) + row_idx_) == 0) {
inputArgs_[i] = inputObjects_[i];
} else {
inputArgs_[i] = null;
}
}
- storeUdfResult(evaluate(inputArgs_), row);
+ // `storeUdfResult` is called to store udf result to output
column. If true
+ // is returned, current value is stored successfully.
Otherwise, current result is
+ // not processed successfully (e.g. current output buffer is
not large enough) so
+ // we break this loop directly.
+ if (!storeUdfResult(evaluate(inputArgs_), row_idx_)) {
+ UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr_
+ 8, row_idx_);
+ return;
+ }
}
} catch (Exception e) {
+ if (retType_.equals(JavaUdfDataType.STRING)) {
+ UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr_ + 8,
batch_size);
+ }
throw new UdfRuntimeException("UDF::evaluate() ran into a
problem.", e);
}
+ if (retType_.equals(JavaUdfDataType.STRING)) {
+ UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr_ + 8,
row_idx_);
+ }
}
/**
@@ -252,42 +284,73 @@ public class UdfExecutor {
}
// Sets the result object 'obj' into the outputBufferPtr_ and
outputNullPtr_
- private void storeUdfResult(Object obj, int row) throws
UdfRuntimeException {
+ private boolean storeUdfResult(Object obj, long row) throws
UdfRuntimeException {
if (obj == null) {
- UNSAFE.putByte(null, UNSAFE.getLong(null, outputNullPtr_) + row *
1L, (byte) 1);
- return;
+ assert (UdfUtils.UNSAFE.getLong(null, outputNullPtr_) != -1);
+ UdfUtils.UNSAFE.putByte(null, UdfUtils.UNSAFE.getLong(null,
outputNullPtr_) + row, (byte) 1);
+ if (retType_.equals(JavaUdfDataType.STRING)) {
+ long bufferSize = UdfUtils.UNSAFE.getLong(null,
outputIntermediateStatePtr_);
+ if (outputOffset_ + 1 > bufferSize) {
+ return false;
+ }
+ outputOffset_ += 1;
+ UdfUtils.UNSAFE.putChar(null, UdfUtils.UNSAFE.getLong(null,
outputBufferPtr_) +
+ outputOffset_ - 1, UdfUtils.END_OF_STRING);
+ UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr_) +
+ 4L * row,
Integer.parseUnsignedInt(String.valueOf(outputOffset_)));
+ }
+ return true;
+ }
+ if (UdfUtils.UNSAFE.getLong(null, outputNullPtr_) != -1) {
+ UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputNullPtr_) + row, (byte) 0);
}
- UNSAFE.putByte(UNSAFE.getLong(null, outputNullPtr_) + row * 1L, (byte)
0);
switch (retType_) {
case BOOLEAN: {
boolean val = (boolean) obj;
- UNSAFE.putByte(UNSAFE.getLong(null, outputBufferPtr_) + row *
retType_.getLen(), val ? (byte) 1 : 0);
- return;
+ UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr_) + row * retType_.getLen(), val ? (byte) 1 : 0);
+ return true;
}
case TINYINT: {
- UNSAFE.putByte(UNSAFE.getLong(null, outputBufferPtr_) + row *
retType_.getLen(), (byte) obj);
- return;
+ UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr_) + row * retType_.getLen(), (byte) obj);
+ return true;
}
case SMALLINT: {
- UNSAFE.putShort(UNSAFE.getLong(null, outputBufferPtr_) + row *
retType_.getLen(), (short) obj);
- return;
+ UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr_) + row * retType_.getLen(), (short) obj);
+ return true;
}
case INT: {
- UNSAFE.putInt(UNSAFE.getLong(null, outputBufferPtr_) + row *
retType_.getLen(), (int) obj);
- return;
+ UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr_) + row * retType_.getLen(), (int) obj);
+ return true;
}
case BIGINT: {
- UNSAFE.putLong(UNSAFE.getLong(null, outputBufferPtr_) + row *
retType_.getLen(), (long) obj);
- return;
+ UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr_) + row * retType_.getLen(), (long) obj);
+ return true;
}
case FLOAT: {
- UNSAFE.putFloat(UNSAFE.getLong(null, outputBufferPtr_) + row *
retType_.getLen(), (float) obj);
- return;
+ UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr_) + row * retType_.getLen(), (float) obj);
+ return true;
}
case DOUBLE: {
- UNSAFE.putDouble(UNSAFE.getLong(null, outputBufferPtr_) + row
* retType_.getLen(), (double) obj);
- return;
+ UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr_) + row * retType_.getLen(), (double) obj);
+ return true;
}
+ case CHAR:
+ case VARCHAR:
+ case STRING:
+ long bufferSize = UdfUtils.UNSAFE.getLong(null,
outputIntermediateStatePtr_);
+ byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8);
+ if (outputOffset_ + bytes.length + 1 > bufferSize) {
+ return false;
+ }
+ outputOffset_ += (bytes.length + 1);
+ UdfUtils.UNSAFE.putChar(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr_) +
+ outputOffset_ - 1, UdfUtils.END_OF_STRING);
+ UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr_) + 4L * row,
+
Integer.parseUnsignedInt(String.valueOf(outputOffset_)));
+ UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null,
+ UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) +
+ outputOffset_ - bytes.length - 1,
bytes.length);
+ return true;
default:
throw new UdfRuntimeException("Unsupported return type: " +
retType_);
}
@@ -295,32 +358,47 @@ public class UdfExecutor {
// Preallocate the input objects that will be passed to the underlying UDF.
// These objects are allocated once and reused across calls to evaluate()
- private void allocateInputObjects(int row) throws UdfRuntimeException {
+ private void allocateInputObjects(long row) throws UdfRuntimeException {
inputObjects_ = new Object[argTypes_.length];
inputArgs_ = new Object[argTypes_.length];
for (int i = 0; i < argTypes_.length; ++i) {
switch (argTypes_[i]) {
case BOOLEAN:
- inputObjects_[i] = UNSAFE.getBoolean(null,
UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 1L *
row);
+ inputObjects_[i] = UdfUtils.UNSAFE.getBoolean(null,
UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i))
+ row);
break;
case TINYINT:
- inputObjects_[i] = UNSAFE.getByte(null,
UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 1L *
row);
+ inputObjects_[i] = UdfUtils.UNSAFE.getByte(null,
UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i))
+ row);
break;
case SMALLINT:
- inputObjects_[i] = UNSAFE.getShort(null,
UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 2L *
row);
+ inputObjects_[i] = UdfUtils.UNSAFE.getShort(null,
UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i))
+ 2L * row);
break;
case INT:
- inputObjects_[i] = UNSAFE.getInt(null,
UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L *
row);
+ inputObjects_[i] = UdfUtils.UNSAFE.getInt(null,
UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i))
+ 4L * row);
break;
case BIGINT:
- inputObjects_[i] = UNSAFE.getLong(null,
UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L *
row);
+ inputObjects_[i] = UdfUtils.UNSAFE.getLong(null,
UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i))
+ 8L * row);
break;
case FLOAT:
- inputObjects_[i] = UNSAFE.getFloat(null,
UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L *
row);
+ inputObjects_[i] = UdfUtils.UNSAFE.getFloat(null,
UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i))
+ 4L * row);
break;
case DOUBLE:
- inputObjects_[i] = UNSAFE.getDouble(null,
UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L *
row);
+ inputObjects_[i] = UdfUtils.UNSAFE.getDouble(null,
UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i))
+ 8L * row);
+ break;
+ case CHAR:
+ case VARCHAR:
+ case STRING:
+ long offset =
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
+ UdfUtils.UNSAFE.getLong(null,
+
UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * row));
+ long numBytes = row == 0 ? offset - 1 : offset -
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
+ UdfUtils.UNSAFE.getLong(null,
+
UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * (row - 1))) - 1;
+ long base = row == 0 ? UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) :
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + offset - numBytes - 1;
+ byte[] bytes = new byte[(int) numBytes];
+ UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
+ inputObjects_[i] = new String(bytes,
StandardCharsets.UTF_8);
break;
default:
throw new UdfRuntimeException("Unsupported argument type:
" + argTypes_[i]);
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
index 9f33977df3..f412d8593f 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
@@ -18,17 +18,42 @@
package org.apache.doris.udf;
import com.google.common.base.Preconditions;
+
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
-
import org.apache.doris.thrift.TPrimitiveType;
import org.apache.doris.thrift.TScalarType;
import org.apache.doris.thrift.TTypeDesc;
import org.apache.doris.thrift.TTypeNode;
+import sun.misc.Unsafe;
+
+import java.lang.reflect.Field;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+
public class UdfUtils {
+ public static final Unsafe UNSAFE;
+ private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
+ public static final long BYTE_ARRAY_OFFSET;
+ public static final char END_OF_STRING = '\0';
+
+ static {
+ UNSAFE = (Unsafe) AccessController.doPrivileged(
+ (PrivilegedAction<Object>) () -> {
+ try {
+ Field f = Unsafe.class.getDeclaredField("theUnsafe");
+ f.setAccessible(true);
+ return f.get(null);
+ } catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new Error();
+ }
+ });
+ BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class);
+ }
+
protected static Pair<Type, Integer> fromThrift(TTypeDesc typeDesc, int
nodeIdx) throws InternalException {
TTypeNode node = typeDesc.getTypes().get(nodeIdx);
Type type = null;
@@ -62,4 +87,30 @@ public class UdfUtils {
protected static long getAddressAtOffset(long base, int offset) {
return base + 8L * offset;
}
+
+ public static void copyMemory(
+ Object src, long srcOffset, Object dst, long dstOffset, long
length) {
+ // Check if dstOffset is before or after srcOffset to determine if we
should copy
+ // forward or backwards. This is necessary in case src and dst overlap.
+ if (dstOffset < srcOffset) {
+ while (length > 0) {
+ long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+ UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+ length -= size;
+ srcOffset += size;
+ dstOffset += size;
+ }
+ } else {
+ srcOffset += length;
+ dstOffset += length;
+ while (length > 0) {
+ long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+ srcOffset -= size;
+ dstOffset -= size;
+ UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+ length -= size;
+ }
+
+ }
+ }
}
diff --git
a/fe/java-udf/src/test/java/org/apache/doris/udf/StringConcatUdf.java
b/fe/java-udf/src/test/java/org/apache/doris/udf/StringConcatUdf.java
new file mode 100644
index 0000000000..2fa6c2754d
--- /dev/null
+++ b/fe/java-udf/src/test/java/org/apache/doris/udf/StringConcatUdf.java
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.udf;
+
+public class StringConcatUdf {
+ public String evaluate(String a, String b) {
+ return a == null || b == null? null: a + b;
+ }
+}
diff --git
a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java
b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java
index 6b1c487604..6113d94e84 100644
--- a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java
+++ b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java
@@ -22,6 +22,7 @@ import org.apache.thrift.TSerializer;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.junit.Test;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -42,24 +43,24 @@ public class UdfExecutorTest {
fn.name = new TFunctionName("ConstantOne");
- long batchSizePtr = UdfExecutor.UNSAFE.allocateMemory(32);
+ long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4);
int batchSize = 10;
- UdfExecutor.UNSAFE.putInt(batchSizePtr, batchSize);
+ UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize);
TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams();
- params.batch_size_ptr = batchSizePtr;
- params.fn = fn;
-
- long outputBuffer = UdfExecutor.UNSAFE.allocateMemory(32 * batchSize);
- long outputNull = UdfExecutor.UNSAFE.allocateMemory(8 * batchSize);
- long outputBufferPtr = UdfExecutor.UNSAFE.allocateMemory(64);
- UdfExecutor.UNSAFE.putLong(outputBufferPtr, outputBuffer);
- long outputNullPtr = UdfExecutor.UNSAFE.allocateMemory(64);
- UdfExecutor.UNSAFE.putLong(outputNullPtr, outputNull);
- params.output_buffer_ptr = outputBufferPtr;
- params.output_null_ptr = outputNullPtr;
- params.input_buffer_ptrs = 0;
- params.input_nulls_ptrs = 0;
- params.input_byte_offsets = 0;
+ params.setBatchSizePtr(batchSizePtr);
+ params.setFn(fn);
+
+ long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
+ long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize);
+ long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8);
+ UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer);
+ long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8);
+ UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull);
+ params.setOutputBufferPtr(outputBufferPtr);
+ params.setOutputNullPtr(outputNullPtr);
+ params.setInputBufferPtrs(0);
+ params.setInputNullsPtrs(0);
+ params.setInputOffsetsPtrs(0);
TBinaryProtocol.Factory factory =
new TBinaryProtocol.Factory();
@@ -70,8 +71,8 @@ public class UdfExecutorTest {
executor.evaluate();
for (int i = 0; i < 10; i ++) {
- assert (UdfExecutor.UNSAFE.getByte(outputNull + 8 * i) == 0);
- assert (UdfExecutor.UNSAFE.getInt(outputBuffer + 32 * i) == 1);
+ assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0);
+ assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == 1);
}
}
@@ -91,52 +92,52 @@ public class UdfExecutorTest {
fn.name = new TFunctionName("SimpleAdd");
- long batchSizePtr = UdfExecutor.UNSAFE.allocateMemory(32);
+ long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4);
int batchSize = 10;
- UdfExecutor.UNSAFE.putInt(batchSizePtr, batchSize);
+ UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize);
TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams();
- params.batch_size_ptr = batchSizePtr;
- params.fn = fn;
+ params.setBatchSizePtr(batchSizePtr);
+ params.setFn(fn);
- long outputBufferPtr = UdfExecutor.UNSAFE.allocateMemory(64);
- long outputNullPtr = UdfExecutor.UNSAFE.allocateMemory(64);
- long outputBuffer = UdfExecutor.UNSAFE.allocateMemory(32 * batchSize);
- long outputNull = UdfExecutor.UNSAFE.allocateMemory(8 * batchSize);
- UdfExecutor.UNSAFE.putLong(outputBufferPtr, outputBuffer);
- UdfExecutor.UNSAFE.putLong(outputNullPtr, outputNull);
+ long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8);
+ long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8);
+ long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
+ long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize);
+ UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer);
+ UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull);
- params.output_buffer_ptr = outputBufferPtr;
- params.output_null_ptr = outputNullPtr;
+ params.setOutputBufferPtr(outputBufferPtr);
+ params.setOutputNullPtr(outputNullPtr);
int numCols = 2;
- long inputBufferPtr = UdfExecutor.UNSAFE.allocateMemory(64 * numCols);
- long inputNullPtr = UdfExecutor.UNSAFE.allocateMemory(64 * numCols);
+ long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
+ long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
- long inputBuffer1 = UdfExecutor.UNSAFE.allocateMemory(32 * batchSize);
- long inputNull1 = UdfExecutor.UNSAFE.allocateMemory(8 * batchSize);
- long inputBuffer2 = UdfExecutor.UNSAFE.allocateMemory(32 * batchSize);
- long inputNull2 = UdfExecutor.UNSAFE.allocateMemory(8 * batchSize);
+ long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
+ long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize);
+ long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
+ long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize);
- UdfExecutor.UNSAFE.putLong(inputBufferPtr, inputBuffer1);
- UdfExecutor.UNSAFE.putLong(inputBufferPtr + 64, inputBuffer2);
- UdfExecutor.UNSAFE.putLong(inputNullPtr, inputNull1);
- UdfExecutor.UNSAFE.putLong(inputNullPtr + 64, inputNull2);
+ UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1);
+ UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2);
+ UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1);
+ UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2);
for (int i = 0; i < batchSize; i ++) {
- UdfExecutor.UNSAFE.putInt(null, inputBuffer1 + i * 32, i);
- UdfExecutor.UNSAFE.putInt(null, inputBuffer2 + i * 32, i);
+ UdfUtils.UNSAFE.putInt(null, inputBuffer1 + i * 4, i);
+ UdfUtils.UNSAFE.putInt(null, inputBuffer2 + i * 4, i);
if (i % 2 == 0) {
- UdfExecutor.UNSAFE.putByte(null, inputNull1 + i * 8, (byte) 1);
+ UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 1);
} else {
- UdfExecutor.UNSAFE.putByte(null, inputNull1 + i * 8, (byte) 0);
+ UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0);
}
- UdfExecutor.UNSAFE.putByte(null, inputNull2 + i * 8, (byte) 0);
+ UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0);
}
- params.input_buffer_ptrs = inputBufferPtr;
- params.input_nulls_ptrs = inputNullPtr;
- params.input_byte_offsets = 0;
+ params.setInputBufferPtrs(inputBufferPtr);
+ params.setInputNullsPtrs(inputNullPtr);
+ params.setInputOffsetsPtrs(0);
TBinaryProtocol.Factory factory =
new TBinaryProtocol.Factory();
@@ -148,11 +149,148 @@ public class UdfExecutorTest {
executor.evaluate();
for (int i = 0; i < batchSize; i ++) {
if (i % 2 == 0) {
- assert (UdfExecutor.UNSAFE.getByte(outputNull + 8 * i) == 1);
+ assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 1);
+ } else {
+ assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0);
+ assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == i * 2);
+ }
+ }
+ }
+
+ @Test
+ public void testStringConcatUdf() throws Exception {
+ TScalarFunction scalarFunction = new TScalarFunction();
+ scalarFunction.symbol = "org.apache.doris.udf.StringConcatUdf";
+
+ TFunction fn = new TFunction();
+ fn.binary_type = TFunctionBinaryType.JAVA_UDF;
+ TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR);
+ typeNode.scalar_type = new TScalarType(TPrimitiveType.STRING);
+ TTypeDesc typeDesc = new
TTypeDesc(Collections.singletonList(typeNode));
+ fn.ret_type = typeDesc;
+ fn.arg_types = Arrays.asList(typeDesc, typeDesc);
+ fn.scalar_fn = scalarFunction;
+ fn.name = new TFunctionName("StringConcat");
+
+
+ long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(32);
+ int batchSize = 10;
+ UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize);
+
+ TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams();
+ params.setBatchSizePtr(batchSizePtr);
+ params.setFn(fn);
+
+ long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8);
+ long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8);
+ long outputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8);
+ long outputIntermediateStatePtr = UdfUtils.UNSAFE.allocateMemory(8 *
2);
+
+ String[] input1 = new String[batchSize];
+ String[] input2 = new String[batchSize];
+ long[] inputOffsets1 = new long[batchSize];
+ long[] inputOffsets2 = new long[batchSize];
+ long inputBufferSize1 = 0;
+ long inputBufferSize2 = 0;
+ for (int i = 0; i < batchSize; i ++) {
+ input1[i] = "Input1_" + i;
+ input2[i] = "Input2_" + i;
+ inputOffsets1[i] = i == 0?
input1[i].getBytes(StandardCharsets.UTF_8).length + 1:
+ inputOffsets1[i - 1] +
input1[i].getBytes(StandardCharsets.UTF_8).length + 1;
+ inputOffsets2[i] = i == 0?
input2[i].getBytes(StandardCharsets.UTF_8).length + 1:
+ inputOffsets2[i - 1] +
input2[i].getBytes(StandardCharsets.UTF_8).length + 1;
+ inputBufferSize1 +=
input1[i].getBytes(StandardCharsets.UTF_8).length;
+ inputBufferSize2 +=
input2[i].getBytes(StandardCharsets.UTF_8).length;
+ }
+ // In our test case, output buffer is (8 + 1) bytes * batchSize
+ long outputBuffer = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 +
inputBufferSize2 + batchSize);
+ long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize);
+ long outputOffset = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
+
+ UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer);
+ UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull);
+ UdfUtils.UNSAFE.putLong(outputOffsetsPtr, outputOffset);
+ // reserved buffer size
+ UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr, inputBufferSize1 +
inputBufferSize2 + batchSize);
+ // current row id
+ UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr + 8, 0);
+
+ params.setOutputBufferPtr(outputBufferPtr);
+ params.setOutputNullPtr(outputNullPtr);
+ params.setOutputOffsetsPtr(outputOffsetsPtr);
+ params.setOutputIntermediateStatePtr(outputIntermediateStatePtr);
+
+ int numCols = 2;
+ long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
+ long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
+ long inputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols);
+
+ long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 +
batchSize);
+ long inputOffset1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
+ long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize2 +
batchSize);
+ long inputOffset2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize);
+
+ UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1);
+ UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2);
+ UdfUtils.UNSAFE.putLong(inputNullPtr, -1);
+ UdfUtils.UNSAFE.putLong(inputNullPtr + 8, -1);
+ UdfUtils.UNSAFE.putLong(inputOffsetsPtr, inputOffset1);
+ UdfUtils.UNSAFE.putLong(inputOffsetsPtr + 8, inputOffset2);
+
+ for (int i = 0; i < batchSize; i ++) {
+ if (i == 0) {
+ UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8),
+ UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1,
+ input1[i].getBytes(StandardCharsets.UTF_8).length);
+ UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8),
+ UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2,
+ input2[i].getBytes(StandardCharsets.UTF_8).length);
+ } else {
+ UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8),
+ UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 +
inputOffsets1[i - 1],
+ input1[i].getBytes(StandardCharsets.UTF_8).length);
+ UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8),
+ UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 +
inputOffsets2[i - 1],
+ input2[i].getBytes(StandardCharsets.UTF_8).length);
+ }
+ UdfUtils.UNSAFE.putInt(null, inputOffset1 + 4L * i,
+
Integer.parseUnsignedInt(String.valueOf(inputOffsets1[i])));
+ UdfUtils.UNSAFE.putInt(null, inputOffset2 + 4L * i,
+
Integer.parseUnsignedInt(String.valueOf(inputOffsets2[i])));
+ UdfUtils.UNSAFE.putChar(null, inputBuffer1 + inputOffsets1[i] - 1,
+ UdfUtils.END_OF_STRING);
+ UdfUtils.UNSAFE.putChar(null, inputBuffer2 + inputOffsets2[i] - 1,
+ UdfUtils.END_OF_STRING);
+
+ }
+ params.setInputBufferPtrs(inputBufferPtr);
+ params.setInputNullsPtrs(inputNullPtr);
+ params.setInputOffsetsPtrs(inputOffsetsPtr);
+
+ TBinaryProtocol.Factory factory =
+ new TBinaryProtocol.Factory();
+ TSerializer serializer = new TSerializer(factory);
+
+ UdfExecutor executor;
+ executor = new UdfExecutor(serializer.serialize(params));
+
+ executor.evaluate();
+ for (int i = 0; i < batchSize; i ++) {
+ byte[] bytes = new
byte[input1[i].getBytes(StandardCharsets.UTF_8).length +
+ input2[i].getBytes(StandardCharsets.UTF_8).length];
+ assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0);
+ if (i == 0) {
+ UdfUtils.copyMemory(null, outputBuffer, bytes,
UdfUtils.BYTE_ARRAY_OFFSET,
+ bytes.length);
} else {
- assert (UdfExecutor.UNSAFE.getByte(outputNull + 8 * i) == 0);
- assert (UdfExecutor.UNSAFE.getInt(outputBuffer + 32 * i) == i
* 2);
+ long lastOffset = UdfUtils.UNSAFE.getInt(null, outputOffset +
4 * (i - 1));
+ UdfUtils.copyMemory(null, outputBuffer + lastOffset, bytes,
UdfUtils.BYTE_ARRAY_OFFSET,
+ bytes.length);
}
+ long curOffset = UdfUtils.UNSAFE.getInt(null, outputOffset + 4 *
i);
+ assert (new String(bytes, StandardCharsets.UTF_8).equals(input1[i]
+ input2[i]));
+ assert (UdfUtils.UNSAFE.getByte(null, outputBuffer + curOffset -
1) == UdfUtils.END_OF_STRING);
+ assert (UdfUtils.UNSAFE.getByte(null, outputNull + i) == 0);
}
}
}
diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift
index d34f5499eb..145cc45555 100644
--- a/gensrc/thrift/Types.thrift
+++ b/gensrc/thrift/Types.thrift
@@ -350,7 +350,7 @@ struct TJavaUdfExecutorCtorParams {
// call the Java executor with a buffer for all the inputs.
// input_byte_offsets[0] is the byte offset in the buffer for the first
// argument; input_byte_offsets[1] is the second, etc.
- 3: optional i64 input_byte_offsets
+ 3: optional i64 input_offsets_ptrs
// Native input buffer ptr (cast as i64) for the inputs. The input arguments
// are written to this buffer directly and read from java with no copies
@@ -365,8 +365,10 @@ struct TJavaUdfExecutorCtorParams {
// NULL.
6: optional i64 output_null_ptr
7: optional i64 output_buffer_ptr
+ 8: optional i64 output_offsets_ptr
+ 9: optional i64 output_intermediate_state_ptr
- 8: optional i64 batch_size_ptr
+ 10: optional i64 batch_size_ptr
}
// Contains all interesting statistics from a single 'memory pool' in the JVM.
diff --git a/run-be-ut.sh b/run-be-ut.sh
index e86b289952..2b9d873527 100755
--- a/run-be-ut.sh
+++ b/run-be-ut.sh
@@ -176,6 +176,50 @@ done
export DORIS_TEST_BINARY_DIR=${DORIS_TEST_BINARY_DIR}/test/
+# prepare jvm if needed
+jdk_version() {
+ local result
+ local java_cmd=$JAVA_HOME/bin/java
+ local IFS=$'\n'
+ # remove \r for Cygwin
+ local lines=$("$java_cmd" -Xms32M -Xmx32M -version 2>&1 | tr '\r' '\n')
+ if [[ -z $java_cmd ]]
+ then
+ result=no_java
+ else
+ for line in $lines; do
+ if [[ (-z $result) && ($line = *"version \""*) ]]
+ then
+ local ver=$(echo $line | sed -e 's/.*version
"\(.*\)"\(.*\)/\1/; 1q')
+ # on macOS, sed doesn't support '?'
+ if [[ $ver = "1."* ]]
+ then
+ result=$(echo $ver | sed -e 's/1\.\([0-9]*\)\(.*\)/\1/;
1q')
+ else
+ result=$(echo $ver | sed -e 's/\([0-9]*\)\(.*\)/\1/; 1q')
+ fi
+ fi
+ done
+ fi
+ echo "$result"
+}
+
+jvm_arch="amd64"
+MACHINE_TYPE=$(uname -m)
+if [[ "${MACHINE_TYPE}" == "aarch64" ]]; then
+ jvm_arch="aarch64"
+fi
+java_version=$(jdk_version)
+if [[ $java_version -gt 8 ]]; then
+ export
LD_LIBRARY_PATH=$JAVA_HOME/lib/server:$JAVA_HOME/lib:$LD_LIBRARY_PATH
+# JAVA_HOME is jdk
+elif [[ -d "$JAVA_HOME/jre" ]]; then
+ export
LD_LIBRARY_PATH=$JAVA_HOME/jre/lib/$jvm_arch/server:$JAVA_HOME/jre/lib/$jvm_arch:$LD_LIBRARY_PATH
+# JAVA_HOME is jre
+else
+ export
LD_LIBRARY_PATH=$JAVA_HOME/lib/$jvm_arch/server:$JAVA_HOME/lib/$jvm_arch:$LD_LIBRARY_PATH
+fi
+
# prepare gtest output dir
GTEST_OUTPUT_DIR=${CMAKE_BUILD_DIR}/gtest_output
rm -rf ${GTEST_OUTPUT_DIR} && mkdir ${GTEST_OUTPUT_DIR}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]