This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new b6c13d7e3e8 [AINode] Forecast table function version2 (#16922)
b6c13d7e3e8 is described below

commit b6c13d7e3e87581ea43ba82e03a5ede133cdd452
Author: Yongzao <[email protected]>
AuthorDate: Thu Dec 18 09:41:24 2025 +0800

    [AINode] Forecast table function version2 (#16922)
    
    
    ---------
    
    Co-authored-by: Liu Zhengyun <[email protected]>
---
 .../ainode/it/AINodeConcurrentForecastIT.java      |   2 +-
 .../apache/iotdb/ainode/it/AINodeForecastIT.java   | 100 +++++-
 .../ainode/core/model/sundial/modeling_sundial.py  |  13 +-
 .../ainode/core/model/timer_xl/modeling_timer.py   |   6 +-
 .../plan/planner/TableOperatorGenerator.java       |  25 +-
 .../function/TableBuiltinTableFunction.java        |   6 +-
 .../function/tvf/ClassifyTableFunction.java        | 383 +++++++++++++++++++++
 .../function/tvf/ForecastTableFunction.java        | 263 ++++----------
 .../relational/utils/ResultColumnAppender.java     | 145 ++++++++
 .../db/queryengine/plan/udf/UDTFForecast.java      |   2 +
 .../relational/analyzer/TableFunctionTest.java     |   4 +-
 11 files changed, 723 insertions(+), 226 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
index 844ec1d8223..fe19f991e57 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
@@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
 
   private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
-      "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM 
root.AI) ORDER BY time, output_length=>%d)";
+      "SELECT * FROM FORECAST(model_id=>'%s', targets=>(SELECT time,s FROM 
root.AI) ORDER BY time, output_length=>%d)";
 
   @BeforeClass
   public static void setUp() throws Exception {
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
index a06656d4ada..bb0de13ed49 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
@@ -38,13 +38,21 @@ import java.sql.SQLException;
 import java.sql.Statement;
 
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
 
 @RunWith(IoTDBTestRunner.class)
 @Category({AIClusterIT.class})
 public class AINodeForecastIT {
 
   private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
-      "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM 
db.AI) ORDER BY time)";
+      "SELECT * FROM FORECAST("
+          + "model_id=>'%s', "
+          + "targets=>(SELECT time, s%d FROM db.AI WHERE time<%d ORDER BY time 
DESC LIMIT %d) ORDER BY time, "
+          + "output_start_time=>%d, "
+          + "output_length=>%d, "
+          + "output_interval=>%d, "
+          + "timecol=>'%s'"
+          + ")";
 
   @BeforeClass
   public static void setUp() throws Exception {
@@ -55,7 +63,7 @@ public class AINodeForecastIT {
       statement.execute("CREATE DATABASE db");
       statement.execute(
           "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 
FIELD, s3 INT64 FIELD)");
-      for (int i = 0; i < 2880; i++) {
+      for (int i = 0; i < 5760; i++) {
         statement.execute(
             String.format(
                 "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
@@ -81,18 +89,100 @@ public class AINodeForecastIT {
 
   public void forecastTableFunctionTest(
       Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws 
SQLException {
-    // Invoke call inference for specified models, there should exist result.
+    // Invoke forecast table function for specified models, there should exist 
result.
     for (int i = 0; i < 4; i++) {
       String forecastTableFunctionSQL =
-          String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, 
modelInfo.getModelId(), i);
+          String.format(
+              FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+              modelInfo.getModelId(),
+              i,
+              5760,
+              2880,
+              5760,
+              96,
+              1,
+              "time");
       try (ResultSet resultSet = 
statement.executeQuery(forecastTableFunctionSQL)) {
         int count = 0;
         while (resultSet.next()) {
           count++;
         }
-        // Ensure the call inference return results
+        // Ensure the forecast sentence return results
         Assert.assertTrue(count > 0);
       }
     }
   }
+
+  @Test
+  public void forecastTableFunctionErrorTest() throws SQLException {
+    for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) 
{
+      try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+          Statement statement = connection.createStatement()) {
+        forecastTableFunctionErrorTest(statement, modelInfo);
+      }
+    }
+  }
+
+  public void forecastTableFunctionErrorTest(
+      Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws 
SQLException {
+    // OUTPUT_START_TIME error
+    String invalidOutputStartTimeSQL =
+        String.format(
+            FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+            modelInfo.getModelId(),
+            0,
+            5760,
+            2880,
+            5759,
+            96,
+            1,
+            "time");
+    errorTest(
+        statement,
+        invalidOutputStartTimeSQL,
+        "701: The OUTPUT_START_TIME should be greater than the maximum 
timestamp of target time series. Expected greater than [5759] but found 
[5759].");
+
+    // OUTPUT_LENGTH error
+    String invalidOutputLengthSQL =
+        String.format(
+            FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+            modelInfo.getModelId(),
+            0,
+            5760,
+            2880,
+            5760,
+            0,
+            1,
+            "time");
+    errorTest(statement, invalidOutputLengthSQL, "701: OUTPUT_LENGTH should be 
greater than 0");
+
+    // OUTPUT_INTERVAL error
+    String invalidOutputIntervalSQL =
+        String.format(
+            FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+            modelInfo.getModelId(),
+            0,
+            5760,
+            2880,
+            5760,
+            96,
+            -1,
+            "time");
+    errorTest(statement, invalidOutputIntervalSQL, "701: OUTPUT_INTERVAL 
should be greater than 0");
+
+    // TIMECOL error
+    String invalidTimecolSQL2 =
+        String.format(
+            FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+            modelInfo.getModelId(),
+            0,
+            5760,
+            2880,
+            5760,
+            96,
+            1,
+            "s0");
+    errorTest(
+        statement, invalidTimecolSQL2, "701: The type of the column [s0] is 
not as expected.");
+  }
 }
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
index dc1de32506e..4e892a6817b 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
@@ -610,7 +610,11 @@ class SundialForPrediction(SundialPreTrainedModel, 
TSGenerationMixin):
             if attention_mask is not None and attention_mask.shape[1] > (
                 input_ids.shape[1] // self.config.input_token_len
             ):
-                input_ids = input_ids[:, -(attention_mask.shape[1] - 
past_length) :]
+                input_ids = input_ids[
+                    :,
+                    -(attention_mask.shape[1] - past_length)
+                    * self.config.input_token_len :,
+                ]
             # 2 - If the past_length is smaller than input_ids', then 
input_ids holds all input tokens. We can discard
             # input_ids based on the past_length.
             elif past_length < (input_ids.shape[1] // 
self.config.input_token_len):
@@ -623,9 +627,10 @@ class SundialForPrediction(SundialPreTrainedModel, 
TSGenerationMixin):
             position_ids = attention_mask.long().cumsum(-1) - 1
             position_ids.masked_fill_(attention_mask == 0, 1)
             if past_key_values:
-                position_ids = position_ids[
-                    :, -(input_ids.shape[1] // self.config.input_token_len) :
-                ]
+                token_num = (
+                    input_ids.shape[1] + self.config.input_token_len - 1
+                ) // self.config.input_token_len
+                position_ids = position_ids[:, -token_num:]
 
         # if `inputs_embeds` are passed, we only want to use them in the 1st 
generation step
         if inputs_embeds is not None and past_key_values is None:
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
index fc9d7b41388..1722b27d715 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
@@ -603,7 +603,11 @@ class TimerForPrediction(TimerPreTrainedModel, 
TSGenerationMixin):
             if attention_mask is not None and attention_mask.shape[1] > (
                 input_ids.shape[1] // self.config.input_token_len
             ):
-                input_ids = input_ids[:, -(attention_mask.shape[1] - 
past_length) :]
+                input_ids = input_ids[
+                    :,
+                    -(attention_mask.shape[1] - past_length)
+                    * self.config.input_token_len :,
+                ]
             # 2 - If the past_length is smaller than input_ids', then 
input_ids holds all input tokens. We can discard
             # input_ids based on the past_length.
             elif past_length < (input_ids.shape[1] // 
self.config.input_token_len):
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
index 6c6ae5cf327..a2bd0dc4815 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
@@ -2013,15 +2013,22 @@ public class TableOperatorGenerator extends 
PlanVisitor<Operator, LocalExecution
 
   @Override
   public Operator visitGroup(GroupNode node, LocalExecutionPlanContext 
context) {
-    StreamSortNode streamSortNode =
-        new StreamSortNode(
-            node.getPlanNodeId(),
-            node.getChild(),
-            node.getOrderingScheme(),
-            false,
-            false,
-            node.getPartitionKeyCount() - 1);
-    return visitStreamSort(streamSortNode, context);
+    if (node.getPartitionKeyCount() == 0) {
+      SortNode sortNode =
+          new SortNode(
+              node.getPlanNodeId(), node.getChild(), node.getOrderingScheme(), 
false, false);
+      return visitSort(sortNode, context);
+    } else {
+      StreamSortNode streamSortNode =
+          new StreamSortNode(
+              node.getPlanNodeId(),
+              node.getChild(),
+              node.getOrderingScheme(),
+              false,
+              false,
+              node.getPartitionKeyCount() - 1);
+      return visitStreamSort(streamSortNode, context);
+    }
   }
 
   @Override
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
index a5a57cebadb..61b96809f84 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
@@ -25,6 +25,7 @@ import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.HOPTableFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
 import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction;
+import 
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ClassifyTableFunction;
 import 
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
 import 
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.PatternMatchTableFunction;
 import org.apache.iotdb.udf.api.relational.TableFunction;
@@ -42,7 +43,8 @@ public enum TableBuiltinTableFunction {
   VARIATION("variation"),
   CAPACITY("capacity"),
   FORECAST("forecast"),
-  PATTERN_MATCH("pattern_match");
+  PATTERN_MATCH("pattern_match"),
+  CLASSIFY("classify");
 
   private final String functionName;
 
@@ -86,6 +88,8 @@ public enum TableBuiltinTableFunction {
         return new CapacityTableFunction();
       case "forecast":
         return new ForecastTableFunction();
+      case "classify":
+        return new ClassifyTableFunction();
       default:
         throw new UnsupportedOperationException("Unsupported table function: " 
+ functionName);
     }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
new file mode 100644
index 00000000000..670e019a4b6
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
@@ -0,0 +1,383 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.function.tvf;
+
+import org.apache.iotdb.ainode.rpc.thrift.TForecastReq;
+import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
+import org.apache.iotdb.commons.client.IClientManager;
+import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
+import org.apache.iotdb.db.exception.sql.SemanticException;
+import org.apache.iotdb.db.protocol.client.an.AINodeClient;
+import org.apache.iotdb.db.protocol.client.an.AINodeClientManager;
+import 
org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.udf.api.exception.UDFException;
+import org.apache.iotdb.udf.api.relational.TableFunction;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis;
+import org.apache.iotdb.udf.api.relational.table.TableFunctionHandle;
+import 
org.apache.iotdb.udf.api.relational.table.TableFunctionProcessorProvider;
+import org.apache.iotdb.udf.api.relational.table.argument.Argument;
+import org.apache.iotdb.udf.api.relational.table.argument.DescribedSchema;
+import org.apache.iotdb.udf.api.relational.table.argument.ScalarArgument;
+import org.apache.iotdb.udf.api.relational.table.argument.TableArgument;
+import 
org.apache.iotdb.udf.api.relational.table.processor.TableFunctionDataProcessor;
+import 
org.apache.iotdb.udf.api.relational.table.specification.ParameterSpecification;
+import 
org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSpecification;
+import 
org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification;
+import org.apache.iotdb.udf.api.type.Type;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.TsBlock;
+import org.apache.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.tsfile.read.common.block.column.TsBlockSerde;
+import org.apache.tsfile.utils.PublicBAOS;
+import org.apache.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static 
org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender;
+import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
+
+public class ClassifyTableFunction implements TableFunction {
+
+  public static class ClassifyTableFunctionHandle implements 
TableFunctionHandle {
+    String modelId;
+    int maxInputLength;
+    List<Type> inputColumnTypes;
+
+    public ClassifyTableFunctionHandle() {}
+
+    public ClassifyTableFunctionHandle(
+        String modelId, int maxInputLength, List<Type> inputColumnTypes) {
+      this.modelId = modelId;
+      this.maxInputLength = maxInputLength;
+      this.inputColumnTypes = inputColumnTypes;
+    }
+
+    @Override
+    public byte[] serialize() {
+      try (PublicBAOS publicBAOS = new PublicBAOS();
+          DataOutputStream outputStream = new DataOutputStream(publicBAOS)) {
+        ReadWriteIOUtils.write(modelId, outputStream);
+        ReadWriteIOUtils.write(maxInputLength, outputStream);
+        ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream);
+        for (Type type : inputColumnTypes) {
+          ReadWriteIOUtils.write(type.getType(), outputStream);
+        }
+        outputStream.flush();
+        return publicBAOS.toByteArray();
+      } catch (IOException e) {
+        throw new IoTDBRuntimeException(
+            String.format(
+                "Error occurred while serializing ForecastTableFunctionHandle: 
%s", e.getMessage()),
+            TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
+      }
+    }
+
+    @Override
+    public void deserialize(byte[] bytes) {
+      ByteBuffer buffer = ByteBuffer.wrap(bytes);
+      this.modelId = ReadWriteIOUtils.readString(buffer);
+      this.maxInputLength = ReadWriteIOUtils.readInt(buffer);
+      int size = ReadWriteIOUtils.readInt(buffer);
+      this.inputColumnTypes = new ArrayList<>(size);
+      for (int i = 0; i < size; i++) {
+        
inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readString(buffer)));
+      }
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      ClassifyTableFunctionHandle that = (ClassifyTableFunctionHandle) o;
+      return maxInputLength == that.maxInputLength
+          && Objects.equals(modelId, that.modelId)
+          && Objects.equals(inputColumnTypes, that.inputColumnTypes);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(modelId, maxInputLength, inputColumnTypes);
+    }
+  }
+
+  private static final String INPUT_PARAMETER_NAME = "INPUT";
+  private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
+  public static final String TIMECOL_PARAMETER_NAME = "TIMECOL";
+  private static final String DEFAULT_TIME_COL = "time";
+  private static final String DEFAULT_OUTPUT_COLUMN_NAME = "category";
+  private static final int MAX_INPUT_LENGTH = 2880;
+
+  private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>();
+
+  static {
+    ALLOWED_INPUT_TYPES.add(Type.INT32);
+    ALLOWED_INPUT_TYPES.add(Type.INT64);
+    ALLOWED_INPUT_TYPES.add(Type.FLOAT);
+    ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
+  }
+
+  @Override
+  public List<ParameterSpecification> getArgumentsSpecifications() {
+    return Arrays.asList(
+        
TableParameterSpecification.builder().name(INPUT_PARAMETER_NAME).setSemantics().build(),
+        ScalarParameterSpecification.builder()
+            .name(MODEL_ID_PARAMETER_NAME)
+            .type(Type.STRING)
+            .build(),
+        ScalarParameterSpecification.builder()
+            .name(TIMECOL_PARAMETER_NAME)
+            .type(Type.STRING)
+            .defaultValue(DEFAULT_TIME_COL)
+            .build());
+  }
+
+  @Override
+  public TableFunctionAnalysis analyze(Map<String, Argument> arguments) throws 
UDFException {
+    TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME);
+    String modelId = (String) ((ScalarArgument) 
arguments.get(MODEL_ID_PARAMETER_NAME)).getValue();
+    // modelId should never be null or empty
+    if (modelId == null || modelId.isEmpty()) {
+      throw new SemanticException(
+          String.format("%s should never be null or empty", 
MODEL_ID_PARAMETER_NAME));
+    }
+
+    String timeColumn =
+        ((String) ((ScalarArgument) 
arguments.get(TIMECOL_PARAMETER_NAME)).getValue())
+            .toLowerCase(Locale.ENGLISH);
+
+    if (timeColumn.isEmpty()) {
+      throw new SemanticException(
+          String.format("%s should never be null or empty.", 
TIMECOL_PARAMETER_NAME));
+    }
+
+    // predicated columns should never contain partition by columns and time 
column
+    Set<String> excludedColumns =
+        input.getPartitionBy().stream()
+            .map(s -> s.toLowerCase(Locale.ENGLISH))
+            .collect(Collectors.toSet());
+    excludedColumns.add(timeColumn);
+    int timeColumnIndex = findColumnIndex(input, timeColumn, 
Collections.singleton(Type.TIMESTAMP));
+
+    // List of required column indexes
+    List<Integer> requiredIndexList = new ArrayList<>();
+    requiredIndexList.add(timeColumnIndex);
+    DescribedSchema.Builder properColumnSchemaBuilder =
+        new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);
+
+    List<Type> inputColumnTypes = new ArrayList<>();
+    List<Optional<String>> allInputColumnsName = input.getFieldNames();
+    List<Type> allInputColumnsType = input.getFieldTypes();
+
+    for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
+      Optional<String> fieldName = allInputColumnsName.get(i);
+      // All input value columns are required for model forecasting
+      if (!fieldName.isPresent()
+          || 
!excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
+        Type columnType = allInputColumnsType.get(i);
+        checkType(columnType, fieldName.orElse(""));
+        inputColumnTypes.add(columnType);
+        requiredIndexList.add(i);
+      }
+    }
+    properColumnSchemaBuilder.addField(DEFAULT_OUTPUT_COLUMN_NAME, Type.INT32);
+
+    ClassifyTableFunctionHandle functionHandle =
+        new ClassifyTableFunctionHandle(modelId, MAX_INPUT_LENGTH, 
inputColumnTypes);
+
+    // outputColumnSchema
+    return TableFunctionAnalysis.builder()
+        .properColumnSchema(properColumnSchemaBuilder.build())
+        .handle(functionHandle)
+        .requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList)
+        .build();
+  }
+
+  // only allow for INT32, INT64, FLOAT, DOUBLE
+  private void checkType(Type type, String columnName) {
+    if (!ALLOWED_INPUT_TYPES.contains(type)) {
+      throw new SemanticException(
+          String.format(
+              "The type of the column [%s] is [%s], only INT32, INT64, FLOAT, 
DOUBLE is allowed",
+              columnName, type));
+    }
+  }
+
+  @Override
+  public TableFunctionHandle createTableFunctionHandle() {
+    return new ClassifyTableFunctionHandle();
+  }
+
+  @Override
+  public TableFunctionProcessorProvider getProcessorProvider(
+      TableFunctionHandle tableFunctionHandle) {
+    return new TableFunctionProcessorProvider() {
+      @Override
+      public TableFunctionDataProcessor getDataProcessor() {
+        return new ClassifyDataProcessor((ClassifyTableFunctionHandle) 
tableFunctionHandle);
+      }
+    };
+  }
+
+  private static class ClassifyDataProcessor implements 
TableFunctionDataProcessor {
+
+    private static final TsBlockSerde SERDE = new TsBlockSerde();
+    private static final IClientManager<Integer, AINodeClient> CLIENT_MANAGER =
+        AINodeClientManager.getInstance();
+
+    private final String modelId;
+    private final int maxInputLength;
+    private final LinkedList<Record> inputRecords;
+    private final TsBlockBuilder inputTsBlockBuilder;
+    private final List<ResultColumnAppender> inputColumnAppenderList;
+    private final List<ResultColumnAppender> resultColumnAppenderList;
+
+    public ClassifyDataProcessor(ClassifyTableFunctionHandle functionHandle) {
+      this.modelId = functionHandle.modelId;
+      this.maxInputLength = functionHandle.maxInputLength;
+      this.inputRecords = new LinkedList<>();
+      List<TSDataType> inputTsDataTypeList =
+          new ArrayList<>(functionHandle.inputColumnTypes.size());
+      this.inputColumnAppenderList = new 
ArrayList<>(functionHandle.inputColumnTypes.size());
+      for (Type type : functionHandle.inputColumnTypes) {
+        // AINode currently only accept double input
+        inputTsDataTypeList.add(TSDataType.DOUBLE);
+        inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE));
+      }
+      this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
+      this.resultColumnAppenderList = new ArrayList<>(1);
+      
this.resultColumnAppenderList.add(createResultColumnAppender(Type.INT32));
+    }
+
+    @Override
+    public void process(
+        Record input,
+        List<ColumnBuilder> properColumnBuilders,
+        ColumnBuilder passThroughIndexBuilder) {
+      // only keep at most maxInputLength rows
+      if (maxInputLength != 0 && inputRecords.size() == maxInputLength) {
+        inputRecords.removeFirst();
+      }
+      inputRecords.add(input);
+    }
+
+    @Override
+    public void finish(
+        List<ColumnBuilder> properColumnBuilders, ColumnBuilder 
passThroughIndexBuilder) {
+
+      // time column
+      long inputStartTime = inputRecords.getFirst().getLong(0);
+      long inputEndTime = inputRecords.getLast().getLong(0);
+      if (inputEndTime < inputStartTime) {
+        throw new SemanticException(
+            String.format(
+                "input end time should never less than start time, start time 
is %s, end time is %s",
+                inputStartTime, inputEndTime));
+      }
+      int outputLength = inputRecords.size();
+      for (Record inputRecord : inputRecords) {
+        properColumnBuilders.get(0).writeLong(inputRecord.getLong(0));
+      }
+
+      // predicated columns
+      TsBlock predicatedResult = classify();
+      if (predicatedResult.getPositionCount() != outputLength) {
+        throw new IoTDBRuntimeException(
+            String.format(
+                "Model %s output length is %s, doesn't equal to specified %s",
+                modelId, predicatedResult.getPositionCount(), outputLength),
+            TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
+      }
+
+      for (int columnIndex = 1, size = predicatedResult.getValueColumnCount();
+          columnIndex <= size;
+          columnIndex++) {
+        Column column = predicatedResult.getColumn(columnIndex - 1);
+        ColumnBuilder builder = properColumnBuilders.get(columnIndex);
+        ResultColumnAppender appender = 
resultColumnAppenderList.get(columnIndex - 1);
+        for (int row = 0; row < outputLength; row++) {
+          if (column.isNull(row)) {
+            builder.appendNull();
+          } else {
+            // convert double to real type
+            appender.writeDouble(column.getDouble(row), builder);
+          }
+        }
+      }
+    }
+
+    private TsBlock classify() {
+      int outputLength = inputRecords.size();
+      // construct inputTSBlock for AINode
+      while (!inputRecords.isEmpty()) {
+        Record row = inputRecords.removeFirst();
+        inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0));
+        for (int i = 1, size = row.size(); i < size; i++) {
+          // we set null input to 0.0
+          if (row.isNull(i)) {
+            inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(0.0);
+          } else {
+            // need to transform other types to DOUBLE
+            inputTsBlockBuilder
+                .getColumnBuilder(i - 1)
+                .writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row, 
i));
+          }
+        }
+        inputTsBlockBuilder.declarePosition();
+      }
+      TsBlock inputData = inputTsBlockBuilder.build();
+
+      TForecastResp resp;
+      try (AINodeClient client =
+          
CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) {
+        resp = client.forecast(new TForecastReq(modelId, 
SERDE.serialize(inputData), outputLength));
+      } catch (Exception e) {
+        throw new IoTDBRuntimeException(e.getMessage(), 
CAN_NOT_CONNECT_AINODE.getStatusCode());
+      }
+
+      if (resp.getStatus().getCode() != 
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+        String message =
+            String.format(
+                "Error occurred while executing classify:[%s]", 
resp.getStatus().getMessage());
+        throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
+      }
+      return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+    }
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
index 08f7ec6c833..950d8c464e8 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -26,7 +26,7 @@ import 
org.apache.iotdb.commons.exception.IoTDBRuntimeException;
 import org.apache.iotdb.db.exception.sql.SemanticException;
 import org.apache.iotdb.db.protocol.client.an.AINodeClient;
 import org.apache.iotdb.db.protocol.client.an.AINodeClientManager;
-import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
+import 
org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender;
 import org.apache.iotdb.rpc.TSStatusCode;
 import org.apache.iotdb.udf.api.relational.TableFunction;
 import org.apache.iotdb.udf.api.relational.access.Record;
@@ -58,6 +58,7 @@ import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedList;
@@ -70,12 +71,11 @@ import java.util.Set;
 import java.util.stream.Collectors;
 
 import static 
org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender;
 import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
 
 public class ForecastTableFunction implements TableFunction {
 
-  private static final TsBlockSerde SERDE = new TsBlockSerde();
-
   public static class ForecastTableFunctionHandle implements 
TableFunctionHandle {
     String modelId;
     int maxInputLength;
@@ -84,7 +84,7 @@ public class ForecastTableFunction implements TableFunction {
     long outputInterval;
     boolean keepInput;
     Map<String, String> options;
-    List<Type> types;
+    List<Type> targetColumntypes;
 
     public ForecastTableFunctionHandle() {}
 
@@ -96,7 +96,7 @@ public class ForecastTableFunction implements TableFunction {
         int outputLength,
         long outputStartTime,
         long outputInterval,
-        List<Type> types) {
+        List<Type> targetColumntypes) {
       this.keepInput = keepInput;
       this.maxInputLength = maxInputLength;
       this.modelId = modelId;
@@ -104,7 +104,7 @@ public class ForecastTableFunction implements TableFunction 
{
       this.outputLength = outputLength;
       this.outputStartTime = outputStartTime;
       this.outputInterval = outputInterval;
-      this.types = types;
+      this.targetColumntypes = targetColumntypes;
     }
 
     @Override
@@ -118,8 +118,8 @@ public class ForecastTableFunction implements TableFunction 
{
         ReadWriteIOUtils.write(outputInterval, outputStream);
         ReadWriteIOUtils.write(keepInput, outputStream);
         ReadWriteIOUtils.write(options, outputStream);
-        ReadWriteIOUtils.write(types.size(), outputStream);
-        for (Type type : types) {
+        ReadWriteIOUtils.write(targetColumntypes.size(), outputStream);
+        for (Type type : targetColumntypes) {
           ReadWriteIOUtils.write(type.getType(), outputStream);
         }
         outputStream.flush();
@@ -143,9 +143,9 @@ public class ForecastTableFunction implements TableFunction 
{
       this.keepInput = ReadWriteIOUtils.readBoolean(buffer);
       this.options = ReadWriteIOUtils.readMap(buffer);
       int size = ReadWriteIOUtils.readInt(buffer);
-      this.types = new ArrayList<>(size);
+      this.targetColumntypes = new ArrayList<>(size);
       for (int i = 0; i < size; i++) {
-        types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
+        targetColumntypes.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
       }
     }
 
@@ -165,7 +165,7 @@ public class ForecastTableFunction implements TableFunction 
{
           && keepInput == that.keepInput
           && Objects.equals(modelId, that.modelId)
           && Objects.equals(options, that.options)
-          && Objects.equals(types, that.types);
+          && Objects.equals(targetColumntypes, that.targetColumntypes);
     }
 
     @Override
@@ -178,16 +178,14 @@ public class ForecastTableFunction implements 
TableFunction {
           outputInterval,
           keepInput,
           options,
-          types);
+          targetColumntypes);
     }
   }
 
-  private static final String INPUT_PARAMETER_NAME = "INPUT";
+  private static final String TARGETS_PARAMETER_NAME = "TARGETS";
   private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
   private static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH";
   private static final int DEFAULT_OUTPUT_LENGTH = 96;
-  private static final String PREDICATED_COLUMNS_PARAMETER_NAME = 
"PREDICATED_COLUMNS";
-  private static final String DEFAULT_PREDICATED_COLUMNS = "";
   private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
   public static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
   private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
@@ -212,20 +210,10 @@ public class ForecastTableFunction implements 
TableFunction {
     ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
   }
 
-  // need to set before analyze method is called
-  // should only be used in fe scope, never be used in 
TableFunctionProcessorProvider
-  // The reason we don't directly set modelFetcher=ModelFetcher.getInstance() 
is that we need to
-  // mock IModelFetcher in UT
-  private IModelFetcher modelFetcher = null;
-
-  public void setModelFetcher(IModelFetcher modelFetcher) {
-    this.modelFetcher = modelFetcher;
-  }
-
   @Override
   public List<ParameterSpecification> getArgumentsSpecifications() {
     return Arrays.asList(
-        
TableParameterSpecification.builder().name(INPUT_PARAMETER_NAME).setSemantics().build(),
+        
TableParameterSpecification.builder().name(TARGETS_PARAMETER_NAME).setSemantics().build(),
         ScalarParameterSpecification.builder()
             .name(MODEL_ID_PARAMETER_NAME)
             .type(Type.STRING)
@@ -245,11 +233,6 @@ public class ForecastTableFunction implements 
TableFunction {
             .type(Type.INT64)
             .defaultValue(DEFAULT_OUTPUT_INTERVAL)
             .build(),
-        ScalarParameterSpecification.builder()
-            .name(PREDICATED_COLUMNS_PARAMETER_NAME)
-            .type(Type.STRING)
-            .defaultValue(DEFAULT_PREDICATED_COLUMNS)
-            .build(),
         ScalarParameterSpecification.builder()
             .name(TIMECOL_PARAMETER_NAME)
             .type(Type.STRING)
@@ -269,7 +252,7 @@ public class ForecastTableFunction implements TableFunction 
{
 
   @Override
   public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
-    TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME);
+    TableArgument targets = (TableArgument) 
arguments.get(TARGETS_PARAMETER_NAME);
     String modelId = (String) ((ScalarArgument) 
arguments.get(MODEL_ID_PARAMETER_NAME)).getValue();
     // modelId should never be null or empty
     if (modelId == null || modelId.isEmpty()) {
@@ -284,82 +267,58 @@ public class ForecastTableFunction implements 
TableFunction {
           String.format("%s should be greater than 0", 
OUTPUT_LENGTH_PARAMETER_NAME));
     }
 
-    String predicatedColumns =
-        (String) ((ScalarArgument) 
arguments.get(PREDICATED_COLUMNS_PARAMETER_NAME)).getValue();
-
     String timeColumn =
         ((String) ((ScalarArgument) 
arguments.get(TIMECOL_PARAMETER_NAME)).getValue())
             .toLowerCase(Locale.ENGLISH);
-
     if (timeColumn.isEmpty()) {
       throw new SemanticException(
           String.format("%s should never be null or empty.", 
TIMECOL_PARAMETER_NAME));
     }
 
+    long outputInterval = (long) ((ScalarArgument) 
arguments.get(OUTPUT_INTERVAL)).getValue();
+    if (outputInterval < 0) {
+      throw new SemanticException(String.format("%s should be greater than 0", 
OUTPUT_INTERVAL));
+    }
+
     // predicated columns should never contain partition by columns and time 
column
     Set<String> excludedColumns =
-        input.getPartitionBy().stream()
+        targets.getPartitionBy().stream()
             .map(s -> s.toLowerCase(Locale.ENGLISH))
             .collect(Collectors.toSet());
     excludedColumns.add(timeColumn);
-    int timeColumnIndex = findColumnIndex(input, timeColumn, 
Collections.singleton(Type.TIMESTAMP));
+    int timeColumnIndex =
+        findColumnIndex(targets, timeColumn, 
Collections.singleton(Type.TIMESTAMP));
 
+    // List of required column indexes
     List<Integer> requiredIndexList = new ArrayList<>();
     requiredIndexList.add(timeColumnIndex);
     DescribedSchema.Builder properColumnSchemaBuilder =
         new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);
 
-    List<Type> predicatedColumnTypes = new ArrayList<>();
-    List<Optional<String>> allInputColumnsName = input.getFieldNames();
-    List<Type> allInputColumnsType = input.getFieldTypes();
-    if (predicatedColumns.isEmpty()) {
-      // predicated columns by default include all columns from input table 
except for timecol and
-      // partition by columns
-      for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
-        Optional<String> fieldName = allInputColumnsName.get(i);
-        if (!fieldName.isPresent()
-            || 
!excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
-          Type columnType = allInputColumnsType.get(i);
-          predicatedColumnTypes.add(columnType);
-          checkType(columnType, fieldName.orElse(""));
-          requiredIndexList.add(i);
-          properColumnSchemaBuilder.addField(fieldName, columnType);
-        }
-      }
-    } else {
-      String[] predictedColumnsArray = predicatedColumns.split(";");
-      Map<String, Integer> inputColumnIndexMap = new HashMap<>();
-      for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
-        Optional<String> fieldName = allInputColumnsName.get(i);
-        if (!fieldName.isPresent()) {
-          continue;
-        }
-        inputColumnIndexMap.put(fieldName.get().toLowerCase(Locale.ENGLISH), 
i);
+    List<Type> targetColumnTypes = new ArrayList<>();
+    List<Optional<String>> allInputColumnsName = targets.getFieldNames();
+    List<Type> allInputColumnsType = targets.getFieldTypes();
+
+    // predicated columns = all input columns except timecol / partition by 
columns
+    for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
+      Optional<String> fieldName = allInputColumnsName.get(i);
+      if (!fieldName.isPresent()
+          || 
excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
+        continue;
       }
 
-      Set<Integer> requiredIndexSet = new 
HashSet<>(predictedColumnsArray.length);
-      // columns need to be predicated
-      for (String outputColumn : predictedColumnsArray) {
-        String lowerCaseOutputColumn = 
outputColumn.toLowerCase(Locale.ENGLISH);
-        if (excludedColumns.contains(lowerCaseOutputColumn)) {
-          throw new SemanticException(
-              String.format("%s is in partition by clause or is time column", 
outputColumn));
-        }
-        Integer inputColumnIndex = 
inputColumnIndexMap.get(lowerCaseOutputColumn);
-        if (inputColumnIndex == null) {
-          throw new SemanticException(
-              String.format("Column %s don't exist in input", outputColumn));
-        }
-        if (!requiredIndexSet.add(inputColumnIndex)) {
-          throw new SemanticException(String.format("Duplicate column %s", 
outputColumn));
-        }
+      Type columnType = allInputColumnsType.get(i);
+      targetColumnTypes.add(columnType);
+      checkType(columnType, fieldName.get());
+      requiredIndexList.add(i);
+      properColumnSchemaBuilder.addField(fieldName.get(), columnType);
+    }
 
-        Type columnType = allInputColumnsType.get(inputColumnIndex);
-        predicatedColumnTypes.add(columnType);
-        checkType(columnType, outputColumn);
-        requiredIndexList.add(inputColumnIndex);
-        properColumnSchemaBuilder.addField(outputColumn, columnType);
-      }
+    if (targetColumnTypes.size() > 1) {
+      throw new SemanticException(
+          String.format(
+              "%s should not contain more than one target column, found [%s] 
target columns.",
+              TARGETS_PARAMETER_NAME, targetColumnTypes.size()));
     }
 
     boolean keepInput =
@@ -369,7 +328,6 @@ public class ForecastTableFunction implements TableFunction 
{
     }
 
     long outputStartTime = (long) ((ScalarArgument) 
arguments.get(OUTPUT_START_TIME)).getValue();
-    long outputInterval = (long) ((ScalarArgument) 
arguments.get(OUTPUT_INTERVAL)).getValue();
     String options = (String) ((ScalarArgument) 
arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
 
     ForecastTableFunctionHandle functionHandle =
@@ -381,13 +339,13 @@ public class ForecastTableFunction implements 
TableFunction {
             outputLength,
             outputStartTime,
             outputInterval,
-            predicatedColumnTypes);
+            targetColumnTypes);
 
     // outputColumnSchema
     return TableFunctionAnalysis.builder()
         .properColumnSchema(properColumnSchemaBuilder.build())
         .handle(functionHandle)
-        .requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList)
+        .requiredColumns(TARGETS_PARAMETER_NAME, requiredIndexList)
         .build();
   }
 
@@ -465,9 +423,9 @@ public class ForecastTableFunction implements TableFunction 
{
       this.keepInput = functionHandle.keepInput;
       this.options = functionHandle.options;
       this.inputRecords = new LinkedList<>();
-      this.resultColumnAppenderList = new 
ArrayList<>(functionHandle.types.size());
-      List<TSDataType> tsDataTypeList = new 
ArrayList<>(functionHandle.types.size());
-      for (Type type : functionHandle.types) {
+      this.resultColumnAppenderList = new 
ArrayList<>(functionHandle.targetColumntypes.size());
+      List<TSDataType> tsDataTypeList = new 
ArrayList<>(functionHandle.targetColumntypes.size());
+      for (Type type : functionHandle.targetColumntypes) {
         resultColumnAppenderList.add(createResultColumnAppender(type));
         // ainode currently only accept double input
         tsDataTypeList.add(TSDataType.DOUBLE);
@@ -475,21 +433,6 @@ public class ForecastTableFunction implements 
TableFunction {
       this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
     }
 
-    private static ResultColumnAppender createResultColumnAppender(Type type) {
-      switch (type) {
-        case INT32:
-          return new Int32Appender();
-        case INT64:
-          return new Int64Appender();
-        case FLOAT:
-          return new FloatAppender();
-        case DOUBLE:
-          return new DoubleAppender();
-        default:
-          throw new IllegalArgumentException("Unsupported column type: " + 
type);
-      }
-    }
-
     @Override
     public void process(
         Record input,
@@ -528,6 +471,9 @@ public class ForecastTableFunction implements TableFunction 
{
 
       int columnSize = properColumnBuilders.size();
 
+      // sort inputRecords in ascending order by timestamp
+      inputRecords.sort(Comparator.comparingLong(record -> record.getLong(0)));
+
       // time column
       long inputStartTime = inputRecords.getFirst().getLong(0);
       long inputEndTime = inputRecords.getLast().getLong(0);
@@ -546,6 +492,12 @@ public class ForecastTableFunction implements 
TableFunction {
       }
       long outputTime =
           (outputStartTime == Long.MIN_VALUE) ? (inputEndTime + interval) : 
outputStartTime;
+      if (outputTime <= inputEndTime) {
+        throw new SemanticException(
+            String.format(
+                "The %s should be greater than the maximum timestamp of target 
time series. Expected greater than [%s] but found [%s].",
+                OUTPUT_START_TIME, inputEndTime, outputTime));
+      }
       for (int i = 0; i < outputLength; i++) {
         properColumnBuilders.get(0).writeLong(outputTime + interval * i);
       }
@@ -585,6 +537,7 @@ public class ForecastTableFunction implements TableFunction 
{
     }
 
     private TsBlock forecast() {
+      // construct inputTSBlock for AINode
       while (!inputRecords.isEmpty()) {
         Record row = inputRecords.removeFirst();
         inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0));
@@ -632,100 +585,4 @@ public class ForecastTableFunction implements 
TableFunction {
       return res;
     }
   }
-
-  private interface ResultColumnAppender {
-    void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder);
-
-    double getDouble(Record row, int columnIndex);
-
-    void writeDouble(double value, ColumnBuilder columnBuilder);
-  }
-
-  private static class Int32Appender implements ResultColumnAppender {
-
-    @Override
-    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
-      if (row.isNull(columnIndex)) {
-        properColumnBuilder.appendNull();
-      } else {
-        properColumnBuilder.writeInt(row.getInt(columnIndex));
-      }
-    }
-
-    @Override
-    public double getDouble(Record row, int columnIndex) {
-      return row.getInt(columnIndex);
-    }
-
-    @Override
-    public void writeDouble(double value, ColumnBuilder columnBuilder) {
-      columnBuilder.writeInt((int) value);
-    }
-  }
-
-  private static class Int64Appender implements ResultColumnAppender {
-
-    @Override
-    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
-      if (row.isNull(columnIndex)) {
-        properColumnBuilder.appendNull();
-      } else {
-        properColumnBuilder.writeLong(row.getLong(columnIndex));
-      }
-    }
-
-    @Override
-    public double getDouble(Record row, int columnIndex) {
-      return row.getLong(columnIndex);
-    }
-
-    @Override
-    public void writeDouble(double value, ColumnBuilder columnBuilder) {
-      columnBuilder.writeLong((long) value);
-    }
-  }
-
-  private static class FloatAppender implements ResultColumnAppender {
-
-    @Override
-    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
-      if (row.isNull(columnIndex)) {
-        properColumnBuilder.appendNull();
-      } else {
-        properColumnBuilder.writeFloat(row.getFloat(columnIndex));
-      }
-    }
-
-    @Override
-    public double getDouble(Record row, int columnIndex) {
-      return row.getFloat(columnIndex);
-    }
-
-    @Override
-    public void writeDouble(double value, ColumnBuilder columnBuilder) {
-      columnBuilder.writeFloat((float) value);
-    }
-  }
-
-  private static class DoubleAppender implements ResultColumnAppender {
-
-    @Override
-    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
-      if (row.isNull(columnIndex)) {
-        properColumnBuilder.appendNull();
-      } else {
-        properColumnBuilder.writeDouble(row.getDouble(columnIndex));
-      }
-    }
-
-    @Override
-    public double getDouble(Record row, int columnIndex) {
-      return row.getDouble(columnIndex);
-    }
-
-    @Override
-    public void writeDouble(double value, ColumnBuilder columnBuilder) {
-      columnBuilder.writeDouble(value);
-    }
-  }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/ResultColumnAppender.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/ResultColumnAppender.java
new file mode 100644
index 00000000000..7748f39382d
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/ResultColumnAppender.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.utils;
+
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.type.Type;
+
+import org.apache.tsfile.block.column.ColumnBuilder;
+
+public interface ResultColumnAppender {
+
+  void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder);
+
+  double getDouble(Record row, int columnIndex);
+
+  void writeDouble(double value, ColumnBuilder columnBuilder);
+
+  /**
+   * Static factory method to return the appropriate ResultColumnAppender 
instance based on the
+   * Type.
+   */
+  static ResultColumnAppender createResultColumnAppender(Type type) {
+    switch (type) {
+      case INT32:
+        return new Int32Appender();
+      case INT64:
+        return new Int64Appender();
+      case FLOAT:
+        return new FloatAppender();
+      case DOUBLE:
+        return new DoubleAppender();
+      default:
+        throw new IllegalArgumentException("Unsupported column type: " + type);
+    }
+  }
+
+  /** INT32 Appender */
+  class Int32Appender implements ResultColumnAppender {
+
+    @Override
+    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
+      if (row.isNull(columnIndex)) {
+        properColumnBuilder.appendNull();
+      } else {
+        properColumnBuilder.writeInt(row.getInt(columnIndex));
+      }
+    }
+
+    @Override
+    public double getDouble(Record row, int columnIndex) {
+      return row.getInt(columnIndex);
+    }
+
+    @Override
+    public void writeDouble(double value, ColumnBuilder columnBuilder) {
+      columnBuilder.writeInt((int) value);
+    }
+  }
+
+  /** INT64 Appender */
+  class Int64Appender implements ResultColumnAppender {
+
+    @Override
+    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
+      if (row.isNull(columnIndex)) {
+        properColumnBuilder.appendNull();
+      } else {
+        properColumnBuilder.writeLong(row.getLong(columnIndex));
+      }
+    }
+
+    @Override
+    public double getDouble(Record row, int columnIndex) {
+      return row.getLong(columnIndex);
+    }
+
+    @Override
+    public void writeDouble(double value, ColumnBuilder columnBuilder) {
+      columnBuilder.writeLong((long) value);
+    }
+  }
+
+  /** FLOAT Appender */
+  class FloatAppender implements ResultColumnAppender {
+
+    @Override
+    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
+      if (row.isNull(columnIndex)) {
+        properColumnBuilder.appendNull();
+      } else {
+        properColumnBuilder.writeFloat(row.getFloat(columnIndex));
+      }
+    }
+
+    @Override
+    public double getDouble(Record row, int columnIndex) {
+      return row.getFloat(columnIndex);
+    }
+
+    @Override
+    public void writeDouble(double value, ColumnBuilder columnBuilder) {
+      columnBuilder.writeFloat((float) value);
+    }
+  }
+
+  /** DOUBLE Appender */
+  class DoubleAppender implements ResultColumnAppender {
+
+    @Override
+    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
+      if (row.isNull(columnIndex)) {
+        properColumnBuilder.appendNull();
+      } else {
+        properColumnBuilder.writeDouble(row.getDouble(columnIndex));
+      }
+    }
+
+    @Override
+    public double getDouble(Record row, int columnIndex) {
+      return row.getDouble(columnIndex);
+    }
+
+    @Override
+    public void writeDouble(double value, ColumnBuilder columnBuilder) {
+      columnBuilder.writeDouble(value);
+    }
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
index 09f00f8ed67..ebecf79f5b7 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
@@ -85,6 +85,7 @@ public class UDTFForecast implements UDTF {
   private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
   private static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS";
   private static final String DEFAULT_OPTIONS = "";
+  private static final int MAX_INPUT_LENGTH = 2880;
 
   private void checkType() {
     for (Type type : this.types) {
@@ -109,6 +110,7 @@ public class UDTFForecast implements UDTF {
       throw new IllegalArgumentException(
           "MODEL_ID parameter must be provided and cannot be empty.");
     }
+    this.maxInputLength = MAX_INPUT_LENGTH;
 
     this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, 
DEFAULT_OUTPUT_INTERVAL);
     this.outputLength =
diff --git 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
index 7bbfe150ade..344c69cfc1a 100644
--- 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
+++ 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
@@ -356,7 +356,7 @@ public class TableFunctionTest {
 
     String sql =
         "SELECT * FROM FORECAST("
-            + "input => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' AND 
tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
+            + "targets => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' 
AND tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
             + "model_id => 'timer_xl')";
     LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
 
@@ -416,7 +416,7 @@ public class TableFunctionTest {
 
     String sql =
         "SELECT * FROM FORECAST("
-            + "input => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' AND 
tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
+            + "targets => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' 
AND tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
             + "model_id => 'timer_xl', timecol=>'TiME')";
     LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
 

Reply via email to