This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a3ade356f [IOTDB-5767] [IoTDB ML] Support delete model file and 
metrics (#9573)
6a3ade356f is described below

commit 6a3ade356fdd990ba15bf4491a40867d2468504e
Author: liuminghui233 <[email protected]>
AuthorDate: Tue Apr 11 10:37:38 2023 +0800

    [IOTDB-5767] [IoTDB ML] Support delete model file and metrics (#9573)
---
 .../procedure/impl/model/DropModelProcedure.java   |   4 +-
 mlnode/.gitignore                                  |   6 +-
 mlnode/iotdb/mlnode/client.py                      | 107 ++++++++++++++-------
 mlnode/iotdb/mlnode/config.py                      |  14 +--
 mlnode/iotdb/mlnode/constant.py                    |  10 ++
 mlnode/iotdb/mlnode/handler.py                     |  29 ++----
 mlnode/iotdb/mlnode/service.py                     |   6 +-
 .../iotdb/mlnode/{model_storage.py => storage.py}  |  23 +++--
 mlnode/iotdb/mlnode/util.py                        |  18 +++-
 mlnode/pyproject.toml                              |   1 +
 mlnode/requirements.txt                            |   2 +-
 mlnode/test/test_model_storage.py                  |  37 ++++---
 .../db/mpp/plan/parser/StatementGenerator.java     |  21 +++-
 .../impl/DataNodeInternalRPCServiceImpl.java       |  27 +++++-
 .../service/thrift/impl/MLNodeRPCServiceImpl.java  |   7 +-
 .../java/org/apache/iotdb/rpc/TSStatusCode.java    |   1 +
 thrift-mlnode/src/main/thrift/mlnode.thrift        |   2 +-
 thrift/src/main/thrift/datanode.thrift             |   1 -
 18 files changed, 217 insertions(+), 99 deletions(-)

diff --git 
a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
 
b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
index bfa461a8b2..8f41de070d 100644
--- 
a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++ 
b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
@@ -156,7 +156,9 @@ public class DropModelProcedure extends 
AbstractNodeProcedure<DropModelState> {
         if (getCycles() > RETRY_THRESHOLD) {
           setFailure(
               new ProcedureException(
-                  String.format("Fail to drop model [%s] at STATE [%s]", 
modelId, state)));
+                  String.format(
+                      "Fail to drop model [%s] at STATE [%s], %s",
+                      modelId, state, e.getMessage())));
         }
       }
     }
diff --git a/mlnode/.gitignore b/mlnode/.gitignore
index 94606bf62f..4a6f6eac65 100644
--- a/mlnode/.gitignore
+++ b/mlnode/.gitignore
@@ -1,4 +1,8 @@
+# generated by Thrift
 /iotdb/thrift/
 
 # generated by Poetry
-/dist/
\ No newline at end of file
+/dist/
+
+# generated by MLNode
+*.pt
\ No newline at end of file
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index 244b6975c9..3157006e57 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -16,38 +16,33 @@
 # under the License.
 #
 import time
+from typing import Dict, List
 
+import pandas as pd
 from thrift.protocol import TBinaryProtocol, TCompactProtocol
 from thrift.Thrift import TException
 from thrift.transport import TSocket, TTransport
 
-from iotdb.mlnode.config import config
+from iotdb.mlnode import serde
+from iotdb.mlnode.config import descriptor
+from iotdb.mlnode.constant import TSStatusCode
 from iotdb.mlnode.log import logger
-from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
+from iotdb.mlnode.util import verify_success
+from iotdb.thrift.common.ttypes import TEndPoint, TrainingState, TSStatus
 from iotdb.thrift.confignode import IConfigNodeRPCService
-from iotdb.thrift.confignode.ttypes import TUpdateModelInfoReq
-from iotdb.thrift.datanode import IDataNodeRPCService
+from iotdb.thrift.confignode.ttypes import (TUpdateModelInfoReq,
+                                            TUpdateModelStateReq)
+from iotdb.thrift.datanode import IMLNodeInternalRPCService
 from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq,
-                                          TFetchTimeseriesResp,
                                           TRecordModelMetricsReq)
 from iotdb.thrift.mlnode import IMLNodeRPCService
 from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TDeleteModelReq
 
-# status code
-SUCCESS_STATUS = 200
-REDIRECTION_RECOMMEND = 400
-
-
-def verify_success(status: TSStatus, err_msg: str) -> None:
-    if status.code != SUCCESS_STATUS:
-        logger.warn(err_msg + ", error status is ", status)
-        raise RuntimeError(str(status.code) + ": " + status.message)
-
 
 class ClientManager(object):
     def __init__(self):
-        self.__data_node_endpoint = config.get_mn_target_data_node()
-        self.__config_node_endpoint = config.get_mn_target_config_node()
+        self.__data_node_endpoint = 
descriptor.get_config().get_mn_target_data_node()
+        self.__config_node_endpoint = 
descriptor.get_config().get_mn_target_config_node()
 
     def borrow_data_node_client(self):
         return DataNodeClient(host=self.__data_node_endpoint.ip,
@@ -77,9 +72,9 @@ class MLNodeClient(object):
     def create_training_task(self,
                              model_id: str,
                              is_auto: bool,
-                             model_configs: dict,
-                             query_expressions: list[str],
-                             query_filter: str = None) -> None:
+                             model_configs: Dict,
+                             query_expressions: List[str],
+                             query_filter: str = '') -> None:
         req = TCreateTrainingTaskReq(
             modelId=model_id,
             isAuto=is_auto,
@@ -124,20 +119,17 @@ class DataNodeClient(object):
                 transport.open()
             except TTransport.TTransportException as e:
                 logger.exception("TTransportException!", exc_info=e)
+                raise e
 
         protocol = TBinaryProtocol.TBinaryProtocol(transport)
-        self.__client = IDataNodeRPCService.Client(protocol)
+        self.__client = IMLNodeInternalRPCService.Client(protocol)
 
     def fetch_timeseries(self,
-                         session_id: int,
-                         statement_id: int,
-                         query_expressions: list[str],
+                         query_expressions: List[str],
                          query_filter: str = None,
                          fetch_size: int = DEFAULT_FETCH_SIZE,
-                         timeout: int = DEFAULT_TIMEOUT) -> 
TFetchTimeseriesResp:
+                         timeout: int = DEFAULT_TIMEOUT) -> [int, bool, 
pd.DataFrame]:
         req = TFetchTimeseriesReq(
-            sessionId=session_id,
-            statementId=statement_id,
             queryExpressions=query_expressions,
             queryFilter=query_filter,
             fetchSize=fetch_size,
@@ -146,15 +138,35 @@ class DataNodeClient(object):
         try:
             resp = self.__client.fetchTimeseries(req)
             verify_success(resp.status, "An error occurs when calling 
fetch_timeseries()")
-            return resp
-        except TTransport.TException as e:
+
+            if len(resp.tsDataset) == 0:
+                raise RuntimeError(f'No data fetched with query filter: 
{query_filter}')
+
+            data = serde.convert_to_df(resp.columnNameList,
+                                       resp.columnTypeList,
+                                       resp.columnNameIndexMap,
+                                       resp.tsDataset)
+            if data.empty:
+                raise RuntimeError(
+                    f'Fetched empty data with query expressions: 
{query_expressions} and query filter: {query_filter}')
+            return resp.queryId, resp.hasMoreData, data
+        except Exception as e:
+            logger.warn(
+                f'Fail to fetch data with query expressions: 
{query_expressions} and query filter: {query_filter}')
             raise e
 
+    def fetch_window_batch(self,
+                           query_expressions: list,
+                           query_filter: str = None,
+                           fetch_size: int = DEFAULT_FETCH_SIZE,
+                           timeout: int = DEFAULT_TIMEOUT) -> [int, bool, 
List[pd.DataFrame]]:
+        pass
+
     def record_model_metrics(self,
                              model_id: str,
                              trial_id: str,
-                             metrics: list[str],
-                             values: list[float]) -> None:
+                             metrics: List[str],
+                             values: List) -> None:
         req = TRecordModelMetricsReq(
             modelId=model_id,
             trialId=trial_id,
@@ -192,6 +204,7 @@ class ConfigNodeClient(object):
         if self.__config_leader is not None:
             try:
                 self.__connect(self.__config_leader)
+                return
             except TException:
                 logger.warn("The current node {} may have been down, try next 
node", self.__config_leader)
                 self.__config_leader = None
@@ -206,6 +219,7 @@ class ConfigNodeClient(object):
             try_endpoint = self.__config_nodes[self.__cursor]
             try:
                 self.__connect(try_endpoint)
+                return
             except TException:
                 logger.warn("The current node {} may have been down, try next 
node", try_endpoint)
 
@@ -223,7 +237,7 @@ class ConfigNodeClient(object):
             except TTransport.TTransportException as e:
                 logger.exception("TTransportException!", exc_info=e)
 
-        protocol = TCompactProtocol.TBinaryProtocol(transport)
+        protocol = TBinaryProtocol.TBinaryProtocol(transport)
         self.__client = IConfigNodeRPCService.Client(protocol)
 
     def __wait_and_reconnect(self) -> None:
@@ -242,7 +256,7 @@ class ConfigNodeClient(object):
         pass
 
     def __update_config_node_leader(self, status: TSStatus) -> bool:
-        if status.code == REDIRECTION_RECOMMEND:
+        if status.code == TSStatusCode.REDIRECTION_RECOMMEND:
             if status.redirectNode is not None:
                 self.__config_leader = status.redirectNode
             else:
@@ -250,15 +264,38 @@ class ConfigNodeClient(object):
             return True
         return False
 
+    def update_model_state(self,
+                           model_id: str,
+                           training_state: TrainingState,
+                           best_trail_id: str = None) -> None:
+        req = TUpdateModelStateReq(
+            modelId=model_id,
+            state=training_state,
+            bestTrailId=best_trail_id
+        )
+        for i in range(0, self.__RETRY_NUM):
+            try:
+                status = self.__client.updateModelState(req)
+                if not self.__update_config_node_leader(status):
+                    verify_success(status, "An error occurs when calling 
update_model_state()")
+                    return
+            except TTransport.TException:
+                logger.warn("Failed to connect to ConfigNode {} from MLNode 
when executing update_model_info()",
+                            self.__config_leader)
+                self.__config_leader = None
+            self.__wait_and_reconnect()
+
+        raise TException(self.__MSG_RECONNECTION_FAIL)
+
     def update_model_info(self,
                           model_id: str,
                           trial_id: str,
-                          model_info: dict) -> None:
+                          model_info: Dict) -> None:
         if model_info is None:
             model_info = {}
         req = TUpdateModelInfoReq(
             modelId=model_id,
-            trialId=trial_id,
+            trailId=trial_id,
             modelInfo={k: str(v) for k, v in model_info.items()},
         )
 
diff --git a/mlnode/iotdb/mlnode/config.py b/mlnode/iotdb/mlnode/config.py
index e59338209a..0ccfdc2cbb 100644
--- a/mlnode/iotdb/mlnode/config.py
+++ b/mlnode/iotdb/mlnode/config.py
@@ -44,7 +44,7 @@ class MLNodeConfig(object):
         self.__mn_target_config_node: TEndPoint = TEndPoint("127.0.0.1", 10710)
 
         # Target DataNode to be connected by MLNode
-        self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10730)
+        self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10780)
 
     def get_mn_rpc_address(self) -> str:
         return self.__mn_rpc_address
@@ -61,13 +61,13 @@ class MLNodeConfig(object):
     def get_mn_model_storage_dir(self) -> str:
         return self.__mn_model_storage_dir
 
-    def set_mn_model_storage_dir(self, mn_model_storage_dir: str):
+    def set_mn_model_storage_dir(self, mn_model_storage_dir: str) -> None:
         self.__mn_model_storage_dir = mn_model_storage_dir
 
     def get_mn_model_storage_cache_size(self) -> int:
         return self.__mn_model_storage_cache_size
 
-    def set_mn_model_storage_cache_size(self, mn_model_storage_cache_size: 
int):
+    def set_mn_model_storage_cache_size(self, mn_model_storage_cache_size: 
int) -> None:
         self.__mn_model_storage_cache_size = mn_model_storage_cache_size
 
     def get_mn_target_config_node(self) -> TEndPoint:
@@ -86,9 +86,8 @@ class MLNodeConfig(object):
 class MLNodeDescriptor(object):
     def __init__(self):
         self.__config = MLNodeConfig()
-        self.__load_config_from_file()
 
-    def __load_config_from_file(self) -> None:
+    def load_config_from_file(self) -> None:
         conf_file = os.path.join(os.getcwd(), MLNODE_CONF_DIRECTORY_NAME, 
MLNODE_CONF_FILE_NAME)
         if not os.path.exists(conf_file):
             logger.info("Cannot find MLNode config file '{}', use default 
configuration.".format(conf_file))
@@ -113,7 +112,7 @@ class MLNodeDescriptor(object):
                 
self.__config.set_mn_model_storage_dir(file_configs.mn_model_storage_dir)
 
             if file_configs.mn_model_storage_cache_size is not None:
-                
self.__config.set_mn_model_storage_cachesize(file_configs.mn_model_storage_cache_size)
+                
self.__config.set_mn_model_storage_cache_size(file_configs.mn_model_storage_cache_size)
 
             if file_configs.mn_target_config_node is not None:
                 
self.__config.set_mn_target_config_node(file_configs.mn_target_config_node)
@@ -129,4 +128,5 @@ class MLNodeDescriptor(object):
         return self.__config
 
 
-config = MLNodeDescriptor().get_config()
+# initialize a singleton
+descriptor = MLNodeDescriptor()
diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py
index 8a38aa95d8..68240af12a 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/constant.py
@@ -15,9 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+from enum import Enum
 
 MLNODE_CONF_DIRECTORY_NAME = "conf"
 MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
 MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
 
 MLNODE_MODEL_STORAGE_DIRECTORY_NAME = "models"
+
+
+class TSStatusCode(Enum):
+    SUCCESS_STATUS = 200
+    REDIRECTION_RECOMMEND = 400
+    MLNODE_INTERNAL_ERROR = 1510
+
+    def get_status_code(self) -> int:
+        return self.value
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index 8a36353d47..78021420d4 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -15,39 +15,30 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-from enum import Enum
-
-from iotdb.thrift.common.ttypes import TSStatus
+from iotdb.mlnode.constant import TSStatusCode
+from iotdb.mlnode.storage import model_storage
+from iotdb.mlnode.util import get_status
 from iotdb.thrift.mlnode import IMLNodeRPCService
 from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
                                         TDeleteModelReq, TForecastReq,
                                         TForecastResp)
 
 
-class TSStatusCode(Enum):
-    SUCCESS_STATUS = 200
-
-    def get_status_code(self) -> int:
-        return self.value
-
-
-def get_status(status_code: TSStatusCode, message: str) -> TSStatus:
-    status = TSStatus(status_code.get_status_code())
-    status.message = message
-    return status
-
-
 class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
     def __init__(self):
         pass
 
     def deleteModel(self, req: TDeleteModelReq):
-        return get_status(TSStatusCode.SUCCESS_STATUS, "")
+        try:
+            model_storage.delete_model(req.modelId)
+            return get_status(TSStatusCode.SUCCESS_STATUS)
+        except Exception as e:
+            return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
 
     def createTrainingTask(self, req: TCreateTrainingTaskReq):
-        return get_status(TSStatusCode.SUCCESS_STATUS, "")
+        return get_status(TSStatusCode.SUCCESS_STATUS)
 
     def forecast(self, req: TForecastReq):
-        status = get_status(TSStatusCode.SUCCESS_STATUS, "")
+        status = get_status(TSStatusCode.SUCCESS_STATUS)
         forecast_result = b'forecast result'
         return TForecastResp(status, forecast_result)
diff --git a/mlnode/iotdb/mlnode/service.py b/mlnode/iotdb/mlnode/service.py
index a2c05ea5c3..8bf97402d1 100644
--- a/mlnode/iotdb/mlnode/service.py
+++ b/mlnode/iotdb/mlnode/service.py
@@ -22,7 +22,7 @@ from thrift.protocol import TCompactProtocol
 from thrift.server import TServer
 from thrift.transport import TSocket, TTransport
 
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
 from iotdb.mlnode.handler import MLNodeRPCServiceHandler
 from iotdb.mlnode.log import logger
 from iotdb.thrift.mlnode import IMLNodeRPCService
@@ -32,7 +32,8 @@ class RPCService(threading.Thread):
     def __init__(self):
         super().__init__()
         processor = 
IMLNodeRPCService.Processor(handler=MLNodeRPCServiceHandler())
-        transport = TSocket.TServerSocket(host=config.get_mn_rpc_address(), 
port=config.get_mn_rpc_port())
+        transport = 
TSocket.TServerSocket(host=descriptor.get_config().get_mn_rpc_address(),
+                                          
port=descriptor.get_config().get_mn_rpc_port())
         transport_factory = TTransport.TFramedTransportFactory()
         protocol_factory = TCompactProtocol.TCompactProtocolFactory()
 
@@ -45,6 +46,7 @@ class RPCService(threading.Thread):
 
 class MLNode(object):
     def __init__(self):
+        descriptor.load_config_from_file()
         self.__rpc_service = RPCService()
 
     def start(self) -> None:
diff --git a/mlnode/iotdb/mlnode/model_storage.py 
b/mlnode/iotdb/mlnode/storage.py
similarity index 82%
rename from mlnode/iotdb/mlnode/model_storage.py
rename to mlnode/iotdb/mlnode/storage.py
index ee745689b1..84d7dfd7ed 100644
--- a/mlnode/iotdb/mlnode/model_storage.py
+++ b/mlnode/iotdb/mlnode/storage.py
@@ -19,42 +19,49 @@
 import json
 import os
 import shutil
+from typing import Dict, Tuple
 
 import torch
 import torch.nn as nn
 from pylru import lrucache
 
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
 from iotdb.mlnode.exception import ModelNotExistError
+from iotdb.mlnode.log import logger
 
 
 class ModelStorage(object):
     def __init__(self):
-        self.__model_dir = os.path.join(os.getcwd(), 
config.get_mn_model_storage_dir())
+        self.__model_dir = os.path.join('.', 
descriptor.get_config().get_mn_model_storage_dir())
         if not os.path.exists(self.__model_dir):
-            os.mkdir(self.__model_dir)
+            try:
+                os.mkdir(self.__model_dir)
+            except PermissionError as e:
+                logger.error(e)
+                raise e
 
-        self.__model_cache = lrucache(config.get_mn_model_storage_cache_size())
+        self.__model_cache = 
lrucache(descriptor.get_config().get_mn_model_storage_cache_size())
 
     def save_model(self,
                    model: nn.Module,
-                   model_config: dict,
+                   model_config: Dict,
                    model_id: str,
-                   trial_id: str) -> None:
+                   trial_id: str) -> str:
         """
         Note: model config for time series should contain 'input_len' and 
'input_vars'
         """
         model_dir_path = os.path.join(self.__model_dir, f'{model_id}')
         if not os.path.exists(model_dir_path):
-            os.mkdir(model_dir_path)
+            os.makedirs(model_dir_path)
         model_file_path = os.path.join(model_dir_path, f'{trial_id}.pt')
 
         sample_input = [torch.randn(1, model_config['input_len'], 
model_config['input_vars'])]
         torch.jit.save(torch.jit.trace(model, sample_input),
                        model_file_path,
                        _extra_files={'model_config': json.dumps(model_config)})
+        return os.path.abspath(model_file_path)
 
-    def load_model(self, model_id: str, trial_id: str) -> 
(torch.jit.ScriptModule, dict):
+    def load_model(self, model_id: str, trial_id: str) -> 
Tuple[torch.jit.ScriptModule, Dict]:
         """
         Returns:
             jit_model: a ScriptModule contains model architecture and 
parameters, which can be deployed cross-platform
diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py
index 8932479c4a..5cdc52f01a 100644
--- a/mlnode/iotdb/mlnode/util.py
+++ b/mlnode/iotdb/mlnode/util.py
@@ -15,20 +15,18 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+from iotdb.mlnode.constant import TSStatusCode
 from iotdb.mlnode.exception import BadNodeUrlError
 from iotdb.mlnode.log import logger
-from iotdb.thrift.common.ttypes import TEndPoint
+from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
 
 
 def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
     """ Parse TEndPoint from a given endpoint url.
-
     Args:
         endpoint_url: an endpoint url, format: ip:port
-
     Returns:
         TEndPoint
-
     Raises:
         BadNodeUrlError
     """
@@ -45,3 +43,15 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
     except ValueError as e:
         logger.warning("Illegal endpoint url format: {} 
({})".format(endpoint_url, e))
         raise BadNodeUrlError(endpoint_url)
+
+
+def get_status(status_code: TSStatusCode, message: str = None) -> TSStatus:
+    status = TSStatus(status_code.get_status_code())
+    status.message = message
+    return status
+
+
+def verify_success(status: TSStatus, err_msg: str) -> None:
+    if status.code != TSStatusCode.SUCCESS_STATUS.get_status_code():
+        logger.warn(err_msg + ", error status is ", status)
+        raise RuntimeError(str(status.code) + ": " + status.message)
diff --git a/mlnode/pyproject.toml b/mlnode/pyproject.toml
index 3944e2910d..56290f8d4e 100644
--- a/mlnode/pyproject.toml
+++ b/mlnode/pyproject.toml
@@ -49,6 +49,7 @@ packages = [
 python = "^3.7"
 thrift = "^0.13.0"
 dynaconf = "^3.1.11"
+pylru = "^1.2.1"
 
 [tool.poetry.scripts]
 mlnode = "iotdb.mlnode.script:main"
\ No newline at end of file
diff --git a/mlnode/requirements.txt b/mlnode/requirements.txt
index edd85701ab..c49c8a0189 100644
--- a/mlnode/requirements.txt
+++ b/mlnode/requirements.txt
@@ -20,7 +20,7 @@ pandas>=1.3.5
 numpy>=1.21.4
 apache-iotdb
 poetry
-torch
+torch~=2.0.0
 pylru
 
 thrift~=0.13.0
diff --git a/mlnode/test/test_model_storage.py 
b/mlnode/test/test_model_storage.py
index 99857db37e..863e73b716 100644
--- a/mlnode/test/test_model_storage.py
+++ b/mlnode/test/test_model_storage.py
@@ -16,26 +16,26 @@
 # under the License.
 #
 
-
 import os
 import time
 
 import torch.nn as nn
 
-from iotdb.mlnode.config import config
-from iotdb.mlnode.model_storage import model_storage
+from iotdb.mlnode.config import descriptor
+from iotdb.mlnode.exception import ModelNotExistError
+from iotdb.mlnode.storage import model_storage
 
 
-class TestModel(nn.Module):
+class ExampleModel(nn.Module):
     def __init__(self):
-        super(TestModel, self).__init__()
+        super(ExampleModel, self).__init__()
         self.layer = nn.Identity()
 
     def forward(self, x):
         return self.layer(x)
 
 
-model = TestModel()
+model = ExampleModel()
 model_config = {
     'input_len': 1,
     'input_vars': 1,
@@ -47,7 +47,8 @@ def test_save_model():
     trial_id = 'tid_0'
     model_id = 'mid_test_model_save'
     model_storage.save_model(model, model_config, model_id=model_id, 
trial_id=trial_id)
-    assert os.path.exists(os.path.join(config.get_mn_model_storage_dir(), 
model_id, f'{trial_id}.pt'))
+    assert os.path.exists(
+        os.path.join(descriptor.get_config().get_mn_model_storage_dir(), 
f'{model_id}', f'{trial_id}.pt'))
 
 
 def test_load_model():
@@ -58,6 +59,17 @@ def test_load_model():
     assert model_config == model_config_loaded
 
 
+def test_load_not_exist_model():
+    trial_id = 'dummy_trial'
+    model_id = 'dummy_model'
+    try:
+        model_loaded, model_config_loaded = 
model_storage.load_model(model_id=model_id, trial_id=trial_id)
+    except Exception as e:
+        assert e.message == ModelNotExistError(
+            os.path.join('.', 
descriptor.get_config().get_mn_model_storage_dir(),
+                         model_id, f'{trial_id}.pt')).message
+
+
 def test_delete_model():
     trial_id1 = 'tid_1'
     trial_id2 = 'tid_2'
@@ -65,9 +77,11 @@ def test_delete_model():
     model_storage.save_model(model, model_config, model_id=model_id, 
trial_id=trial_id1)
     model_storage.save_model(model, model_config, model_id=model_id, 
trial_id=trial_id2)
     model_storage.delete_model(model_id=model_id)
-    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), 
model_id, f'{trial_id1}.pt'))
-    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), 
model_id, f'{trial_id2}.pt'))
-    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), 
model_id))
+    assert not os.path.exists(
+        os.path.join(descriptor.get_config().get_mn_model_storage_dir(), 
f'{model_id}', f'{trial_id1}.pt'))
+    assert not os.path.exists(
+        os.path.join(descriptor.get_config().get_mn_model_storage_dir(), 
f'{model_id}', f'{trial_id2}.pt'))
+    assert not 
os.path.exists(os.path.join(descriptor.get_config().get_mn_model_storage_dir(), 
f'{model_id}'))
 
 
 def test_delete_trial():
@@ -75,4 +89,5 @@ def test_delete_trial():
     model_id = 'mid_test_model_delete'
     model_storage.save_model(model, model_config, model_id=model_id, 
trial_id=trial_id)
     model_storage.delete_trial(model_id=model_id, trial_id=trial_id)
-    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), 
model_id, f'{trial_id}.pt'))
+    assert not os.path.exists(
+        os.path.join(descriptor.get_config().get_mn_model_storage_dir(), 
f'{model_id}', f'{trial_id}.pt'))
diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
index ee59471e6e..432b76a718 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
@@ -66,6 +66,7 @@ import 
org.apache.iotdb.db.mpp.plan.statement.metadata.template.UnsetSchemaTempl
 import org.apache.iotdb.db.qp.sql.IoTDBSqlParser;
 import org.apache.iotdb.db.qp.sql.SqlLexer;
 import org.apache.iotdb.db.utils.QueryDataSetUtils;
+import org.apache.iotdb.mpp.rpc.thrift.TDeleteModelMetricsReq;
 import org.apache.iotdb.mpp.rpc.thrift.TFetchTimeseriesReq;
 import org.apache.iotdb.mpp.rpc.thrift.TRecordModelMetricsReq;
 import org.apache.iotdb.service.rpc.thrift.TSAggregationQueryReq;
@@ -111,6 +112,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
+import static 
org.apache.iotdb.commons.conf.IoTDBConstant.MULTI_LEVEL_PATH_WILDCARD;
+import static 
org.apache.iotdb.db.service.thrift.impl.MLNodeRPCServiceImpl.ML_METRICS_PATH_PREFIX;
+
 /** Convert SQL and RPC requests to {@link Statement}. */
 public class StatementGenerator {
   private static final PerformanceOverviewMetrics PERFORMANCE_OVERVIEW_METRICS 
=
@@ -807,10 +811,10 @@ public class StatementGenerator {
     return databasePath;
   }
 
-  public static InsertRowStatement createStatement(
-      TRecordModelMetricsReq recordModelMetricsReq, String prefix) throws 
IllegalPathException {
+  public static InsertRowStatement createStatement(TRecordModelMetricsReq 
recordModelMetricsReq)
+      throws IllegalPathException {
     String path =
-        prefix
+        ML_METRICS_PATH_PREFIX
             + TsFileConstant.PATH_SEPARATOR
             + recordModelMetricsReq.getModelId()
             + TsFileConstant.PATH_SEPARATOR
@@ -873,4 +877,15 @@ public class StatementGenerator {
     }
     return queryStatement;
   }
+
+  public static DeleteTimeSeriesStatement 
createStatement(TDeleteModelMetricsReq req)
+      throws IllegalPathException {
+    String path =
+        ML_METRICS_PATH_PREFIX
+            + TsFileConstant.PATH_SEPARATOR
+            + req.getModelId()
+            + TsFileConstant.PATH_SEPARATOR
+            + MULTI_LEVEL_PATH_WILDCARD;
+    return new DeleteTimeSeriesStatement(Collections.singletonList(new 
PartialPath(path)));
+  }
 }
diff --git 
a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
 
b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
index c50c74e841..b575bc7543 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
@@ -104,6 +104,7 @@ import 
org.apache.iotdb.db.mpp.plan.planner.plan.node.write.DeleteDataNode;
 import org.apache.iotdb.db.mpp.plan.scheduler.load.LoadTsFileScheduler;
 import org.apache.iotdb.db.mpp.plan.statement.component.WhereCondition;
 import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
+import 
org.apache.iotdb.db.mpp.plan.statement.metadata.DeleteTimeSeriesStatement;
 import org.apache.iotdb.db.pipe.agent.PipeAgent;
 import org.apache.iotdb.db.query.control.SessionManager;
 import org.apache.iotdb.db.query.control.clientsession.IClientSession;
@@ -200,6 +201,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.TimeZone;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
@@ -880,7 +882,30 @@ public class DataNodeInternalRPCServiceImpl implements 
IDataNodeRPCService.Iface
 
   @Override
   public TSStatus deleteModelMetrics(TDeleteModelMetricsReq req) throws 
TException {
-    return RpcUtils.SUCCESS_STATUS;
+    IClientSession session = new InternalClientSession(req.getModelId());
+    SESSION_MANAGER.registerSession(session);
+    SESSION_MANAGER.supplySession(
+        session, "MLNode", TimeZone.getDefault().getID(), ClientVersion.V_1_0);
+
+    try {
+      DeleteTimeSeriesStatement deleteTimeSeriesStatement = 
StatementGenerator.createStatement(req);
+
+      long queryId = SESSION_MANAGER.requestQueryId();
+      ExecutionResult result =
+          COORDINATOR.execute(
+              deleteTimeSeriesStatement,
+              queryId,
+              SESSION_MANAGER.getSessionInfo(session),
+              "",
+              PARTITION_FETCHER,
+              SCHEMA_FETCHER);
+      return result.status;
+    } catch (Exception e) {
+      return onQueryException(e, OperationType.DELETE_TIMESERIES);
+    } finally {
+      SESSION_MANAGER.closeSession(session, 
COORDINATOR::cleanupQueryExecution);
+      SESSION_MANAGER.removeCurrSession();
+    }
   }
 
   @Override
diff --git 
a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
 
b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
index 544d4cd04c..74e14f9f3f 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
@@ -60,14 +60,14 @@ import static 
org.apache.iotdb.db.utils.ErrorHandlingUtils.onQueryException;
 
 public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler {
 
+  public static final String ML_METRICS_PATH_PREFIX = "root.__system.ml.exp";
+
   private static final Logger LOGGER = 
LoggerFactory.getLogger(MLNodeRPCServiceImpl.class);
 
   private static final SessionManager SESSION_MANAGER = 
SessionManager.getInstance();
 
   private static final Coordinator COORDINATOR = Coordinator.getInstance();
 
-  private static final String ML_METRICS_STORAGE_GROUP = 
"root.__system.ml.exp";
-
   private final IPartitionFetcher PARTITION_FETCHER;
 
   private final ISchemaFetcher SCHEMA_FETCHER;
@@ -176,8 +176,7 @@ public class MLNodeRPCServiceImpl implements 
IMLNodeRPCServiceWithHandler {
   @Override
   public TSStatus recordModelMetrics(TRecordModelMetricsReq req) throws 
TException {
     try {
-      InsertRowStatement insertRowStatement =
-          StatementGenerator.createStatement(req, ML_METRICS_STORAGE_GROUP);
+      InsertRowStatement insertRowStatement = 
StatementGenerator.createStatement(req);
 
       long queryId = SESSION_MANAGER.requestQueryId();
       ExecutionResult result =
diff --git a/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java 
b/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
index 5f0fd876be..f9ae2b356a 100644
--- a/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
+++ b/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
@@ -186,6 +186,7 @@ public enum TSStatusCode {
   // ML Model
   CREATE_MODEL_ERROR(1500),
   DROP_MODEL_ERROR(1501),
+  MLNODE_INTERNAL_ERROR(1510),
 
   // Pipe Plugin
   CREATE_PIPE_PLUGIN_ERROR(1600),
diff --git a/thrift-mlnode/src/main/thrift/mlnode.thrift 
b/thrift-mlnode/src/main/thrift/mlnode.thrift
index 916022e973..abadc79576 100644
--- a/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -31,7 +31,7 @@ struct TCreateTrainingTaskReq {
 
 struct TDeleteModelReq {
   1: required string modelId
-  2: optional string trailId
+  2: optional string trialId
 }
 
 struct TForecastReq {
diff --git a/thrift/src/main/thrift/datanode.thrift 
b/thrift/src/main/thrift/datanode.thrift
index 49066f56c2..23cc0fa032 100644
--- a/thrift/src/main/thrift/datanode.thrift
+++ b/thrift/src/main/thrift/datanode.thrift
@@ -809,4 +809,3 @@ service IMLNodeInternalRPCService{
   */
   common.TSStatus recordModelMetrics(TRecordModelMetricsReq req)
 }
-


Reply via email to