This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch ty/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/ty/forecast by this push:
new ee510fc353a Delete inputLength
ee510fc353a is described below
commit ee510fc353ac6a3b98942a8bc6b45d0d22737eab
Author: JackieTien97 <[email protected]>
AuthorDate: Tue Apr 22 11:53:15 2025 +0800
Delete inputLength
---
.../relational/TableBuiltinTableFunction.java | 2 +-
.../relational/tvf/ForecastTableFunction.java | 21 ++++++++++-----------
.../thrift-ainode/src/main/thrift/ainode.thrift | 14 ++++++++++++++
3 files changed, 25 insertions(+), 12 deletions(-)
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java
index 9bece352a75..395615589d0 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java
@@ -20,8 +20,8 @@
package org.apache.iotdb.commons.udf.builtin.relational;
import
org.apache.iotdb.commons.udf.builtin.relational.tvf.CapacityTableFunction;
-import
org.apache.iotdb.commons.udf.builtin.relational.tvf.ForecastTableFunction;
import
org.apache.iotdb.commons.udf.builtin.relational.tvf.CumulateTableFunction;
+import
org.apache.iotdb.commons.udf.builtin.relational.tvf.ForecastTableFunction;
import org.apache.iotdb.commons.udf.builtin.relational.tvf.HOPTableFunction;
import
org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
index e7731a3406b..f44e55c6a99 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
@@ -41,6 +41,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -130,12 +131,6 @@ public class ForecastTableFunction implements
TableFunction {
MODEL_ID_PARAMETER_NAME, modelId));
}
- int inputLength =
- (int) ((ScalarArgument)
arguments.get(INPUT_LENGTH_PARAMETER_NAME)).getValue();
- if (inputLength <= 0) {
- throw new UDFException(
- String.format("%s should be greater than 0",
INPUT_LENGTH_PARAMETER_NAME));
- }
int outputLength =
(int) ((ScalarArgument)
arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue();
if (outputLength <= 0) {
@@ -209,7 +204,6 @@ public class ForecastTableFunction implements TableFunction
{
// outputColumnSchema
return TableFunctionAnalysis.builder()
.properColumnSchema(properColumnSchemaBuilder.build())
- .requireRecordSnapshot(false)
.requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList)
.build();
}
@@ -247,24 +241,29 @@ public class ForecastTableFunction implements
TableFunction {
private static class ForecastDataProcessor implements
TableFunctionDataProcessor {
- private final int inputLength;
+ private final int maxInputLength;
private final int outputLength;
private final boolean keepInput;
private final String options;
+ private final List<Record> inputRecords;
+
public ForecastDataProcessor(
- int inputLength, int outputLength, boolean keepInput, String options) {
- this.inputLength = inputLength;
+ int maxInputLength, int outputLength, boolean keepInput, String
options) {
+ this.maxInputLength = maxInputLength;
this.outputLength = outputLength;
this.keepInput = keepInput;
this.options = options;
+ this.inputRecords = new LinkedList<>();
}
@Override
public void process(
Record input,
List<ColumnBuilder> properColumnBuilders,
- ColumnBuilder passThroughIndexBuilder) {}
+ ColumnBuilder passThroughIndexBuilder) {
+ inputRecords.add(input);
+ }
@Override
public void finish(List<ColumnBuilder> columnBuilders, ColumnBuilder
passThroughIndexBuilder) {
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 9ac07b48dca..1849cdf19d5 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -89,6 +89,18 @@ struct TTrainingReq {
6: optional string existingModelId
}
+struct TForecastReq {
+ 1: required string modelId
+ 2: required list<list<double>> inputData
+ 3: required i32 outputLength
+ 4: optional map<string, string> options
+}
+
+struct TForecastResp {
+ 1: required common.TSStatus status
+ 2: required list<list<double>> inferenceResult
+}
+
service IAINodeRPCService {
// -------------- For Config Node --------------
@@ -104,4 +116,6 @@ service IAINodeRPCService {
// -------------- For Data Node --------------
TInferenceResp inference(TInferenceReq req)
+
+ TForecastResp forecast(TForecastReq req)
}
\ No newline at end of file