This is an automated email from the ASF dual-hosted git repository. hui pushed a commit to branch lmh/MLSQL in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 71fb614bb14cb9fdfbe87a2053628ba5e676fb96 Author: Minghui Liu <[email protected]> AuthorDate: Thu Mar 23 15:17:35 2023 +0800 fix bug & finish --- .../iotdb/confignode/persistence/ModelInfo.java | 2 +- .../procedure/impl/model/CreateModelProcedure.java | 2 +- .../procedure/impl/model/DropModelProcedure.java | 27 +------- .../procedure/state/model/DropModelState.java | 1 - .../procedure/store/ProcedureFactory.java | 4 ++ mlnode/iotdb/mlnode/service.py | 2 +- .../iotdb/commons/model/ModelInformation.java | 79 ++++++++++++++++++---- .../org/apache/iotdb/db/client/MLNodeClient.java | 18 +++-- .../impl/DataNodeInternalRPCServiceImpl.java | 3 +- 9 files changed, 88 insertions(+), 50 deletions(-) diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java index 13e4dabe73..3c72e09570 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java @@ -84,7 +84,7 @@ public class ModelInfo implements SnapshotProcessor { return new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) .setMessage(errorMessage); } - return null; + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); } public TSStatus dropModel(DropModelPlan plan) { diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java index 0a4d306fde..7dff5fe06e 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java @@ -49,7 +49,7 @@ import java.util.Objects; public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState> { private static final Logger LOGGER = LoggerFactory.getLogger(CreateModelProcedure.class); - private static final int RETRY_THRESHOLD = 5; + private static final int RETRY_THRESHOLD = 1; private ModelInformation modelInformation; private Map<String, String> modelConfigs; diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java index 1268368bd6..bfa461a8b2 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java @@ -21,19 +21,16 @@ package org.apache.iotdb.confignode.procedure.impl.model; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.common.rpc.thrift.TrainingState; import org.apache.iotdb.commons.model.exception.ModelManagementException; import org.apache.iotdb.confignode.client.DataNodeRequestType; import org.apache.iotdb.confignode.client.sync.SyncDataNodeClientPool; import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelStatePlan; import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; import org.apache.iotdb.confignode.procedure.exception.ProcedureException; import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; import org.apache.iotdb.confignode.procedure.state.model.DropModelState; import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelStateReq; import org.apache.iotdb.consensus.common.response.ConsensusWriteResponse; import org.apache.iotdb.db.client.MLNodeClient; import org.apache.iotdb.mpp.rpc.thrift.TDeleteModelMetricsReq; @@ -53,7 +50,7 @@ import java.util.Optional; public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> { private static final Logger LOGGER = LoggerFactory.getLogger(DropModelProcedure.class); - private static final int RETRY_THRESHOLD = 5; + private static final int RETRY_THRESHOLD = 1; private String modelId; @@ -87,25 +84,6 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> { break; case VALIDATED: - LOGGER.info("Change state of model [{}] to DROPPING", modelId); - - ConsensusWriteResponse response = - env.getConfigManager() - .getConsensusManager() - .write( - new UpdateModelStatePlan( - new TUpdateModelStateReq(modelId, TrainingState.DROPPING))); - if (!response.isSuccessful()) { - throw new ModelManagementException( - String.format( - "Failed to drop model [%s], fail to modify model state: %s", - modelId, response.getErrorMessage())); - } - - setNextState(DropModelState.CONFIG_NODE_DROPPING); - break; - - case CONFIG_NODE_DROPPING: LOGGER.info("Start to drop model metrics [{}] on Data Nodes", modelId); Optional<TDataNodeLocation> targetDataNode = @@ -153,7 +131,8 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> { case ML_NODE_DROPPED: LOGGER.info("Start to drop model [{}] on Config Nodes", modelId); - response = env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelId)); + ConsensusWriteResponse response = + env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelId)); if (!response.isSuccessful()) { throw new ModelManagementException( String.format( diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java index 5f8c5a6f6e..54e32e86da 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java @@ -22,7 +22,6 @@ package org.apache.iotdb.confignode.procedure.state.model; public enum DropModelState { INIT, VALIDATED, - CONFIG_NODE_DROPPING, DATA_NODE_DROPPED, ML_NODE_DROPPED, CONFIG_NODE_DROPPED diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java index 48a4cfd997..4c026cb4f2 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java @@ -158,6 +158,10 @@ public class ProcedureFactory implements IProcedureFactory { return ProcedureType.DEACTIVATE_TEMPLATE_PROCEDURE; } else if (procedure instanceof UnsetTemplateProcedure) { return ProcedureType.UNSET_TEMPLATE_PROCEDURE; + } else if (procedure instanceof CreateModelProcedure) { + return ProcedureType.CREATE_MODEL_PROCEDURE; + } else if (procedure instanceof DropModelProcedure) { + return ProcedureType.DROP_MODEL_PROCEDURE; } return null; } diff --git a/mlnode/iotdb/mlnode/service.py b/mlnode/iotdb/mlnode/service.py index 8314dc363e..a2c05ea5c3 100644 --- a/mlnode/iotdb/mlnode/service.py +++ b/mlnode/iotdb/mlnode/service.py @@ -33,7 +33,7 @@ class RPCService(threading.Thread): super().__init__() processor = IMLNodeRPCService.Processor(handler=MLNodeRPCServiceHandler()) transport = TSocket.TServerSocket(host=config.get_mn_rpc_address(), port=config.get_mn_rpc_port()) - transport_factory = TTransport.TBufferedTransportFactory() + transport_factory = TTransport.TFramedTransportFactory() protocol_factory = TCompactProtocol.TCompactProtocolFactory() self.__pool_server = TServer.TThreadPoolServer(processor, transport, transport_factory, protocol_factory) diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java index e6fbf13c95..a8cff6968d 100644 --- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java +++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java @@ -24,6 +24,8 @@ import org.apache.iotdb.common.rpc.thrift.TrainingState; import org.apache.iotdb.tsfile.utils.PublicBAOS; import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils; +import javax.annotation.Nullable; + import java.io.DataOutputStream; import java.io.FileOutputStream; import java.io.IOException; @@ -44,12 +46,12 @@ public class ModelInformation { private final String modelType; private final List<String> queryExpressions; - private final String queryFilter; + @Nullable private String queryFilter; private final boolean isAuto; private TrainingState trainingState; - private String bestTrailId; + @Nullable private String bestTrailId; private final Map<String, TrailInformation> trailMap; public ModelInformation( @@ -58,11 +60,12 @@ public class ModelInformation { String modelType, boolean isAuto, List<String> queryExpressions, - String queryFilter) { + @Nullable String queryFilter) { this.modelId = modelId; this.modelTask = modelTask; this.modelType = modelType; this.isAuto = isAuto; + this.trainingState = TrainingState.PENDING; this.queryExpressions = queryExpressions; this.queryFilter = queryFilter; this.trailMap = new HashMap<>(); @@ -79,10 +82,18 @@ public class ModelInformation { this.queryExpressions.add(ReadWriteIOUtils.readString(buffer)); } - this.queryFilter = ReadWriteIOUtils.readString(buffer); + byte isNull = ReadWriteIOUtils.readByte(buffer); + if (isNull == 1) { + this.queryFilter = ReadWriteIOUtils.readString(buffer); + } + this.isAuto = ReadWriteIOUtils.readBool(buffer); this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(buffer)); - this.bestTrailId = ReadWriteIOUtils.readString(buffer); + + isNull = ReadWriteIOUtils.readByte(buffer); + if (isNull == 1) { + this.bestTrailId = ReadWriteIOUtils.readString(buffer); + } int mapSize = ReadWriteIOUtils.readInt(buffer); this.trailMap = new HashMap<>(); @@ -103,10 +114,18 @@ public class ModelInformation { this.queryExpressions.add(ReadWriteIOUtils.readString(stream)); } - this.queryFilter = ReadWriteIOUtils.readString(stream); + byte isNull = ReadWriteIOUtils.readByte(stream); + if (isNull == 1) { + this.queryFilter = ReadWriteIOUtils.readString(stream); + } + this.isAuto = ReadWriteIOUtils.readBool(stream); this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(stream)); - this.bestTrailId = ReadWriteIOUtils.readString(stream); + + isNull = ReadWriteIOUtils.readByte(stream); + if (isNull == 1) { + this.bestTrailId = ReadWriteIOUtils.readString(stream); + } int mapSize = ReadWriteIOUtils.readInt(stream); this.trailMap = new HashMap<>(); @@ -128,6 +147,7 @@ public class ModelInformation { return queryExpressions; } + @Nullable public String getQueryFilter() { return queryFilter; } @@ -174,10 +194,24 @@ public class ModelInformation { for (String queryExpression : queryExpressions) { ReadWriteIOUtils.write(queryExpression, stream); } - ReadWriteIOUtils.write(queryFilter, stream); + + if (queryFilter == null) { + ReadWriteIOUtils.write((byte) 0, stream); + } else { + ReadWriteIOUtils.write((byte) 1, stream); + ReadWriteIOUtils.write(queryFilter, stream); + } + ReadWriteIOUtils.write(isAuto, stream); ReadWriteIOUtils.write(trainingState.ordinal(), stream); - ReadWriteIOUtils.write(bestTrailId, stream); + + if (bestTrailId == null) { + ReadWriteIOUtils.write((byte) 0, stream); + } else { + ReadWriteIOUtils.write((byte) 1, stream); + ReadWriteIOUtils.write(bestTrailId, stream); + } + ReadWriteIOUtils.write(trailMap.size(), stream); for (TrailInformation trailInformation : trailMap.values()) { trailInformation.serialize(stream); @@ -194,10 +228,22 @@ public class ModelInformation { ReadWriteIOUtils.write(queryExpression, stream); } - ReadWriteIOUtils.write(queryFilter, stream); + if (queryFilter == null) { + ReadWriteIOUtils.write((byte) 0, stream); + } else { + ReadWriteIOUtils.write((byte) 1, stream); + ReadWriteIOUtils.write(queryFilter, stream); + } + ReadWriteIOUtils.write(isAuto, stream); ReadWriteIOUtils.write(trainingState.ordinal(), stream); - ReadWriteIOUtils.write(bestTrailId, stream); + + if (bestTrailId == null) { + ReadWriteIOUtils.write((byte) 0, stream); + } else { + ReadWriteIOUtils.write((byte) 1, stream); + ReadWriteIOUtils.write(bestTrailId, stream); + } ReadWriteIOUtils.write(trailMap.size(), stream); for (TrailInformation trailInformation : trailMap.values()) { @@ -222,9 +268,14 @@ public class ModelInformation { ReadWriteIOUtils.write(Arrays.toString(queryExpressions.toArray(new String[0])), stream); ReadWriteIOUtils.write(trainingState.toString(), stream); - TrailInformation bestTrail = trailMap.get(bestTrailId); - ReadWriteIOUtils.write(bestTrail.getModelHyperparameter().toString(), stream); - ReadWriteIOUtils.write(bestTrail.getModelPath(), stream); + if (bestTrailId != null) { + TrailInformation bestTrail = trailMap.get(bestTrailId); + ReadWriteIOUtils.write(bestTrail.getModelHyperparameter().toString(), stream); + ReadWriteIOUtils.write(bestTrail.getModelPath(), stream); + } else { + ReadWriteIOUtils.write("UNKNOWN", stream); + ReadWriteIOUtils.write("UNKNOWN", stream); + } return ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size()); } } diff --git a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java index 84278d9ba4..1ff54d43b6 100644 --- a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java +++ b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java @@ -29,16 +29,18 @@ import org.apache.iotdb.mlnode.rpc.thrift.TCreateTrainingTaskReq; import org.apache.iotdb.mlnode.rpc.thrift.TDeleteModelReq; import org.apache.iotdb.mlnode.rpc.thrift.TForecastReq; import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp; -import org.apache.iotdb.rpc.RpcTransportFactory; +import org.apache.iotdb.rpc.TConfigurationConst; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.tsfile.read.common.block.TsBlock; import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde; import org.apache.thrift.TException; -import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TCompactProtocol; import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TSocket; import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.transport.layered.TFramedTransport; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,9 +65,13 @@ public class MLNodeClient implements AutoCloseable { try { long connectionTimeout = ClientPoolProperty.DefaultProperty.WAIT_CLIENT_TIMEOUT_MS; transport = - RpcTransportFactory.INSTANCE.getTransport( - // As there is a try-catch already, we do not need to use TSocket.wrap - endpoint.getIp(), endpoint.getPort(), (int) connectionTimeout); + new TFramedTransport.Factory() + .getTransport( + new TSocket( + TConfigurationConst.defaultTConfiguration, + endpoint.getIp(), + endpoint.getPort(), + (int) connectionTimeout)); if (!transport.isOpen()) { transport.open(); } @@ -73,7 +79,7 @@ public class MLNodeClient implements AutoCloseable { throw new TException(MSG_CONNECTION_FAIL); } - TProtocolFactory protocolFactory = new TBinaryProtocol.Factory(); + TProtocolFactory protocolFactory = new TCompactProtocol.Factory(); client = new IMLNodeRPCService.Client(protocolFactory.getProtocol(transport)); } diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java index 60edda2b56..c2fa6249f8 100644 --- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java +++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java @@ -875,8 +875,7 @@ public class DataNodeInternalRPCServiceImpl implements IDataNodeRPCService.Iface @Override public TSStatus deleteModelMetrics(TDeleteModelMetricsReq req) throws TException { - // TODO - throw new TException(new UnsupportedOperationException().getCause()); + return RpcUtils.SUCCESS_STATUS; } @Override
