This is an automated email from the ASF dual-hosted git repository. hui pushed a commit to branch lmh/forecastTest in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit a455bc6e0ccdd2b2b86a646b9b50f39be578e855 Author: Minghui Liu <[email protected]> AuthorDate: Mon May 22 10:42:01 2023 +0800 fix iotdb --- .../consensus/request/ConfigPhysicalPlan.java | 4 + .../consensus/request/ConfigPhysicalPlanType.java | 1 + .../request/read/model/GetModelInfoPlan.java | 79 ++++++++++++ .../consensus/response/model/GetModelInfoResp.java | 44 +++++++ .../response/{ => model}/ModelTableResp.java | 2 +- .../response/{ => model}/TrailTableResp.java | 2 +- .../iotdb/confignode/manager/ConfigManager.java | 10 ++ .../apache/iotdb/confignode/manager/IManager.java | 5 + .../iotdb/confignode/manager/ModelManager.java | 77 ++++++++++-- .../iotdb/confignode/persistence/ModelInfo.java | 32 ++++- .../persistence/executor/ConfigPlanExecutor.java | 3 + .../thrift/ConfigNodeRPCServiceProcessor.java | 7 ++ .../commons/model/ForecastModeInformation.java | 139 +++++++++++++++++++++ .../iotdb/commons/model/ModelInformation.java | 48 +++++-- .../apache/iotdb/db/client/ConfigNodeClient.java | 18 +++ .../operator/process/ml/ForecastOperator.java | 21 +++- .../iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java | 12 +- .../mpp/plan/analyze/ClusterPartitionFetcher.java | 16 ++- .../config/executor/ClusterConfigTaskExecutor.java | 16 ++- .../db/mpp/plan/planner/OperatorTreeGenerator.java | 2 + .../model/ForecastModelInferenceDescriptor.java | 28 ++--- .../src/main/thrift/confignode.thrift | 14 +++ .../tsfile/file/metadata/enums/TSDataType.java | 10 ++ .../iotdb/tsfile/read/common/block/TsBlock.java | 13 ++ .../tsfile/read/common/block/TsBlockBuilder.java | 4 +- 25 files changed, 548 insertions(+), 59 deletions(-) diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java index 2991f05ea65..51eef38d4f4 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java @@ -24,6 +24,7 @@ import org.apache.iotdb.confignode.consensus.request.read.database.CountDatabase import org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan; +import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan; @@ -456,6 +457,9 @@ public abstract class ConfigPhysicalPlan implements IConsensusRequest { case ShowTrail: plan = new ShowTrailPlan(); break; + case GetModelInfo: + plan = new GetModelInfoPlan(); + break; case CreatePipePlugin: plan = new CreatePipePluginPlan(); break; diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java index d36a9973cd5..9adace96172 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java @@ -162,6 +162,7 @@ public enum ConfigPhysicalPlanType { DropModel((short) 1203), ShowModel((short) 1204), ShowTrail((short) 1205), + GetModelInfo((short) 1206), /** Pipe Plugin */ CreatePipePlugin((short) 1300), diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java new file mode 100644 index 00000000000..79c7e9e0142 --- /dev/null +++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.read.model; + +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class GetModelInfoPlan extends ConfigPhysicalPlan { + + private String modelId; + + public GetModelInfoPlan() { + super(ConfigPhysicalPlanType.GetModelInfo); + } + + public GetModelInfoPlan(TGetModelInfoReq getModelInfoReq) { + super(ConfigPhysicalPlanType.GetModelInfo); + this.modelId = getModelInfoReq.getModelId(); + } + + public String getModelId() { + return modelId; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + ReadWriteIOUtils.write(modelId, stream); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + this.modelId = ReadWriteIOUtils.readString(buffer); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + GetModelInfoPlan that = (GetModelInfoPlan) o; + return Objects.equals(modelId, that.modelId); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), modelId); + } +} diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java new file mode 100644 index 00000000000..5e7ee2641fa --- /dev/null +++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.response.model; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; +import org.apache.iotdb.consensus.common.DataSet; + +import java.nio.ByteBuffer; + +public class GetModelInfoResp implements DataSet { + + private final TSStatus status; + private ByteBuffer serializedModelInformation; + + public GetModelInfoResp(TSStatus status) { + this.status = status; + } + + public void setModelInfo(ByteBuffer serializedModelInformation) { + this.serializedModelInformation = serializedModelInformation; + } + + public TGetModelInfoResp convertToThriftResponse() { + return new TGetModelInfoResp(status, serializedModelInformation); + } +} diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java similarity index 97% rename from confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java rename to confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java index 6642f76be97..9a23d9ed713 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.iotdb.confignode.consensus.response; +package org.apache.iotdb.confignode.consensus.response.model; import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.commons.model.ModelInformation; diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/TrailTableResp.java similarity index 97% rename from confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java rename to confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/TrailTableResp.java index 1f9a6b5acbf..3f2c2823dac 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/TrailTableResp.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.iotdb.confignode.consensus.response; +package org.apache.iotdb.confignode.consensus.response.model; import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.commons.model.TrailInformation; diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java index 4d9a4cdc8c8..7028240dd9e 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java @@ -131,6 +131,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; import org.apache.iotdb.confignode.rpc.thrift.TGetRegionIdReq; @@ -1850,6 +1852,14 @@ public class ConfigManager implements IManager { : status; } + @Override + public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { + TSStatus status = confirmLeader(); + return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() + ? modelManager.getModelInfo(req) + : new TGetModelInfoResp(status, null); + } + @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) { TSStatus status = confirmLeader(); diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java b/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java index 4a305f822bc..57c8ba1c1b9 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java @@ -74,6 +74,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; import org.apache.iotdb.confignode.rpc.thrift.TGetRegionIdReq; @@ -666,6 +668,9 @@ public interface IManager { /** Update the model state */ TSStatus updateModelState(TUpdateModelStateReq req); + /** Update the model state */ + TGetModelInfoResp getModelInfo(TGetModelInfoReq req); + /** Set space quota */ TSStatus setSpaceQuota(TSetSpaceQuotaReq req); } diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java index 425bafbd4c0..9a7c1080d96 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java @@ -19,17 +19,23 @@ package org.apache.iotdb.confignode.manager; +import org.apache.iotdb.common.rpc.thrift.ModelTask; import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.model.ForecastModeInformation; import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan; import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelStatePlan; -import org.apache.iotdb.confignode.consensus.response.ModelTableResp; -import org.apache.iotdb.confignode.consensus.response.TrailTableResp; +import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp; +import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp; +import org.apache.iotdb.confignode.consensus.response.model.TrailTableResp; import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq; import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp; import org.apache.iotdb.confignode.rpc.thrift.TShowTrailReq; @@ -39,12 +45,17 @@ import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelStateReq; import org.apache.iotdb.consensus.common.response.ConsensusReadResponse; import org.apache.iotdb.consensus.common.response.ConsensusWriteResponse; import org.apache.iotdb.rpc.TSStatusCode; +import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; public class ModelManager { @@ -63,15 +74,41 @@ public class ModelManager { } public TSStatus createModel(TCreateModelReq req) { - ModelInformation modelInformation = - new ModelInformation( - req.getModelId(), - req.getModelTask(), - req.getModelType(), - req.isIsAuto(), - req.getQueryExpressions(), - req.getQueryFilter()); - return configManager.getProcedureManager().createModel(modelInformation, req.getModelConfigs()); + ModelTask modelTask = req.getModelTask(); + Map<String, String> modelConfigs = req.getModelConfigs(); + ModelInformation modelInformation; + switch (modelTask) { + case FORECAST: + String inputTypeListStr = modelConfigs.get("input_type_list"); + List<TSDataType> inputTypeList = + Arrays.stream(inputTypeListStr.split(",")) + .sequential() + .map(s -> TSDataType.valueOf(s.toUpperCase())) + .collect(Collectors.toList()); + + String predictIndexListStr = modelConfigs.get("predict_index_list"); + List<Integer> predictIndexList = + Arrays.stream(predictIndexListStr.split(",")) + .sequential() + .map(Integer::valueOf) + .collect(Collectors.toList()); + + modelInformation = + new ForecastModeInformation( + req.getModelId(), + req.getModelType(), + req.isIsAuto(), + req.getQueryExpressions(), + req.getQueryFilter(), + inputTypeList, + predictIndexList, + Integer.parseInt(modelConfigs.getOrDefault("input_length", "96")), + Integer.parseInt(modelConfigs.getOrDefault("predict_length", "96"))); + break; + default: + throw new IllegalArgumentException("Invalid task type: " + modelTask); + } + return configManager.getProcedureManager().createModel(modelInformation, modelConfigs); } public TSStatus dropModel(TDropModelReq req) { @@ -126,7 +163,7 @@ public class ModelManager { return new TShowModelResp(res, Collections.emptyList()); } } catch (IOException e) { - LOGGER.error("Fail to get ModelTable", e); + LOGGER.warn("Fail to get ModelTable", e); return new TShowModelResp( new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) .setMessage(e.getMessage()), @@ -148,11 +185,25 @@ public class ModelManager { return new TShowTrailResp(res, Collections.emptyList()); } } catch (IOException e) { - LOGGER.error("Fail to get TrailTable", e); + LOGGER.warn("Fail to get TrailTable", e); return new TShowTrailResp( new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) .setMessage(e.getMessage()), Collections.emptyList()); } } + + public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { + ConsensusReadResponse response = + configManager.getConsensusManager().read(new GetModelInfoPlan(req)); + if (response.getDataset() != null) { + return ((GetModelInfoResp) response.getDataset()).convertToThriftResponse(); + } else { + LOGGER.warn("Unexpected error happened while getting model: ", response.getException()); + // consensus layer related errors + TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); + res.setMessage(response.getException().toString()); + return new TGetModelInfoResp(res, null); + } + } } diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java index 3c72e09570f..1e14bb6b137 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java @@ -24,15 +24,18 @@ import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.model.ModelTable; import org.apache.iotdb.commons.model.TrailInformation; import org.apache.iotdb.commons.snapshot.SnapshotProcessor; +import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan; import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelStatePlan; -import org.apache.iotdb.confignode.consensus.response.ModelTableResp; -import org.apache.iotdb.confignode.consensus.response.TrailTableResp; +import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp; +import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp; +import org.apache.iotdb.confignode.consensus.response.model.TrailTableResp; import org.apache.iotdb.rpc.TSStatusCode; +import org.apache.iotdb.tsfile.utils.PublicBAOS; import org.apache.thrift.TException; import org.slf4j.Logger; @@ -40,10 +43,12 @@ import org.slf4j.LoggerFactory; import javax.annotation.concurrent.ThreadSafe; +import java.io.DataOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.concurrent.locks.ReentrantLock; @ThreadSafe @@ -155,6 +160,29 @@ public class ModelInfo implements SnapshotProcessor { } } + public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) { + acquireModelTableLock(); + try { + GetModelInfoResp getModelInfoResp = + new GetModelInfoResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); + ModelInformation modelInformation = modelTable.getModelInformationById(plan.getModelId()); + if (modelInformation != null) { + PublicBAOS buffer = new PublicBAOS(); + DataOutputStream stream = new DataOutputStream(buffer); + modelInformation.serialize(stream); + getModelInfoResp.setModelInfo(ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size())); + } + return getModelInfoResp; + } catch (IOException e) { + LOGGER.warn("Fail to get model info", e); + return new GetModelInfoResp( + new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) + .setMessage(e.getMessage())); + } finally { + releaseModelTableLock(); + } + } + public TSStatus updateModelInfo(UpdateModelInfoPlan plan) { acquireModelTableLock(); try { diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java index 2f812b173b2..44527d67817 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java @@ -29,6 +29,7 @@ import org.apache.iotdb.confignode.consensus.request.auth.AuthorPlan; import org.apache.iotdb.confignode.consensus.request.read.database.CountDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; +import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.model.ShowTrailPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan; @@ -278,6 +279,8 @@ public class ConfigPlanExecutor { return modelInfo.showModel((ShowModelPlan) req); case ShowTrail: return modelInfo.showTrail((ShowTrailPlan) req); + case GetModelInfo: + return modelInfo.getModelInfo((GetModelInfoPlan) req); case GetPipePluginTable: return pipeInfo.getPipePluginInfo().showPipePlugins(); case GetPipePluginJar: diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java b/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java index 22367d7c164..1aa9a29f95f 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java @@ -110,6 +110,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipeSinkReq; @@ -990,6 +992,11 @@ public class ConfigNodeRPCServiceProcessor implements IConfigNodeRPCService.Ifac return configManager.updateModelState(req); } + @Override + public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { + return configManager.getModelInfo(req); + } + @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException { return configManager.setSpaceQuota(req); diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastModeInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastModeInformation.java new file mode 100644 index 00000000000..f32a0d5190f --- /dev/null +++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastModeInformation.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.model; + +import org.apache.iotdb.common.rpc.thrift.ModelTask; +import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType; +import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils; + +import javax.annotation.Nullable; + +import java.io.DataOutputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public class ForecastModeInformation extends ModelInformation { + + private final List<TSDataType> inputTypeList; + + private final List<Integer> predictIndexList; + + private final int inputLength; + private final int predictLength; + + public ForecastModeInformation( + String modelId, + String modelType, + boolean isAuto, + List<String> queryExpressions, + @Nullable String queryFilter, + List<TSDataType> inputTypeList, + List<Integer> predictIndexList, + int inputLength, + int predictLength) { + super(ModelTask.FORECAST, modelId, modelType, isAuto, queryExpressions, queryFilter); + this.inputTypeList = inputTypeList; + this.predictIndexList = predictIndexList; + this.inputLength = inputLength; + this.predictLength = predictLength; + } + + public ForecastModeInformation(ByteBuffer buffer) { + super(ModelTask.FORECAST, buffer); + int listSize = ReadWriteIOUtils.readInt(buffer); + this.inputTypeList = new ArrayList<>(listSize); + for (int i = 0; i < listSize; i++) { + inputTypeList.add(TSDataType.deserializeFrom(buffer)); + } + listSize = ReadWriteIOUtils.readInt(buffer); + this.predictIndexList = new ArrayList<>(listSize); + for (int i = 0; i < listSize; i++) { + predictIndexList.add(ReadWriteIOUtils.readInt(buffer)); + } + this.inputLength = ReadWriteIOUtils.readInt(buffer); + this.predictLength = ReadWriteIOUtils.readInt(buffer); + } + + public ForecastModeInformation(InputStream stream) throws IOException { + super(ModelTask.FORECAST, stream); + int listSize = ReadWriteIOUtils.readInt(stream); + this.inputTypeList = new ArrayList<>(listSize); + for (int i = 0; i < listSize; i++) { + inputTypeList.add(TSDataType.deserializeFrom(stream)); + } + listSize = ReadWriteIOUtils.readInt(stream); + this.predictIndexList = new ArrayList<>(listSize); + for (int i = 0; i < listSize; i++) { + predictIndexList.add(ReadWriteIOUtils.readInt(stream)); + } + this.inputLength = ReadWriteIOUtils.readInt(stream); + this.predictLength = ReadWriteIOUtils.readInt(stream); + } + + public List<TSDataType> getInputTypeList() { + return inputTypeList; + } + + public List<Integer> getPredictIndexList() { + return predictIndexList; + } + + public int getInputLength() { + return inputLength; + } + + public int getPredictLength() { + return predictLength; + } + + @Override + public void serialize(DataOutputStream stream) throws IOException { + super.serialize(stream); + ReadWriteIOUtils.write(inputTypeList.size(), stream); + for (TSDataType inputType : inputTypeList) { + inputType.serializeTo(stream); + } + ReadWriteIOUtils.write(predictIndexList.size(), stream); + for (Integer index : predictIndexList) { + ReadWriteIOUtils.write(index, stream); + } + ReadWriteIOUtils.write(inputLength, stream); + ReadWriteIOUtils.write(predictLength, stream); + } + + @Override + public void serialize(FileOutputStream stream) throws IOException { + super.serialize(stream); + ReadWriteIOUtils.write(inputTypeList.size(), stream); + for (TSDataType inputType : inputTypeList) { + inputType.serializeTo(stream); + } + ReadWriteIOUtils.write(predictIndexList.size(), stream); + for (Integer index : predictIndexList) { + ReadWriteIOUtils.write(index, stream); + } + ReadWriteIOUtils.write(inputLength, stream); + ReadWriteIOUtils.write(predictLength, stream); + } +} diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java index 7c46a6c3437..fc0cdb45dde 100644 --- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java +++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java @@ -39,7 +39,7 @@ import java.util.Map; import static org.apache.iotdb.commons.model.TrailInformation.MODEL_PATH; -public class ModelInformation { +public abstract class ModelInformation { private final String modelId; private final ModelTask modelTask; @@ -55,8 +55,8 @@ public class ModelInformation { private final Map<String, TrailInformation> trailMap; public ModelInformation( - String modelId, ModelTask modelTask, + String modelId, String modelType, boolean isAuto, List<String> queryExpressions, @@ -71,9 +71,10 @@ public class ModelInformation { this.trailMap = new HashMap<>(); } - public ModelInformation(ByteBuffer buffer) { + public ModelInformation(ModelTask modelTask, ByteBuffer buffer) { + this.modelTask = modelTask; + this.modelId = ReadWriteIOUtils.readString(buffer); - this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(buffer)); this.modelType = ReadWriteIOUtils.readString(buffer); int listSize = ReadWriteIOUtils.readInt(buffer); @@ -103,9 +104,10 @@ public class ModelInformation { } } - public ModelInformation(InputStream stream) throws IOException { + public ModelInformation(ModelTask modelTask, InputStream stream) throws IOException { + this.modelTask = modelTask; + this.modelId = ReadWriteIOUtils.readString(stream); - this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(stream)); this.modelType = ReadWriteIOUtils.readString(stream); int listSize = ReadWriteIOUtils.readInt(stream); @@ -200,8 +202,9 @@ public class ModelInformation { } public void serialize(DataOutputStream stream) throws IOException { - ReadWriteIOUtils.write(modelId, stream); ReadWriteIOUtils.write(modelTask.ordinal(), stream); + + ReadWriteIOUtils.write(modelId, stream); ReadWriteIOUtils.write(modelType, stream); ReadWriteIOUtils.write(queryExpressions.size(), stream); for (String queryExpression : queryExpressions) { @@ -232,8 +235,9 @@ public class ModelInformation { } public void serialize(FileOutputStream stream) throws IOException { - ReadWriteIOUtils.write(modelId, stream); ReadWriteIOUtils.write(modelTask.ordinal(), stream); + + ReadWriteIOUtils.write(modelId, stream); ReadWriteIOUtils.write(modelType, stream); ReadWriteIOUtils.write(queryExpressions.size(), stream); @@ -264,12 +268,32 @@ public class ModelInformation { } } - public static ModelInformation deserialize(InputStream stream) throws IOException { - return new ModelInformation(stream); + public static ModelInformation deserialize(ByteBuffer buffer) { + ModelTask modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(buffer)); + if (modelTask == null) { + throw new IllegalArgumentException(); + } + + switch (modelTask) { + case FORECAST: + return new ForecastModeInformation(buffer); + default: + throw new IllegalArgumentException("Invalid task type: " + modelTask); + } } - public static ModelInformation deserialize(ByteBuffer buffer) { - return new ModelInformation(buffer); + public static ModelInformation deserialize(InputStream stream) throws IOException { + ModelTask modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(stream)); + if (modelTask == null) { + throw new IllegalArgumentException(); + } + + switch (modelTask) { + case FORECAST: + return new ForecastModeInformation(stream); + default: + throw new IllegalArgumentException("Invalid task type: " + modelTask); + } } public ByteBuffer serializeShowModelResult() throws IOException { diff --git a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java index ca35b446fc3..8cae68b3d12 100644 --- a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java +++ b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java @@ -79,6 +79,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipeSinkReq; @@ -2122,6 +2124,22 @@ public class ConfigNodeClient implements IConfigNodeRPCService.Iface, ThriftClie throw new TException(new UnsupportedOperationException().getCause()); } + @Override + public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) throws TException { + for (int i = 0; i < RETRY_NUM; i++) { + try { + TGetModelInfoResp resp = client.getModelInfo(req); + if (!updateConfigNodeLeader(resp.getStatus())) { + return resp; + } + } catch (TException e) { + configLeader = null; + } + waitAndReconnect(); + } + throw new TException(MSG_RECONNECTION_FAIL); + } + @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException { for (int i = 0; i < RETRY_NUM; i++) { diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java index 5aac6830fbc..b7375985942 100644 --- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java +++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java @@ -20,6 +20,7 @@ package org.apache.iotdb.db.mpp.execution.operator.process.ml; import org.apache.iotdb.db.client.MLNodeClient; +import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.db.exception.ModelInferenceProcessException; import org.apache.iotdb.db.mpp.execution.operator.Operator; import org.apache.iotdb.db.mpp.execution.operator.OperatorContext; @@ -152,7 +153,10 @@ public class ForecastOperator implements ProcessOperator { } finished = true; - return new TsBlockSerde().deserialize(forecastResp.bufferForForecastResult()); + TsBlock resultTsBlock = + new TsBlockSerde().deserialize(forecastResp.bufferForForecastResult()); + resultTsBlock = modifyTimeColumn(resultTsBlock); + return resultTsBlock; } catch (InterruptedException e) { LOGGER.warn( "{}: interrupted when processing write operation future with exception {}", this, e); @@ -164,6 +168,21 @@ public class ForecastOperator implements ProcessOperator { } } + private TsBlock modifyTimeColumn(TsBlock resultTsBlock) { + long delta = + IoTDBDescriptor.getInstance().getConfig().getTimestampPrecision().equals("ms") + ? 1_000_000L + : 1_000L; + + TsBlockBuilder newTsBlockBuilder = TsBlockBuilder.createWithOnlyTimeColumn(); + TimeColumnBuilder timeColumnBuilder = newTsBlockBuilder.getTimeColumnBuilder(); + for (int i = 0; i < resultTsBlock.getPositionCount(); i++) { + timeColumnBuilder.writeLong(resultTsBlock.getTimeByIndex(i) / delta); + newTsBlockBuilder.declarePosition(); + } + return newTsBlockBuilder.build().appendValueColumns(resultTsBlock.getValueColumns()); + } + private void appendTsBlockToBuilder(TsBlock inputTsBlock) { TimeColumnBuilder timeColumnBuilder = inputTsBlockBuilder.getTimeColumnBuilder(); ColumnBuilder[] columnBuilders = inputTsBlockBuilder.getValueColumnBuilders(); diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java index 5034007063e..1310dc93f0e 100644 --- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java +++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java @@ -24,6 +24,7 @@ import org.apache.iotdb.commons.client.exception.ClientManagerException; import org.apache.iotdb.commons.conf.IoTDBConstant; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.MetadataException; +import org.apache.iotdb.commons.model.ForecastModeInformation; import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; @@ -383,8 +384,15 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext> ModelInferenceFunction.valueOf(modelInferenceExpression.getFunctionName().toUpperCase()); switch (functionType) { case FORECAST: - ModelInferenceDescriptor modelInferenceDescriptor = - new ForecastModelInferenceDescriptor(functionType, modelInformation); + ForecastModelInferenceDescriptor modelInferenceDescriptor = + new ForecastModelInferenceDescriptor( + functionType, (ForecastModeInformation) modelInformation); + Map<String, String> modelInferenceAttributes = + modelInferenceExpression.getFunctionAttributes(); + if (modelInferenceAttributes.containsKey("predict_length")) { + modelInferenceDescriptor.setExpectedPredictLength( + Integer.parseInt(modelInferenceAttributes.get("predict_length"))); + } analysis.setModelInferenceDescriptor(modelInferenceDescriptor); List<ResultColumn> newResultColumns = new ArrayList<>(); diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java index dbc2e3068b9..98063c19b74 100644 --- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java +++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ClusterPartitionFetcher.java @@ -35,12 +35,12 @@ import org.apache.iotdb.commons.partition.executor.SeriesPartitionExecutor; import org.apache.iotdb.commons.path.PathPatternTree; import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionReq; import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionTableResp; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TSchemaNodeManagementReq; import org.apache.iotdb.confignode.rpc.thrift.TSchemaNodeManagementResp; import org.apache.iotdb.confignode.rpc.thrift.TSchemaPartitionReq; import org.apache.iotdb.confignode.rpc.thrift.TSchemaPartitionTableResp; -import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq; -import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp; import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList; import org.apache.iotdb.db.client.ConfigNodeClient; import org.apache.iotdb.db.client.ConfigNodeClientManager; @@ -300,19 +300,17 @@ public class ClusterPartitionFetcher implements IPartitionFetcher { public ModelInformation getModelInformation(String modelId) { try (ConfigNodeClient client = configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - TShowModelReq showModelReq = new TShowModelReq(); - showModelReq.setModelId(modelId); - TShowModelResp showModelResp = client.showModel(showModelReq); - if (showModelResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - if (showModelResp.modelInfoList.size() > 0) { - return ModelInformation.deserialize(showModelResp.modelInfoList.get(0)); + TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelId)); + if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + if (getModelInfoResp.modelInfo != null) { + return ModelInformation.deserialize(getModelInfoResp.modelInfo); } else { return null; } } else { throw new StatementAnalyzeException( "An error occurred when executing getModelInformation():" - + showModelResp.getStatus().getMessage()); + + getModelInfoResp.getStatus().getMessage()); } } catch (ClientManagerException | TException e) { throw new StatementAnalyzeException( diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java index 9c48fdf7e15..2b0c045d0c2 100644 --- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java +++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java @@ -1929,6 +1929,20 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor { Expression whereExpression = analysis.getWhereExpression(); String queryFilter = whereExpression == null ? null : whereExpression.toString(); + Map<String, String> modelConfigs = createModelStatement.getAttributes(); + if (!modelConfigs.containsKey("input_type_list")) { + String inputTypeListStr = analysis.getRespDatasetHeader().getRespDataTypeList().toString(); + modelConfigs.put( + "input_type_list", inputTypeListStr.substring(1, inputTypeListStr.length() - 1)); + } + if (!modelConfigs.containsKey("predict_index_list")) { + StringBuilder predictIndexListStr = new StringBuilder("0"); + for (int i = 1; i < analysis.getRespDatasetHeader().getOutputValueColumnCount(); i++) { + predictIndexListStr.append(",").append(i); + } + modelConfigs.put("predict_index_list", predictIndexListStr.toString()); + } + SettableFuture<ConfigTaskResult> future = SettableFuture.create(); try (ConfigNodeClient client = CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { @@ -1939,7 +1953,7 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor { createModelReq.setIsAuto(createModelStatement.isAuto()); createModelReq.setQueryExpressions(queryExpressions); createModelReq.setQueryFilter(queryFilter); - createModelReq.setModelConfigs(createModelStatement.getAttributes()); + createModelReq.setModelConfigs(modelConfigs); final TSStatus executionStatus = client.createModel(createModelReq); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) { LOGGER.warn( diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java index e40a87c976f..cc240b1e768 100644 --- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java +++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java @@ -1678,6 +1678,8 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP + (long) outputColumnNum * DoubleColumn.SIZE_IN_BYTES_PER_POSITION) * expectedPredictLength; + context.getTimeSliceAllocator().recordExecutionWeight(operatorContext, 1); + return new ForecastOperator( operatorContext, child, diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java index 97c1a8cab1e..d4aff88e07b 100644 --- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java +++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java @@ -19,7 +19,7 @@ package org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model; -import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.commons.model.ForecastModeInformation; import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction; import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType; import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils; @@ -37,18 +37,22 @@ import static org.apache.iotdb.db.constant.SqlConstant.PREDICT_LENGTH; public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor { - private List<TSDataType> inputTypeList; - private List<Integer> predictIndexList; + private final List<TSDataType> inputTypeList; + private final List<Integer> predictIndexList; - private int modelInputLength; - private int modelPredictLength; + private final int modelInputLength; + private final int modelPredictLength; private int expectedPredictLength; private LinkedHashMap<String, String> outputAttributes; public ForecastModelInferenceDescriptor( - ModelInferenceFunction functionType, ModelInformation modelInformation) { + ModelInferenceFunction functionType, ForecastModeInformation modelInformation) { super(functionType, modelInformation); + this.inputTypeList = modelInformation.getInputTypeList(); + this.predictIndexList = modelInformation.getPredictIndexList(); + this.modelInputLength = modelInformation.getInputLength(); + this.expectedPredictLength = this.modelPredictLength = modelInformation.getPredictLength(); } public ForecastModelInferenceDescriptor(ByteBuffer buffer) { @@ -72,18 +76,10 @@ public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor { return predictIndexList; } - public void setPredictIndexList(List<Integer> predictIndexList) { - this.predictIndexList = predictIndexList; - } - public List<TSDataType> getInputTypeList() { return inputTypeList; } - public void setInputTypeList(List<TSDataType> inputTypeList) { - this.inputTypeList = inputTypeList; - } - public int getModelInputLength() { return modelInputLength; } @@ -96,6 +92,10 @@ public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor { return expectedPredictLength; } + public void setExpectedPredictLength(int expectedPredictLength) { + this.expectedPredictLength = expectedPredictLength; + } + @Override public LinkedHashMap<String, String> getOutputAttributes() { if (outputAttributes == null) { diff --git a/thrift-confignode/src/main/thrift/confignode.thrift b/thrift-confignode/src/main/thrift/confignode.thrift index 8fa78991448..1453a7a2434 100644 --- a/thrift-confignode/src/main/thrift/confignode.thrift +++ b/thrift-confignode/src/main/thrift/confignode.thrift @@ -731,6 +731,10 @@ struct TShowTrailReq { 2: optional string trailId } +struct TGetModelInfoReq { + 1: required string modelId +} + struct TShowTrailResp { 1: required common.TSStatus status 2: required list<binary> trailInfoList @@ -748,6 +752,11 @@ struct TUpdateModelStateReq { 3: optional string bestTrailId } +struct TGetModelInfoResp { + 1: required common.TSStatus status + 2: required binary modelInfo +} + // ==================================================== // Quota // ==================================================== @@ -1374,6 +1383,11 @@ service IConfigNodeRPCService { */ common.TSStatus updateModelState(TUpdateModelStateReq req) + /** + * Return the model info by model_id + */ + TGetModelInfoResp getModelInfo(TGetModelInfoReq req) + // ====================================================== // Quota // ====================================================== diff --git a/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java b/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java index f4b88db27ca..f686a39c708 100644 --- a/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java +++ b/tsfile/src/main/java/org/apache/iotdb/tsfile/file/metadata/enums/TSDataType.java @@ -21,7 +21,9 @@ package org.apache.iotdb.tsfile.file.metadata.enums; import org.apache.iotdb.tsfile.exception.write.UnSupportedDataTypeException; import java.io.DataOutputStream; +import java.io.FileOutputStream; import java.io.IOException; +import java.io.InputStream; import java.nio.ByteBuffer; public enum TSDataType { @@ -91,6 +93,10 @@ public enum TSDataType { return deserialize(buffer.get()); } + public static TSDataType deserializeFrom(InputStream stream) throws IOException { + return deserialize((byte) stream.read()); + } + public static int getSerializedSize() { return Byte.BYTES; } @@ -103,6 +109,10 @@ public enum TSDataType { outputStream.write(serialize()); } + public void serializeTo(FileOutputStream outputStream) throws IOException { + outputStream.write(serialize()); + } + public int getDataTypeSize() { switch (this) { case BOOLEAN: diff --git a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java index 9f37f4b780d..46ea0b667c9 100644 --- a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java +++ b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlock.java @@ -147,6 +147,19 @@ public class TsBlock { return wrapBlocksWithoutCopy(positionCount, timeColumn, newBlocks); } + public TsBlock appendValueColumns(Column[] columns) { + Column[] newBlocks = Arrays.copyOf(valueColumns, valueColumns.length + columns.length); + int newColumnIndex = valueColumns.length; + for (Column column : columns) { + requireNonNull(column, "Column is null"); + if (positionCount != column.getPositionCount()) { + throw new IllegalArgumentException("Block does not have same position count"); + } + newBlocks[newColumnIndex++] = column; + } + return wrapBlocksWithoutCopy(positionCount, timeColumn, newBlocks); + } + /** * Attention. This method uses System.arraycopy() to extend the valueColumn array, so its * performance is not ensured if you have many insert operations. diff --git a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java index c309835a09f..cab5e90e3e0 100644 --- a/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java +++ b/tsfile/src/main/java/org/apache/iotdb/tsfile/read/common/block/TsBlockBuilder.java @@ -81,6 +81,7 @@ public class TsBlockBuilder { res.timeColumnBuilder = new TimeColumnBuilder( res.tsBlockBuilderStatus.createColumnBuilderStatus(), DEFAULT_INITIAL_EXPECTED_ENTRIES); + res.valueColumnBuilders = new ColumnBuilder[0]; return res; } @@ -285,9 +286,6 @@ public class TsBlockBuilder { } public TsBlock build() { - if (valueColumnBuilders.length == 0) { - return new TsBlock(declaredPositions); - } TimeColumn timeColumn = (TimeColumn) timeColumnBuilder.build(); if (timeColumn.getPositionCount() != declaredPositions) { throw new IllegalStateException(
