This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch research/MLEngine
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/research/MLEngine by this push:
     new 96301b8fda0 refactor mlnode (tmp save)
96301b8fda0 is described below

commit 96301b8fda08d028720c0e939c240c5b179be0ae
Author: Minghui Liu <[email protected]>
AuthorDate: Mon Jun 19 10:09:15 2023 +0800

    refactor mlnode (tmp save)
---
 .../thrift-mlnode/src/main/thrift/mlnode.thrift    |  7 +-
 mlnode/iotdb/mlnode/algorithm/factory.py           | 36 +++-----
 mlnode/iotdb/mlnode/algorithm/hyperparameter.py    | 88 +++++++++++++++++++
 .../mlnode/algorithm/models/forecast/dlinear.py    | 14 ++++
 .../mlnode/{constant.py => algorithm/validator.py} | 33 +++++---
 mlnode/iotdb/mlnode/constant.py                    | 20 +++++
 mlnode/iotdb/mlnode/data_access/factory.py         | 16 ++--
 mlnode/iotdb/mlnode/exception.py                   |  6 ++
 mlnode/iotdb/mlnode/handler.py                     | 27 +++---
 mlnode/iotdb/mlnode/parser.py                      | 56 ++++++++++++-
 mlnode/iotdb/mlnode/process/manager.py             | 56 ++++++-------
 mlnode/iotdb/mlnode/process/task.py                | 98 +++++++++-------------
 mlnode/iotdb/mlnode/process/trial.py               | 14 ++--
 mlnode/pom.xml                                     |  8 +-
 14 files changed, 322 insertions(+), 157 deletions(-)

diff --git a/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift 
b/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
index 46f7b025f47..5faa939a903 100644
--- a/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -23,10 +23,9 @@ namespace py iotdb.thrift.mlnode
 
 struct TCreateTrainingTaskReq {
   1: required string modelId
-  2: required bool isAuto
-  3: required map<string, string> modelConfigs
-  4: required list<string> queryExpressions
-  5: optional string queryFilter
+  2: required map<string, string> options
+  3: required map<string, string> hyperparameters
+  4: required string queryBody
 }
 
 struct TDeleteModelReq {
diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py 
b/mlnode/iotdb/mlnode/algorithm/factory.py
index ed39e4e2b8c..70a11da55b7 100644
--- a/mlnode/iotdb/mlnode/algorithm/factory.py
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-from typing import Dict, Tuple
+from typing import Dict
 
 import torch.nn as nn
 
@@ -24,9 +24,10 @@ from iotdb.mlnode.algorithm.models.forecast.dlinear import 
(dlinear,
                                                             dlinear_individual)
 from iotdb.mlnode.algorithm.models.forecast.nbeats import nbeats
 from iotdb.mlnode.exception import BadConfigValueError
+# Common configs for all forecasting model with default values
+from iotdb.mlnode.parser import TaskOptions
 
 
-# Common configs for all forecasting model with default values
 def _common_config(**kwargs):
     return {
         'input_len': 96,
@@ -51,14 +52,9 @@ _forecasting_model_default_config_dict = {
 
 
 def create_forecast_model(
-        model_name,
-        input_len=96,
-        pred_len=96,
-        input_vars=1,
-        output_vars=1,
-        forecast_task_type=ForecastTaskType.ENDOGENOUS,
-        **kwargs,
-) -> Tuple[nn.Module, Dict]:
+        task_options: TaskOptions,
+        model_configs: Dict,
+) -> nn.Module:
     """
     Factory method for all support forecasting models
     the given arguments is common configs shared by all forecasting models
@@ -83,13 +79,6 @@ def create_forecast_model(
         raise BadConfigValueError('forecast_task_type', forecast_task_type,
                                   f'It should be one of 
{list(_forecasting_model_default_config_dict.keys())}')
 
-    common_config = _forecasting_model_default_config_dict[forecast_task_type]
-    common_config['input_len'] = input_len
-    common_config['pred_len'] = pred_len
-    common_config['input_vars'] = input_vars
-    common_config['output_vars'] = output_vars
-    common_config['forecast_task_type'] = str(forecast_task_type)
-
     if not input_len > 0:
         raise BadConfigValueError('input_len', input_len,
                                   'Length of input series should be positive')
@@ -108,23 +97,24 @@ def create_forecast_model(
                                       'Number of input/output variables should 
be '
                                       'the same in endogenous forecast')
 
-    if model_name == ForecastModelType.DLINEAR.value:
+    model_type =
+    if task_options.model_type == ForecastModelType.DLINEAR.value:
         model, model_config = dlinear(
             common_config=common_config,
             **kwargs
         )
-    elif model_name == ForecastModelType.DLINEAR_INDIVIDUAL.value:
+    elif task_options.model_type == ForecastModelType.DLINEAR_INDIVIDUAL.value:
         model, model_config = dlinear_individual(
             common_config=common_config,
             **kwargs
         )
-    elif model_name == ForecastModelType.NBEATS.value:
+    elif task_options.model_type == ForecastModelType.NBEATS.value:
         model, model_config = nbeats(
             common_config=common_config,
             **kwargs
         )
     else:
-        raise BadConfigValueError('model_name', model_name, f'It should be one 
of {ForecastModelType.values()}')
+        raise BadConfigValueError('model_name', task_options.model_type,
+                                  f'It should be one of 
{ForecastModelType.values()}')
 
-    model_config['model_name'] = model_name
-    return model, model_config
+    return model
diff --git a/mlnode/iotdb/mlnode/algorithm/hyperparameter.py 
b/mlnode/iotdb/mlnode/algorithm/hyperparameter.py
new file mode 100644
index 00000000000..7b10dc7b9e6
--- /dev/null
+++ b/mlnode/iotdb/mlnode/algorithm/hyperparameter.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from enum import Enum
+from typing import Optional, List, Dict, Tuple
+
+import optuna
+
+from iotdb.mlnode.algorithm.enums import ForecastModelType
+from iotdb.mlnode.algorithm.models.forecast.dlinear import 
dlinear_hyperparameter_map
+from iotdb.mlnode.algorithm.validator import Validator
+from iotdb.mlnode.parser import TaskOptions
+
+
+class Hyperparameter(object):
+
+    def __init__(self, name: str, log: bool):
+        """
+        Args:
+            name: name of the hyperparameter
+        """
+        self.__name = name
+        self.__log = log
+
+
+class FloatHyperparameter(Hyperparameter):
+    def __init__(self, name: str,
+                 log: bool,
+                 default_value: float,
+                 value_validators: List[Validator],
+                 default_low: float,
+                 low_validators: List[Validator],
+                 default_high: float,
+                 high_validators: List[Validator],
+                 step: Optional[float] = None):
+        super(FloatHyperparameter, self).__init__(name, log)
+        self.__default_value = default_value
+        self.__value_validators = value_validators
+        self.__default_low = default_low
+        self.__low_validators = low_validators
+        self.__default_high = default_high
+        self.__high_validators = high_validators
+        self.__step = step
+
+
+class HyperparameterName(Enum):
+    LEARNING_RATE = "learning_rate"
+
+    def name(self):
+        return self.value
+
+
+def get_model_hyperparameter_map(model_type: ForecastModelType):
+    if model_type == ForecastModelType.DLINEAR:
+        return dlinear_hyperparameter_map
+    else:
+        raise NotImplementedError(f"Model type {model_type} is not supported 
yet.")
+
+
+def parse_fixed_hyperparameters(
+        task_options: TaskOptions,
+        hyperparameters: Dict[str, str]
+) -> Tuple[Dict, Dict]:
+    hyperparameter_map = get_model_hyperparameter_map(task_options.model_type)
+    return None, None
+
+
+def generate_hyperparameters(
+        optuna_suggest: optuna.Trial,
+        task_options: TaskOptions,
+        hyperparameters: Dict[str, str]
+) -> Tuple[Dict, Dict]:
+    hyperparameter_map = get_model_hyperparameter_map(task_options.model_type)
+    return None, None
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py 
b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
index 35ba728fa78..b839f811733 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
@@ -23,6 +23,8 @@ import torch
 import torch.nn as nn
 
 from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.algorithm.hyperparameter import FloatHyperparameter, 
HyperparameterName
+from iotdb.mlnode.algorithm.validator import FloatRangeValidator
 from iotdb.mlnode.exception import BadConfigValueError
 
 
@@ -143,6 +145,18 @@ def _model_config(**kwargs):
     }
 
 
+dlinear_hyperparameter_map = {
+    HyperparameterName.LEARNING_RATE: 
FloatHyperparameter(name=HyperparameterName.LEARNING_RATE.name(),
+                                                          log=True,
+                                                          default_value=1e-3,
+                                                          
value_validators=[FloatRangeValidator(1, 10)],
+                                                          default_low=1e-5,
+                                                          low_validators=[],
+                                                          default_high=1e-1,
+                                                          high_validators=[]),
+}
+
+
 def dlinear(common_config: Dict, kernel_size=25, **kwargs) -> Tuple[DLinear, 
Dict]:
     config = _model_config()
     config.update(**common_config)
diff --git a/mlnode/iotdb/mlnode/constant.py 
b/mlnode/iotdb/mlnode/algorithm/validator.py
similarity index 52%
copy from mlnode/iotdb/mlnode/constant.py
copy to mlnode/iotdb/mlnode/algorithm/validator.py
index 68240af12a4..7c26e15342b 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/algorithm/validator.py
@@ -15,19 +15,30 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-from enum import Enum
+from abc import abstractmethod
 
-MLNODE_CONF_DIRECTORY_NAME = "conf"
-MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
-MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
 
-MLNODE_MODEL_STORAGE_DIRECTORY_NAME = "models"
+class Validator(object):
+    @abstractmethod
+    def validate(self, value):
+        """
+        Checks whether the given value is valid.
 
+        Parameters:
+        - value: The value to validate
 
-class TSStatusCode(Enum):
-    SUCCESS_STATUS = 200
-    REDIRECTION_RECOMMEND = 400
-    MLNODE_INTERNAL_ERROR = 1510
+        Returns:
+        - True if the value is valid, False otherwise.
+        """
+        raise NotImplementedError("Subclasses must implement the validate() 
method.")
 
-    def get_status_code(self) -> int:
-        return self.value
+
+class FloatRangeValidator(Validator):
+    def __init__(self, min_value, max_value):
+        self.min_value = min_value
+        self.max_value = max_value
+
+    def validate(self, value):
+        if isinstance(value, float) and self.min_value <= value <= 
self.max_value:
+            return True
+        return False
diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py
index 68240af12a4..7292856a5e9 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/constant.py
@@ -23,6 +23,8 @@ MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
 
 MLNODE_MODEL_STORAGE_DIRECTORY_NAME = "models"
 
+DEFAULT_TRIAL_ID = "__trial_0"
+
 
 class TSStatusCode(Enum):
     SUCCESS_STATUS = 200
@@ -31,3 +33,21 @@ class TSStatusCode(Enum):
 
     def get_status_code(self) -> int:
         return self.value
+
+
+class TaskType(Enum):
+    FORECAST = "forecast"
+
+
+class OptionsKey(Enum):
+    # common
+    TASK_TYPE = "task_type"
+    MODEL_TYPE = "model_type"
+    AUTO_TUNING = "auto_tuning"
+
+    # forecast
+    INPUT_LENGTH = "input_length"
+    PREDICT_LENGTH = "predict_length"
+
+    def name(self) -> str:
+        return self.value
diff --git a/mlnode/iotdb/mlnode/data_access/factory.py 
b/mlnode/iotdb/mlnode/data_access/factory.py
index ee8d57ffb45..9850bc45ef8 100644
--- a/mlnode/iotdb/mlnode/data_access/factory.py
+++ b/mlnode/iotdb/mlnode/data_access/factory.py
@@ -19,12 +19,14 @@ from typing import Dict, Tuple
 
 from torch.utils.data import Dataset
 
+from iotdb.mlnode.constant import TaskType
 from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
 from iotdb.mlnode.data_access.offline.dataset import (TimeSeriesDataset,
                                                       WindowDataset)
 from iotdb.mlnode.data_access.offline.source import (FileDataSource,
                                                      ThriftDataSource)
 from iotdb.mlnode.exception import BadConfigValueError, MissingConfigError
+from iotdb.mlnode.parser import TaskOptions
 
 
 def _dataset_common_config(**kwargs):
@@ -43,11 +45,15 @@ _dataset_default_config_dict = {
 }
 
 
-def create_forecast_dataset(
-        source_type,
-        dataset_type,
-        **kwargs,
-) -> Tuple[Dataset, Dict]:
+def create_dataset(query_body: str, task_options: TaskOptions) -> Dataset:
+    task_type = task_options.get_task_type()
+    if task_type == TaskType.FORECAST:
+        return create_forecast_dataset(query_body, task_options)
+    else:
+        raise Exception(f"task type {task_type} not supported.")
+
+
+def create_forecast_dataset(query_body: str, task_options: TaskOptions) -> 
Dataset:
     """
     Factory method for all support dataset
     currently implement two types of PyTorch dataset: WindowDataset, 
TimeSeriesDataset
diff --git a/mlnode/iotdb/mlnode/exception.py b/mlnode/iotdb/mlnode/exception.py
index 47edb95eb43..f679417544c 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/exception.py
@@ -18,6 +18,7 @@
 
 class _BaseError(Exception):
     """Base class for exceptions in this module."""
+
     def __init__(self):
         self.message = None
 
@@ -45,6 +46,11 @@ class MissingConfigError(_BaseError):
         self.message = "Missing config: {}".format(config_name)
 
 
+class MissingOptionError(_BaseError):
+    def __init__(self, config_name: str):
+        self.message = "Missing task option: {}".format(config_name)
+
+
 class WrongTypeConfigError(_BaseError):
     def __init__(self, config_name: str, expected_type: str):
         self.message = "Wrong type for config: {0}, expected: 
{1}".format(config_name, expected_type)
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index d4b4ca3b034..c42c05c8100 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -18,9 +18,9 @@
 
 from iotdb.mlnode.config import descriptor
 from iotdb.mlnode.constant import TSStatusCode
-from iotdb.mlnode.data_access.factory import create_forecast_dataset
+from iotdb.mlnode.data_access.factory import create_dataset
 from iotdb.mlnode.log import logger
-from iotdb.mlnode.parser import parse_forecast_request, parse_training_request
+from iotdb.mlnode.parser import parse_forecast_request, parse_task_options
 from iotdb.mlnode.process.manager import TaskManager
 from iotdb.mlnode.serde import convert_to_binary
 from iotdb.mlnode.storage import model_storage
@@ -46,24 +46,25 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
     def createTrainingTask(self, req: TCreateTrainingTaskReq):
         task = None
         try:
-            # parse request, check required config and config type
-            data_config, model_config, task_config = 
parse_training_request(req)
-            # create dataset & check data config legitimacy
-            dataset, data_config = create_forecast_dataset(**data_config)
+            # parse options
+            task_options = parse_task_options(req.options)
 
-            model_config['input_vars'] = data_config['input_vars']
-            model_config['output_vars'] = data_config['output_vars']
-
-            # create task & check task config legitimacy
-            task = self.__task_manager.create_training_task(dataset, 
data_config, model_config, task_config)
+            # create task
+            task = self.__task_manager.create_forecast_training_task(
+                model_id=req.modelId,
+                task_options=task_options,
+                hyperparameters=req.hyperparameters,
+                dataset=create_dataset(req.queryBody, task_options)
+            )
 
             return get_status(TSStatusCode.SUCCESS_STATUS)
         except Exception as e:
             logger.warn(e)
             return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
         finally:
-            # submit task stage & check resource and decide pending/start
-            self.__task_manager.submit_training_task(task)
+            if task is not None:
+                # submit task to process pool
+                self.__task_manager.submit_training_task(task)
 
     def forecast(self, req: TForecastReq):
         model_path, data, pred_length = parse_forecast_request(req)
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
index 34bb7eb65a9..739fd95e071 100644
--- a/mlnode/iotdb/mlnode/parser.py
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -19,15 +19,67 @@
 
 import argparse
 import re
+from abc import abstractmethod
 from typing import Dict, List, Tuple
 
-from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.algorithm.enums import ForecastTaskType, ForecastModelType
+from iotdb.mlnode.constant import TaskType, OptionsKey
 from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
-from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
+from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError, 
MissingOptionError
 from iotdb.mlnode.serde import convert_to_df
 from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TForecastReq
 
 
+class TaskOptions(object):
+    def __init__(self, options: Dict):
+        self.__raw_options = options
+
+        if OptionsKey.MODEL_TYPE not in self.__raw_options:
+            raise MissingOptionError(OptionsKey.MODEL_TYPE.name())
+        self.model_type = getattr(ForecastModelType, 
self.__raw_options.pop(OptionsKey.MODEL_TYPE), None)
+        if not self.model_type:
+            raise Exception(f"model_type {self.model_type} not supported.")
+
+        # training with auto-tuning as default
+        self.auto_tuning = self.__raw_options.pop(OptionsKey.AUTO_TUNING, 
default=True)
+
+    @abstractmethod
+    def get_task_type(self) -> TaskType:
+        raise NotImplementedError("Subclasses must implement the validate() 
method.")
+
+    def _check_redundant_options(self) -> None:
+        if len(self.__raw_options):
+            raise Exception(f"redundant options: {self.__raw_options}.")
+
+
+class ForecastTaskOptions(TaskOptions):
+    def __init__(self, options: Dict):
+        super().__init__(options)
+        self.input_length = self.__raw_options.pop(OptionsKey.INPUT_LENGTH, 
default=96)
+        self.predict_length = 
self.__raw_options.pop(OptionsKey.PREDICT_LENGTH, default=96)
+        super()._check_redundant_options()
+
+    def get_task_type(self) -> TaskType:
+        return TaskType.FORECAST
+
+
+def parse_task_type(options: Dict) -> TaskType:
+    if OptionsKey.TASK_TYPE not in options:
+        raise MissingOptionError(OptionsKey.TASK_TYPE.name())
+    task_type = getattr(TaskType, options.pop(OptionsKey.TASK_TYPE), None)
+    if not task_type:
+        raise Exception(f"task type {task_type} not supported.")
+    return task_type
+
+
+def parse_task_options(options) -> TaskOptions:
+    task_type = parse_task_type(options)
+    if task_type == TaskType.FORECAST:
+        return ForecastTaskOptions(options)
+    else:
+        raise Exception(f"task type {task_type} not supported.")
+
+
 class _ConfigParser(argparse.ArgumentParser):
     """
     A parser for parsing configs from configs: dict
diff --git a/mlnode/iotdb/mlnode/process/manager.py 
b/mlnode/iotdb/mlnode/process/manager.py
index 5f40efd3f8f..7f82110502e 100644
--- a/mlnode/iotdb/mlnode/process/manager.py
+++ b/mlnode/iotdb/mlnode/process/manager.py
@@ -28,9 +28,10 @@ from torch.utils.data import Dataset
 from subprocess import call
 
 from iotdb.mlnode.log import logger
+from iotdb.mlnode.parser import TaskOptions, ForecastTaskOptions
 from iotdb.mlnode.process.task import (ForecastingInferenceTask,
-                                       ForecastingSingleTrainingTask,
-                                       ForecastingTuningTrainingTask)
+                                       ForecastFixedParamTrainingTask,
+                                       ForecastAutoTuningTrainingTask)
 
 
 class TaskManager(object):
@@ -48,47 +49,42 @@ class TaskManager(object):
         self.__training_process_pool = mp.Pool(pool_size)
         self.__inference_process_pool = mp.Pool(pool_size)
 
-    def create_training_task(self,
-                             dataset: Dataset,
-                             data_configs: Dict,
-                             model_configs: Dict,
-                             task_configs: Dict):
+    def create_forecast_training_task(self,
+                                      model_id: str,
+                                      task_options: TaskOptions,
+                                      hyperparameters: Dict[str, str],
+                                      dataset: Dataset):
         """
 
         Args:
+            model_id:
+            task_options:
             dataset: a torch dataset to be used for training
-            data_configs: dict of data configurations
-            model_configs: dict of model configurations
-            task_configs: dict of task configurations
+            hyperparameters:
 
         Returns:
             task: a training task for forecasting, which can be submitted to 
self.__training_process_pool
         """
-        model_id = task_configs['model_id']
-        if task_configs['tuning']:
-            task = ForecastingTuningTrainingTask(
-                task_configs,
-                model_configs,
-                self.__pid_info,
-                data_configs,
-                dataset,
+        if task_options.auto_tuning:
+            return ForecastAutoTuningTrainingTask(
                 model_id,
+                task_options,
+                hyperparameters,
+                dataset,
+                pid_info=self.__pid_info,
             )
         else:
-            task = ForecastingSingleTrainingTask(
-                task_configs,
-                model_configs,
-                self.__pid_info,
-                data_configs,
-                dataset,
-                model_id,
+            return ForecastFixedParamTrainingTask(
+                model_id=model_id,
+                task_options=task_options,
+                hyperparameters=hyperparameters,
+                dataset=dataset,
+                pid_info=self.__pid_info,
             )
-        return task
 
-    def submit_training_task(self, task: Union[ForecastingTuningTrainingTask, 
ForecastingSingleTrainingTask]) -> None:
-        if task is not None:
-            self.__training_process_pool.apply_async(task, args=())
-            logger.info(f'Task: ({task.model_id}) - Training process submitted 
successfully')
+    def submit_training_task(self, task: Union[ForecastAutoTuningTrainingTask, 
ForecastFixedParamTrainingTask]) -> None:
+        self.__training_process_pool.apply_async(task, args=())
+        logger.info(f'Task: ({task.model_id}) - Training process submitted 
successfully')
 
     def create_forecast_task(self,
                              task_configs,
diff --git a/mlnode/iotdb/mlnode/process/task.py 
b/mlnode/iotdb/mlnode/process/task.py
index afdef68f4ea..6eeace39220 100644
--- a/mlnode/iotdb/mlnode/process/task.py
+++ b/mlnode/iotdb/mlnode/process/task.py
@@ -28,9 +28,11 @@ import torch
 from torch.utils.data import Dataset
 
 from iotdb.mlnode.algorithm.factory import create_forecast_model
+from iotdb.mlnode.algorithm.hyperparameter import parse_fixed_hyperparameters, 
generate_hyperparameters
 from iotdb.mlnode.client import client_manager
 from iotdb.mlnode.config import descriptor
 from iotdb.mlnode.log import logger
+from iotdb.mlnode.parser import TaskOptions
 from iotdb.mlnode.process.trial import ForecastingTrainingTrial
 from iotdb.mlnode.storage import model_storage
 from iotdb.thrift.common.ttypes import TrainingState
@@ -45,26 +47,21 @@ class ForestingTrainingObjective:
 
     def __init__(
             self,
-            trial_configs: Dict,
-            model_configs: Dict,
+            task_options: TaskOptions,
+            hyperparameters: Dict[str, str],
             dataset: Dataset,
-            # pid_info: Dict
+            pid_info: Dict
     ):
-        self.trial_configs = trial_configs
-        self.model_configs = model_configs
+        self.task_options = task_options
+        self.hyperparameters = hyperparameters
         self.dataset = dataset
-        # self.pid_info = pid_info
+        self.pid_info = pid_info
 
-    def __call__(self, trial: optuna.Trial):
-        # TODO: decide which parameters to tune
-        trial_configs = self.trial_configs
-        trial_configs['learning_rate'] = trial.suggest_float("lr", 1e-7, 1e-1, 
log=True)
-        trial_configs['trial_id'] = 'tid_' + str(trial._trial_id)
-        # TODO: check args
-        model, model_configs = create_forecast_model(**self.model_configs)
-        # self.pid_info[self.trial_configs['model_id']][trial._trial_id] = 
os.getpid()
-        _trial = ForecastingTrainingTrial(trial_configs, model, model_configs, 
self.dataset)
-        loss = _trial.start()
+    def __call__(self, optuna_suggest: optuna.Trial):
+        model_configs, task_configs = generate_hyperparameters(optuna_suggest, 
self.task_options, self.hyperparameters)
+        model = create_forecast_model(self.task_options, model_configs)
+        trial = ForecastingTrainingTrial(model, task_configs, self.dataset)
+        loss = trial.start()
         return loss
 
 
@@ -76,19 +73,13 @@ class _BasicTask(object):
 
     def __init__(
             self,
-            task_configs: Dict,
-            model_configs: Dict,
             pid_info: Dict
     ):
         """
         Args:
-            task_configs:
-            model_configs:
             pid_info:
         """
         self.pid_info = pid_info
-        self.task_configs = task_configs
-        self.model_configs = model_configs
 
     @abstractmethod
     def __call__(self):
@@ -98,23 +89,18 @@ class _BasicTask(object):
 class _BasicTrainingTask(_BasicTask):
     def __init__(
             self,
-            task_configs: Dict,
-            model_configs: Dict,
-            pid_info: Dict,
-            data_configs: Dict,
+            model_id: str,
+            task_options: TaskOptions,
+            hyperparameters: Dict[str, str],
             dataset: Dataset,
+            pid_info: Dict
     ):
-        """
-        Args:
-            task_configs:
-            model_configs:
-            pid_info:
-            data_configs:
-            dataset:
-        """
-        super().__init__(task_configs, model_configs, pid_info)
-        self.data_configs = data_configs
+        super().__init__(pid_info)
+        self.model_id = model_id
+        self.task_options = task_options
+        self.hyperparameters = hyperparameters
         self.dataset = dataset
+
         self.confignode_client = client_manager.borrow_config_node_client()
 
     @abstractmethod
@@ -153,15 +139,14 @@ class _BasicInferenceTask(_BasicTask):
         raise NotImplementedError
 
 
-class ForecastingSingleTrainingTask(_BasicTrainingTask):
+class ForecastFixedParamTrainingTask(_BasicTrainingTask):
     def __init__(
             self,
-            task_configs: Dict,
-            model_configs: Dict,
-            pid_info: Dict,
-            data_configs: Dict,
-            dataset: Dataset,
             model_id: str,
+            task_options: TaskOptions,
+            hyperparameters: Dict[str, str],
+            dataset: Dataset,
+            pid_info: Dict
     ):
         """
         Args:
@@ -171,33 +156,31 @@ class ForecastingSingleTrainingTask(_BasicTrainingTask):
             data_configs: dict of data configurations
             dataset: training dataset
         """
-        super().__init__(task_configs, model_configs, pid_info, data_configs, 
dataset)
-        self.model_id = model_id
-        self.default_trial_id = 'tid_0'
-        self.task_configs['trial_id'] = self.default_trial_id
-        model, model_configs = create_forecast_model(**model_configs)
-        self.trial = ForecastingTrainingTrial(task_configs, model, 
model_configs, dataset)
+        super().__init__(model_id, task_options, hyperparameters, dataset, 
pid_info)
+        model_configs, task_configs = 
parse_fixed_hyperparameters(task_options, hyperparameters)
+        self.trial = 
ForecastingTrainingTrial(create_forecast_model(task_options, model_configs),
+                                              task_configs,
+                                              dataset)
 
     def __call__(self):
         try:
             self.pid_info[self.model_id] = os.getpid()
             self.trial.start()
-            self.confignode_client.update_model_state(self.model_id, 
TrainingState.FINISHED, self.default_trial_id)
+            self.confignode_client.update_model_state(self.model_id, 
TrainingState.FINISHED, self.trial.trial_id)
         except Exception as e:
             logger.warn(e)
-            self.confignode_client.update_model_state(self.model_id, 
TrainingState.FAILED, self.default_trial_id)
+            self.confignode_client.update_model_state(self.model_id, 
TrainingState.FAILED, self.trial.trial_id)
             raise e
 
 
-class ForecastingTuningTrainingTask(_BasicTrainingTask):
+class ForecastAutoTuningTrainingTask(_BasicTrainingTask):
     def __init__(
             self,
-            task_configs: Dict,
-            model_configs: Dict,
-            pid_info: Dict,
-            data_configs: Dict,
-            dataset: Dataset,
             model_id: str,
+            task_options: TaskOptions,
+            hyperparameters: Dict[str, str],
+            dataset: Dataset,
+            pid_info: Dict
     ):
         """
         Args:
@@ -207,8 +190,7 @@ class ForecastingTuningTrainingTask(_BasicTrainingTask):
             data_configs: dict of data configurations
             dataset: training dataset
         """
-        super().__init__(task_configs, model_configs, pid_info, data_configs, 
dataset)
-        self.model_id = model_id
+        super().__init__(model_id, task_options, hyperparameters, dataset, 
pid_info)
         self.study = optuna.create_study(direction='minimize')
 
     def __call__(self):
diff --git a/mlnode/iotdb/mlnode/process/trial.py 
b/mlnode/iotdb/mlnode/process/trial.py
index 7fc34477686..be6b57fbccf 100644
--- a/mlnode/iotdb/mlnode/process/trial.py
+++ b/mlnode/iotdb/mlnode/process/trial.py
@@ -26,7 +26,9 @@ from torch.utils.data import DataLoader, Dataset
 
 from iotdb.mlnode.algorithm.metric import all_metrics, build_metrics
 from iotdb.mlnode.client import client_manager
+from iotdb.mlnode.constant import DEFAULT_TRIAL_ID
 from iotdb.mlnode.log import logger
+from iotdb.mlnode.parser import TaskOptions
 from iotdb.mlnode.storage import model_storage
 from iotdb.thrift.common.ttypes import TrainingState
 
@@ -81,9 +83,8 @@ def _parse_trial_config(**kwargs):
 class BasicTrial(object):
     def __init__(
             self,
-            task_configs: Dict,
             model: nn.Module,
-            model_configs: Dict,
+            task_configs: Dict,
             dataset: Dataset
     ):
         self.trial_configs = task_configs
@@ -131,10 +132,9 @@ class BasicTrial(object):
 class ForecastingTrainingTrial(BasicTrial):
     def __init__(
             self,
-            task_configs: Dict,
             model: nn.Module,
-            model_configs: Dict,
-            dataset: Dataset,
+            task_configs: Dict,
+            dataset: Dataset
     ):
         """
         A training trial, accept all parameters needed and train a single 
model.
@@ -146,8 +146,8 @@ class ForecastingTrainingTrial(BasicTrial):
             dataset: training dataset
             **kwargs:
         """
-        super(ForecastingTrainingTrial, self).__init__(task_configs, model, 
model_configs, dataset)
-
+        super(ForecastingTrainingTrial, self).__init__(model, task_configs, 
dataset)
+        self.trial_id = DEFAULT_TRIAL_ID
         self.dataloader = self._build_dataloader()
         self.datanode_client = client_manager.borrow_data_node_client()
         self.confignode_client = client_manager.borrow_config_node_client()
diff --git a/mlnode/pom.xml b/mlnode/pom.xml
index c75e9c6fb33..f94bfc5a496 100644
--- a/mlnode/pom.xml
+++ b/mlnode/pom.xml
@@ -104,13 +104,13 @@
                             
<outputDirectory>${basedir}/iotdb/thrift/</outputDirectory>
                             <resources>
                                 <resource>
-                                    
<directory>${basedir}/../thrift-commons/target/generated-sources-python/iotdb/thrift/</directory>
+                                    
<directory>${basedir}/../iotdb-protocol/thrift-commons/target/generated-sources-python/iotdb/thrift/</directory>
                                 </resource>
                                 <resource>
-                                    
<directory>${basedir}/../thrift-confignode/target/generated-sources-python/iotdb/thrift/</directory>
+                                    
<directory>${basedir}/../iotdb-protocol/thrift-confignode/target/generated-sources-python/iotdb/thrift/</directory>
                                 </resource>
                                 <resource>
-                                    
<directory>${basedir}/../thrift-mlnode/target/generated-sources-python/iotdb/thrift/</directory>
+                                    
<directory>${basedir}/../iotdb-protocol/thrift-mlnode/target/generated-sources-python/iotdb/thrift/</directory>
                                 </resource>
                             </resources>
                         </configuration>
@@ -127,7 +127,7 @@
                             
<outputDirectory>${basedir}/iotdb/thrift/datanode</outputDirectory>
                             <resources>
                                 <resource>
-                                    
<directory>${basedir}/../thrift/target/generated-sources-python/iotdb/thrift/datanode/</directory>
+                                    
<directory>${basedir}/../iotdb-protocol/thrift/target/generated-sources-python/iotdb/thrift/datanode/</directory>
                                 </resource>
                             </resources>
                         </configuration>


Reply via email to