This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 01c001e2ac [refactor](javaudf) simplify UdfExecutor and UdafExecutor
(#16050)
01c001e2ac is described below
commit 01c001e2ac06513d8049d519341e05b712460c7a
Author: Gabriel <[email protected]>
AuthorDate: Sat Jan 21 08:07:28 2023 +0800
[refactor](javaudf) simplify UdfExecutor and UdafExecutor (#16050)
* [refactor](javaudf) simplify UdfExecutor and UdafExecutor
* update
* update
---
.../udf/{UdfExecutor.java => BaseExecutor.java} | 502 +++++++--------------
.../java/org/apache/doris/udf/UdafExecutor.java | 379 ++--------------
.../java/org/apache/doris/udf/UdfExecutor.java | 384 ++--------------
3 files changed, 227 insertions(+), 1038 deletions(-)
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
similarity index 60%
copy from fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
copy to fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
index 62deef5cda..55ff08f700 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
@@ -18,80 +18,67 @@
package org.apache.doris.udf;
import org.apache.doris.catalog.Type;
-import org.apache.doris.common.Pair;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.udf.UdfUtils.JavaUdfDataType;
-import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
-import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import java.io.IOException;
-import java.lang.reflect.Constructor;
-import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
-import java.net.MalformedURLException;
import java.net.URLClassLoader;
import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
import java.util.Arrays;
-public class UdfExecutor {
- private static final Logger LOG = Logger.getLogger(UdfExecutor.class);
+public abstract class BaseExecutor {
+ private static final Logger LOG = Logger.getLogger(BaseExecutor.class);
// By convention, the function in the class must be called evaluate()
public static final String UDF_FUNCTION_NAME = "evaluate";
+ public static final String UDAF_CREATE_FUNCTION = "create";
+ public static final String UDAF_DESTROY_FUNCTION = "destroy";
+ public static final String UDAF_ADD_FUNCTION = "add";
+ public static final String UDAF_SERIALIZE_FUNCTION = "serialize";
+ public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize";
+ public static final String UDAF_MERGE_FUNCTION = "merge";
+ public static final String UDAF_RESULT_FUNCTION = "getValue";
// Object to deserialize ctor params from BE.
- private static final TBinaryProtocol.Factory PROTOCOL_FACTORY =
+ protected static final TBinaryProtocol.Factory PROTOCOL_FACTORY =
new TBinaryProtocol.Factory();
- private Object udf;
+ protected Object udf;
// setup by init() and cleared by close()
- private Method method;
- // setup by init() and cleared by close()
- private URLClassLoader classLoader;
+ protected URLClassLoader classLoader;
// Return and argument types of the function inferred from the udf method
signature.
// The JavaUdfDataType enum maps it to corresponding primitive type.
- private JavaUdfDataType[] argTypes;
- private JavaUdfDataType retType;
+ protected JavaUdfDataType[] argTypes;
+ protected JavaUdfDataType retType;
// Input buffer from the backend. This is valid for the duration of an
evaluate() call.
// These buffers are allocated in the BE.
- private final long inputBufferPtrs;
- private final long inputNullsPtrs;
- private final long inputOffsetsPtrs;
+ protected final long inputBufferPtrs;
+ protected final long inputNullsPtrs;
+ protected final long inputOffsetsPtrs;
// Output buffer to return non-string values. These buffers are allocated
in the BE.
- private final long outputBufferPtr;
- private final long outputNullPtr;
- private final long outputOffsetsPtr;
- private final long outputIntermediateStatePtr;
-
- // Pre-constructed input objects for the UDF. This minimizes object
creation overhead
- // as these objects are reused across calls to evaluate().
- private Object[] inputObjects;
- // inputArgs_[i] is either inputObjects[i] or null
- private Object[] inputArgs;
-
- private long outputOffset;
- private long rowIdx;
-
- private final long batchSizePtr;
- private Class[] argClass;
+ protected final long outputBufferPtr;
+ protected final long outputNullPtr;
+ protected final long outputOffsetsPtr;
+ protected final long outputIntermediateStatePtr;
+ protected Class[] argClass;
/**
* Create a UdfExecutor, using parameters from a serialized thrift object.
Used by
* the backend.
*/
- public UdfExecutor(byte[] thriftParams) throws Exception {
+ public BaseExecutor(byte[] thriftParams) throws Exception {
TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams();
TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY);
try {
@@ -99,14 +86,6 @@ public class UdfExecutor {
} catch (TException e) {
throw new InternalException(e.getMessage());
}
- String className = request.fn.scalar_fn.symbol;
- String jarFile = request.location;
- Type retType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
- Type[] parameterTypes = new Type[request.fn.arg_types.size()];
- for (int i = 0; i < request.fn.arg_types.size(); ++i) {
- parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i));
- }
- batchSizePtr = request.batch_size_ptr;
inputBufferPtrs = request.input_buffer_ptrs;
inputNullsPtrs = request.input_nulls_ptrs;
inputOffsetsPtrs = request.input_offsets_ptrs;
@@ -116,18 +95,139 @@ public class UdfExecutor {
outputOffsetsPtr = request.output_offsets_ptr;
outputIntermediateStatePtr = request.output_intermediate_state_ptr;
- outputOffset = 0L;
- rowIdx = 0L;
+ Type[] parameterTypes = new Type[request.fn.arg_types.size()];
+ for (int i = 0; i < request.fn.arg_types.size(); ++i) {
+ parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i));
+ }
+ String jarFile = request.location;
+ Type funcRetType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
- init(jarFile, className, retType, parameterTypes);
+ init(request, jarFile, funcRetType, parameterTypes);
}
- @Override
- protected void finalize() throws Throwable {
- close();
- super.finalize();
+ protected abstract void init(TJavaUdfExecutorCtorParams request, String
jarPath,
+ Type funcRetType, Type... parameterTypes) throws
UdfRuntimeException;
+
+ protected Object[] allocateInputObjects(long row, int argClassOffset)
throws UdfRuntimeException {
+ Object[] inputObjects = new Object[argTypes.length];
+
+ for (int i = 0; i < argTypes.length; ++i) {
+ if (UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1
+ && (UdfUtils.UNSAFE.getByte(null,
UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row) ==
1)) {
+ inputObjects[i] = null;
+ continue;
+ }
+ switch (argTypes[i]) {
+ case BOOLEAN:
+ inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
+ break;
+ case TINYINT:
+ inputObjects[i] = UdfUtils.UNSAFE.getByte(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
+ break;
+ case SMALLINT:
+ inputObjects[i] = UdfUtils.UNSAFE.getShort(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case INT:
+ inputObjects[i] = UdfUtils.UNSAFE.getInt(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case BIGINT:
+ inputObjects[i] = UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case FLOAT:
+ inputObjects[i] = UdfUtils.UNSAFE.getFloat(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case DOUBLE:
+ inputObjects[i] = UdfUtils.UNSAFE.getDouble(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case DATE: {
+ long data = UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ inputObjects[i] = UdfUtils.convertDateToJavaDate(data,
argClass[i + argClassOffset]);
+ break;
+ }
+ case DATETIME: {
+ long data = UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ inputObjects[i] =
UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i + argClassOffset]);
+ break;
+ }
+ case DATEV2: {
+ int data = UdfUtils.UNSAFE.getInt(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data,
argClass[i + argClassOffset]);
+ break;
+ }
+ case DATETIMEV2: {
+ long data = UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ inputObjects[i] =
UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i + argClassOffset]);
+ break;
+ }
+ case LARGEINT: {
+ long base = UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row;
+ byte[] bytes = new byte[argTypes[i].getLen()];
+ UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
+
+ inputObjects[i] = new
BigInteger(UdfUtils.convertByteOrder(bytes));
+ break;
+ }
+ case DECIMALV2:
+ case DECIMAL32:
+ case DECIMAL64:
+ case DECIMAL128: {
+ long base = UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row;
+ byte[] bytes = new byte[argTypes[i].getLen()];
+ UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
+
+ BigInteger value = new
BigInteger(UdfUtils.convertByteOrder(bytes));
+ inputObjects[i] = new BigDecimal(value,
argTypes[i].getScale());
+ break;
+ }
+ case CHAR:
+ case VARCHAR:
+ case STRING: {
+ long offset =
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i))
+ 4L * row));
+ long numBytes = row == 0 ? offset : offset -
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
+ UdfUtils.UNSAFE.getLong(null,
+
UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1)));
+ long base =
+ row == 0 ? UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) :
+ UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + offset - numBytes;
+ byte[] bytes = new byte[(int) numBytes];
+ UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
+ inputObjects[i] = new String(bytes,
StandardCharsets.UTF_8);
+ break;
+ }
+ default:
+ throw new UdfRuntimeException("Unsupported argument type:
" + argTypes[i]);
+ }
+ }
+ return inputObjects;
}
+ protected abstract long getCurrentOutputOffset(long row);
+
/**
* Close the class loader we may have created.
*/
@@ -142,91 +242,11 @@ public class UdfExecutor {
}
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
- method = null;
classLoader = null;
}
- /**
- * evaluate function called by the backend. The inputs to the UDF have
- * been serialized to 'input'
- */
- public void evaluate() throws UdfRuntimeException {
- int batchSize = UdfUtils.UNSAFE.getInt(null, batchSizePtr);
- try {
- if (retType.equals(JavaUdfDataType.STRING) ||
retType.equals(JavaUdfDataType.VARCHAR)
- || retType.equals(JavaUdfDataType.CHAR)) {
- // If this udf return variable-size type (e.g.) String, we
have to allocate output
- // buffer multiple times until buffer size is enough to store
output column. So we
- // always begin with the last evaluated row instead of
beginning of this batch.
- rowIdx = UdfUtils.UNSAFE.getLong(null,
outputIntermediateStatePtr + 8);
- if (rowIdx == 0) {
- outputOffset = 0L;
- }
- } else {
- rowIdx = 0;
- }
- for (; rowIdx < batchSize; rowIdx++) {
- allocateInputObjects(rowIdx);
- for (int i = 0; i < argTypes.length; ++i) {
- // Currently, -1 indicates this column is not nullable. So
input argument is
- // null iff inputNullsPtrs_ != -1 and nullCol[row_idx] !=
0.
- if (UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) ==
-1
- || UdfUtils.UNSAFE.getByte(null,
UdfUtils.UNSAFE.getLong(null,
-
UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + rowIdx) == 0) {
- inputArgs[i] = inputObjects[i];
- } else {
- inputArgs[i] = null;
- }
- }
- // `storeUdfResult` is called to store udf result to output
column. If true
- // is returned, current value is stored successfully.
Otherwise, current result is
- // not processed successfully (e.g. current output buffer is
not large enough) so
- // we break this loop directly.
- if (!storeUdfResult(evaluate(inputArgs), rowIdx)) {
- UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr +
8, rowIdx);
- return;
- }
- }
- } catch (Exception e) {
- if (retType.equals(JavaUdfDataType.STRING)) {
- UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8,
batchSize);
- }
- throw new UdfRuntimeException("UDF::evaluate() ran into a
problem.", e);
- }
- if (retType.equals(JavaUdfDataType.STRING)) {
- UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8,
rowIdx);
- }
- }
-
- /**
- * Evaluates the UDF with 'args' as the input to the UDF.
- */
- private Object evaluate(Object... args) throws UdfRuntimeException {
- try {
- return method.invoke(udf, args);
- } catch (Exception e) {
- throw new UdfRuntimeException("UDF failed to evaluate", e);
- }
- }
-
- public Method getMethod() {
- return method;
- }
-
// Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_
- private boolean storeUdfResult(Object obj, long row) throws
UdfRuntimeException {
- if (obj == null) {
- if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
- throw new UdfRuntimeException("UDF failed to store null data
to not null column");
- }
- UdfUtils.UNSAFE.putByte(null, UdfUtils.UNSAFE.getLong(null,
outputNullPtr) + row, (byte) 1);
- if (retType.equals(JavaUdfDataType.STRING)) {
- UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr)
- + 4L * row,
Integer.parseUnsignedInt(String.valueOf(outputOffset)));
- }
- return true;
- }
+ protected boolean storeUdfResult(Object obj, long row, Class retClass)
throws UdfRuntimeException {
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) {
UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputNullPtr) + row, (byte) 0);
}
@@ -268,22 +288,22 @@ public class UdfExecutor {
return true;
}
case DATE: {
- long time = UdfUtils.convertToDate(obj,
method.getReturnType());
+ long time = UdfUtils.convertToDate(obj, retClass);
UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
return true;
}
case DATETIME: {
- long time = UdfUtils.convertToDateTime(obj,
method.getReturnType());
+ long time = UdfUtils.convertToDateTime(obj, retClass);
UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
return true;
}
case DATEV2: {
- int time = UdfUtils.convertToDateV2(obj,
method.getReturnType());
+ int time = UdfUtils.convertToDateV2(obj, retClass);
UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
return true;
}
case DATETIMEV2: {
- long time = UdfUtils.convertToDateTimeV2(obj,
method.getReturnType());
+ long time = UdfUtils.convertToDateTimeV2(obj, retClass);
UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
return true;
}
@@ -349,14 +369,16 @@ public class UdfExecutor {
case STRING: {
long bufferSize = UdfUtils.UNSAFE.getLong(null,
outputIntermediateStatePtr);
byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8);
- if (outputOffset + bytes.length > bufferSize) {
+ long offset = getCurrentOutputOffset(row);
+ if (offset + bytes.length > bufferSize) {
return false;
}
- outputOffset += bytes.length;
+ offset += bytes.length;
UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 4L * row,
-
Integer.parseUnsignedInt(String.valueOf(outputOffset)));
+ Integer.parseUnsignedInt(String.valueOf(offset)));
UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) +
outputOffset - bytes.length, bytes.length);
+ UdfUtils.UNSAFE.getLong(null, outputBufferPtr) +
offset - bytes.length, bytes.length);
+ updateOutputOffset(offset);
return true;
}
default:
@@ -364,203 +386,5 @@ public class UdfExecutor {
}
}
- // Preallocate the input objects that will be passed to the underlying UDF.
- // These objects are allocated once and reused across calls to evaluate()
- private void allocateInputObjects(long row) throws UdfRuntimeException {
- inputObjects = new Object[argTypes.length];
- inputArgs = new Object[argTypes.length];
-
- for (int i = 0; i < argTypes.length; ++i) {
- switch (argTypes[i]) {
- case BOOLEAN:
- inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case TINYINT:
- inputObjects[i] = UdfUtils.UNSAFE.getByte(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case SMALLINT:
- inputObjects[i] = UdfUtils.UNSAFE.getShort(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case INT:
- inputObjects[i] = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case BIGINT:
- inputObjects[i] = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case FLOAT:
- inputObjects[i] = UdfUtils.UNSAFE.getFloat(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case DOUBLE:
- inputObjects[i] = UdfUtils.UNSAFE.getDouble(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case DATE: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] = UdfUtils.convertDateToJavaDate(data,
argClass[i]);
- break;
- }
- case DATETIME: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] =
UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i]);
- break;
- }
- case DATEV2: {
- int data = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data,
argClass[i]);
- break;
- }
- case DATETIMEV2: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] =
UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i]);
- break;
- }
- case LARGEINT: {
- long base = UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row;
- byte[] bytes = new byte[argTypes[i].getLen()];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
-
- inputObjects[i] = new
BigInteger(UdfUtils.convertByteOrder(bytes));
- break;
- }
- case DECIMALV2:
- case DECIMAL32:
- case DECIMAL64:
- case DECIMAL128: {
- long base = UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row;
- byte[] bytes = new byte[argTypes[i].getLen()];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
-
- BigInteger value = new
BigInteger(UdfUtils.convertByteOrder(bytes));
- inputObjects[i] = new BigDecimal(value,
argTypes[i].getScale());
- break;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long offset =
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i))
+ 4L * row));
- long numBytes = row == 0 ? offset : offset -
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
-
UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1)));
- long base =
- row == 0 ? UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) :
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + offset - numBytes;
- byte[] bytes = new byte[(int) numBytes];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
- inputObjects[i] = new String(bytes,
StandardCharsets.UTF_8);
- break;
- }
- default:
- throw new UdfRuntimeException("Unsupported argument type:
" + argTypes[i]);
- }
- }
- }
-
- private void init(String jarPath, String udfPath, Type funcRetType,
Type... parameterTypes)
- throws UdfRuntimeException {
- ArrayList<String> signatures = Lists.newArrayList();
- try {
- LOG.debug("Loading UDF '" + udfPath + "' from " + jarPath);
- ClassLoader loader;
- if (jarPath != null) {
- // Save for cleanup.
- ClassLoader parent = getClass().getClassLoader();
- classLoader = UdfUtils.getClassLoader(jarPath, parent);
- loader = classLoader;
- } else {
- // for test
- loader = ClassLoader.getSystemClassLoader();
- }
- Class<?> c = Class.forName(udfPath, true, loader);
- Constructor<?> ctor = c.getConstructor();
- udf = ctor.newInstance();
- Method[] methods = c.getMethods();
- for (Method m : methods) {
- // By convention, the udf must contain the function "evaluate"
- if (!m.getName().equals(UDF_FUNCTION_NAME)) {
- continue;
- }
- signatures.add(m.toGenericString());
- argClass = m.getParameterTypes();
-
- // Try to match the arguments
- if (argClass.length != parameterTypes.length) {
- continue;
- }
- method = m;
- Pair<Boolean, JavaUdfDataType> returnType;
- if (argClass.length == 0 && parameterTypes.length == 0) {
- // Special case where the UDF doesn't take any input args
- returnType = UdfUtils.setReturnType(funcRetType,
m.getReturnType());
- if (!returnType.first) {
- continue;
- } else {
- retType = returnType.second;
- }
- argTypes = new JavaUdfDataType[0];
- LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath);
- return;
- }
- returnType = UdfUtils.setReturnType(funcRetType,
m.getReturnType());
- if (!returnType.first) {
- continue;
- } else {
- retType = returnType.second;
- }
- Pair<Boolean, JavaUdfDataType[]> inputType =
UdfUtils.setArgTypes(parameterTypes, argClass, false);
- if (!inputType.first) {
- continue;
- } else {
- argTypes = inputType.second;
- }
- LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath);
- return;
- }
-
- StringBuilder sb = new StringBuilder();
- sb.append("Unable to find evaluate function with the correct
signature: ")
- .append(udfPath + ".evaluate(")
- .append(Joiner.on(", ").join(parameterTypes))
- .append(")\n")
- .append("UDF contains: \n ")
- .append(Joiner.on("\n ").join(signatures));
- throw new UdfRuntimeException(sb.toString());
- } catch (MalformedURLException e) {
- throw new UdfRuntimeException("Unable to load jar.", e);
- } catch (SecurityException e) {
- throw new UdfRuntimeException("Unable to load function.", e);
- } catch (ClassNotFoundException e) {
- throw new UdfRuntimeException("Unable to find class.", e);
- } catch (NoSuchMethodException e) {
- throw new UdfRuntimeException(
- "Unable to find constructor with no arguments.", e);
- } catch (IllegalArgumentException e) {
- throw new UdfRuntimeException(
- "Unable to call UDF constructor with no arguments.", e);
- } catch (Exception e) {
- throw new UdfRuntimeException("Unable to call create UDF
instance.", e);
- }
- }
+ protected void updateOutputOffset(long offset) {}
}
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
index 0e6028b06e..4f88fa967e 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
@@ -23,12 +23,8 @@ import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.udf.UdfUtils.JavaUdfDataType;
import com.google.common.base.Joiner;
-import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
-import org.apache.thrift.TDeserializer;
-import org.apache.thrift.TException;
-import org.apache.thrift.protocol.TBinaryProtocol;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
@@ -36,100 +32,36 @@ import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
-import java.math.BigDecimal;
-import java.math.BigInteger;
-import java.math.RoundingMode;
import java.net.MalformedURLException;
-import java.net.URLClassLoader;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.HashMap;
/**
* udaf executor.
*/
-public class UdafExecutor {
- public static final String UDAF_CREATE_FUNCTION = "create";
- public static final String UDAF_DESTROY_FUNCTION = "destroy";
- public static final String UDAF_ADD_FUNCTION = "add";
- public static final String UDAF_SERIALIZE_FUNCTION = "serialize";
- public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize";
- public static final String UDAF_MERGE_FUNCTION = "merge";
- public static final String UDAF_RESULT_FUNCTION = "getValue";
+public class UdafExecutor extends BaseExecutor {
+
private static final Logger LOG = Logger.getLogger(UdafExecutor.class);
- private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new
TBinaryProtocol.Factory();
- private final long inputBufferPtrs;
- private final long inputNullsPtrs;
- private final long inputOffsetsPtrs;
- private final long inputPlacesPtr;
- private final long outputBufferPtr;
- private final long outputNullPtr;
- private final long outputOffsetsPtr;
- private final long outputIntermediateStatePtr;
- private Object udaf;
+
+ private long inputPlacesPtr;
private HashMap<String, Method> allMethods;
private HashMap<Long, Object> stateObjMap;
- private URLClassLoader classLoader;
- private JavaUdfDataType[] argTypes;
- private JavaUdfDataType retType;
- private Class[] argClass;
private Class retClass;
/**
* Constructor to create an object.
*/
public UdafExecutor(byte[] thriftParams) throws Exception {
- TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams();
- TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY);
- try {
- deserializer.deserialize(request, thriftParams);
- } catch (TException e) {
- throw new InternalException(e.getMessage());
- }
- Type[] parameterTypes = new Type[request.fn.arg_types.size()];
- for (int i = 0; i < request.fn.arg_types.size(); ++i) {
- parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i));
- }
- inputBufferPtrs = request.input_buffer_ptrs;
- inputNullsPtrs = request.input_nulls_ptrs;
- inputOffsetsPtrs = request.input_offsets_ptrs;
- inputPlacesPtr = request.input_places_ptr;
-
- outputBufferPtr = request.output_buffer_ptr;
- outputNullPtr = request.output_null_ptr;
- outputOffsetsPtr = request.output_offsets_ptr;
- outputIntermediateStatePtr = request.output_intermediate_state_ptr;
- allMethods = new HashMap<>();
- stateObjMap = new HashMap<>();
- String className = request.fn.aggregate_fn.symbol;
- String jarFile = request.location;
- Type funcRetType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
- init(jarFile, className, funcRetType, parameterTypes);
+ super(thriftParams);
}
/**
* close and invoke destroy function.
*/
+ @Override
public void close() {
- if (classLoader != null) {
- try {
- classLoader.close();
- } catch (Exception e) {
- // Log and ignore.
- LOG.debug("Error closing the URLClassloader.", e);
- }
- }
- // We are now un-usable (because the class loader has been
- // closed), so null out allMethods and classLoader.
allMethods = null;
- classLoader = null;
- }
-
- @Override
- protected void finalize() throws Throwable {
- close();
- super.finalize();
+ super.close();
}
/**
@@ -144,11 +76,11 @@ public class UdafExecutor {
stateObjMap.putIfAbsent(curPlace, createAggState());
inputArgs[0] = stateObjMap.get(curPlace);
do {
- Object[] inputObjects = allocateInputObjects(idx);
+ Object[] inputObjects = allocateInputObjects(idx, 1);
for (int i = 0; i < argTypes.length; ++i) {
inputArgs[i + 1] = inputObjects[i];
}
- allMethods.get(UDAF_ADD_FUNCTION).invoke(udaf, inputArgs);
+ allMethods.get(UDAF_ADD_FUNCTION).invoke(udf, inputArgs);
idx++;
} while (isSinglePlace && idx < rowEnd);
} while (idx < rowEnd);
@@ -162,7 +94,7 @@ public class UdafExecutor {
*/
public Object createAggState() throws UdfRuntimeException {
try {
- return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udaf, null);
+ return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udf, null);
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to create: ", e);
}
@@ -174,7 +106,7 @@ public class UdafExecutor {
public void destroy() throws UdfRuntimeException {
try {
for (Object obj : stateObjMap.values()) {
- allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udaf, obj);
+ allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udf, obj);
}
stateObjMap.clear();
} catch (Exception e) {
@@ -191,7 +123,7 @@ public class UdafExecutor {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
args[0] = stateObjMap.get((Long) place);
args[1] = new DataOutputStream(baos);
- allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udaf, args);
+ allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udf, args);
return baos.toByteArray();
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to serialize: ", e);
@@ -208,12 +140,12 @@ public class UdafExecutor {
ByteArrayInputStream bins = new ByteArrayInputStream(data);
args[0] = createAggState();
args[1] = new DataInputStream(bins);
- allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udaf, args);
+ allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udf, args);
args[1] = args[0];
Long curPlace = place;
stateObjMap.putIfAbsent(curPlace, createAggState());
args[0] = stateObjMap.get(curPlace);
- allMethods.get(UDAF_MERGE_FUNCTION).invoke(udaf, args);
+ allMethods.get(UDAF_MERGE_FUNCTION).invoke(udf, args);
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to merge: ", e);
}
@@ -224,14 +156,15 @@ public class UdafExecutor {
*/
public boolean getValue(long row, long place) throws UdfRuntimeException {
try {
- return
storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udaf,
stateObjMap.get((Long) place)),
- row);
+ return
storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf,
stateObjMap.get((Long) place)),
+ row, retClass);
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to result", e);
}
}
- private boolean storeUdfResult(Object obj, long row) throws
UdfRuntimeException {
+ @Override
+ protected boolean storeUdfResult(Object obj, long row, Class retClass)
throws UdfRuntimeException {
if (obj == null) {
// If result is null, return true directly when row == 0 as we
have already inserted default value.
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
@@ -239,267 +172,23 @@ public class UdafExecutor {
}
return true;
}
- if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) {
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputNullPtr) + row, (byte) 0);
- }
- switch (retType) {
- case BOOLEAN: {
- boolean val = (boolean) obj;
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- val ? (byte) 1 : 0);
- return true;
- }
- case TINYINT: {
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (byte) obj);
- return true;
- }
- case SMALLINT: {
- UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (short) obj);
- return true;
- }
- case INT: {
- UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (int) obj);
- return true;
- }
- case BIGINT: {
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (long) obj);
- return true;
- }
- case FLOAT: {
- UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (float) obj);
- return true;
- }
- case DOUBLE: {
- UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (double) obj);
- return true;
- }
- case DATE: {
- long time = UdfUtils.convertToDate(obj, retClass);
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATETIME: {
- long time = UdfUtils.convertToDateTime(obj, retClass);
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATEV2: {
- long time = UdfUtils.convertToDateV2(obj, retClass);
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATETIMEV2: {
- long time = UdfUtils.convertToDateTimeV2(obj, retClass);
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case LARGEINT: {
- BigInteger data = (BigInteger) obj;
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
-
- //here value is 16 bytes, so if result data greater than the
maximum of 16 bytes
- //it will return a wrong num to backend;
- byte[] value = new byte[16];
- //check data is negative
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
- for (int index = 0; index < Math.min(bytes.length,
value.length); ++index) {
- value[index] = bytes[index];
- }
-
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row *
retType.getLen(), value.length);
- return true;
- }
- case DECIMALV2: {
- Preconditions.checkArgument(((BigDecimal) obj).scale() == 9,
"Scale of DECIMALV2 must be 9");
- BigInteger data = ((BigDecimal) obj).unscaledValue();
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
- //TODO: here is maybe overflow also, and may find a better way
to handle
- byte[] value = new byte[16];
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
-
- for (int index = 0; index < Math.min(bytes.length,
value.length); ++index) {
- value[index] = bytes[index];
- }
-
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row *
retType.getLen(), value.length);
- return true;
- }
- case DECIMAL32:
- case DECIMAL64:
- case DECIMAL128: {
- BigDecimal retValue = ((BigDecimal)
obj).setScale(retType.getScale(), RoundingMode.HALF_EVEN);
- BigInteger data = retValue.unscaledValue();
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
- //TODO: here is maybe overflow also, and may find a better way
to handle
- byte[] value = new byte[retType.getLen()];
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
-
- for (int index = 0; index < Math.min(bytes.length,
value.length); ++index) {
- value[index] = bytes[index];
- }
-
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row *
retType.getLen(), value.length);
- return true;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long bufferSize = UdfUtils.UNSAFE.getLong(null,
outputIntermediateStatePtr);
- byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8);
- long offset = Integer.toUnsignedLong(
- UdfUtils.UNSAFE.getInt(null,
UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1)));
- if (offset + bytes.length > bufferSize) {
- return false;
- }
- offset += bytes.length;
- UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 4L * row,
- Integer.parseUnsignedInt(String.valueOf(offset)));
- UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) +
offset - bytes.length, bytes.length);
- return true;
- }
- default:
- throw new UdfRuntimeException("Unsupported return type: " +
retType);
- }
+ return super.storeUdfResult(obj, row, retClass);
}
- private Object[] allocateInputObjects(long row) throws UdfRuntimeException
{
- Object[] inputObjects = new Object[argTypes.length];
-
- for (int i = 0; i < argTypes.length; ++i) {
- // skip the input column of current row is null
- if (UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1
- && (UdfUtils.UNSAFE.getByte(null,
UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row) ==
1)) {
- inputObjects[i] = null;
- continue;
- }
- switch (argTypes[i]) {
- case BOOLEAN:
- inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case TINYINT:
- inputObjects[i] = UdfUtils.UNSAFE.getByte(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case SMALLINT:
- inputObjects[i] = UdfUtils.UNSAFE.getShort(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case INT:
- inputObjects[i] = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case BIGINT:
- inputObjects[i] = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case FLOAT:
- inputObjects[i] = UdfUtils.UNSAFE.getFloat(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case DOUBLE:
- inputObjects[i] = UdfUtils.UNSAFE.getDouble(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case DATE: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] = UdfUtils.convertDateToJavaDate(data,
argClass[i + 1]);
- break;
- }
- case DATETIME: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] =
UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i + 1]);
- break;
- }
- case DATEV2: {
- int data = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data,
argClass[i + 1]);
- break;
- }
- case DATETIMEV2: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] =
UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i + 1]);
- break;
- }
- case LARGEINT: {
- long base = UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row;
- byte[] bytes = new byte[argTypes[i].getLen()];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
-
- inputObjects[i] = new
BigInteger(UdfUtils.convertByteOrder(bytes));
- break;
- }
- case DECIMALV2:
- case DECIMAL32:
- case DECIMAL64:
- case DECIMAL128: {
- long base = UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row;
- byte[] bytes = new byte[argTypes[i].getLen()];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
-
- BigInteger value = new
BigInteger(UdfUtils.convertByteOrder(bytes));
- inputObjects[i] = new BigDecimal(value,
argTypes[i].getScale());
- break;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long offset =
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i))
- + 4L * row));
- long numBytes = row == 0 ? offset : offset -
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row
- - 1)));
- long base = row == 0 ? UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- : UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + offset
- - numBytes;
- byte[] bytes = new byte[(int) numBytes];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
- inputObjects[i] = new String(bytes,
StandardCharsets.UTF_8);
- break;
- }
- default:
- throw new UdfRuntimeException("Unsupported argument type:
" + argTypes[i]);
- }
- }
- return inputObjects;
+ @Override
+ protected long getCurrentOutputOffset(long row) {
+ return Integer.toUnsignedLong(
+ UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 4L * (row - 1)));
}
- private void init(String jarPath, String udfPath, Type funcRetType,
Type... parameterTypes)
- throws UdfRuntimeException {
+ @Override
+ protected void init(TJavaUdfExecutorCtorParams request, String jarPath,
Type funcRetType,
+ Type... parameterTypes) throws UdfRuntimeException {
+ String className = request.fn.aggregate_fn.symbol;
+ inputPlacesPtr = request.input_places_ptr;
+ allMethods = new HashMap<>();
+ stateObjMap = new HashMap<>();
+
ArrayList<String> signatures = Lists.newArrayList();
try {
ClassLoader loader;
@@ -511,9 +200,9 @@ public class UdafExecutor {
// for test
loader = ClassLoader.getSystemClassLoader();
}
- Class<?> c = Class.forName(udfPath, true, loader);
+ Class<?> c = Class.forName(className, true, loader);
Constructor<?> ctor = c.getConstructor();
- udaf = ctor.newInstance();
+ udf = ctor.newInstance();
Method[] methods = c.getDeclaredMethods();
int idx = 0;
for (idx = 0; idx < methods.length; ++idx) {
@@ -569,7 +258,7 @@ public class UdafExecutor {
return;
}
StringBuilder sb = new StringBuilder();
- sb.append("Unable to find evaluate function with the correct
signature: ").append(udfPath + ".evaluate(")
+ sb.append("Unable to find evaluate function with the correct
signature: ").append(className + ".evaluate(")
.append(Joiner.on(",
").join(parameterTypes)).append(")\n").append("UDF contains: \n ")
.append(Joiner.on("\n ").join(signatures));
throw new UdfRuntimeException(sb.toString());
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
index 62deef5cda..5f043f64a8 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
@@ -23,127 +23,45 @@ import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.udf.UdfUtils.JavaUdfDataType;
import com.google.common.base.Joiner;
-import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
-import org.apache.thrift.TDeserializer;
-import org.apache.thrift.TException;
-import org.apache.thrift.protocol.TBinaryProtocol;
-import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
-import java.math.BigDecimal;
-import java.math.BigInteger;
-import java.math.RoundingMode;
import java.net.MalformedURLException;
-import java.net.URLClassLoader;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
-import java.util.Arrays;
-public class UdfExecutor {
+public class UdfExecutor extends BaseExecutor {
private static final Logger LOG = Logger.getLogger(UdfExecutor.class);
-
- // By convention, the function in the class must be called evaluate()
- public static final String UDF_FUNCTION_NAME = "evaluate";
-
- // Object to deserialize ctor params from BE.
- private static final TBinaryProtocol.Factory PROTOCOL_FACTORY =
- new TBinaryProtocol.Factory();
-
- private Object udf;
// setup by init() and cleared by close()
private Method method;
- // setup by init() and cleared by close()
- private URLClassLoader classLoader;
-
- // Return and argument types of the function inferred from the udf method
signature.
- // The JavaUdfDataType enum maps it to corresponding primitive type.
- private JavaUdfDataType[] argTypes;
- private JavaUdfDataType retType;
-
- // Input buffer from the backend. This is valid for the duration of an
evaluate() call.
- // These buffers are allocated in the BE.
- private final long inputBufferPtrs;
- private final long inputNullsPtrs;
- private final long inputOffsetsPtrs;
-
- // Output buffer to return non-string values. These buffers are allocated
in the BE.
- private final long outputBufferPtr;
- private final long outputNullPtr;
- private final long outputOffsetsPtr;
- private final long outputIntermediateStatePtr;
// Pre-constructed input objects for the UDF. This minimizes object
creation overhead
// as these objects are reused across calls to evaluate().
private Object[] inputObjects;
- // inputArgs_[i] is either inputObjects[i] or null
- private Object[] inputArgs;
private long outputOffset;
private long rowIdx;
- private final long batchSizePtr;
- private Class[] argClass;
+ private long batchSizePtr;
/**
* Create a UdfExecutor, using parameters from a serialized thrift object.
Used by
* the backend.
*/
public UdfExecutor(byte[] thriftParams) throws Exception {
- TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams();
- TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY);
- try {
- deserializer.deserialize(request, thriftParams);
- } catch (TException e) {
- throw new InternalException(e.getMessage());
- }
- String className = request.fn.scalar_fn.symbol;
- String jarFile = request.location;
- Type retType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
- Type[] parameterTypes = new Type[request.fn.arg_types.size()];
- for (int i = 0; i < request.fn.arg_types.size(); ++i) {
- parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i));
- }
- batchSizePtr = request.batch_size_ptr;
- inputBufferPtrs = request.input_buffer_ptrs;
- inputNullsPtrs = request.input_nulls_ptrs;
- inputOffsetsPtrs = request.input_offsets_ptrs;
-
- outputBufferPtr = request.output_buffer_ptr;
- outputNullPtr = request.output_null_ptr;
- outputOffsetsPtr = request.output_offsets_ptr;
- outputIntermediateStatePtr = request.output_intermediate_state_ptr;
-
- outputOffset = 0L;
- rowIdx = 0L;
-
- init(jarFile, className, retType, parameterTypes);
- }
-
- @Override
- protected void finalize() throws Throwable {
- close();
- super.finalize();
+ super(thriftParams);
}
/**
* Close the class loader we may have created.
*/
+ @Override
public void close() {
- if (classLoader != null) {
- try {
- classLoader.close();
- } catch (IOException e) {
- // Log and ignore.
- LOG.debug("Error closing the URLClassloader.", e);
- }
- }
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
method = null;
- classLoader = null;
+ super.close();
}
/**
@@ -166,24 +84,12 @@ public class UdfExecutor {
rowIdx = 0;
}
for (; rowIdx < batchSize; rowIdx++) {
- allocateInputObjects(rowIdx);
- for (int i = 0; i < argTypes.length; ++i) {
- // Currently, -1 indicates this column is not nullable. So
input argument is
- // null iff inputNullsPtrs_ != -1 and nullCol[row_idx] !=
0.
- if (UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) ==
-1
- || UdfUtils.UNSAFE.getByte(null,
UdfUtils.UNSAFE.getLong(null,
-
UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + rowIdx) == 0) {
- inputArgs[i] = inputObjects[i];
- } else {
- inputArgs[i] = null;
- }
- }
+ inputObjects = allocateInputObjects(rowIdx, 0);
// `storeUdfResult` is called to store udf result to output
column. If true
// is returned, current value is stored successfully.
Otherwise, current result is
// not processed successfully (e.g. current output buffer is
not large enough) so
// we break this loop directly.
- if (!storeUdfResult(evaluate(inputArgs), rowIdx)) {
+ if (!storeUdfResult(evaluate(inputObjects), rowIdx,
method.getReturnType())) {
UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr +
8, rowIdx);
return;
}
@@ -215,7 +121,8 @@ public class UdfExecutor {
}
// Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_
- private boolean storeUdfResult(Object obj, long row) throws
UdfRuntimeException {
+ @Override
+ protected boolean storeUdfResult(Object obj, long row, Class retClass)
throws UdfRuntimeException {
if (obj == null) {
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
throw new UdfRuntimeException("UDF failed to store null data
to not null column");
@@ -227,262 +134,31 @@ public class UdfExecutor {
}
return true;
}
- if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) {
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputNullPtr) + row, (byte) 0);
- }
- switch (retType) {
- case BOOLEAN: {
- boolean val = (boolean) obj;
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- val ? (byte) 1 : 0);
- return true;
- }
- case TINYINT: {
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (byte) obj);
- return true;
- }
- case SMALLINT: {
- UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (short) obj);
- return true;
- }
- case INT: {
- UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (int) obj);
- return true;
- }
- case BIGINT: {
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (long) obj);
- return true;
- }
- case FLOAT: {
- UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (float) obj);
- return true;
- }
- case DOUBLE: {
- UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(),
- (double) obj);
- return true;
- }
- case DATE: {
- long time = UdfUtils.convertToDate(obj,
method.getReturnType());
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATETIME: {
- long time = UdfUtils.convertToDateTime(obj,
method.getReturnType());
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATEV2: {
- int time = UdfUtils.convertToDateV2(obj,
method.getReturnType());
- UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATETIMEV2: {
- long time = UdfUtils.convertToDateTimeV2(obj,
method.getReturnType());
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case LARGEINT: {
- BigInteger data = (BigInteger) obj;
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
-
- //here value is 16 bytes, so if result data greater than the
maximum of 16 bytes
- //it will return a wrong num to backend;
- byte[] value = new byte[16];
- //check data is negative
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
- for (int index = 0; index < Math.min(bytes.length,
value.length); ++index) {
- value[index] = bytes[index];
- }
-
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row *
retType.getLen(), value.length);
- return true;
- }
- case DECIMALV2: {
- Preconditions.checkArgument(((BigDecimal) obj).scale() == 9,
"Scale of DECIMALV2 must be 9");
- BigInteger data = ((BigDecimal) obj).unscaledValue();
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
- //TODO: here is maybe overflow also, and may find a better way
to handle
- byte[] value = new byte[16];
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
-
- for (int index = 0; index < Math.min(bytes.length,
value.length); ++index) {
- value[index] = bytes[index];
- }
-
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row *
retType.getLen(), value.length);
- return true;
- }
- case DECIMAL32:
- case DECIMAL64:
- case DECIMAL128: {
- BigDecimal retValue = ((BigDecimal)
obj).setScale(retType.getScale(), RoundingMode.HALF_EVEN);
- BigInteger data = retValue.unscaledValue();
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
- //TODO: here is maybe overflow also, and may find a better way
to handle
- byte[] value = new byte[retType.getLen()];
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
+ return super.storeUdfResult(obj, row, retClass);
+ }
- for (int index = 0; index < Math.min(bytes.length,
value.length); ++index) {
- value[index] = bytes[index];
- }
+ @Override
+ protected long getCurrentOutputOffset(long row) {
+ return outputOffset;
+ }
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row *
retType.getLen(), value.length);
- return true;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long bufferSize = UdfUtils.UNSAFE.getLong(null,
outputIntermediateStatePtr);
- byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8);
- if (outputOffset + bytes.length > bufferSize) {
- return false;
- }
- outputOffset += bytes.length;
- UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 4L * row,
-
Integer.parseUnsignedInt(String.valueOf(outputOffset)));
- UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) +
outputOffset - bytes.length, bytes.length);
- return true;
- }
- default:
- throw new UdfRuntimeException("Unsupported return type: " +
retType);
- }
+ @Override
+ protected void updateOutputOffset(long offset) {
+ outputOffset = offset;
}
// Preallocate the input objects that will be passed to the underlying UDF.
// These objects are allocated once and reused across calls to evaluate()
- private void allocateInputObjects(long row) throws UdfRuntimeException {
- inputObjects = new Object[argTypes.length];
- inputArgs = new Object[argTypes.length];
-
- for (int i = 0; i < argTypes.length; ++i) {
- switch (argTypes[i]) {
- case BOOLEAN:
- inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case TINYINT:
- inputObjects[i] = UdfUtils.UNSAFE.getByte(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case SMALLINT:
- inputObjects[i] = UdfUtils.UNSAFE.getShort(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case INT:
- inputObjects[i] = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case BIGINT:
- inputObjects[i] = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case FLOAT:
- inputObjects[i] = UdfUtils.UNSAFE.getFloat(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case DOUBLE:
- inputObjects[i] = UdfUtils.UNSAFE.getDouble(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- break;
- case DATE: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] = UdfUtils.convertDateToJavaDate(data,
argClass[i]);
- break;
- }
- case DATETIME: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] =
UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i]);
- break;
- }
- case DATEV2: {
- int data = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data,
argClass[i]);
- break;
- }
- case DATETIMEV2: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row);
- inputObjects[i] =
UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i]);
- break;
- }
- case LARGEINT: {
- long base = UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row;
- byte[] bytes = new byte[argTypes[i].getLen()];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
-
- inputObjects[i] = new
BigInteger(UdfUtils.convertByteOrder(bytes));
- break;
- }
- case DECIMALV2:
- case DECIMAL32:
- case DECIMAL64:
- case DECIMAL128: {
- long base = UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + argTypes[i].getLen() * row;
- byte[] bytes = new byte[argTypes[i].getLen()];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
-
- BigInteger value = new
BigInteger(UdfUtils.convertByteOrder(bytes));
- inputObjects[i] = new BigDecimal(value,
argTypes[i].getScale());
- break;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long offset =
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i))
+ 4L * row));
- long numBytes = row == 0 ? offset : offset -
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
-
UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1)));
- long base =
- row == 0 ? UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) :
- UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + offset - numBytes;
- byte[] bytes = new byte[(int) numBytes];
- UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
- inputObjects[i] = new String(bytes,
StandardCharsets.UTF_8);
- break;
- }
- default:
- throw new UdfRuntimeException("Unsupported argument type:
" + argTypes[i]);
- }
- }
- }
-
- private void init(String jarPath, String udfPath, Type funcRetType,
Type... parameterTypes)
- throws UdfRuntimeException {
+ @Override
+ protected void init(TJavaUdfExecutorCtorParams request, String jarPath,
Type funcRetType,
+ Type... parameterTypes) throws UdfRuntimeException {
+ String className = request.fn.scalar_fn.symbol;
+ batchSizePtr = request.batch_size_ptr;
+ outputOffset = 0L;
+ rowIdx = 0L;
ArrayList<String> signatures = Lists.newArrayList();
try {
- LOG.debug("Loading UDF '" + udfPath + "' from " + jarPath);
+ LOG.debug("Loading UDF '" + className + "' from " + jarPath);
ClassLoader loader;
if (jarPath != null) {
// Save for cleanup.
@@ -493,7 +169,7 @@ public class UdfExecutor {
// for test
loader = ClassLoader.getSystemClassLoader();
}
- Class<?> c = Class.forName(udfPath, true, loader);
+ Class<?> c = Class.forName(className, true, loader);
Constructor<?> ctor = c.getConstructor();
udf = ctor.newInstance();
Method[] methods = c.getMethods();
@@ -520,7 +196,7 @@ public class UdfExecutor {
retType = returnType.second;
}
argTypes = new JavaUdfDataType[0];
- LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath);
+ LOG.debug("Loaded UDF '" + className + "' from " +
jarPath);
return;
}
returnType = UdfUtils.setReturnType(funcRetType,
m.getReturnType());
@@ -535,13 +211,13 @@ public class UdfExecutor {
} else {
argTypes = inputType.second;
}
- LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath);
+ LOG.debug("Loaded UDF '" + className + "' from " + jarPath);
return;
}
StringBuilder sb = new StringBuilder();
sb.append("Unable to find evaluate function with the correct
signature: ")
- .append(udfPath + ".evaluate(")
+ .append(className + ".evaluate(")
.append(Joiner.on(", ").join(parameterTypes))
.append(")\n")
.append("UDF contains: \n ")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]