This is an automated email from the ASF dual-hosted git repository. hui pushed a commit to branch mlnode/test in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 833d0619ed19711132042578ae0ea6123bec77c4 Author: Minghui Liu <[email protected]> AuthorDate: Mon Apr 3 16:11:31 2023 +0800 make mlnode available --- mlnode/iotdb/mlnode/algorithm/enums.py | 3 ++ mlnode/iotdb/mlnode/algorithm/factory.py | 1 + .../mlnode/algorithm/models/forecast/dlinear.py | 3 +- mlnode/iotdb/mlnode/client.py | 32 ++++++------ mlnode/iotdb/mlnode/config.py | 10 ++-- mlnode/iotdb/mlnode/constant.py | 6 --- mlnode/iotdb/mlnode/data_access/enums.py | 3 ++ mlnode/iotdb/mlnode/data_access/offline/source.py | 4 +- mlnode/iotdb/mlnode/handler.py | 32 +++++------- mlnode/iotdb/mlnode/parser.py | 7 ++- mlnode/iotdb/mlnode/process/manager.py | 36 +++++++------ mlnode/iotdb/mlnode/process/task.py | 4 +- mlnode/iotdb/mlnode/process/task_factory.py | 2 +- mlnode/iotdb/mlnode/process/trial.py | 59 ++++++++++++---------- mlnode/iotdb/mlnode/service.py | 11 ++-- mlnode/iotdb/mlnode/storage.py | 11 ++-- mlnode/iotdb/mlnode/util.py | 2 +- mlnode/pyproject.toml | 1 + mlnode/requirements.txt | 2 +- 19 files changed, 122 insertions(+), 107 deletions(-) diff --git a/mlnode/iotdb/mlnode/algorithm/enums.py b/mlnode/iotdb/mlnode/algorithm/enums.py index 4b05aa4bf8..2def3751cd 100644 --- a/mlnode/iotdb/mlnode/algorithm/enums.py +++ b/mlnode/iotdb/mlnode/algorithm/enums.py @@ -27,3 +27,6 @@ class ForecastTaskType(Enum): def __eq__(self, other: str) -> bool: return self.value == other + + def __hash__(self) -> int: + return hash(self.value) diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py b/mlnode/iotdb/mlnode/algorithm/factory.py index 92cb01a883..26eab10860 100644 --- a/mlnode/iotdb/mlnode/algorithm/factory.py +++ b/mlnode/iotdb/mlnode/algorithm/factory.py @@ -19,6 +19,7 @@ import torch.nn as nn from iotdb.mlnode.algorithm.enums import ForecastTaskType from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models +from iotdb.mlnode.algorithm.models.forecast.dlinear import dlinear from iotdb.mlnode.exception import BadConfigValueError diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py index fa9ee04e56..966ea20347 100644 --- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py +++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py @@ -21,6 +21,7 @@ import math import torch import torch.nn as nn +from iotdb.mlnode.algorithm.enums import ForecastTaskType from iotdb.mlnode.exception import BadConfigValueError @@ -65,7 +66,7 @@ class DLinear(nn.Module): pred_len=96, input_vars=1, output_vars=1, - forecast_type='m', # TODO, support others + forecast_task_type=ForecastTaskType.ENDOGENOUS, # TODO, support others ): super(DLinear, self).__init__() self.input_len = input_len diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py index 6c1c549ea1..724d517316 100644 --- a/mlnode/iotdb/mlnode/client.py +++ b/mlnode/iotdb/mlnode/client.py @@ -21,7 +21,7 @@ from thrift.protocol import TBinaryProtocol, TCompactProtocol from thrift.Thrift import TException from thrift.transport import TSocket, TTransport -from iotdb.mlnode.config import config +from iotdb.mlnode.config import descriptor from iotdb.mlnode.constant import TSStatusCode from iotdb.mlnode.log import logger from iotdb.mlnode.util import verify_success @@ -29,7 +29,7 @@ from iotdb.thrift.common.ttypes import TEndPoint, TrainingState, TSStatus from iotdb.thrift.confignode import IConfigNodeRPCService from iotdb.thrift.confignode.ttypes import (TUpdateModelInfoReq, TUpdateModelStateReq) -from iotdb.thrift.datanode import IDataNodeRPCService +from iotdb.thrift.datanode import IMLNodeInternalRPCService from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq, TFetchTimeseriesResp, TRecordModelMetricsReq) @@ -39,8 +39,8 @@ from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TDeleteModelReq class ClientManager(object): def __init__(self): - self.__data_node_endpoint = config.get_mn_target_data_node() - self.__config_node_endpoint = config.get_mn_target_config_node() + self.__data_node_endpoint = descriptor.get_config().get_mn_target_data_node() + self.__config_node_endpoint = descriptor.get_config().get_mn_target_config_node() def borrow_data_node_client(self): return DataNodeClient(host=self.__data_node_endpoint.ip, @@ -120,18 +120,14 @@ class DataNodeClient(object): raise e protocol = TBinaryProtocol.TBinaryProtocol(transport) - self.__client = IDataNodeRPCService.Client(protocol) + self.__client = IMLNodeInternalRPCService.Client(protocol) def fetch_timeseries(self, - session_id: int, - statement_id: int, query_expressions: list = [], query_filter: str = None, fetch_size: int = DEFAULT_FETCH_SIZE, timeout: int = DEFAULT_TIMEOUT) -> TFetchTimeseriesResp: req = TFetchTimeseriesReq( - sessionId=session_id, - statementId=statement_id, queryExpressions=query_expressions, queryFilter=query_filter, fetchSize=fetch_size, @@ -147,8 +143,8 @@ class DataNodeClient(object): def record_model_metrics(self, model_id: str, trial_id: str, - metrics: list = [], - values: list = []) -> None: + metrics: list, + values: list) -> None: req = TRecordModelMetricsReq( modelId=model_id, trialId=trial_id, @@ -186,6 +182,7 @@ class ConfigNodeClient(object): if self.__config_leader is not None: try: self.__connect(self.__config_leader) + return except TException: logger.warn("The current node {} may have been down, try next node", self.__config_leader) self.__config_leader = None @@ -200,6 +197,7 @@ class ConfigNodeClient(object): try_endpoint = self.__config_nodes[self.__cursor] try: self.__connect(try_endpoint) + return except TException: logger.warn("The current node {} may have been down, try next node", try_endpoint) @@ -217,7 +215,7 @@ class ConfigNodeClient(object): except TTransport.TTransportException as e: logger.exception("TTransportException!", exc_info=e) - protocol = TCompactProtocol.TBinaryProtocol(transport) + protocol = TBinaryProtocol.TBinaryProtocol(transport) self.__client = IConfigNodeRPCService.Client(protocol) def __wait_and_reconnect(self) -> None: @@ -246,12 +244,12 @@ class ConfigNodeClient(object): def update_model_state(self, model_id: str, - trial_id: str, - training_state: TrainingState) -> None: + training_state: TrainingState, + best_trail_id: str = None) -> None: req = TUpdateModelStateReq( modelId=model_id, - trialId=trial_id, - trainingState=training_state + state=training_state, + bestTrailId=best_trail_id ) for i in range(0, self.__RETRY_NUM): try: @@ -275,7 +273,7 @@ class ConfigNodeClient(object): model_info = {} req = TUpdateModelInfoReq( modelId=model_id, - trialId=trial_id, + trailId=trial_id, modelInfo={k: str(v) for k, v in model_info.items()}, ) diff --git a/mlnode/iotdb/mlnode/config.py b/mlnode/iotdb/mlnode/config.py index e59338209a..109452eab5 100644 --- a/mlnode/iotdb/mlnode/config.py +++ b/mlnode/iotdb/mlnode/config.py @@ -44,7 +44,7 @@ class MLNodeConfig(object): self.__mn_target_config_node: TEndPoint = TEndPoint("127.0.0.1", 10710) # Target DataNode to be connected by MLNode - self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10730) + self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10780) def get_mn_rpc_address(self) -> str: return self.__mn_rpc_address @@ -86,9 +86,8 @@ class MLNodeConfig(object): class MLNodeDescriptor(object): def __init__(self): self.__config = MLNodeConfig() - self.__load_config_from_file() - def __load_config_from_file(self) -> None: + def load_config_from_file(self) -> None: conf_file = os.path.join(os.getcwd(), MLNODE_CONF_DIRECTORY_NAME, MLNODE_CONF_FILE_NAME) if not os.path.exists(conf_file): logger.info("Cannot find MLNode config file '{}', use default configuration.".format(conf_file)) @@ -113,7 +112,7 @@ class MLNodeDescriptor(object): self.__config.set_mn_model_storage_dir(file_configs.mn_model_storage_dir) if file_configs.mn_model_storage_cache_size is not None: - self.__config.set_mn_model_storage_cachesize(file_configs.mn_model_storage_cache_size) + self.__config.set_mn_model_storage_cache_size(file_configs.mn_model_storage_cache_size) if file_configs.mn_target_config_node is not None: self.__config.set_mn_target_config_node(file_configs.mn_target_config_node) @@ -129,4 +128,5 @@ class MLNodeDescriptor(object): return self.__config -config = MLNodeDescriptor().get_config() +# initialize a singleton +descriptor = MLNodeDescriptor() diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py index e0be2a7b63..3bffa06526 100644 --- a/mlnode/iotdb/mlnode/constant.py +++ b/mlnode/iotdb/mlnode/constant.py @@ -31,9 +31,3 @@ class TSStatusCode(Enum): def get_status_code(self) -> int: return self.value - - -class ModelState(Enum): - RUNNING = 'running' - FINISHED = 'finished' - FAILED = 'failed' diff --git a/mlnode/iotdb/mlnode/data_access/enums.py b/mlnode/iotdb/mlnode/data_access/enums.py index d21a9f69c4..e7f5417b3d 100644 --- a/mlnode/iotdb/mlnode/data_access/enums.py +++ b/mlnode/iotdb/mlnode/data_access/enums.py @@ -27,3 +27,6 @@ class DatasetType(Enum): def __eq__(self, other: str) -> bool: return self.value == other + + def __hash__(self) -> int: + return hash(self.value) diff --git a/mlnode/iotdb/mlnode/data_access/offline/source.py b/mlnode/iotdb/mlnode/data_access/offline/source.py index a63371ec7a..0422bb373d 100644 --- a/mlnode/iotdb/mlnode/data_access/offline/source.py +++ b/mlnode/iotdb/mlnode/data_access/offline/source.py @@ -74,8 +74,8 @@ class ThriftDataSource(DataSource): try: res = data_client.fetch_timeseries( - queryExpressions=self.query_expressions, - queryFilter=self.query_filter, + query_expressions=self.query_expressions, + query_filter=self.query_filter, ) except Exception: raise RuntimeError(f'Fail to fetch data with query expressions: {self.query_expressions}' diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py index e43f26c226..1a6e3eb90a 100644 --- a/mlnode/iotdb/mlnode/handler.py +++ b/mlnode/iotdb/mlnode/handler.py @@ -19,7 +19,6 @@ from iotdb.mlnode.algorithm.factory import create_forecast_model from iotdb.mlnode.constant import TSStatusCode from iotdb.mlnode.data_access.factory import create_forecast_dataset -from iotdb.mlnode.log import logger from iotdb.mlnode.parser import parse_training_request from iotdb.mlnode.process.manager import TaskManager from iotdb.mlnode.util import get_status @@ -37,29 +36,26 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface): return get_status(TSStatusCode.SUCCESS_STATUS, "") def createTrainingTask(self, req: TCreateTrainingTaskReq): - # parse request stage (check required config and config type) - data_config, model_config, task_config = parse_training_request(req) - - # create model stage (check model config legitimacy) + task = None try: + # parse request, check required config and config type + data_config, model_config, task_config = parse_training_request(req) + + # create model & check model config legitimacy model, model_config = create_forecast_model(**model_config) - except Exception as e: # Create model failed - return get_status(TSStatusCode.FAIL_STATUS, str(e)) - logger.info('model config: ' + str(model_config)) - # create data stage (check data config legitimacy) - try: + # create dataset & check data config legitimacy dataset, data_config = create_forecast_dataset(**data_config) - except Exception as e: # Create data failed - return get_status(TSStatusCode.FAIL_STATUS, str(e)) - logger.info('data config: ' + str(data_config)) - - # create task stage (check task config legitimacy) - # submit task stage (check resource and decide pending/start) - self.__task_manager.submit_training_task(task_config, model_config, model, dataset) + # create task & check task config legitimacy + task = self.__task_manager.create_training_task(dataset, model, model_config, task_config) - return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task') + return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task') + except Exception as e: + return get_status(TSStatusCode.FAIL_STATUS, str(e)) + finally: + # submit task stage & check resource and decide pending/start + self.__task_manager.submit_training_task(task) def forecast(self, req: TForecastReq): status = get_status(TSStatusCode.SUCCESS_STATUS, "") diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py index 236032b9a0..c052cd5050 100644 --- a/mlnode/iotdb/mlnode/parser.py +++ b/mlnode/iotdb/mlnode/parser.py @@ -91,8 +91,9 @@ class _ConfigParser(argparse.ArgumentParser): - output_vars: number of output variables """ _data_config_parser = _ConfigParser() -_data_config_parser.add_argument('--source_type', type=str, required=True) -_data_config_parser.add_argument('--dataset_type', type=DatasetType, required=True) +_data_config_parser.add_argument('--source_type', type=str, default="thrift") +_data_config_parser.add_argument('--dataset_type', type=DatasetType, default=DatasetType.WINDOW, + choices=list(DatasetType)) _data_config_parser.add_argument('--filename', type=str, default='') _data_config_parser.add_argument('--query_expressions', type=str, nargs='*', default=[]) _data_config_parser.add_argument('--query_filter', type=str, default='') @@ -183,6 +184,8 @@ def parse_training_request(req: TCreateTrainingTaskReq): task_config: configurations related to task """ config = req.modelConfigs + config.update(model_name=config['model_type']) + config.update(task_class=config['model_task']) config.update(model_id=req.modelId) config.update(tuning=req.isAuto) config.update(query_expressions=req.queryExpressions) diff --git a/mlnode/iotdb/mlnode/process/manager.py b/mlnode/iotdb/mlnode/process/manager.py index bfb035f27b..0af0353973 100644 --- a/mlnode/iotdb/mlnode/process/manager.py +++ b/mlnode/iotdb/mlnode/process/manager.py @@ -18,7 +18,11 @@ import multiprocessing as mp +from torch import nn +from torch.utils.data import Dataset + from iotdb.mlnode.log import logger +from iotdb.mlnode.process.task import ForecastingTrainingTask from iotdb.mlnode.process.task_factory import create_task @@ -33,22 +37,22 @@ class TaskManager(object): self.__pid_info = self.__shared_resource_manager.dict() self.__training_process_pool = mp.Pool(pool_num) - def submit_training_task(self, task_configs, model_configs, model, dataset): - assert 'model_id' in task_configs.keys(), 'Task config should contain model_id' + def create_training_task(self, + dataset: Dataset, + model: nn.Module, + model_configs: dict, + task_configs: dict) -> ForecastingTrainingTask: model_id = task_configs['model_id'] self.__pid_info[model_id] = self.__shared_resource_manager.dict() - try: - task = create_task( - task_configs, - model_configs, - model, - dataset, - self.__pid_info - ) - except Exception as e: - logger.exception(e) - return e, False + return create_task( + task_configs, + model_configs, + model, + dataset, + self.__pid_info + ) - logger.info(f'Task: ({model_id}) - Training process submitted successfully') - self.__training_process_pool.apply_async(task, args=()) - return model_id, True + def submit_training_task(self, task: ForecastingTrainingTask) -> None: + if task is not None: + self.__training_process_pool.apply_async(task, args=()) + logger.info(f'Task: ({task.model_id}) - Training process submitted successfully') diff --git a/mlnode/iotdb/mlnode/process/task.py b/mlnode/iotdb/mlnode/process/task.py index 7fac9cb1c5..85d5b5d2cf 100644 --- a/mlnode/iotdb/mlnode/process/task.py +++ b/mlnode/iotdb/mlnode/process/task.py @@ -75,7 +75,7 @@ class _BasicTask(object): class ForecastingTrainingTask(_BasicTask): def __init__(self, task_configs, model_configs, model, dataset, task_trial_map): super(ForecastingTrainingTask, self).__init__(task_configs, model_configs, model, dataset, task_trial_map) - model_id = self.task_configs['model_id'] + self.model_id = self.task_configs['model_id'] self.tuning = self.task_configs["tuning"] if self.tuning: # TODO implement tuning task @@ -83,7 +83,7 @@ class ForecastingTrainingTask(_BasicTask): else: self.task_configs['trial_id'] = 'tid_0' # TODO: set a default trial id self.trial = ForecastingTrainingTrial(self.task_configs, self.model, self.model_configs, self.dataset) - self.task_trial_map[model_id]['tid_0'] = os.getpid() + self.task_trial_map[self.model_id]['tid_0'] = os.getpid() def __call__(self): try: diff --git a/mlnode/iotdb/mlnode/process/task_factory.py b/mlnode/iotdb/mlnode/process/task_factory.py index 7b9966a8f3..083b84eba2 100644 --- a/mlnode/iotdb/mlnode/process/task_factory.py +++ b/mlnode/iotdb/mlnode/process/task_factory.py @@ -20,7 +20,7 @@ from iotdb.mlnode.process.task import ForecastingTrainingTask support_task_types = { - 'forecast_training_task': ForecastingTrainingTask + 'forecast': ForecastingTrainingTask } diff --git a/mlnode/iotdb/mlnode/process/trial.py b/mlnode/iotdb/mlnode/process/trial.py index f8671b4657..9852e3ffb4 100644 --- a/mlnode/iotdb/mlnode/process/trial.py +++ b/mlnode/iotdb/mlnode/process/trial.py @@ -23,11 +23,11 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset -from iotdb.mlnode.algorithm.metric import all_metrics +from iotdb.mlnode.algorithm.metric import MAE, MSE, all_metrics from iotdb.mlnode.client import client_manager from iotdb.mlnode.log import logger from iotdb.mlnode.storage import model_storage -from iotdb.mlnode.constant import ModelState +from iotdb.thrift.common.ttypes import TrainingState def _parse_trial_config(**kwargs): @@ -188,8 +188,8 @@ class ForecastingTrainingTrial(BasicTrial): val_loss.append(loss.item()) for name in self.metric_names: - value = eval(name)(outputs.detach().cpu().numpy(), - batch_y.detach().cpu().numpy()) + metric = eval(name)() + value = metric(outputs.detach().cpu().numpy(), batch_y.detach().cpu().numpy()) metrics_dict[name].append(value) for name, value_list in metrics_dict.items(): @@ -207,25 +207,32 @@ class ForecastingTrainingTrial(BasicTrial): return val_loss, metrics_dict def start(self) -> float: - self.confignode_client.update_model_state(self.model_id, self.trial_id, ModelState.RUNNING) - best_loss = np.inf - best_metrics_dict = None - for epoch in range(self.epochs): - self._train(epoch) - val_loss, metrics_dict = self._validate(epoch) - if val_loss < best_loss: - best_loss = val_loss - best_metrics_dict = metrics_dict - model_storage.save_model(self.model, - self.model_configs, - model_id=self.model_id, - trial_id=self.trial_id) - - logger.info(f'Trial: ({self.model_id}_{self.trial_id}) - Finished with best model saved successfully') - - self.confignode_client.update_model_state(self.model_id, self.trial_id, ModelState.RUNNING) - model_info = {} - model_info.update(best_metrics_dict) - model_info.update(self.trial_configs) - self.confignode_client.update_model_info(self.model_id, self.trial_id, model_info) - return best_loss + try: + self.confignode_client.update_model_state(self.model_id, TrainingState.RUNNING) + best_loss = np.inf + best_metrics_dict = None + model_path = None + for epoch in range(self.epochs): + self._train(epoch) + val_loss, metrics_dict = self._validate(epoch) + if val_loss < best_loss: + best_loss = val_loss + best_metrics_dict = metrics_dict + model_path = model_storage.save_model(self.model, + self.model_configs, + model_id=self.model_id, + trial_id=self.trial_id) + + logger.info(f'Trial: ({self.model_id}_{self.trial_id}) - Finished with best model saved successfully') + + model_info = {} + model_info.update(best_metrics_dict) + model_info.update(self.trial_configs) + model_info['model_path'] = model_path + self.confignode_client.update_model_info(self.model_id, self.trial_id, model_info) + self.confignode_client.update_model_state(self.model_id, TrainingState.FINISHED, self.trial_id) + return best_loss + except Exception as e: + logger.warn(e) + self.confignode_client.update_model_state(self.model_id, TrainingState.FAILED) + raise e diff --git a/mlnode/iotdb/mlnode/service.py b/mlnode/iotdb/mlnode/service.py index a2c05ea5c3..ae0727cc5a 100644 --- a/mlnode/iotdb/mlnode/service.py +++ b/mlnode/iotdb/mlnode/service.py @@ -19,10 +19,10 @@ import threading import time from thrift.protocol import TCompactProtocol -from thrift.server import TServer +from thrift.server import TProcessPoolServer from thrift.transport import TSocket, TTransport -from iotdb.mlnode.config import config +from iotdb.mlnode.config import descriptor from iotdb.mlnode.handler import MLNodeRPCServiceHandler from iotdb.mlnode.log import logger from iotdb.thrift.mlnode import IMLNodeRPCService @@ -32,11 +32,13 @@ class RPCService(threading.Thread): def __init__(self): super().__init__() processor = IMLNodeRPCService.Processor(handler=MLNodeRPCServiceHandler()) - transport = TSocket.TServerSocket(host=config.get_mn_rpc_address(), port=config.get_mn_rpc_port()) + transport = TSocket.TServerSocket(host=descriptor.get_config().get_mn_rpc_address(), + port=descriptor.get_config().get_mn_rpc_port()) transport_factory = TTransport.TFramedTransportFactory() protocol_factory = TCompactProtocol.TCompactProtocolFactory() - self.__pool_server = TServer.TThreadPoolServer(processor, transport, transport_factory, protocol_factory) + self.__pool_server = TProcessPoolServer.TProcessPoolServer(processor, transport, transport_factory, + protocol_factory) def run(self) -> None: logger.info("The RPC service thread begin to run...") @@ -45,6 +47,7 @@ class RPCService(threading.Thread): class MLNode(object): def __init__(self): + descriptor.load_config_from_file() self.__rpc_service = RPCService() def start(self) -> None: diff --git a/mlnode/iotdb/mlnode/storage.py b/mlnode/iotdb/mlnode/storage.py index ee745689b1..78a0be43bf 100644 --- a/mlnode/iotdb/mlnode/storage.py +++ b/mlnode/iotdb/mlnode/storage.py @@ -24,35 +24,36 @@ import torch import torch.nn as nn from pylru import lrucache -from iotdb.mlnode.config import config +from iotdb.mlnode.config import descriptor from iotdb.mlnode.exception import ModelNotExistError class ModelStorage(object): def __init__(self): - self.__model_dir = os.path.join(os.getcwd(), config.get_mn_model_storage_dir()) + self.__model_dir = os.path.join('.', descriptor.get_config().get_mn_model_storage_dir()) if not os.path.exists(self.__model_dir): os.mkdir(self.__model_dir) - self.__model_cache = lrucache(config.get_mn_model_storage_cache_size()) + self.__model_cache = lrucache(descriptor.get_config().get_mn_model_storage_cache_size()) def save_model(self, model: nn.Module, model_config: dict, model_id: str, - trial_id: str) -> None: + trial_id: str) -> str: """ Note: model config for time series should contain 'input_len' and 'input_vars' """ model_dir_path = os.path.join(self.__model_dir, f'{model_id}') if not os.path.exists(model_dir_path): - os.mkdir(model_dir_path) + os.makedirs(model_dir_path) model_file_path = os.path.join(model_dir_path, f'{trial_id}.pt') sample_input = [torch.randn(1, model_config['input_len'], model_config['input_vars'])] torch.jit.save(torch.jit.trace(model, sample_input), model_file_path, _extra_files={'model_config': json.dumps(model_config)}) + return os.path.abspath(model_file_path) def load_model(self, model_id: str, trial_id: str) -> (torch.jit.ScriptModule, dict): """ diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py index e451d2b25a..5d3a2d670e 100644 --- a/mlnode/iotdb/mlnode/util.py +++ b/mlnode/iotdb/mlnode/util.py @@ -52,6 +52,6 @@ def get_status(status_code: TSStatusCode, message: str) -> TSStatus: def verify_success(status: TSStatus, err_msg: str) -> None: - if status.code != TSStatusCode.SUCCESS_STATUS: + if status.code != TSStatusCode.SUCCESS_STATUS.get_status_code(): logger.warn(err_msg + ", error status is ", status) raise RuntimeError(str(status.code) + ": " + status.message) diff --git a/mlnode/pyproject.toml b/mlnode/pyproject.toml index 3944e2910d..56290f8d4e 100644 --- a/mlnode/pyproject.toml +++ b/mlnode/pyproject.toml @@ -49,6 +49,7 @@ packages = [ python = "^3.7" thrift = "^0.13.0" dynaconf = "^3.1.11" +pylru = "^1.2.1" [tool.poetry.scripts] mlnode = "iotdb.mlnode.script:main" \ No newline at end of file diff --git a/mlnode/requirements.txt b/mlnode/requirements.txt index edd85701ab..c49c8a0189 100644 --- a/mlnode/requirements.txt +++ b/mlnode/requirements.txt @@ -20,7 +20,7 @@ pandas>=1.3.5 numpy>=1.21.4 apache-iotdb poetry -torch +torch~=2.0.0 pylru thrift~=0.13.0
