This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 0750fa6f49 [IOTDB-5728] Implement config parser & model/dataset 
factory on MLNode (#9458)
0750fa6f49 is described below

commit 0750fa6f492b052fea5c01e88a21032a12631488
Author: lichenyu <[email protected]>
AuthorDate: Fri Mar 31 15:35:26 2023 +0800

    [IOTDB-5728] Implement config parser & model/dataset factory on MLNode 
(#9458)
    
    Co-authored-by: Wenwei <[email protected]>
---
 .../{models/forecast/__init__.py => enums.py}      |  12 ++
 mlnode/iotdb/mlnode/algorithm/factory.py           | 128 ++++++++++++++
 .../mlnode/algorithm/models/forecast/__init__.py   |   3 +
 .../mlnode/algorithm/models/forecast/dlinear.py    |  41 ++++-
 .../mlnode/algorithm/models/forecast/nbeats.py     |  47 ++++-
 mlnode/iotdb/mlnode/client.py                      |   9 +-
 mlnode/iotdb/mlnode/constant.py                    |   1 +
 .../{datats/utils => data_access}/__init__.py      |   0
 .../forecast/__init__.py => data_access/enums.py}  |  12 ++
 mlnode/iotdb/mlnode/data_access/factory.py         | 105 +++++++++++
 .../forecast => data_access/offline}/__init__.py   |   0
 .../{datats => data_access}/offline/dataset.py     |  30 +---
 .../offline/source.py}                             |   9 +-
 .../forecast => data_access/utils}/__init__.py     |   0
 .../{datats => data_access}/utils/timefeatures.py  |   2 -
 mlnode/iotdb/mlnode/exception.py                   |  16 +-
 mlnode/iotdb/mlnode/handler.py                     |  27 ++-
 mlnode/iotdb/mlnode/parser.py                      | 194 +++++++++++++++++++++
 mlnode/iotdb/mlnode/serde.py                       |  30 +++-
 mlnode/iotdb/mlnode/util.py                        |   4 +-
 mlnode/test/test_parse_training_request.py         | 136 +++++++++++++++
 21 files changed, 747 insertions(+), 59 deletions(-)

diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py 
b/mlnode/iotdb/mlnode/algorithm/enums.py
similarity index 76%
copy from mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
copy to mlnode/iotdb/mlnode/algorithm/enums.py
index 2a1e720805..4b05aa4bf8 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
+++ b/mlnode/iotdb/mlnode/algorithm/enums.py
@@ -15,3 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+from enum import Enum
+
+
+class ForecastTaskType(Enum):
+    ENDOGENOUS = "endogenous"
+    EXOGENOUS = "exogenous"
+
+    def __str__(self):
+        return self.value
+
+    def __eq__(self, other: str) -> bool:
+        return self.value == other
diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py 
b/mlnode/iotdb/mlnode/algorithm/factory.py
new file mode 100644
index 0000000000..92cb01a883
--- /dev/null
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import torch.nn as nn
+
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models
+from iotdb.mlnode.exception import BadConfigValueError
+
+
+# Common configs for all forecasting model with default values
+def _common_config(**kwargs):
+    return {
+        'input_len': 96,
+        'pred_len': 96,
+        'input_vars': 1,
+        'output_vars': 1,
+        **kwargs
+    }
+
+
+# Common forecasting task configs
+support_common_configs = {
+    # multivariate forecasting, current support this only
+    ForecastTaskType.ENDOGENOUS: _common_config(
+        input_vars=1,
+        output_vars=1),
+
+    # univariate forecasting with observable exogenous variables
+    ForecastTaskType.EXOGENOUS: _common_config(
+        output_vars=1),
+}
+
+
+def is_model(model_name: str) -> bool:
+    """
+    Check if a model name exists
+    """
+    return model_name in support_forecasting_models
+
+
+def list_model() -> list[str]:
+    """
+    List support forecasting model
+    """
+    return support_forecasting_models
+
+
+def create_forecast_model(
+        model_name,
+        forecast_task_type=ForecastTaskType.ENDOGENOUS,
+        input_len=96,
+        pred_len=96,
+        input_vars=1,
+        output_vars=1,
+        **kwargs,
+) -> [nn.Module, dict]:
+    """
+    Factory method for all support forecasting models
+    the given arguments is common configs shared by all forecasting models
+    for specific model configs, see __model_config in 
`algorithm/models/MODELNAME.py`
+
+    Args:
+        model_name: see available models by `list_model`
+        forecast_task_type: 'm' for multivariate forecasting, 'ms' for 
covariate forecasting,
+                   's' for univariate forecasting
+        input_len: time length of model input
+        pred_len: time length of model output
+        input_vars: number of input series
+        output_vars: number of output series
+        kwargs: for specific model configs, see returned `model_config` with 
kwargs=None
+
+    Returns:
+        model: torch.nn.Module
+        model_config: dict of model configurations
+    """
+    if not is_model(model_name):
+        raise BadConfigValueError('model_name', model_name, f'It should be one 
of {list_model()}')
+    if forecast_task_type not in support_common_configs.keys():
+        raise BadConfigValueError('forecast_task_type', forecast_task_type,
+                                  f'It should be one of 
{list(support_common_configs.keys())}')
+
+    common_config = support_common_configs[forecast_task_type]
+    common_config['input_len'] = input_len
+    common_config['pred_len'] = pred_len
+    common_config['input_vars'] = input_vars
+    common_config['output_vars'] = output_vars
+    common_config['forecast_task_type'] = str(forecast_task_type)
+
+    if not input_len > 0:
+        raise BadConfigValueError('input_len', input_len,
+                                  'Length of input series should be positive')
+    if not pred_len > 0:
+        raise BadConfigValueError('pred_len', pred_len,
+                                  'Length of predicted series should be 
positive')
+    if not input_vars > 0:
+        raise BadConfigValueError('input_vars', input_vars,
+                                  'Number of input variates should be 
positive')
+    if not output_vars > 0:
+        raise BadConfigValueError('output_vars', output_vars,
+                                  'Number of output variates should be 
positive')
+    if forecast_task_type == ForecastTaskType.ENDOGENOUS:
+        if input_vars != output_vars:
+            raise BadConfigValueError('forecast_task_type', forecast_task_type,
+                                      'Number of input/output variates should 
be '
+                                      'the same in multivariate forecast')
+    create_fn = eval(model_name)
+    model, model_config = create_fn(
+        common_config=common_config,
+        **kwargs
+    )
+    model_config['model_name'] = model_name
+
+    return model, model_config
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py 
b/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
index 2a1e720805..2abb5faf37 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
@@ -15,3 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+
+
+support_forecasting_models = ['dlinear', 'dlinear_individual', 'nbeats']
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py 
b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
index 58fb12bf29..fa9ee04e56 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
@@ -15,12 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-import argparse
+
 import math
 
 import torch
 import torch.nn as nn
 
+from iotdb.mlnode.exception import BadConfigValueError
+
 
 class MovingAverageBlock(nn.Module):
     """ Moving average block to highlight the trend of time series """
@@ -61,7 +63,9 @@ class DLinear(nn.Module):
             kernel_size=25,
             input_len=96,
             pred_len=96,
-            input_vars=1
+            input_vars=1,
+            output_vars=1,
+            forecast_type='m',  # TODO, support others
     ):
         super(DLinear, self).__init__()
         self.input_len = input_len
@@ -94,7 +98,9 @@ class DLinearIndividual(nn.Module):
             kernel_size=25,
             input_len=96,
             pred_len=96,
-            input_vars=1
+            input_vars=1,
+            output_vars=1,
+            forecast_type='m',  # TODO, support others
     ):
         super(DLinearIndividual, self).__init__()
         self.input_len = input_len
@@ -128,11 +134,28 @@ class DLinearIndividual(nn.Module):
         return x.permute(0, 2, 1)  # to [Batch, Output length, Channel]
 
 
-def dlinear(model_config: argparse.Namespace) -> DLinear:
-    # TODO (@lcy)
-    pass
+def _model_config(**kwargs):
+    return {
+        'kernel_size': 25,
+        **kwargs
+    }
+
+
+def dlinear(common_config: dict, kernel_size=25, **kwargs) -> [DLinear, dict]:
+    config = _model_config()
+    config.update(**common_config)
+    if not kernel_size > 0:
+        raise BadConfigValueError('kernel_size', kernel_size,
+                                  'Kernel size of dlinear should larger than 
0')
+    config['kernel_size'] = kernel_size
+    return DLinear(**config), config
 
 
-def dlinear_individual(model_config: argparse.Namespace) -> DLinearIndividual:
-    # TODO (@lcy)
-    pass
+def dlinear_individual(common_config: dict, kernel_size=25, **kwargs) -> 
[DLinearIndividual, dict]:
+    config = _model_config()
+    config.update(**common_config)
+    if not kernel_size > 0:
+        raise BadConfigValueError('kernel_size', kernel_size,
+                                  'Kernel size of dlinear_individual should 
larger than 0')
+    config['kernel_size'] = kernel_size
+    return DLinearIndividual(**config), config
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py 
b/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
index 0744cd4460..e3c3ca6a0a 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
@@ -16,12 +16,13 @@
 # under the License.
 #
 
-import argparse
 from typing import Tuple
 
 import torch
 import torch.nn as nn
 
+from iotdb.mlnode.exception import BadConfigValueError
+
 
 class GenericBasis(nn.Module):
     """ Generic basis function """
@@ -37,10 +38,6 @@ class GenericBasis(nn.Module):
 
 block_dict = {
     'generic': GenericBasis,
-
-    # TODO(@lcy) support more block type
-    # 'trend': TrendBasis,
-    # 'seasonality': SeasonalityBasis,
 }
 
 
@@ -109,6 +106,8 @@ class NBeats(nn.Module):
             input_len=96,
             pred_len=96,
             input_vars=1,
+            output_vars=1,
+            forecast_type='m',  # TODO, support others
     ):
         super(NBeats, self).__init__()
         self.enc_in = input_vars
@@ -133,6 +132,38 @@ class NBeats(nn.Module):
         return torch.stack(res, dim=-1)  # to [Batch, Output length, Channel]
 
 
-def nbeats(model_config: argparse.Namespace) -> NBeats:
-    # TODO (@lcy)
-    pass
+def _model_config(**kwargs):
+    return {
+        'block_type': 'generic',
+        'd_model': 128,
+        'inner_layers': 4,
+        'outer_layers': 4,
+        **kwargs
+    }
+
+
+"""
+Specific configs for NBeats variants
+"""
+support_model_configs = {
+    'nbeats': _model_config(
+        block_type='generic'),
+}
+
+
+def nbeats(common_config: dict, d_model=128, inner_layers=4, outer_layers=4, 
**kwargs) -> [NBeats, dict]:
+    config = _model_config()
+    config.update(**common_config)
+    if not d_model > 0:
+        raise BadConfigValueError('d_model', d_model,
+                                  'Model dimension (d_model) of nbeats should 
larger than 0')
+    if not inner_layers > 0:
+        raise BadConfigValueError('inner_layers', inner_layers,
+                                  'Number of inner layers of nbeats should 
larger than 0')
+    if not outer_layers > 0:
+        raise BadConfigValueError('outer_layers', outer_layers,
+                                  'Number of outer layers of nbeats should 
larger than 0')
+    config['d_model'] = d_model
+    config['inner_layers'] = inner_layers
+    config['outer_layers'] = outer_layers
+    return NBeats(**config), config
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index aa1536e130..76eb754596 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -70,7 +70,7 @@ class MLNodeClient(object):
                              model_id: str,
                              is_auto: bool,
                              model_configs: dict,
-                             query_expressions: list[str],
+                             query_expressions: list = [],
                              query_filter: str = None) -> None:
         req = TCreateTrainingTaskReq(
             modelId=model_id,
@@ -116,6 +116,7 @@ class DataNodeClient(object):
                 transport.open()
             except TTransport.TTransportException as e:
                 logger.exception("TTransportException!", exc_info=e)
+                raise e
 
         protocol = TBinaryProtocol.TBinaryProtocol(transport)
         self.__client = IDataNodeRPCService.Client(protocol)
@@ -123,7 +124,7 @@ class DataNodeClient(object):
     def fetch_timeseries(self,
                          session_id: int,
                          statement_id: int,
-                         query_expressions: list[str],
+                         query_expressions: list = [],
                          query_filter: str = None,
                          fetch_size: int = DEFAULT_FETCH_SIZE,
                          timeout: int = DEFAULT_TIMEOUT) -> 
TFetchTimeseriesResp:
@@ -145,8 +146,8 @@ class DataNodeClient(object):
     def record_model_metrics(self,
                              model_id: str,
                              trial_id: str,
-                             metrics: list[str],
-                             values: list[float]) -> None:
+                             metrics: list = [],
+                             values: list = []) -> None:
         req = TRecordModelMetricsReq(
             modelId=model_id,
             trialId=trial_id,
diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py
index 810d7c261e..3bffa06526 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/constant.py
@@ -27,6 +27,7 @@ MLNODE_MODEL_STORAGE_DIRECTORY_NAME = "models"
 class TSStatusCode(Enum):
     SUCCESS_STATUS = 200
     REDIRECTION_RECOMMEND = 400
+    FAIL_STATUS = 404
 
     def get_status_code(self) -> int:
         return self.value
diff --git a/mlnode/iotdb/mlnode/datats/utils/__init__.py 
b/mlnode/iotdb/mlnode/data_access/__init__.py
similarity index 100%
rename from mlnode/iotdb/mlnode/datats/utils/__init__.py
rename to mlnode/iotdb/mlnode/data_access/__init__.py
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py 
b/mlnode/iotdb/mlnode/data_access/enums.py
similarity index 77%
copy from mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
copy to mlnode/iotdb/mlnode/data_access/enums.py
index 2a1e720805..d21a9f69c4 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
+++ b/mlnode/iotdb/mlnode/data_access/enums.py
@@ -15,3 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+from enum import Enum
+
+
+class DatasetType(Enum):
+    TIMESERIES = "timeseries"
+    WINDOW = "window"
+
+    def __str__(self):
+        return self.value
+
+    def __eq__(self, other: str) -> bool:
+        return self.value == other
diff --git a/mlnode/iotdb/mlnode/data_access/factory.py 
b/mlnode/iotdb/mlnode/data_access/factory.py
new file mode 100644
index 0000000000..d0041388a6
--- /dev/null
+++ b/mlnode/iotdb/mlnode/data_access/factory.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from torch.utils.data import Dataset
+
+from iotdb.mlnode.data_access.enums import DatasetType
+from iotdb.mlnode.data_access.offline.dataset import (TimeSeriesDataset,
+                                                      WindowDataset)
+from iotdb.mlnode.data_access.offline.source import (FileDataSource,
+                                                     ThriftDataSource)
+from iotdb.mlnode.exception import BadConfigValueError, MissingConfigError
+
+support_forecasting_dataset = {
+    DatasetType.TIMESERIES: TimeSeriesDataset,
+    DatasetType.WINDOW: WindowDataset
+}
+
+
+def _dataset_config(**kwargs):
+    return {
+        'time_embed': 'h',
+        **kwargs
+    }
+
+
+support_dataset_configs = {
+    DatasetType.TIMESERIES: _dataset_config(),
+    DatasetType.WINDOW: _dataset_config(
+        input_len=96,
+        pred_len=96,
+    )
+}
+
+
+def create_forecast_dataset(
+        source_type,
+        dataset_type,
+        **kwargs,
+) -> [Dataset, dict]:
+    """
+    Factory method for all support dataset
+    currently implement WindowDataset, TimeSeriesDataset
+    for specific dataset configs, see _dataset_config in 
`algorithm/models/MODELNAME.py`
+
+    Args:
+        dataset_type: available choice in support_forecasting_dataset
+        source_type:  available choice in ['file', 'thrift']
+        kwargs: for specific dataset configs, see returned `dataset_config` 
with kwargs=None
+
+    Returns:
+        dataset: torch.nn.Module
+        dataset_config: dict of dataset configurations
+    """
+    if dataset_type not in support_forecasting_dataset.keys():
+        raise BadConfigValueError('dataset_type', dataset_type,
+                                  f'It should be one of 
{list(support_forecasting_dataset.keys())}')
+
+    if source_type == 'file':
+        if 'filename' not in kwargs.keys():
+            raise MissingConfigError('filename')
+        datasource = FileDataSource(kwargs['filename'])
+    elif source_type == 'thrift':
+        if 'query_expressions' not in kwargs.keys():
+            raise MissingConfigError('query_expressions')
+        if 'query_filter' not in kwargs.keys():
+            raise MissingConfigError('query_filter')
+        datasource = ThriftDataSource(kwargs['query_expressions'], 
kwargs['query_filter'])
+    else:
+        raise BadConfigValueError('source_type', source_type, "It should be 
one of ['file', 'thrift]")
+
+    dataset_fn = support_forecasting_dataset[dataset_type]
+    dataset_config = support_dataset_configs[dataset_type]
+
+    for k, v in kwargs.items():
+        if k in dataset_config.keys():
+            dataset_config[k] = v
+
+    dataset = dataset_fn(datasource, **dataset_config)
+
+    if 'input_vars' in kwargs.keys() and dataset.get_variable_num() != 
kwargs['input_vars']:
+        raise BadConfigValueError('input_vars', kwargs['input_vars'],
+                                  f'Variable number of fetched data: 
({dataset.get_variable_num()})'
+                                  f' should be consistent with input_vars')
+
+    data_config = dataset_config.copy()
+    data_config['input_vars'] = dataset.get_variable_num()
+    data_config['output_vars'] = dataset.get_variable_num()
+    data_config['source_type'] = source_type
+    data_config['dataset_type'] = dataset_type
+
+    return dataset, data_config
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py 
b/mlnode/iotdb/mlnode/data_access/offline/__init__.py
similarity index 100%
copy from mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
copy to mlnode/iotdb/mlnode/data_access/offline/__init__.py
diff --git a/mlnode/iotdb/mlnode/datats/offline/dataset.py 
b/mlnode/iotdb/mlnode/data_access/offline/dataset.py
similarity index 85%
rename from mlnode/iotdb/mlnode/datats/offline/dataset.py
rename to mlnode/iotdb/mlnode/data_access/offline/dataset.py
index c71aaf87c5..1a96e81a4a 100644
--- a/mlnode/iotdb/mlnode/datats/offline/dataset.py
+++ b/mlnode/iotdb/mlnode/data_access/offline/dataset.py
@@ -15,16 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-
-
-import argparse
-
 from torch.utils.data import Dataset
 
-from iotdb.mlnode.datats.offline.data_source import DataSource
-from iotdb.mlnode.datats.utils.timefeatures import time_features
-
-# currently support for multivariate forecasting only
+from iotdb.mlnode.data_access.offline.source import DataSource
+from iotdb.mlnode.data_access.utils.timefeatures import time_features
 
 
 class TimeSeriesDataset(Dataset):
@@ -81,11 +75,11 @@ class WindowDataset(TimeSeriesDataset):
                  time_embed: str = 'h'):
         self.input_len = input_len
         self.pred_len = pred_len
-        if input_len <= self.data.shape[0]:
+        super(WindowDataset, self).__init__(data_source, time_embed)
+        if input_len > self.data.shape[0]:
             raise RuntimeError('input_len should not be larger than the number 
of time series points')
-        if pred_len <= self.data.shape[0]:
+        if pred_len > self.data.shape[0]:
             raise RuntimeError('pred_len should not be larger than the number 
of time series points')
-        super(WindowDataset, self).__init__(data_source, time_embed)
 
     def __getitem__(self, index):
         s_begin = index
@@ -100,17 +94,3 @@ class WindowDataset(TimeSeriesDataset):
 
     def __len__(self):
         return len(self.data) - self.input_len - self.pred_len + 1
-
-
-def get_timeseries_dataset(data_config: argparse.Namespace) -> 
TimeSeriesDataset:
-    # TODO (@lcy)
-    # init datasource
-    # init dataset
-    pass
-
-
-def get_window_dataset(data_config: argparse.Namespace) -> WindowDataset:
-    # TODO (@lcy)
-    # init datasource
-    # init dataset
-    pass
diff --git a/mlnode/iotdb/mlnode/datats/offline/data_source.py 
b/mlnode/iotdb/mlnode/data_access/offline/source.py
similarity index 96%
rename from mlnode/iotdb/mlnode/datats/offline/data_source.py
rename to mlnode/iotdb/mlnode/data_access/offline/source.py
index cd8e9a891c..a63371ec7a 100644
--- a/mlnode/iotdb/mlnode/datats/offline/data_source.py
+++ b/mlnode/iotdb/mlnode/data_access/offline/source.py
@@ -33,6 +33,7 @@ class DataSource(object):
     def __init__(self):
         self.data = None
         self.timestamp = None
+        self._read_data()
 
     def _read_data(self):
         raise NotImplementedError
@@ -46,9 +47,8 @@ class DataSource(object):
 
 class FileDataSource(DataSource):
     def __init__(self, filename: str = None):
-        super(FileDataSource, self).__init__()
         self.filename = filename
-        self._read_data()
+        super(FileDataSource, self).__init__()
 
     def _read_data(self):
         try:
@@ -62,15 +62,14 @@ class FileDataSource(DataSource):
 
 class ThriftDataSource(DataSource):
     def __init__(self, query_expressions: list = None, query_filter: str = 
None):
-        super(DataSource, self).__init__()
         self.query_expressions = query_expressions
         self.query_filter = query_filter
-        self._read_data()
+        super(ThriftDataSource, self).__init__()
 
     def _read_data(self):
         try:
             data_client = client_manager.borrow_data_node_client()
-        except Exception:  # is this exception catch needed???
+        except Exception:
             raise RuntimeError('Fail to establish connection with DataNode')
 
         try:
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py 
b/mlnode/iotdb/mlnode/data_access/utils/__init__.py
similarity index 100%
copy from mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
copy to mlnode/iotdb/mlnode/data_access/utils/__init__.py
diff --git a/mlnode/iotdb/mlnode/datats/utils/timefeatures.py 
b/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
similarity index 99%
rename from mlnode/iotdb/mlnode/datats/utils/timefeatures.py
rename to mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
index bd1681cfbf..ecd6784ca4 100644
--- a/mlnode/iotdb/mlnode/datats/utils/timefeatures.py
+++ b/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
@@ -15,8 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-
-
 from typing import List
 
 import numpy as np
diff --git a/mlnode/iotdb/mlnode/exception.py b/mlnode/iotdb/mlnode/exception.py
index 3907a67d58..a7b211dbc2 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/exception.py
@@ -16,7 +16,6 @@
 # under the License.
 #
 
-
 class _BaseError(Exception):
     """Base class for exceptions in this module."""
     pass
@@ -30,3 +29,18 @@ class BadNodeUrlError(_BaseError):
 class ModelNotExistError(_BaseError):
     def __init__(self, file_path: str):
         self.message = "Model path: ({}) not exists".format(file_path)
+
+
+class BadConfigValueError(_BaseError):
+    def __init__(self, config_name: str, config_value, hint: str = ''):
+        self.message = "Bad value ({0}) for config: ({1}). 
{2}".format(config_value, config_name, hint)
+
+
+class MissingConfigError(_BaseError):
+    def __init__(self, config_name: str):
+        self.message = "Missing config: ({})".format(config_name)
+
+
+class WrongTypeConfigError(_BaseError):
+    def __init__(self, config_name: str, expected_type: str):
+        self.message = "Wrong type for config: ({0}), expected: 
({1})".format(config_name, expected_type)
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index d1f21ff517..e7ff76cbe0 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -16,7 +16,11 @@
 # under the License.
 #
 
+from iotdb.mlnode.algorithm.factory import create_forecast_model
 from iotdb.mlnode.constant import TSStatusCode
+from iotdb.mlnode.data_access.factory import create_forecast_dataset
+from iotdb.mlnode.log import logger
+from iotdb.mlnode.parser import parse_training_request
 from iotdb.mlnode.util import get_status
 from iotdb.thrift.mlnode import IMLNodeRPCService
 from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
@@ -32,7 +36,28 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
         return get_status(TSStatusCode.SUCCESS_STATUS, "")
 
     def createTrainingTask(self, req: TCreateTrainingTaskReq):
-        return get_status(TSStatusCode.SUCCESS_STATUS, "")
+        # parse request stage (check required config and config type)
+        data_config, model_config, task_config = parse_training_request(req)
+
+        # create model stage (check model config legitimacy)
+        try:
+            model, model_config = create_forecast_model(**model_config)
+        except Exception as e:  # Create model failed
+            return get_status(TSStatusCode.FAIL_STATUS, str(e))
+        logger.info('model config: ' + str(model_config))
+
+        # create data stage (check data config legitimacy)
+        try:
+            dataset, data_config = create_forecast_dataset(**data_config)
+        except Exception as e:  # Create data failed
+            return get_status(TSStatusCode.FAIL_STATUS, str(e))
+        logger.info('data config: ' + str(data_config))
+
+        # create task stage (check task config legitimacy)
+
+        # submit task stage (check resource and decide pending/start)
+
+        return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create 
training task')
 
     def forecast(self, req: TForecastReq):
         status = get_status(TSStatusCode.SUCCESS_STATUS, "")
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
new file mode 100644
index 0000000000..236032b9a0
--- /dev/null
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import argparse
+import re
+
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.data_access.enums import DatasetType
+from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
+from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq
+
+
+class _ConfigParser(argparse.ArgumentParser):
+    """
+    A parser for parsing configs from configs: dict
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def parse_configs(self, configs):
+        """
+        Parse configs from a dict
+        Args:configs: a dict of all configs which contains all required 
arguments
+        Returns: a dict of parsed configs
+        """
+        args = self.parse_dict(configs)
+        return vars(self.parse_known_args(args)[0])
+
+    @staticmethod
+    def parse_dict(config_dict):
+        """
+        Parse a dict of configs to a list of arguments
+        Args:config_dict: a dict of configs
+        Returns: a list of arguments which can be parsed by argparse
+        """
+        args = []
+        for k, v in config_dict.items():
+            args.append("--{}".format(k))
+            if isinstance(v, str) and re.match(r'^\[(.*)]$', v):
+                v = eval(v)
+                v = [str(i) for i in v]
+                args.extend(v)
+            elif isinstance(v, list):
+                args.extend([str(i) for i in v])
+            else:
+                args.append(v)
+        return args
+
+    def error(self, message: str):
+        """
+        Override the error method to raise exceptions instead of exiting
+        """
+        if message.startswith('the following arguments are required:'):
+            missing_arg = re.findall(r': --(\w+)', message)[0]
+            raise MissingConfigError(missing_arg)
+        elif re.match(r'argument --\w+: invalid \w+ value:', message):
+            argument = re.findall(r'argument --(\w+):', message)[0]
+            expected_type = re.findall(r'invalid (\w+) value:', message)[0]
+            raise WrongTypeConfigError(argument, expected_type)
+        else:
+            raise Exception(message)
+
+
+""" Argument description:
+ - query_expressions: query expressions
+ - query_filter: query filter
+ - source_type: source type
+ - filename: filename
+ - dataset_type: dataset type
+ - time_embed: freq for time features encoding
+ - input_len: input sequence length
+ - pred_len: prediction sequence length
+ - input_vars: number of input variables
+ - output_vars: number of output variables
+"""
+_data_config_parser = _ConfigParser()
+_data_config_parser.add_argument('--source_type', type=str, required=True)
+_data_config_parser.add_argument('--dataset_type', type=DatasetType, 
required=True)
+_data_config_parser.add_argument('--filename', type=str, default='')
+_data_config_parser.add_argument('--query_expressions', type=str, nargs='*', 
default=[])
+_data_config_parser.add_argument('--query_filter', type=str, default='')
+_data_config_parser.add_argument('--time_embed', type=str, default='h')
+_data_config_parser.add_argument('--input_len', type=int, default=96)
+_data_config_parser.add_argument('--pred_len', type=int, default=96)
+_data_config_parser.add_argument('--input_vars', type=int, default=1)
+_data_config_parser.add_argument('--output_vars', type=int, default=1)
+
+""" Argument description:
+ - model_name: model name
+ - input_len: input sequence length
+ - pred_len: prediction sequence length
+ - input_vars: number of input variables
+ - output_vars: number of output variables
+ - task_type: task type, options:[M, S, MS];
+        M:multivariate predict multivariate,
+        S:univariate predict univariate,
+        MS:multivariate predict univariate'
+ - kernel_size: kernel size
+ - block_type: block type
+ - d_model: dimension of feature in model
+ - inner_layers: number of inner layers
+ - outer_layers: number of outer layers
+"""
+_model_config_parser = _ConfigParser()
+_model_config_parser.add_argument('--model_name', type=str, required=True)
+_model_config_parser.add_argument('--input_len', type=int, default=96)
+_model_config_parser.add_argument('--pred_len', type=int, default=96)
+_model_config_parser.add_argument('--input_vars', type=int, default=1)
+_model_config_parser.add_argument('--output_vars', type=int, default=1)
+_model_config_parser.add_argument('--forecast_task_type', 
type=ForecastTaskType, default=ForecastTaskType.ENDOGENOUS,
+                                  choices=list(ForecastTaskType))
+_model_config_parser.add_argument('--kernel_size', type=int, default=25)
+_model_config_parser.add_argument('--block_type', type=str, default='generic')
+_model_config_parser.add_argument('--d_model', type=int, default=128)
+_model_config_parser.add_argument('--inner_layers', type=int, default=4)
+_model_config_parser.add_argument('--outer_layers', type=int, default=4)
+
+""" Argument description:
+ - model_id: model id
+ - tuning: whether to tune hyperparameters
+ - task_type: task type, options:[M, S, MS]; M:multivariate predict 
multivariate, S:univariate predict univariate,
+        MS:multivariate predict univariate'
+ - task_class: task class
+ - input_len: input sequence length
+ - pred_len: prediction sequence length
+ - input_vars: number of input variables
+ - output_vars: number of output variables
+ - learning_rate: learning rate
+ - batch_size: batch size
+ - num_workers: number of workers
+ - epochs: number of epochs
+ - use_gpu: whether to use gpu
+ - use_multi_gpu: whether to use multi-gpu
+ - devices: devices to use
+ - metric_names: metric to use
+"""
+_task_config_parser = _ConfigParser()
+_task_config_parser.add_argument('--task_class', type=str, required=True)
+_task_config_parser.add_argument('--model_id', type=str, required=True)
+_task_config_parser.add_argument('--tuning', type=bool, default=False)
+_task_config_parser.add_argument('--forecast_task_type', 
type=ForecastTaskType, default=ForecastTaskType.ENDOGENOUS,
+                                 choices=list(ForecastTaskType))
+_task_config_parser.add_argument('--input_len', type=int, default=96)
+_task_config_parser.add_argument('--pred_len', type=int, default=96)
+_task_config_parser.add_argument('--input_vars', type=int, default=1)
+_task_config_parser.add_argument('--output_vars', type=int, default=1)
+_task_config_parser.add_argument('--learning_rate', type=float, default=0.0001)
+_task_config_parser.add_argument('--batch_size', type=int, default=32)
+_task_config_parser.add_argument('--num_workers', type=int, default=0)
+_task_config_parser.add_argument('--epochs', type=int, default=10)
+_task_config_parser.add_argument('--use_gpu', type=bool, default=False)
+_task_config_parser.add_argument('--gpu', type=int, default=0)
+_task_config_parser.add_argument('--use_multi_gpu', type=bool, default=False)
+_task_config_parser.add_argument('--devices', type=int, nargs='+', default=[0])
+_task_config_parser.add_argument('--metric_names', type=str, nargs='+', 
default=['MSE', 'MAE'])
+
+
+def parse_training_request(req: TCreateTrainingTaskReq):
+    """
+    Parse TCreateTrainingTaskReq with given yaml template
+    Args:
+        req: TCreateTrainingTaskReq
+    Returns:
+        data_config: configurations related to data
+        model_config: configurations related to model
+        task_config: configurations related to task
+    """
+    config = req.modelConfigs
+    config.update(model_id=req.modelId)
+    config.update(tuning=req.isAuto)
+    config.update(query_expressions=req.queryExpressions)
+    config.update(query_filter=req.queryFilter)
+
+    data_config = _data_config_parser.parse_configs(config)
+    model_config = _model_config_parser.parse_configs(config)
+    task_config = _task_config_parser.parse_configs(config)
+    return data_config, model_config, task_config
diff --git a/mlnode/iotdb/mlnode/serde.py b/mlnode/iotdb/mlnode/serde.py
index 26860faf38..5e98636e2e 100644
--- a/mlnode/iotdb/mlnode/serde.py
+++ b/mlnode/iotdb/mlnode/serde.py
@@ -15,10 +15,38 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+from enum import Enum
+
 import numpy as np
 import pandas as pd
 
-from iotdb.utils.IoTDBConstants import TSDataType
+
+class TSDataType(Enum):
+    BOOLEAN = 0
+    INT32 = 1
+    INT64 = 2
+    FLOAT = 3
+    DOUBLE = 4
+    TEXT = 5
+
+    # this method is implemented to avoid the issue reported by:
+    # https://bugs.python.org/issue30545
+    def __eq__(self, other) -> bool:
+        return self.value == other.value
+
+    def __hash__(self):
+        return self.value
+
+    def np_dtype(self):
+        return {
+            TSDataType.BOOLEAN: np.dtype(">?"),
+            TSDataType.FLOAT: np.dtype(">f4"),
+            TSDataType.DOUBLE: np.dtype(">f8"),
+            TSDataType.INT32: np.dtype(">i4"),
+            TSDataType.INT64: np.dtype(">i8"),
+            TSDataType.TEXT: np.dtype("str"),
+        }[self]
+
 
 TIMESTAMP_STR = "Time"
 START_INDEX = 2
diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py
index c15e84da11..d67ba1290d 100644
--- a/mlnode/iotdb/mlnode/util.py
+++ b/mlnode/iotdb/mlnode/util.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+
 from iotdb.mlnode.constant import TSStatusCode
 from iotdb.mlnode.exception import BadNodeUrlError
 from iotdb.mlnode.log import logger
@@ -23,13 +24,10 @@ from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
 
 def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
     """ Parse TEndPoint from a given endpoint url.
-
     Args:
         endpoint_url: an endpoint url, format: ip:port
-
     Returns:
         TEndPoint
-
     Raises:
         BadNodeUrlError
     """
diff --git a/mlnode/test/test_parse_training_request.py 
b/mlnode/test/test_parse_training_request.py
new file mode 100644
index 0000000000..ec318ae60d
--- /dev/null
+++ b/mlnode/test/test_parse_training_request.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from iotdb.mlnode.parser import parse_training_request
+from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq
+
+
+def test_parse_training_request():
+    model_id = 'mid_etth1_dlinear_default'
+    is_auto = False
+    model_configs = {
+        'task_class': 'forecast_training_task',
+        'source_type': 'thrift',
+        'dataset_type': 'window',
+        'filename': 'ETTh1.csv',
+        'time_embed': 'h',
+        'input_len': 96,
+        'pred_len': 96,
+        'model_name': 'dlinear',
+        'input_vars': 7,
+        'output_vars': 7,
+        'forecast_type': 'm',
+        'kernel_size': 25,
+        'learning_rate': 1e-3,
+        'batch_size': 32,
+        'num_workers': 0,
+        'epochs': 10,
+        'metric_names': ['MSE', 'MAE']
+    }
+    query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**', 
'root.eg.etth1.**']
+    query_filter = '0,1501516800000'
+    req = TCreateTrainingTaskReq(
+        modelId=str(model_id),
+        isAuto=is_auto,
+        modelConfigs={k: str(v) for k, v in model_configs.items()},
+        queryExpressions=[str(query) for query in query_expressions],
+        queryFilter=str(query_filter),
+    )
+    data_config, model_config, task_config = parse_training_request(req)
+    for config in model_configs:
+        if config in data_config:
+            assert data_config[config] == model_configs[config]
+        if config in model_config:
+            assert model_config[config] == model_configs[config]
+        if config in task_config:
+            assert task_config[config] == model_configs[config]
+
+
+def test_missing_argument():
+    # missing model_name
+    model_id = 'mid_etth1_dlinear_default'
+    is_auto = False
+    model_configs = {
+        'task_class': 'forecast_training_task',
+        'source_type': 'thrift',
+        'dataset_type': 'window',
+        'filename': 'ETTh1.csv',
+        'time_embed': 'h',
+        'input_len': 96,
+        'pred_len': 96,
+        'input_vars': 7,
+        'output_vars': 7,
+        'forecast_type': 'm',
+        'kernel_size': 25,
+        'learning_rate': 1e-3,
+        'batch_size': 32,
+        'num_workers': 0,
+        'epochs': 10,
+        'metric_names': ['MSE', 'MAE']
+    }
+    query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**', 
'root.eg.etth1.**']
+    query_filter = '0,1501516800000'
+    req = TCreateTrainingTaskReq(
+        modelId=str(model_id),
+        isAuto=is_auto,
+        modelConfigs={k: str(v) for k, v in model_configs.items()},
+        queryExpressions=[str(query) for query in query_expressions],
+        queryFilter=str(query_filter),
+    )
+    try:
+        data_config, model_config, task_config = parse_training_request(req)
+    except Exception as e:
+        assert e.message == 'Missing config: (model_name)'
+
+
+def test_wrong_argument_type():
+    model_id = 'mid_etth1_dlinear_default'
+    is_auto = False
+    model_configs = {
+        'task_class': 'forecast_training_task',
+        'source_type': 'thrift',
+        'dataset_type': 'window',
+        'filename': 'ETTh1.csv',
+        'time_embed': 'h',
+        'input_len': 96.7,
+        'pred_len': 96,
+        'model_name': 'dlinear',
+        'input_vars': 7,
+        'output_vars': 7,
+        'forecast_type': 'm',
+        'kernel_size': 25,
+        'learning_rate': 1e-3,
+        'batch_size': 32,
+        'num_workers': 0,
+        'epochs': 10,
+        'metric_names': ['MSE', 'MAE']
+    }
+    query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**', 
'root.eg.etth1.**']
+    query_filter = '0,1501516800000'
+    req = TCreateTrainingTaskReq(
+        modelId=str(model_id),
+        isAuto=is_auto,
+        modelConfigs={k: str(v) for k, v in model_configs.items()},
+        queryExpressions=[str(query) for query in query_expressions],
+        queryFilter=str(query_filter),
+    )
+    try:
+        data_config, model_config, task_config = parse_training_request(req)
+    except Exception as e:
+        message = "Wrong type for config: ({})".format('input_len')
+        message += ", expected: ({})".format('int')
+        assert e.message == message


Reply via email to