This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 6a3ade356f [IOTDB-5767] [IoTDB ML] Support delete model file and
metrics (#9573)
6a3ade356f is described below
commit 6a3ade356fdd990ba15bf4491a40867d2468504e
Author: liuminghui233 <[email protected]>
AuthorDate: Tue Apr 11 10:37:38 2023 +0800
[IOTDB-5767] [IoTDB ML] Support delete model file and metrics (#9573)
---
.../procedure/impl/model/DropModelProcedure.java | 4 +-
mlnode/.gitignore | 6 +-
mlnode/iotdb/mlnode/client.py | 107 ++++++++++++++-------
mlnode/iotdb/mlnode/config.py | 14 +--
mlnode/iotdb/mlnode/constant.py | 10 ++
mlnode/iotdb/mlnode/handler.py | 29 ++----
mlnode/iotdb/mlnode/service.py | 6 +-
.../iotdb/mlnode/{model_storage.py => storage.py} | 23 +++--
mlnode/iotdb/mlnode/util.py | 18 +++-
mlnode/pyproject.toml | 1 +
mlnode/requirements.txt | 2 +-
mlnode/test/test_model_storage.py | 37 ++++---
.../db/mpp/plan/parser/StatementGenerator.java | 21 +++-
.../impl/DataNodeInternalRPCServiceImpl.java | 27 +++++-
.../service/thrift/impl/MLNodeRPCServiceImpl.java | 7 +-
.../java/org/apache/iotdb/rpc/TSStatusCode.java | 1 +
thrift-mlnode/src/main/thrift/mlnode.thrift | 2 +-
thrift/src/main/thrift/datanode.thrift | 1 -
18 files changed, 217 insertions(+), 99 deletions(-)
diff --git
a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
index bfa461a8b2..8f41de070d 100644
---
a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++
b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
@@ -156,7 +156,9 @@ public class DropModelProcedure extends
AbstractNodeProcedure<DropModelState> {
if (getCycles() > RETRY_THRESHOLD) {
setFailure(
new ProcedureException(
- String.format("Fail to drop model [%s] at STATE [%s]",
modelId, state)));
+ String.format(
+ "Fail to drop model [%s] at STATE [%s], %s",
+ modelId, state, e.getMessage())));
}
}
}
diff --git a/mlnode/.gitignore b/mlnode/.gitignore
index 94606bf62f..4a6f6eac65 100644
--- a/mlnode/.gitignore
+++ b/mlnode/.gitignore
@@ -1,4 +1,8 @@
+# generated by Thrift
/iotdb/thrift/
# generated by Poetry
-/dist/
\ No newline at end of file
+/dist/
+
+# generated by MLNode
+*.pt
\ No newline at end of file
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index 244b6975c9..3157006e57 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -16,38 +16,33 @@
# under the License.
#
import time
+from typing import Dict, List
+import pandas as pd
from thrift.protocol import TBinaryProtocol, TCompactProtocol
from thrift.Thrift import TException
from thrift.transport import TSocket, TTransport
-from iotdb.mlnode.config import config
+from iotdb.mlnode import serde
+from iotdb.mlnode.config import descriptor
+from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.log import logger
-from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
+from iotdb.mlnode.util import verify_success
+from iotdb.thrift.common.ttypes import TEndPoint, TrainingState, TSStatus
from iotdb.thrift.confignode import IConfigNodeRPCService
-from iotdb.thrift.confignode.ttypes import TUpdateModelInfoReq
-from iotdb.thrift.datanode import IDataNodeRPCService
+from iotdb.thrift.confignode.ttypes import (TUpdateModelInfoReq,
+ TUpdateModelStateReq)
+from iotdb.thrift.datanode import IMLNodeInternalRPCService
from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq,
- TFetchTimeseriesResp,
TRecordModelMetricsReq)
from iotdb.thrift.mlnode import IMLNodeRPCService
from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TDeleteModelReq
-# status code
-SUCCESS_STATUS = 200
-REDIRECTION_RECOMMEND = 400
-
-
-def verify_success(status: TSStatus, err_msg: str) -> None:
- if status.code != SUCCESS_STATUS:
- logger.warn(err_msg + ", error status is ", status)
- raise RuntimeError(str(status.code) + ": " + status.message)
-
class ClientManager(object):
def __init__(self):
- self.__data_node_endpoint = config.get_mn_target_data_node()
- self.__config_node_endpoint = config.get_mn_target_config_node()
+ self.__data_node_endpoint =
descriptor.get_config().get_mn_target_data_node()
+ self.__config_node_endpoint =
descriptor.get_config().get_mn_target_config_node()
def borrow_data_node_client(self):
return DataNodeClient(host=self.__data_node_endpoint.ip,
@@ -77,9 +72,9 @@ class MLNodeClient(object):
def create_training_task(self,
model_id: str,
is_auto: bool,
- model_configs: dict,
- query_expressions: list[str],
- query_filter: str = None) -> None:
+ model_configs: Dict,
+ query_expressions: List[str],
+ query_filter: str = '') -> None:
req = TCreateTrainingTaskReq(
modelId=model_id,
isAuto=is_auto,
@@ -124,20 +119,17 @@ class DataNodeClient(object):
transport.open()
except TTransport.TTransportException as e:
logger.exception("TTransportException!", exc_info=e)
+ raise e
protocol = TBinaryProtocol.TBinaryProtocol(transport)
- self.__client = IDataNodeRPCService.Client(protocol)
+ self.__client = IMLNodeInternalRPCService.Client(protocol)
def fetch_timeseries(self,
- session_id: int,
- statement_id: int,
- query_expressions: list[str],
+ query_expressions: List[str],
query_filter: str = None,
fetch_size: int = DEFAULT_FETCH_SIZE,
- timeout: int = DEFAULT_TIMEOUT) ->
TFetchTimeseriesResp:
+ timeout: int = DEFAULT_TIMEOUT) -> [int, bool,
pd.DataFrame]:
req = TFetchTimeseriesReq(
- sessionId=session_id,
- statementId=statement_id,
queryExpressions=query_expressions,
queryFilter=query_filter,
fetchSize=fetch_size,
@@ -146,15 +138,35 @@ class DataNodeClient(object):
try:
resp = self.__client.fetchTimeseries(req)
verify_success(resp.status, "An error occurs when calling
fetch_timeseries()")
- return resp
- except TTransport.TException as e:
+
+ if len(resp.tsDataset) == 0:
+ raise RuntimeError(f'No data fetched with query filter:
{query_filter}')
+
+ data = serde.convert_to_df(resp.columnNameList,
+ resp.columnTypeList,
+ resp.columnNameIndexMap,
+ resp.tsDataset)
+ if data.empty:
+ raise RuntimeError(
+ f'Fetched empty data with query expressions:
{query_expressions} and query filter: {query_filter}')
+ return resp.queryId, resp.hasMoreData, data
+ except Exception as e:
+ logger.warn(
+ f'Fail to fetch data with query expressions:
{query_expressions} and query filter: {query_filter}')
raise e
+ def fetch_window_batch(self,
+ query_expressions: list,
+ query_filter: str = None,
+ fetch_size: int = DEFAULT_FETCH_SIZE,
+ timeout: int = DEFAULT_TIMEOUT) -> [int, bool,
List[pd.DataFrame]]:
+ pass
+
def record_model_metrics(self,
model_id: str,
trial_id: str,
- metrics: list[str],
- values: list[float]) -> None:
+ metrics: List[str],
+ values: List) -> None:
req = TRecordModelMetricsReq(
modelId=model_id,
trialId=trial_id,
@@ -192,6 +204,7 @@ class ConfigNodeClient(object):
if self.__config_leader is not None:
try:
self.__connect(self.__config_leader)
+ return
except TException:
logger.warn("The current node {} may have been down, try next
node", self.__config_leader)
self.__config_leader = None
@@ -206,6 +219,7 @@ class ConfigNodeClient(object):
try_endpoint = self.__config_nodes[self.__cursor]
try:
self.__connect(try_endpoint)
+ return
except TException:
logger.warn("The current node {} may have been down, try next
node", try_endpoint)
@@ -223,7 +237,7 @@ class ConfigNodeClient(object):
except TTransport.TTransportException as e:
logger.exception("TTransportException!", exc_info=e)
- protocol = TCompactProtocol.TBinaryProtocol(transport)
+ protocol = TBinaryProtocol.TBinaryProtocol(transport)
self.__client = IConfigNodeRPCService.Client(protocol)
def __wait_and_reconnect(self) -> None:
@@ -242,7 +256,7 @@ class ConfigNodeClient(object):
pass
def __update_config_node_leader(self, status: TSStatus) -> bool:
- if status.code == REDIRECTION_RECOMMEND:
+ if status.code == TSStatusCode.REDIRECTION_RECOMMEND:
if status.redirectNode is not None:
self.__config_leader = status.redirectNode
else:
@@ -250,15 +264,38 @@ class ConfigNodeClient(object):
return True
return False
+ def update_model_state(self,
+ model_id: str,
+ training_state: TrainingState,
+ best_trail_id: str = None) -> None:
+ req = TUpdateModelStateReq(
+ modelId=model_id,
+ state=training_state,
+ bestTrailId=best_trail_id
+ )
+ for i in range(0, self.__RETRY_NUM):
+ try:
+ status = self.__client.updateModelState(req)
+ if not self.__update_config_node_leader(status):
+ verify_success(status, "An error occurs when calling
update_model_state()")
+ return
+ except TTransport.TException:
+ logger.warn("Failed to connect to ConfigNode {} from MLNode
when executing update_model_info()",
+ self.__config_leader)
+ self.__config_leader = None
+ self.__wait_and_reconnect()
+
+ raise TException(self.__MSG_RECONNECTION_FAIL)
+
def update_model_info(self,
model_id: str,
trial_id: str,
- model_info: dict) -> None:
+ model_info: Dict) -> None:
if model_info is None:
model_info = {}
req = TUpdateModelInfoReq(
modelId=model_id,
- trialId=trial_id,
+ trailId=trial_id,
modelInfo={k: str(v) for k, v in model_info.items()},
)
diff --git a/mlnode/iotdb/mlnode/config.py b/mlnode/iotdb/mlnode/config.py
index e59338209a..0ccfdc2cbb 100644
--- a/mlnode/iotdb/mlnode/config.py
+++ b/mlnode/iotdb/mlnode/config.py
@@ -44,7 +44,7 @@ class MLNodeConfig(object):
self.__mn_target_config_node: TEndPoint = TEndPoint("127.0.0.1", 10710)
# Target DataNode to be connected by MLNode
- self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10730)
+ self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10780)
def get_mn_rpc_address(self) -> str:
return self.__mn_rpc_address
@@ -61,13 +61,13 @@ class MLNodeConfig(object):
def get_mn_model_storage_dir(self) -> str:
return self.__mn_model_storage_dir
- def set_mn_model_storage_dir(self, mn_model_storage_dir: str):
+ def set_mn_model_storage_dir(self, mn_model_storage_dir: str) -> None:
self.__mn_model_storage_dir = mn_model_storage_dir
def get_mn_model_storage_cache_size(self) -> int:
return self.__mn_model_storage_cache_size
- def set_mn_model_storage_cache_size(self, mn_model_storage_cache_size:
int):
+ def set_mn_model_storage_cache_size(self, mn_model_storage_cache_size:
int) -> None:
self.__mn_model_storage_cache_size = mn_model_storage_cache_size
def get_mn_target_config_node(self) -> TEndPoint:
@@ -86,9 +86,8 @@ class MLNodeConfig(object):
class MLNodeDescriptor(object):
def __init__(self):
self.__config = MLNodeConfig()
- self.__load_config_from_file()
- def __load_config_from_file(self) -> None:
+ def load_config_from_file(self) -> None:
conf_file = os.path.join(os.getcwd(), MLNODE_CONF_DIRECTORY_NAME,
MLNODE_CONF_FILE_NAME)
if not os.path.exists(conf_file):
logger.info("Cannot find MLNode config file '{}', use default
configuration.".format(conf_file))
@@ -113,7 +112,7 @@ class MLNodeDescriptor(object):
self.__config.set_mn_model_storage_dir(file_configs.mn_model_storage_dir)
if file_configs.mn_model_storage_cache_size is not None:
-
self.__config.set_mn_model_storage_cachesize(file_configs.mn_model_storage_cache_size)
+
self.__config.set_mn_model_storage_cache_size(file_configs.mn_model_storage_cache_size)
if file_configs.mn_target_config_node is not None:
self.__config.set_mn_target_config_node(file_configs.mn_target_config_node)
@@ -129,4 +128,5 @@ class MLNodeDescriptor(object):
return self.__config
-config = MLNodeDescriptor().get_config()
+# initialize a singleton
+descriptor = MLNodeDescriptor()
diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py
index 8a38aa95d8..68240af12a 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/constant.py
@@ -15,9 +15,19 @@
# specific language governing permissions and limitations
# under the License.
#
+from enum import Enum
MLNODE_CONF_DIRECTORY_NAME = "conf"
MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
MLNODE_MODEL_STORAGE_DIRECTORY_NAME = "models"
+
+
+class TSStatusCode(Enum):
+ SUCCESS_STATUS = 200
+ REDIRECTION_RECOMMEND = 400
+ MLNODE_INTERNAL_ERROR = 1510
+
+ def get_status_code(self) -> int:
+ return self.value
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index 8a36353d47..78021420d4 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -15,39 +15,30 @@
# specific language governing permissions and limitations
# under the License.
#
-from enum import Enum
-
-from iotdb.thrift.common.ttypes import TSStatus
+from iotdb.mlnode.constant import TSStatusCode
+from iotdb.mlnode.storage import model_storage
+from iotdb.mlnode.util import get_status
from iotdb.thrift.mlnode import IMLNodeRPCService
from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
TDeleteModelReq, TForecastReq,
TForecastResp)
-class TSStatusCode(Enum):
- SUCCESS_STATUS = 200
-
- def get_status_code(self) -> int:
- return self.value
-
-
-def get_status(status_code: TSStatusCode, message: str) -> TSStatus:
- status = TSStatus(status_code.get_status_code())
- status.message = message
- return status
-
-
class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
def __init__(self):
pass
def deleteModel(self, req: TDeleteModelReq):
- return get_status(TSStatusCode.SUCCESS_STATUS, "")
+ try:
+ model_storage.delete_model(req.modelId)
+ return get_status(TSStatusCode.SUCCESS_STATUS)
+ except Exception as e:
+ return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
def createTrainingTask(self, req: TCreateTrainingTaskReq):
- return get_status(TSStatusCode.SUCCESS_STATUS, "")
+ return get_status(TSStatusCode.SUCCESS_STATUS)
def forecast(self, req: TForecastReq):
- status = get_status(TSStatusCode.SUCCESS_STATUS, "")
+ status = get_status(TSStatusCode.SUCCESS_STATUS)
forecast_result = b'forecast result'
return TForecastResp(status, forecast_result)
diff --git a/mlnode/iotdb/mlnode/service.py b/mlnode/iotdb/mlnode/service.py
index a2c05ea5c3..8bf97402d1 100644
--- a/mlnode/iotdb/mlnode/service.py
+++ b/mlnode/iotdb/mlnode/service.py
@@ -22,7 +22,7 @@ from thrift.protocol import TCompactProtocol
from thrift.server import TServer
from thrift.transport import TSocket, TTransport
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
from iotdb.mlnode.handler import MLNodeRPCServiceHandler
from iotdb.mlnode.log import logger
from iotdb.thrift.mlnode import IMLNodeRPCService
@@ -32,7 +32,8 @@ class RPCService(threading.Thread):
def __init__(self):
super().__init__()
processor =
IMLNodeRPCService.Processor(handler=MLNodeRPCServiceHandler())
- transport = TSocket.TServerSocket(host=config.get_mn_rpc_address(),
port=config.get_mn_rpc_port())
+ transport =
TSocket.TServerSocket(host=descriptor.get_config().get_mn_rpc_address(),
+
port=descriptor.get_config().get_mn_rpc_port())
transport_factory = TTransport.TFramedTransportFactory()
protocol_factory = TCompactProtocol.TCompactProtocolFactory()
@@ -45,6 +46,7 @@ class RPCService(threading.Thread):
class MLNode(object):
def __init__(self):
+ descriptor.load_config_from_file()
self.__rpc_service = RPCService()
def start(self) -> None:
diff --git a/mlnode/iotdb/mlnode/model_storage.py
b/mlnode/iotdb/mlnode/storage.py
similarity index 82%
rename from mlnode/iotdb/mlnode/model_storage.py
rename to mlnode/iotdb/mlnode/storage.py
index ee745689b1..84d7dfd7ed 100644
--- a/mlnode/iotdb/mlnode/model_storage.py
+++ b/mlnode/iotdb/mlnode/storage.py
@@ -19,42 +19,49 @@
import json
import os
import shutil
+from typing import Dict, Tuple
import torch
import torch.nn as nn
from pylru import lrucache
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
from iotdb.mlnode.exception import ModelNotExistError
+from iotdb.mlnode.log import logger
class ModelStorage(object):
def __init__(self):
- self.__model_dir = os.path.join(os.getcwd(),
config.get_mn_model_storage_dir())
+ self.__model_dir = os.path.join('.',
descriptor.get_config().get_mn_model_storage_dir())
if not os.path.exists(self.__model_dir):
- os.mkdir(self.__model_dir)
+ try:
+ os.mkdir(self.__model_dir)
+ except PermissionError as e:
+ logger.error(e)
+ raise e
- self.__model_cache = lrucache(config.get_mn_model_storage_cache_size())
+ self.__model_cache =
lrucache(descriptor.get_config().get_mn_model_storage_cache_size())
def save_model(self,
model: nn.Module,
- model_config: dict,
+ model_config: Dict,
model_id: str,
- trial_id: str) -> None:
+ trial_id: str) -> str:
"""
Note: model config for time series should contain 'input_len' and
'input_vars'
"""
model_dir_path = os.path.join(self.__model_dir, f'{model_id}')
if not os.path.exists(model_dir_path):
- os.mkdir(model_dir_path)
+ os.makedirs(model_dir_path)
model_file_path = os.path.join(model_dir_path, f'{trial_id}.pt')
sample_input = [torch.randn(1, model_config['input_len'],
model_config['input_vars'])]
torch.jit.save(torch.jit.trace(model, sample_input),
model_file_path,
_extra_files={'model_config': json.dumps(model_config)})
+ return os.path.abspath(model_file_path)
- def load_model(self, model_id: str, trial_id: str) ->
(torch.jit.ScriptModule, dict):
+ def load_model(self, model_id: str, trial_id: str) ->
Tuple[torch.jit.ScriptModule, Dict]:
"""
Returns:
jit_model: a ScriptModule contains model architecture and
parameters, which can be deployed cross-platform
diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py
index 8932479c4a..5cdc52f01a 100644
--- a/mlnode/iotdb/mlnode/util.py
+++ b/mlnode/iotdb/mlnode/util.py
@@ -15,20 +15,18 @@
# specific language governing permissions and limitations
# under the License.
#
+from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.exception import BadNodeUrlError
from iotdb.mlnode.log import logger
-from iotdb.thrift.common.ttypes import TEndPoint
+from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
""" Parse TEndPoint from a given endpoint url.
-
Args:
endpoint_url: an endpoint url, format: ip:port
-
Returns:
TEndPoint
-
Raises:
BadNodeUrlError
"""
@@ -45,3 +43,15 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
except ValueError as e:
logger.warning("Illegal endpoint url format: {}
({})".format(endpoint_url, e))
raise BadNodeUrlError(endpoint_url)
+
+
+def get_status(status_code: TSStatusCode, message: str = None) -> TSStatus:
+ status = TSStatus(status_code.get_status_code())
+ status.message = message
+ return status
+
+
+def verify_success(status: TSStatus, err_msg: str) -> None:
+ if status.code != TSStatusCode.SUCCESS_STATUS.get_status_code():
+ logger.warn(err_msg + ", error status is ", status)
+ raise RuntimeError(str(status.code) + ": " + status.message)
diff --git a/mlnode/pyproject.toml b/mlnode/pyproject.toml
index 3944e2910d..56290f8d4e 100644
--- a/mlnode/pyproject.toml
+++ b/mlnode/pyproject.toml
@@ -49,6 +49,7 @@ packages = [
python = "^3.7"
thrift = "^0.13.0"
dynaconf = "^3.1.11"
+pylru = "^1.2.1"
[tool.poetry.scripts]
mlnode = "iotdb.mlnode.script:main"
\ No newline at end of file
diff --git a/mlnode/requirements.txt b/mlnode/requirements.txt
index edd85701ab..c49c8a0189 100644
--- a/mlnode/requirements.txt
+++ b/mlnode/requirements.txt
@@ -20,7 +20,7 @@ pandas>=1.3.5
numpy>=1.21.4
apache-iotdb
poetry
-torch
+torch~=2.0.0
pylru
thrift~=0.13.0
diff --git a/mlnode/test/test_model_storage.py
b/mlnode/test/test_model_storage.py
index 99857db37e..863e73b716 100644
--- a/mlnode/test/test_model_storage.py
+++ b/mlnode/test/test_model_storage.py
@@ -16,26 +16,26 @@
# under the License.
#
-
import os
import time
import torch.nn as nn
-from iotdb.mlnode.config import config
-from iotdb.mlnode.model_storage import model_storage
+from iotdb.mlnode.config import descriptor
+from iotdb.mlnode.exception import ModelNotExistError
+from iotdb.mlnode.storage import model_storage
-class TestModel(nn.Module):
+class ExampleModel(nn.Module):
def __init__(self):
- super(TestModel, self).__init__()
+ super(ExampleModel, self).__init__()
self.layer = nn.Identity()
def forward(self, x):
return self.layer(x)
-model = TestModel()
+model = ExampleModel()
model_config = {
'input_len': 1,
'input_vars': 1,
@@ -47,7 +47,8 @@ def test_save_model():
trial_id = 'tid_0'
model_id = 'mid_test_model_save'
model_storage.save_model(model, model_config, model_id=model_id,
trial_id=trial_id)
- assert os.path.exists(os.path.join(config.get_mn_model_storage_dir(),
model_id, f'{trial_id}.pt'))
+ assert os.path.exists(
+ os.path.join(descriptor.get_config().get_mn_model_storage_dir(),
f'{model_id}', f'{trial_id}.pt'))
def test_load_model():
@@ -58,6 +59,17 @@ def test_load_model():
assert model_config == model_config_loaded
+def test_load_not_exist_model():
+ trial_id = 'dummy_trial'
+ model_id = 'dummy_model'
+ try:
+ model_loaded, model_config_loaded =
model_storage.load_model(model_id=model_id, trial_id=trial_id)
+ except Exception as e:
+ assert e.message == ModelNotExistError(
+ os.path.join('.',
descriptor.get_config().get_mn_model_storage_dir(),
+ model_id, f'{trial_id}.pt')).message
+
+
def test_delete_model():
trial_id1 = 'tid_1'
trial_id2 = 'tid_2'
@@ -65,9 +77,11 @@ def test_delete_model():
model_storage.save_model(model, model_config, model_id=model_id,
trial_id=trial_id1)
model_storage.save_model(model, model_config, model_id=model_id,
trial_id=trial_id2)
model_storage.delete_model(model_id=model_id)
- assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(),
model_id, f'{trial_id1}.pt'))
- assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(),
model_id, f'{trial_id2}.pt'))
- assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(),
model_id))
+ assert not os.path.exists(
+ os.path.join(descriptor.get_config().get_mn_model_storage_dir(),
f'{model_id}', f'{trial_id1}.pt'))
+ assert not os.path.exists(
+ os.path.join(descriptor.get_config().get_mn_model_storage_dir(),
f'{model_id}', f'{trial_id2}.pt'))
+ assert not
os.path.exists(os.path.join(descriptor.get_config().get_mn_model_storage_dir(),
f'{model_id}'))
def test_delete_trial():
@@ -75,4 +89,5 @@ def test_delete_trial():
model_id = 'mid_test_model_delete'
model_storage.save_model(model, model_config, model_id=model_id,
trial_id=trial_id)
model_storage.delete_trial(model_id=model_id, trial_id=trial_id)
- assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(),
model_id, f'{trial_id}.pt'))
+ assert not os.path.exists(
+ os.path.join(descriptor.get_config().get_mn_model_storage_dir(),
f'{model_id}', f'{trial_id}.pt'))
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
index ee59471e6e..432b76a718 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
@@ -66,6 +66,7 @@ import
org.apache.iotdb.db.mpp.plan.statement.metadata.template.UnsetSchemaTempl
import org.apache.iotdb.db.qp.sql.IoTDBSqlParser;
import org.apache.iotdb.db.qp.sql.SqlLexer;
import org.apache.iotdb.db.utils.QueryDataSetUtils;
+import org.apache.iotdb.mpp.rpc.thrift.TDeleteModelMetricsReq;
import org.apache.iotdb.mpp.rpc.thrift.TFetchTimeseriesReq;
import org.apache.iotdb.mpp.rpc.thrift.TRecordModelMetricsReq;
import org.apache.iotdb.service.rpc.thrift.TSAggregationQueryReq;
@@ -111,6 +112,9 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
+import static
org.apache.iotdb.commons.conf.IoTDBConstant.MULTI_LEVEL_PATH_WILDCARD;
+import static
org.apache.iotdb.db.service.thrift.impl.MLNodeRPCServiceImpl.ML_METRICS_PATH_PREFIX;
+
/** Convert SQL and RPC requests to {@link Statement}. */
public class StatementGenerator {
private static final PerformanceOverviewMetrics PERFORMANCE_OVERVIEW_METRICS
=
@@ -807,10 +811,10 @@ public class StatementGenerator {
return databasePath;
}
- public static InsertRowStatement createStatement(
- TRecordModelMetricsReq recordModelMetricsReq, String prefix) throws
IllegalPathException {
+ public static InsertRowStatement createStatement(TRecordModelMetricsReq
recordModelMetricsReq)
+ throws IllegalPathException {
String path =
- prefix
+ ML_METRICS_PATH_PREFIX
+ TsFileConstant.PATH_SEPARATOR
+ recordModelMetricsReq.getModelId()
+ TsFileConstant.PATH_SEPARATOR
@@ -873,4 +877,15 @@ public class StatementGenerator {
}
return queryStatement;
}
+
+ public static DeleteTimeSeriesStatement
createStatement(TDeleteModelMetricsReq req)
+ throws IllegalPathException {
+ String path =
+ ML_METRICS_PATH_PREFIX
+ + TsFileConstant.PATH_SEPARATOR
+ + req.getModelId()
+ + TsFileConstant.PATH_SEPARATOR
+ + MULTI_LEVEL_PATH_WILDCARD;
+ return new DeleteTimeSeriesStatement(Collections.singletonList(new
PartialPath(path)));
+ }
}
diff --git
a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
index c50c74e841..b575bc7543 100644
---
a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
+++
b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
@@ -104,6 +104,7 @@ import
org.apache.iotdb.db.mpp.plan.planner.plan.node.write.DeleteDataNode;
import org.apache.iotdb.db.mpp.plan.scheduler.load.LoadTsFileScheduler;
import org.apache.iotdb.db.mpp.plan.statement.component.WhereCondition;
import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
+import
org.apache.iotdb.db.mpp.plan.statement.metadata.DeleteTimeSeriesStatement;
import org.apache.iotdb.db.pipe.agent.PipeAgent;
import org.apache.iotdb.db.query.control.SessionManager;
import org.apache.iotdb.db.query.control.clientsession.IClientSession;
@@ -200,6 +201,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.TimeZone;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
@@ -880,7 +882,30 @@ public class DataNodeInternalRPCServiceImpl implements
IDataNodeRPCService.Iface
@Override
public TSStatus deleteModelMetrics(TDeleteModelMetricsReq req) throws
TException {
- return RpcUtils.SUCCESS_STATUS;
+ IClientSession session = new InternalClientSession(req.getModelId());
+ SESSION_MANAGER.registerSession(session);
+ SESSION_MANAGER.supplySession(
+ session, "MLNode", TimeZone.getDefault().getID(), ClientVersion.V_1_0);
+
+ try {
+ DeleteTimeSeriesStatement deleteTimeSeriesStatement =
StatementGenerator.createStatement(req);
+
+ long queryId = SESSION_MANAGER.requestQueryId();
+ ExecutionResult result =
+ COORDINATOR.execute(
+ deleteTimeSeriesStatement,
+ queryId,
+ SESSION_MANAGER.getSessionInfo(session),
+ "",
+ PARTITION_FETCHER,
+ SCHEMA_FETCHER);
+ return result.status;
+ } catch (Exception e) {
+ return onQueryException(e, OperationType.DELETE_TIMESERIES);
+ } finally {
+ SESSION_MANAGER.closeSession(session,
COORDINATOR::cleanupQueryExecution);
+ SESSION_MANAGER.removeCurrSession();
+ }
}
@Override
diff --git
a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
index 544d4cd04c..74e14f9f3f 100644
---
a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
+++
b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
@@ -60,14 +60,14 @@ import static
org.apache.iotdb.db.utils.ErrorHandlingUtils.onQueryException;
public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler {
+ public static final String ML_METRICS_PATH_PREFIX = "root.__system.ml.exp";
+
private static final Logger LOGGER =
LoggerFactory.getLogger(MLNodeRPCServiceImpl.class);
private static final SessionManager SESSION_MANAGER =
SessionManager.getInstance();
private static final Coordinator COORDINATOR = Coordinator.getInstance();
- private static final String ML_METRICS_STORAGE_GROUP =
"root.__system.ml.exp";
-
private final IPartitionFetcher PARTITION_FETCHER;
private final ISchemaFetcher SCHEMA_FETCHER;
@@ -176,8 +176,7 @@ public class MLNodeRPCServiceImpl implements
IMLNodeRPCServiceWithHandler {
@Override
public TSStatus recordModelMetrics(TRecordModelMetricsReq req) throws
TException {
try {
- InsertRowStatement insertRowStatement =
- StatementGenerator.createStatement(req, ML_METRICS_STORAGE_GROUP);
+ InsertRowStatement insertRowStatement =
StatementGenerator.createStatement(req);
long queryId = SESSION_MANAGER.requestQueryId();
ExecutionResult result =
diff --git a/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
b/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
index 5f0fd876be..f9ae2b356a 100644
--- a/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
+++ b/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
@@ -186,6 +186,7 @@ public enum TSStatusCode {
// ML Model
CREATE_MODEL_ERROR(1500),
DROP_MODEL_ERROR(1501),
+ MLNODE_INTERNAL_ERROR(1510),
// Pipe Plugin
CREATE_PIPE_PLUGIN_ERROR(1600),
diff --git a/thrift-mlnode/src/main/thrift/mlnode.thrift
b/thrift-mlnode/src/main/thrift/mlnode.thrift
index 916022e973..abadc79576 100644
--- a/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -31,7 +31,7 @@ struct TCreateTrainingTaskReq {
struct TDeleteModelReq {
1: required string modelId
- 2: optional string trailId
+ 2: optional string trialId
}
struct TForecastReq {
diff --git a/thrift/src/main/thrift/datanode.thrift
b/thrift/src/main/thrift/datanode.thrift
index 49066f56c2..23cc0fa032 100644
--- a/thrift/src/main/thrift/datanode.thrift
+++ b/thrift/src/main/thrift/datanode.thrift
@@ -809,4 +809,3 @@ service IMLNodeInternalRPCService{
*/
common.TSStatus recordModelMetrics(TRecordModelMetricsReq req)
}
-