This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch ty/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/ty/forecast by this push:
     new a4cfe9bad23 perfect ForecastTableFunction
a4cfe9bad23 is described below

commit a4cfe9bad23d9b44dd5541bf187ccf1247c44d6f
Author: JackieTien97 <[email protected]>
AuthorDate: Tue Apr 22 20:23:10 2025 +0800

    perfect ForecastTableFunction
---
 .../java/org/apache/iotdb/rpc/TSStatusCode.java    |   1 +
 .../iotdb/db/exception/ainode/ModelException.java  |  11 +-
 .../db/queryengine/plan/analyze/IModelFetcher.java |   4 +
 .../db/queryengine/plan/analyze/ModelFetcher.java  |  26 ++
 .../function}/ForecastTableFunction.java           | 286 +++++++++++++++++----
 .../iotdb/commons/client/ainode/AINodeClient.java  |  27 ++
 .../thrift-ainode/src/main/thrift/ainode.thrift    |   2 +-
 7 files changed, 302 insertions(+), 55 deletions(-)

diff --git 
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java 
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
index 0c732282bcf..a3114333bcc 100644
--- 
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
+++ 
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
@@ -190,6 +190,7 @@ public enum TSStatusCode {
   TRANSFER_LEADER_ERROR(1008),
   GET_CLUSTER_ID_ERROR(1009),
   CAN_NOT_CONNECT_CONFIGNODE(1010),
+  CAN_NOT_CONNECT_AINODE(1011),
 
   // Sync, Load TsFile
   LOAD_FILE_ERROR(1100),
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java
index 4a007e7048c..cce01c8ad28 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java
@@ -19,17 +19,18 @@
 
 package org.apache.iotdb.db.exception.ainode;
 
+import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
 import org.apache.iotdb.rpc.TSStatusCode;
 
-public class ModelException extends RuntimeException {
-  TSStatusCode statusCode;
+import static org.apache.iotdb.rpc.TSStatusCode.representOf;
+
+public class ModelException extends IoTDBRuntimeException {
 
   public ModelException(String message, TSStatusCode code) {
-    super(message);
-    this.statusCode = code;
+    super(message, code.getStatusCode());
   }
 
   public TSStatusCode getStatusCode() {
-    return statusCode;
+    return representOf(getErrorCode());
   }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java
index 1feecaefde9..586e12e589a 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java
@@ -20,8 +20,12 @@
 package org.apache.iotdb.db.queryengine.plan.analyze;
 
 import org.apache.iotdb.common.rpc.thrift.TSStatus;
+import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
 
 public interface IModelFetcher {
   /** Get model information by model id from configNode. */
   TSStatus fetchModel(String modelId, Analysis analysis);
+
+  // currently only used by table model
+  ModelInferenceDescriptor fetchModel(String modelName);
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java
index 8cefb5e0cf3..36382348b8e 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.common.rpc.thrift.TSStatus;
 import org.apache.iotdb.commons.client.IClientManager;
 import org.apache.iotdb.commons.client.exception.ClientManagerException;
 import org.apache.iotdb.commons.consensus.ConfigRegionId;
+import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
 import org.apache.iotdb.commons.model.ModelInformation;
 import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
 import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
@@ -78,4 +79,29 @@ public class ModelFetcher implements IModelFetcher {
       throw new StatementAnalyzeException(e.getMessage());
     }
   }
+
+  @Override
+  public ModelInferenceDescriptor fetchModel(String modelName) {
+    try (ConfigNodeClient client =
+        configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) 
{
+      TGetModelInfoResp getModelInfoResp = client.getModelInfo(new 
TGetModelInfoReq(modelName));
+      if (getModelInfoResp.getStatus().getCode() == 
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+        if (getModelInfoResp.modelInfo != null && 
getModelInfoResp.isSetAiNodeAddress()) {
+          return new ModelInferenceDescriptor(
+              getModelInfoResp.aiNodeAddress,
+              ModelInformation.deserialize(getModelInfoResp.modelInfo));
+        } else {
+          throw new IoTDBRuntimeException(
+              String.format("model [%s] is not available", modelName),
+              TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
+        }
+      } else {
+        throw new 
ModelNotFoundException(getModelInfoResp.getStatus().getMessage());
+      }
+    } catch (ClientManagerException | TException e) {
+      throw new IoTDBRuntimeException(
+          String.format("fetch model [%s] info failed: %s", modelName, 
e.getMessage()),
+          TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
+    }
+  }
 }
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/ForecastTableFunction.java
similarity index 59%
rename from 
iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
rename to 
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/ForecastTableFunction.java
index a1ed1b73d04..481b5154ad5 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/ForecastTableFunction.java
@@ -17,11 +17,19 @@
  * under the License.
  */
 
-package org.apache.iotdb.commons.udf.builtin.relational.tvf;
+package org.apache.iotdb.db.queryengine.plan.relational.function;
 
+import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
+import org.apache.iotdb.common.rpc.thrift.TEndPoint;
+import org.apache.iotdb.commons.client.IClientManager;
+import org.apache.iotdb.commons.client.ainode.AINodeClient;
+import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
 import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
+import org.apache.iotdb.db.exception.sql.SemanticException;
+import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
+import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher;
+import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
 import org.apache.iotdb.rpc.TSStatusCode;
-import org.apache.iotdb.udf.api.exception.UDFException;
 import org.apache.iotdb.udf.api.relational.TableFunction;
 import org.apache.iotdb.udf.api.relational.access.Record;
 import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis;
@@ -36,8 +44,18 @@ import 
org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSp
 import 
org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification;
 import org.apache.iotdb.udf.api.type.Type;
 
+import org.apache.tsfile.block.column.Column;
 import org.apache.tsfile.block.column.ColumnBuilder;
-
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.TsBlock;
+import org.apache.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.tsfile.read.common.block.column.TsBlockSerde;
+import org.apache.tsfile.utils.PublicBAOS;
+import org.apache.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -50,9 +68,81 @@ import java.util.Optional;
 import java.util.Set;
 
 import static 
org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex;
+import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
 
 public class ForecastTableFunction implements TableFunction {
 
+  private static class ForecastTableFunctionHandle {
+    TEndPoint targetAINode;
+    String modelId;
+    int maxInputLength;
+    int outputLength;
+    boolean keepInput;
+    Map<String, String> options;
+    List<Type> types;
+
+    public ForecastTableFunctionHandle() {}
+
+    public ForecastTableFunctionHandle(
+        boolean keepInput,
+        int maxInputLength,
+        String modelId,
+        Map<String, String> options,
+        int outputLength,
+        TEndPoint targetAINode,
+        List<Type> types) {
+      this.keepInput = keepInput;
+      this.maxInputLength = maxInputLength;
+      this.modelId = modelId;
+      this.options = options;
+      this.outputLength = outputLength;
+      this.targetAINode = targetAINode;
+      this.types = types;
+    }
+
+    public byte[] serialize() {
+      try (PublicBAOS publicBAOS = new PublicBAOS();
+          DataOutputStream outputStream = new DataOutputStream(publicBAOS)) {
+        ReadWriteIOUtils.write(targetAINode.getIp(), outputStream);
+        ReadWriteIOUtils.write(targetAINode.getPort(), outputStream);
+        ReadWriteIOUtils.write(modelId, outputStream);
+        ReadWriteIOUtils.write(maxInputLength, outputStream);
+        ReadWriteIOUtils.write(outputLength, outputStream);
+        ReadWriteIOUtils.write(keepInput, outputStream);
+        ReadWriteIOUtils.write(options, outputStream);
+        ReadWriteIOUtils.write(types.size(), outputStream);
+        for (Type type : types) {
+          ReadWriteIOUtils.write(type.getType(), outputStream);
+        }
+        outputStream.flush();
+        return publicBAOS.toByteArray();
+      } catch (IOException e) {
+        throw new IoTDBRuntimeException(
+            String.format(
+                "Error occurred while serializing ForecastTableFunctionHandle: 
%s", e.getMessage()),
+            TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
+      }
+    }
+
+    public void deserialize(byte[] bytes) {
+      ByteBuffer buffer = ByteBuffer.wrap(bytes);
+      this.targetAINode =
+          new TEndPoint(ReadWriteIOUtils.readString(buffer), 
ReadWriteIOUtils.readInt(buffer));
+      this.modelId = ReadWriteIOUtils.readString(buffer);
+      this.maxInputLength = ReadWriteIOUtils.readInt(buffer);
+      this.outputLength = ReadWriteIOUtils.readInt(buffer);
+      this.keepInput = ReadWriteIOUtils.readBoolean(buffer);
+      this.options = ReadWriteIOUtils.readMap(buffer);
+      int size = ReadWriteIOUtils.readInt(buffer);
+      this.types = new ArrayList<>(size);
+      for (int i = 0; i < size; i++) {
+        types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
+      }
+    }
+  }
+
+  private static final IModelFetcher MODEL_FETCHER = 
ModelFetcher.getInstance();
+
   private static final String INPUT_PARAMETER_NAME = "INPUT";
   private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
   private static final String INPUT_LENGTH_PARAMETER_NAME = "INPUT_LENGTH";
@@ -68,6 +158,8 @@ public class ForecastTableFunction implements TableFunction {
   private static final String OPTIONS_PARAMETER_NAME = "OPTIONS";
   private static final String DEFAULT_OPTIONS = "";
 
+  private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s";
+
   private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>();
 
   static {
@@ -117,27 +209,30 @@ public class ForecastTableFunction implements 
TableFunction {
   }
 
   @Override
-  public TableFunctionAnalysis analyze(Map<String, Argument> arguments) throws 
UDFException {
+  public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
     TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME);
     String modelId = (String) ((ScalarArgument) 
arguments.get(MODEL_ID_PARAMETER_NAME)).getValue();
     // modelId should never be null or empty
     if (modelId == null || modelId.isEmpty()) {
-      throw new UDFException(
+      throw new SemanticException(
           String.format("%s should never be null or empty", 
MODEL_ID_PARAMETER_NAME));
     }
 
     // make sure modelId exists
-    if (!checkModelIdExist(modelId)) {
-      throw new UDFException(
-          String.format(
-              "%s %s doesn't exist, you can use `show models` command to 
choose existing model.",
-              MODEL_ID_PARAMETER_NAME, modelId));
+    ModelInferenceDescriptor descriptor = getModelInfo(modelId);
+    if (descriptor == null || !descriptor.getModelInformation().available()) {
+      throw new IoTDBRuntimeException(
+          String.format("model [%s] is not available", modelId),
+          TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
     }
 
+    int maxInputLength = descriptor.getModelInformation().getInputShape()[0];
+    TEndPoint targetAINode = descriptor.getTargetAINode();
+
     int outputLength =
         (int) ((ScalarArgument) 
arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue();
     if (outputLength <= 0) {
-      throw new UDFException(
+      throw new SemanticException(
           String.format("%s should be greater than 0", 
OUTPUT_LENGTH_PARAMETER_NAME));
     }
 
@@ -157,6 +252,7 @@ public class ForecastTableFunction implements TableFunction 
{
     DescribedSchema.Builder properColumnSchemaBuilder =
         new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);
 
+    List<Type> predicatedColumnTypes = new ArrayList<>();
     List<Optional<String>> allInputColumnsName = input.getFieldNames();
     List<Type> allInputColumnsType = input.getFieldTypes();
     if (predicatedColumns.isEmpty()) {
@@ -166,6 +262,7 @@ public class ForecastTableFunction implements TableFunction 
{
         Optional<String> fieldName = allInputColumnsName.get(i);
         if (!fieldName.isPresent() || 
!excludedColumns.contains(fieldName.get())) {
           Type columnType = allInputColumnsType.get(i);
+          predicatedColumnTypes.add(columnType);
           checkType(columnType, fieldName.orElse(""));
           requiredIndexList.add(i);
           properColumnSchemaBuilder.addField(fieldName, columnType);
@@ -186,18 +283,20 @@ public class ForecastTableFunction implements 
TableFunction {
       // columns need to be predicated
       for (String outputColumn : predictedColumnsArray) {
         if (excludedColumns.contains(outputColumn)) {
-          throw new UDFException(
+          throw new SemanticException(
               String.format("%s is in partition by clause or is time column", 
outputColumn));
         }
         Integer inputColumnIndex = inputColumnIndexMap.get(outputColumn);
         if (inputColumnIndex == null) {
-          throw new UDFException(String.format("Column %s don't exist in 
input", outputColumn));
+          throw new SemanticException(
+              String.format("Column %s don't exist in input", outputColumn));
         }
         if (!requiredIndexSet.add(inputColumnIndex)) {
-          throw new UDFException(String.format("Duplicate column %s", 
outputColumn));
+          throw new SemanticException(String.format("Duplicate column %s", 
outputColumn));
         }
 
         Type columnType = allInputColumnsType.get(inputColumnIndex);
+        predicatedColumnTypes.add(columnType);
         checkType(columnType, outputColumn);
         requiredIndexList.add(inputColumnIndex);
         properColumnSchemaBuilder.addField(outputColumn, columnType);
@@ -210,6 +309,19 @@ public class ForecastTableFunction implements 
TableFunction {
       properColumnSchemaBuilder.addField(IS_INPUT_COLUMN_NAME, Type.BOOLEAN);
     }
 
+    String options = (String) ((ScalarArgument) 
arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
+
+    // TODO put functionHandle into TableFunctionAnalysis after after yanze's 
pr being merged
+    ForecastTableFunctionHandle functionHandle =
+        new ForecastTableFunctionHandle(
+            keepInput,
+            maxInputLength,
+            modelId,
+            parseOptions(options),
+            outputLength,
+            targetAINode,
+            predicatedColumnTypes);
+
     // outputColumnSchema
     return TableFunctionAnalysis.builder()
         .properColumnSchema(properColumnSchemaBuilder.build())
@@ -217,14 +329,14 @@ public class ForecastTableFunction implements 
TableFunction {
         .build();
   }
 
-  private boolean checkModelIdExist(String modelId) {
-    return true;
+  private ModelInferenceDescriptor getModelInfo(String modelId) {
+    return MODEL_FETCHER.fetchModel(modelId);
   }
 
   // only allow for INT32, INT64, FLOAT, DOUBLE
   private void checkType(Type type, String columnName) {
     if (!ALLOWED_INPUT_TYPES.contains(type)) {
-      throw new UDFException(
+      throw new SemanticException(
           String.format(
               "The type of the column [%s] is [%s], only INT32, INT64, FLOAT, 
DOUBLE is allowed",
               columnName, type));
@@ -233,50 +345,67 @@ public class ForecastTableFunction implements 
TableFunction {
 
   @Override
   public TableFunctionProcessorProvider getProcessorProvider(Map<String, 
Argument> arguments) {
-    int inputLength =
-        (int) ((ScalarArgument) 
arguments.get(INPUT_LENGTH_PARAMETER_NAME)).getValue();
-    int outputLength =
-        (int) ((ScalarArgument) 
arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue();
-    boolean keepInput =
-        (boolean) ((ScalarArgument) 
arguments.get(KEEP_INPUT_PARAMETER_NAME)).getValue();
-    String options = (String) ((ScalarArgument) 
arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
+    // TODO use functionHandle in parameter after yanze's pr being merged
+    ForecastTableFunctionHandle functionHandle = new 
ForecastTableFunctionHandle();
     return new TableFunctionProcessorProvider() {
       @Override
       public TableFunctionDataProcessor getDataProcessor() {
-        return new ForecastDataProcessor(
-            inputLength, outputLength, keepInput, parseOptions(options), 
Collections.emptyList());
+        return new ForecastDataProcessor(functionHandle);
       }
     };
   }
 
   private static Map<String, String> parseOptions(String options) {
-    return Collections.emptyMap();
+    String[] optionArray = options.split(",");
+    if (optionArray.length == 0) {
+      throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, 
options));
+    }
+
+    Map<String, String> optionsMap = new HashMap<>(optionArray.length);
+    for (String option : optionArray) {
+      int index = option.indexOf('=');
+      if (index == -1 || index == option.length() - 1) {
+        throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, 
option));
+      }
+      String key = option.substring(0, index).trim();
+      String value = option.substring(index + 1).trim();
+      optionsMap.put(key, value);
+    }
+    return optionsMap;
   }
 
   private static class ForecastDataProcessor implements 
TableFunctionDataProcessor {
 
+    private static final TsBlockSerde SERDE = new TsBlockSerde();
+    private static final IClientManager<TEndPoint, AINodeClient> 
CLIENT_MANAGER =
+        AINodeClientManager.getInstance();
+
+    private final TEndPoint targetAINode;
+    private final String modelId;
     private final int maxInputLength;
     private final int outputLength;
     private final boolean keepInput;
     private final Map<String, String> options;
     private final List<Record> inputRecords;
     private final List<ResultColumnAppender> resultColumnAppenderList;
-
-    public ForecastDataProcessor(
-        int maxInputLength,
-        int outputLength,
-        boolean keepInput,
-        Map<String, String> options,
-        List<Type> types) {
-      this.maxInputLength = maxInputLength;
-      this.outputLength = outputLength;
-      this.keepInput = keepInput;
-      this.options = options;
+    private final TsBlockBuilder inputTsBlockBuilder;
+
+    public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
+      this.targetAINode = functionHandle.targetAINode;
+      this.modelId = functionHandle.modelId;
+      this.maxInputLength = functionHandle.maxInputLength;
+      this.outputLength = functionHandle.outputLength;
+      this.keepInput = functionHandle.keepInput;
+      this.options = functionHandle.options;
       this.inputRecords = new LinkedList<>();
-      this.resultColumnAppenderList = new ArrayList<>(types.size());
-      for (Type type : types) {
+      this.resultColumnAppenderList = new 
ArrayList<>(functionHandle.types.size());
+      List<TSDataType> tsDataTypeList = new 
ArrayList<>(functionHandle.types.size());
+      for (Type type : functionHandle.types) {
         resultColumnAppenderList.add(createResultColumnAppender(type));
+        // ainode currently only accept double input
+        tsDataTypeList.add(TSDataType.DOUBLE);
       }
+      this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
     }
 
     private static ResultColumnAppender createResultColumnAppender(Type type) {
@@ -341,6 +470,29 @@ public class ForecastTableFunction implements 
TableFunction {
       }
 
       // predicated columns
+      TsBlock predicatedResult = forecast();
+      if (predicatedResult.getPositionCount() != outputLength) {
+        throw new IoTDBRuntimeException(
+            String.format(
+                "Model %s output length is %s, doesn't equal to specified %s",
+                modelId, predicatedResult.getPositionCount(), outputLength),
+            TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
+      }
+      for (int columnIndex = 1, size = predicatedResult.getValueColumnCount();
+          columnIndex <= size;
+          columnIndex++) {
+        Column column = predicatedResult.getColumn(columnIndex - 1);
+        ColumnBuilder builder = properColumnBuilders.get(columnIndex);
+        ResultColumnAppender appender = 
resultColumnAppenderList.get(columnIndex - 1);
+        for (int row = 0; row < outputLength; row++) {
+          if (column.isNull(row)) {
+            builder.appendNull();
+          } else {
+            // convert double to real type
+            appender.writeDouble(column.getDouble(row), builder);
+          }
+        }
+      }
 
       // is_input column if keep_input is true
       if (keepInput) {
@@ -349,6 +501,42 @@ public class ForecastTableFunction implements 
TableFunction {
         }
       }
     }
+
+    private TsBlock forecast() {
+      while (!inputRecords.isEmpty()) {
+        Record row = inputRecords.removeFirst();
+        inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0));
+        for (int i = 1, size = row.size(); i < size; i++) {
+          // we set null input to 0.0
+          if (row.isNull(i)) {
+            inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(0.0);
+          } else {
+            // need to transform other types to DOUBLE
+            inputTsBlockBuilder
+                .getColumnBuilder(i - 1)
+                .writeDouble(resultColumnAppenderList.get(i - 
1).getDouble(row, i));
+          }
+        }
+        inputTsBlockBuilder.declarePosition();
+      }
+      TsBlock inputData = inputTsBlockBuilder.build();
+
+      TForecastResp resp;
+      try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) {
+        resp = client.forecast(modelId, inputData, outputLength, options);
+      } catch (Exception e) {
+        throw new IoTDBRuntimeException(e.getMessage(), 
CAN_NOT_CONNECT_AINODE.getStatusCode());
+      }
+
+      if (resp.getStatus().getCode() != 
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+        String message =
+            String.format(
+                "Error occurred while executing forecast:[%s]", 
resp.getStatus().getMessage());
+        throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
+      }
+
+      return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+    }
   }
 
   private interface ResultColumnAppender {
@@ -356,7 +544,7 @@ public class ForecastTableFunction implements TableFunction 
{
 
     double getDouble(Record row, int columnIndex);
 
-    void writeDouble(double value, ColumnBuilder properColumnBuilder);
+    void writeDouble(double value, ColumnBuilder columnBuilder);
   }
 
   private static class Int32Appender implements ResultColumnAppender {
@@ -376,8 +564,8 @@ public class ForecastTableFunction implements TableFunction 
{
     }
 
     @Override
-    public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
-      properColumnBuilder.writeInt((int) value);
+    public void writeDouble(double value, ColumnBuilder columnBuilder) {
+      columnBuilder.writeInt((int) value);
     }
   }
 
@@ -398,8 +586,8 @@ public class ForecastTableFunction implements TableFunction 
{
     }
 
     @Override
-    public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
-      properColumnBuilder.writeLong((long) value);
+    public void writeDouble(double value, ColumnBuilder columnBuilder) {
+      columnBuilder.writeLong((long) value);
     }
   }
 
@@ -420,8 +608,8 @@ public class ForecastTableFunction implements TableFunction 
{
     }
 
     @Override
-    public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
-      properColumnBuilder.writeFloat((float) value);
+    public void writeDouble(double value, ColumnBuilder columnBuilder) {
+      columnBuilder.writeFloat((float) value);
     }
   }
 
@@ -442,8 +630,8 @@ public class ForecastTableFunction implements TableFunction 
{
     }
 
     @Override
-    public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
-      properColumnBuilder.writeDouble(value);
+    public void writeDouble(double value, ColumnBuilder columnBuilder) {
+      columnBuilder.writeDouble(value);
     }
   }
 }
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index 1eca6e7f16a..3cc416f0aad 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -22,6 +22,8 @@ package org.apache.iotdb.commons.client.ainode;
 import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService;
 import org.apache.iotdb.ainode.rpc.thrift.TConfigs;
 import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq;
+import org.apache.iotdb.ainode.rpc.thrift.TForecastReq;
+import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
 import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq;
 import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
 import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
@@ -53,10 +55,14 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 
+import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
+import static org.apache.iotdb.rpc.TSStatusCode.INTERNAL_SERVER_ERROR;
+
 public class AINodeClient implements AutoCloseable, ThriftClient {
 
   private static final Logger logger = 
LoggerFactory.getLogger(AINodeClient.class);
@@ -188,6 +194,27 @@ public class AINodeClient implements AutoCloseable, 
ThriftClient {
     }
   }
 
+  public TForecastResp forecast(
+      String modelId, TsBlock inputTsBlock, int outputLength, Map<String, 
String> options) {
+    try {
+      TForecastReq forecastReq =
+          new TForecastReq(modelId, tsBlockSerde.serialize(inputTsBlock), 
outputLength);
+      forecastReq.setOptions(options);
+      return client.forecast(forecastReq);
+    } catch (IOException e) {
+      TSStatus tsStatus = new TSStatus(INTERNAL_SERVER_ERROR.getStatusCode());
+      tsStatus.setMessage(String.format("Failed to serialize input tsblock 
%s", e.getMessage()));
+      return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
+    } catch (TException e) {
+      TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode());
+      tsStatus.setMessage(
+          String.format(
+              "Failed to connect to AINode from DataNode when executing %s: 
%s",
+              Thread.currentThread().getStackTrace()[1].getMethodName(), 
e.getMessage()));
+      return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
+    }
+  }
+
   public TSStatus createTrainingTask(TTrainingReq req) throws TException {
     try {
       return client.createTrainingTask(req);
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift 
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 95ab108b948..5643da743a8 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -98,7 +98,7 @@ struct TForecastReq {
 
 struct TForecastResp {
   1: required common.TSStatus status
-  2: required binary inferenceResult
+  2: required binary forecastResult
 }
 
 service IAINodeRPCService {

Reply via email to