This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch research/MLEngine
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/research/MLEngine by this push:
new 96301b8fda0 refactor mlnode (tmp save)
96301b8fda0 is described below
commit 96301b8fda08d028720c0e939c240c5b179be0ae
Author: Minghui Liu <[email protected]>
AuthorDate: Mon Jun 19 10:09:15 2023 +0800
refactor mlnode (tmp save)
---
.../thrift-mlnode/src/main/thrift/mlnode.thrift | 7 +-
mlnode/iotdb/mlnode/algorithm/factory.py | 36 +++-----
mlnode/iotdb/mlnode/algorithm/hyperparameter.py | 88 +++++++++++++++++++
.../mlnode/algorithm/models/forecast/dlinear.py | 14 ++++
.../mlnode/{constant.py => algorithm/validator.py} | 33 +++++---
mlnode/iotdb/mlnode/constant.py | 20 +++++
mlnode/iotdb/mlnode/data_access/factory.py | 16 ++--
mlnode/iotdb/mlnode/exception.py | 6 ++
mlnode/iotdb/mlnode/handler.py | 27 +++---
mlnode/iotdb/mlnode/parser.py | 56 ++++++++++++-
mlnode/iotdb/mlnode/process/manager.py | 56 ++++++-------
mlnode/iotdb/mlnode/process/task.py | 98 +++++++++-------------
mlnode/iotdb/mlnode/process/trial.py | 14 ++--
mlnode/pom.xml | 8 +-
14 files changed, 322 insertions(+), 157 deletions(-)
diff --git a/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
b/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
index 46f7b025f47..5faa939a903 100644
--- a/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/iotdb-protocol/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -23,10 +23,9 @@ namespace py iotdb.thrift.mlnode
struct TCreateTrainingTaskReq {
1: required string modelId
- 2: required bool isAuto
- 3: required map<string, string> modelConfigs
- 4: required list<string> queryExpressions
- 5: optional string queryFilter
+ 2: required map<string, string> options
+ 3: required map<string, string> hyperparameters
+ 4: required string queryBody
}
struct TDeleteModelReq {
diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py
b/mlnode/iotdb/mlnode/algorithm/factory.py
index ed39e4e2b8c..70a11da55b7 100644
--- a/mlnode/iotdb/mlnode/algorithm/factory.py
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Dict, Tuple
+from typing import Dict
import torch.nn as nn
@@ -24,9 +24,10 @@ from iotdb.mlnode.algorithm.models.forecast.dlinear import
(dlinear,
dlinear_individual)
from iotdb.mlnode.algorithm.models.forecast.nbeats import nbeats
from iotdb.mlnode.exception import BadConfigValueError
+# Common configs for all forecasting model with default values
+from iotdb.mlnode.parser import TaskOptions
-# Common configs for all forecasting model with default values
def _common_config(**kwargs):
return {
'input_len': 96,
@@ -51,14 +52,9 @@ _forecasting_model_default_config_dict = {
def create_forecast_model(
- model_name,
- input_len=96,
- pred_len=96,
- input_vars=1,
- output_vars=1,
- forecast_task_type=ForecastTaskType.ENDOGENOUS,
- **kwargs,
-) -> Tuple[nn.Module, Dict]:
+ task_options: TaskOptions,
+ model_configs: Dict,
+) -> nn.Module:
"""
Factory method for all support forecasting models
the given arguments is common configs shared by all forecasting models
@@ -83,13 +79,6 @@ def create_forecast_model(
raise BadConfigValueError('forecast_task_type', forecast_task_type,
f'It should be one of
{list(_forecasting_model_default_config_dict.keys())}')
- common_config = _forecasting_model_default_config_dict[forecast_task_type]
- common_config['input_len'] = input_len
- common_config['pred_len'] = pred_len
- common_config['input_vars'] = input_vars
- common_config['output_vars'] = output_vars
- common_config['forecast_task_type'] = str(forecast_task_type)
-
if not input_len > 0:
raise BadConfigValueError('input_len', input_len,
'Length of input series should be positive')
@@ -108,23 +97,24 @@ def create_forecast_model(
'Number of input/output variables should
be '
'the same in endogenous forecast')
- if model_name == ForecastModelType.DLINEAR.value:
+ model_type =
+ if task_options.model_type == ForecastModelType.DLINEAR.value:
model, model_config = dlinear(
common_config=common_config,
**kwargs
)
- elif model_name == ForecastModelType.DLINEAR_INDIVIDUAL.value:
+ elif task_options.model_type == ForecastModelType.DLINEAR_INDIVIDUAL.value:
model, model_config = dlinear_individual(
common_config=common_config,
**kwargs
)
- elif model_name == ForecastModelType.NBEATS.value:
+ elif task_options.model_type == ForecastModelType.NBEATS.value:
model, model_config = nbeats(
common_config=common_config,
**kwargs
)
else:
- raise BadConfigValueError('model_name', model_name, f'It should be one
of {ForecastModelType.values()}')
+ raise BadConfigValueError('model_name', task_options.model_type,
+ f'It should be one of
{ForecastModelType.values()}')
- model_config['model_name'] = model_name
- return model, model_config
+ return model
diff --git a/mlnode/iotdb/mlnode/algorithm/hyperparameter.py
b/mlnode/iotdb/mlnode/algorithm/hyperparameter.py
new file mode 100644
index 00000000000..7b10dc7b9e6
--- /dev/null
+++ b/mlnode/iotdb/mlnode/algorithm/hyperparameter.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from enum import Enum
+from typing import Optional, List, Dict, Tuple
+
+import optuna
+
+from iotdb.mlnode.algorithm.enums import ForecastModelType
+from iotdb.mlnode.algorithm.models.forecast.dlinear import
dlinear_hyperparameter_map
+from iotdb.mlnode.algorithm.validator import Validator
+from iotdb.mlnode.parser import TaskOptions
+
+
+class Hyperparameter(object):
+
+ def __init__(self, name: str, log: bool):
+ """
+ Args:
+ name: name of the hyperparameter
+ """
+ self.__name = name
+ self.__log = log
+
+
+class FloatHyperparameter(Hyperparameter):
+ def __init__(self, name: str,
+ log: bool,
+ default_value: float,
+ value_validators: List[Validator],
+ default_low: float,
+ low_validators: List[Validator],
+ default_high: float,
+ high_validators: List[Validator],
+ step: Optional[float] = None):
+ super(FloatHyperparameter, self).__init__(name, log)
+ self.__default_value = default_value
+ self.__value_validators = value_validators
+ self.__default_low = default_low
+ self.__low_validators = low_validators
+ self.__default_high = default_high
+ self.__high_validators = high_validators
+ self.__step = step
+
+
+class HyperparameterName(Enum):
+ LEARNING_RATE = "learning_rate"
+
+ def name(self):
+ return self.value
+
+
+def get_model_hyperparameter_map(model_type: ForecastModelType):
+ if model_type == ForecastModelType.DLINEAR:
+ return dlinear_hyperparameter_map
+ else:
+ raise NotImplementedError(f"Model type {model_type} is not supported
yet.")
+
+
+def parse_fixed_hyperparameters(
+ task_options: TaskOptions,
+ hyperparameters: Dict[str, str]
+) -> Tuple[Dict, Dict]:
+ hyperparameter_map = get_model_hyperparameter_map(task_options.model_type)
+ return None, None
+
+
+def generate_hyperparameters(
+ optuna_suggest: optuna.Trial,
+ task_options: TaskOptions,
+ hyperparameters: Dict[str, str]
+) -> Tuple[Dict, Dict]:
+ hyperparameter_map = get_model_hyperparameter_map(task_options.model_type)
+ return None, None
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
index 35ba728fa78..b839f811733 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
@@ -23,6 +23,8 @@ import torch
import torch.nn as nn
from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.algorithm.hyperparameter import FloatHyperparameter,
HyperparameterName
+from iotdb.mlnode.algorithm.validator import FloatRangeValidator
from iotdb.mlnode.exception import BadConfigValueError
@@ -143,6 +145,18 @@ def _model_config(**kwargs):
}
+dlinear_hyperparameter_map = {
+ HyperparameterName.LEARNING_RATE:
FloatHyperparameter(name=HyperparameterName.LEARNING_RATE.name(),
+ log=True,
+ default_value=1e-3,
+
value_validators=[FloatRangeValidator(1, 10)],
+ default_low=1e-5,
+ low_validators=[],
+ default_high=1e-1,
+ high_validators=[]),
+}
+
+
def dlinear(common_config: Dict, kernel_size=25, **kwargs) -> Tuple[DLinear,
Dict]:
config = _model_config()
config.update(**common_config)
diff --git a/mlnode/iotdb/mlnode/constant.py
b/mlnode/iotdb/mlnode/algorithm/validator.py
similarity index 52%
copy from mlnode/iotdb/mlnode/constant.py
copy to mlnode/iotdb/mlnode/algorithm/validator.py
index 68240af12a4..7c26e15342b 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/algorithm/validator.py
@@ -15,19 +15,30 @@
# specific language governing permissions and limitations
# under the License.
#
-from enum import Enum
+from abc import abstractmethod
-MLNODE_CONF_DIRECTORY_NAME = "conf"
-MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
-MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
-MLNODE_MODEL_STORAGE_DIRECTORY_NAME = "models"
+class Validator(object):
+ @abstractmethod
+ def validate(self, value):
+ """
+ Checks whether the given value is valid.
+ Parameters:
+ - value: The value to validate
-class TSStatusCode(Enum):
- SUCCESS_STATUS = 200
- REDIRECTION_RECOMMEND = 400
- MLNODE_INTERNAL_ERROR = 1510
+ Returns:
+ - True if the value is valid, False otherwise.
+ """
+ raise NotImplementedError("Subclasses must implement the validate()
method.")
- def get_status_code(self) -> int:
- return self.value
+
+class FloatRangeValidator(Validator):
+ def __init__(self, min_value, max_value):
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def validate(self, value):
+ if isinstance(value, float) and self.min_value <= value <=
self.max_value:
+ return True
+ return False
diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py
index 68240af12a4..7292856a5e9 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/constant.py
@@ -23,6 +23,8 @@ MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
MLNODE_MODEL_STORAGE_DIRECTORY_NAME = "models"
+DEFAULT_TRIAL_ID = "__trial_0"
+
class TSStatusCode(Enum):
SUCCESS_STATUS = 200
@@ -31,3 +33,21 @@ class TSStatusCode(Enum):
def get_status_code(self) -> int:
return self.value
+
+
+class TaskType(Enum):
+ FORECAST = "forecast"
+
+
+class OptionsKey(Enum):
+ # common
+ TASK_TYPE = "task_type"
+ MODEL_TYPE = "model_type"
+ AUTO_TUNING = "auto_tuning"
+
+ # forecast
+ INPUT_LENGTH = "input_length"
+ PREDICT_LENGTH = "predict_length"
+
+ def name(self) -> str:
+ return self.value
diff --git a/mlnode/iotdb/mlnode/data_access/factory.py
b/mlnode/iotdb/mlnode/data_access/factory.py
index ee8d57ffb45..9850bc45ef8 100644
--- a/mlnode/iotdb/mlnode/data_access/factory.py
+++ b/mlnode/iotdb/mlnode/data_access/factory.py
@@ -19,12 +19,14 @@ from typing import Dict, Tuple
from torch.utils.data import Dataset
+from iotdb.mlnode.constant import TaskType
from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
from iotdb.mlnode.data_access.offline.dataset import (TimeSeriesDataset,
WindowDataset)
from iotdb.mlnode.data_access.offline.source import (FileDataSource,
ThriftDataSource)
from iotdb.mlnode.exception import BadConfigValueError, MissingConfigError
+from iotdb.mlnode.parser import TaskOptions
def _dataset_common_config(**kwargs):
@@ -43,11 +45,15 @@ _dataset_default_config_dict = {
}
-def create_forecast_dataset(
- source_type,
- dataset_type,
- **kwargs,
-) -> Tuple[Dataset, Dict]:
+def create_dataset(query_body: str, task_options: TaskOptions) -> Dataset:
+ task_type = task_options.get_task_type()
+ if task_type == TaskType.FORECAST:
+ return create_forecast_dataset(query_body, task_options)
+ else:
+ raise Exception(f"task type {task_type} not supported.")
+
+
+def create_forecast_dataset(query_body: str, task_options: TaskOptions) ->
Dataset:
"""
Factory method for all support dataset
currently implement two types of PyTorch dataset: WindowDataset,
TimeSeriesDataset
diff --git a/mlnode/iotdb/mlnode/exception.py b/mlnode/iotdb/mlnode/exception.py
index 47edb95eb43..f679417544c 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/exception.py
@@ -18,6 +18,7 @@
class _BaseError(Exception):
"""Base class for exceptions in this module."""
+
def __init__(self):
self.message = None
@@ -45,6 +46,11 @@ class MissingConfigError(_BaseError):
self.message = "Missing config: {}".format(config_name)
+class MissingOptionError(_BaseError):
+ def __init__(self, config_name: str):
+ self.message = "Missing task option: {}".format(config_name)
+
+
class WrongTypeConfigError(_BaseError):
def __init__(self, config_name: str, expected_type: str):
self.message = "Wrong type for config: {0}, expected:
{1}".format(config_name, expected_type)
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index d4b4ca3b034..c42c05c8100 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -18,9 +18,9 @@
from iotdb.mlnode.config import descriptor
from iotdb.mlnode.constant import TSStatusCode
-from iotdb.mlnode.data_access.factory import create_forecast_dataset
+from iotdb.mlnode.data_access.factory import create_dataset
from iotdb.mlnode.log import logger
-from iotdb.mlnode.parser import parse_forecast_request, parse_training_request
+from iotdb.mlnode.parser import parse_forecast_request, parse_task_options
from iotdb.mlnode.process.manager import TaskManager
from iotdb.mlnode.serde import convert_to_binary
from iotdb.mlnode.storage import model_storage
@@ -46,24 +46,25 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
def createTrainingTask(self, req: TCreateTrainingTaskReq):
task = None
try:
- # parse request, check required config and config type
- data_config, model_config, task_config =
parse_training_request(req)
- # create dataset & check data config legitimacy
- dataset, data_config = create_forecast_dataset(**data_config)
+ # parse options
+ task_options = parse_task_options(req.options)
- model_config['input_vars'] = data_config['input_vars']
- model_config['output_vars'] = data_config['output_vars']
-
- # create task & check task config legitimacy
- task = self.__task_manager.create_training_task(dataset,
data_config, model_config, task_config)
+ # create task
+ task = self.__task_manager.create_forecast_training_task(
+ model_id=req.modelId,
+ task_options=task_options,
+ hyperparameters=req.hyperparameters,
+ dataset=create_dataset(req.queryBody, task_options)
+ )
return get_status(TSStatusCode.SUCCESS_STATUS)
except Exception as e:
logger.warn(e)
return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
finally:
- # submit task stage & check resource and decide pending/start
- self.__task_manager.submit_training_task(task)
+ if task is not None:
+ # submit task to process pool
+ self.__task_manager.submit_training_task(task)
def forecast(self, req: TForecastReq):
model_path, data, pred_length = parse_forecast_request(req)
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
index 34bb7eb65a9..739fd95e071 100644
--- a/mlnode/iotdb/mlnode/parser.py
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -19,15 +19,67 @@
import argparse
import re
+from abc import abstractmethod
from typing import Dict, List, Tuple
-from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.algorithm.enums import ForecastTaskType, ForecastModelType
+from iotdb.mlnode.constant import TaskType, OptionsKey
from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
-from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
+from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError,
MissingOptionError
from iotdb.mlnode.serde import convert_to_df
from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TForecastReq
+class TaskOptions(object):
+ def __init__(self, options: Dict):
+ self.__raw_options = options
+
+ if OptionsKey.MODEL_TYPE not in self.__raw_options:
+ raise MissingOptionError(OptionsKey.MODEL_TYPE.name())
+ self.model_type = getattr(ForecastModelType,
self.__raw_options.pop(OptionsKey.MODEL_TYPE), None)
+ if not self.model_type:
+ raise Exception(f"model_type {self.model_type} not supported.")
+
+ # training with auto-tuning as default
+ self.auto_tuning = self.__raw_options.pop(OptionsKey.AUTO_TUNING,
default=True)
+
+ @abstractmethod
+ def get_task_type(self) -> TaskType:
+ raise NotImplementedError("Subclasses must implement the validate()
method.")
+
+ def _check_redundant_options(self) -> None:
+ if len(self.__raw_options):
+ raise Exception(f"redundant options: {self.__raw_options}.")
+
+
+class ForecastTaskOptions(TaskOptions):
+ def __init__(self, options: Dict):
+ super().__init__(options)
+ self.input_length = self.__raw_options.pop(OptionsKey.INPUT_LENGTH,
default=96)
+ self.predict_length =
self.__raw_options.pop(OptionsKey.PREDICT_LENGTH, default=96)
+ super()._check_redundant_options()
+
+ def get_task_type(self) -> TaskType:
+ return TaskType.FORECAST
+
+
+def parse_task_type(options: Dict) -> TaskType:
+ if OptionsKey.TASK_TYPE not in options:
+ raise MissingOptionError(OptionsKey.TASK_TYPE.name())
+ task_type = getattr(TaskType, options.pop(OptionsKey.TASK_TYPE), None)
+ if not task_type:
+ raise Exception(f"task type {task_type} not supported.")
+ return task_type
+
+
+def parse_task_options(options) -> TaskOptions:
+ task_type = parse_task_type(options)
+ if task_type == TaskType.FORECAST:
+ return ForecastTaskOptions(options)
+ else:
+ raise Exception(f"task type {task_type} not supported.")
+
+
class _ConfigParser(argparse.ArgumentParser):
"""
A parser for parsing configs from configs: dict
diff --git a/mlnode/iotdb/mlnode/process/manager.py
b/mlnode/iotdb/mlnode/process/manager.py
index 5f40efd3f8f..7f82110502e 100644
--- a/mlnode/iotdb/mlnode/process/manager.py
+++ b/mlnode/iotdb/mlnode/process/manager.py
@@ -28,9 +28,10 @@ from torch.utils.data import Dataset
from subprocess import call
from iotdb.mlnode.log import logger
+from iotdb.mlnode.parser import TaskOptions, ForecastTaskOptions
from iotdb.mlnode.process.task import (ForecastingInferenceTask,
- ForecastingSingleTrainingTask,
- ForecastingTuningTrainingTask)
+ ForecastFixedParamTrainingTask,
+ ForecastAutoTuningTrainingTask)
class TaskManager(object):
@@ -48,47 +49,42 @@ class TaskManager(object):
self.__training_process_pool = mp.Pool(pool_size)
self.__inference_process_pool = mp.Pool(pool_size)
- def create_training_task(self,
- dataset: Dataset,
- data_configs: Dict,
- model_configs: Dict,
- task_configs: Dict):
+ def create_forecast_training_task(self,
+ model_id: str,
+ task_options: TaskOptions,
+ hyperparameters: Dict[str, str],
+ dataset: Dataset):
"""
Args:
+ model_id:
+ task_options:
dataset: a torch dataset to be used for training
- data_configs: dict of data configurations
- model_configs: dict of model configurations
- task_configs: dict of task configurations
+ hyperparameters:
Returns:
task: a training task for forecasting, which can be submitted to
self.__training_process_pool
"""
- model_id = task_configs['model_id']
- if task_configs['tuning']:
- task = ForecastingTuningTrainingTask(
- task_configs,
- model_configs,
- self.__pid_info,
- data_configs,
- dataset,
+ if task_options.auto_tuning:
+ return ForecastAutoTuningTrainingTask(
model_id,
+ task_options,
+ hyperparameters,
+ dataset,
+ pid_info=self.__pid_info,
)
else:
- task = ForecastingSingleTrainingTask(
- task_configs,
- model_configs,
- self.__pid_info,
- data_configs,
- dataset,
- model_id,
+ return ForecastFixedParamTrainingTask(
+ model_id=model_id,
+ task_options=task_options,
+ hyperparameters=hyperparameters,
+ dataset=dataset,
+ pid_info=self.__pid_info,
)
- return task
- def submit_training_task(self, task: Union[ForecastingTuningTrainingTask,
ForecastingSingleTrainingTask]) -> None:
- if task is not None:
- self.__training_process_pool.apply_async(task, args=())
- logger.info(f'Task: ({task.model_id}) - Training process submitted
successfully')
+ def submit_training_task(self, task: Union[ForecastAutoTuningTrainingTask,
ForecastFixedParamTrainingTask]) -> None:
+ self.__training_process_pool.apply_async(task, args=())
+ logger.info(f'Task: ({task.model_id}) - Training process submitted
successfully')
def create_forecast_task(self,
task_configs,
diff --git a/mlnode/iotdb/mlnode/process/task.py
b/mlnode/iotdb/mlnode/process/task.py
index afdef68f4ea..6eeace39220 100644
--- a/mlnode/iotdb/mlnode/process/task.py
+++ b/mlnode/iotdb/mlnode/process/task.py
@@ -28,9 +28,11 @@ import torch
from torch.utils.data import Dataset
from iotdb.mlnode.algorithm.factory import create_forecast_model
+from iotdb.mlnode.algorithm.hyperparameter import parse_fixed_hyperparameters,
generate_hyperparameters
from iotdb.mlnode.client import client_manager
from iotdb.mlnode.config import descriptor
from iotdb.mlnode.log import logger
+from iotdb.mlnode.parser import TaskOptions
from iotdb.mlnode.process.trial import ForecastingTrainingTrial
from iotdb.mlnode.storage import model_storage
from iotdb.thrift.common.ttypes import TrainingState
@@ -45,26 +47,21 @@ class ForestingTrainingObjective:
def __init__(
self,
- trial_configs: Dict,
- model_configs: Dict,
+ task_options: TaskOptions,
+ hyperparameters: Dict[str, str],
dataset: Dataset,
- # pid_info: Dict
+ pid_info: Dict
):
- self.trial_configs = trial_configs
- self.model_configs = model_configs
+ self.task_options = task_options
+ self.hyperparameters = hyperparameters
self.dataset = dataset
- # self.pid_info = pid_info
+ self.pid_info = pid_info
- def __call__(self, trial: optuna.Trial):
- # TODO: decide which parameters to tune
- trial_configs = self.trial_configs
- trial_configs['learning_rate'] = trial.suggest_float("lr", 1e-7, 1e-1,
log=True)
- trial_configs['trial_id'] = 'tid_' + str(trial._trial_id)
- # TODO: check args
- model, model_configs = create_forecast_model(**self.model_configs)
- # self.pid_info[self.trial_configs['model_id']][trial._trial_id] =
os.getpid()
- _trial = ForecastingTrainingTrial(trial_configs, model, model_configs,
self.dataset)
- loss = _trial.start()
+ def __call__(self, optuna_suggest: optuna.Trial):
+ model_configs, task_configs = generate_hyperparameters(optuna_suggest,
self.task_options, self.hyperparameters)
+ model = create_forecast_model(self.task_options, model_configs)
+ trial = ForecastingTrainingTrial(model, task_configs, self.dataset)
+ loss = trial.start()
return loss
@@ -76,19 +73,13 @@ class _BasicTask(object):
def __init__(
self,
- task_configs: Dict,
- model_configs: Dict,
pid_info: Dict
):
"""
Args:
- task_configs:
- model_configs:
pid_info:
"""
self.pid_info = pid_info
- self.task_configs = task_configs
- self.model_configs = model_configs
@abstractmethod
def __call__(self):
@@ -98,23 +89,18 @@ class _BasicTask(object):
class _BasicTrainingTask(_BasicTask):
def __init__(
self,
- task_configs: Dict,
- model_configs: Dict,
- pid_info: Dict,
- data_configs: Dict,
+ model_id: str,
+ task_options: TaskOptions,
+ hyperparameters: Dict[str, str],
dataset: Dataset,
+ pid_info: Dict
):
- """
- Args:
- task_configs:
- model_configs:
- pid_info:
- data_configs:
- dataset:
- """
- super().__init__(task_configs, model_configs, pid_info)
- self.data_configs = data_configs
+ super().__init__(pid_info)
+ self.model_id = model_id
+ self.task_options = task_options
+ self.hyperparameters = hyperparameters
self.dataset = dataset
+
self.confignode_client = client_manager.borrow_config_node_client()
@abstractmethod
@@ -153,15 +139,14 @@ class _BasicInferenceTask(_BasicTask):
raise NotImplementedError
-class ForecastingSingleTrainingTask(_BasicTrainingTask):
+class ForecastFixedParamTrainingTask(_BasicTrainingTask):
def __init__(
self,
- task_configs: Dict,
- model_configs: Dict,
- pid_info: Dict,
- data_configs: Dict,
- dataset: Dataset,
model_id: str,
+ task_options: TaskOptions,
+ hyperparameters: Dict[str, str],
+ dataset: Dataset,
+ pid_info: Dict
):
"""
Args:
@@ -171,33 +156,31 @@ class ForecastingSingleTrainingTask(_BasicTrainingTask):
data_configs: dict of data configurations
dataset: training dataset
"""
- super().__init__(task_configs, model_configs, pid_info, data_configs,
dataset)
- self.model_id = model_id
- self.default_trial_id = 'tid_0'
- self.task_configs['trial_id'] = self.default_trial_id
- model, model_configs = create_forecast_model(**model_configs)
- self.trial = ForecastingTrainingTrial(task_configs, model,
model_configs, dataset)
+ super().__init__(model_id, task_options, hyperparameters, dataset,
pid_info)
+ model_configs, task_configs =
parse_fixed_hyperparameters(task_options, hyperparameters)
+ self.trial =
ForecastingTrainingTrial(create_forecast_model(task_options, model_configs),
+ task_configs,
+ dataset)
def __call__(self):
try:
self.pid_info[self.model_id] = os.getpid()
self.trial.start()
- self.confignode_client.update_model_state(self.model_id,
TrainingState.FINISHED, self.default_trial_id)
+ self.confignode_client.update_model_state(self.model_id,
TrainingState.FINISHED, self.trial.trial_id)
except Exception as e:
logger.warn(e)
- self.confignode_client.update_model_state(self.model_id,
TrainingState.FAILED, self.default_trial_id)
+ self.confignode_client.update_model_state(self.model_id,
TrainingState.FAILED, self.trial.trial_id)
raise e
-class ForecastingTuningTrainingTask(_BasicTrainingTask):
+class ForecastAutoTuningTrainingTask(_BasicTrainingTask):
def __init__(
self,
- task_configs: Dict,
- model_configs: Dict,
- pid_info: Dict,
- data_configs: Dict,
- dataset: Dataset,
model_id: str,
+ task_options: TaskOptions,
+ hyperparameters: Dict[str, str],
+ dataset: Dataset,
+ pid_info: Dict
):
"""
Args:
@@ -207,8 +190,7 @@ class ForecastingTuningTrainingTask(_BasicTrainingTask):
data_configs: dict of data configurations
dataset: training dataset
"""
- super().__init__(task_configs, model_configs, pid_info, data_configs,
dataset)
- self.model_id = model_id
+ super().__init__(model_id, task_options, hyperparameters, dataset,
pid_info)
self.study = optuna.create_study(direction='minimize')
def __call__(self):
diff --git a/mlnode/iotdb/mlnode/process/trial.py
b/mlnode/iotdb/mlnode/process/trial.py
index 7fc34477686..be6b57fbccf 100644
--- a/mlnode/iotdb/mlnode/process/trial.py
+++ b/mlnode/iotdb/mlnode/process/trial.py
@@ -26,7 +26,9 @@ from torch.utils.data import DataLoader, Dataset
from iotdb.mlnode.algorithm.metric import all_metrics, build_metrics
from iotdb.mlnode.client import client_manager
+from iotdb.mlnode.constant import DEFAULT_TRIAL_ID
from iotdb.mlnode.log import logger
+from iotdb.mlnode.parser import TaskOptions
from iotdb.mlnode.storage import model_storage
from iotdb.thrift.common.ttypes import TrainingState
@@ -81,9 +83,8 @@ def _parse_trial_config(**kwargs):
class BasicTrial(object):
def __init__(
self,
- task_configs: Dict,
model: nn.Module,
- model_configs: Dict,
+ task_configs: Dict,
dataset: Dataset
):
self.trial_configs = task_configs
@@ -131,10 +132,9 @@ class BasicTrial(object):
class ForecastingTrainingTrial(BasicTrial):
def __init__(
self,
- task_configs: Dict,
model: nn.Module,
- model_configs: Dict,
- dataset: Dataset,
+ task_configs: Dict,
+ dataset: Dataset
):
"""
A training trial, accept all parameters needed and train a single
model.
@@ -146,8 +146,8 @@ class ForecastingTrainingTrial(BasicTrial):
dataset: training dataset
**kwargs:
"""
- super(ForecastingTrainingTrial, self).__init__(task_configs, model,
model_configs, dataset)
-
+ super(ForecastingTrainingTrial, self).__init__(model, task_configs,
dataset)
+ self.trial_id = DEFAULT_TRIAL_ID
self.dataloader = self._build_dataloader()
self.datanode_client = client_manager.borrow_data_node_client()
self.confignode_client = client_manager.borrow_config_node_client()
diff --git a/mlnode/pom.xml b/mlnode/pom.xml
index c75e9c6fb33..f94bfc5a496 100644
--- a/mlnode/pom.xml
+++ b/mlnode/pom.xml
@@ -104,13 +104,13 @@
<outputDirectory>${basedir}/iotdb/thrift/</outputDirectory>
<resources>
<resource>
-
<directory>${basedir}/../thrift-commons/target/generated-sources-python/iotdb/thrift/</directory>
+
<directory>${basedir}/../iotdb-protocol/thrift-commons/target/generated-sources-python/iotdb/thrift/</directory>
</resource>
<resource>
-
<directory>${basedir}/../thrift-confignode/target/generated-sources-python/iotdb/thrift/</directory>
+
<directory>${basedir}/../iotdb-protocol/thrift-confignode/target/generated-sources-python/iotdb/thrift/</directory>
</resource>
<resource>
-
<directory>${basedir}/../thrift-mlnode/target/generated-sources-python/iotdb/thrift/</directory>
+
<directory>${basedir}/../iotdb-protocol/thrift-mlnode/target/generated-sources-python/iotdb/thrift/</directory>
</resource>
</resources>
</configuration>
@@ -127,7 +127,7 @@
<outputDirectory>${basedir}/iotdb/thrift/datanode</outputDirectory>
<resources>
<resource>
-
<directory>${basedir}/../thrift/target/generated-sources-python/iotdb/thrift/datanode/</directory>
+
<directory>${basedir}/../iotdb-protocol/thrift/target/generated-sources-python/iotdb/thrift/datanode/</directory>
</resource>
</resources>
</configuration>