This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch ty/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/ty/forecast by this push:
new a4cfe9bad23 perfect ForecastTableFunction
a4cfe9bad23 is described below
commit a4cfe9bad23d9b44dd5541bf187ccf1247c44d6f
Author: JackieTien97 <[email protected]>
AuthorDate: Tue Apr 22 20:23:10 2025 +0800
perfect ForecastTableFunction
---
.../java/org/apache/iotdb/rpc/TSStatusCode.java | 1 +
.../iotdb/db/exception/ainode/ModelException.java | 11 +-
.../db/queryengine/plan/analyze/IModelFetcher.java | 4 +
.../db/queryengine/plan/analyze/ModelFetcher.java | 26 ++
.../function}/ForecastTableFunction.java | 286 +++++++++++++++++----
.../iotdb/commons/client/ainode/AINodeClient.java | 27 ++
.../thrift-ainode/src/main/thrift/ainode.thrift | 2 +-
7 files changed, 302 insertions(+), 55 deletions(-)
diff --git
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
index 0c732282bcf..a3114333bcc 100644
---
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
+++
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
@@ -190,6 +190,7 @@ public enum TSStatusCode {
TRANSFER_LEADER_ERROR(1008),
GET_CLUSTER_ID_ERROR(1009),
CAN_NOT_CONNECT_CONFIGNODE(1010),
+ CAN_NOT_CONNECT_AINODE(1011),
// Sync, Load TsFile
LOAD_FILE_ERROR(1100),
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java
index 4a007e7048c..cce01c8ad28 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java
@@ -19,17 +19,18 @@
package org.apache.iotdb.db.exception.ainode;
+import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
import org.apache.iotdb.rpc.TSStatusCode;
-public class ModelException extends RuntimeException {
- TSStatusCode statusCode;
+import static org.apache.iotdb.rpc.TSStatusCode.representOf;
+
+public class ModelException extends IoTDBRuntimeException {
public ModelException(String message, TSStatusCode code) {
- super(message);
- this.statusCode = code;
+ super(message, code.getStatusCode());
}
public TSStatusCode getStatusCode() {
- return statusCode;
+ return representOf(getErrorCode());
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java
index 1feecaefde9..586e12e589a 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java
@@ -20,8 +20,12 @@
package org.apache.iotdb.db.queryengine.plan.analyze;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
+import
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
public interface IModelFetcher {
/** Get model information by model id from configNode. */
TSStatus fetchModel(String modelId, Analysis analysis);
+
+ // currently only used by table model
+ ModelInferenceDescriptor fetchModel(String modelName);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java
index 8cefb5e0cf3..36382348b8e 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.client.IClientManager;
import org.apache.iotdb.commons.client.exception.ClientManagerException;
import org.apache.iotdb.commons.consensus.ConfigRegionId;
+import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
@@ -78,4 +79,29 @@ public class ModelFetcher implements IModelFetcher {
throw new StatementAnalyzeException(e.getMessage());
}
}
+
+ @Override
+ public ModelInferenceDescriptor fetchModel(String modelName) {
+ try (ConfigNodeClient client =
+ configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID))
{
+ TGetModelInfoResp getModelInfoResp = client.getModelInfo(new
TGetModelInfoReq(modelName));
+ if (getModelInfoResp.getStatus().getCode() ==
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ if (getModelInfoResp.modelInfo != null &&
getModelInfoResp.isSetAiNodeAddress()) {
+ return new ModelInferenceDescriptor(
+ getModelInfoResp.aiNodeAddress,
+ ModelInformation.deserialize(getModelInfoResp.modelInfo));
+ } else {
+ throw new IoTDBRuntimeException(
+ String.format("model [%s] is not available", modelName),
+ TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
+ }
+ } else {
+ throw new
ModelNotFoundException(getModelInfoResp.getStatus().getMessage());
+ }
+ } catch (ClientManagerException | TException e) {
+ throw new IoTDBRuntimeException(
+ String.format("fetch model [%s] info failed: %s", modelName,
e.getMessage()),
+ TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
+ }
+ }
}
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/ForecastTableFunction.java
similarity index 59%
rename from
iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
rename to
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/ForecastTableFunction.java
index a1ed1b73d04..481b5154ad5 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/tvf/ForecastTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/ForecastTableFunction.java
@@ -17,11 +17,19 @@
* under the License.
*/
-package org.apache.iotdb.commons.udf.builtin.relational.tvf;
+package org.apache.iotdb.db.queryengine.plan.relational.function;
+import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
+import org.apache.iotdb.common.rpc.thrift.TEndPoint;
+import org.apache.iotdb.commons.client.IClientManager;
+import org.apache.iotdb.commons.client.ainode.AINodeClient;
+import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
+import org.apache.iotdb.db.exception.sql.SemanticException;
+import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
+import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher;
+import
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
import org.apache.iotdb.rpc.TSStatusCode;
-import org.apache.iotdb.udf.api.exception.UDFException;
import org.apache.iotdb.udf.api.relational.TableFunction;
import org.apache.iotdb.udf.api.relational.access.Record;
import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis;
@@ -36,8 +44,18 @@ import
org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSp
import
org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification;
import org.apache.iotdb.udf.api.type.Type;
+import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
-
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.TsBlock;
+import org.apache.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.tsfile.read.common.block.column.TsBlockSerde;
+import org.apache.tsfile.utils.PublicBAOS;
+import org.apache.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -50,9 +68,81 @@ import java.util.Optional;
import java.util.Set;
import static
org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex;
+import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
public class ForecastTableFunction implements TableFunction {
+ private static class ForecastTableFunctionHandle {
+ TEndPoint targetAINode;
+ String modelId;
+ int maxInputLength;
+ int outputLength;
+ boolean keepInput;
+ Map<String, String> options;
+ List<Type> types;
+
+ public ForecastTableFunctionHandle() {}
+
+ public ForecastTableFunctionHandle(
+ boolean keepInput,
+ int maxInputLength,
+ String modelId,
+ Map<String, String> options,
+ int outputLength,
+ TEndPoint targetAINode,
+ List<Type> types) {
+ this.keepInput = keepInput;
+ this.maxInputLength = maxInputLength;
+ this.modelId = modelId;
+ this.options = options;
+ this.outputLength = outputLength;
+ this.targetAINode = targetAINode;
+ this.types = types;
+ }
+
+ public byte[] serialize() {
+ try (PublicBAOS publicBAOS = new PublicBAOS();
+ DataOutputStream outputStream = new DataOutputStream(publicBAOS)) {
+ ReadWriteIOUtils.write(targetAINode.getIp(), outputStream);
+ ReadWriteIOUtils.write(targetAINode.getPort(), outputStream);
+ ReadWriteIOUtils.write(modelId, outputStream);
+ ReadWriteIOUtils.write(maxInputLength, outputStream);
+ ReadWriteIOUtils.write(outputLength, outputStream);
+ ReadWriteIOUtils.write(keepInput, outputStream);
+ ReadWriteIOUtils.write(options, outputStream);
+ ReadWriteIOUtils.write(types.size(), outputStream);
+ for (Type type : types) {
+ ReadWriteIOUtils.write(type.getType(), outputStream);
+ }
+ outputStream.flush();
+ return publicBAOS.toByteArray();
+ } catch (IOException e) {
+ throw new IoTDBRuntimeException(
+ String.format(
+ "Error occurred while serializing ForecastTableFunctionHandle:
%s", e.getMessage()),
+ TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
+ }
+ }
+
+ public void deserialize(byte[] bytes) {
+ ByteBuffer buffer = ByteBuffer.wrap(bytes);
+ this.targetAINode =
+ new TEndPoint(ReadWriteIOUtils.readString(buffer),
ReadWriteIOUtils.readInt(buffer));
+ this.modelId = ReadWriteIOUtils.readString(buffer);
+ this.maxInputLength = ReadWriteIOUtils.readInt(buffer);
+ this.outputLength = ReadWriteIOUtils.readInt(buffer);
+ this.keepInput = ReadWriteIOUtils.readBoolean(buffer);
+ this.options = ReadWriteIOUtils.readMap(buffer);
+ int size = ReadWriteIOUtils.readInt(buffer);
+ this.types = new ArrayList<>(size);
+ for (int i = 0; i < size; i++) {
+ types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
+ }
+ }
+ }
+
+ private static final IModelFetcher MODEL_FETCHER =
ModelFetcher.getInstance();
+
private static final String INPUT_PARAMETER_NAME = "INPUT";
private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
private static final String INPUT_LENGTH_PARAMETER_NAME = "INPUT_LENGTH";
@@ -68,6 +158,8 @@ public class ForecastTableFunction implements TableFunction {
private static final String OPTIONS_PARAMETER_NAME = "OPTIONS";
private static final String DEFAULT_OPTIONS = "";
+ private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s";
+
private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>();
static {
@@ -117,27 +209,30 @@ public class ForecastTableFunction implements
TableFunction {
}
@Override
- public TableFunctionAnalysis analyze(Map<String, Argument> arguments) throws
UDFException {
+ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME);
String modelId = (String) ((ScalarArgument)
arguments.get(MODEL_ID_PARAMETER_NAME)).getValue();
// modelId should never be null or empty
if (modelId == null || modelId.isEmpty()) {
- throw new UDFException(
+ throw new SemanticException(
String.format("%s should never be null or empty",
MODEL_ID_PARAMETER_NAME));
}
// make sure modelId exists
- if (!checkModelIdExist(modelId)) {
- throw new UDFException(
- String.format(
- "%s %s doesn't exist, you can use `show models` command to
choose existing model.",
- MODEL_ID_PARAMETER_NAME, modelId));
+ ModelInferenceDescriptor descriptor = getModelInfo(modelId);
+ if (descriptor == null || !descriptor.getModelInformation().available()) {
+ throw new IoTDBRuntimeException(
+ String.format("model [%s] is not available", modelId),
+ TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
}
+ int maxInputLength = descriptor.getModelInformation().getInputShape()[0];
+ TEndPoint targetAINode = descriptor.getTargetAINode();
+
int outputLength =
(int) ((ScalarArgument)
arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue();
if (outputLength <= 0) {
- throw new UDFException(
+ throw new SemanticException(
String.format("%s should be greater than 0",
OUTPUT_LENGTH_PARAMETER_NAME));
}
@@ -157,6 +252,7 @@ public class ForecastTableFunction implements TableFunction
{
DescribedSchema.Builder properColumnSchemaBuilder =
new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);
+ List<Type> predicatedColumnTypes = new ArrayList<>();
List<Optional<String>> allInputColumnsName = input.getFieldNames();
List<Type> allInputColumnsType = input.getFieldTypes();
if (predicatedColumns.isEmpty()) {
@@ -166,6 +262,7 @@ public class ForecastTableFunction implements TableFunction
{
Optional<String> fieldName = allInputColumnsName.get(i);
if (!fieldName.isPresent() ||
!excludedColumns.contains(fieldName.get())) {
Type columnType = allInputColumnsType.get(i);
+ predicatedColumnTypes.add(columnType);
checkType(columnType, fieldName.orElse(""));
requiredIndexList.add(i);
properColumnSchemaBuilder.addField(fieldName, columnType);
@@ -186,18 +283,20 @@ public class ForecastTableFunction implements
TableFunction {
// columns need to be predicated
for (String outputColumn : predictedColumnsArray) {
if (excludedColumns.contains(outputColumn)) {
- throw new UDFException(
+ throw new SemanticException(
String.format("%s is in partition by clause or is time column",
outputColumn));
}
Integer inputColumnIndex = inputColumnIndexMap.get(outputColumn);
if (inputColumnIndex == null) {
- throw new UDFException(String.format("Column %s don't exist in
input", outputColumn));
+ throw new SemanticException(
+ String.format("Column %s don't exist in input", outputColumn));
}
if (!requiredIndexSet.add(inputColumnIndex)) {
- throw new UDFException(String.format("Duplicate column %s",
outputColumn));
+ throw new SemanticException(String.format("Duplicate column %s",
outputColumn));
}
Type columnType = allInputColumnsType.get(inputColumnIndex);
+ predicatedColumnTypes.add(columnType);
checkType(columnType, outputColumn);
requiredIndexList.add(inputColumnIndex);
properColumnSchemaBuilder.addField(outputColumn, columnType);
@@ -210,6 +309,19 @@ public class ForecastTableFunction implements
TableFunction {
properColumnSchemaBuilder.addField(IS_INPUT_COLUMN_NAME, Type.BOOLEAN);
}
+ String options = (String) ((ScalarArgument)
arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
+
+ // TODO put functionHandle into TableFunctionAnalysis after after yanze's
pr being merged
+ ForecastTableFunctionHandle functionHandle =
+ new ForecastTableFunctionHandle(
+ keepInput,
+ maxInputLength,
+ modelId,
+ parseOptions(options),
+ outputLength,
+ targetAINode,
+ predicatedColumnTypes);
+
// outputColumnSchema
return TableFunctionAnalysis.builder()
.properColumnSchema(properColumnSchemaBuilder.build())
@@ -217,14 +329,14 @@ public class ForecastTableFunction implements
TableFunction {
.build();
}
- private boolean checkModelIdExist(String modelId) {
- return true;
+ private ModelInferenceDescriptor getModelInfo(String modelId) {
+ return MODEL_FETCHER.fetchModel(modelId);
}
// only allow for INT32, INT64, FLOAT, DOUBLE
private void checkType(Type type, String columnName) {
if (!ALLOWED_INPUT_TYPES.contains(type)) {
- throw new UDFException(
+ throw new SemanticException(
String.format(
"The type of the column [%s] is [%s], only INT32, INT64, FLOAT,
DOUBLE is allowed",
columnName, type));
@@ -233,50 +345,67 @@ public class ForecastTableFunction implements
TableFunction {
@Override
public TableFunctionProcessorProvider getProcessorProvider(Map<String,
Argument> arguments) {
- int inputLength =
- (int) ((ScalarArgument)
arguments.get(INPUT_LENGTH_PARAMETER_NAME)).getValue();
- int outputLength =
- (int) ((ScalarArgument)
arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue();
- boolean keepInput =
- (boolean) ((ScalarArgument)
arguments.get(KEEP_INPUT_PARAMETER_NAME)).getValue();
- String options = (String) ((ScalarArgument)
arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
+ // TODO use functionHandle in parameter after yanze's pr being merged
+ ForecastTableFunctionHandle functionHandle = new
ForecastTableFunctionHandle();
return new TableFunctionProcessorProvider() {
@Override
public TableFunctionDataProcessor getDataProcessor() {
- return new ForecastDataProcessor(
- inputLength, outputLength, keepInput, parseOptions(options),
Collections.emptyList());
+ return new ForecastDataProcessor(functionHandle);
}
};
}
private static Map<String, String> parseOptions(String options) {
- return Collections.emptyMap();
+ String[] optionArray = options.split(",");
+ if (optionArray.length == 0) {
+ throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT,
options));
+ }
+
+ Map<String, String> optionsMap = new HashMap<>(optionArray.length);
+ for (String option : optionArray) {
+ int index = option.indexOf('=');
+ if (index == -1 || index == option.length() - 1) {
+ throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT,
option));
+ }
+ String key = option.substring(0, index).trim();
+ String value = option.substring(index + 1).trim();
+ optionsMap.put(key, value);
+ }
+ return optionsMap;
}
private static class ForecastDataProcessor implements
TableFunctionDataProcessor {
+ private static final TsBlockSerde SERDE = new TsBlockSerde();
+ private static final IClientManager<TEndPoint, AINodeClient>
CLIENT_MANAGER =
+ AINodeClientManager.getInstance();
+
+ private final TEndPoint targetAINode;
+ private final String modelId;
private final int maxInputLength;
private final int outputLength;
private final boolean keepInput;
private final Map<String, String> options;
private final List<Record> inputRecords;
private final List<ResultColumnAppender> resultColumnAppenderList;
-
- public ForecastDataProcessor(
- int maxInputLength,
- int outputLength,
- boolean keepInput,
- Map<String, String> options,
- List<Type> types) {
- this.maxInputLength = maxInputLength;
- this.outputLength = outputLength;
- this.keepInput = keepInput;
- this.options = options;
+ private final TsBlockBuilder inputTsBlockBuilder;
+
+ public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
+ this.targetAINode = functionHandle.targetAINode;
+ this.modelId = functionHandle.modelId;
+ this.maxInputLength = functionHandle.maxInputLength;
+ this.outputLength = functionHandle.outputLength;
+ this.keepInput = functionHandle.keepInput;
+ this.options = functionHandle.options;
this.inputRecords = new LinkedList<>();
- this.resultColumnAppenderList = new ArrayList<>(types.size());
- for (Type type : types) {
+ this.resultColumnAppenderList = new
ArrayList<>(functionHandle.types.size());
+ List<TSDataType> tsDataTypeList = new
ArrayList<>(functionHandle.types.size());
+ for (Type type : functionHandle.types) {
resultColumnAppenderList.add(createResultColumnAppender(type));
+ // ainode currently only accept double input
+ tsDataTypeList.add(TSDataType.DOUBLE);
}
+ this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
}
private static ResultColumnAppender createResultColumnAppender(Type type) {
@@ -341,6 +470,29 @@ public class ForecastTableFunction implements
TableFunction {
}
// predicated columns
+ TsBlock predicatedResult = forecast();
+ if (predicatedResult.getPositionCount() != outputLength) {
+ throw new IoTDBRuntimeException(
+ String.format(
+ "Model %s output length is %s, doesn't equal to specified %s",
+ modelId, predicatedResult.getPositionCount(), outputLength),
+ TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
+ }
+ for (int columnIndex = 1, size = predicatedResult.getValueColumnCount();
+ columnIndex <= size;
+ columnIndex++) {
+ Column column = predicatedResult.getColumn(columnIndex - 1);
+ ColumnBuilder builder = properColumnBuilders.get(columnIndex);
+ ResultColumnAppender appender =
resultColumnAppenderList.get(columnIndex - 1);
+ for (int row = 0; row < outputLength; row++) {
+ if (column.isNull(row)) {
+ builder.appendNull();
+ } else {
+ // convert double to real type
+ appender.writeDouble(column.getDouble(row), builder);
+ }
+ }
+ }
// is_input column if keep_input is true
if (keepInput) {
@@ -349,6 +501,42 @@ public class ForecastTableFunction implements
TableFunction {
}
}
}
+
+ private TsBlock forecast() {
+ while (!inputRecords.isEmpty()) {
+ Record row = inputRecords.removeFirst();
+ inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0));
+ for (int i = 1, size = row.size(); i < size; i++) {
+ // we set null input to 0.0
+ if (row.isNull(i)) {
+ inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(0.0);
+ } else {
+ // need to transform other types to DOUBLE
+ inputTsBlockBuilder
+ .getColumnBuilder(i - 1)
+ .writeDouble(resultColumnAppenderList.get(i -
1).getDouble(row, i));
+ }
+ }
+ inputTsBlockBuilder.declarePosition();
+ }
+ TsBlock inputData = inputTsBlockBuilder.build();
+
+ TForecastResp resp;
+ try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) {
+ resp = client.forecast(modelId, inputData, outputLength, options);
+ } catch (Exception e) {
+ throw new IoTDBRuntimeException(e.getMessage(),
CAN_NOT_CONNECT_AINODE.getStatusCode());
+ }
+
+ if (resp.getStatus().getCode() !=
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ String message =
+ String.format(
+ "Error occurred while executing forecast:[%s]",
resp.getStatus().getMessage());
+ throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
+ }
+
+ return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+ }
}
private interface ResultColumnAppender {
@@ -356,7 +544,7 @@ public class ForecastTableFunction implements TableFunction
{
double getDouble(Record row, int columnIndex);
- void writeDouble(double value, ColumnBuilder properColumnBuilder);
+ void writeDouble(double value, ColumnBuilder columnBuilder);
}
private static class Int32Appender implements ResultColumnAppender {
@@ -376,8 +564,8 @@ public class ForecastTableFunction implements TableFunction
{
}
@Override
- public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
- properColumnBuilder.writeInt((int) value);
+ public void writeDouble(double value, ColumnBuilder columnBuilder) {
+ columnBuilder.writeInt((int) value);
}
}
@@ -398,8 +586,8 @@ public class ForecastTableFunction implements TableFunction
{
}
@Override
- public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
- properColumnBuilder.writeLong((long) value);
+ public void writeDouble(double value, ColumnBuilder columnBuilder) {
+ columnBuilder.writeLong((long) value);
}
}
@@ -420,8 +608,8 @@ public class ForecastTableFunction implements TableFunction
{
}
@Override
- public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
- properColumnBuilder.writeFloat((float) value);
+ public void writeDouble(double value, ColumnBuilder columnBuilder) {
+ columnBuilder.writeFloat((float) value);
}
}
@@ -442,8 +630,8 @@ public class ForecastTableFunction implements TableFunction
{
}
@Override
- public void writeDouble(double value, ColumnBuilder properColumnBuilder) {
- properColumnBuilder.writeDouble(value);
+ public void writeDouble(double value, ColumnBuilder columnBuilder) {
+ columnBuilder.writeDouble(value);
}
}
}
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index 1eca6e7f16a..3cc416f0aad 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -22,6 +22,8 @@ package org.apache.iotdb.commons.client.ainode;
import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService;
import org.apache.iotdb.ainode.rpc.thrift.TConfigs;
import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq;
+import org.apache.iotdb.ainode.rpc.thrift.TForecastReq;
+import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
@@ -53,10 +55,14 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
+import static org.apache.iotdb.rpc.TSStatusCode.INTERNAL_SERVER_ERROR;
+
public class AINodeClient implements AutoCloseable, ThriftClient {
private static final Logger logger =
LoggerFactory.getLogger(AINodeClient.class);
@@ -188,6 +194,27 @@ public class AINodeClient implements AutoCloseable,
ThriftClient {
}
}
+ public TForecastResp forecast(
+ String modelId, TsBlock inputTsBlock, int outputLength, Map<String,
String> options) {
+ try {
+ TForecastReq forecastReq =
+ new TForecastReq(modelId, tsBlockSerde.serialize(inputTsBlock),
outputLength);
+ forecastReq.setOptions(options);
+ return client.forecast(forecastReq);
+ } catch (IOException e) {
+ TSStatus tsStatus = new TSStatus(INTERNAL_SERVER_ERROR.getStatusCode());
+ tsStatus.setMessage(String.format("Failed to serialize input tsblock
%s", e.getMessage()));
+ return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
+ } catch (TException e) {
+ TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode());
+ tsStatus.setMessage(
+ String.format(
+ "Failed to connect to AINode from DataNode when executing %s:
%s",
+ Thread.currentThread().getStackTrace()[1].getMethodName(),
e.getMessage()));
+ return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
+ }
+ }
+
public TSStatus createTrainingTask(TTrainingReq req) throws TException {
try {
return client.createTrainingTask(req);
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 95ab108b948..5643da743a8 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -98,7 +98,7 @@ struct TForecastReq {
struct TForecastResp {
1: required common.TSStatus status
- 2: required binary inferenceResult
+ 2: required binary forecastResult
}
service IAINodeRPCService {