This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch ty/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/ty/forecast by this push:
     new ee510fc353a Delete inputLength
ee510fc353a is described below

commit ee510fc353ac6a3b98942a8bc6b45d0d22737eab
Author: JackieTien97 <[email protected]>
AuthorDate: Tue Apr 22 11:53:15 2025 +0800

    Delete inputLength
---
 .../relational/TableBuiltinTableFunction.java       |  2 +-
 .../relational/tvf/ForecastTableFunction.java       | 21 ++++++++++-----------
 .../thrift-ainode/src/main/thrift/ainode.thrift     | 14 ++++++++++++++
 3 files changed, 25 insertions(+), 12 deletions(-)

diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java
index 9bece352a75..395615589d0 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java
@@ -20,8 +20,8 @@
 package org.apache.iotdb.commons.udf.builtin.relational;
 
 import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.CapacityTableFunction;
-import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.ForecastTableFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.CumulateTableFunction;
+import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.ForecastTableFunction;
 import org.apache.iotdb.commons.udf.builtin.relational.tvf.HOPTableFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
 import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
index e7731a3406b..f44e55c6a99 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
@@ -41,6 +41,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -130,12 +131,6 @@ public class ForecastTableFunction implements 
TableFunction {
               MODEL_ID_PARAMETER_NAME, modelId));
     }
 
-    int inputLength =
-        (int) ((ScalarArgument) 
arguments.get(INPUT_LENGTH_PARAMETER_NAME)).getValue();
-    if (inputLength <= 0) {
-      throw new UDFException(
-          String.format("%s should be greater than 0", 
INPUT_LENGTH_PARAMETER_NAME));
-    }
     int outputLength =
         (int) ((ScalarArgument) 
arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue();
     if (outputLength <= 0) {
@@ -209,7 +204,6 @@ public class ForecastTableFunction implements TableFunction 
{
     // outputColumnSchema
     return TableFunctionAnalysis.builder()
         .properColumnSchema(properColumnSchemaBuilder.build())
-        .requireRecordSnapshot(false)
         .requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList)
         .build();
   }
@@ -247,24 +241,29 @@ public class ForecastTableFunction implements 
TableFunction {
 
   private static class ForecastDataProcessor implements 
TableFunctionDataProcessor {
 
-    private final int inputLength;
+    private final int maxInputLength;
     private final int outputLength;
     private final boolean keepInput;
     private final String options;
 
+    private final List<Record> inputRecords;
+
     public ForecastDataProcessor(
-        int inputLength, int outputLength, boolean keepInput, String options) {
-      this.inputLength = inputLength;
+        int maxInputLength, int outputLength, boolean keepInput, String 
options) {
+      this.maxInputLength = maxInputLength;
       this.outputLength = outputLength;
       this.keepInput = keepInput;
       this.options = options;
+      this.inputRecords = new LinkedList<>();
     }
 
     @Override
     public void process(
         Record input,
         List<ColumnBuilder> properColumnBuilders,
-        ColumnBuilder passThroughIndexBuilder) {}
+        ColumnBuilder passThroughIndexBuilder) {
+      inputRecords.add(input);
+    }
 
     @Override
     public void finish(List<ColumnBuilder> columnBuilders, ColumnBuilder 
passThroughIndexBuilder) {
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift 
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 9ac07b48dca..1849cdf19d5 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -89,6 +89,18 @@ struct TTrainingReq {
   6: optional string existingModelId
 }
 
+struct TForecastReq {
+  1: required string modelId
+  2: required list<list<double>> inputData
+  3: required i32 outputLength
+  4: optional map<string, string> options
+}
+
+struct TForecastResp {
+  1: required common.TSStatus status
+  2: required list<list<double>> inferenceResult
+}
+
 service IAINodeRPCService {
 
   // -------------- For Config Node --------------
@@ -104,4 +116,6 @@ service IAINodeRPCService {
   // -------------- For Data Node --------------
 
   TInferenceResp inference(TInferenceReq req)
+
+  TForecastResp forecast(TForecastReq req)
 }
\ No newline at end of file

Reply via email to