This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch mlnode/test
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/mlnode/test by this push:
     new 3fa855d8da refactor factory
3fa855d8da is described below

commit 3fa855d8daaccd29ee601594c4d6b0e64480ad21
Author: Minghui Liu <[email protected]>
AuthorDate: Mon Apr 3 22:32:49 2023 +0800

    refactor factory
---
 mlnode/iotdb/mlnode/algorithm/enums.py             | 22 +++++---
 mlnode/iotdb/mlnode/algorithm/factory.py           | 62 +++++++++++-----------
 mlnode/iotdb/mlnode/algorithm/metric.py            |  7 +++
 .../mlnode/algorithm/models/forecast/__init__.py   |  5 --
 mlnode/iotdb/mlnode/client.py                      |  4 +-
 mlnode/iotdb/mlnode/data_access/enums.py           |  2 +-
 mlnode/iotdb/mlnode/data_access/factory.py         | 37 +++++--------
 mlnode/iotdb/mlnode/process/manager.py             |  4 +-
 mlnode/iotdb/mlnode/process/task.py                |  3 +-
 mlnode/iotdb/mlnode/process/task_factory.py        | 40 --------------
 mlnode/iotdb/mlnode/process/trial.py               | 25 ++++-----
 mlnode/iotdb/mlnode/storage.py                     |  4 +-
 mlnode/test/test_create_forecast_dataset.py        |  5 +-
 mlnode/test/test_create_forecast_model.py          |  5 +-
 mlnode/test/test_model_storage.py                  |  2 +-
 mlnode/test/test_parse_training_request.py         |  4 +-
 thrift-mlnode/src/main/thrift/mlnode.thrift        |  2 +-
 17 files changed, 100 insertions(+), 133 deletions(-)

diff --git a/mlnode/iotdb/mlnode/algorithm/enums.py 
b/mlnode/iotdb/mlnode/algorithm/enums.py
index 0f93cf056b..8d9c50640b 100644
--- a/mlnode/iotdb/mlnode/algorithm/enums.py
+++ b/mlnode/iotdb/mlnode/algorithm/enums.py
@@ -21,23 +21,33 @@ from enum import Enum
 class ForecastTaskType(Enum):
     """
     In multivariable time series forecasting tasks, the columns to be 
predicted are called the endogenous variables
-    and the columns that are independent or non-forecast are called the 
exogenous variables (they might be helpful 
+    and the columns that are independent or non-forecast are called the 
exogenous variables (they might be helpful
     to predict the endogenous variables). Both of them can appear in the input 
time series.
-    
+
     ForecastTaskType.ENDOGENOUS: all input time series of input are endogenous 
variables
     ForecastTaskType.EXOGENOUS: the input time series is combined with 
endogenous and exogenous variables
     """
     ENDOGENOUS = "endogenous"
     EXOGENOUS = "exogenous"
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.value
 
-    def __hash__(self):
-        return hash(self.value)
-
     def __eq__(self, other: str) -> bool:
         return self.value == other
 
     def __hash__(self) -> int:
         return hash(self.value)
+
+
+class ForecastModelType(Enum):
+    DLINEAR = "dlinear"
+    DLINEAR_INDIVIDUAL = "dlinear_individual"
+    NBEATS = "nbeats"
+
+    @classmethod
+    def values(cls) -> list[str]:
+        values = []
+        for item in list(cls):
+            values.append(item.value)
+        return values
diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py 
b/mlnode/iotdb/mlnode/algorithm/factory.py
index 37f81c1e68..0488a6dfae 100644
--- a/mlnode/iotdb/mlnode/algorithm/factory.py
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@ -16,10 +16,11 @@
 # under the License.
 #
 import torch.nn as nn
-from iotdb.mlnode.algorithm.models.forecast import *
-from iotdb.mlnode.algorithm.enums import ForecastTaskType
-from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models
-from iotdb.mlnode.algorithm.models.forecast.dlinear import dlinear
+
+from iotdb.mlnode.algorithm.enums import ForecastModelType, ForecastTaskType
+from iotdb.mlnode.algorithm.models.forecast.dlinear import (dlinear,
+                                                            dlinear_individual)
+from iotdb.mlnode.algorithm.models.forecast.nbeats import nbeats
 from iotdb.mlnode.exception import BadConfigValueError
 
 
@@ -35,7 +36,7 @@ def _common_config(**kwargs):
 
 
 # Common forecasting task configs
-support_common_configs = {
+_forecasting_model_default_config_dict = {
     # multivariable forecasting with all endogenous variables, current support 
this only
     ForecastTaskType.ENDOGENOUS: _common_config(
         input_vars=1,
@@ -47,20 +48,6 @@ support_common_configs = {
 }
 
 
-def is_model(model_name: str) -> bool:
-    """
-    Check if a model name exists
-    """
-    return model_name in support_forecasting_models
-
-
-def list_model() -> list:
-    """
-    List support forecasting model
-    """
-    return support_forecasting_models
-
-
 def create_forecast_model(
         model_name,
         input_len=96,
@@ -88,13 +75,13 @@ def create_forecast_model(
         model: torch.nn.Module
         model_config: dict of model configurations
     """
-    if not is_model(model_name):
-        raise BadConfigValueError('model_name', model_name, f'It should be one 
of {list_model()}')
-    if forecast_task_type not in support_common_configs.keys():
+    if model_name not in ForecastModelType.values():
+        raise BadConfigValueError('model_name', model_name, f'It should be one 
of {ForecastModelType.values()}')
+    if forecast_task_type not in _forecasting_model_default_config_dict.keys():
         raise BadConfigValueError('forecast_task_type', forecast_task_type,
-                                  f'It should be one of 
{list(support_common_configs.keys())}')
+                                  f'It should be one of 
{list(_forecasting_model_default_config_dict.keys())}')
 
-    common_config = support_common_configs[forecast_task_type]
+    common_config = _forecasting_model_default_config_dict[forecast_task_type]
     common_config['input_len'] = input_len
     common_config['pred_len'] = pred_len
     common_config['input_vars'] = input_vars
@@ -113,16 +100,29 @@ def create_forecast_model(
     if not output_vars > 0:
         raise BadConfigValueError('output_vars', output_vars,
                                   'Number of output variables should be 
positive')
-    if forecast_task_type == ForecastTaskType.ENDOGENOUS:
+    if forecast_task_type is ForecastTaskType.ENDOGENOUS:
         if input_vars != output_vars:
             raise BadConfigValueError('forecast_task_type', forecast_task_type,
                                       'Number of input/output variables should 
be '
                                       'the same in endogenous forecast')
-    create_fn = eval(model_name)
-    model, model_config = create_fn(
-        common_config=common_config,
-        **kwargs
-    )
-    model_config['model_name'] = model_name
 
+    if model_name == ForecastModelType.DLINEAR.value:
+        model, model_config = dlinear(
+            common_config=common_config,
+            **kwargs
+        )
+    elif model_name == ForecastModelType.DLINEAR_INDIVIDUAL.value:
+        model, model_config = dlinear_individual(
+            common_config=common_config,
+            **kwargs
+        )
+    elif model_name == ForecastModelType.NBEATS.value:
+        model, model_config = nbeats(
+            common_config=common_config,
+            **kwargs
+        )
+    else:
+        raise BadConfigValueError('model_name', model_name, f'It should be one 
of {ForecastModelType.values()}')
+
+    model_config['model_name'] = model_name
     return model, model_config
diff --git a/mlnode/iotdb/mlnode/algorithm/metric.py 
b/mlnode/iotdb/mlnode/algorithm/metric.py
index e623642191..41519dc956 100644
--- a/mlnode/iotdb/mlnode/algorithm/metric.py
+++ b/mlnode/iotdb/mlnode/algorithm/metric.py
@@ -67,3 +67,10 @@ class MAPE(Metric):
 class MSPE(Metric):
     def calculate(self, pred, ground_truth):
         return np.mean(np.square((pred - ground_truth) / ground_truth))
+
+
+def build_metrics(metric_names: list[str]) -> dict[str, Metric]:
+    metrics_dict = {}
+    for metric_name in metric_names:
+        metrics_dict[metric_name] = eval(metric_name)()
+    return metrics_dict
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py 
b/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
index 8be9fa80b5..2a1e720805 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
@@ -15,8 +15,3 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-
-from iotdb.mlnode.algorithm.models.forecast.dlinear import dlinear, 
dlinear_individual
-from iotdb.mlnode.algorithm.models.forecast.nbeats import nbeats
-
-support_forecasting_models = ['dlinear', 'dlinear_individual', 'nbeats']
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index ecfe75859f..1560b0507a 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -71,7 +71,7 @@ class MLNodeClient(object):
                              model_id: str,
                              is_auto: bool,
                              model_configs: dict,
-                             query_expressions: list = [],
+                             query_expressions: list,
                              query_filter: str = '') -> None:
         req = TCreateTrainingTaskReq(
             modelId=model_id,
@@ -123,7 +123,7 @@ class DataNodeClient(object):
         self.__client = IMLNodeInternalRPCService.Client(protocol)
 
     def fetch_timeseries(self,
-                         query_expressions: list = [],
+                         query_expressions: list,
                          query_filter: str = None,
                          fetch_size: int = DEFAULT_FETCH_SIZE,
                          timeout: int = DEFAULT_TIMEOUT) -> 
TFetchTimeseriesResp:
diff --git a/mlnode/iotdb/mlnode/data_access/enums.py 
b/mlnode/iotdb/mlnode/data_access/enums.py
index aad21d2c8c..a5d415fc0e 100644
--- a/mlnode/iotdb/mlnode/data_access/enums.py
+++ b/mlnode/iotdb/mlnode/data_access/enums.py
@@ -43,4 +43,4 @@ class DataSourceType(Enum):
         return hash(self.value)
 
     def __eq__(self, other: str) -> bool:
-        return self.value == other
\ No newline at end of file
+        return self.value == other
diff --git a/mlnode/iotdb/mlnode/data_access/factory.py 
b/mlnode/iotdb/mlnode/data_access/factory.py
index 2e9a3342f7..b2d75f18fa 100644
--- a/mlnode/iotdb/mlnode/data_access/factory.py
+++ b/mlnode/iotdb/mlnode/data_access/factory.py
@@ -24,27 +24,17 @@ from iotdb.mlnode.data_access.offline.source import 
(FileDataSource,
                                                      ThriftDataSource)
 from iotdb.mlnode.exception import BadConfigValueError, MissingConfigError
 
-support_forecasting_dataset = {
-    DatasetType.TIMESERIES: TimeSeriesDataset,
-    DatasetType.WINDOW: WindowDataset
-}
-
-support_forecasting_datasource = {
-    DataSourceType.FILE: FileDataSource,
-    DataSourceType.THRIFT: ThriftDataSource
-}
 
-
-def _dataset_config(**kwargs):
+def _dataset_common_config(**kwargs):
     return {
         'time_embed': 'h',
         **kwargs
     }
 
 
-support_dataset_configs = {
-    DatasetType.TIMESERIES: _dataset_config(),
-    DatasetType.WINDOW: _dataset_config(
+_dataset_default_config_dict = {
+    DatasetType.TIMESERIES: _dataset_common_config(),
+    DatasetType.WINDOW: _dataset_common_config(
         input_len=96,
         pred_len=96,
     )
@@ -71,9 +61,8 @@ def create_forecast_dataset(
         dataset: torch.nn.Module
         dataset_config: dict of dataset configurations
     """
-    if dataset_type not in support_forecasting_dataset.keys():
-        raise BadConfigValueError('dataset_type', dataset_type,
-                                  f'It should be one of 
{list(support_forecasting_dataset.keys())}')
+    if dataset_type not in list(DatasetType):
+        raise BadConfigValueError('dataset_type', dataset_type, f'It should be 
one of {list(DatasetType)}')
 
     if source_type == DataSourceType.FILE:
         if 'filename' not in kwargs.keys():
@@ -86,17 +75,19 @@ def create_forecast_dataset(
             raise MissingConfigError('query_filter')
         datasource = ThriftDataSource(kwargs['query_expressions'], 
kwargs['query_filter'])
     else:
-        raise BadConfigValueError('source_type', source_type,
-                                  f"It should be one of 
{list(support_forecasting_datasource)}")
-
-    dataset_fn = support_forecasting_dataset[dataset_type]
-    dataset_config = support_dataset_configs[dataset_type]
+        raise BadConfigValueError('source_type', source_type, f"It should be 
one of {list(DataSourceType)}")
 
+    dataset_config = _dataset_default_config_dict[dataset_type]
     for k, v in kwargs.items():
         if k in dataset_config.keys():
             dataset_config[k] = v
 
-    dataset = dataset_fn(datasource, **dataset_config)
+    if dataset_type == DatasetType.TIMESERIES:
+        dataset = TimeSeriesDataset(datasource, **dataset_config)
+    elif dataset_type == DatasetType.WINDOW:
+        dataset = WindowDataset(datasource, **dataset_config)
+    else:
+        raise BadConfigValueError('dataset_type', dataset_type, f'It should be 
one of {list(DatasetType)}')
 
     if 'input_vars' in kwargs.keys() and dataset.get_variable_num() != 
kwargs['input_vars']:
         raise BadConfigValueError('input_vars', kwargs['input_vars'],
diff --git a/mlnode/iotdb/mlnode/process/manager.py 
b/mlnode/iotdb/mlnode/process/manager.py
index 0af0353973..88214a19bc 100644
--- a/mlnode/iotdb/mlnode/process/manager.py
+++ b/mlnode/iotdb/mlnode/process/manager.py
@@ -23,7 +23,6 @@ from torch.utils.data import Dataset
 
 from iotdb.mlnode.log import logger
 from iotdb.mlnode.process.task import ForecastingTrainingTask
-from iotdb.mlnode.process.task_factory import create_task
 
 
 class TaskManager(object):
@@ -44,13 +43,14 @@ class TaskManager(object):
                              task_configs: dict) -> ForecastingTrainingTask:
         model_id = task_configs['model_id']
         self.__pid_info[model_id] = self.__shared_resource_manager.dict()
-        return create_task(
+        task = ForecastingTrainingTask(
             task_configs,
             model_configs,
             model,
             dataset,
             self.__pid_info
         )
+        return task
 
     def submit_training_task(self, task: ForecastingTrainingTask) -> None:
         if task is not None:
diff --git a/mlnode/iotdb/mlnode/process/task.py 
b/mlnode/iotdb/mlnode/process/task.py
index 85d5b5d2cf..cf067cb165 100644
--- a/mlnode/iotdb/mlnode/process/task.py
+++ b/mlnode/iotdb/mlnode/process/task.py
@@ -97,4 +97,5 @@ class ForecastingTrainingTask(_BasicTask):
             else:
                 self.trial.start()
         except Exception as e:
-            logger.exception(e)
+            logger.warn(e)
+            raise e
diff --git a/mlnode/iotdb/mlnode/process/task_factory.py 
b/mlnode/iotdb/mlnode/process/task_factory.py
deleted file mode 100644
index 083b84eba2..0000000000
--- a/mlnode/iotdb/mlnode/process/task_factory.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-
-from iotdb.mlnode.process.task import ForecastingTrainingTask
-
-support_task_types = {
-    'forecast': ForecastingTrainingTask
-}
-
-
-def create_task(task_configs, model_configs, model, dataset, task_trial_map):
-    task_class = task_configs["task_class"]
-    if task_class not in support_task_types:
-        raise RuntimeError(f'Unknown task type: ({task_class}), which'
-                           f' should be one of {support_task_types.keys()}')
-    task_fn = support_task_types[task_class]
-    task = task_fn(
-        task_configs,
-        model_configs,
-        model,
-        dataset,
-        task_trial_map
-    )
-    return task
diff --git a/mlnode/iotdb/mlnode/process/trial.py 
b/mlnode/iotdb/mlnode/process/trial.py
index 9852e3ffb4..105dfeaf37 100644
--- a/mlnode/iotdb/mlnode/process/trial.py
+++ b/mlnode/iotdb/mlnode/process/trial.py
@@ -23,7 +23,7 @@ import torch
 import torch.nn as nn
 from torch.utils.data import DataLoader, Dataset
 
-from iotdb.mlnode.algorithm.metric import MAE, MSE, all_metrics
+from iotdb.mlnode.algorithm.metric import all_metrics, build_metrics
 from iotdb.mlnode.client import client_manager
 from iotdb.mlnode.log import logger
 from iotdb.mlnode.storage import model_storage
@@ -122,6 +122,7 @@ class ForecastingTrainingTrial(BasicTrial):
         self.confignode_client = client_manager.borrow_config_node_client()
         self.criterion = torch.nn.MSELoss()
         self.optimizer = torch.optim.Adam(self.model.parameters(), 
lr=self.learning_rate)
+        self.metrics_dict = build_metrics(self.metric_names)
 
     def _build_dataloader(self):
         return DataLoader(
@@ -148,8 +149,10 @@ class ForecastingTrainingTrial(BasicTrial):
             # decoder input
             dec_inp = torch.zeros_like(batch_y[:, -self.pred_len:, :]).float()
             outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
+
             outputs = outputs[:, -self.pred_len:]
             batch_y = batch_y[:, -self.pred_len:]
+
             loss = self.criterion(outputs, batch_y)
             train_loss.append(loss.item())
 
@@ -161,7 +164,6 @@ class ForecastingTrainingTrial(BasicTrial):
             self.optimizer.step()
 
         train_loss = np.average(train_loss)
-        # TODO: manage these training output
         logger.info('Epoch: {0} cost time: {1} | Train Loss: {2:.7f}'
                     .format(epoch + 1, time.time() - epoch_time, train_loss))
         return train_loss
@@ -169,7 +171,7 @@ class ForecastingTrainingTrial(BasicTrial):
     def _validate(self, epoch):
         self.model.eval()
         val_loss = []
-        metrics_dict = {name: [] for name in self.metric_names}
+        metrics_value_dict = {name: [] for name in self.metric_names}
         for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in 
enumerate(self.dataloader):
             batch_x = batch_x.float().to(self.device)
             batch_y = batch_y.float().to(self.device)
@@ -185,26 +187,25 @@ class ForecastingTrainingTrial(BasicTrial):
             batch_y = batch_y[:, -self.pred_len:]
 
             loss = self.criterion(outputs, batch_y)
-
             val_loss.append(loss.item())
+
             for name in self.metric_names:
-                metric = eval(name)()
+                metric = self.metrics_dict[name]
                 value = metric(outputs.detach().cpu().numpy(), 
batch_y.detach().cpu().numpy())
-                metrics_dict[name].append(value)
+                metrics_value_dict[name].append(value)
 
-        for name, value_list in metrics_dict.items():
-            metrics_dict[name] = np.average(value_list)
+        for name, value_list in metrics_value_dict.items():
+            metrics_value_dict[name] = np.average(value_list)
 
-        # TODO: handle some exception
         self.datanode_client.record_model_metrics(
             model_id=self.model_id,
             trial_id=self.trial_id,
-            metrics=list(metrics_dict.keys()),
-            values=list(metrics_dict.values())
+            metrics=list(metrics_value_dict.keys()),
+            values=list(metrics_value_dict.values())
         )
         val_loss = np.average(val_loss)
         logger.info('Epoch: {0} Vali Loss: {1:.7f}'.format(epoch + 1, 
val_loss))
-        return val_loss, metrics_dict
+        return val_loss, metrics_value_dict
 
     def start(self) -> float:
         try:
diff --git a/mlnode/iotdb/mlnode/storage.py b/mlnode/iotdb/mlnode/storage.py
index 68392be53b..ae93cc991b 100644
--- a/mlnode/iotdb/mlnode/storage.py
+++ b/mlnode/iotdb/mlnode/storage.py
@@ -26,6 +26,7 @@ from pylru import lrucache
 
 from iotdb.mlnode.config import descriptor
 from iotdb.mlnode.exception import ModelNotExistError
+from iotdb.mlnode.log import logger
 
 
 class ModelStorage(object):
@@ -34,7 +35,8 @@ class ModelStorage(object):
         if not os.path.exists(self.__model_dir):
             try:
                 os.mkdir(self.__model_dir)
-            except PermissionError as e: # TODO: handle storage permission
+            except PermissionError as e:
+                logger.error(e)
                 raise e
 
         self.__model_cache = 
lrucache(descriptor.get_config().get_mn_model_storage_cache_size())
diff --git a/mlnode/test/test_create_forecast_dataset.py 
b/mlnode/test/test_create_forecast_dataset.py
index 69307e77a1..49e2d177e8 100644
--- a/mlnode/test/test_create_forecast_dataset.py
+++ b/mlnode/test/test_create_forecast_dataset.py
@@ -16,12 +16,11 @@
 # under the License.
 #
 import os
-import torch
+
 import requests
 
-from iotdb.mlnode.data_access.factory import create_forecast_dataset
-from iotdb.mlnode.exception import BadConfigValueError, MissingConfigError
 from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
+from iotdb.mlnode.data_access.factory import create_forecast_dataset
 
 
 def test_create_dataset():
diff --git a/mlnode/test/test_create_forecast_model.py 
b/mlnode/test/test_create_forecast_model.py
index 0e4e479c83..08c0fd10fb 100644
--- a/mlnode/test/test_create_forecast_model.py
+++ b/mlnode/test/test_create_forecast_model.py
@@ -17,9 +17,9 @@
 #
 import torch
 
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
 from iotdb.mlnode.algorithm.factory import create_forecast_model
 from iotdb.mlnode.exception import BadConfigValueError
-from iotdb.mlnode.algorithm.enums import ForecastTaskType
 
 
 def test_create_forecast_model():
@@ -34,7 +34,7 @@ def test_create_forecast_model():
     model, model_config = create_forecast_model(model_name='nbeats',
                                                 kernel_size=25, d_model=64)
     assert model_config['d_model'] == 64
-    assert 'kernel_size' not in model_config # config kernel_size not belongs 
to nbeats model
+    assert 'kernel_size' not in model_config  # config kernel_size not belongs 
to nbeats model
 
 
 def test_bad_config_model1():
@@ -68,6 +68,7 @@ def test_bad_config_model4():
     except BadConfigValueError as e:
         print(e)  # ('forecast_task_type', 'dummy_task')
 
+
 def test_bad_config_model5():
     try:
         model, models = create_forecast_model(model_name='dlinear',
diff --git a/mlnode/test/test_model_storage.py 
b/mlnode/test/test_model_storage.py
index 8e2cbf4623..1b18f974aa 100644
--- a/mlnode/test/test_model_storage.py
+++ b/mlnode/test/test_model_storage.py
@@ -23,8 +23,8 @@ import time
 import torch.nn as nn
 
 from iotdb.mlnode.config import config
-from iotdb.mlnode.storage import model_storage
 from iotdb.mlnode.exception import ModelNotExistError
+from iotdb.mlnode.storage import model_storage
 
 
 class ExampleModel(nn.Module):
diff --git a/mlnode/test/test_parse_training_request.py 
b/mlnode/test/test_parse_training_request.py
index 8fe193f53d..bafa53695d 100644
--- a/mlnode/test/test_parse_training_request.py
+++ b/mlnode/test/test_parse_training_request.py
@@ -15,8 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-from iotdb.mlnode.parser import parse_training_request
 from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
+from iotdb.mlnode.parser import parse_training_request
 from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq
 
 
@@ -26,7 +26,7 @@ def test_parse_training_request():
     model_configs = {
         'task_class': 'forecast_training_task',
         'source_type': 'thrift',
-        'dataset_type': 'window', # or use DatasetType.WINDOW,
+        'dataset_type': 'window',  # or use DatasetType.WINDOW,
         'filename': 'ETTh1.csv',
         'time_embed': 'h',
         'input_len': 96,
diff --git a/thrift-mlnode/src/main/thrift/mlnode.thrift 
b/thrift-mlnode/src/main/thrift/mlnode.thrift
index 916022e973..abadc79576 100644
--- a/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -31,7 +31,7 @@ struct TCreateTrainingTaskReq {
 
 struct TDeleteModelReq {
   1: required string modelId
-  2: optional string trailId
+  2: optional string trialId
 }
 
 struct TForecastReq {

Reply via email to