This is an automated email from the ASF dual-hosted git repository. chenyz pushed a commit to branch udsf in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit b8a4fb6ab45adfbeb5f992b3ef4494795839c859 Author: Chen YZ <[email protected]> AuthorDate: Tue Nov 19 16:43:56 2024 +0800 save --- .../confignode/it/IoTDBConfigNodeSnapshotIT.java | 1 - .../iotdb/udf/api/relational/ScalarFunction.java | 12 ++- .../relational/ColumnTransformerBuilder.java | 32 ++++-- .../relational/metadata/TableMetadataImpl.java | 13 ++- .../udf/UserDefineScalarFunctionTransformer.java | 109 ++++++++------------- .../iotdb/commons/udf/access/RecordIterator.java | 105 ++++++++++++++++++++ .../iotdb/commons/udf/utils/TableUDFUtils.java | 21 ++-- 7 files changed, 194 insertions(+), 99 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/confignode/it/IoTDBConfigNodeSnapshotIT.java b/integration-test/src/test/java/org/apache/iotdb/confignode/it/IoTDBConfigNodeSnapshotIT.java index 03faf93585c..7a9d99006a8 100644 --- a/integration-test/src/test/java/org/apache/iotdb/confignode/it/IoTDBConfigNodeSnapshotIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/confignode/it/IoTDBConfigNodeSnapshotIT.java @@ -31,7 +31,6 @@ import org.apache.iotdb.commons.path.PathDeserializeUtil; import org.apache.iotdb.commons.trigger.TriggerInformation; import org.apache.iotdb.commons.trigger.service.TriggerExecutableManager; import org.apache.iotdb.commons.udf.UDFInformation; -import org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan; import org.apache.iotdb.confignode.rpc.thrift.TCQEntry; import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; diff --git a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java index bd0586d051c..6f52103805e 100644 --- a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java +++ b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java @@ -26,12 +26,20 @@ import org.apache.iotdb.udf.api.type.Type; public interface ScalarFunction extends SQLFunction { /** - * This method is mainly used to validate {@link FunctionParameters} and infer output data type. + * This method is used to validate {@link FunctionParameters}. * * @param parameters parameters used to validate * @throws Exception if any parameter is not valid */ - Type validateAndInferOutputType(FunctionParameters parameters) throws Exception; + void validate(FunctionParameters parameters) throws Exception; + + /** + * This method is used to infer the output data type of the transformation. + * + * @param parameters input parameters + * @return the output data type + */ + Type inferOutputType(FunctionParameters parameters); /** * This method will be called to process the transformation. In a single UDF query, this method diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java index 26b30bec4f8..429ffdeadac 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java @@ -19,7 +19,8 @@ package org.apache.iotdb.db.queryengine.execution.relational; -import org.apache.iotdb.commons.udf.service.UDFManagementService; +import org.apache.iotdb.commons.udf.utils.TableUDFUtils; +import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.common.SessionInfo; import org.apache.iotdb.db.queryengine.plan.analyze.TypeProvider; @@ -95,6 +96,7 @@ import org.apache.iotdb.db.queryengine.transformation.dag.column.multi.LogicalAn import org.apache.iotdb.db.queryengine.transformation.dag.column.multi.LogicalOrMultiColumnTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.ternary.BetweenColumnTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.ternary.Like3ColumnTransformer; +import org.apache.iotdb.db.queryengine.transformation.dag.column.udf.UserDefineScalarFunctionTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.unary.IsNullColumnTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.unary.LikeColumnTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.unary.LogicNotColumnTransformer; @@ -153,7 +155,7 @@ import org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.Tr import org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.TrimColumnTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.TryCastFunctionColumnTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.UpperColumnTransformer; -import org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.UserDefineScalarFunctionTransformer; +import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters; import org.apache.iotdb.udf.api.relational.ScalarFunction; import org.apache.tsfile.common.conf.TSFileConfig; @@ -173,6 +175,7 @@ import org.apache.tsfile.utils.Binary; import java.time.ZoneId; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -999,13 +1002,24 @@ public class ColumnTransformerBuilder source, ((LongLiteral) children.get(3)).getParsedValue(), context.sessionInfo.getZoneId()); - } else if (UDFManagementService.getInstance() - .isAssignableFrom(functionName, ScalarFunction.class)) { - List<ColumnTransformer> childrenColumnTransformer = - children.stream().map(child -> process(child, context)).collect(Collectors.toList()); - // TODO(UDSF): check the return type of the function - return new UserDefineScalarFunctionTransformer( - INT32, functionName, children, childrenColumnTransformer); + } else { + // user defined function + ScalarFunction scalarFunction = TableUDFUtils.tryGetScalarFunction(functionName); + if (scalarFunction != null) { + List<ColumnTransformer> childrenColumnTransformer = + children.stream().map(child -> process(child, context)).collect(Collectors.toList()); + FunctionParameters parameters = + new FunctionParameters( + childrenColumnTransformer.stream() + .map(i -> UDFDataTypeTransformer.transformReadTypeToUDFDataType(i.getType())) + .collect(Collectors.toList()), + Collections.emptyMap()); + Type returnType = + UDFDataTypeTransformer.transformUDFDataTypeToReadType( + scalarFunction.inferOutputType(parameters)); + return new UserDefineScalarFunctionTransformer( + returnType, scalarFunction, childrenColumnTransformer); + } } throw new IllegalArgumentException(String.format("Unknown function: %s", functionName)); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index e127de7186d..d8c0c8c9b89 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -23,7 +23,6 @@ import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaPartition; import org.apache.iotdb.commons.schema.table.TsTable; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.commons.udf.utils.TableUDFUtils; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.exception.sql.SemanticException; @@ -622,10 +621,9 @@ public class TableMetadataImpl implements Metadata { // ignore } - // 根据 argumentTypes 获取返回类型,这边暂时先 mock 一个 INT32 - if (TableUDFUtils.isScalarFunction(functionName)) { - ScalarFunction scalarFunction = - UDFManagementService.getInstance().reflect(functionName, ScalarFunction.class); + // User-defined scalar function + ScalarFunction scalarFunction = TableUDFUtils.tryGetScalarFunction(functionName); + if (scalarFunction != null) { FunctionParameters functionParameters = new FunctionParameters( argumentTypes.stream() @@ -633,11 +631,12 @@ public class TableMetadataImpl implements Metadata { .collect(Collectors.toList()), Collections.emptyMap()); try { - return UDFDataTypeTransformer.transformUDFDataTypeToReadType( - scalarFunction.validateAndInferOutputType(functionParameters)); + scalarFunction.validate(functionParameters); } catch (Exception e) { throw new SemanticException("Invalid function parameters: " + e.getMessage()); } + return UDFDataTypeTransformer.transformUDFDataTypeToReadType( + scalarFunction.inferOutputType(functionParameters)); } // TODO UDAF diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java index 7689bbcca9d..c81c42135d5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java @@ -19,107 +19,80 @@ package org.apache.iotdb.db.queryengine.transformation.dag.column.udf; -import org.apache.iotdb.commons.udf.service.UDFManagementService; -import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; -import org.apache.iotdb.db.exception.sql.SemanticException; -import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.commons.udf.access.RecordIterator; import org.apache.iotdb.db.queryengine.transformation.dag.column.ColumnTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.multi.MultiColumnTransformer; -import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDFParametersFactory; -import org.apache.iotdb.udf.api.access.ColumnToRowIterator; -import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters; import org.apache.iotdb.udf.api.relational.ScalarFunction; +import org.apache.iotdb.udf.api.relational.access.Record; import org.apache.tsfile.block.column.Column; import org.apache.tsfile.block.column.ColumnBuilder; -import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.read.common.type.Type; -import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; -// TODO(UDSF): encapsulate refect and validate logic public class UserDefineScalarFunctionTransformer extends MultiColumnTransformer { private final ScalarFunction scalarFunction; - private final List<TSDataType> childrenTypes; public UserDefineScalarFunctionTransformer( Type returnType, - String functionName, - List<Expression> children, + ScalarFunction scalarFunction, List<ColumnTransformer> childrenTransformers) { super(returnType, childrenTransformers); - ScalarFunction scalarFunction = - UDFManagementService.getInstance().reflect(functionName, ScalarFunction.class); - this.childrenTypes = - childrenTransformers.stream() - .map(ColumnTransformer::getType) - .map(UDFDataTypeTransformer::transformReadTypeToTSDataType) - .collect(Collectors.toList()); - // TODO: 1、Table UDF 里不应该再用 String Expression 了 - // TODO:2、想办法弄到 attributes - UDFParameters udfParameters = - UDFParametersFactory.buildUdfParameters( - children.stream().map(Expression::toString).collect(Collectors.toList()), - childrenTypes, - Collections.emptyMap()); - try { - // scalarFunction.validate(new UDFParameterValidator(udfParameters)); - // scalarFunction.beforeStart(udfParameters, new ScalarFunctionConfig()); - } catch (Exception e) { - throw new SemanticException(e.getMessage()); - } - this.scalarFunction = scalarFunction; } @Override protected void doTransform( List<Column> childrenColumns, ColumnBuilder builder, int positionCount) { - ColumnToRowIterator iterator = - new ColumnToRowIterator(childrenTypes, childrenColumns, positionCount); - // while (iterator.hasNextRow()) { - // try { - // Row row = iterator.next(); - // Object result = scalarFunction.evaluate(row); - // if (result == null) { - // builder.appendNull(); - // } else { - // builder.writeObject(result); - // } - // } catch (Exception e) { - // throw new RuntimeException( - // "Error occurs when evaluating UDF " + scalarFunction.getClass().getName(), e); - // } - // } + RecordIterator iterator = new RecordIterator(childrenColumns, positionCount); + while (iterator.hasNext()) { + try { + Object result = scalarFunction.evaluate(iterator.next()); + if (result == null) { + builder.appendNull(); + } else { + builder.writeObject(result); + } + } catch (Exception e) { + throw new RuntimeException( + "Error occurs when evaluating user-defined scalar function " + + scalarFunction.getClass().getName(), + e); + } + } } @Override protected void doTransform( List<Column> childrenColumns, ColumnBuilder builder, int positionCount, boolean[] selection) { - ColumnToRowIterator iterator = - new ColumnToRowIterator(childrenTypes, childrenColumns, positionCount); + RecordIterator iterator = new RecordIterator(childrenColumns, positionCount); int i = 0; - // while (iterator.hasNextRow()) { - // try { - // Row row = iterator.next(); - // Object result = scalarFunction.evaluate(row); - // if (selection[i++] || result == null) { - // builder.appendNull(); - // } else { - // builder.writeObject(result); - // } - // } catch (Exception e) { - // throw new RuntimeException( - // "Error occurs when evaluating UDF " + scalarFunction.getClass().getName(), e); - // } - // } + while (iterator.hasNext()) { + try { + Record input = iterator.next(); + if (selection[i++]) { + builder.appendNull(); + continue; + } + Object result = scalarFunction.evaluate(input); + if (result == null) { + builder.appendNull(); + } else { + builder.writeObject(result); + } + } catch (Exception e) { + throw new RuntimeException( + "Error occurs when evaluating user-defined scalar function " + + scalarFunction.getClass().getName(), + e); + } + } } @Override protected void checkType() { - // TODO: implement this method + // do nothing } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/access/RecordIterator.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/access/RecordIterator.java new file mode 100644 index 00000000000..29f473b3c9d --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/access/RecordIterator.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.udf.access; + +import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; +import org.apache.iotdb.udf.api.relational.access.Record; +import org.apache.iotdb.udf.api.type.Binary; +import org.apache.iotdb.udf.api.type.Type; + +import org.apache.tsfile.block.column.Column; + +import java.io.IOException; +import java.util.Iterator; +import java.util.List; + +public class RecordIterator implements Iterator<Record> { + + private final List<Column> childrenColumns; + private final int positionCount; + private int currentIndex; + + public RecordIterator(List<Column> childrenColumns, int positionCount) { + this.childrenColumns = childrenColumns; + this.positionCount = positionCount; + } + + @Override + public boolean hasNext() { + return currentIndex < positionCount; + } + + @Override + public Record next() { + final int index = currentIndex++; + return new Record() { + @Override + public int getInt(int columnIndex) throws IOException { + return childrenColumns.get(columnIndex).getInt(index); + } + + @Override + public long getLong(int columnIndex) throws IOException { + return childrenColumns.get(columnIndex).getLong(index); + } + + @Override + public float getFloat(int columnIndex) throws IOException { + return childrenColumns.get(columnIndex).getFloat(index); + } + + @Override + public double getDouble(int columnIndex) throws IOException { + return childrenColumns.get(columnIndex).getDouble(index); + } + + @Override + public boolean getBoolean(int columnIndex) throws IOException { + return childrenColumns.get(columnIndex).getBoolean(index); + } + + @Override + public Binary getBinary(int columnIndex) throws IOException { + return new Binary(childrenColumns.get(columnIndex).getBinary(index).getValues()); + } + + @Override + public String getString(int columnIndex) throws IOException { + return childrenColumns.get(columnIndex).getBinary(index).toString(); + } + + @Override + public Type getDataType(int columnIndex) { + return UDFDataTypeTransformer.transformToUDFDataType( + childrenColumns.get(columnIndex).getDataType()); + } + + @Override + public boolean isNull(int columnIndex) throws IOException { + return childrenColumns.get(columnIndex).isNull(index); + } + + @Override + public int size() { + return childrenColumns.size(); + } + }; + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java index bc2d42507a9..06bfa55feba 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java @@ -25,30 +25,27 @@ import org.apache.iotdb.udf.api.relational.ScalarFunction; import org.apache.iotdb.udf.api.relational.TableFunction; public class TableUDFUtils { - public static boolean isScalarFunction(String functionName) { + public static ScalarFunction tryGetScalarFunction(String functionName) { try { - UDFManagementService.getInstance().reflect(functionName, ScalarFunction.class); - return true; + return UDFManagementService.getInstance().reflect(functionName, ScalarFunction.class); } catch (Throwable e) { - return false; + return null; } } - public static boolean isTableFunction(String functionName) { + public static TableFunction tryGetTableFunction(String functionName) { try { - UDFManagementService.getInstance().reflect(functionName, TableFunction.class); - return true; + return UDFManagementService.getInstance().reflect(functionName, TableFunction.class); } catch (Throwable e) { - return false; + return null; } } - public static boolean isAggregateFunction(String functionName) { + public static AggregateFunction tryGetAggregateFunction(String functionName) { try { - UDFManagementService.getInstance().reflect(functionName, AggregateFunction.class); - return true; + return UDFManagementService.getInstance().reflect(functionName, AggregateFunction.class); } catch (Throwable e) { - return false; + return null; } } }
