This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch lmh/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/lmh/forecast by this push:
new f6ee9360f08 fix iotdb
f6ee9360f08 is described below
commit f6ee9360f087ed710b022039a393e49759879632
Author: Minghui Liu <[email protected]>
AuthorDate: Mon May 22 10:42:01 2023 +0800
fix iotdb
---
.../consensus/request/ConfigPhysicalPlan.java | 4 +
.../consensus/request/ConfigPhysicalPlanType.java | 1 +
.../request/read/model/GetModelInfoPlan.java | 79 ++++++++++++
.../consensus/response/model/GetModelInfoResp.java | 44 +++++++
.../response/{ => model}/ModelTableResp.java | 2 +-
.../response/{ => model}/TrailTableResp.java | 2 +-
.../iotdb/confignode/manager/ConfigManager.java | 10 ++
.../apache/iotdb/confignode/manager/IManager.java | 5 +
.../iotdb/confignode/manager/ModelManager.java | 77 ++++++++++--
.../iotdb/confignode/persistence/ModelInfo.java | 32 ++++-
.../persistence/executor/ConfigPlanExecutor.java | 3 +
.../thrift/ConfigNodeRPCServiceProcessor.java | 7 ++
.../commons/model/ForecastModeInformation.java | 139 +++++++++++++++++++++
.../iotdb/commons/model/ModelInformation.java | 48 +++++--
.../apache/iotdb/db/client/ConfigNodeClient.java | 18 +++
.../operator/process/ml/ForecastOperator.java | 21 +++-
.../iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java | 12 +-
.../mpp/plan/analyze/ClusterPartitionFetcher.java | 16 ++-
.../config/executor/ClusterConfigTaskExecutor.java | 16 ++-
.../db/mpp/plan/planner/OperatorTreeGenerator.java | 2 +
.../model/ForecastModelInferenceDescriptor.java | 28 ++---
.../src/main/thrift/confignode.thrift | 14 +++
.../tsfile/file/metadata/enums/TSDataType.java | 10 ++
.../iotdb/tsfile/read/common/block/TsBlock.java | 13 ++
.../tsfile/read/common/block/TsBlockBuilder.java | 4 +-
25 files changed, 548 insertions(+), 59 deletions(-)
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
index 2991f05ea65..51eef38d4f4 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
@@ -24,6 +24,7 @@ import
org.apache.iotdb.confignode.consensus.request.read.database.CountDatabase
import
org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan;
import
org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan;
import
org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan;
+import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan;
import
org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan;
@@ -456,6 +457,9 @@ public abstract class ConfigPhysicalPlan implements
IConsensusRequest {
case ShowTrail:
plan = new ShowTrailPlan();
break;
+ case GetModelInfo:
+ plan = new GetModelInfoPlan();
+ break;
case CreatePipePlugin:
plan = new CreatePipePluginPlan();
break;
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java
index d36a9973cd5..9adace96172 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java
@@ -162,6 +162,7 @@ public enum ConfigPhysicalPlanType {
DropModel((short) 1203),
ShowModel((short) 1204),
ShowTrail((short) 1205),
+ GetModelInfo((short) 1206),
/** Pipe Plugin */
CreatePipePlugin((short) 1300),
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java
new file mode 100644
index 00000000000..79c7e9e0142
--- /dev/null
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.confignode.consensus.request.read.model;
+
+import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan;
+import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Objects;
+
+public class GetModelInfoPlan extends ConfigPhysicalPlan {
+
+ private String modelId;
+
+ public GetModelInfoPlan() {
+ super(ConfigPhysicalPlanType.GetModelInfo);
+ }
+
+ public GetModelInfoPlan(TGetModelInfoReq getModelInfoReq) {
+ super(ConfigPhysicalPlanType.GetModelInfo);
+ this.modelId = getModelInfoReq.getModelId();
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ @Override
+ protected void serializeImpl(DataOutputStream stream) throws IOException {
+ stream.writeShort(getType().getPlanType());
+ ReadWriteIOUtils.write(modelId, stream);
+ }
+
+ @Override
+ protected void deserializeImpl(ByteBuffer buffer) throws IOException {
+ this.modelId = ReadWriteIOUtils.readString(buffer);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ if (!super.equals(o)) {
+ return false;
+ }
+ GetModelInfoPlan that = (GetModelInfoPlan) o;
+ return Objects.equals(modelId, that.modelId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), modelId);
+ }
+}
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java
new file mode 100644
index 00000000000..5e7ee2641fa
--- /dev/null
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.confignode.consensus.response.model;
+
+import org.apache.iotdb.common.rpc.thrift.TSStatus;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
+import org.apache.iotdb.consensus.common.DataSet;
+
+import java.nio.ByteBuffer;
+
+public class GetModelInfoResp implements DataSet {
+
+ private final TSStatus status;
+ private ByteBuffer serializedModelInformation;
+
+ public GetModelInfoResp(TSStatus status) {
+ this.status = status;
+ }
+
+ public void setModelInfo(ByteBuffer serializedModelInformation) {
+ this.serializedModelInformation = serializedModelInformation;
+ }
+
+ public TGetModelInfoResp convertToThriftResponse() {
+ return new TGetModelInfoResp(status, serializedModelInformation);
+ }
+}
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
similarity index 97%
rename from
confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
rename to
confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
index 6642f76be97..9a23d9ed713 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.iotdb.confignode.consensus.response;
+package org.apache.iotdb.confignode.consensus.response.model;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.model.ModelInformation;
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/TrailTableResp.java
similarity index 97%
rename from
confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
rename to
confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/TrailTableResp.java
index 1f9a6b5acbf..3f2c2823dac 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/TrailTableResp.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.iotdb.confignode.consensus.response;
+package org.apache.iotdb.confignode.consensus.response.model;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.model.TrailInformation;
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index 4d9a4cdc8c8..7028240dd9e 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -131,6 +131,8 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetRegionIdReq;
@@ -1850,6 +1852,14 @@ public class ConfigManager implements IManager {
: status;
}
+ @Override
+ public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
+ TSStatus status = confirmLeader();
+ return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
+ ? modelManager.getModelInfo(req)
+ : new TGetModelInfoResp(status, null);
+ }
+
@Override
public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) {
TSStatus status = confirmLeader();
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
index 4a305f822bc..57c8ba1c1b9 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
@@ -74,6 +74,8 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetRegionIdReq;
@@ -666,6 +668,9 @@ public interface IManager {
/** Update the model state */
TSStatus updateModelState(TUpdateModelStateReq req);
+ /** Update the model state */
+ TGetModelInfoResp getModelInfo(TGetModelInfoReq req);
+
/** Set space quota */
TSStatus setSpaceQuota(TSetSpaceQuotaReq req);
}
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
index 425bafbd4c0..9a7c1080d96 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
@@ -19,17 +19,23 @@
package org.apache.iotdb.confignode.manager;
+import org.apache.iotdb.common.rpc.thrift.ModelTask;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
+import org.apache.iotdb.commons.model.ForecastModeInformation;
import org.apache.iotdb.commons.model.ModelInformation;
+import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelStatePlan;
-import org.apache.iotdb.confignode.consensus.response.ModelTableResp;
-import org.apache.iotdb.confignode.consensus.response.TrailTableResp;
+import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp;
+import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp;
+import org.apache.iotdb.confignode.consensus.response.model.TrailTableResp;
import org.apache.iotdb.confignode.persistence.ModelInfo;
import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp;
import org.apache.iotdb.confignode.rpc.thrift.TShowTrailReq;
@@ -39,12 +45,17 @@ import
org.apache.iotdb.confignode.rpc.thrift.TUpdateModelStateReq;
import org.apache.iotdb.consensus.common.response.ConsensusReadResponse;
import org.apache.iotdb.consensus.common.response.ConsensusWriteResponse;
import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
+import java.util.Arrays;
import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
public class ModelManager {
@@ -63,15 +74,41 @@ public class ModelManager {
}
public TSStatus createModel(TCreateModelReq req) {
- ModelInformation modelInformation =
- new ModelInformation(
- req.getModelId(),
- req.getModelTask(),
- req.getModelType(),
- req.isIsAuto(),
- req.getQueryExpressions(),
- req.getQueryFilter());
- return configManager.getProcedureManager().createModel(modelInformation,
req.getModelConfigs());
+ ModelTask modelTask = req.getModelTask();
+ Map<String, String> modelConfigs = req.getModelConfigs();
+ ModelInformation modelInformation;
+ switch (modelTask) {
+ case FORECAST:
+ String inputTypeListStr = modelConfigs.get("input_type_list");
+ List<TSDataType> inputTypeList =
+ Arrays.stream(inputTypeListStr.split(","))
+ .sequential()
+ .map(s -> TSDataType.valueOf(s.toUpperCase()))
+ .collect(Collectors.toList());
+
+ String predictIndexListStr = modelConfigs.get("predict_index_list");
+ List<Integer> predictIndexList =
+ Arrays.stream(predictIndexListStr.split(","))
+ .sequential()
+ .map(Integer::valueOf)
+ .collect(Collectors.toList());
+
+ modelInformation =
+ new ForecastModeInformation(
+ req.getModelId(),
+ req.getModelType(),
+ req.isIsAuto(),
+ req.getQueryExpressions(),
+ req.getQueryFilter(),
+ inputTypeList,
+ predictIndexList,
+ Integer.parseInt(modelConfigs.getOrDefault("input_length",
"96")),
+ Integer.parseInt(modelConfigs.getOrDefault("predict_length",
"96")));
+ break;
+ default:
+ throw new IllegalArgumentException("Invalid task type: " + modelTask);
+ }
+ return configManager.getProcedureManager().createModel(modelInformation,
modelConfigs);
}
public TSStatus dropModel(TDropModelReq req) {
@@ -126,7 +163,7 @@ public class ModelManager {
return new TShowModelResp(res, Collections.emptyList());
}
} catch (IOException e) {
- LOGGER.error("Fail to get ModelTable", e);
+ LOGGER.warn("Fail to get ModelTable", e);
return new TShowModelResp(
new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
.setMessage(e.getMessage()),
@@ -148,11 +185,25 @@ public class ModelManager {
return new TShowTrailResp(res, Collections.emptyList());
}
} catch (IOException e) {
- LOGGER.error("Fail to get TrailTable", e);
+ LOGGER.warn("Fail to get TrailTable", e);
return new TShowTrailResp(
new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
.setMessage(e.getMessage()),
Collections.emptyList());
}
}
+
+ public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
+ ConsensusReadResponse response =
+ configManager.getConsensusManager().read(new GetModelInfoPlan(req));
+ if (response.getDataset() != null) {
+ return ((GetModelInfoResp)
response.getDataset()).convertToThriftResponse();
+ } else {
+ LOGGER.warn("Unexpected error happened while getting model: ",
response.getException());
+ // consensus layer related errors
+ TSStatus res = new
TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
+ res.setMessage(response.getException().toString());
+ return new TGetModelInfoResp(res, null);
+ }
+ }
}
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 3c72e09570f..1e14bb6b137 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -24,15 +24,18 @@ import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.model.ModelTable;
import org.apache.iotdb.commons.model.TrailInformation;
import org.apache.iotdb.commons.snapshot.SnapshotProcessor;
+import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelStatePlan;
-import org.apache.iotdb.confignode.consensus.response.ModelTableResp;
-import org.apache.iotdb.confignode.consensus.response.TrailTableResp;
+import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp;
+import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp;
+import org.apache.iotdb.confignode.consensus.response.model.TrailTableResp;
import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.utils.PublicBAOS;
import org.apache.thrift.TException;
import org.slf4j.Logger;
@@ -40,10 +43,12 @@ import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.ThreadSafe;
+import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.concurrent.locks.ReentrantLock;
@ThreadSafe
@@ -155,6 +160,29 @@ public class ModelInfo implements SnapshotProcessor {
}
}
+ public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) {
+ acquireModelTableLock();
+ try {
+ GetModelInfoResp getModelInfoResp =
+ new GetModelInfoResp(new
TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
+ ModelInformation modelInformation =
modelTable.getModelInformationById(plan.getModelId());
+ if (modelInformation != null) {
+ PublicBAOS buffer = new PublicBAOS();
+ DataOutputStream stream = new DataOutputStream(buffer);
+ modelInformation.serialize(stream);
+ getModelInfoResp.setModelInfo(ByteBuffer.wrap(buffer.getBuf(), 0,
buffer.size()));
+ }
+ return getModelInfoResp;
+ } catch (IOException e) {
+ LOGGER.warn("Fail to get model info", e);
+ return new GetModelInfoResp(
+ new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
+ .setMessage(e.getMessage()));
+ } finally {
+ releaseModelTableLock();
+ }
+ }
+
public TSStatus updateModelInfo(UpdateModelInfoPlan plan) {
acquireModelTableLock();
try {
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
index 2f812b173b2..44527d67817 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
@@ -29,6 +29,7 @@ import
org.apache.iotdb.confignode.consensus.request.auth.AuthorPlan;
import
org.apache.iotdb.confignode.consensus.request.read.database.CountDatabasePlan;
import
org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan;
import
org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan;
+import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan;
import
org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan;
@@ -278,6 +279,8 @@ public class ConfigPlanExecutor {
return modelInfo.showModel((ShowModelPlan) req);
case ShowTrail:
return modelInfo.showTrail((ShowTrailPlan) req);
+ case GetModelInfo:
+ return modelInfo.getModelInfo((GetModelInfoPlan) req);
case GetPipePluginTable:
return pipeInfo.getPipePluginInfo().showPipePlugins();
case GetPipePluginJar:
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
b/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
index 22367d7c164..1aa9a29f95f 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
@@ -110,6 +110,8 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipeSinkReq;
@@ -990,6 +992,11 @@ public class ConfigNodeRPCServiceProcessor implements
IConfigNodeRPCService.Ifac
return configManager.updateModelState(req);
}
+ @Override
+ public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
+ return configManager.getModelInfo(req);
+ }
+
@Override
public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException {
return configManager.setSpaceQuota(req);
diff --git
a/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastModeInformation.java
b/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastModeInformation.java
new file mode 100644
index 00000000000..f32a0d5190f
--- /dev/null
+++
b/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastModeInformation.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.commons.model;
+
+import org.apache.iotdb.common.rpc.thrift.ModelTask;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import javax.annotation.Nullable;
+
+import java.io.DataOutputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ForecastModeInformation extends ModelInformation {
+
+ private final List<TSDataType> inputTypeList;
+
+ private final List<Integer> predictIndexList;
+
+ private final int inputLength;
+ private final int predictLength;
+
+ public ForecastModeInformation(
+ String modelId,
+ String modelType,
+ boolean isAuto,
+ List<String> queryExpressions,
+ @Nullable String queryFilter,
+ List<TSDataType> inputTypeList,
+ List<Integer> predictIndexList,
+ int inputLength,
+ int predictLength) {
+ super(ModelTask.FORECAST, modelId, modelType, isAuto, queryExpressions,
queryFilter);
+ this.inputTypeList = inputTypeList;
+ this.predictIndexList = predictIndexList;
+ this.inputLength = inputLength;
+ this.predictLength = predictLength;
+ }
+
+ public ForecastModeInformation(ByteBuffer buffer) {
+ super(ModelTask.FORECAST, buffer);
+ int listSize = ReadWriteIOUtils.readInt(buffer);
+ this.inputTypeList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ inputTypeList.add(TSDataType.deserializeFrom(buffer));
+ }
+ listSize = ReadWriteIOUtils.readInt(buffer);
+ this.predictIndexList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ predictIndexList.add(ReadWriteIOUtils.readInt(buffer));
+ }
+ this.inputLength = ReadWriteIOUtils.readInt(buffer);
+ this.predictLength = ReadWriteIOUtils.readInt(buffer);
+ }
+
+ public ForecastModeInformation(InputStream stream) throws IOException {
+ super(ModelTask.FORECAST, stream);
+ int listSize = ReadWriteIOUtils.readInt(stream);
+ this.inputTypeList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ inputTypeList.add(TSDataType.deserializeFrom(stream));
+ }
+ listSize = ReadWriteIOUtils.readInt(stream);
+ this.predictIndexList = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ predictIndexList.add(ReadWriteIOUtils.readInt(stream));
+ }
+ this.inputLength = ReadWriteIOUtils.readInt(stream);
+ this.predictLength = ReadWriteIOUtils.readInt(stream);
+ }
+
+ public List<TSDataType> getInputTypeList() {
+ return inputTypeList;
+ }
+
+ public List<Integer> getPredictIndexList() {
+ return predictIndexList;
+ }
+
+ public int getInputLength() {
+ return inputLength;
+ }
+
+ public int getPredictLength() {
+ return predictLength;
+ }
+
+ @Override
+ public void serialize(DataOutputStream stream) throws IOException {
+ super.serialize(stream);
+ ReadWriteIOUtils.write(inputTypeList.size(), stream);
+ for (TSDataType inputType : inputTypeList) {
+ inputType.serializeTo(stream);
+ }
+ ReadWriteIOUtils.write(predictIndexList.size(), stream);
+ for (Integer index : predictIndexList) {
+ ReadWriteIOUtils.write(index, stream);
+ }
+ ReadWriteIOUtils.write(inputLength, stream);
+ ReadWriteIOUtils.write(predictLength, stream);
+ }
+
+ @Override
+ public void serialize(FileOutputStream stream) throws IOException {
+ super.serialize(stream);
+ ReadWriteIOUtils.write(inputTypeList.size(), stream);
+ for (TSDataType inputType : inputTypeList) {
+ inputType.serializeTo(stream);
+ }
+ ReadWriteIOUtils.write(predictIndexList.size(), stream);
+ for (Integer index : predictIndexList) {
+ ReadWriteIOUtils.write(index, stream);
+ }
+ ReadWriteIOUtils.write(inputLength, stream);
+ ReadWriteIOUtils.write(predictLength, stream);
+ }
+}
diff --git
a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index 7c46a6c3437..fc0cdb45dde 100644
---
a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++
b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -39,7 +39,7 @@ import java.util.Map;
import static org.apache.iotdb.commons.model.TrailInformation.MODEL_PATH;
-public class ModelInformation {
+public abstract class ModelInformation {
private final String modelId;
private final ModelTask modelTask;
@@ -55,8 +55,8 @@ public class ModelInformation {
private final Map<String, TrailInformation> trailMap;
public ModelInformation(
- String modelId,
ModelTask modelTask,
+ String modelId,
String modelType,
boolean isAuto,
List<String> queryExpressions,
@@ -71,9 +71,10 @@ public class ModelInformation {
this.trailMap = new HashMap<>();
}
- public ModelInformation(ByteBuffer buffer) {
+ public ModelInformation(ModelTask modelTask, ByteBuffer buffer) {
+ this.modelTask = modelTask;
+
this.modelId = ReadWriteIOUtils.readString(buffer);
- this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(buffer));
this.modelType = ReadWriteIOUtils.readString(buffer);
int listSize = ReadWriteIOUtils.readInt(buffer);
@@ -103,9 +104,10 @@ public class ModelInformation {
}
}
- public ModelInformation(InputStream stream) throws IOException {
+ public ModelInformation(ModelTask modelTask, InputStream stream) throws
IOException {
+ this.modelTask = modelTask;
+
this.modelId = ReadWriteIOUtils.readString(stream);
- this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(stream));
this.modelType = ReadWriteIOUtils.readString(stream);
int listSize = ReadWriteIOUtils.readInt(stream);
@@ -200,8 +202,9 @@ public class ModelInformation {
}
public void serialize(DataOutputStream stream) throws IOException {
- ReadWriteIOUtils.write(modelId, stream);
ReadWriteIOUtils.write(modelTask.ordinal(), stream);
+
+ ReadWriteIOUtils.write(modelId, stream);
ReadWriteIOUtils.write(modelType, stream);
ReadWriteIOUtils.write(queryExpressions.size(), stream);
for (String queryExpression : queryExpressions) {
@@ -232,8 +235,9 @@ public class ModelInformation {
}
public void serialize(FileOutputStream stream) throws IOException {
- ReadWriteIOUtils.write(modelId, stream);
ReadWriteIOUtils.write(modelTask.ordinal(), stream);
+
+ ReadWriteIOUtils.write(modelId, stream);
ReadWriteIOUtils.write(modelType, stream);
ReadWriteIOUtils.write(queryExpressions.size(), stream);
@@ -264,12 +268,32 @@ public class ModelInformation {
}
}
- public static ModelInformation deserialize(InputStream stream) throws
IOException {
- return new ModelInformation(stream);
+ public static ModelInformation deserialize(ByteBuffer buffer) {
+ ModelTask modelTask =
ModelTask.findByValue(ReadWriteIOUtils.readInt(buffer));
+ if (modelTask == null) {
+ throw new IllegalArgumentException();
+ }
+
+ switch (modelTask) {
+ case FORECAST:
+ return new ForecastModeInformation(buffer);
+ default:
+ throw new IllegalArgumentException("Invalid task type: " + modelTask);
+ }
}
- public static ModelInformation deserialize(ByteBuffer buffer) {
- return new ModelInformation(buffer);
+ public static ModelInformation deserialize(InputStream stream) throws
IOException {
+ ModelTask modelTask =
ModelTask.findByValue(ReadWriteIOUtils.readInt(stream));
+ if (modelTask == null) {
+ throw new IllegalArgumentException();
+ }
+
+ switch (modelTask) {
+ case FORECAST:
+ return new ForecastModeInformation(stream);
+ default:
+ throw new IllegalArgumentException("Invalid task type: " + modelTask);
+ }
}
public ByteBuffer serializeShowModelResult() throws IOException {
diff --git
a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
index ca35b446fc3..8cae68b3d12 100644
--- a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
+++ b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
@@ -79,6 +79,8 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipeSinkReq;
@@ -2122,6 +2124,22 @@ public class ConfigNodeClient implements
IConfigNodeRPCService.Iface, ThriftClie
throw new TException(new UnsupportedOperationException().getCause());
}
+ @Override
+ public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) throws
TException {
+ for (int i = 0; i < RETRY_NUM; i++) {
+ try {
+ TGetModelInfoResp resp = client.getModelInfo(req);
+ if (!updateConfigNodeLeader(resp.getStatus())) {
+ return resp;
+ }
+ } catch (TException e) {
+ configLeader = null;
+ }
+ waitAndReconnect();
+ }
+ throw new TException(MSG_RECONNECTION_FAIL);
+ }
+
@Override
public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException {
for (int i = 0; i < RETRY_NUM; i++) {
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
index 5aac6830fbc..b7375985942 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
@@ -20,6 +20,7 @@
package org.apache.iotdb.db.mpp.execution.operator.process.ml;
import org.apache.iotdb.db.client.MLNodeClient;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
import org.apache.iotdb.db.exception.ModelInferenceProcessException;
import org.apache.iotdb.db.mpp.execution.operator.Operator;
import org.apache.iotdb.db.mpp.execution.operator.OperatorContext;
@@ -152,7 +153,10 @@ public class ForecastOperator implements ProcessOperator {
}
finished = true;
- return new
TsBlockSerde().deserialize(forecastResp.bufferForForecastResult());
+ TsBlock resultTsBlock =
+ new
TsBlockSerde().deserialize(forecastResp.bufferForForecastResult());
+ resultTsBlock = modifyTimeColumn(resultTsBlock);
+ return resultTsBlock;
} catch (InterruptedException e) {
LOGGER.warn(
"{}: interrupted when processing write operation future with
exception {}", this, e);
@@ -164,6 +168,21 @@ public class ForecastOperator implements ProcessOperator {
}
}
+ private TsBlock modifyTimeColumn(TsBlock resultTsBlock) {
+ long delta =
+
IoTDBDescriptor.getInstance().getConfig().getTimestampPrecision().equals("ms")
+ ? 1_000_000L
+ : 1_000L;
+
+ TsBlockBuilder newTsBlockBuilder =
TsBlockBuilder.createWithOnlyTimeColumn();
+ TimeColumnBuilder timeColumnBuilder =
newTsBlockBuilder.getTimeColumnBuilder();
+ for (int i = 0; i < resultTsBlock.getPositionCount(); i++) {
+ timeColumnBuilder.writeLong(resultTsBlock.getTimeByIndex(i) / delta);
+ newTsBlockBuilder.declarePosition();
+ }
+ return
newTsBlockBuilder.build().appendValueColumns(resultTsBlock.getValueColumns());
+ }
+
private void appendTsBlockToBuilder(TsBlock inputTsBlock) {
TimeColumnBuilder timeColumnBuilder =
inputTsBlockBuilder.getTimeColumnBuilder();
ColumnBuilder[] columnBuilders =
inputTsBlockBuilder.getValueColumnBuilders();
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
index 5034007063e..1310dc93f0e 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
@@ -24,6 +24,7 @@ import
org.apache.iotdb.commons.client.exception.ClientManagerException;
import org.apache.iotdb.commons.conf.IoTDBConstant;
import org.apache.iotdb.commons.exception.IllegalPathException;
import org.apache.iotdb.commons.exception.MetadataException;
+import org.apache.iotdb.commons.model.ForecastModeInformation;
import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.partition.DataPartition;
import org.apache.iotdb.commons.partition.DataPartitionQueryParam;
@@ -383,8 +384,15 @@ public class AnalyzeVisitor extends
StatementVisitor<Analysis, MPPQueryContext>
ModelInferenceFunction.valueOf(modelInferenceExpression.getFunctionName().toUpperCase());
switch (functionType) {
case FORECAST:
- ModelInferenceDescriptor modelInferenceDescriptor =
- new ForecastModelInferenceDescriptor(functionType,
modelInformation);
+ ForecastModelInferenceDescriptor modelInferenceDescriptor =
+ new ForecastModelInferenceDescriptor(
+ functionType, (ForecastModeInformation) modelInformation);
+ Map<String, String> modelInferenceAttributes =
+ modelInferenceExpression.getFunctionAttributes();
+ if (modelInferenceAttributes.containsKey("predict_length")) {
+ modelInferenceDescriptor.setExpectedPredictLength(
+
Integer.parseInt(modelInferenceAttributes.get("predict_length")));
+ }
analysis.setModelInferenceDescriptor(modelInferenceDescriptor);
List<ResultColumn> newResultColumns = new ArrayList<>();
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java
index dbc2e3068b9..98063c19b74 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java
@@ -35,12 +35,12 @@ import
org.apache.iotdb.commons.partition.executor.SeriesPartitionExecutor;
import org.apache.iotdb.commons.path.PathPatternTree;
import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionTableResp;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
+import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TSchemaNodeManagementReq;
import org.apache.iotdb.confignode.rpc.thrift.TSchemaNodeManagementResp;
import org.apache.iotdb.confignode.rpc.thrift.TSchemaPartitionReq;
import org.apache.iotdb.confignode.rpc.thrift.TSchemaPartitionTableResp;
-import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq;
-import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp;
import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList;
import org.apache.iotdb.db.client.ConfigNodeClient;
import org.apache.iotdb.db.client.ConfigNodeClientManager;
@@ -300,19 +300,17 @@ public class ClusterPartitionFetcher implements
IPartitionFetcher {
public ModelInformation getModelInformation(String modelId) {
try (ConfigNodeClient client =
configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID))
{
- TShowModelReq showModelReq = new TShowModelReq();
- showModelReq.setModelId(modelId);
- TShowModelResp showModelResp = client.showModel(showModelReq);
- if (showModelResp.getStatus().getCode() ==
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- if (showModelResp.modelInfoList.size() > 0) {
- return
ModelInformation.deserialize(showModelResp.modelInfoList.get(0));
+ TGetModelInfoResp getModelInfoResp = client.getModelInfo(new
TGetModelInfoReq(modelId));
+ if (getModelInfoResp.getStatus().getCode() ==
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ if (getModelInfoResp.modelInfo != null) {
+ return ModelInformation.deserialize(getModelInfoResp.modelInfo);
} else {
return null;
}
} else {
throw new StatementAnalyzeException(
"An error occurred when executing getModelInformation():"
- + showModelResp.getStatus().getMessage());
+ + getModelInfoResp.getStatus().getMessage());
}
} catch (ClientManagerException | TException e) {
throw new StatementAnalyzeException(
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index 3119043f3f6..92a8994cbe0 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -1956,6 +1956,20 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
Expression whereExpression = analysis.getWhereExpression();
String queryFilter = whereExpression == null ? null :
whereExpression.toString();
+ Map<String, String> modelConfigs = createModelStatement.getAttributes();
+ if (!modelConfigs.containsKey("input_type_list")) {
+ String inputTypeListStr =
analysis.getRespDatasetHeader().getRespDataTypeList().toString();
+ modelConfigs.put(
+ "input_type_list", inputTypeListStr.substring(1,
inputTypeListStr.length() - 1));
+ }
+ if (!modelConfigs.containsKey("predict_index_list")) {
+ StringBuilder predictIndexListStr = new StringBuilder("0");
+ for (int i = 1; i <
analysis.getRespDatasetHeader().getOutputValueColumnCount(); i++) {
+ predictIndexListStr.append(",").append(i);
+ }
+ modelConfigs.put("predict_index_list", predictIndexListStr.toString());
+ }
+
SettableFuture<ConfigTaskResult> future = SettableFuture.create();
try (ConfigNodeClient client =
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
@@ -1966,7 +1980,7 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
createModelReq.setIsAuto(createModelStatement.isAuto());
createModelReq.setQueryExpressions(queryExpressions);
createModelReq.setQueryFilter(queryFilter);
- createModelReq.setModelConfigs(createModelStatement.getAttributes());
+ createModelReq.setModelConfigs(modelConfigs);
final TSStatus executionStatus = client.createModel(createModelReq);
if (TSStatusCode.SUCCESS_STATUS.getStatusCode() !=
executionStatus.getCode()) {
LOGGER.warn(
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
index e40a87c976f..cc240b1e768 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
@@ -1678,6 +1678,8 @@ public class OperatorTreeGenerator extends
PlanVisitor<Operator, LocalExecutionP
+ (long) outputColumnNum *
DoubleColumn.SIZE_IN_BYTES_PER_POSITION)
* expectedPredictLength;
+ context.getTimeSliceAllocator().recordExecutionWeight(operatorContext, 1);
+
return new ForecastOperator(
operatorContext,
child,
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
index 97c1a8cab1e..d4aff88e07b 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
@@ -19,7 +19,7 @@
package org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model;
-import org.apache.iotdb.commons.model.ModelInformation;
+import org.apache.iotdb.commons.model.ForecastModeInformation;
import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
@@ -37,18 +37,22 @@ import static
org.apache.iotdb.db.constant.SqlConstant.PREDICT_LENGTH;
public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor
{
- private List<TSDataType> inputTypeList;
- private List<Integer> predictIndexList;
+ private final List<TSDataType> inputTypeList;
+ private final List<Integer> predictIndexList;
- private int modelInputLength;
- private int modelPredictLength;
+ private final int modelInputLength;
+ private final int modelPredictLength;
private int expectedPredictLength;
private LinkedHashMap<String, String> outputAttributes;
public ForecastModelInferenceDescriptor(
- ModelInferenceFunction functionType, ModelInformation modelInformation) {
+ ModelInferenceFunction functionType, ForecastModeInformation
modelInformation) {
super(functionType, modelInformation);
+ this.inputTypeList = modelInformation.getInputTypeList();
+ this.predictIndexList = modelInformation.getPredictIndexList();
+ this.modelInputLength = modelInformation.getInputLength();
+ this.expectedPredictLength = this.modelPredictLength =
modelInformation.getPredictLength();
}
public ForecastModelInferenceDescriptor(ByteBuffer buffer) {
@@ -72,18 +76,10 @@ public class ForecastModelInferenceDescriptor extends
ModelInferenceDescriptor {
return predictIndexList;
}
- public void setPredictIndexList(List<Integer> predictIndexList) {
- this.predictIndexList = predictIndexList;
- }
-
public List<TSDataType> getInputTypeList() {
return inputTypeList;
}
- public void setInputTypeList(List<TSDataType> inputTypeList) {
- this.inputTypeList = inputTypeList;
- }
-
public int getModelInputLength() {
return modelInputLength;
}
@@ -96,6 +92,10 @@ public class ForecastModelInferenceDescriptor extends
ModelInferenceDescriptor {
return expectedPredictLength;
}
+ public void setExpectedPredictLength(int expectedPredictLength) {
+ this.expectedPredictLength = expectedPredictLength;
+ }
+
@Override
public LinkedHashMap<String, String> getOutputAttributes() {
if (outputAttributes == null) {
diff --git a/thrift-confignode/src/main/thrift/confignode.thrift
b/thrift-confignode/src/main/thrift/confignode.thrift
index 8fa78991448..1453a7a2434 100644
--- a/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/thrift-confignode/src/main/thrift/confignode.thrift
@@ -731,6 +731,10 @@ struct TShowTrailReq {
2: optional string trailId
}
+struct TGetModelInfoReq {
+ 1: required string modelId
+}
+
struct TShowTrailResp {
1: required common.TSStatus status
2: required list<binary> trailInfoList
@@ -748,6 +752,11 @@ struct TUpdateModelStateReq {
3: optional string bestTrailId
}
+struct TGetModelInfoResp {
+ 1: required common.TSStatus status
+ 2: required binary modelInfo
+}
+
// ====================================================
// Quota
// ====================================================
@@ -1374,6 +1383,11 @@ service IConfigNodeRPCService {
*/
common.TSStatus updateModelState(TUpdateModelStateReq req)
+ /**
+ * Return the model info by model_id
+ */
+ TGetModelInfoResp getModelInfo(TGetModelInfoReq req)
+
// ======================================================
// Quota
// ======================================================
diff --git
a/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java
b/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java
index f4b88db27ca..f686a39c708 100644
---
a/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java
+++
b/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java
@@ -21,7 +21,9 @@ package org.apache.iotdb.tsfile.file.metadata.enums;
import org.apache.iotdb.tsfile.exception.write.UnSupportedDataTypeException;
import java.io.DataOutputStream;
+import java.io.FileOutputStream;
import java.io.IOException;
+import java.io.InputStream;
import java.nio.ByteBuffer;
public enum TSDataType {
@@ -91,6 +93,10 @@ public enum TSDataType {
return deserialize(buffer.get());
}
+ public static TSDataType deserializeFrom(InputStream stream) throws
IOException {
+ return deserialize((byte) stream.read());
+ }
+
public static int getSerializedSize() {
return Byte.BYTES;
}
@@ -103,6 +109,10 @@ public enum TSDataType {
outputStream.write(serialize());
}
+ public void serializeTo(FileOutputStream outputStream) throws IOException {
+ outputStream.write(serialize());
+ }
+
public int getDataTypeSize() {
switch (this) {
case BOOLEAN:
diff --git
a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java
b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java
index 9f37f4b780d..46ea0b667c9 100644
---
a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java
+++
b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java
@@ -147,6 +147,19 @@ public class TsBlock {
return wrapBlocksWithoutCopy(positionCount, timeColumn, newBlocks);
}
+ public TsBlock appendValueColumns(Column[] columns) {
+ Column[] newBlocks = Arrays.copyOf(valueColumns, valueColumns.length +
columns.length);
+ int newColumnIndex = valueColumns.length;
+ for (Column column : columns) {
+ requireNonNull(column, "Column is null");
+ if (positionCount != column.getPositionCount()) {
+ throw new IllegalArgumentException("Block does not have same position
count");
+ }
+ newBlocks[newColumnIndex++] = column;
+ }
+ return wrapBlocksWithoutCopy(positionCount, timeColumn, newBlocks);
+ }
+
/**
* Attention. This method uses System.arraycopy() to extend the valueColumn
array, so its
* performance is not ensured if you have many insert operations.
diff --git
a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java
b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java
index c309835a09f..cab5e90e3e0 100644
---
a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java
+++
b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java
@@ -81,6 +81,7 @@ public class TsBlockBuilder {
res.timeColumnBuilder =
new TimeColumnBuilder(
res.tsBlockBuilderStatus.createColumnBuilderStatus(),
DEFAULT_INITIAL_EXPECTED_ENTRIES);
+ res.valueColumnBuilders = new ColumnBuilder[0];
return res;
}
@@ -285,9 +286,6 @@ public class TsBlockBuilder {
}
public TsBlock build() {
- if (valueColumnBuilders.length == 0) {
- return new TsBlock(declaredPositions);
- }
TimeColumn timeColumn = (TimeColumn) timeColumnBuilder.build();
if (timeColumn.getPositionCount() != declaredPositions) {
throw new IllegalStateException(