This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch research/MLEngine
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/research/MLEngine by this push:
new 8a19bc3dbfd [IOTDB-5083] [IoTDB ML] Support FORECAST function on
DataNode (#9906)
8a19bc3dbfd is described below
commit 8a19bc3dbfd1705126ea788d7a0a1a92ca217e99
Author: liuminghui233 <[email protected]>
AuthorDate: Fri Jun 2 10:36:21 2023 +0800
[IOTDB-5083] [IoTDB ML] Support FORECAST function on DataNode (#9906)
---
.../consensus/request/ConfigPhysicalPlan.java | 4 +
.../consensus/request/ConfigPhysicalPlanType.java | 1 +
.../request/read/model/GetModelInfoPlan.java | 79 +++++++
.../consensus/response/model/GetModelInfoResp.java | 29 ++-
.../response/{ => model}/ModelTableResp.java | 2 +-
.../response/{ => model}/TrailTableResp.java | 2 +-
.../iotdb/confignode/manager/ConfigManager.java | 10 +
.../apache/iotdb/confignode/manager/IManager.java | 5 +
.../iotdb/confignode/manager/ModelManager.java | 77 +++++--
.../iotdb/confignode/persistence/ModelInfo.java | 32 ++-
.../persistence/executor/ConfigPlanExecutor.java | 3 +
.../thrift/ConfigNodeRPCServiceProcessor.java | 7 +
.../src/main/thrift/confignode.thrift | 14 ++
.../thrift-mlnode/src/main/thrift/mlnode.thrift | 5 +-
.../commons/model/ForecastModeInformation.java | 139 ++++++++++++
.../iotdb/commons/model/ModelInformation.java | 61 ++++-
.../udf/builtin/ModelInferenceFunction.java | 34 ++-
.../commons/udf/service/UDFManagementService.java | 19 ++
.../apache/iotdb/db/client/ConfigNodeClient.java | 18 ++
.../org/apache/iotdb/db/client/MLNodeClient.java | 28 ++-
.../java/org/apache/iotdb/db/conf/IoTDBConfig.java | 11 +
.../org/apache/iotdb/db/constant/SqlConstant.java | 3 +
.../ModelInferenceProcessException.java} | 12 +-
.../fragment/FragmentInstanceManager.java | 9 +
.../db/mpp/execution/operator/AggregationUtil.java | 2 +-
.../operator/process/ml/ForecastOperator.java | 251 +++++++++++++++++++++
.../apache/iotdb/db/mpp/plan/analyze/Analysis.java | 11 +
.../iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java | 103 ++++++++-
.../mpp/plan/analyze/ClusterPartitionFetcher.java | 25 ++
.../db/mpp/plan/analyze/ExpressionAnalyzer.java | 2 +
.../mpp/plan/analyze/ExpressionTypeAnalyzer.java | 2 +
.../db/mpp/plan/analyze/IPartitionFetcher.java | 3 +
.../config/executor/ClusterConfigTaskExecutor.java | 16 +-
.../plan/expression/multi/FunctionExpression.java | 10 +
.../db/mpp/plan/expression/multi/FunctionType.java | 3 +-
.../db/mpp/plan/planner/LogicalPlanBuilder.java | 10 +
.../db/mpp/plan/planner/LogicalPlanVisitor.java | 17 ++
.../db/mpp/plan/planner/OperatorTreeGenerator.java | 49 ++++
.../plan/planner/plan/node/PlanGraphPrinter.java | 12 +
.../mpp/plan/planner/plan/node/PlanNodeType.java | 6 +-
.../db/mpp/plan/planner/plan/node/PlanVisitor.java | 5 +
.../planner/plan/node/process/ml/ForecastNode.java | 122 ++++++++++
.../model/ForecastModelInferenceDescriptor.java | 176 +++++++++++++++
.../parameter/model/ModelInferenceDescriptor.java | 111 +++++++++
.../mpp/plan/statement/component/ResultColumn.java | 3 +-
.../plan/statement/component/SelectComponent.java | 10 +-
.../db/mpp/plan/statement/crud/QueryStatement.java | 52 ++++-
.../mpp/plan/analyze/FakePartitionFetcherImpl.java | 6 +
.../iotdb/db/mpp/plan/plan/distribution/Util.java | 6 +
.../tsfile/file/metadata/enums/TSDataType.java | 10 +
.../iotdb/tsfile/read/common/block/TsBlock.java | 13 ++
.../tsfile/read/common/block/TsBlockBuilder.java | 4 +-
52 files changed, 1576 insertions(+), 68 deletions(-)
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
index ff915a93e97..742b6efce0f 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
@@ -24,6 +24,7 @@ import
org.apache.iotdb.confignode.consensus.request.read.database.CountDatabase
import
org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan;
import
org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan;
import
org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan;
+import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan;
import
org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan;
@@ -460,6 +461,9 @@ public abstract class ConfigPhysicalPlan implements
IConsensusRequest {
case ShowTrail:
plan = new ShowTrailPlan();
break;
+ case GetModelInfo:
+ plan = new GetModelInfoPlan();
+ break;
case CreatePipePlugin:
plan = new CreatePipePluginPlan();
break;
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java
index b4ab2ad4953..d32c5c255de 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java
@@ -162,6 +162,7 @@ public enum ConfigPhysicalPlanType {
DropModel((short) 1203),
ShowModel((short) 1204),
ShowTrail((short) 1205),
+ GetModelInfo((short) 1206),
/** Pipe Plugin */
CreatePipePlugin((short) 1300),
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java
new file mode 100644
index 00000000000..79c7e9e0142
--- /dev/null
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.confignode.consensus.request.read.model;
+
+import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan;
+import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Objects;
+
+public class GetModelInfoPlan extends ConfigPhysicalPlan {
+
+ private String modelId;
+
+ public GetModelInfoPlan() {
+ super(ConfigPhysicalPlanType.GetModelInfo);
+ }
+
+ public GetModelInfoPlan(TGetModelInfoReq getModelInfoReq) {
+ super(ConfigPhysicalPlanType.GetModelInfo);
+ this.modelId = getModelInfoReq.getModelId();
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ @Override
+ protected void serializeImpl(DataOutputStream stream) throws IOException {
+ stream.writeShort(getType().getPlanType());
+ ReadWriteIOUtils.write(modelId, stream);
+ }
+
+ @Override
+ protected void deserializeImpl(ByteBuffer buffer) throws IOException {
+ this.modelId = ReadWriteIOUtils.readString(buffer);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ if (!super.equals(o)) {
+ return false;
+ }
+ GetModelInfoPlan that = (GetModelInfoPlan) o;
+ return Objects.equals(modelId, that.modelId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), modelId);
+ }
+}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java
similarity index 52%
copy from
server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
copy to
confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java
index 734ebb4bef4..5e7ee2641fa 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java
@@ -17,11 +17,28 @@
* under the License.
*/
-package org.apache.iotdb.db.mpp.plan.expression.multi;
+package org.apache.iotdb.confignode.consensus.response.model;
-/** */
-public enum FunctionType {
- AGGREGATION_FUNCTION,
- BUILT_IN_SCALAR_FUNCTION,
- UDF
+import org.apache.iotdb.common.rpc.thrift.TSStatus;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
+import org.apache.iotdb.consensus.common.DataSet;
+
+import java.nio.ByteBuffer;
+
+public class GetModelInfoResp implements DataSet {
+
+ private final TSStatus status;
+ private ByteBuffer serializedModelInformation;
+
+ public GetModelInfoResp(TSStatus status) {
+ this.status = status;
+ }
+
+ public void setModelInfo(ByteBuffer serializedModelInformation) {
+ this.serializedModelInformation = serializedModelInformation;
+ }
+
+ public TGetModelInfoResp convertToThriftResponse() {
+ return new TGetModelInfoResp(status, serializedModelInformation);
+ }
}
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
similarity index 97%
rename from
confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
rename to
confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
index 6642f76be97..9a23d9ed713 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.iotdb.confignode.consensus.response;
+package org.apache.iotdb.confignode.consensus.response.model;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.model.ModelInformation;
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/TrailTableResp.java
similarity index 97%
rename from
confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
rename to
confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/TrailTableResp.java
index 1f9a6b5acbf..3f2c2823dac 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/TrailTableResp.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.iotdb.confignode.consensus.response;
+package org.apache.iotdb.confignode.consensus.response.model;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.model.TrailInformation;
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index 42baef81f85..88f8a7a7e2c 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -132,6 +132,8 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetRegionIdReq;
@@ -1853,6 +1855,14 @@ public class ConfigManager implements IManager {
: status;
}
+ @Override
+ public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
+ TSStatus status = confirmLeader();
+ return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
+ ? modelManager.getModelInfo(req)
+ : new TGetModelInfoResp(status, null);
+ }
+
@Override
public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) {
TSStatus status = confirmLeader();
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
index 9e8560d0aec..ca5f750564c 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
@@ -75,6 +75,8 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetRegionIdReq;
@@ -661,6 +663,9 @@ public interface IManager {
/** Update the model state */
TSStatus updateModelState(TUpdateModelStateReq req);
+ /** Update the model state */
+ TGetModelInfoResp getModelInfo(TGetModelInfoReq req);
+
/** Set space quota */
TSStatus setSpaceQuota(TSetSpaceQuotaReq req);
}
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
index 425bafbd4c0..9a7c1080d96 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
@@ -19,17 +19,23 @@
package org.apache.iotdb.confignode.manager;
+import org.apache.iotdb.common.rpc.thrift.ModelTask;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
+import org.apache.iotdb.commons.model.ForecastModeInformation;
import org.apache.iotdb.commons.model.ModelInformation;
+import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelStatePlan;
-import org.apache.iotdb.confignode.consensus.response.ModelTableResp;
-import org.apache.iotdb.confignode.consensus.response.TrailTableResp;
+import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp;
+import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp;
+import org.apache.iotdb.confignode.consensus.response.model.TrailTableResp;
import org.apache.iotdb.confignode.persistence.ModelInfo;
import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp;
import org.apache.iotdb.confignode.rpc.thrift.TShowTrailReq;
@@ -39,12 +45,17 @@ import
org.apache.iotdb.confignode.rpc.thrift.TUpdateModelStateReq;
import org.apache.iotdb.consensus.common.response.ConsensusReadResponse;
import org.apache.iotdb.consensus.common.response.ConsensusWriteResponse;
import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
+import java.util.Arrays;
import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
public class ModelManager {
@@ -63,15 +74,41 @@ public class ModelManager {
}
public TSStatus createModel(TCreateModelReq req) {
- ModelInformation modelInformation =
- new ModelInformation(
- req.getModelId(),
- req.getModelTask(),
- req.getModelType(),
- req.isIsAuto(),
- req.getQueryExpressions(),
- req.getQueryFilter());
- return configManager.getProcedureManager().createModel(modelInformation,
req.getModelConfigs());
+ ModelTask modelTask = req.getModelTask();
+ Map<String, String> modelConfigs = req.getModelConfigs();
+ ModelInformation modelInformation;
+ switch (modelTask) {
+ case FORECAST:
+ String inputTypeListStr = modelConfigs.get("input_type_list");
+ List<TSDataType> inputTypeList =
+ Arrays.stream(inputTypeListStr.split(","))
+ .sequential()
+ .map(s -> TSDataType.valueOf(s.toUpperCase()))
+ .collect(Collectors.toList());
+
+ String predictIndexListStr = modelConfigs.get("predict_index_list");
+ List<Integer> predictIndexList =
+ Arrays.stream(predictIndexListStr.split(","))
+ .sequential()
+ .map(Integer::valueOf)
+ .collect(Collectors.toList());
+
+ modelInformation =
+ new ForecastModeInformation(
+ req.getModelId(),
+ req.getModelType(),
+ req.isIsAuto(),
+ req.getQueryExpressions(),
+ req.getQueryFilter(),
+ inputTypeList,
+ predictIndexList,
+ Integer.parseInt(modelConfigs.getOrDefault("input_length",
"96")),
+ Integer.parseInt(modelConfigs.getOrDefault("predict_length",
"96")));
+ break;
+ default:
+ throw new IllegalArgumentException("Invalid task type: " + modelTask);
+ }
+ return configManager.getProcedureManager().createModel(modelInformation,
modelConfigs);
}
public TSStatus dropModel(TDropModelReq req) {
@@ -126,7 +163,7 @@ public class ModelManager {
return new TShowModelResp(res, Collections.emptyList());
}
} catch (IOException e) {
- LOGGER.error("Fail to get ModelTable", e);
+ LOGGER.warn("Fail to get ModelTable", e);
return new TShowModelResp(
new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
.setMessage(e.getMessage()),
@@ -148,11 +185,25 @@ public class ModelManager {
return new TShowTrailResp(res, Collections.emptyList());
}
} catch (IOException e) {
- LOGGER.error("Fail to get TrailTable", e);
+ LOGGER.warn("Fail to get TrailTable", e);
return new TShowTrailResp(
new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
.setMessage(e.getMessage()),
Collections.emptyList());
}
}
+
+ public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
+ ConsensusReadResponse response =
+ configManager.getConsensusManager().read(new GetModelInfoPlan(req));
+ if (response.getDataset() != null) {
+ return ((GetModelInfoResp)
response.getDataset()).convertToThriftResponse();
+ } else {
+ LOGGER.warn("Unexpected error happened while getting model: ",
response.getException());
+ // consensus layer related errors
+ TSStatus res = new
TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
+ res.setMessage(response.getException().toString());
+ return new TGetModelInfoResp(res, null);
+ }
+ }
}
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 3c72e09570f..1e14bb6b137 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -24,15 +24,18 @@ import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.model.ModelTable;
import org.apache.iotdb.commons.model.TrailInformation;
import org.apache.iotdb.commons.snapshot.SnapshotProcessor;
+import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelStatePlan;
-import org.apache.iotdb.confignode.consensus.response.ModelTableResp;
-import org.apache.iotdb.confignode.consensus.response.TrailTableResp;
+import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp;
+import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp;
+import org.apache.iotdb.confignode.consensus.response.model.TrailTableResp;
import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.utils.PublicBAOS;
import org.apache.thrift.TException;
import org.slf4j.Logger;
@@ -40,10 +43,12 @@ import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.ThreadSafe;
+import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.concurrent.locks.ReentrantLock;
@ThreadSafe
@@ -155,6 +160,29 @@ public class ModelInfo implements SnapshotProcessor {
}
}
+ public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) {
+ acquireModelTableLock();
+ try {
+ GetModelInfoResp getModelInfoResp =
+ new GetModelInfoResp(new
TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
+ ModelInformation modelInformation =
modelTable.getModelInformationById(plan.getModelId());
+ if (modelInformation != null) {
+ PublicBAOS buffer = new PublicBAOS();
+ DataOutputStream stream = new DataOutputStream(buffer);
+ modelInformation.serialize(stream);
+ getModelInfoResp.setModelInfo(ByteBuffer.wrap(buffer.getBuf(), 0,
buffer.size()));
+ }
+ return getModelInfoResp;
+ } catch (IOException e) {
+ LOGGER.warn("Fail to get model info", e);
+ return new GetModelInfoResp(
+ new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
+ .setMessage(e.getMessage()));
+ } finally {
+ releaseModelTableLock();
+ }
+ }
+
public TSStatus updateModelInfo(UpdateModelInfoPlan plan) {
acquireModelTableLock();
try {
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
index 2cd14b69fc8..1a2485f3671 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
@@ -29,6 +29,7 @@ import
org.apache.iotdb.confignode.consensus.request.auth.AuthorPlan;
import
org.apache.iotdb.confignode.consensus.request.read.database.CountDatabasePlan;
import
org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan;
import
org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan;
+import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan;
import
org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan;
@@ -279,6 +280,8 @@ public class ConfigPlanExecutor {
return modelInfo.showModel((ShowModelPlan) req);
case ShowTrail:
return modelInfo.showTrail((ShowTrailPlan) req);
+ case GetModelInfo:
+ return modelInfo.getModelInfo((GetModelInfoPlan) req);
case GetPipePluginTable:
return pipeInfo.getPipePluginInfo().showPipePlugins();
case GetPipePluginJar:
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
b/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
index 6e73fee6675..92eb56872dd 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
@@ -111,6 +111,8 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipeSinkReq;
@@ -989,6 +991,11 @@ public class ConfigNodeRPCServiceProcessor implements
IConfigNodeRPCService.Ifac
return configManager.updateModelState(req);
}
+ @Override
+ public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
+ return configManager.getModelInfo(req);
+ }
+
@Override
public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException {
return configManager.setSpaceQuota(req);
diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
index e9d41faea32..8fbe508ad9b 100644
--- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
@@ -736,6 +736,10 @@ struct TShowTrailReq {
2: optional string trailId
}
+struct TGetModelInfoReq {
+ 1: required string modelId
+}
+
struct TShowTrailResp {
1: required common.TSStatus status
2: required list<binary> trailInfoList
@@ -753,6 +757,11 @@ struct TUpdateModelStateReq {
3: optional string bestTrailId
}
+struct TGetModelInfoResp {
+ 1: required common.TSStatus status
+ 2: required binary modelInfo
+}
+
// ====================================================
// Quota
// ====================================================
@@ -1378,6 +1387,11 @@ service IConfigNodeRPCService {
*/
common.TSStatus updateModelState(TUpdateModelStateReq req)
+ /**
+ * Return the model info by model_id
+ */
+ TGetModelInfoResp getModelInfo(TGetModelInfoReq req)
+
// ======================================================
// Quota
// ======================================================
diff --git a/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
b/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
index abadc795768..46f7b025f47 100644
--- a/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -36,7 +36,10 @@ struct TDeleteModelReq {
struct TForecastReq {
1: required string modelPath
- 2: required binary dataset
+ 2: required binary inputData
+ 3: required list<string> inputTypeList
+ 4: required list<string> inputColumnNameList
+ 5: required i32 predictLength
}
struct TForecastResp {
diff --git
a/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastModeInformation.java
b/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastModeInformation.java
new file mode 100644
index 00000000000..f32a0d5190f
--- /dev/null
+++
b/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastModeInformation.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.commons.model;
+
+import org.apache.iotdb.common.rpc.thrift.ModelTask;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import javax.annotation.Nullable;
+
+import java.io.DataOutputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ForecastModeInformation extends ModelInformation {
+
+ private final List<TSDataType> inputTypeList;
+
+ private final List<Integer> predictIndexList;
+
+ private final int inputLength;
+ private final int predictLength;
+
+ public ForecastModeInformation(
+ String modelId,
+ String modelType,
+ boolean isAuto,
+ List<String> queryExpressions,
+ @Nullable String queryFilter,
+ List<TSDataType> inputTypeList,
+ List<Integer> predictIndexList,
+ int inputLength,
+ int predictLength) {
+ super(ModelTask.FORECAST, modelId, modelType, isAuto, queryExpressions,
queryFilter);
+ this.inputTypeList = inputTypeList;
+ this.predictIndexList = predictIndexList;
+ this.inputLength = inputLength;
+ this.predictLength = predictLength;
+ }
+
+ public ForecastModeInformation(ByteBuffer buffer) {
+ super(ModelTask.FORECAST, buffer);
+ int listSize = ReadWriteIOUtils.readInt(buffer);
+ this.inputTypeList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ inputTypeList.add(TSDataType.deserializeFrom(buffer));
+ }
+ listSize = ReadWriteIOUtils.readInt(buffer);
+ this.predictIndexList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ predictIndexList.add(ReadWriteIOUtils.readInt(buffer));
+ }
+ this.inputLength = ReadWriteIOUtils.readInt(buffer);
+ this.predictLength = ReadWriteIOUtils.readInt(buffer);
+ }
+
+ public ForecastModeInformation(InputStream stream) throws IOException {
+ super(ModelTask.FORECAST, stream);
+ int listSize = ReadWriteIOUtils.readInt(stream);
+ this.inputTypeList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ inputTypeList.add(TSDataType.deserializeFrom(stream));
+ }
+ listSize = ReadWriteIOUtils.readInt(stream);
+ this.predictIndexList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ predictIndexList.add(ReadWriteIOUtils.readInt(stream));
+ }
+ this.inputLength = ReadWriteIOUtils.readInt(stream);
+ this.predictLength = ReadWriteIOUtils.readInt(stream);
+ }
+
+ public List<TSDataType> getInputTypeList() {
+ return inputTypeList;
+ }
+
+ public List<Integer> getPredictIndexList() {
+ return predictIndexList;
+ }
+
+ public int getInputLength() {
+ return inputLength;
+ }
+
+ public int getPredictLength() {
+ return predictLength;
+ }
+
+ @Override
+ public void serialize(DataOutputStream stream) throws IOException {
+ super.serialize(stream);
+ ReadWriteIOUtils.write(inputTypeList.size(), stream);
+ for (TSDataType inputType : inputTypeList) {
+ inputType.serializeTo(stream);
+ }
+ ReadWriteIOUtils.write(predictIndexList.size(), stream);
+ for (Integer index : predictIndexList) {
+ ReadWriteIOUtils.write(index, stream);
+ }
+ ReadWriteIOUtils.write(inputLength, stream);
+ ReadWriteIOUtils.write(predictLength, stream);
+ }
+
+ @Override
+ public void serialize(FileOutputStream stream) throws IOException {
+ super.serialize(stream);
+ ReadWriteIOUtils.write(inputTypeList.size(), stream);
+ for (TSDataType inputType : inputTypeList) {
+ inputType.serializeTo(stream);
+ }
+ ReadWriteIOUtils.write(predictIndexList.size(), stream);
+ for (Integer index : predictIndexList) {
+ ReadWriteIOUtils.write(index, stream);
+ }
+ ReadWriteIOUtils.write(inputLength, stream);
+ ReadWriteIOUtils.write(predictLength, stream);
+ }
+}
diff --git
a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index 522f609e51e..fc0cdb45dde 100644
---
a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++
b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -39,7 +39,7 @@ import java.util.Map;
import static org.apache.iotdb.commons.model.TrailInformation.MODEL_PATH;
-public class ModelInformation {
+public abstract class ModelInformation {
private final String modelId;
private final ModelTask modelTask;
@@ -55,8 +55,8 @@ public class ModelInformation {
private final Map<String, TrailInformation> trailMap;
public ModelInformation(
- String modelId,
ModelTask modelTask,
+ String modelId,
String modelType,
boolean isAuto,
List<String> queryExpressions,
@@ -71,9 +71,10 @@ public class ModelInformation {
this.trailMap = new HashMap<>();
}
- public ModelInformation(ByteBuffer buffer) {
+ public ModelInformation(ModelTask modelTask, ByteBuffer buffer) {
+ this.modelTask = modelTask;
+
this.modelId = ReadWriteIOUtils.readString(buffer);
- this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(buffer));
this.modelType = ReadWriteIOUtils.readString(buffer);
int listSize = ReadWriteIOUtils.readInt(buffer);
@@ -103,9 +104,10 @@ public class ModelInformation {
}
}
- public ModelInformation(InputStream stream) throws IOException {
+ public ModelInformation(ModelTask modelTask, InputStream stream) throws
IOException {
+ this.modelTask = modelTask;
+
this.modelId = ReadWriteIOUtils.readString(stream);
- this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(stream));
this.modelType = ReadWriteIOUtils.readString(stream);
int listSize = ReadWriteIOUtils.readInt(stream);
@@ -152,6 +154,10 @@ public class ModelInformation {
return queryFilter;
}
+ public boolean available() {
+ return trainingState == TrainingState.FINISHED;
+ }
+
public TrailInformation getTrailInformationById(String trailId) {
if (trailMap.containsKey(trailId)) {
return trailMap.get(trailId);
@@ -186,9 +192,19 @@ public class ModelInformation {
}
}
+ public String getModelPath() {
+ if (bestTrailId != null) {
+ TrailInformation bestTrail = trailMap.get(bestTrailId);
+ return bestTrail.getModelPath();
+ } else {
+ return "UNKNOWN";
+ }
+ }
+
public void serialize(DataOutputStream stream) throws IOException {
- ReadWriteIOUtils.write(modelId, stream);
ReadWriteIOUtils.write(modelTask.ordinal(), stream);
+
+ ReadWriteIOUtils.write(modelId, stream);
ReadWriteIOUtils.write(modelType, stream);
ReadWriteIOUtils.write(queryExpressions.size(), stream);
for (String queryExpression : queryExpressions) {
@@ -219,8 +235,9 @@ public class ModelInformation {
}
public void serialize(FileOutputStream stream) throws IOException {
- ReadWriteIOUtils.write(modelId, stream);
ReadWriteIOUtils.write(modelTask.ordinal(), stream);
+
+ ReadWriteIOUtils.write(modelId, stream);
ReadWriteIOUtils.write(modelType, stream);
ReadWriteIOUtils.write(queryExpressions.size(), stream);
@@ -251,12 +268,32 @@ public class ModelInformation {
}
}
- public static ModelInformation deserialize(InputStream stream) throws
IOException {
- return new ModelInformation(stream);
+ public static ModelInformation deserialize(ByteBuffer buffer) {
+ ModelTask modelTask =
ModelTask.findByValue(ReadWriteIOUtils.readInt(buffer));
+ if (modelTask == null) {
+ throw new IllegalArgumentException();
+ }
+
+ switch (modelTask) {
+ case FORECAST:
+ return new ForecastModeInformation(buffer);
+ default:
+ throw new IllegalArgumentException("Invalid task type: " + modelTask);
+ }
}
- public static ModelInformation deserialize(ByteBuffer buffer) {
- return new ModelInformation(buffer);
+ public static ModelInformation deserialize(InputStream stream) throws
IOException {
+ ModelTask modelTask =
ModelTask.findByValue(ReadWriteIOUtils.readInt(stream));
+ if (modelTask == null) {
+ throw new IllegalArgumentException();
+ }
+
+ switch (modelTask) {
+ case FORECAST:
+ return new ForecastModeInformation(stream);
+ default:
+ throw new IllegalArgumentException("Invalid task type: " + modelTask);
+ }
}
public ByteBuffer serializeShowModelResult() throws IOException {
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/ModelInferenceFunction.java
similarity index 52%
copy from
server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
copy to
node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/ModelInferenceFunction.java
index 734ebb4bef4..f52a0348af6 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
+++
b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/ModelInferenceFunction.java
@@ -17,11 +17,33 @@
* under the License.
*/
-package org.apache.iotdb.db.mpp.plan.expression.multi;
+package org.apache.iotdb.commons.udf.builtin;
-/** */
-public enum FunctionType {
- AGGREGATION_FUNCTION,
- BUILT_IN_SCALAR_FUNCTION,
- UDF
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public enum ModelInferenceFunction {
+ FORECAST("forecast");
+
+ private final String functionName;
+
+ ModelInferenceFunction(String functionName) {
+ this.functionName = functionName;
+ }
+
+ public String getFunctionName() {
+ return functionName;
+ }
+
+ private static final Set<String> NATIVE_FUNCTION_NAMES =
+ new HashSet<>(
+ Arrays.stream(ModelInferenceFunction.values())
+ .map(ModelInferenceFunction::getFunctionName)
+ .collect(Collectors.toList()));
+
+ public static Set<String> getNativeFunctionNames() {
+ return NATIVE_FUNCTION_NAMES;
+ }
}
diff --git
a/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
b/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
index fbdf8684e91..e7ecb32f008 100644
---
a/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
+++
b/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
@@ -22,6 +22,7 @@ package org.apache.iotdb.commons.udf.service;
import org.apache.iotdb.commons.udf.UDFInformation;
import org.apache.iotdb.commons.udf.UDFTable;
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
import org.apache.iotdb.commons.utils.TestOnly;
import org.apache.iotdb.udf.api.UDF;
import org.apache.iotdb.udf.api.UDTF;
@@ -111,8 +112,26 @@ public class UDFManagementService {
throw new UDFManagementException(errorMessage);
}
+ private void checkIsModelInferenceFunctionName(UDFInformation udfInformation)
+ throws UDFManagementException {
+ String functionName = udfInformation.getFunctionName();
+ String className = udfInformation.getClassName();
+ if
(!ModelInferenceFunction.getNativeFunctionNames().contains(functionName.toLowerCase()))
{
+ return;
+ }
+
+ String errorMessage =
+ String.format(
+ "Failed to register UDF %s(%s), because the given function name
conflicts with the ML model inference function name",
+ functionName, className);
+
+ LOGGER.warn(errorMessage);
+ throw new UDFManagementException(errorMessage);
+ }
+
private void checkIfRegistered(UDFInformation udfInformation) throws
UDFManagementException {
checkIsBuiltInAggregationFunctionName(udfInformation);
+ checkIsModelInferenceFunctionName(udfInformation);
String functionName = udfInformation.getFunctionName();
String className = udfInformation.getClassName();
UDFInformation information = udfTable.getUDFInformation(functionName);
diff --git
a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
index f9d69c2caa4..31f85adaa75 100644
--- a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
+++ b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
@@ -80,6 +80,8 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipeSinkReq;
@@ -2127,6 +2129,22 @@ public class ConfigNodeClient implements
IConfigNodeRPCService.Iface, ThriftClie
throw new TException(new UnsupportedOperationException().getCause());
}
+ @Override
+ public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) throws
TException {
+ for (int i = 0; i < RETRY_NUM; i++) {
+ try {
+ TGetModelInfoResp resp = client.getModelInfo(req);
+ if (!updateConfigNodeLeader(resp.getStatus())) {
+ return resp;
+ }
+ } catch (TException e) {
+ configLeader = null;
+ }
+ waitAndReconnect();
+ }
+ throw new TException(MSG_RECONNECTION_FAIL);
+ }
+
@Override
public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException {
for (int i = 0; i < RETRY_NUM; i++) {
diff --git a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
index 1ff54d43b6f..768bf7a74da 100644
--- a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
+++ b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
@@ -30,7 +30,7 @@ import org.apache.iotdb.mlnode.rpc.thrift.TDeleteModelReq;
import org.apache.iotdb.mlnode.rpc.thrift.TForecastReq;
import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp;
import org.apache.iotdb.rpc.TConfigurationConst;
-import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import org.apache.iotdb.tsfile.read.common.block.TsBlock;
import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
@@ -45,6 +45,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -115,14 +117,26 @@ public class MLNodeClient implements AutoCloseable {
}
}
- public TsBlock forecast(String modelPath, TsBlock inputTsBlock) throws
TException {
+ public TForecastResp forecast(
+ String modelPath,
+ TsBlock inputTsBlock,
+ List<TSDataType> inputTypeList,
+ List<String> inputColumnNameList,
+ int predictLength)
+ throws TException {
try {
- TForecastReq forecastReq = new TForecastReq(modelPath,
tsBlockSerde.serialize(inputTsBlock));
- TForecastResp resp = client.forecast(forecastReq);
- if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- throw new TException("Failed to execute forecast task, because: " +
resp.status.message);
+ List<String> reqInputTypeList = new ArrayList<>();
+ for (TSDataType dataType : inputTypeList) {
+ reqInputTypeList.add(dataType.toString());
}
- return tsBlockSerde.deserialize(resp.forecastResult);
+ TForecastReq forecastReq =
+ new TForecastReq(
+ modelPath,
+ tsBlockSerde.serialize(inputTsBlock),
+ reqInputTypeList,
+ inputColumnNameList,
+ predictLength);
+ return client.forecast(forecastReq);
} catch (IOException e) {
throw new TException("An exception occurred while serializing input
tsblock", e);
} catch (TException e) {
diff --git a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
index cdd894b5d47..ee4757d2b71 100644
--- a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
+++ b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
@@ -722,6 +722,9 @@ public class IoTDBConfig {
/** The number of threads in the thread pool that execute insert-tablet
tasks. */
private int intoOperationExecutionThreadCount = 2;
+ /** The number of threads in the thread pool that execute model inference
tasks. */
+ private int modelInferenceExecutionThreadCount = 10;
+
/** Default TSfile storage is in local file system */
private FSType tsFileStorageFs = FSType.LOCAL;
@@ -2045,6 +2048,14 @@ public class IoTDBConfig {
this.intoOperationExecutionThreadCount = intoOperationExecutionThreadCount;
}
+ public int getModelInferenceExecutionThreadCount() {
+ return modelInferenceExecutionThreadCount;
+ }
+
+ public void setModelInferenceExecutionThreadCount(int
modelInferenceExecutionThreadCount) {
+ this.modelInferenceExecutionThreadCount =
modelInferenceExecutionThreadCount;
+ }
+
public int getCompactionWriteThroughputMbPerSec() {
return compactionWriteThroughputMbPerSec;
}
diff --git a/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
b/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
index 4f18776b6b5..976862547cb 100644
--- a/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
+++ b/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
@@ -78,6 +78,9 @@ public class SqlConstant {
public static final String SUBSTRING_IS_STANDARD = "isStandard";
public static final String SUBSTRING_FOR = "FOR";
+ public static final String MODEL_ID = "model_id";
+ public static final String PREDICT_LENGTH = "predict_length";
+
public static String[] getSingleRootArray() {
return SINGLE_ROOT_ARRAY;
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
b/server/src/main/java/org/apache/iotdb/db/exception/ModelInferenceProcessException.java
similarity index 80%
copy from
server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
copy to
server/src/main/java/org/apache/iotdb/db/exception/ModelInferenceProcessException.java
index 734ebb4bef4..1ddb212f113 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
+++
b/server/src/main/java/org/apache/iotdb/db/exception/ModelInferenceProcessException.java
@@ -17,11 +17,11 @@
* under the License.
*/
-package org.apache.iotdb.db.mpp.plan.expression.multi;
+package org.apache.iotdb.db.exception;
-/** */
-public enum FunctionType {
- AGGREGATION_FUNCTION,
- BUILT_IN_SCALAR_FUNCTION,
- UDF
+public class ModelInferenceProcessException extends RuntimeException {
+
+ public ModelInferenceProcessException(String message) {
+ super(message);
+ }
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
index 74e2985207d..1ce3b5acc51 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
@@ -105,6 +105,11 @@ public class FragmentInstanceManager {
IoTDBThreadPoolFactory.newFixedThreadPool(
IoTDBDescriptor.getInstance().getConfig().getIntoOperationExecutionThreadCount(),
"into-operation-executor");
+
+ this.modelInferenceExecutor =
+ IoTDBThreadPoolFactory.newFixedThreadPool(
+
IoTDBDescriptor.getInstance().getConfig().getModelInferenceExecutionThreadCount(),
+ "model-inference-executor");
}
public int getInstanceContextSize() {
@@ -330,6 +335,10 @@ public class FragmentInstanceManager {
return intoOperationExecutor;
}
+ public ExecutorService getModelInferenceExecutor() {
+ return modelInferenceExecutor;
+ }
+
private static class InstanceHolder {
private InstanceHolder() {}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
index 586ff48be35..da8efa94e74 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
@@ -223,7 +223,7 @@ public class AggregationUtil {
return timeValueColumnsSizePerLine;
}
- private static long getOutputColumnSizePerLine(
+ public static long getOutputColumnSizePerLine(
TSDataType tsDataType, PartialPath inputSeriesPath) {
switch (tsDataType) {
case INT32:
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
new file mode 100644
index 00000000000..b7375985942
--- /dev/null
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.execution.operator.process.ml;
+
+import org.apache.iotdb.db.client.MLNodeClient;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
+import org.apache.iotdb.db.exception.ModelInferenceProcessException;
+import org.apache.iotdb.db.mpp.execution.operator.Operator;
+import org.apache.iotdb.db.mpp.execution.operator.OperatorContext;
+import org.apache.iotdb.db.mpp.execution.operator.process.ProcessOperator;
+import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.read.common.block.TsBlock;
+import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.ColumnBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.TimeColumnBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
+
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.thrift.TException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+
+import static com.google.common.util.concurrent.Futures.successfulAsList;
+
+public class ForecastOperator implements ProcessOperator {
+
+ private static final Logger LOGGER =
LoggerFactory.getLogger(ProcessOperator.class);
+
+ private final OperatorContext operatorContext;
+ private final Operator child;
+
+ private final String modelPath;
+ private final List<TSDataType> inputTypeList;
+ private final List<String> inputColumnNameList;
+ private final int expectedPredictLength;
+
+ private final TsBlockBuilder inputTsBlockBuilder;
+
+ private MLNodeClient client;
+ private final ExecutorService modelInferenceExecutor;
+ private ListenableFuture<TForecastResp> forecastExecutionFuture;
+
+ private boolean finished = false;
+
+ private final long maxRetainedSize;
+ private final long maxReturnSize;
+
+ public ForecastOperator(
+ OperatorContext operatorContext,
+ Operator child,
+ String modelPath,
+ List<TSDataType> inputTypeList,
+ List<String> inputColumnNameList,
+ int expectedPredictLength,
+ ExecutorService modelInferenceExecutor,
+ long maxRetainedSize,
+ long maxReturnSize) {
+ this.operatorContext = operatorContext;
+ this.child = child;
+ this.modelPath = modelPath;
+ this.inputTypeList = inputTypeList;
+ this.inputColumnNameList = inputColumnNameList;
+ this.expectedPredictLength = expectedPredictLength;
+ this.inputTsBlockBuilder = new TsBlockBuilder(inputTypeList);
+ this.modelInferenceExecutor = modelInferenceExecutor;
+ this.maxRetainedSize = maxRetainedSize;
+ this.maxReturnSize = maxReturnSize;
+ }
+
+ @Override
+ public OperatorContext getOperatorContext() {
+ return operatorContext;
+ }
+
+ @Override
+ public ListenableFuture<?> isBlocked() {
+ ListenableFuture<?> childBlocked = child.isBlocked();
+ boolean executionDone = forecastExecutionDone();
+ if (executionDone && childBlocked.isDone()) {
+ return NOT_BLOCKED;
+ } else if (childBlocked.isDone()) {
+ return forecastExecutionFuture;
+ } else if (executionDone) {
+ return childBlocked;
+ } else {
+ return successfulAsList(Arrays.asList(forecastExecutionFuture,
childBlocked));
+ }
+ }
+
+ private boolean forecastExecutionDone() {
+ if (forecastExecutionFuture == null) {
+ return true;
+ }
+ return forecastExecutionFuture.isDone();
+ }
+
+ @Override
+ public boolean hasNext() throws Exception {
+ return !finished;
+ }
+
+ @Override
+ public TsBlock next() throws Exception {
+ if (forecastExecutionFuture == null) {
+ if (child.hasNextWithTimer()) {
+ TsBlock inputTsBlock = child.nextWithTimer();
+ if (inputTsBlock != null) {
+ appendTsBlockToBuilder(inputTsBlock);
+ }
+ } else {
+ submitForecastTask();
+ }
+ return null;
+ } else {
+ try {
+ if (!forecastExecutionFuture.isDone()) {
+ throw new IllegalStateException(
+ "The operator cannot continue until the forecast execution is
done.");
+ }
+
+ TForecastResp forecastResp = forecastExecutionFuture.get();
+ if (forecastResp.getStatus().getCode() !=
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ String message =
+ String.format(
+ "Error occurred while executing forecast: %s",
+ forecastResp.getStatus().getMessage());
+ throw new ModelInferenceProcessException(message);
+ }
+
+ finished = true;
+ TsBlock resultTsBlock =
+ new
TsBlockSerde().deserialize(forecastResp.bufferForForecastResult());
+ resultTsBlock = modifyTimeColumn(resultTsBlock);
+ return resultTsBlock;
+ } catch (InterruptedException e) {
+ LOGGER.warn(
+ "{}: interrupted when processing write operation future with
exception {}", this, e);
+ Thread.currentThread().interrupt();
+ throw new ModelInferenceProcessException(e.getMessage());
+ } catch (ExecutionException e) {
+ throw new ModelInferenceProcessException(e.getMessage());
+ }
+ }
+ }
+
+ private TsBlock modifyTimeColumn(TsBlock resultTsBlock) {
+ long delta =
+
IoTDBDescriptor.getInstance().getConfig().getTimestampPrecision().equals("ms")
+ ? 1_000_000L
+ : 1_000L;
+
+ TsBlockBuilder newTsBlockBuilder =
TsBlockBuilder.createWithOnlyTimeColumn();
+ TimeColumnBuilder timeColumnBuilder =
newTsBlockBuilder.getTimeColumnBuilder();
+ for (int i = 0; i < resultTsBlock.getPositionCount(); i++) {
+ timeColumnBuilder.writeLong(resultTsBlock.getTimeByIndex(i) / delta);
+ newTsBlockBuilder.declarePosition();
+ }
+ return
newTsBlockBuilder.build().appendValueColumns(resultTsBlock.getValueColumns());
+ }
+
+ private void appendTsBlockToBuilder(TsBlock inputTsBlock) {
+ TimeColumnBuilder timeColumnBuilder =
inputTsBlockBuilder.getTimeColumnBuilder();
+ ColumnBuilder[] columnBuilders =
inputTsBlockBuilder.getValueColumnBuilders();
+
+ for (int i = 0; i < inputTsBlock.getPositionCount(); i++) {
+ timeColumnBuilder.writeLong(inputTsBlock.getTimeByIndex(i));
+ for (int columnIndex = 0; columnIndex <
inputTsBlock.getValueColumnCount(); columnIndex++) {
+ columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex),
i);
+ }
+ inputTsBlockBuilder.declarePosition();
+ }
+ }
+
+ private void submitForecastTask() {
+ try {
+ if (client == null) {
+ client = new MLNodeClient();
+ }
+ } catch (TException e) {
+ throw new ModelInferenceProcessException(e.getMessage());
+ }
+
+ TsBlock inputTsBlock = inputTsBlockBuilder.build();
+ inputTsBlock.reverse();
+
+ forecastExecutionFuture =
+ Futures.submit(
+ () ->
+ client.forecast(
+ modelPath,
+ inputTsBlock,
+ inputTypeList,
+ inputColumnNameList,
+ expectedPredictLength),
+ modelInferenceExecutor);
+ }
+
+ @Override
+ public boolean isFinished() throws Exception {
+ return finished;
+ }
+
+ @Override
+ public void close() throws Exception {
+ client.close();
+ if (forecastExecutionFuture != null) {
+ forecastExecutionFuture.cancel(true);
+ }
+ child.close();
+ }
+
+ @Override
+ public long calculateMaxPeekMemory() {
+ return maxReturnSize + maxRetainedSize + child.calculateMaxPeekMemory();
+ }
+
+ @Override
+ public long calculateMaxReturnSize() {
+ return maxReturnSize;
+ }
+
+ @Override
+ public long calculateRetainedSizeAfterCallingNext() {
+ return maxRetainedSize + child.calculateRetainedSizeAfterCallingNext();
+ }
+}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analysis.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analysis.java
index b50ef2dfd2a..c85142a72f1 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analysis.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analysis.java
@@ -39,6 +39,7 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByParameter;
import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByTimeParameter;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.IntoPathDescriptor;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OrderByParameter;
+import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
import org.apache.iotdb.db.mpp.plan.statement.Statement;
import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
import org.apache.iotdb.db.mpp.plan.statement.component.SortItem;
@@ -206,6 +207,8 @@ public class Analysis {
// indicate whether the Nodes produce source data are VirtualSourceNodes
private boolean isVirtualSource = false;
+ private ModelInferenceDescriptor modelInferenceDescriptor;
+
/////////////////////////////////////////////////////////////////////////////////////////////////
// SELECT INTO Analysis
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -709,4 +712,12 @@ public class Analysis {
public void setTimeseriesOrderingForLastQuery(Ordering
timeseriesOrderingForLastQuery) {
this.timeseriesOrderingForLastQuery = timeseriesOrderingForLastQuery;
}
+
+ public ModelInferenceDescriptor getModelInferenceDescriptor() {
+ return modelInferenceDescriptor;
+ }
+
+ public void setModelInferenceDescriptor(ModelInferenceDescriptor
modelInferenceDescriptor) {
+ this.modelInferenceDescriptor = modelInferenceDescriptor;
+ }
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
index a96bd165a59..7295ce923b5 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
@@ -25,6 +25,8 @@ import org.apache.iotdb.commons.conf.IoTDBConstant;
import org.apache.iotdb.commons.exception.IllegalPathException;
import org.apache.iotdb.commons.exception.IoTDBException;
import org.apache.iotdb.commons.exception.MetadataException;
+import org.apache.iotdb.commons.model.ForecastModeInformation;
+import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.partition.DataPartition;
import org.apache.iotdb.commons.partition.DataPartitionQueryParam;
import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition;
@@ -35,6 +37,8 @@ import org.apache.iotdb.commons.path.PathPatternTree;
import org.apache.iotdb.commons.schema.view.LogicalViewSchema;
import org.apache.iotdb.commons.schema.view.viewExpression.ViewExpression;
import org.apache.iotdb.commons.service.metric.PerformanceOverviewMetrics;
+import
org.apache.iotdb.commons.service.metric.enums.PerformanceOverviewMetrics;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
import org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
import org.apache.iotdb.db.client.ConfigNodeClient;
import org.apache.iotdb.db.client.ConfigNodeClientManager;
@@ -79,6 +83,8 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByTimeParameter;
import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByVariationParameter;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.IntoPathDescriptor;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OrderByParameter;
+import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
+import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
import org.apache.iotdb.db.mpp.plan.statement.Statement;
import org.apache.iotdb.db.mpp.plan.statement.StatementNode;
import org.apache.iotdb.db.mpp.plan.statement.StatementVisitor;
@@ -90,6 +96,7 @@ import
org.apache.iotdb.db.mpp.plan.statement.component.GroupBySessionComponent;
import org.apache.iotdb.db.mpp.plan.statement.component.GroupByTimeComponent;
import
org.apache.iotdb.db.mpp.plan.statement.component.GroupByVariationComponent;
import org.apache.iotdb.db.mpp.plan.statement.component.IntoComponent;
+import org.apache.iotdb.db.mpp.plan.statement.component.OrderByComponent;
import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
import org.apache.iotdb.db.mpp.plan.statement.component.ResultColumn;
import org.apache.iotdb.db.mpp.plan.statement.component.SortItem;
@@ -187,6 +194,8 @@ import static
org.apache.iotdb.commons.conf.IoTDBConstant.ALLOWED_SCHEMA_PROPS;
import static org.apache.iotdb.commons.conf.IoTDBConstant.DEADBAND;
import static org.apache.iotdb.commons.conf.IoTDBConstant.LOSS;
import static
org.apache.iotdb.commons.conf.IoTDBConstant.ONE_LEVEL_PATH_WILDCARD;
+import static
org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction.FORECAST;
+import static org.apache.iotdb.db.constant.SqlConstant.MODEL_ID;
import static
org.apache.iotdb.db.metadata.view.viewExpression.visitor.GetSourcePathsVisitor.getSourcePaths;
import static
org.apache.iotdb.db.mpp.common.header.ColumnHeaderConstant.DEVICE;
import static
org.apache.iotdb.db.mpp.common.header.ColumnHeaderConstant.ENDTIME;
@@ -240,12 +249,16 @@ public class AnalyzeVisitor extends
StatementVisitor<Analysis, MPPQueryContext>
try {
// check for semantic errors
queryStatement.semanticCheck();
+ analysis.setStatement(queryStatement);
+
+ if (queryStatement.isModelInferenceQuery()) {
+ analyzeModelInference(analysis, queryStatement);
+ }
// concat path and construct path pattern tree
PathPatternTree patternTree = new
PathPatternTree(queryStatement.useWildcard());
queryStatement =
(QueryStatement) new ConcatPathRewriter().rewrite(queryStatement,
patternTree);
- analysis.setStatement(queryStatement);
// request schema fetch API
long startTime = System.nanoTime();
@@ -375,6 +388,47 @@ public class AnalyzeVisitor extends
StatementVisitor<Analysis, MPPQueryContext>
return analysis;
}
+ private void analyzeModelInference(Analysis analysis, QueryStatement
queryStatement) {
+ FunctionExpression modelInferenceExpression =
+ (FunctionExpression)
+
queryStatement.getSelectComponent().getResultColumns().get(0).getExpression();
+ String modelId =
modelInferenceExpression.getFunctionAttributes().get(MODEL_ID);
+
+ ModelInformation modelInformation =
partitionFetcher.getModelInformation(modelId);
+ if (modelInformation == null || !modelInformation.available()) {
+ throw new SemanticException("");
+ }
+
+ ModelInferenceFunction functionType =
+
ModelInferenceFunction.valueOf(modelInferenceExpression.getFunctionName().toUpperCase());
+ switch (functionType) {
+ case FORECAST:
+ ForecastModelInferenceDescriptor modelInferenceDescriptor =
+ new ForecastModelInferenceDescriptor(
+ functionType, (ForecastModeInformation) modelInformation);
+ Map<String, String> modelInferenceAttributes =
+ modelInferenceExpression.getFunctionAttributes();
+ if (modelInferenceAttributes.containsKey("predict_length")) {
+ modelInferenceDescriptor.setExpectedPredictLength(
+
Integer.parseInt(modelInferenceAttributes.get("predict_length")));
+ }
+ analysis.setModelInferenceDescriptor(modelInferenceDescriptor);
+
+ List<ResultColumn> newResultColumns = new ArrayList<>();
+ for (Expression inputExpression :
modelInferenceExpression.getExpressions()) {
+ newResultColumns.add(new ResultColumn(inputExpression,
ResultColumn.ColumnType.RAW));
+ }
+ queryStatement.getSelectComponent().setResultColumns(newResultColumns);
+
+ OrderByComponent descTimeOrder = new OrderByComponent();
+ descTimeOrder.addSortItem(new SortItem("TIME", Ordering.DESC));
+ queryStatement.setOrderByComponent(descTimeOrder);
+ break;
+ default:
+ throw new IllegalArgumentException("");
+ }
+ }
+
private Analysis finishQuery(QueryStatement queryStatement, Analysis
analysis) {
if (queryStatement.isSelectInto()) {
analysis.setRespDatasetHeader(
@@ -1252,6 +1306,53 @@ public class AnalyzeVisitor extends
StatementVisitor<Analysis, MPPQueryContext>
return;
}
+ if (queryStatement.isModelInferenceQuery()) {
+ List<ColumnHeader> columnHeaders = new ArrayList<>();
+ boolean isIgnoreTimestamp;
+
+ ModelInferenceDescriptor modelInferenceDescriptor =
analysis.getModelInferenceDescriptor();
+ switch (modelInferenceDescriptor.getFunctionType()) {
+ case FORECAST:
+ isIgnoreTimestamp = false;
+ ForecastModelInferenceDescriptor forecastModelInferenceDescriptor =
+ (ForecastModelInferenceDescriptor) modelInferenceDescriptor;
+
+ List<TSDataType> inputTypeList =
forecastModelInferenceDescriptor.getInputTypeList();
+ if (outputExpressions.size() != inputTypeList.size()) {
+ throw new SemanticException("");
+ }
+ for (int i = 0; i < inputTypeList.size(); i++) {
+ Expression inputExpression = outputExpressions.get(i).left;
+ if (analysis.getType(inputExpression) != inputTypeList.get(i)) {
+ throw new SemanticException("");
+ }
+ }
+
+ List<FunctionExpression> modelInferenceOutputExpressions = new
ArrayList<>();
+ for (int predictIndex :
forecastModelInferenceDescriptor.getPredictIndexList()) {
+ Expression inputExpression =
outputExpressions.get(predictIndex).left;
+ FunctionExpression modelInferenceOutputExpression =
+ new FunctionExpression(
+ FORECAST.getFunctionName(),
+ forecastModelInferenceDescriptor.getOutputAttributes(),
+ Collections.singletonList(inputExpression));
+ analyzeExpression(analysis, modelInferenceOutputExpression);
+
modelInferenceOutputExpressions.add(modelInferenceOutputExpression);
+ columnHeaders.add(
+ new ColumnHeader(
+ modelInferenceOutputExpression.toString(),
+ analysis.getType(modelInferenceOutputExpression)));
+ }
+ forecastModelInferenceDescriptor.setModelInferenceOutputExpressions(
+ modelInferenceOutputExpressions);
+ break;
+ default:
+ throw new SemanticException("");
+ }
+ analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders,
isIgnoreTimestamp));
+ return;
+ }
+
boolean isIgnoreTimestamp = queryStatement.isAggregationQuery() &&
!queryStatement.isGroupBy();
List<ColumnHeader> columnHeaders = new ArrayList<>();
if (queryStatement.isAlignByDevice()) {
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java
index 7bf6cfecd31..98063c19b74 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java
@@ -26,6 +26,7 @@ import org.apache.iotdb.commons.client.IClientManager;
import org.apache.iotdb.commons.client.exception.ClientManagerException;
import org.apache.iotdb.commons.consensus.ConfigRegionId;
import org.apache.iotdb.commons.exception.IoTDBException;
+import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.partition.DataPartition;
import org.apache.iotdb.commons.partition.DataPartitionQueryParam;
import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition;
@@ -34,6 +35,8 @@ import
org.apache.iotdb.commons.partition.executor.SeriesPartitionExecutor;
import org.apache.iotdb.commons.path.PathPatternTree;
import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionTableResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TSchemaNodeManagementReq;
import org.apache.iotdb.confignode.rpc.thrift.TSchemaNodeManagementResp;
import org.apache.iotdb.confignode.rpc.thrift.TSchemaPartitionReq;
@@ -293,6 +296,28 @@ public class ClusterPartitionFetcher implements
IPartitionFetcher {
partitionCache.invalidAllCache();
}
+ @Override
+ public ModelInformation getModelInformation(String modelId) {
+ try (ConfigNodeClient client =
+ configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID))
{
+ TGetModelInfoResp getModelInfoResp = client.getModelInfo(new
TGetModelInfoReq(modelId));
+ if (getModelInfoResp.getStatus().getCode() ==
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ if (getModelInfoResp.modelInfo != null) {
+ return ModelInformation.deserialize(getModelInfoResp.modelInfo);
+ } else {
+ return null;
+ }
+ } else {
+ throw new StatementAnalyzeException(
+ "An error occurred when executing getModelInformation():"
+ + getModelInfoResp.getStatus().getMessage());
+ }
+ } catch (ClientManagerException | TException e) {
+ throw new StatementAnalyzeException(
+ "An error occurred when executing getModelInformation():" +
e.getMessage());
+ }
+ }
+
/** split data partition query param by database */
private Map<String, List<DataPartitionQueryParam>>
splitDataPartitionQueryParam(
List<DataPartitionQueryParam> dataPartitionQueryParams, boolean
isAutoCreate) {
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
index f283eb762b0..7187579ddbc 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionAnalyzer.java
@@ -187,6 +187,8 @@ public class ExpressionAnalyzer {
}
}
return ResultColumn.ColumnType.AGGREGATION;
+ } else if (((FunctionExpression) expression).isModelInferenceFunction())
{
+ return ResultColumn.ColumnType.MODEL_INFERENCE;
} else {
ResultColumn.ColumnType checkedType = null;
int lastCheckedIndex = 0;
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
index 1f47297c458..73f3c09fa6d 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
@@ -266,6 +266,8 @@ public class ExpressionTypeAnalyzer {
functionExpression,
TypeInferenceUtils.getBuiltInScalarFunctionDataType(
functionExpression,
expressionTypes.get(NodeRef.of(inputExpressions.get(0)))));
+ } else if (functionExpression.isModelInferenceFunction()) {
+ return setExpressionType(functionExpression, TSDataType.DOUBLE);
} else {
return setExpressionType(
functionExpression,
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/IPartitionFetcher.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/IPartitionFetcher.java
index e2c4b4a7c88..d50bf23fdb5 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/IPartitionFetcher.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/IPartitionFetcher.java
@@ -18,6 +18,7 @@
*/
package org.apache.iotdb.db.mpp.plan.analyze;
+import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.partition.DataPartition;
import org.apache.iotdb.commons.partition.DataPartitionQueryParam;
import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition;
@@ -87,4 +88,6 @@ public interface IPartitionFetcher {
/** Invalid all partition cache */
void invalidAllCache();
+
+ ModelInformation getModelInformation(String modelId);
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index 53061ae18df..cd3e94e6da4 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -2182,6 +2182,20 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
Expression whereExpression = analysis.getWhereExpression();
String queryFilter = whereExpression == null ? null :
whereExpression.getExpressionString();
+ Map<String, String> modelConfigs = createModelStatement.getAttributes();
+ if (!modelConfigs.containsKey("input_type_list")) {
+ String inputTypeListStr =
analysis.getRespDatasetHeader().getRespDataTypeList().toString();
+ modelConfigs.put(
+ "input_type_list", inputTypeListStr.substring(1,
inputTypeListStr.length() - 1));
+ }
+ if (!modelConfigs.containsKey("predict_index_list")) {
+ StringBuilder predictIndexListStr = new StringBuilder("0");
+ for (int i = 1; i <
analysis.getRespDatasetHeader().getOutputValueColumnCount(); i++) {
+ predictIndexListStr.append(",").append(i);
+ }
+ modelConfigs.put("predict_index_list", predictIndexListStr.toString());
+ }
+
SettableFuture<ConfigTaskResult> future = SettableFuture.create();
try (ConfigNodeClient client =
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
@@ -2192,7 +2206,7 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
createModelReq.setIsAuto(createModelStatement.isAuto());
createModelReq.setQueryExpressions(queryExpressions);
createModelReq.setQueryFilter(queryFilter);
- createModelReq.setModelConfigs(createModelStatement.getAttributes());
+ createModelReq.setModelConfigs(modelConfigs);
final TSStatus executionStatus = client.createModel(createModelReq);
if (TSStatusCode.SUCCESS_STATUS.getStatusCode() !=
executionStatus.getCode()) {
LOGGER.warn(
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
index af515d13179..771047ce7fd 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.commons.conf.IoTDBConstant;
import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
import org.apache.iotdb.db.mpp.common.NodeRef;
import org.apache.iotdb.db.mpp.plan.expression.Expression;
import org.apache.iotdb.db.mpp.plan.expression.ExpressionType;
@@ -105,6 +106,8 @@ public class FunctionExpression extends Expression {
functionType = FunctionType.AGGREGATION_FUNCTION;
} else if
(BuiltinScalarFunction.getNativeFunctionNames().contains(functionName)) {
functionType = FunctionType.BUILT_IN_SCALAR_FUNCTION;
+ } else if
(ModelInferenceFunction.getNativeFunctionNames().contains(functionName)) {
+ functionType = FunctionType.MODEL_INFERENCE_FUNCTION;
} else {
functionType = FunctionType.UDF;
}
@@ -125,6 +128,13 @@ public class FunctionExpression extends Expression {
return functionType == FunctionType.BUILT_IN_SCALAR_FUNCTION;
}
+ public boolean isModelInferenceFunction() {
+ if (functionType == null) {
+ initializeFunctionType();
+ }
+ return functionType == FunctionType.MODEL_INFERENCE_FUNCTION;
+ }
+
@Override
public boolean isConstantOperandInternal() {
if (isConstantOperandCache == null) {
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
index 734ebb4bef4..c7e9d4ade2c 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
@@ -23,5 +23,6 @@ package org.apache.iotdb.db.mpp.plan.expression.multi;
public enum FunctionType {
AGGREGATION_FUNCTION,
BUILT_IN_SCALAR_FUNCTION,
- UDF
+ UDF,
+ MODEL_INFERENCE_FUNCTION
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
index f2d19fe9025..178ee26a6fa 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
@@ -72,6 +72,7 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SortNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesScanNode;
@@ -89,6 +90,7 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByParameter;
import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByTimeParameter;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.IntoPathDescriptor;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OrderByParameter;
+import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
import org.apache.iotdb.db.mpp.plan.statement.component.OrderByKey;
import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
import org.apache.iotdb.db.mpp.plan.statement.component.SortItem;
@@ -1328,4 +1330,12 @@ public class LogicalPlanBuilder {
}
return this;
}
+
+ public LogicalPlanBuilder planForecast(
+ ForecastModelInferenceDescriptor forecastModelInferenceDescriptor) {
+ this.root =
+ new ForecastNode(
+ context.getQueryId().genPlanNodeId(), root,
forecastModelInferenceDescriptor);
+ return this;
+ }
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
index 14a49611613..8fc5defed6a 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
@@ -48,6 +48,8 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.node.write.InsertRowsNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.write.InsertRowsOfOneDeviceNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.write.InsertTabletNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
+import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
+import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
import org.apache.iotdb.db.mpp.plan.statement.StatementNode;
import org.apache.iotdb.db.mpp.plan.statement.StatementVisitor;
import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
@@ -208,6 +210,21 @@ public class LogicalPlanVisitor extends
StatementVisitor<PlanNode, MPPQueryConte
.planOffset(queryStatement.getRowOffset())
.planLimit(queryStatement.getRowLimit());
+ if (queryStatement.isModelInferenceQuery()) {
+ ModelInferenceDescriptor modelInferenceDescriptor =
analysis.getModelInferenceDescriptor();
+ switch (modelInferenceDescriptor.getFunctionType()) {
+ case FORECAST:
+ ForecastModelInferenceDescriptor forecastModelInferenceDescriptor =
+ (ForecastModelInferenceDescriptor) modelInferenceDescriptor;
+ planBuilder
+
.planLimit(forecastModelInferenceDescriptor.getModelInputLength())
+ .planForecast(forecastModelInferenceDescriptor);
+ break;
+ default:
+ throw new IllegalArgumentException();
+ }
+ }
+
// plan select into
if (queryStatement.isAlignByDevice()) {
planBuilder =
planBuilder.planDeviceViewInto(analysis.getDeviceViewIntoPathDescriptor());
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
index 9273acd3361..2ad6289520b 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
@@ -104,6 +104,7 @@ import
org.apache.iotdb.db.mpp.execution.operator.process.last.LastQuerySortOper
import org.apache.iotdb.db.mpp.execution.operator.process.last.LastQueryUtil;
import
org.apache.iotdb.db.mpp.execution.operator.process.last.UpdateLastCacheOperator;
import
org.apache.iotdb.db.mpp.execution.operator.process.last.UpdateViewPathLastCacheOperator;
+import org.apache.iotdb.db.mpp.execution.operator.process.ml.ForecastOperator;
import
org.apache.iotdb.db.mpp.execution.operator.schema.CountGroupByLevelMergeOperator;
import
org.apache.iotdb.db.mpp.execution.operator.schema.CountGroupByLevelScanOperator;
import org.apache.iotdb.db.mpp.execution.operator.schema.CountMergeOperator;
@@ -180,6 +181,7 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
@@ -203,6 +205,7 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.InputLocation;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.IntoPathDescriptor;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OutputColumn;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.SeriesScanOptions;
+import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
import org.apache.iotdb.db.mpp.plan.statement.component.FillPolicy;
import org.apache.iotdb.db.mpp.plan.statement.component.OrderByKey;
import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
@@ -216,6 +219,8 @@ import org.apache.iotdb.db.utils.datastructure.TimeSelector;
import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import org.apache.iotdb.tsfile.read.TimeValuePair;
import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.DoubleColumn;
+import org.apache.iotdb.tsfile.read.common.block.column.TimeColumn;
import org.apache.iotdb.tsfile.read.filter.basic.Filter;
import org.apache.iotdb.tsfile.read.filter.operator.Gt;
import org.apache.iotdb.tsfile.read.filter.operator.GtEq;
@@ -244,6 +249,7 @@ import static
com.google.common.base.Preconditions.checkArgument;
import static org.apache.iotdb.db.mpp.common.DataNodeEndPoints.isSameNode;
import static
org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.calculateMaxAggregationResultSize;
import static
org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.calculateMaxAggregationResultSizeForLastQuery;
+import static
org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.getOutputColumnSizePerLine;
import static
org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.initTimeRangeIterator;
import static
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.SeriesScanOptions.updateFilterUsingTTL;
@@ -1637,6 +1643,49 @@ public class OperatorTreeGenerator extends
PlanVisitor<Operator, LocalExecutionP
MergeSortComparator.getComparator(sortItemList, sortItemIndexList,
sortItemDataTypeList));
}
+ @Override
+ public Operator visitForecast(ForecastNode node, LocalExecutionPlanContext
context) {
+ Operator child = node.getChild().accept(this, context);
+ OperatorContext operatorContext =
+ context
+ .getDriverContext()
+ .addOperatorContext(
+ context.getNextOperatorId(),
+ node.getPlanNodeId(),
+ ForecastOperator.class.getSimpleName());
+
+ ForecastModelInferenceDescriptor forecastModelInferenceDescriptor =
+ node.getModelInferenceDescriptor();
+
+ List<TSDataType> inputTypeList =
forecastModelInferenceDescriptor.getInputTypeList();
+ int modelInputLength =
forecastModelInferenceDescriptor.getModelInputLength();
+ long timeValueColumnsSizePerLine = TimeColumn.SIZE_IN_BYTES_PER_POSITION;
+ for (TSDataType dataType : inputTypeList) {
+ timeValueColumnsSizePerLine += getOutputColumnSizePerLine(dataType, new
PartialPath());
+ }
+ long maxRetainedSize = timeValueColumnsSizePerLine * modelInputLength;
+
+ int expectedPredictLength =
forecastModelInferenceDescriptor.getExpectedPredictLength();
+ int outputColumnNum =
forecastModelInferenceDescriptor.getPredictIndexList().size();
+ long maxReturnSize =
+ (TimeColumn.SIZE_IN_BYTES_PER_POSITION
+ + (long) outputColumnNum *
DoubleColumn.SIZE_IN_BYTES_PER_POSITION)
+ * expectedPredictLength;
+
+ context.getTimeSliceAllocator().recordExecutionWeight(operatorContext, 1);
+
+ return new ForecastOperator(
+ operatorContext,
+ child,
+ forecastModelInferenceDescriptor.getModelPath(),
+ inputTypeList,
+ node.getChild().getOutputColumnNames(),
+ expectedPredictLength,
+ FragmentInstanceManager.getInstance().getModelInferenceExecutor(),
+ maxRetainedSize,
+ maxReturnSize);
+ }
+
@Override
public Operator visitInto(IntoNode node, LocalExecutionPlanContext context) {
Operator child = node.getChild().accept(this, context);
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
index a9726d0c2f7..2ad21bb98f3 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
@@ -45,6 +45,7 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
@@ -443,6 +444,17 @@ public class PlanGraphPrinter extends
PlanVisitor<List<String>, PlanGraphPrinter
return render(node, boxValue, context);
}
+ @Override
+ public List<String> visitForecast(ForecastNode node, GraphContext context) {
+ List<String> boxValue = new ArrayList<>();
+ boxValue.add(String.format("Forecast-%s", node.getPlanNodeId().getId()));
+ boxValue.add("Output: ");
+ for (String outputColumnName : node.getOutputColumnNames()) {
+ boxValue.add(String.format(" %s", outputColumnName));
+ }
+ return render(node, boxValue, context);
+ }
+
private String printRegion(TRegionReplicaSet regionReplicaSet) {
return String.format(
"Partition: %s",
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
index f0ff085eded..77250c97e70 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
@@ -78,6 +78,7 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
@@ -177,7 +178,8 @@ public enum PlanNodeType {
CONSTRUCT_LOGICAL_VIEW_BLACK_LIST((short) 74),
ROLLBACK_LOGICAL_VIEW_BLACK_LIST((short) 75),
DELETE_LOGICAL_VIEW((short) 76),
- LOGICAL_VIEW_SCHEMA_SCAN((short) 77);
+ LOGICAL_VIEW_SCHEMA_SCAN((short) 77),
+ FORECAST((short) 78);
public static final int BYTES = Short.BYTES;
@@ -380,6 +382,8 @@ public enum PlanNodeType {
return DeleteLogicalViewNode.deserialize(buffer);
case 77:
return LogicalViewSchemaScanNode.deserialize(buffer);
+ case 78:
+ return ForecastNode.deserialize(buffer);
default:
throw new IllegalArgumentException("Invalid node type: " + nodeType);
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
index d2955839a8c..486dbabfb29 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
@@ -77,6 +77,7 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
import
org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
@@ -190,6 +191,10 @@ public abstract class PlanVisitor<R, C> {
return visitSingleChildProcess(node, context);
}
+ public R visitForecast(ForecastNode node, C context) {
+ return visitSingleChildProcess(node, context);
+ }
+
// multi child
--------------------------------------------------------------------------------
public R visitMultiChildProcess(MultiChildProcessNode node, C context) {
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
new file mode 100644
index 00000000000..47809354641
--- /dev/null
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml;
+
+import org.apache.iotdb.db.mpp.plan.expression.Expression;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeType;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanVisitor;
+import
org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SingleChildProcessNode;
+import
org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+public class ForecastNode extends SingleChildProcessNode {
+
+ private final ForecastModelInferenceDescriptor modelInferenceDescriptor;
+
+ private List<String> outputColumnNames;
+
+ public ForecastNode(
+ PlanNodeId id, PlanNode child, ForecastModelInferenceDescriptor
modelInferenceDescriptor) {
+ super(id, child);
+ this.modelInferenceDescriptor = modelInferenceDescriptor;
+ }
+
+ public ForecastNode(PlanNodeId id, ForecastModelInferenceDescriptor
modelInferenceDescriptor) {
+ super(id);
+ this.modelInferenceDescriptor = modelInferenceDescriptor;
+ }
+
+ public ForecastModelInferenceDescriptor getModelInferenceDescriptor() {
+ return modelInferenceDescriptor;
+ }
+
+ @Override
+ public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
+ return visitor.visitForecast(this, context);
+ }
+
+ @Override
+ public PlanNode clone() {
+ return new ForecastNode(getPlanNodeId(), child, modelInferenceDescriptor);
+ }
+
+ @Override
+ public List<String> getOutputColumnNames() {
+ if (outputColumnNames == null) {
+ outputColumnNames = new ArrayList<>();
+ for (Expression expression :
modelInferenceDescriptor.getModelInferenceOutputExpressions()) {
+ outputColumnNames.add(expression.toString());
+ }
+ }
+ return outputColumnNames;
+ }
+
+ @Override
+ protected void serializeAttributes(ByteBuffer byteBuffer) {
+ PlanNodeType.FORECAST.serialize(byteBuffer);
+ modelInferenceDescriptor.serialize(byteBuffer);
+ }
+
+ @Override
+ protected void serializeAttributes(DataOutputStream stream) throws
IOException {
+ PlanNodeType.FORECAST.serialize(stream);
+ modelInferenceDescriptor.serialize(stream);
+ }
+
+ public static ForecastNode deserialize(ByteBuffer buffer) {
+ ForecastModelInferenceDescriptor modelInferenceDescriptor =
+ ForecastModelInferenceDescriptor.deserialize(buffer);
+ PlanNodeId planNodeId = PlanNodeId.deserialize(buffer);
+ return new ForecastNode(planNodeId, modelInferenceDescriptor);
+ }
+
+ @Override
+ public String toString() {
+ return "ForecastNode-" + this.getPlanNodeId();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ if (!super.equals(o)) {
+ return false;
+ }
+ ForecastNode that = (ForecastNode) o;
+ return modelInferenceDescriptor.equals(that.modelInferenceDescriptor);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), modelInferenceDescriptor);
+ }
+}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
new file mode 100644
index 00000000000..d4aff88e07b
--- /dev/null
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model;
+
+import org.apache.iotdb.commons.model.ForecastModeInformation;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Objects;
+
+import static org.apache.iotdb.db.constant.SqlConstant.MODEL_ID;
+import static org.apache.iotdb.db.constant.SqlConstant.PREDICT_LENGTH;
+
+public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor
{
+
+ private final List<TSDataType> inputTypeList;
+ private final List<Integer> predictIndexList;
+
+ private final int modelInputLength;
+ private final int modelPredictLength;
+ private int expectedPredictLength;
+
+ private LinkedHashMap<String, String> outputAttributes;
+
+ public ForecastModelInferenceDescriptor(
+ ModelInferenceFunction functionType, ForecastModeInformation
modelInformation) {
+ super(functionType, modelInformation);
+ this.inputTypeList = modelInformation.getInputTypeList();
+ this.predictIndexList = modelInformation.getPredictIndexList();
+ this.modelInputLength = modelInformation.getInputLength();
+ this.expectedPredictLength = this.modelPredictLength =
modelInformation.getPredictLength();
+ }
+
+ public ForecastModelInferenceDescriptor(ByteBuffer buffer) {
+ super(buffer);
+ int listSize = ReadWriteIOUtils.readInt(buffer);
+ this.inputTypeList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ this.inputTypeList.add(TSDataType.deserializeFrom(buffer));
+ }
+ listSize = ReadWriteIOUtils.readInt(buffer);
+ this.predictIndexList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ this.predictIndexList.add(ReadWriteIOUtils.readInt(buffer));
+ }
+ this.modelInputLength = ReadWriteIOUtils.readInt(buffer);
+ this.modelPredictLength = ReadWriteIOUtils.readInt(buffer);
+ this.expectedPredictLength = ReadWriteIOUtils.readInt(buffer);
+ }
+
+ public List<Integer> getPredictIndexList() {
+ return predictIndexList;
+ }
+
+ public List<TSDataType> getInputTypeList() {
+ return inputTypeList;
+ }
+
+ public int getModelInputLength() {
+ return modelInputLength;
+ }
+
+ public int getModelPredictLength() {
+ return modelPredictLength;
+ }
+
+ public int getExpectedPredictLength() {
+ return expectedPredictLength;
+ }
+
+ public void setExpectedPredictLength(int expectedPredictLength) {
+ this.expectedPredictLength = expectedPredictLength;
+ }
+
+ @Override
+ public LinkedHashMap<String, String> getOutputAttributes() {
+ if (outputAttributes == null) {
+ outputAttributes = new LinkedHashMap<>();
+ outputAttributes.put(MODEL_ID, modelId);
+ if (expectedPredictLength != modelPredictLength) {
+ outputAttributes.put(PREDICT_LENGTH,
String.valueOf(expectedPredictLength));
+ }
+ }
+ return outputAttributes;
+ }
+
+ @Override
+ public void serialize(ByteBuffer byteBuffer) {
+ super.serialize(byteBuffer);
+ ReadWriteIOUtils.write(inputTypeList.size(), byteBuffer);
+ for (TSDataType dataType : inputTypeList) {
+ dataType.serializeTo(byteBuffer);
+ }
+ ReadWriteIOUtils.write(predictIndexList.size(), byteBuffer);
+ for (Integer index : predictIndexList) {
+ ReadWriteIOUtils.write(index, byteBuffer);
+ }
+ ReadWriteIOUtils.write(modelInputLength, byteBuffer);
+ ReadWriteIOUtils.write(modelPredictLength, byteBuffer);
+ ReadWriteIOUtils.write(expectedPredictLength, byteBuffer);
+ }
+
+ @Override
+ public void serialize(DataOutputStream stream) throws IOException {
+ super.serialize(stream);
+ ReadWriteIOUtils.write(inputTypeList.size(), stream);
+ for (TSDataType dataType : inputTypeList) {
+ dataType.serializeTo(stream);
+ }
+ ReadWriteIOUtils.write(predictIndexList.size(), stream);
+ for (Integer index : predictIndexList) {
+ ReadWriteIOUtils.write(index, stream);
+ }
+ ReadWriteIOUtils.write(modelInputLength, stream);
+ ReadWriteIOUtils.write(modelPredictLength, stream);
+ ReadWriteIOUtils.write(expectedPredictLength, stream);
+ }
+
+ public static ForecastModelInferenceDescriptor deserialize(ByteBuffer
buffer) {
+ return new ForecastModelInferenceDescriptor(buffer);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ if (!super.equals(o)) {
+ return false;
+ }
+ ForecastModelInferenceDescriptor that = (ForecastModelInferenceDescriptor)
o;
+ return modelInputLength == that.modelInputLength
+ && modelPredictLength == that.modelPredictLength
+ && expectedPredictLength == that.expectedPredictLength
+ && inputTypeList.equals(that.inputTypeList)
+ && predictIndexList.equals(that.predictIndexList);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ super.hashCode(),
+ inputTypeList,
+ predictIndexList,
+ modelInputLength,
+ modelPredictLength,
+ expectedPredictLength);
+ }
+}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
new file mode 100644
index 00000000000..1948ebeb983
--- /dev/null
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model;
+
+import org.apache.iotdb.commons.model.ModelInformation;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
+import org.apache.iotdb.db.mpp.plan.expression.multi.FunctionExpression;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Objects;
+
+public abstract class ModelInferenceDescriptor {
+
+ protected final ModelInferenceFunction functionType;
+
+ protected final String modelId;
+
+ protected final String modelPath;
+
+ protected List<FunctionExpression> modelInferenceOutputExpressions;
+
+ public ModelInferenceDescriptor(
+ ModelInferenceFunction functionType, ModelInformation modelInformation) {
+ this.functionType = functionType;
+ this.modelId = modelInformation.getModelId();
+ this.modelPath = modelInformation.getModelPath();
+ }
+
+ public ModelInferenceDescriptor(ByteBuffer buffer) {
+ this.functionType =
ModelInferenceFunction.values()[ReadWriteIOUtils.readInt(buffer)];
+ this.modelId = ReadWriteIOUtils.readString(buffer);
+ this.modelPath = ReadWriteIOUtils.readString(buffer);
+ }
+
+ public ModelInferenceFunction getFunctionType() {
+ return functionType;
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ public String getModelPath() {
+ return modelPath;
+ }
+
+ public List<FunctionExpression> getModelInferenceOutputExpressions() {
+ return modelInferenceOutputExpressions;
+ }
+
+ public void setModelInferenceOutputExpressions(
+ List<FunctionExpression> modelInferenceOutputExpressions) {
+ this.modelInferenceOutputExpressions = modelInferenceOutputExpressions;
+ }
+
+ public abstract LinkedHashMap<String, String> getOutputAttributes();
+
+ public void serialize(ByteBuffer byteBuffer) {
+ ReadWriteIOUtils.write(functionType.ordinal(), byteBuffer);
+ ReadWriteIOUtils.write(modelId, byteBuffer);
+ ReadWriteIOUtils.write(modelPath, byteBuffer);
+ }
+
+ public void serialize(DataOutputStream stream) throws IOException {
+ ReadWriteIOUtils.write(functionType.ordinal(), stream);
+ ReadWriteIOUtils.write(modelId, stream);
+ ReadWriteIOUtils.write(modelPath, stream);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ModelInferenceDescriptor that = (ModelInferenceDescriptor) o;
+ return functionType == that.functionType
+ && modelId.equals(that.modelId)
+ && modelPath.equals(that.modelPath)
+ &&
modelInferenceOutputExpressions.equals(that.modelInferenceOutputExpressions);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(functionType, modelId, modelPath,
modelInferenceOutputExpressions);
+ }
+}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/component/ResultColumn.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/component/ResultColumn.java
index 38523b11012..8887589f991 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/component/ResultColumn.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/component/ResultColumn.java
@@ -127,6 +127,7 @@ public class ResultColumn extends StatementNode {
public enum ColumnType {
RAW,
AGGREGATION,
- CONSTANT
+ CONSTANT,
+ MODEL_INFERENCE
}
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/component/SelectComponent.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/component/SelectComponent.java
index 74ac2c554c1..ef5e0853888 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/component/SelectComponent.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/component/SelectComponent.java
@@ -36,6 +36,8 @@ public class SelectComponent extends StatementNode {
private boolean hasBuiltInAggregationFunction = false;
+ private boolean hasModelInferenceFunction = false;
+
protected List<ResultColumn> resultColumns = new ArrayList<>();
private Map<String, Expression> aliasToColumnMap;
@@ -48,15 +50,21 @@ public class SelectComponent extends StatementNode {
return zoneId;
}
- public boolean isHasBuiltInAggregationFunction() {
+ public boolean hasBuiltInAggregationFunction() {
return hasBuiltInAggregationFunction;
}
+ public boolean hasModelInferenceFunction() {
+ return hasModelInferenceFunction;
+ }
+
public void addResultColumn(ResultColumn resultColumn) {
resultColumns.add(resultColumn);
ResultColumn.ColumnType columnType = resultColumn.getColumnType();
if (columnType == ResultColumn.ColumnType.AGGREGATION) {
hasBuiltInAggregationFunction = true;
+ } else if (columnType == ResultColumn.ColumnType.MODEL_INFERENCE) {
+ hasModelInferenceFunction = true;
}
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
index 3b7fe31c517..ba6a3b3c68a 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
@@ -51,6 +51,8 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
+import static org.apache.iotdb.db.constant.SqlConstant.MODEL_ID;
+
/**
* Base class of SELECT statement.
*
@@ -276,7 +278,7 @@ public class QueryStatement extends Statement {
}
public boolean isAggregationQuery() {
- return selectComponent.isHasBuiltInAggregationFunction();
+ return selectComponent.hasBuiltInAggregationFunction();
}
public boolean isGroupByLevel() {
@@ -431,8 +433,7 @@ public class QueryStatement extends Statement {
List<SortItem> sortItems = getSortItemList();
List<SortItem> newSortItems = new ArrayList<>();
int expressionIndex = 0;
- for (int i = 0; i < sortItems.size(); i++) {
- SortItem sortItem = sortItems.get(i);
+ for (SortItem sortItem : sortItems) {
SortItem newSortItem =
new SortItem(sortItem.getSortKey(), sortItem.getOrdering(),
sortItem.getNullOrdering());
if (sortItem.isExpression()) {
@@ -480,7 +481,52 @@ public class QueryStatement extends Statement {
return useWildcard;
}
+ public boolean isModelInferenceQuery() {
+ return selectComponent.hasModelInferenceFunction();
+ }
+
public void semanticCheck() {
+ if (isModelInferenceQuery()) {
+ if (selectComponent.getResultColumns().size() > 1) {
+ throw new SemanticException("");
+ }
+
+ Expression modelInferenceExpression =
+ selectComponent.getResultColumns().get(0).getExpression();
+ if (!(modelInferenceExpression instanceof FunctionExpression
+ && ((FunctionExpression)
modelInferenceExpression).isModelInferenceFunction())) {
+ throw new SemanticException("");
+ }
+ if (!((FunctionExpression) modelInferenceExpression)
+ .getFunctionAttributes()
+ .containsKey(MODEL_ID)) {
+ throw new SemanticException("");
+ }
+ if
(ExpressionAnalyzer.searchAggregationExpressions(modelInferenceExpression).size()
> 0) {
+ throw new SemanticException("");
+ }
+
+ if (hasHaving()
+ || isGroupBy()
+ || isGroupByLevel()
+ || isGroupByTag()
+ || isAlignByDevice()
+ || isLastQuery()
+ || seriesLimit > 0
+ || seriesOffset > 0
+ || isSelectInto()
+ || isOrderByDevice()
+ || isOrderByTimeseries()) {
+ throw new SemanticException("");
+ }
+
+ if (orderByComponent != null
+ && (!orderByComponent.isOrderByTime()
+ || orderByComponent.getTimeOrder() != Ordering.ASC)) {
+ throw new SemanticException("");
+ }
+ }
+
if (isAggregationQuery()) {
if (disableAlign()) {
throw new SemanticException("AGGREGATION doesn't support disable align
clause.");
diff --git
a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/FakePartitionFetcherImpl.java
b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/FakePartitionFetcherImpl.java
index b15229ca779..01fb0795ac8 100644
---
a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/FakePartitionFetcherImpl.java
+++
b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/FakePartitionFetcherImpl.java
@@ -26,6 +26,7 @@ import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
import org.apache.iotdb.common.rpc.thrift.TSeriesPartitionSlot;
import org.apache.iotdb.common.rpc.thrift.TTimePartitionSlot;
+import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.partition.DataPartition;
import org.apache.iotdb.commons.partition.DataPartitionQueryParam;
import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition;
@@ -298,4 +299,9 @@ public class FakePartitionFetcherImpl implements
IPartitionFetcher {
@Override
public void invalidAllCache() {}
+
+ @Override
+ public ModelInformation getModelInformation(String modelId) {
+ return null;
+ }
}
diff --git
a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/Util.java
b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/Util.java
index a8aa156554c..a167129114f 100644
---
a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/Util.java
+++
b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/Util.java
@@ -27,6 +27,7 @@ import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
import org.apache.iotdb.common.rpc.thrift.TSeriesPartitionSlot;
import org.apache.iotdb.common.rpc.thrift.TTimePartitionSlot;
import org.apache.iotdb.commons.exception.IllegalPathException;
+import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.partition.DataPartition;
import org.apache.iotdb.commons.partition.DataPartitionQueryParam;
import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition;
@@ -395,6 +396,11 @@ public class Util {
@Override
public void invalidAllCache() {}
+
+ @Override
+ public ModelInformation getModelInformation(String modelId) {
+ return null;
+ }
};
}
diff --git
a/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java
b/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java
index 583a6cd5a24..4c686d2a24a 100644
---
a/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java
+++
b/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java
@@ -21,7 +21,9 @@ package org.apache.iotdb.tsfile.file.metadata.enums;
import org.apache.iotdb.tsfile.exception.write.UnSupportedDataTypeException;
import java.io.DataOutputStream;
+import java.io.FileOutputStream;
import java.io.IOException;
+import java.io.InputStream;
import java.nio.ByteBuffer;
public enum TSDataType {
@@ -96,6 +98,10 @@ public enum TSDataType {
return deserialize(buffer.get());
}
+ public static TSDataType deserializeFrom(InputStream stream) throws
IOException {
+ return deserialize((byte) stream.read());
+ }
+
public static int getSerializedSize() {
return Byte.BYTES;
}
@@ -108,6 +114,10 @@ public enum TSDataType {
outputStream.write(serialize());
}
+ public void serializeTo(FileOutputStream outputStream) throws IOException {
+ outputStream.write(serialize());
+ }
+
public int getDataTypeSize() {
switch (this) {
case BOOLEAN:
diff --git
a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java
b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java
index 9f37f4b780d..46ea0b667c9 100644
---
a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java
+++
b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java
@@ -147,6 +147,19 @@ public class TsBlock {
return wrapBlocksWithoutCopy(positionCount, timeColumn, newBlocks);
}
+ public TsBlock appendValueColumns(Column[] columns) {
+ Column[] newBlocks = Arrays.copyOf(valueColumns, valueColumns.length +
columns.length);
+ int newColumnIndex = valueColumns.length;
+ for (Column column : columns) {
+ requireNonNull(column, "Column is null");
+ if (positionCount != column.getPositionCount()) {
+ throw new IllegalArgumentException("Block does not have same position
count");
+ }
+ newBlocks[newColumnIndex++] = column;
+ }
+ return wrapBlocksWithoutCopy(positionCount, timeColumn, newBlocks);
+ }
+
/**
* Attention. This method uses System.arraycopy() to extend the valueColumn
array, so its
* performance is not ensured if you have many insert operations.
diff --git
a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java
b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java
index c309835a09f..cab5e90e3e0 100644
---
a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java
+++
b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java
@@ -81,6 +81,7 @@ public class TsBlockBuilder {
res.timeColumnBuilder =
new TimeColumnBuilder(
res.tsBlockBuilderStatus.createColumnBuilderStatus(),
DEFAULT_INITIAL_EXPECTED_ENTRIES);
+ res.valueColumnBuilders = new ColumnBuilder[0];
return res;
}
@@ -285,9 +286,6 @@ public class TsBlockBuilder {
}
public TsBlock build() {
- if (valueColumnBuilders.length == 0) {
- return new TsBlock(declaredPositions);
- }
TimeColumn timeColumn = (TimeColumn) timeColumnBuilder.build();
if (timeColumn.getPositionCount() != declaredPositions) {
throw new IllegalStateException(