This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch refactor-forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 671b0c22608abfb50b6fa99b831e741a05b5c4ed
Author: Yongzao <[email protected]>
AuthorDate: Sat Nov 1 17:02:11 2025 +0800

    finish
---
 iotdb-core/ainode/.gitignore                       |  1 -
 .../function/tvf/ForecastTableFunction.java        | 88 ++++++++++++++--------
 .../db/queryengine/plan/udf/UDTFForecast.java      |  7 +-
 .../relational/analyzer/TableFunctionTest.java     |  2 +
 4 files changed, 60 insertions(+), 38 deletions(-)

diff --git a/iotdb-core/ainode/.gitignore b/iotdb-core/ainode/.gitignore
index 8cc2098c3fd..80221b44dc5 100644
--- a/iotdb-core/ainode/.gitignore
+++ b/iotdb-core/ainode/.gitignore
@@ -5,7 +5,6 @@
 /iotdb/thrift/
 /iotdb/tsfile/
 /iotdb/utils/
-/iotdb/__init__.py
 /iotdb/Session.py
 /iotdb/SessionPool.py
 /iotdb/table_session.py
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
index 5521062a24e..af09dbf4584 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -84,7 +84,8 @@ public class ForecastTableFunction implements TableFunction {
     long outputInterval;
     boolean keepInput;
     Map<String, String> options;
-    List<Type> types;
+    List<Type> inputColumnTypes;
+    List<Type> predicatedColumnTypes;
 
     public ForecastTableFunctionHandle() {}
 
@@ -97,7 +98,8 @@ public class ForecastTableFunction implements TableFunction {
         long outputStartTime,
         long outputInterval,
         TEndPoint targetAINode,
-        List<Type> types) {
+        List<Type> inputColumnTypes,
+        List<Type> predicatedColumnTypes) {
       this.keepInput = keepInput;
       this.maxInputLength = maxInputLength;
       this.modelId = modelId;
@@ -106,7 +108,8 @@ public class ForecastTableFunction implements TableFunction 
{
       this.outputStartTime = outputStartTime;
       this.outputInterval = outputInterval;
       this.targetAINode = targetAINode;
-      this.types = types;
+      this.inputColumnTypes = inputColumnTypes;
+      this.predicatedColumnTypes = predicatedColumnTypes;
     }
 
     @Override
@@ -122,8 +125,12 @@ public class ForecastTableFunction implements 
TableFunction {
         ReadWriteIOUtils.write(outputInterval, outputStream);
         ReadWriteIOUtils.write(keepInput, outputStream);
         ReadWriteIOUtils.write(options, outputStream);
-        ReadWriteIOUtils.write(types.size(), outputStream);
-        for (Type type : types) {
+        ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream);
+        for (Type type : inputColumnTypes) {
+          ReadWriteIOUtils.write(type.getType(), outputStream);
+        }
+        ReadWriteIOUtils.write(predicatedColumnTypes.size(), outputStream);
+        for (Type type : predicatedColumnTypes) {
           ReadWriteIOUtils.write(type.getType(), outputStream);
         }
         outputStream.flush();
@@ -149,9 +156,14 @@ public class ForecastTableFunction implements 
TableFunction {
       this.keepInput = ReadWriteIOUtils.readBoolean(buffer);
       this.options = ReadWriteIOUtils.readMap(buffer);
       int size = ReadWriteIOUtils.readInt(buffer);
-      this.types = new ArrayList<>(size);
+      this.inputColumnTypes = new ArrayList<>(size);
+      for (int i = 0; i < size; i++) {
+        
inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readString(buffer)));
+      }
+      size = ReadWriteIOUtils.readInt(buffer);
+      this.predicatedColumnTypes = new ArrayList<>(size);
       for (int i = 0; i < size; i++) {
-        types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
+        
predicatedColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
       }
     }
 
@@ -172,7 +184,8 @@ public class ForecastTableFunction implements TableFunction 
{
           && Objects.equals(targetAINode, that.targetAINode)
           && Objects.equals(modelId, that.modelId)
           && Objects.equals(options, that.options)
-          && Objects.equals(types, that.types);
+          && Objects.equals(inputColumnTypes, that.inputColumnTypes)
+          && Objects.equals(predicatedColumnTypes, that.predicatedColumnTypes);
     }
 
     @Override
@@ -186,7 +199,8 @@ public class ForecastTableFunction implements TableFunction 
{
           outputInterval,
           keepInput,
           options,
-          types);
+          inputColumnTypes,
+          predicatedColumnTypes);
     }
   }
 
@@ -319,12 +333,22 @@ public class ForecastTableFunction implements 
TableFunction {
     DescribedSchema.Builder properColumnSchemaBuilder =
         new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);
 
+    List<Type> inputColumnTypes = new ArrayList<>();
     List<Type> predicatedColumnTypes = new ArrayList<>();
     List<Optional<String>> allInputColumnsName = input.getFieldNames();
     List<Type> allInputColumnsType = input.getFieldTypes();
+    for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
+      Optional<String> fieldName = allInputColumnsName.get(i);
+      // All input value columns are required for model forecasting
+      if (!fieldName.isPresent()
+          || 
!excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
+        inputColumnTypes.add(allInputColumnsType.get(i));
+        requiredIndexList.add(i);
+      }
+    }
     if (predicatedColumns.isEmpty()) {
-      // predicated columns by default include all columns from input table 
except for timecol and
-      // partition by columns
+      // predicated columns by default include all columns from input table 
except for
+      // timecol and partition by columns
       for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
         Optional<String> fieldName = allInputColumnsName.get(i);
         if (!fieldName.isPresent()
@@ -332,7 +356,6 @@ public class ForecastTableFunction implements TableFunction 
{
           Type columnType = allInputColumnsType.get(i);
           predicatedColumnTypes.add(columnType);
           checkType(columnType, fieldName.orElse(""));
-          requiredIndexList.add(i);
           properColumnSchemaBuilder.addField(fieldName, columnType);
         }
       }
@@ -347,7 +370,7 @@ public class ForecastTableFunction implements TableFunction 
{
         inputColumnIndexMap.put(fieldName.get().toLowerCase(Locale.ENGLISH), 
i);
       }
 
-      Set<Integer> requiredIndexSet = new 
HashSet<>(predictedColumnsArray.length);
+      Set<Integer> predicatedIndexSet = new 
HashSet<>(predictedColumnsArray.length);
       // columns need to be predicated
       for (String outputColumn : predictedColumnsArray) {
         String lowerCaseOutputColumn = 
outputColumn.toLowerCase(Locale.ENGLISH);
@@ -360,14 +383,13 @@ public class ForecastTableFunction implements 
TableFunction {
           throw new SemanticException(
               String.format("Column %s don't exist in input", outputColumn));
         }
-        if (!requiredIndexSet.add(inputColumnIndex)) {
+        if (!predicatedIndexSet.add(inputColumnIndex)) {
           throw new SemanticException(String.format("Duplicate column %s", 
outputColumn));
         }
 
         Type columnType = allInputColumnsType.get(inputColumnIndex);
         predicatedColumnTypes.add(columnType);
         checkType(columnType, outputColumn);
-        requiredIndexList.add(inputColumnIndex);
         properColumnSchemaBuilder.addField(outputColumn, columnType);
       }
     }
@@ -392,6 +414,7 @@ public class ForecastTableFunction implements TableFunction 
{
             outputStartTime,
             outputInterval,
             targetAINode,
+            inputColumnTypes,
             predicatedColumnTypes);
 
     // outputColumnSchema
@@ -469,8 +492,9 @@ public class ForecastTableFunction implements TableFunction 
{
     private final boolean keepInput;
     private final Map<String, String> options;
     private final LinkedList<Record> inputRecords;
-    private final List<ResultColumnAppender> resultColumnAppenderList;
     private final TsBlockBuilder inputTsBlockBuilder;
+    private final List<ResultColumnAppender> inputColumnAppenderList;
+    private final List<ResultColumnAppender> resultColumnAppenderList;
 
     public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
       this.targetAINode = functionHandle.targetAINode;
@@ -482,14 +506,22 @@ public class ForecastTableFunction implements 
TableFunction {
       this.keepInput = functionHandle.keepInput;
       this.options = functionHandle.options;
       this.inputRecords = new LinkedList<>();
-      this.resultColumnAppenderList = new 
ArrayList<>(functionHandle.types.size());
-      List<TSDataType> tsDataTypeList = new 
ArrayList<>(functionHandle.types.size());
-      for (Type type : functionHandle.types) {
+      List<TSDataType> inputTsDataTypeList =
+          new ArrayList<>(functionHandle.inputColumnTypes.size());
+      for (Type type : functionHandle.inputColumnTypes) {
+        // AINode currently only accept double input
+        inputTsDataTypeList.add(TSDataType.DOUBLE);
+      }
+      this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
+      this.inputColumnAppenderList = new 
ArrayList<>(functionHandle.inputColumnTypes.size());
+      for (Type type : functionHandle.inputColumnTypes) {
+        // AINode currently only accept double input
+        inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE));
+      }
+      this.resultColumnAppenderList = new 
ArrayList<>(functionHandle.predicatedColumnTypes.size());
+      for (Type type : functionHandle.predicatedColumnTypes) {
         resultColumnAppenderList.add(createResultColumnAppender(type));
-        // ainode currently only accept double input
-        tsDataTypeList.add(TSDataType.DOUBLE);
       }
-      this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
     }
 
     private static ResultColumnAppender createResultColumnAppender(Type type) {
@@ -613,7 +645,7 @@ public class ForecastTableFunction implements TableFunction 
{
             // need to transform other types to DOUBLE
             inputTsBlockBuilder
                 .getColumnBuilder(i - 1)
-                .writeDouble(resultColumnAppenderList.get(i - 
1).getDouble(row, i));
+                .writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row, 
i));
           }
         }
         inputTsBlockBuilder.declarePosition();
@@ -634,15 +666,7 @@ public class ForecastTableFunction implements 
TableFunction {
         throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
       }
 
-      TsBlock res = 
SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
-      if (res.getValueColumnCount() != inputData.getValueColumnCount()) {
-        throw new IoTDBRuntimeException(
-            String.format(
-                "Model %s output %s columns, doesn't equal to specified %s",
-                modelId, res.getValueColumnCount(), 
inputData.getValueColumnCount()),
-            TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
-      }
-      return res;
+      return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
     }
   }
 
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
index 22c2bce7b5e..f1e387c0ff6 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
@@ -128,11 +128,8 @@ public class UDTFForecast implements UDTF {
         Arrays.stream(
                 parameters.getStringOrDefault(OPTIONS_PARAMETER_NAME, 
DEFAULT_OPTIONS).split(","))
             .map(s -> s.split("="))
-            .filter(arr -> arr.length == 2 && !arr[0].isEmpty()) // 防御性检查
-            .collect(
-                Collectors.toMap(
-                    arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2 
// 如果 key 重复,保留后一个
-                    ));
+            .filter(arr -> arr.length == 2 && !arr[0].isEmpty())
+            .collect(Collectors.toMap(arr -> arr[0].trim(), arr -> 
arr[1].trim(), (v1, v2) -> v2));
     this.inputRows = new LinkedList<>();
     List<TSDataType> tsDataTypeList = new ArrayList<>(this.types.size() - 1);
     for (int i = 0; i < this.types.size(); i++) {
diff --git 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
index e56b48936b9..0e8ffb484f1 100644
--- 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
+++ 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
@@ -379,6 +379,7 @@ public class TableFunctionTest {
                         DEFAULT_OUTPUT_START_TIME,
                         DEFAULT_OUTPUT_INTERVAL,
                         new TEndPoint("127.0.0.1", 10810),
+                        Collections.singletonList(DOUBLE),
                         Collections.singletonList(DOUBLE)));
     // Verify full LogicalPlan
     // Output - TableFunctionProcessor - TableScan
@@ -440,6 +441,7 @@ public class TableFunctionTest {
                         DEFAULT_OUTPUT_START_TIME,
                         DEFAULT_OUTPUT_INTERVAL,
                         new TEndPoint("127.0.0.1", 10810),
+                        Collections.singletonList(DOUBLE),
                         Collections.singletonList(DOUBLE)));
     // Verify full LogicalPlan
     // Output - TableFunctionProcessor - TableScan

Reply via email to