This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new c1580924bca Support built-in forecast function through UDTF for tree 
model (#15682)
c1580924bca is described below

commit c1580924bcad91a04a066160eed5552ce59d9a13
Author: YangCaiyin <[email protected]>
AuthorDate: Mon Oct 27 20:12:21 2025 +0800

    Support built-in forecast function through UDTF for tree model (#15682)
---
 .../constant/BuiltinAggregationFunctionEnum.java   |   7 +-
 .../BuiltinTimeSeriesGeneratingFunctionEnum.java   |   1 +
 .../ainode/it/AINodeConcurrentInferenceIT.java     |  31 ++-
 .../execution/operator/AggregationUtil.java        |   2 +-
 .../plan/analyze/ExpressionAnalyzer.java           |   4 +-
 .../config/metadata/ShowFunctionsTask.java         |   6 +-
 .../plan/expression/ExpressionFactory.java         |   2 +-
 .../plan/expression/multi/FunctionExpression.java  |   4 +-
 .../plan/optimization/AggregationPushDown.java     |   2 +-
 .../plan/udf}/BuiltinAggregationFunction.java      |   2 +-
 .../plan/udf}/BuiltinScalarFunction.java           |   2 +-
 .../udf}/BuiltinTimeSeriesGeneratingFunction.java  |  51 +++-
 .../queryengine/plan/udf/UDFManagementService.java |   3 -
 .../db/queryengine/plan/udf/UDTFForecast.java      | 273 +++++++++++++++++++++
 .../plan/relational/analyzer/TSBSMetadata.java     |   2 +-
 .../plan/relational/analyzer/TestMetadata.java     |   2 +-
 16 files changed, 361 insertions(+), 33 deletions(-)

diff --git 
a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java
 
b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java
index 7c2c283d30e..85721cfb7a1 100644
--- 
a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java
+++ 
b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java
@@ -56,8 +56,11 @@ public enum BuiltinAggregationFunctionEnum {
 
   private static final Set<String> NATIVE_FUNCTION_NAMES =
       new HashSet<>(
-          
Arrays.stream(org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction.values())
-              
.map(org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction::getFunctionName)
+          Arrays.stream(
+                  
org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction.values())
+              .map(
+                  
org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction
+                      ::getFunctionName)
               .collect(Collectors.toList()));
 
   public static Set<String> getNativeFunctionNames() {
diff --git 
a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinTimeSeriesGeneratingFunctionEnum.java
 
b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinTimeSeriesGeneratingFunctionEnum.java
index fa0c5b5b148..5167c52f4d5 100644
--- 
a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinTimeSeriesGeneratingFunctionEnum.java
+++ 
b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinTimeSeriesGeneratingFunctionEnum.java
@@ -74,6 +74,7 @@ public enum BuiltinTimeSeriesGeneratingFunctionEnum {
   EQUAL_SIZE_BUCKET_OUTLIER_SAMPLE("EQUAL_SIZE_BUCKET_OUTLIER_SAMPLE"),
   JEXL("JEXL"),
   MASTER_REPAIR("MASTER_REPAIR"),
+  FORECAST("FORECAST"),
   M4("M4");
 
   private final String functionName;
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
index a5884a3dc8d..42fcf1d30d6 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
@@ -145,13 +145,21 @@ public class AINodeConcurrentInferenceIT {
     }
   }
 
+  String forecastTableFunctionSql =
+      "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM 
root.AI) ORDER BY time), predict_length=>%d";
+  String forecastUDTFSql =
+      "SELECT forecast(s, 'MODEL_ID'='%s', 'PREDICT_LENGTH'='%d') FROM 
root.AI";
+
   @Test
   public void concurrentCPUForecastTest() throws SQLException, 
InterruptedException {
-    concurrentCPUForecastTest("timer_xl");
-    concurrentCPUForecastTest("sundial");
+    concurrentCPUForecastTest("timer_xl", forecastUDTFSql);
+    concurrentCPUForecastTest("sundial", forecastUDTFSql);
+    concurrentCPUForecastTest("timer_xl", forecastTableFunctionSql);
+    concurrentCPUForecastTest("sundial", forecastTableFunctionSql);
   }
 
-  private void concurrentCPUForecastTest(String modelId) throws SQLException, 
InterruptedException {
+  private void concurrentCPUForecastTest(String modelId, String selectSQL)
+      throws SQLException, InterruptedException {
     try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
       final int threadCnt = 4;
@@ -162,9 +170,7 @@ public class AINodeConcurrentInferenceIT {
       long startTime = System.currentTimeMillis();
       concurrentInference(
           statement,
-          String.format(
-              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s 
FROM root.AI) ORDER BY time), predict_length=>%d",
-              modelId, predictLength),
+          String.format(selectSQL, modelId, predictLength),
           threadCnt,
           loop,
           predictLength);
@@ -179,11 +185,14 @@ public class AINodeConcurrentInferenceIT {
 
   @Test
   public void concurrentGPUForecastTest() throws SQLException, 
InterruptedException {
-    concurrentGPUForecastTest("timer_xl");
-    concurrentGPUForecastTest("sundial");
+    concurrentGPUForecastTest("timer_xl", forecastUDTFSql);
+    concurrentGPUForecastTest("sundial", forecastUDTFSql);
+    concurrentGPUForecastTest("timer_xl", forecastTableFunctionSql);
+    concurrentGPUForecastTest("sundial", forecastTableFunctionSql);
   }
 
-  public void concurrentGPUForecastTest(String modelId) throws SQLException, 
InterruptedException {
+  public void concurrentGPUForecastTest(String modelId, String selectSql)
+      throws SQLException, InterruptedException {
     try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
       final int threadCnt = 10;
@@ -195,9 +204,7 @@ public class AINodeConcurrentInferenceIT {
       long startTime = System.currentTimeMillis();
       concurrentInference(
           statement,
-          String.format(
-              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s 
FROM root.AI) ORDER BY time), predict_length=>%d",
-              modelId, predictLength),
+          String.format(selectSql, modelId, predictLength),
           threadCnt,
           loop,
           predictLength);
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/AggregationUtil.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/AggregationUtil.java
index cba7c3defa6..7c671e7b274 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/AggregationUtil.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/AggregationUtil.java
@@ -19,7 +19,6 @@
 
 package org.apache.iotdb.db.queryengine.execution.operator;
 
-import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
 import org.apache.iotdb.db.queryengine.execution.aggregation.TreeAggregator;
 import 
org.apache.iotdb.db.queryengine.execution.aggregation.timerangeiterator.ITimeRangeIterator;
 import 
org.apache.iotdb.db.queryengine.execution.aggregation.timerangeiterator.SingleTimeWindowIterator;
@@ -29,6 +28,7 @@ import 
org.apache.iotdb.db.queryengine.execution.operator.window.TimeWindow;
 import org.apache.iotdb.db.queryengine.plan.analyze.TypeProvider;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.AggregationDescriptor;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.GroupByTimeParameter;
+import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
 import org.apache.iotdb.db.queryengine.statistics.StatisticsManager;
 
 import org.apache.tsfile.block.column.Column;
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionAnalyzer.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionAnalyzer.java
index a53ee3f7d32..9a847a9ebd1 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionAnalyzer.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionAnalyzer.java
@@ -24,8 +24,6 @@ import org.apache.iotdb.commons.path.MeasurementPath;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.commons.path.PathPatternTree;
 import org.apache.iotdb.commons.schema.column.ColumnHeader;
-import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
-import 
org.apache.iotdb.commons.udf.builtin.BuiltinTimeSeriesGeneratingFunction;
 import org.apache.iotdb.db.exception.sql.SemanticException;
 import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
 import org.apache.iotdb.db.queryengine.common.schematree.ISchemaTree;
@@ -56,6 +54,8 @@ import 
org.apache.iotdb.db.queryengine.plan.expression.visitor.cartesian.ConcatD
 import 
org.apache.iotdb.db.queryengine.plan.expression.visitor.cartesian.ConcatDeviceAndBindSchemaForPredicateVisitor;
 import 
org.apache.iotdb.db.queryengine.plan.expression.visitor.cartesian.ConcatExpressionWithSuffixPathsVisitor;
 import org.apache.iotdb.db.queryengine.plan.statement.component.ResultColumn;
+import org.apache.iotdb.db.queryengine.plan.udf.BuiltinScalarFunction;
+import 
org.apache.iotdb.db.queryengine.plan.udf.BuiltinTimeSeriesGeneratingFunction;
 import org.apache.iotdb.db.utils.constant.SqlConstant;
 
 import org.apache.tsfile.common.constant.TsFileConstant;
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowFunctionsTask.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowFunctionsTask.java
index ab5c8261de1..608cc0b284e 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowFunctionsTask.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowFunctionsTask.java
@@ -24,9 +24,6 @@ import org.apache.iotdb.commons.schema.column.ColumnHeader;
 import org.apache.iotdb.commons.schema.column.ColumnHeaderConstant;
 import org.apache.iotdb.commons.udf.UDFInformation;
 import org.apache.iotdb.commons.udf.UDFType;
-import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
-import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
-import 
org.apache.iotdb.commons.udf.builtin.BuiltinTimeSeriesGeneratingFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinScalarFunction;
 import org.apache.iotdb.db.queryengine.common.header.DatasetHeader;
@@ -35,6 +32,9 @@ import 
org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult;
 import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask;
 import 
org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor;
 import 
org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction;
+import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
+import org.apache.iotdb.db.queryengine.plan.udf.BuiltinScalarFunction;
+import 
org.apache.iotdb.db.queryengine.plan.udf.BuiltinTimeSeriesGeneratingFunction;
 import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils;
 import org.apache.iotdb.db.queryengine.plan.udf.TreeUDFUtils;
 import org.apache.iotdb.rpc.TSStatusCode;
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/ExpressionFactory.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/ExpressionFactory.java
index e2f162d101a..b45a90c5b80 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/ExpressionFactory.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/ExpressionFactory.java
@@ -23,7 +23,6 @@ import org.apache.iotdb.common.rpc.thrift.TAggregationType;
 import org.apache.iotdb.commons.exception.IllegalPathException;
 import org.apache.iotdb.commons.path.MeasurementPath;
 import org.apache.iotdb.commons.path.PartialPath;
-import 
org.apache.iotdb.commons.udf.builtin.BuiltinTimeSeriesGeneratingFunction;
 import 
org.apache.iotdb.db.queryengine.plan.expression.binary.AdditionExpression;
 import 
org.apache.iotdb.db.queryengine.plan.expression.binary.EqualToExpression;
 import 
org.apache.iotdb.db.queryengine.plan.expression.binary.GreaterEqualExpression;
@@ -46,6 +45,7 @@ import 
org.apache.iotdb.db.queryengine.plan.expression.unary.LikeExpression;
 import 
org.apache.iotdb.db.queryengine.plan.expression.unary.LogicNotExpression;
 import org.apache.iotdb.db.queryengine.plan.expression.unary.RegularExpression;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.GroupByTimeParameter;
+import 
org.apache.iotdb.db.queryengine.plan.udf.BuiltinTimeSeriesGeneratingFunction;
 
 import org.apache.tsfile.enums.TSDataType;
 import org.apache.tsfile.utils.TimeDuration;
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/multi/FunctionExpression.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/multi/FunctionExpression.java
index 957856f4093..aef4a199c39 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/multi/FunctionExpression.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/multi/FunctionExpression.java
@@ -21,8 +21,6 @@ package org.apache.iotdb.db.queryengine.plan.expression.multi;
 
 import org.apache.iotdb.commons.conf.IoTDBConstant;
 import org.apache.iotdb.commons.path.PartialPath;
-import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
-import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
 import org.apache.iotdb.db.queryengine.common.NodeRef;
 import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper;
 import org.apache.iotdb.db.queryengine.plan.expression.Expression;
@@ -31,6 +29,8 @@ import 
org.apache.iotdb.db.queryengine.plan.expression.leaf.TimeSeriesOperand;
 import 
org.apache.iotdb.db.queryengine.plan.expression.multi.builtin.BuiltInScalarFunctionHelperFactory;
 import 
org.apache.iotdb.db.queryengine.plan.expression.visitor.ExpressionVisitor;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.InputLocation;
+import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
+import org.apache.iotdb.db.queryengine.plan.udf.BuiltinScalarFunction;
 import org.apache.iotdb.db.queryengine.plan.udf.TreeUDFUtils;
 import 
org.apache.iotdb.db.queryengine.transformation.dag.memory.LayerMemoryAssigner;
 import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDTFExecutor;
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/optimization/AggregationPushDown.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/optimization/AggregationPushDown.java
index fb98cedeace..461087768e5 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/optimization/AggregationPushDown.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/optimization/AggregationPushDown.java
@@ -25,7 +25,6 @@ import org.apache.iotdb.commons.path.AlignedPath;
 import org.apache.iotdb.commons.path.MeasurementPath;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.commons.schema.column.ColumnHeaderConstant;
-import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
 import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
 import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper;
 import org.apache.iotdb.db.queryengine.plan.analyze.Analysis;
@@ -57,6 +56,7 @@ import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.GroupByTimePa
 import org.apache.iotdb.db.queryengine.plan.statement.StatementType;
 import org.apache.iotdb.db.queryengine.plan.statement.component.Ordering;
 import org.apache.iotdb.db.queryengine.plan.statement.crud.QueryStatement;
+import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
 import org.apache.iotdb.db.schemaengine.schemaregion.utils.MetaUtils;
 import org.apache.iotdb.db.utils.SchemaUtils;
 
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinAggregationFunction.java
similarity index 98%
rename from 
iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java
rename to 
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinAggregationFunction.java
index 1c6b25ef53a..2913ddf86db 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinAggregationFunction.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.iotdb.commons.udf.builtin;
+package org.apache.iotdb.db.queryengine.plan.udf;
 
 import java.util.Arrays;
 import java.util.HashSet;
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinScalarFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinScalarFunction.java
similarity index 97%
rename from 
iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinScalarFunction.java
rename to 
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinScalarFunction.java
index eaa41f7f397..f47b738cc89 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinScalarFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinScalarFunction.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.iotdb.commons.udf.builtin;
+package org.apache.iotdb.db.queryengine.plan.udf;
 
 import com.google.common.collect.ImmutableSet;
 
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinTimeSeriesGeneratingFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinTimeSeriesGeneratingFunction.java
similarity index 65%
rename from 
iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinTimeSeriesGeneratingFunction.java
rename to 
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinTimeSeriesGeneratingFunction.java
index 304fdcb040f..67009258bc2 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinTimeSeriesGeneratingFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/BuiltinTimeSeriesGeneratingFunction.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.iotdb.commons.udf.builtin;
+package org.apache.iotdb.db.queryengine.plan.udf;
 
 import org.apache.iotdb.commons.udf.builtin.String.UDTFConcat;
 import org.apache.iotdb.commons.udf.builtin.String.UDTFEndsWith;
@@ -28,6 +28,51 @@ import 
org.apache.iotdb.commons.udf.builtin.String.UDTFStrLength;
 import org.apache.iotdb.commons.udf.builtin.String.UDTFStrLocate;
 import org.apache.iotdb.commons.udf.builtin.String.UDTFTrim;
 import org.apache.iotdb.commons.udf.builtin.String.UDTFUpper;
+import org.apache.iotdb.commons.udf.builtin.UDTFAbs;
+import org.apache.iotdb.commons.udf.builtin.UDTFAcos;
+import org.apache.iotdb.commons.udf.builtin.UDTFAsin;
+import org.apache.iotdb.commons.udf.builtin.UDTFAtan;
+import org.apache.iotdb.commons.udf.builtin.UDTFBottomK;
+import org.apache.iotdb.commons.udf.builtin.UDTFCeil;
+import org.apache.iotdb.commons.udf.builtin.UDTFChangePoints;
+import org.apache.iotdb.commons.udf.builtin.UDTFCommonDerivative;
+import org.apache.iotdb.commons.udf.builtin.UDTFCommonValueDifference;
+import org.apache.iotdb.commons.udf.builtin.UDTFConst;
+import org.apache.iotdb.commons.udf.builtin.UDTFConstE;
+import org.apache.iotdb.commons.udf.builtin.UDTFConstPi;
+import org.apache.iotdb.commons.udf.builtin.UDTFContains;
+import org.apache.iotdb.commons.udf.builtin.UDTFCos;
+import org.apache.iotdb.commons.udf.builtin.UDTFCosh;
+import org.apache.iotdb.commons.udf.builtin.UDTFDegrees;
+import org.apache.iotdb.commons.udf.builtin.UDTFEqualSizeBucketAggSample;
+import org.apache.iotdb.commons.udf.builtin.UDTFEqualSizeBucketM4Sample;
+import org.apache.iotdb.commons.udf.builtin.UDTFEqualSizeBucketOutlierSample;
+import org.apache.iotdb.commons.udf.builtin.UDTFEqualSizeBucketRandomSample;
+import org.apache.iotdb.commons.udf.builtin.UDTFExp;
+import org.apache.iotdb.commons.udf.builtin.UDTFFloor;
+import org.apache.iotdb.commons.udf.builtin.UDTFInRange;
+import org.apache.iotdb.commons.udf.builtin.UDTFJexl;
+import org.apache.iotdb.commons.udf.builtin.UDTFLog;
+import org.apache.iotdb.commons.udf.builtin.UDTFLog10;
+import org.apache.iotdb.commons.udf.builtin.UDTFM4;
+import org.apache.iotdb.commons.udf.builtin.UDTFMasterRepair;
+import org.apache.iotdb.commons.udf.builtin.UDTFMatches;
+import org.apache.iotdb.commons.udf.builtin.UDTFNonNegativeDerivative;
+import org.apache.iotdb.commons.udf.builtin.UDTFNonNegativeValueDifference;
+import org.apache.iotdb.commons.udf.builtin.UDTFNonZeroCount;
+import org.apache.iotdb.commons.udf.builtin.UDTFNonZeroDuration;
+import org.apache.iotdb.commons.udf.builtin.UDTFOnOff;
+import org.apache.iotdb.commons.udf.builtin.UDTFRadians;
+import org.apache.iotdb.commons.udf.builtin.UDTFSign;
+import org.apache.iotdb.commons.udf.builtin.UDTFSin;
+import org.apache.iotdb.commons.udf.builtin.UDTFSinh;
+import org.apache.iotdb.commons.udf.builtin.UDTFSqrt;
+import org.apache.iotdb.commons.udf.builtin.UDTFTan;
+import org.apache.iotdb.commons.udf.builtin.UDTFTanh;
+import org.apache.iotdb.commons.udf.builtin.UDTFTimeDifference;
+import org.apache.iotdb.commons.udf.builtin.UDTFTopK;
+import org.apache.iotdb.commons.udf.builtin.UDTFZeroCount;
+import org.apache.iotdb.commons.udf.builtin.UDTFZeroDuration;
 
 import com.google.common.collect.ImmutableSet;
 
@@ -93,7 +138,9 @@ public enum BuiltinTimeSeriesGeneratingFunction {
       "EQUAL_SIZE_BUCKET_OUTLIER_SAMPLE", 
UDTFEqualSizeBucketOutlierSample.class),
   JEXL("JEXL", UDTFJexl.class),
   MASTER_REPAIR("MASTER_REPAIR", UDTFMasterRepair.class),
-  M4("M4", UDTFM4.class);
+  M4("M4", UDTFM4.class),
+  FORECAST("FORECAST", UDTFForecast.class),
+  ;
 
   private final String functionName;
   private final Class<?> functionClass;
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDFManagementService.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDFManagementService.java
index 9efa02d7579..14acfa529e3 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDFManagementService.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDFManagementService.java
@@ -23,9 +23,6 @@ import org.apache.iotdb.common.rpc.thrift.Model;
 import org.apache.iotdb.commons.udf.UDFInformation;
 import org.apache.iotdb.commons.udf.UDFTable;
 import org.apache.iotdb.commons.udf.UDFType;
-import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
-import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
-import 
org.apache.iotdb.commons.udf.builtin.BuiltinTimeSeriesGeneratingFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinScalarFunction;
 import org.apache.iotdb.commons.udf.service.UDFClassLoader;
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
new file mode 100644
index 00000000000..22c2bce7b5e
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.udf;
+
+import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
+import org.apache.iotdb.common.rpc.thrift.TEndPoint;
+import org.apache.iotdb.commons.client.IClientManager;
+import org.apache.iotdb.commons.client.ainode.AINodeClient;
+import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
+import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
+import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
+import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher;
+import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.udf.api.UDTF;
+import org.apache.iotdb.udf.api.access.Row;
+import org.apache.iotdb.udf.api.collector.PointCollector;
+import org.apache.iotdb.udf.api.customizer.config.UDTFConfigurations;
+import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters;
+import org.apache.iotdb.udf.api.customizer.strategy.RowByRowAccessStrategy;
+import org.apache.iotdb.udf.api.type.Type;
+
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.TsBlock;
+import org.apache.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.tsfile.read.common.block.column.TsBlockSerde;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class UDTFForecast implements UDTF {
+  private static final TsBlockSerde serde = new TsBlockSerde();
+  private static final IClientManager<TEndPoint, AINodeClient> CLIENT_MANAGER =
+      AINodeClientManager.getInstance();
+  private TEndPoint targetAINode = new TEndPoint("127.0.0.1", 10810);
+  private String model_id;
+  private int maxInputLength;
+  private int outputLength;
+  private long outputStartTime;
+  private long outputInterval;
+  private boolean keepInput;
+  Map<String, String> options;
+  List<Type> types;
+  private LinkedList<Row> inputRows;
+  private TsBlockBuilder inputTsBlockBuilder;
+  private final IModelFetcher modelFetcher = ModelFetcher.getInstance();
+
+  private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>();
+
+  static {
+    ALLOWED_INPUT_TYPES.add(Type.INT32);
+    ALLOWED_INPUT_TYPES.add(Type.INT64);
+    ALLOWED_INPUT_TYPES.add(Type.FLOAT);
+    ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
+  }
+
+  private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
+  private static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH";
+  private static final int DEFAULT_OUTPUT_LENGTH = 96;
+  private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
+  public static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
+  private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
+  public static final long DEFAULT_OUTPUT_INTERVAL = 0L;
+  private static final String KEEP_INPUT_PARAMETER_NAME = "PRESERVE_INPUT";
+  private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
+  private static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS";
+  private static final String DEFAULT_OPTIONS = "";
+
+  private void checkType() {
+    for (Type type : this.types) {
+      if (!ALLOWED_INPUT_TYPES.contains(type)) {
+        throw new IllegalArgumentException(
+            String.format(
+                "Input data type %s is not supported, only %s are allowed.",
+                type, ALLOWED_INPUT_TYPES));
+      }
+    }
+  }
+
+  @Override
+  public void beforeStart(UDFParameters parameters, UDTFConfigurations 
configurations)
+      throws Exception {
+    this.types = parameters.getDataTypes();
+    checkType();
+    configurations.setAccessStrategy(new 
RowByRowAccessStrategy()).setOutputDataType(Type.DOUBLE);
+
+    this.model_id = parameters.getString(MODEL_ID_PARAMETER_NAME);
+    if (this.model_id == null || this.model_id.isEmpty()) {
+      throw new IllegalArgumentException(
+          "MODEL_ID parameter must be provided and cannot be empty.");
+    }
+    ModelInferenceDescriptor descriptor = 
modelFetcher.fetchModel(this.model_id);
+    this.targetAINode = descriptor.getTargetAINode();
+    this.maxInputLength = descriptor.getModelInformation().getInputShape()[0];
+
+    this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, 
DEFAULT_OUTPUT_INTERVAL);
+    this.outputLength =
+        parameters.getIntOrDefault(OUTPUT_LENGTH_PARAMETER_NAME, 
DEFAULT_OUTPUT_LENGTH);
+    this.outputStartTime =
+        parameters.getLongOrDefault(OUTPUT_START_TIME, 
DEFAULT_OUTPUT_START_TIME);
+    this.keepInput = parameters.getBooleanOrDefault(KEEP_INPUT_PARAMETER_NAME, 
DEFAULT_KEEP_INPUT);
+    this.options =
+        Arrays.stream(
+                parameters.getStringOrDefault(OPTIONS_PARAMETER_NAME, 
DEFAULT_OPTIONS).split(","))
+            .map(s -> s.split("="))
+            .filter(arr -> arr.length == 2 && !arr[0].isEmpty()) // 防御性检查
+            .collect(
+                Collectors.toMap(
+                    arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2 
// 如果 key 重复,保留后一个
+                    ));
+    this.inputRows = new LinkedList<>();
+    List<TSDataType> tsDataTypeList = new ArrayList<>(this.types.size() - 1);
+    for (int i = 0; i < this.types.size(); i++) {
+      tsDataTypeList.add(TSDataType.DOUBLE);
+    }
+    this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
+  }
+
+  private void setByType(Row row, PointCollector collector) throws IOException 
{
+    for (int i = 0; i < row.size(); i++) {
+      switch (this.types.get(i)) {
+        case INT32:
+          collector.putInt(row.getTime(), row.getInt(i));
+          break;
+        case INT64:
+          collector.putLong(row.getTime(), row.getLong(i));
+          break;
+        case FLOAT:
+          collector.putFloat(row.getTime(), row.getFloat(i));
+          break;
+        case DOUBLE:
+          collector.putDouble(row.getTime(), row.getDouble(i));
+          break;
+        default:
+          throw new IllegalArgumentException(
+              String.format("Unsupported data type %s", this.types.get(i + 
1)));
+      }
+    }
+  }
+
+  private void setByType(Row row, TsBlockBuilder tsBlockBuilder) throws 
IOException {
+    for (int i = 0; i < row.size(); i++) {
+      if (row.isNull(i)) {
+        tsBlockBuilder.getColumnBuilder(i).appendNull();
+        continue;
+      }
+      switch (this.types.get(i)) {
+        case INT32:
+          tsBlockBuilder.getColumnBuilder(i).writeInt(row.getInt(i));
+          break;
+        case INT64:
+          tsBlockBuilder.getColumnBuilder(i).writeLong(row.getLong(i));
+          break;
+        case FLOAT:
+          tsBlockBuilder.getColumnBuilder(i).writeFloat(row.getFloat(i));
+          break;
+        case DOUBLE:
+          tsBlockBuilder.getColumnBuilder(i).writeDouble(row.getDouble(i));
+          break;
+        default:
+          throw new IllegalArgumentException(
+              String.format("Unsupported data type %s", this.types.get(i + 
1)));
+      }
+    }
+  }
+
+  @Override
+  public void transform(Row row, PointCollector collector) throws Exception {
+    if (this.keepInput) {
+      setByType(row, collector);
+    }
+
+    if (maxInputLength != 0 && inputRows.size() >= maxInputLength) {
+      // If the input rows exceed the maximum length, remove the oldest row
+      inputRows.removeFirst();
+    }
+    inputRows.add(row);
+  }
+
+  private TsBlock forecast() throws Exception {
+    // Build the input data which will be sent to AINode
+    while (!inputRows.isEmpty()) {
+      Row row = inputRows.removeFirst();
+      inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getTime());
+      setByType(row, inputTsBlockBuilder);
+      inputTsBlockBuilder.declarePosition();
+    }
+
+    TsBlock inputData = inputTsBlockBuilder.build();
+
+    TForecastResp resp;
+    try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) {
+      resp = client.forecast(model_id, inputData, outputLength, options);
+    } catch (Exception e) {
+      throw new IoTDBRuntimeException(
+          e.getMessage(), TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode());
+    }
+
+    if (resp.getStatus().getCode() != 
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+      throw new IoTDBRuntimeException(
+          String.format(
+              "Forecast failed due to %d %s",
+              resp.getStatus().getCode(), resp.getStatus().getMessage()),
+          resp.getStatus().getCode());
+    }
+    return serde.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+  }
+
+  @Override
+  public void terminate(PointCollector collector) throws Exception {
+    long inputStartTime = inputRows.get(0).getTime();
+    long inputEndTime = inputRows.get(inputRows.size() - 1).getTime();
+    if (inputStartTime > inputEndTime) {
+      throw new IllegalArgumentException(
+          String.format(
+              "input end time should never less than start time, start time is 
%s, end time is %s",
+              inputStartTime, inputEndTime));
+    }
+    long interval = this.outputInterval;
+    if (outputInterval <= 0) {
+      interval = (inputEndTime - inputStartTime) / (inputRows.size() - 1);
+    }
+    long outputTime =
+        (this.outputStartTime == Long.MIN_VALUE) ? inputEndTime + interval : 
this.outputStartTime;
+    long[] outputTimes = new long[this.outputLength];
+    for (int i = 0; i < this.outputLength; i++) {
+      outputTimes[i] = outputTime + interval * i;
+    }
+
+    TsBlock forecastResult = forecast();
+    if (forecastResult.getPositionCount() != this.outputLength) {
+      throw new IllegalArgumentException(
+          String.format(
+              "The forecast result length %d does not match the expected 
output length %d",
+              forecastResult.getPositionCount(), this.outputLength));
+    }
+    if (forecastResult.getValueColumnCount() != 1) {
+      throw new IllegalArgumentException(
+          String.format(
+              "The forecast result should have only one value column, but got 
%d",
+              forecastResult.getValueColumnCount()));
+    }
+
+    for (int i = 0; i < forecastResult.getPositionCount(); i++) {
+      collector.putDouble(outputTimes[i], 
forecastResult.getValueColumns()[0].getDouble(i));
+    }
+  }
+}
diff --git 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java
 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java
index ff25549b33d..4822b15be86 100644
--- 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java
+++ 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java
@@ -25,7 +25,6 @@ import 
org.apache.iotdb.commons.partition.SchemaNodeManagementPartition;
 import org.apache.iotdb.commons.partition.SchemaPartition;
 import org.apache.iotdb.commons.path.PathPatternTree;
 import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory;
-import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
 import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
 import org.apache.iotdb.db.queryengine.common.SessionInfo;
 import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
@@ -46,6 +45,7 @@ import 
org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager;
 import org.apache.iotdb.db.queryengine.plan.relational.type.TypeManager;
 import 
org.apache.iotdb.db.queryengine.plan.relational.type.TypeNotFoundException;
 import org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignature;
+import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
 import org.apache.iotdb.mpp.rpc.thrift.TRegionRouteReq;
 import org.apache.iotdb.udf.api.relational.TableFunction;
 
diff --git 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java
 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java
index 369d3607fd3..ccea3fb4ab6 100644
--- 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java
+++ 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java
@@ -28,7 +28,6 @@ import org.apache.iotdb.commons.partition.SchemaPartition;
 import org.apache.iotdb.commons.path.PathPatternTree;
 import org.apache.iotdb.commons.schema.table.TsTable;
 import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory;
-import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
 import org.apache.iotdb.db.exception.sql.SemanticException;
 import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
 import org.apache.iotdb.db.queryengine.common.SessionInfo;
@@ -60,6 +59,7 @@ import 
org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager;
 import org.apache.iotdb.db.queryengine.plan.relational.type.TypeManager;
 import 
org.apache.iotdb.db.queryengine.plan.relational.type.TypeNotFoundException;
 import org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignature;
+import org.apache.iotdb.db.queryengine.plan.udf.BuiltinAggregationFunction;
 import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils;
 import org.apache.iotdb.db.schemaengine.table.InformationSchemaUtils;
 import org.apache.iotdb.mpp.rpc.thrift.TRegionRouteReq;


Reply via email to