This is an automated email from the ASF dual-hosted git repository. hui pushed a commit to branch mlnode/test in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 3dcb848c81879817de772028a4b17880fcf13ffb Author: Minghui Liu <[email protected]> AuthorDate: Tue Apr 4 17:08:59 2023 +0800 support drop model --- .../procedure/impl/model/DropModelProcedure.java | 4 +++- mlnode/iotdb/mlnode/handler.py | 11 ++++++--- mlnode/iotdb/mlnode/util.py | 2 +- .../db/mpp/plan/parser/StatementGenerator.java | 21 ++++++++++++++--- .../impl/DataNodeInternalRPCServiceImpl.java | 27 +++++++++++++++++++++- .../service/thrift/impl/MLNodeRPCServiceImpl.java | 7 +++--- 6 files changed, 59 insertions(+), 13 deletions(-) diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java index bfa461a8b2..8f41de070d 100644 --- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java +++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java @@ -156,7 +156,9 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> { if (getCycles() > RETRY_THRESHOLD) { setFailure( new ProcedureException( - String.format("Fail to drop model [%s] at STATE [%s]", modelId, state))); + String.format( + "Fail to drop model [%s] at STATE [%s], %s", + modelId, state, e.getMessage()))); } } } diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py index 1a6e3eb90a..b4c64d94b1 100644 --- a/mlnode/iotdb/mlnode/handler.py +++ b/mlnode/iotdb/mlnode/handler.py @@ -21,6 +21,7 @@ from iotdb.mlnode.constant import TSStatusCode from iotdb.mlnode.data_access.factory import create_forecast_dataset from iotdb.mlnode.parser import parse_training_request from iotdb.mlnode.process.manager import TaskManager +from iotdb.mlnode.storage import model_storage from iotdb.mlnode.util import get_status from iotdb.thrift.mlnode import IMLNodeRPCService from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq, @@ -33,7 +34,11 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface): self.__task_manager = TaskManager(pool_num=10) # TODO: add pool num to config def deleteModel(self, req: TDeleteModelReq): - return get_status(TSStatusCode.SUCCESS_STATUS, "") + try: + model_storage.delete_model(req.modelId) + return get_status(TSStatusCode.SUCCESS_STATUS) + except Exception as e: + return get_status(TSStatusCode.FAIL_STATUS, str(e)) def createTrainingTask(self, req: TCreateTrainingTaskReq): task = None @@ -50,7 +55,7 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface): # create task & check task config legitimacy task = self.__task_manager.create_training_task(dataset, model, model_config, task_config) - return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task') + return get_status(TSStatusCode.SUCCESS_STATUS) except Exception as e: return get_status(TSStatusCode.FAIL_STATUS, str(e)) finally: @@ -58,6 +63,6 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface): self.__task_manager.submit_training_task(task) def forecast(self, req: TForecastReq): - status = get_status(TSStatusCode.SUCCESS_STATUS, "") + status = get_status(TSStatusCode.SUCCESS_STATUS) forecast_result = b'forecast result' return TForecastResp(status, forecast_result) diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py index 5d3a2d670e..5cdc52f01a 100644 --- a/mlnode/iotdb/mlnode/util.py +++ b/mlnode/iotdb/mlnode/util.py @@ -45,7 +45,7 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint: raise BadNodeUrlError(endpoint_url) -def get_status(status_code: TSStatusCode, message: str) -> TSStatus: +def get_status(status_code: TSStatusCode, message: str = None) -> TSStatus: status = TSStatus(status_code.get_status_code()) status.message = message return status diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java index 422617162d..f0da4a29f1 100644 --- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java +++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java @@ -66,6 +66,7 @@ import org.apache.iotdb.db.mpp.plan.statement.metadata.template.UnsetSchemaTempl import org.apache.iotdb.db.qp.sql.IoTDBSqlParser; import org.apache.iotdb.db.qp.sql.SqlLexer; import org.apache.iotdb.db.utils.QueryDataSetUtils; +import org.apache.iotdb.mpp.rpc.thrift.TDeleteModelMetricsReq; import org.apache.iotdb.mpp.rpc.thrift.TFetchTimeseriesReq; import org.apache.iotdb.mpp.rpc.thrift.TRecordModelMetricsReq; import org.apache.iotdb.service.rpc.thrift.TSAggregationQueryReq; @@ -110,6 +111,9 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import static org.apache.iotdb.commons.conf.IoTDBConstant.MULTI_LEVEL_PATH_WILDCARD; +import static org.apache.iotdb.db.service.thrift.impl.MLNodeRPCServiceImpl.ML_METRICS_PATH_PREFIX; + /** Convert SQL and RPC requests to {@link Statement}. */ public class StatementGenerator { private static final PerformanceOverviewMetrics PERFORMANCE_OVERVIEW_METRICS = @@ -806,10 +810,10 @@ public class StatementGenerator { return databasePath; } - public static InsertRowStatement createStatement( - TRecordModelMetricsReq recordModelMetricsReq, String prefix) throws IllegalPathException { + public static InsertRowStatement createStatement(TRecordModelMetricsReq recordModelMetricsReq) + throws IllegalPathException { String path = - prefix + ML_METRICS_PATH_PREFIX + TsFileConstant.PATH_SEPARATOR + recordModelMetricsReq.getModelId() + TsFileConstant.PATH_SEPARATOR @@ -870,4 +874,15 @@ public class StatementGenerator { queryStatement.setSelectComponent(selectComponent); return queryStatement; } + + public static DeleteTimeSeriesStatement createStatement(TDeleteModelMetricsReq req) + throws IllegalPathException { + String path = + ML_METRICS_PATH_PREFIX + + TsFileConstant.PATH_SEPARATOR + + req.getModelId() + + TsFileConstant.PATH_SEPARATOR + + MULTI_LEVEL_PATH_WILDCARD; + return new DeleteTimeSeriesStatement(Collections.singletonList(new PartialPath(path))); + } } diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java index d00903759a..ea1cc229c6 100644 --- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java +++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java @@ -102,6 +102,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.write.DeleteDataNode; import org.apache.iotdb.db.mpp.plan.scheduler.load.LoadTsFileScheduler; import org.apache.iotdb.db.mpp.plan.statement.component.WhereCondition; import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement; +import org.apache.iotdb.db.mpp.plan.statement.metadata.DeleteTimeSeriesStatement; import org.apache.iotdb.db.pipe.agent.PipeAgent; import org.apache.iotdb.db.query.control.SessionManager; import org.apache.iotdb.db.query.control.clientsession.IClientSession; @@ -196,6 +197,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.TimeZone; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -870,7 +872,30 @@ public class DataNodeInternalRPCServiceImpl implements IDataNodeRPCService.Iface @Override public TSStatus deleteModelMetrics(TDeleteModelMetricsReq req) throws TException { - return RpcUtils.SUCCESS_STATUS; + IClientSession session = new InternalClientSession(req.getModelId()); + SESSION_MANAGER.registerSession(session); + SESSION_MANAGER.supplySession( + session, "MLNode", TimeZone.getDefault().getID(), ClientVersion.V_1_0); + + try { + DeleteTimeSeriesStatement deleteTimeSeriesStatement = StatementGenerator.createStatement(req); + + long queryId = SESSION_MANAGER.requestQueryId(); + ExecutionResult result = + COORDINATOR.execute( + deleteTimeSeriesStatement, + queryId, + SESSION_MANAGER.getSessionInfo(session), + "", + PARTITION_FETCHER, + SCHEMA_FETCHER); + return result.status; + } catch (Exception e) { + return onQueryException(e, OperationType.DELETE_TIMESERIES); + } finally { + SESSION_MANAGER.closeSession(session, COORDINATOR::cleanupQueryExecution); + SESSION_MANAGER.removeCurrSession(); + } } private PathPatternTree filterPathPatternTree(PathPatternTree patternTree, String storageGroup) { diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java index 544d4cd04c..74e14f9f3f 100644 --- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java +++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java @@ -60,14 +60,14 @@ import static org.apache.iotdb.db.utils.ErrorHandlingUtils.onQueryException; public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler { + public static final String ML_METRICS_PATH_PREFIX = "root.__system.ml.exp"; + private static final Logger LOGGER = LoggerFactory.getLogger(MLNodeRPCServiceImpl.class); private static final SessionManager SESSION_MANAGER = SessionManager.getInstance(); private static final Coordinator COORDINATOR = Coordinator.getInstance(); - private static final String ML_METRICS_STORAGE_GROUP = "root.__system.ml.exp"; - private final IPartitionFetcher PARTITION_FETCHER; private final ISchemaFetcher SCHEMA_FETCHER; @@ -176,8 +176,7 @@ public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler { @Override public TSStatus recordModelMetrics(TRecordModelMetricsReq req) throws TException { try { - InsertRowStatement insertRowStatement = - StatementGenerator.createStatement(req, ML_METRICS_STORAGE_GROUP); + InsertRowStatement insertRowStatement = StatementGenerator.createStatement(req); long queryId = SESSION_MANAGER.requestQueryId(); ExecutionResult result =
