This is an automated email from the ASF dual-hosted git repository.

chenyz pushed a commit to branch udsf
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit b8a4fb6ab45adfbeb5f992b3ef4494795839c859
Author: Chen YZ <[email protected]>
AuthorDate: Tue Nov 19 16:43:56 2024 +0800

    save
---
 .../confignode/it/IoTDBConfigNodeSnapshotIT.java   |   1 -
 .../iotdb/udf/api/relational/ScalarFunction.java   |  12 ++-
 .../relational/ColumnTransformerBuilder.java       |  32 ++++--
 .../relational/metadata/TableMetadataImpl.java     |  13 ++-
 .../udf/UserDefineScalarFunctionTransformer.java   | 109 ++++++++-------------
 .../iotdb/commons/udf/access/RecordIterator.java   | 105 ++++++++++++++++++++
 .../iotdb/commons/udf/utils/TableUDFUtils.java     |  21 ++--
 7 files changed, 194 insertions(+), 99 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/confignode/it/IoTDBConfigNodeSnapshotIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/confignode/it/IoTDBConfigNodeSnapshotIT.java
index 03faf93585c..7a9d99006a8 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/confignode/it/IoTDBConfigNodeSnapshotIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/confignode/it/IoTDBConfigNodeSnapshotIT.java
@@ -31,7 +31,6 @@ import org.apache.iotdb.commons.path.PathDeserializeUtil;
 import org.apache.iotdb.commons.trigger.TriggerInformation;
 import org.apache.iotdb.commons.trigger.service.TriggerExecutableManager;
 import org.apache.iotdb.commons.udf.UDFInformation;
-import 
org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan;
 import org.apache.iotdb.confignode.rpc.thrift.TCQEntry;
 import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq;
 import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq;
diff --git 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java
 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java
index bd0586d051c..6f52103805e 100644
--- 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java
+++ 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java
@@ -26,12 +26,20 @@ import org.apache.iotdb.udf.api.type.Type;
 public interface ScalarFunction extends SQLFunction {
 
   /**
-   * This method is mainly used to validate {@link FunctionParameters} and 
infer output data type.
+   * This method is used to validate {@link FunctionParameters}.
    *
    * @param parameters parameters used to validate
    * @throws Exception if any parameter is not valid
    */
-  Type validateAndInferOutputType(FunctionParameters parameters) throws 
Exception;
+  void validate(FunctionParameters parameters) throws Exception;
+
+  /**
+   * This method is used to infer the output data type of the transformation.
+   *
+   * @param parameters input parameters
+   * @return the output data type
+   */
+  Type inferOutputType(FunctionParameters parameters);
 
   /**
    * This method will be called to process the transformation. In a single UDF 
query, this method
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
index 26b30bec4f8..429ffdeadac 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
@@ -19,7 +19,8 @@
 
 package org.apache.iotdb.db.queryengine.execution.relational;
 
-import org.apache.iotdb.commons.udf.service.UDFManagementService;
+import org.apache.iotdb.commons.udf.utils.TableUDFUtils;
+import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer;
 import org.apache.iotdb.db.exception.sql.SemanticException;
 import org.apache.iotdb.db.queryengine.common.SessionInfo;
 import org.apache.iotdb.db.queryengine.plan.analyze.TypeProvider;
@@ -95,6 +96,7 @@ import 
org.apache.iotdb.db.queryengine.transformation.dag.column.multi.LogicalAn
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.multi.LogicalOrMultiColumnTransformer;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.ternary.BetweenColumnTransformer;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.ternary.Like3ColumnTransformer;
+import 
org.apache.iotdb.db.queryengine.transformation.dag.column.udf.UserDefineScalarFunctionTransformer;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.unary.IsNullColumnTransformer;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.unary.LikeColumnTransformer;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.unary.LogicNotColumnTransformer;
@@ -153,7 +155,7 @@ import 
org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.Tr
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.TrimColumnTransformer;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.TryCastFunctionColumnTransformer;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.UpperColumnTransformer;
-import 
org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar.UserDefineScalarFunctionTransformer;
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters;
 import org.apache.iotdb.udf.api.relational.ScalarFunction;
 
 import org.apache.tsfile.common.conf.TSFileConfig;
@@ -173,6 +175,7 @@ import org.apache.tsfile.utils.Binary;
 import java.time.ZoneId;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -999,13 +1002,24 @@ public class ColumnTransformerBuilder
           source,
           ((LongLiteral) children.get(3)).getParsedValue(),
           context.sessionInfo.getZoneId());
-    } else if (UDFManagementService.getInstance()
-        .isAssignableFrom(functionName, ScalarFunction.class)) {
-      List<ColumnTransformer> childrenColumnTransformer =
-          children.stream().map(child -> process(child, 
context)).collect(Collectors.toList());
-      // TODO(UDSF): check the return type of the function
-      return new UserDefineScalarFunctionTransformer(
-          INT32, functionName, children, childrenColumnTransformer);
+    } else {
+      // user defined function
+      ScalarFunction scalarFunction = 
TableUDFUtils.tryGetScalarFunction(functionName);
+      if (scalarFunction != null) {
+        List<ColumnTransformer> childrenColumnTransformer =
+            children.stream().map(child -> process(child, 
context)).collect(Collectors.toList());
+        FunctionParameters parameters =
+            new FunctionParameters(
+                childrenColumnTransformer.stream()
+                    .map(i -> 
UDFDataTypeTransformer.transformReadTypeToUDFDataType(i.getType()))
+                    .collect(Collectors.toList()),
+                Collections.emptyMap());
+        Type returnType =
+            UDFDataTypeTransformer.transformUDFDataTypeToReadType(
+                scalarFunction.inferOutputType(parameters));
+        return new UserDefineScalarFunctionTransformer(
+            returnType, scalarFunction, childrenColumnTransformer);
+      }
     }
     throw new IllegalArgumentException(String.format("Unknown function: %s", 
functionName));
   }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
index e127de7186d..d8c0c8c9b89 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
@@ -23,7 +23,6 @@ import org.apache.iotdb.commons.partition.DataPartition;
 import org.apache.iotdb.commons.partition.DataPartitionQueryParam;
 import org.apache.iotdb.commons.partition.SchemaPartition;
 import org.apache.iotdb.commons.schema.table.TsTable;
-import org.apache.iotdb.commons.udf.service.UDFManagementService;
 import org.apache.iotdb.commons.udf.utils.TableUDFUtils;
 import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer;
 import org.apache.iotdb.db.exception.sql.SemanticException;
@@ -622,10 +621,9 @@ public class TableMetadataImpl implements Metadata {
         // ignore
     }
 
-    // 根据 argumentTypes 获取返回类型,这边暂时先 mock 一个 INT32
-    if (TableUDFUtils.isScalarFunction(functionName)) {
-      ScalarFunction scalarFunction =
-          UDFManagementService.getInstance().reflect(functionName, 
ScalarFunction.class);
+    // User-defined scalar function
+    ScalarFunction scalarFunction = 
TableUDFUtils.tryGetScalarFunction(functionName);
+    if (scalarFunction != null) {
       FunctionParameters functionParameters =
           new FunctionParameters(
               argumentTypes.stream()
@@ -633,11 +631,12 @@ public class TableMetadataImpl implements Metadata {
                   .collect(Collectors.toList()),
               Collections.emptyMap());
       try {
-        return UDFDataTypeTransformer.transformUDFDataTypeToReadType(
-            scalarFunction.validateAndInferOutputType(functionParameters));
+        scalarFunction.validate(functionParameters);
       } catch (Exception e) {
         throw new SemanticException("Invalid function parameters: " + 
e.getMessage());
       }
+      return UDFDataTypeTransformer.transformUDFDataTypeToReadType(
+          scalarFunction.inferOutputType(functionParameters));
     }
 
     // TODO UDAF
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java
index 7689bbcca9d..c81c42135d5 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/udf/UserDefineScalarFunctionTransformer.java
@@ -19,107 +19,80 @@
 
 package org.apache.iotdb.db.queryengine.transformation.dag.column.udf;
 
-import org.apache.iotdb.commons.udf.service.UDFManagementService;
-import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer;
-import org.apache.iotdb.db.exception.sql.SemanticException;
-import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.commons.udf.access.RecordIterator;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.ColumnTransformer;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.column.multi.MultiColumnTransformer;
-import 
org.apache.iotdb.db.queryengine.transformation.dag.udf.UDFParametersFactory;
-import org.apache.iotdb.udf.api.access.ColumnToRowIterator;
-import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
 import org.apache.iotdb.udf.api.relational.ScalarFunction;
+import org.apache.iotdb.udf.api.relational.access.Record;
 
 import org.apache.tsfile.block.column.Column;
 import org.apache.tsfile.block.column.ColumnBuilder;
-import org.apache.tsfile.enums.TSDataType;
 import org.apache.tsfile.read.common.type.Type;
 
-import java.util.Collections;
 import java.util.List;
-import java.util.stream.Collectors;
 
-// TODO(UDSF): encapsulate refect and validate logic
 public class UserDefineScalarFunctionTransformer extends 
MultiColumnTransformer {
 
   private final ScalarFunction scalarFunction;
-  private final List<TSDataType> childrenTypes;
 
   public UserDefineScalarFunctionTransformer(
       Type returnType,
-      String functionName,
-      List<Expression> children,
+      ScalarFunction scalarFunction,
       List<ColumnTransformer> childrenTransformers) {
     super(returnType, childrenTransformers);
-    ScalarFunction scalarFunction =
-        UDFManagementService.getInstance().reflect(functionName, 
ScalarFunction.class);
-    this.childrenTypes =
-        childrenTransformers.stream()
-            .map(ColumnTransformer::getType)
-            .map(UDFDataTypeTransformer::transformReadTypeToTSDataType)
-            .collect(Collectors.toList());
-    // TODO: 1、Table UDF 里不应该再用 String Expression 了
-    // TODO:2、想办法弄到 attributes
-    UDFParameters udfParameters =
-        UDFParametersFactory.buildUdfParameters(
-            
children.stream().map(Expression::toString).collect(Collectors.toList()),
-            childrenTypes,
-            Collections.emptyMap());
-    try {
-      //      scalarFunction.validate(new 
UDFParameterValidator(udfParameters));
-      //      scalarFunction.beforeStart(udfParameters, new 
ScalarFunctionConfig());
-    } catch (Exception e) {
-      throw new SemanticException(e.getMessage());
-    }
-
     this.scalarFunction = scalarFunction;
   }
 
   @Override
   protected void doTransform(
       List<Column> childrenColumns, ColumnBuilder builder, int positionCount) {
-    ColumnToRowIterator iterator =
-        new ColumnToRowIterator(childrenTypes, childrenColumns, positionCount);
-    //    while (iterator.hasNextRow()) {
-    //      try {
-    //        Row row = iterator.next();
-    //        Object result = scalarFunction.evaluate(row);
-    //        if (result == null) {
-    //          builder.appendNull();
-    //        } else {
-    //          builder.writeObject(result);
-    //        }
-    //      } catch (Exception e) {
-    //        throw new RuntimeException(
-    //            "Error occurs when evaluating UDF " + 
scalarFunction.getClass().getName(), e);
-    //      }
-    //    }
+    RecordIterator iterator = new RecordIterator(childrenColumns, 
positionCount);
+    while (iterator.hasNext()) {
+      try {
+        Object result = scalarFunction.evaluate(iterator.next());
+        if (result == null) {
+          builder.appendNull();
+        } else {
+          builder.writeObject(result);
+        }
+      } catch (Exception e) {
+        throw new RuntimeException(
+            "Error occurs when evaluating user-defined scalar function "
+                + scalarFunction.getClass().getName(),
+            e);
+      }
+    }
   }
 
   @Override
   protected void doTransform(
       List<Column> childrenColumns, ColumnBuilder builder, int positionCount, 
boolean[] selection) {
-    ColumnToRowIterator iterator =
-        new ColumnToRowIterator(childrenTypes, childrenColumns, positionCount);
+    RecordIterator iterator = new RecordIterator(childrenColumns, 
positionCount);
     int i = 0;
-    //    while (iterator.hasNextRow()) {
-    //      try {
-    //        Row row = iterator.next();
-    //        Object result = scalarFunction.evaluate(row);
-    //        if (selection[i++] || result == null) {
-    //          builder.appendNull();
-    //        } else {
-    //          builder.writeObject(result);
-    //        }
-    //      } catch (Exception e) {
-    //        throw new RuntimeException(
-    //            "Error occurs when evaluating UDF " + 
scalarFunction.getClass().getName(), e);
-    //      }
-    //    }
+    while (iterator.hasNext()) {
+      try {
+        Record input = iterator.next();
+        if (selection[i++]) {
+          builder.appendNull();
+          continue;
+        }
+        Object result = scalarFunction.evaluate(input);
+        if (result == null) {
+          builder.appendNull();
+        } else {
+          builder.writeObject(result);
+        }
+      } catch (Exception e) {
+        throw new RuntimeException(
+            "Error occurs when evaluating user-defined scalar function "
+                + scalarFunction.getClass().getName(),
+            e);
+      }
+    }
   }
 
   @Override
   protected void checkType() {
-    // TODO: implement this method
+    // do nothing
   }
 }
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/access/RecordIterator.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/access/RecordIterator.java
new file mode 100644
index 00000000000..29f473b3c9d
--- /dev/null
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/access/RecordIterator.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.commons.udf.access;
+
+import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.type.Binary;
+import org.apache.iotdb.udf.api.type.Type;
+
+import org.apache.tsfile.block.column.Column;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+
+public class RecordIterator implements Iterator<Record> {
+
+  private final List<Column> childrenColumns;
+  private final int positionCount;
+  private int currentIndex;
+
+  public RecordIterator(List<Column> childrenColumns, int positionCount) {
+    this.childrenColumns = childrenColumns;
+    this.positionCount = positionCount;
+  }
+
+  @Override
+  public boolean hasNext() {
+    return currentIndex < positionCount;
+  }
+
+  @Override
+  public Record next() {
+    final int index = currentIndex++;
+    return new Record() {
+      @Override
+      public int getInt(int columnIndex) throws IOException {
+        return childrenColumns.get(columnIndex).getInt(index);
+      }
+
+      @Override
+      public long getLong(int columnIndex) throws IOException {
+        return childrenColumns.get(columnIndex).getLong(index);
+      }
+
+      @Override
+      public float getFloat(int columnIndex) throws IOException {
+        return childrenColumns.get(columnIndex).getFloat(index);
+      }
+
+      @Override
+      public double getDouble(int columnIndex) throws IOException {
+        return childrenColumns.get(columnIndex).getDouble(index);
+      }
+
+      @Override
+      public boolean getBoolean(int columnIndex) throws IOException {
+        return childrenColumns.get(columnIndex).getBoolean(index);
+      }
+
+      @Override
+      public Binary getBinary(int columnIndex) throws IOException {
+        return new 
Binary(childrenColumns.get(columnIndex).getBinary(index).getValues());
+      }
+
+      @Override
+      public String getString(int columnIndex) throws IOException {
+        return childrenColumns.get(columnIndex).getBinary(index).toString();
+      }
+
+      @Override
+      public Type getDataType(int columnIndex) {
+        return UDFDataTypeTransformer.transformToUDFDataType(
+            childrenColumns.get(columnIndex).getDataType());
+      }
+
+      @Override
+      public boolean isNull(int columnIndex) throws IOException {
+        return childrenColumns.get(columnIndex).isNull(index);
+      }
+
+      @Override
+      public int size() {
+        return childrenColumns.size();
+      }
+    };
+  }
+}
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java
index bc2d42507a9..06bfa55feba 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java
@@ -25,30 +25,27 @@ import org.apache.iotdb.udf.api.relational.ScalarFunction;
 import org.apache.iotdb.udf.api.relational.TableFunction;
 
 public class TableUDFUtils {
-  public static boolean isScalarFunction(String functionName) {
+  public static ScalarFunction tryGetScalarFunction(String functionName) {
     try {
-      UDFManagementService.getInstance().reflect(functionName, 
ScalarFunction.class);
-      return true;
+      return UDFManagementService.getInstance().reflect(functionName, 
ScalarFunction.class);
     } catch (Throwable e) {
-      return false;
+      return null;
     }
   }
 
-  public static boolean isTableFunction(String functionName) {
+  public static TableFunction tryGetTableFunction(String functionName) {
     try {
-      UDFManagementService.getInstance().reflect(functionName, 
TableFunction.class);
-      return true;
+      return UDFManagementService.getInstance().reflect(functionName, 
TableFunction.class);
     } catch (Throwable e) {
-      return false;
+      return null;
     }
   }
 
-  public static boolean isAggregateFunction(String functionName) {
+  public static AggregateFunction tryGetAggregateFunction(String functionName) 
{
     try {
-      UDFManagementService.getInstance().reflect(functionName, 
AggregateFunction.class);
-      return true;
+      return UDFManagementService.getInstance().reflect(functionName, 
AggregateFunction.class);
     } catch (Throwable e) {
-      return false;
+      return null;
     }
   }
 }

Reply via email to