This is an automated email from the ASF dual-hosted git repository.

chenyz pushed a commit to branch table_udsf
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit ea028efd5e90d67ca143eafd60787c108294a987
Author: Chen YZ <[email protected]>
AuthorDate: Tue Nov 19 15:03:26 2024 +0800

    refactor
---
 .../apache/iotdb/udf/ScalarFunctionExample.java    |  42 +++
 .../iotdb/udf/api/access/ColumnToRowIterator.java  |  64 +++++
 .../customizer/config/ScalarFunctionConfig.java    |  30 +++
 .../customizer/parameter/FunctionParameters.java   |  56 ++++
 .../udf/api/relational/AggregationFunction.java    |   3 +
 .../iotdb/udf/api/relational/ScalarFunction.java   |  33 +++
 .../iotdb/udf/api/relational/TableFunction.java    |   3 +
 .../org/apache/iotdb/db/exp/ExpScalarFunction.java | 288 +++++++++++++++++++++
 .../UserDefineScalarFunctionTransformer.java       | 125 +++++++++
 9 files changed, 644 insertions(+)

diff --git 
a/example/udf/src/main/java/org/apache/iotdb/udf/ScalarFunctionExample.java 
b/example/udf/src/main/java/org/apache/iotdb/udf/ScalarFunctionExample.java
new file mode 100644
index 00000000000..e67d31cc03b
--- /dev/null
+++ b/example/udf/src/main/java/org/apache/iotdb/udf/ScalarFunctionExample.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.udf;
+
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters;
+import org.apache.iotdb.udf.api.relational.ScalarFunction;
+import org.apache.iotdb.udf.api.relational.data.Record;
+import org.apache.iotdb.udf.api.type.Type;
+
+public class ScalarFunctionExample implements ScalarFunction {
+
+  @Override
+  public Type validateAndInferOutputType(FunctionParameters parameters) throws 
Exception {
+    return Type.INT32;
+  }
+
+  @Override
+  public Object evaluate(Record input) throws Exception {
+    if (input.isNull(0)) {
+      return null;
+    } else {
+      return -input.getInt(0);
+    }
+  }
+}
diff --git 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/access/ColumnToRowIterator.java
 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/access/ColumnToRowIterator.java
new file mode 100644
index 00000000000..3881b536dd1
--- /dev/null
+++ 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/access/ColumnToRowIterator.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.udf.api.access;
+
+import org.apache.iotdb.udf.api.utils.RowImpl;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.enums.TSDataType;
+
+import java.io.IOException;
+import java.util.List;
+
+public class ColumnToRowIterator implements RowIterator {
+  private final RowImpl row;
+  private final List<Column> columnList;
+  private final Object[] rowRecord;
+  private final int positionCount;
+  private int curIndex = 0;
+
+  public ColumnToRowIterator(
+      List<TSDataType> dataTypes, List<Column> columnList, int positionCount) {
+    this.rowRecord = new Object[dataTypes.size()];
+    this.columnList = columnList;
+    this.positionCount = positionCount;
+    this.row = new RowImpl(dataTypes.toArray(new TSDataType[0]), false);
+    this.row.setRowRecord(rowRecord);
+  }
+
+  @Override
+  public boolean hasNextRow() {
+    return curIndex < positionCount;
+  }
+
+  @Override
+  public Row next() throws IOException {
+    for (int i = 0; i < columnList.size(); i++) {
+      rowRecord[i] = columnList.get(i).getObject(curIndex);
+    }
+    curIndex++;
+    return row;
+  }
+
+  @Override
+  public void reset() {
+    curIndex = 0;
+  }
+}
diff --git 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/config/ScalarFunctionConfig.java
 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/config/ScalarFunctionConfig.java
new file mode 100644
index 00000000000..adacabed085
--- /dev/null
+++ 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/config/ScalarFunctionConfig.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.udf.api.customizer.config;
+
+import org.apache.iotdb.udf.api.type.Type;
+
+public class ScalarFunctionConfig extends UDFConfigurations {
+
+  public ScalarFunctionConfig setOutputDataType(Type outputDataType) {
+    this.outputDataType = outputDataType;
+    return this;
+  }
+}
diff --git 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/parameter/FunctionParameters.java
 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/parameter/FunctionParameters.java
new file mode 100644
index 00000000000..5af832cbaa7
--- /dev/null
+++ 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/customizer/parameter/FunctionParameters.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.udf.api.customizer.parameter;
+
+import org.apache.iotdb.udf.api.type.Type;
+
+import java.util.List;
+import java.util.Map;
+
+public class FunctionParameters {
+  private final List<Type> childExpressionDataTypes;
+  private final Map<String, String> systemAttributes;
+
+  public FunctionParameters(
+      List<Type> childExpressionDataTypes, Map<String, String> 
systemAttributes) {
+    this.childExpressionDataTypes = childExpressionDataTypes;
+    this.systemAttributes = systemAttributes;
+  }
+
+  public List<Type> getChildExpressionDataTypes() {
+    return childExpressionDataTypes;
+  }
+
+  public int getChildExpressionsSize() {
+    return childExpressionDataTypes.size();
+  }
+
+  public Type getDataType(int index) {
+    return childExpressionDataTypes.get(index);
+  }
+
+  public boolean hasSystemAttribute(String attributeKey) {
+    return systemAttributes.containsKey(attributeKey);
+  }
+
+  public Map<String, String> getSystemAttributes() {
+    return systemAttributes;
+  }
+}
diff --git 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregationFunction.java
 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregationFunction.java
new file mode 100644
index 00000000000..3d926059e6b
--- /dev/null
+++ 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregationFunction.java
@@ -0,0 +1,3 @@
+package org.apache.iotdb.udf.api.relational;
+
+public interface AggregationFunction extends SQLFunction {}
diff --git 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java
 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java
new file mode 100644
index 00000000000..d62c0185f10
--- /dev/null
+++ 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/ScalarFunction.java
@@ -0,0 +1,33 @@
+package org.apache.iotdb.udf.api.relational;
+
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters;
+import 
org.apache.iotdb.udf.api.customizer.strategy.MappableRowByRowAccessStrategy;
+import org.apache.iotdb.udf.api.relational.data.Record;
+import org.apache.iotdb.udf.api.type.Type;
+
+public interface ScalarFunction extends SQLFunction {
+
+  /**
+   * This method is mainly used to validate {@link FunctionParameters} and 
infer output data type.
+   *
+   * @param parameters parameters used to validate
+   * @throws Exception if any parameter is not valid
+   */
+  Type validateAndInferOutputType(FunctionParameters parameters) throws 
Exception;
+
+  /**
+   * This method will be called to process the transformation. In a single UDF 
query, this method
+   * may be called multiple times.
+   *
+   * @param input original input data row
+   * @throws Exception the user can throw errors if necessary
+   * @throws UnsupportedOperationException if the user does not override this 
method
+   * @see MappableRowByRowAccessStrategy
+   */
+  Object evaluate(Record input) throws Exception;
+
+  /** This method is mainly used to release the resources used in the 
SQLFunction. */
+  default void beforeDestroy() {
+    // do nothing
+  }
+}
diff --git 
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/TableFunction.java
 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/TableFunction.java
new file mode 100644
index 00000000000..8d24f957ef1
--- /dev/null
+++ 
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/TableFunction.java
@@ -0,0 +1,3 @@
+package org.apache.iotdb.udf.api.relational;
+
+public interface TableFunction extends SQLFunction {}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exp/ExpScalarFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exp/ExpScalarFunction.java
new file mode 100644
index 00000000000..709e17164f2
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exp/ExpScalarFunction.java
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.exp;
+
+import org.apache.iotdb.db.queryengine.transformation.dag.util.TypeUtils;
+import org.apache.iotdb.udf.api.access.Row;
+import org.apache.iotdb.udf.api.type.Binary;
+import org.apache.iotdb.udf.api.type.Type;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.column.LongColumn;
+
+import java.io.IOException;
+import java.lang.invoke.MethodHandle;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.MethodType;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+public class ExpScalarFunction {
+
+  void processBatch(Column[] columns, ColumnBuilder builder) {
+    int count = columns[0].getPositionCount();
+    for (int i = 0; i < count; i++) {
+      long sum = 0;
+      for (int j = 0; j < columns.length; j++) {
+        sum += columns[j].getLong(i);
+      }
+      builder.writeLong(Math.abs(sum));
+    }
+  }
+
+  void processSingle(Column[] columns, ColumnBuilder builder) {
+    Object[] row = new Object[columns.length];
+    for (int i = 0; i < columns[0].getPositionCount(); i++) {
+      for (int j = 0; j < columns.length; j++) {
+        row[j] = columns[j].getObject(i);
+      }
+      builder.writeLong((long) processSingle(row));
+    }
+  }
+
+  Object processSingle(Object[] inputs) {
+    long sum = 0;
+    for (int i = 0; i < inputs.length; i++) {
+      if (inputs[i] != null) {
+        sum += (long) inputs[i];
+      }
+    }
+    return Math.abs(sum);
+  }
+
+  void processRow(Column[] columns, ColumnBuilder builder) throws IOException {
+    for (int i = 0; i < columns[0].getPositionCount(); i++) {
+      final int finalI = i;
+      builder.writeLong(
+          (long)
+              processRow(
+                  new Row() {
+                    @Override
+                    public long getTime() throws IOException {
+                      return 0;
+                    }
+
+                    @Override
+                    public int getInt(int columnIndex) throws IOException {
+                      return 0;
+                    }
+
+                    @Override
+                    public long getLong(int columnIndex) throws IOException {
+                      return columns[columnIndex].getLong(finalI);
+                    }
+
+                    @Override
+                    public float getFloat(int columnIndex) throws IOException {
+                      return 0;
+                    }
+
+                    @Override
+                    public double getDouble(int columnIndex) throws 
IOException {
+                      return 0;
+                    }
+
+                    @Override
+                    public boolean getBoolean(int columnIndex) throws 
IOException {
+                      return false;
+                    }
+
+                    @Override
+                    public Binary getBinary(int columnIndex) throws 
IOException {
+                      return null;
+                    }
+
+                    @Override
+                    public String getString(int columnIndex) throws 
IOException {
+                      return "";
+                    }
+
+                    @Override
+                    public Type getDataType(int columnIndex) {
+                      return null;
+                    }
+
+                    @Override
+                    public boolean isNull(int columnIndex) throws IOException {
+                      return false;
+                    }
+
+                    @Override
+                    public int size() {
+                      return columns.length;
+                    }
+                  }));
+    }
+  }
+
+  Object processRow(Row row) throws IOException {
+    long sum = 0;
+    for (int i = 0; i < row.size(); i++) {
+      sum += row.getLong(i);
+    }
+    return Math.abs(sum);
+  }
+
+  static void processMethodHandle1(
+      Column[] columns, ColumnBuilder builder, MethodHandle processRawHandle) 
throws Throwable {
+    Object[] row = new Object[columns.length];
+    for (int i = 0; i < columns[0].getPositionCount(); i++) {
+      for (int j = 0; j < columns.length; j++) {
+        row[j] = columns[j].getObject(i);
+      }
+      builder.writeLong((long) processRawHandle.invokeWithArguments(row));
+    }
+  }
+
+  static void processMethodHandle2(
+      Column[] columns, ColumnBuilder builder, MethodHandle processRawHandle) 
throws Throwable {
+    for (int i = 0; i < columns[0].getPositionCount(); i++) {
+      long a = columns[0].getLong(i);
+      long b = columns[1].getLong(i);
+      long c = columns[2].getLong(i);
+      long d = columns[3].getLong(i);
+      long e = columns[4].getLong(i);
+      builder.writeLong((long) processRawHandle.invoke(a, b, c, d, e));
+    }
+  }
+
+  static long processRaw(long a, long b, long c, long d, long e) {
+    return Math.abs(a + b + c + d + e);
+  }
+
+  public static void main(String[] args) {
+    int COLUMN_COUNT = 5;
+    int ROW_COUNT = 50000;
+    int LOOP_COUNT = 100;
+    int EPOCH = 10;
+    ExpScalarFunction expScalarFunction = new ExpScalarFunction();
+    // create column[] and columnBuilder
+    List<Column[]> columns = new ArrayList<>();
+    for (int loop = 0; loop < LOOP_COUNT; loop++) {
+      columns.add(new LongColumn[COLUMN_COUNT]);
+      for (int i = 0; i < COLUMN_COUNT; i++) {
+        long[] values = new long[ROW_COUNT];
+        for (int j = 0; j < ROW_COUNT; j++) {
+          // random value
+          values[j] = (long) (Math.random() * 100);
+        }
+        columns.get(loop)[i] = new LongColumn(ROW_COUNT, Optional.empty(), 
values);
+      }
+    }
+    System.out.println("Start testing...");
+
+    try {
+      // 获取 processRaw 方法的 MethodHandle
+      MethodHandle processRawHandle =
+          MethodHandles.lookup()
+              .findStatic(
+                  ExpScalarFunction.class,
+                  "processRaw",
+                  MethodType.methodType(
+                      long.class, long.class, long.class, long.class, 
long.class, long.class));
+      long startTime = System.currentTimeMillis();
+      //            for (int j = 0; j < EPOCH; j++) {
+      //                for (int i = 0; i < LOOP_COUNT; i++) {
+      //                    ColumnBuilder builder = 
TypeUtils.initColumnBuilder(TSDataType.INT64,
+      // ROW_COUNT);
+      //                    try {
+      //                        processMethodHandle1(columns.get(i), builder, 
processRawHandle);
+      //                    } catch (Throwable throwable) {
+      //                        throwable.printStackTrace();
+      //                    }
+      //                }
+      //            }
+      long endTime = System.currentTimeMillis();
+      System.out.println(
+          "Process by processRawHandle#invokeWithArguments time: " + (endTime 
- startTime) + "ms");
+
+      startTime = System.currentTimeMillis();
+      for (int j = 0; j < EPOCH; j++) {
+        for (int i = 0; i < LOOP_COUNT; i++) {
+          ColumnBuilder builder = 
TypeUtils.initColumnBuilder(TSDataType.INT64, ROW_COUNT);
+          try {
+            processMethodHandle2(columns.get(i), builder, processRawHandle);
+          } catch (Throwable throwable) {
+            throwable.printStackTrace();
+          }
+
+          if (i == 0) {
+            System.out.println(builder.build().getLong(ROW_COUNT - 1));
+          }
+        }
+      }
+      endTime = System.currentTimeMillis();
+      System.out.println(
+          "Process by processRawHandle#invoke time: " + (endTime - startTime) 
+ "ms");
+    } catch (NoSuchMethodException | IllegalAccessException e) {
+      e.printStackTrace();
+    }
+
+    // batch
+    long startTime = System.currentTimeMillis();
+    for (int j = 0; j < EPOCH; j++) {
+      for (int i = 0; i < LOOP_COUNT; i++) {
+        ColumnBuilder builder = TypeUtils.initColumnBuilder(TSDataType.INT64, 
ROW_COUNT);
+        expScalarFunction.processBatch(columns.get(i), builder);
+        if (i == 0) {
+          System.out.println(builder.build().getLong(ROW_COUNT - 1));
+        }
+      }
+    }
+    long endTime = System.currentTimeMillis();
+    System.out.println("Batch time: " + (endTime - startTime) + "ms");
+
+    // single
+    startTime = System.currentTimeMillis();
+    for (int j = 0; j < EPOCH; j++) {
+      for (int i = 0; i < LOOP_COUNT; i++) {
+        ColumnBuilder builder = TypeUtils.initColumnBuilder(TSDataType.INT64, 
ROW_COUNT);
+        expScalarFunction.processSingle(columns.get(i), builder);
+        builder.build().getLong(0);
+        if (i == 0) {
+          System.out.println(builder.build().getLong(ROW_COUNT - 1));
+        }
+      }
+    }
+    endTime = System.currentTimeMillis();
+    System.out.println("Single time: " + (endTime - startTime) + "ms");
+
+    // row
+    startTime = System.currentTimeMillis();
+    for (int j = 0; j < EPOCH; j++) {
+      for (int i = 0; i < LOOP_COUNT; i++) {
+        ColumnBuilder builder = TypeUtils.initColumnBuilder(TSDataType.INT64, 
ROW_COUNT);
+        try {
+          expScalarFunction.processRow(columns.get(i), builder);
+        } catch (IOException e) {
+          e.printStackTrace();
+        }
+        if (i == 0) {
+          System.out.println(builder.build().getLong(ROW_COUNT - 1));
+        }
+      }
+    }
+    endTime = System.currentTimeMillis();
+    System.out.println("Row time: " + (endTime - startTime) + "ms");
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/unary/scalar/UserDefineScalarFunctionTransformer.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/unary/scalar/UserDefineScalarFunctionTransformer.java
new file mode 100644
index 00000000000..3f72c677a80
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/column/unary/scalar/UserDefineScalarFunctionTransformer.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.transformation.dag.column.unary.scalar;
+
+import org.apache.iotdb.commons.udf.service.UDFManagementService;
+import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer;
+import org.apache.iotdb.db.exception.sql.SemanticException;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import 
org.apache.iotdb.db.queryengine.transformation.dag.column.ColumnTransformer;
+import 
org.apache.iotdb.db.queryengine.transformation.dag.column.multi.MultiColumnTransformer;
+import 
org.apache.iotdb.db.queryengine.transformation.dag.udf.UDFParametersFactory;
+import org.apache.iotdb.udf.api.access.ColumnToRowIterator;
+import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
+import org.apache.iotdb.udf.api.relational.ScalarFunction;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.type.Type;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+// TODO(UDSF): encapsulate refect and validate logic
+public class UserDefineScalarFunctionTransformer extends 
MultiColumnTransformer {
+
+  private final ScalarFunction scalarFunction;
+  private final List<TSDataType> childrenTypes;
+
+  public UserDefineScalarFunctionTransformer(
+      Type returnType,
+      String functionName,
+      List<Expression> children,
+      List<ColumnTransformer> childrenTransformers) {
+    super(returnType, childrenTransformers);
+    ScalarFunction scalarFunction =
+        UDFManagementService.getInstance().reflect(functionName, 
ScalarFunction.class);
+    this.childrenTypes =
+        childrenTransformers.stream()
+            .map(ColumnTransformer::getType)
+            .map(UDFDataTypeTransformer::transformReadTypeToTSDataType)
+            .collect(Collectors.toList());
+    // TODO: 1、Table UDF 里不应该再用 String Expression 了
+    // TODO:2、想办法弄到 attributes
+    UDFParameters udfParameters =
+        UDFParametersFactory.buildUdfParameters(
+            
children.stream().map(Expression::toString).collect(Collectors.toList()),
+            childrenTypes,
+            Collections.emptyMap());
+    try {
+      //      scalarFunction.validate(new 
UDFParameterValidator(udfParameters));
+      //      scalarFunction.beforeStart(udfParameters, new 
ScalarFunctionConfig());
+    } catch (Exception e) {
+      throw new SemanticException(e.getMessage());
+    }
+
+    this.scalarFunction = scalarFunction;
+  }
+
+  @Override
+  protected void doTransform(
+      List<Column> childrenColumns, ColumnBuilder builder, int positionCount) {
+    ColumnToRowIterator iterator =
+        new ColumnToRowIterator(childrenTypes, childrenColumns, positionCount);
+    //    while (iterator.hasNextRow()) {
+    //      try {
+    //        Row row = iterator.next();
+    //        Object result = scalarFunction.evaluate(row);
+    //        if (result == null) {
+    //          builder.appendNull();
+    //        } else {
+    //          builder.writeObject(result);
+    //        }
+    //      } catch (Exception e) {
+    //        throw new RuntimeException(
+    //            "Error occurs when evaluating UDF " + 
scalarFunction.getClass().getName(), e);
+    //      }
+    //    }
+  }
+
+  @Override
+  protected void doTransform(
+      List<Column> childrenColumns, ColumnBuilder builder, int positionCount, 
boolean[] selection) {
+    ColumnToRowIterator iterator =
+        new ColumnToRowIterator(childrenTypes, childrenColumns, positionCount);
+    int i = 0;
+    //    while (iterator.hasNextRow()) {
+    //      try {
+    //        Row row = iterator.next();
+    //        Object result = scalarFunction.evaluate(row);
+    //        if (selection[i++] || result == null) {
+    //          builder.appendNull();
+    //        } else {
+    //          builder.writeObject(result);
+    //        }
+    //      } catch (Exception e) {
+    //        throw new RuntimeException(
+    //            "Error occurs when evaluating UDF " + 
scalarFunction.getClass().getName(), e);
+    //      }
+    //    }
+  }
+
+  @Override
+  protected void checkType() {
+    // TODO: implement this method
+  }
+}

Reply via email to