This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch ty/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/ty/forecast by this push:
     new 8f1e6b5dd9e change ainode forecast interface
8f1e6b5dd9e is described below

commit 8f1e6b5dd9eca1d5dcc419f555625a19cdae0288
Author: JackieTien97 <[email protected]>
AuthorDate: Tue Apr 22 17:19:01 2025 +0800

    change ainode forecast interface
---
 .../relational/tvf/ForecastTableFunction.java      | 188 ++++++++++++++++++++-
 .../thrift-ainode/src/main/thrift/ainode.thrift    |   4 +-
 2 files changed, 184 insertions(+), 8 deletions(-)

diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
index f44e55c6a99..a1ed1b73d04 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
@@ -19,6 +19,8 @@
 
 package org.apache.iotdb.commons.udf.builtin.relational.tvf;
 
+import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
+import org.apache.iotdb.rpc.TSStatusCode;
 import org.apache.iotdb.udf.api.exception.UDFException;
 import org.apache.iotdb.udf.api.relational.TableFunction;
 import org.apache.iotdb.udf.api.relational.access.Record;
@@ -62,6 +64,7 @@ public class ForecastTableFunction implements TableFunction {
   private static final String DEFAULT_TIME_COL = "time";
   private static final String KEEP_INPUT_PARAMETER_NAME = "KEEP_INPUT";
   private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
+  private static final String IS_INPUT_COLUMN_NAME = "is_input";
   private static final String OPTIONS_PARAMETER_NAME = "OPTIONS";
   private static final String DEFAULT_OPTIONS = "";
 
@@ -201,6 +204,12 @@ public class ForecastTableFunction implements 
TableFunction {
       }
     }
 
+    boolean keepInput =
+        (boolean) ((ScalarArgument) 
arguments.get(KEEP_INPUT_PARAMETER_NAME)).getValue();
+    if (keepInput) {
+      properColumnSchemaBuilder.addField(IS_INPUT_COLUMN_NAME, Type.BOOLEAN);
+    }
+
     // outputColumnSchema
     return TableFunctionAnalysis.builder()
         .properColumnSchema(properColumnSchemaBuilder.build())
@@ -234,27 +243,55 @@ public class ForecastTableFunction implements 
TableFunction {
     return new TableFunctionProcessorProvider() {
       @Override
       public TableFunctionDataProcessor getDataProcessor() {
-        return new ForecastDataProcessor(inputLength, outputLength, keepInput, 
options);
+        return new ForecastDataProcessor(
+            inputLength, outputLength, keepInput, parseOptions(options), 
Collections.emptyList());
       }
     };
   }
 
+  private static Map<String, String> parseOptions(String options) {
+    return Collections.emptyMap();
+  }
+
   private static class ForecastDataProcessor implements 
TableFunctionDataProcessor {
 
     private final int maxInputLength;
     private final int outputLength;
     private final boolean keepInput;
-    private final String options;
-
+    private final Map<String, String> options;
     private final List<Record> inputRecords;
+    private final List<ResultColumnAppender> resultColumnAppenderList;
 
     public ForecastDataProcessor(
-        int maxInputLength, int outputLength, boolean keepInput, String 
options) {
+        int maxInputLength,
+        int outputLength,
+        boolean keepInput,
+        Map<String, String> options,
+        List<Type> types) {
       this.maxInputLength = maxInputLength;
       this.outputLength = outputLength;
       this.keepInput = keepInput;
       this.options = options;
       this.inputRecords = new LinkedList<>();
+      this.resultColumnAppenderList = new ArrayList<>(types.size());
+      for (Type type : types) {
+        resultColumnAppenderList.add(createResultColumnAppender(type));
+      }
+    }
+
+    private static ResultColumnAppender createResultColumnAppender(Type type) {
+      switch (type) {
+        case INT32:
+          return new Int32Appender();
+        case INT64:
+          return new Int64Appender();
+        case FLOAT:
+          return new FloatAppender();
+        case DOUBLE:
+          return new DoubleAppender();
+        default:
+          throw new IllegalArgumentException("Unsupported column type: " + 
type);
+      }
     }
 
     @Override
@@ -262,12 +299,151 @@ public class ForecastTableFunction implements 
TableFunction {
         Record input,
         List<ColumnBuilder> properColumnBuilders,
         ColumnBuilder passThroughIndexBuilder) {
+
+      if (keepInput) {
+        int columnSize = properColumnBuilders.size();
+
+        // time column, will never be null
+        if (input.isNull(0)) {
+          throw new IoTDBRuntimeException(
+              "Time column should never be null", 
TSStatusCode.SEMANTIC_ERROR.getStatusCode());
+        }
+        properColumnBuilders.get(0).writeLong(input.getLong(0));
+
+        // predicated columns
+        for (int i = 1, size = columnSize - 1; i < size; i++) {
+          resultColumnAppenderList.get(i - 1).append(input, i, 
properColumnBuilders.get(i));
+        }
+
+        // is_input column
+        properColumnBuilders.get(columnSize - 1).writeBoolean(true);
+      }
+
+      // only keep at most maxInputLength rows
+      if (inputRecords.size() == maxInputLength) {
+        inputRecords.removeFirst();
+      }
       inputRecords.add(input);
     }
 
     @Override
-    public void finish(List<ColumnBuilder> columnBuilders, ColumnBuilder 
passThroughIndexBuilder) {
-      TableFunctionDataProcessor.super.finish(columnBuilders, 
passThroughIndexBuilder);
+    public void finish(
+        List<ColumnBuilder> properColumnBuilders, ColumnBuilder 
passThroughIndexBuilder) {
+
+      int columnSize = properColumnBuilders.size();
+
+      // time column
+      long startTime = inputRecords.getFirst().getLong(0);
+      long endTime = inputRecords.getLast().getLong(0);
+      long interval = (endTime - startTime) / inputRecords.size();
+      for (int i = 0; i < outputLength; i++) {
+        properColumnBuilders.get(0).writeLong(endTime + interval * i);
+      }
+
+      // predicated columns
+
+      // is_input column if keep_input is true
+      if (keepInput) {
+        for (int i = 0; i < outputLength; i++) {
+          properColumnBuilders.get(columnSize - 1).writeBoolean(false);
+        }
+      }
+    }
+  }
+
+  private interface ResultColumnAppender {
+    void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder);
+
+    double getDouble(Record row, int columnIndex);
+
+    void writeDouble(double value, ColumnBuilder properColumnBuilder);
+  }
+
+  private static class Int32Appender implements ResultColumnAppender {
+
+    @Override
+    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
+      if (row.isNull(columnIndex)) {
+        properColumnBuilder.appendNull();
+      } else {
+        properColumnBuilder.writeInt(row.getInt(columnIndex));
+      }
+    }
+
+    @Override
+    public double getDouble(Record row, int columnIndex) {
+      return row.getInt(columnIndex);
+    }
+
+    @Override
+    public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
+      properColumnBuilder.writeInt((int) value);
+    }
+  }
+
+  private static class Int64Appender implements ResultColumnAppender {
+
+    @Override
+    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
+      if (row.isNull(columnIndex)) {
+        properColumnBuilder.appendNull();
+      } else {
+        properColumnBuilder.writeLong(row.getLong(columnIndex));
+      }
+    }
+
+    @Override
+    public double getDouble(Record row, int columnIndex) {
+      return row.getLong(columnIndex);
+    }
+
+    @Override
+    public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
+      properColumnBuilder.writeLong((long) value);
+    }
+  }
+
+  private static class FloatAppender implements ResultColumnAppender {
+
+    @Override
+    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
+      if (row.isNull(columnIndex)) {
+        properColumnBuilder.appendNull();
+      } else {
+        properColumnBuilder.writeFloat(row.getFloat(columnIndex));
+      }
+    }
+
+    @Override
+    public double getDouble(Record row, int columnIndex) {
+      return row.getFloat(columnIndex);
+    }
+
+    @Override
+    public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
+      properColumnBuilder.writeFloat((float) value);
+    }
+  }
+
+  private static class DoubleAppender implements ResultColumnAppender {
+
+    @Override
+    public void append(Record row, int columnIndex, ColumnBuilder 
properColumnBuilder) {
+      if (row.isNull(columnIndex)) {
+        properColumnBuilder.appendNull();
+      } else {
+        properColumnBuilder.writeDouble(row.getDouble(columnIndex));
+      }
+    }
+
+    @Override
+    public double getDouble(Record row, int columnIndex) {
+      return row.getDouble(columnIndex);
+    }
+
+    @Override
+    public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
+      properColumnBuilder.writeDouble(value);
     }
   }
 }
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift 
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 1849cdf19d5..95ab108b948 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -91,14 +91,14 @@ struct TTrainingReq {
 
 struct TForecastReq {
   1: required string modelId
-  2: required list<list<double>> inputData
+  2: required binary inputData
   3: required i32 outputLength
   4: optional map<string, string> options
 }
 
 struct TForecastResp {
   1: required common.TSStatus status
-  2: required list<list<double>> inferenceResult
+  2: required binary inferenceResult
 }
 
 service IAINodeRPCService {

Reply via email to