This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch ty/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/ty/forecast by this push:
new 8f1e6b5dd9e change ainode forecast interface
8f1e6b5dd9e is described below
commit 8f1e6b5dd9eca1d5dcc419f555625a19cdae0288
Author: JackieTien97 <[email protected]>
AuthorDate: Tue Apr 22 17:19:01 2025 +0800
change ainode forecast interface
---
.../relational/tvf/ForecastTableFunction.java | 188 ++++++++++++++++++++-
.../thrift-ainode/src/main/thrift/ainode.thrift | 4 +-
2 files changed, 184 insertions(+), 8 deletions(-)
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
index f44e55c6a99..a1ed1b73d04 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
@@ -19,6 +19,8 @@
package org.apache.iotdb.commons.udf.builtin.relational.tvf;
+import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
+import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.iotdb.udf.api.exception.UDFException;
import org.apache.iotdb.udf.api.relational.TableFunction;
import org.apache.iotdb.udf.api.relational.access.Record;
@@ -62,6 +64,7 @@ public class ForecastTableFunction implements TableFunction {
private static final String DEFAULT_TIME_COL = "time";
private static final String KEEP_INPUT_PARAMETER_NAME = "KEEP_INPUT";
private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
+ private static final String IS_INPUT_COLUMN_NAME = "is_input";
private static final String OPTIONS_PARAMETER_NAME = "OPTIONS";
private static final String DEFAULT_OPTIONS = "";
@@ -201,6 +204,12 @@ public class ForecastTableFunction implements
TableFunction {
}
}
+ boolean keepInput =
+ (boolean) ((ScalarArgument)
arguments.get(KEEP_INPUT_PARAMETER_NAME)).getValue();
+ if (keepInput) {
+ properColumnSchemaBuilder.addField(IS_INPUT_COLUMN_NAME, Type.BOOLEAN);
+ }
+
// outputColumnSchema
return TableFunctionAnalysis.builder()
.properColumnSchema(properColumnSchemaBuilder.build())
@@ -234,27 +243,55 @@ public class ForecastTableFunction implements
TableFunction {
return new TableFunctionProcessorProvider() {
@Override
public TableFunctionDataProcessor getDataProcessor() {
- return new ForecastDataProcessor(inputLength, outputLength, keepInput,
options);
+ return new ForecastDataProcessor(
+ inputLength, outputLength, keepInput, parseOptions(options),
Collections.emptyList());
}
};
}
+ private static Map<String, String> parseOptions(String options) {
+ return Collections.emptyMap();
+ }
+
private static class ForecastDataProcessor implements
TableFunctionDataProcessor {
private final int maxInputLength;
private final int outputLength;
private final boolean keepInput;
- private final String options;
-
+ private final Map<String, String> options;
private final List<Record> inputRecords;
+ private final List<ResultColumnAppender> resultColumnAppenderList;
public ForecastDataProcessor(
- int maxInputLength, int outputLength, boolean keepInput, String
options) {
+ int maxInputLength,
+ int outputLength,
+ boolean keepInput,
+ Map<String, String> options,
+ List<Type> types) {
this.maxInputLength = maxInputLength;
this.outputLength = outputLength;
this.keepInput = keepInput;
this.options = options;
this.inputRecords = new LinkedList<>();
+ this.resultColumnAppenderList = new ArrayList<>(types.size());
+ for (Type type : types) {
+ resultColumnAppenderList.add(createResultColumnAppender(type));
+ }
+ }
+
+ private static ResultColumnAppender createResultColumnAppender(Type type) {
+ switch (type) {
+ case INT32:
+ return new Int32Appender();
+ case INT64:
+ return new Int64Appender();
+ case FLOAT:
+ return new FloatAppender();
+ case DOUBLE:
+ return new DoubleAppender();
+ default:
+ throw new IllegalArgumentException("Unsupported column type: " +
type);
+ }
}
@Override
@@ -262,12 +299,151 @@ public class ForecastTableFunction implements
TableFunction {
Record input,
List<ColumnBuilder> properColumnBuilders,
ColumnBuilder passThroughIndexBuilder) {
+
+ if (keepInput) {
+ int columnSize = properColumnBuilders.size();
+
+ // time column, will never be null
+ if (input.isNull(0)) {
+ throw new IoTDBRuntimeException(
+ "Time column should never be null",
TSStatusCode.SEMANTIC_ERROR.getStatusCode());
+ }
+ properColumnBuilders.get(0).writeLong(input.getLong(0));
+
+ // predicated columns
+ for (int i = 1, size = columnSize - 1; i < size; i++) {
+ resultColumnAppenderList.get(i - 1).append(input, i,
properColumnBuilders.get(i));
+ }
+
+ // is_input column
+ properColumnBuilders.get(columnSize - 1).writeBoolean(true);
+ }
+
+ // only keep at most maxInputLength rows
+ if (inputRecords.size() == maxInputLength) {
+ inputRecords.removeFirst();
+ }
inputRecords.add(input);
}
@Override
- public void finish(List<ColumnBuilder> columnBuilders, ColumnBuilder
passThroughIndexBuilder) {
- TableFunctionDataProcessor.super.finish(columnBuilders,
passThroughIndexBuilder);
+ public void finish(
+ List<ColumnBuilder> properColumnBuilders, ColumnBuilder
passThroughIndexBuilder) {
+
+ int columnSize = properColumnBuilders.size();
+
+ // time column
+ long startTime = inputRecords.getFirst().getLong(0);
+ long endTime = inputRecords.getLast().getLong(0);
+ long interval = (endTime - startTime) / inputRecords.size();
+ for (int i = 0; i < outputLength; i++) {
+ properColumnBuilders.get(0).writeLong(endTime + interval * i);
+ }
+
+ // predicated columns
+
+ // is_input column if keep_input is true
+ if (keepInput) {
+ for (int i = 0; i < outputLength; i++) {
+ properColumnBuilders.get(columnSize - 1).writeBoolean(false);
+ }
+ }
+ }
+ }
+
+ private interface ResultColumnAppender {
+ void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder);
+
+ double getDouble(Record row, int columnIndex);
+
+ void writeDouble(double value, ColumnBuilder properColumnBuilder);
+ }
+
+ private static class Int32Appender implements ResultColumnAppender {
+
+ @Override
+ public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
+ if (row.isNull(columnIndex)) {
+ properColumnBuilder.appendNull();
+ } else {
+ properColumnBuilder.writeInt(row.getInt(columnIndex));
+ }
+ }
+
+ @Override
+ public double getDouble(Record row, int columnIndex) {
+ return row.getInt(columnIndex);
+ }
+
+ @Override
+ public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
+ properColumnBuilder.writeInt((int) value);
+ }
+ }
+
+ private static class Int64Appender implements ResultColumnAppender {
+
+ @Override
+ public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
+ if (row.isNull(columnIndex)) {
+ properColumnBuilder.appendNull();
+ } else {
+ properColumnBuilder.writeLong(row.getLong(columnIndex));
+ }
+ }
+
+ @Override
+ public double getDouble(Record row, int columnIndex) {
+ return row.getLong(columnIndex);
+ }
+
+ @Override
+ public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
+ properColumnBuilder.writeLong((long) value);
+ }
+ }
+
+ private static class FloatAppender implements ResultColumnAppender {
+
+ @Override
+ public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
+ if (row.isNull(columnIndex)) {
+ properColumnBuilder.appendNull();
+ } else {
+ properColumnBuilder.writeFloat(row.getFloat(columnIndex));
+ }
+ }
+
+ @Override
+ public double getDouble(Record row, int columnIndex) {
+ return row.getFloat(columnIndex);
+ }
+
+ @Override
+ public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
+ properColumnBuilder.writeFloat((float) value);
+ }
+ }
+
+ private static class DoubleAppender implements ResultColumnAppender {
+
+ @Override
+ public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
+ if (row.isNull(columnIndex)) {
+ properColumnBuilder.appendNull();
+ } else {
+ properColumnBuilder.writeDouble(row.getDouble(columnIndex));
+ }
+ }
+
+ @Override
+ public double getDouble(Record row, int columnIndex) {
+ return row.getDouble(columnIndex);
+ }
+
+ @Override
+ public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
+ properColumnBuilder.writeDouble(value);
}
}
}
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 1849cdf19d5..95ab108b948 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -91,14 +91,14 @@ struct TTrainingReq {
struct TForecastReq {
1: required string modelId
- 2: required list<list<double>> inputData
+ 2: required binary inputData
3: required i32 outputLength
4: optional map<string, string> options
}
struct TForecastResp {
1: required common.TSStatus status
- 2: required list<list<double>> inferenceResult
+ 2: required binary inferenceResult
}
service IAINodeRPCService {