This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch research/MLEngine
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/research/MLEngine by this push:
new f457ab4c772 [IOTDB-5748] [IoTDB ML] Support executing training tasks
on MLNode (#9491)
f457ab4c772 is described below
commit f457ab4c772b2091f2d22f1a1a0e7a60b1371b93
Author: hangzhou188 <[email protected]>
AuthorDate: Fri Jun 2 10:46:11 2023 +0800
[IOTDB-5748] [IoTDB ML] Support executing training tasks on MLNode (#9491)
---
mlnode/.gitignore | 3 +-
mlnode/iotdb/mlnode/algorithm/enums.py | 54 ++++
mlnode/iotdb/mlnode/algorithm/factory.py | 130 +++++++++
mlnode/iotdb/mlnode/algorithm/metric.py | 77 ++++++
.../models/forecast/__init__.py} | 14 -
.../mlnode/algorithm/models/forecast/dlinear.py | 163 ++++++++++++
.../mlnode/algorithm/models/forecast/nbeats.py | 170 ++++++++++++
mlnode/iotdb/mlnode/client.py | 81 +++++-
mlnode/iotdb/mlnode/config.py | 36 +++
.../mlnode/{exception.py => data_access/enums.py} | 33 ++-
mlnode/iotdb/mlnode/data_access/factory.py | 105 ++++++++
mlnode/iotdb/mlnode/data_access/offline/dataset.py | 99 +++++++
mlnode/iotdb/mlnode/data_access/offline/source.py | 82 ++++++
.../utils/__init__.py} | 14 -
.../iotdb/mlnode/data_access/utils/timefeatures.py | 171 ++++++++++++
mlnode/iotdb/mlnode/exception.py | 23 +-
mlnode/iotdb/mlnode/handler.py | 51 +++-
mlnode/iotdb/mlnode/parser.py | 228 ++++++++++++++++
.../mlnode/{exception.py => process/__init__.py} | 14 -
mlnode/iotdb/mlnode/process/manager.py | 128 +++++++++
mlnode/iotdb/mlnode/process/task.py | 294 +++++++++++++++++++++
mlnode/iotdb/mlnode/process/trial.py | 263 ++++++++++++++++++
mlnode/iotdb/mlnode/serde.py | 142 +++++++---
mlnode/iotdb/mlnode/storage.py | 11 +-
mlnode/resources/conf/iotdb-mlnode.toml | 12 +
mlnode/test/test_create_forecast_dataset.py | 88 ++++++
mlnode/test/test_create_forecast_model.py | 78 ++++++
mlnode/test/test_model_storage.py | 2 +-
mlnode/test/test_parse_training_request.py | 135 ++++++++++
.../config/executor/ClusterConfigTaskExecutor.java | 166 ++++++------
30 files changed, 2673 insertions(+), 194 deletions(-)
diff --git a/mlnode/.gitignore b/mlnode/.gitignore
index 4a6f6eac650..cc99cc4f975 100644
--- a/mlnode/.gitignore
+++ b/mlnode/.gitignore
@@ -5,4 +5,5 @@
/dist/
# generated by MLNode
-*.pt
\ No newline at end of file
+*.pt
+/rebuild.bat
diff --git a/mlnode/iotdb/mlnode/algorithm/enums.py
b/mlnode/iotdb/mlnode/algorithm/enums.py
new file mode 100644
index 00000000000..aea00db377a
--- /dev/null
+++ b/mlnode/iotdb/mlnode/algorithm/enums.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from enum import Enum
+from typing import List
+
+
+class ForecastTaskType(Enum):
+ """
+ In multivariable time series forecasting tasks, the columns to be
predicted are called the endogenous variables
+ and the columns that are independent or non-forecast are called the
exogenous variables (they might be helpful
+ to predict the endogenous variables). Both of them can appear in the input
time series.
+
+ ForecastTaskType.ENDOGENOUS: all input time series of input are endogenous
variables
+ ForecastTaskType.EXOGENOUS: the input time series is combined with
endogenous and exogenous variables
+ """
+ ENDOGENOUS = "endogenous"
+ EXOGENOUS = "exogenous"
+
+ def __str__(self) -> str:
+ return self.value
+
+ def __eq__(self, other: str) -> bool:
+ return self.value == other
+
+ def __hash__(self) -> int:
+ return hash(self.value)
+
+
+class ForecastModelType(Enum):
+ DLINEAR = "dlinear"
+ DLINEAR_INDIVIDUAL = "dlinear_individual"
+ NBEATS = "nbeats"
+
+ @classmethod
+ def values(cls) -> List[str]:
+ values = []
+ for item in list(cls):
+ values.append(item.value)
+ return values
diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py
b/mlnode/iotdb/mlnode/algorithm/factory.py
new file mode 100644
index 00000000000..ed39e4e2b8c
--- /dev/null
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from typing import Dict, Tuple
+
+import torch.nn as nn
+
+from iotdb.mlnode.algorithm.enums import ForecastModelType, ForecastTaskType
+from iotdb.mlnode.algorithm.models.forecast.dlinear import (dlinear,
+ dlinear_individual)
+from iotdb.mlnode.algorithm.models.forecast.nbeats import nbeats
+from iotdb.mlnode.exception import BadConfigValueError
+
+
+# Common configs for all forecasting model with default values
+def _common_config(**kwargs):
+ return {
+ 'input_len': 96,
+ 'pred_len': 96,
+ 'input_vars': 1,
+ 'output_vars': 1,
+ **kwargs
+ }
+
+
+# Common forecasting task configs
+_forecasting_model_default_config_dict = {
+ # multivariable forecasting with all endogenous variables, current support
this only
+ ForecastTaskType.ENDOGENOUS: _common_config(
+ input_vars=1,
+ output_vars=1),
+
+ # multivariable forecasting with some exogenous variables
+ ForecastTaskType.EXOGENOUS: _common_config(
+ output_vars=1),
+}
+
+
+def create_forecast_model(
+ model_name,
+ input_len=96,
+ pred_len=96,
+ input_vars=1,
+ output_vars=1,
+ forecast_task_type=ForecastTaskType.ENDOGENOUS,
+ **kwargs,
+) -> Tuple[nn.Module, Dict]:
+ """
+ Factory method for all support forecasting models
+ the given arguments is common configs shared by all forecasting models
+ for specific model configs, see _model_config in `MODELNAME.py`
+
+ Args:
+ model_name: see available models by `list_model`
+ forecast_task_type: see algorithm/enums for available choices
+ input_len: time length of model input
+ pred_len: time length of model output
+ input_vars: number of input series
+ output_vars: number of output series
+ kwargs: for specific model configs, see returned `model_config` with
kwargs=None
+
+ Returns:
+ model: torch.nn.Module
+ model_config: dict of model configurations
+ """
+ if model_name not in ForecastModelType.values():
+ raise BadConfigValueError('model_name', model_name, f'It should be one
of {ForecastModelType.values()}')
+ if forecast_task_type not in _forecasting_model_default_config_dict.keys():
+ raise BadConfigValueError('forecast_task_type', forecast_task_type,
+ f'It should be one of
{list(_forecasting_model_default_config_dict.keys())}')
+
+ common_config = _forecasting_model_default_config_dict[forecast_task_type]
+ common_config['input_len'] = input_len
+ common_config['pred_len'] = pred_len
+ common_config['input_vars'] = input_vars
+ common_config['output_vars'] = output_vars
+ common_config['forecast_task_type'] = str(forecast_task_type)
+
+ if not input_len > 0:
+ raise BadConfigValueError('input_len', input_len,
+ 'Length of input series should be positive')
+ if not pred_len > 0:
+ raise BadConfigValueError('pred_len', pred_len,
+ 'Length of predicted series should be
positive')
+ if not input_vars > 0:
+ raise BadConfigValueError('input_vars', input_vars,
+ 'Number of input variables should be
positive')
+ if not output_vars > 0:
+ raise BadConfigValueError('output_vars', output_vars,
+ 'Number of output variables should be
positive')
+ if forecast_task_type is ForecastTaskType.ENDOGENOUS:
+ if input_vars != output_vars:
+ raise BadConfigValueError('forecast_task_type', forecast_task_type,
+ 'Number of input/output variables should
be '
+ 'the same in endogenous forecast')
+
+ if model_name == ForecastModelType.DLINEAR.value:
+ model, model_config = dlinear(
+ common_config=common_config,
+ **kwargs
+ )
+ elif model_name == ForecastModelType.DLINEAR_INDIVIDUAL.value:
+ model, model_config = dlinear_individual(
+ common_config=common_config,
+ **kwargs
+ )
+ elif model_name == ForecastModelType.NBEATS.value:
+ model, model_config = nbeats(
+ common_config=common_config,
+ **kwargs
+ )
+ else:
+ raise BadConfigValueError('model_name', model_name, f'It should be one
of {ForecastModelType.values()}')
+
+ model_config['model_name'] = model_name
+ return model, model_config
diff --git a/mlnode/iotdb/mlnode/algorithm/metric.py
b/mlnode/iotdb/mlnode/algorithm/metric.py
new file mode 100644
index 00000000000..5ffc5f24a5c
--- /dev/null
+++ b/mlnode/iotdb/mlnode/algorithm/metric.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from abc import abstractmethod
+from typing import Dict, List
+
+import numpy as np
+
+all_metrics = ['RSE', 'CORR', 'MAE', 'MSE', 'RMSE', 'MAPE', 'MSPE']
+
+
+class Metric(object):
+ def __call__(self, pred, ground_truth):
+ return self.calculate(pred, ground_truth)
+
+ @abstractmethod
+ def calculate(self, pred, ground_truth):
+ pass
+
+
+class RSE(Metric):
+ def calculate(self, pred, ground_truth):
+ return np.sqrt(np.sum((ground_truth - pred) ** 2)) /
np.sqrt(np.sum((ground_truth - ground_truth.mean()) ** 2))
+
+
+class CORR(Metric):
+ def calculate(self, pred, ground_truth):
+ u = ((ground_truth - ground_truth.mean(0)) * (pred -
pred.mean(0))).sum(0)
+ d = np.sqrt(((ground_truth - ground_truth.mean(0)) ** 2 * (pred -
pred.mean(0)) ** 2).sum(0))
+ return (u / d).mean(-1)
+
+
+class MAE(Metric):
+ def calculate(self, pred, ground_truth):
+ return np.mean(np.abs(pred - ground_truth))
+
+
+class MSE(Metric):
+ def calculate(self, pred, ground_truth):
+ return np.mean((pred - ground_truth) ** 2)
+
+
+class RMSE(Metric):
+ def calculate(self, pred, ground_truth):
+ mse = MSE()
+ return np.sqrt(mse(pred, ground_truth))
+
+
+class MAPE(Metric):
+ def calculate(self, pred, ground_truth):
+ return np.mean(np.abs((pred - ground_truth) / ground_truth))
+
+
+class MSPE(Metric):
+ def calculate(self, pred, ground_truth):
+ return np.mean(np.square((pred - ground_truth) / ground_truth))
+
+
+def build_metrics(metric_names: List) -> Dict[str, Metric]:
+ metrics_dict = {}
+ for metric_name in metric_names:
+ metrics_dict[metric_name] = eval(metric_name)()
+ return metrics_dict
diff --git a/mlnode/iotdb/mlnode/exception.py
b/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
similarity index 67%
copy from mlnode/iotdb/mlnode/exception.py
copy to mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
index 6307909a9ac..2a1e720805f 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
@@ -15,17 +15,3 @@
# specific language governing permissions and limitations
# under the License.
#
-
-class _BaseError(Exception):
- """Base class for exceptions in this module."""
- pass
-
-
-class BadNodeUrlError(_BaseError):
- def __init__(self, node_url: str):
- self.message = "Bad node url: {}".format(node_url)
-
-
-class ModelNotExistError(_BaseError):
- def __init__(self, file_path: str):
- self.message = "Model path: ({}) not exists".format(file_path)
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
new file mode 100644
index 00000000000..35ba728fa78
--- /dev/null
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import math
+from typing import Dict, Tuple
+
+import torch
+import torch.nn as nn
+
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.exception import BadConfigValueError
+
+
+class MovingAverageBlock(nn.Module):
+ """ Moving average block to highlight the trend of time series """
+
+ def __init__(self, kernel_size, stride):
+ super(MovingAverageBlock, self).__init__()
+ self.kernel_size = kernel_size
+ self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride,
padding=0)
+
+ def forward(self, x):
+ # padding on the both ends of time series
+ front = x[:, 0:1, :].repeat(1, self.kernel_size - 1 -
math.floor((self.kernel_size - 1) // 2), 1)
+ end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2),
1)
+ x = torch.cat([front, x, end], dim=1)
+ x = self.avg(x.permute(0, 2, 1))
+ x = x.permute(0, 2, 1)
+ return x
+
+
+class SeriesDecompositionBlock(nn.Module):
+ """ Series decomposition block """
+
+ def __init__(self, kernel_size):
+ super(SeriesDecompositionBlock, self).__init__()
+ self.moving_avg = MovingAverageBlock(kernel_size, stride=1)
+
+ def forward(self, x):
+ moving_mean = self.moving_avg(x)
+ res = x - moving_mean
+ return res, moving_mean
+
+
+class DLinear(nn.Module):
+ """ Decomposition Linear Model """
+
+ def __init__(
+ self,
+ kernel_size=25,
+ input_len=96,
+ pred_len=96,
+ input_vars=1,
+ output_vars=1,
+ forecast_task_type=ForecastTaskType.ENDOGENOUS, # TODO, support
others
+ ):
+ super(DLinear, self).__init__()
+ self.input_len = input_len
+ self.pred_len = pred_len
+ self.kernel_size = kernel_size
+ self.channels = input_vars
+
+ # decomposition Kernel Size
+ self.decomposition = SeriesDecompositionBlock(kernel_size)
+ self.linear_seasonal = nn.Linear(self.input_len, self.pred_len)
+ self.linear_trend = nn.Linear(self.input_len, self.pred_len)
+
+ def forward(self, x, *args):
+ # x: [Batch, Input length, Channel]
+ seasonal_init, trend_init = self.decomposition(x)
+ seasonal_init, trend_init = seasonal_init.permute(0, 2, 1),
trend_init.permute(0, 2, 1)
+
+ seasonal_output = self.linear_seasonal(seasonal_init)
+ trend_output = self.linear_trend(trend_init)
+
+ x = seasonal_output + trend_output
+ return x.permute(0, 2, 1) # to [Batch, Output length, Channel]
+
+
+class DLinearIndividual(nn.Module):
+ """ Decomposition Linear Model (individual) """
+
+ def __init__(
+ self,
+ kernel_size=25,
+ input_len=96,
+ pred_len=96,
+ input_vars=1,
+ output_vars=1,
+ forecast_task_type=ForecastTaskType.ENDOGENOUS, # TODO, support
others
+ ):
+ super(DLinearIndividual, self).__init__()
+ self.input_len = input_len
+ self.pred_len = pred_len
+ self.kernel_size = kernel_size
+ self.channels = input_vars
+
+ self.decomposition = SeriesDecompositionBlock(kernel_size)
+ self.Linear_Seasonal = nn.ModuleList(
+ [nn.Linear(self.input_len, self.pred_len) for _ in
range(self.channels)]
+ )
+ self.Linear_Trend = nn.ModuleList(
+ [nn.Linear(self.input_len, self.pred_len) for _ in
range(self.channels)]
+ )
+
+ def forward(self, x, *args):
+ # x: [Batch, Input length, Channel]
+ seasonal_init, trend_init = self.decomposition(x)
+ seasonal_init, trend_init = seasonal_init.permute(0, 2, 1),
trend_init.permute(0, 2, 1)
+
+ seasonal_output = torch.zeros([seasonal_init.size(0),
seasonal_init.size(1), self.pred_len],
+
dtype=seasonal_init.dtype).to(seasonal_init.device)
+ trend_output = torch.zeros([trend_init.size(0), trend_init.size(1),
self.pred_len],
+
dtype=trend_init.dtype).to(trend_init.device)
+ for i, linear_season_layer in enumerate(self.Linear_Seasonal):
+ seasonal_output[:, i, :] = linear_season_layer(seasonal_init[:, i,
:])
+ for i, linear_trend_layer in enumerate(self.Linear_Trend):
+ trend_output[:, i, :] = linear_trend_layer(trend_init[:, i, :])
+
+ x = seasonal_output + trend_output
+ return x.permute(0, 2, 1) # to [Batch, Output length, Channel]
+
+
+def _model_config(**kwargs):
+ return {
+ 'kernel_size': 25,
+ **kwargs
+ }
+
+
+def dlinear(common_config: Dict, kernel_size=25, **kwargs) -> Tuple[DLinear,
Dict]:
+ config = _model_config()
+ config.update(**common_config)
+ if not kernel_size > 0:
+ raise BadConfigValueError('kernel_size', kernel_size,
+ 'Kernel size of dlinear should larger than
0')
+ config['kernel_size'] = kernel_size
+ return DLinear(**config), config
+
+
+def dlinear_individual(common_config: Dict, kernel_size=25, **kwargs) ->
Tuple[DLinearIndividual, Dict]:
+ config = _model_config()
+ config.update(**common_config)
+ if not kernel_size > 0:
+ raise BadConfigValueError('kernel_size', kernel_size,
+ 'Kernel size of dlinear_individual should
larger than 0')
+ config['kernel_size'] = kernel_size
+ return DLinearIndividual(**config), config
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
b/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
new file mode 100644
index 00000000000..f12e4135bee
--- /dev/null
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
@@ -0,0 +1,170 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Dict, Tuple
+
+import torch
+import torch.nn as nn
+
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.exception import BadConfigValueError
+
+
+class GenericBasis(nn.Module):
+ """ Generic basis function """
+
+ def __init__(self, backcast_size: int, forecast_size: int):
+ super().__init__()
+ self.backcast_size = backcast_size
+ self.forecast_size = forecast_size
+
+ def forward(self, theta: torch.Tensor):
+ return theta[:, :self.backcast_size], theta[:, -self.forecast_size:]
+
+
+block_dict = {
+ 'generic': GenericBasis,
+}
+
+
+class NBeatsBlock(nn.Module):
+ """ N-BEATS block which takes a basis function as an argument """
+
+ def __init__(self,
+ input_size,
+ theta_size: int,
+ basis_function: nn.Module,
+ layers: int,
+ layer_size: int):
+ """
+ N-BEATS block
+
+ Args:
+ input_size: input sample size
+ theta_size: number of parameters for the basis function
+ basis_function: basis function which takes the parameters and
produces backcast and forecast
+ layers: number of layers
+ layer_size: layer size
+ """
+ super().__init__()
+ self.layers = nn.ModuleList([nn.Linear(in_features=input_size,
out_features=layer_size)] + [
+ nn.Linear(in_features=layer_size, out_features=layer_size) for _
in range(layers - 1)])
+ self.basis_parameters = nn.Linear(in_features=layer_size,
out_features=theta_size)
+ self.basis_function = basis_function
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ block_input = x
+ for layer in self.layers:
+ block_input = torch.relu(layer(block_input))
+ basis_parameters = self.basis_parameters(block_input)
+ return self.basis_function(basis_parameters)
+
+
+class NBeatsUnivariable(nn.Module):
+ """ N-Beats Model (uni-variable) """
+
+ def __init__(self, blocks: nn.ModuleList):
+ super().__init__()
+ self.blocks = blocks
+
+ def forward(self, x):
+ residuals = x
+ forecast = None
+ for _, block in enumerate(self.blocks):
+ backcast, block_forecast = block(residuals)
+ residuals = (residuals - backcast)
+ if forecast is None:
+ forecast = block_forecast
+ else:
+ forecast += block_forecast
+ return forecast
+
+
+class NBeats(nn.Module):
+ """ Neural Basis Expansion Analysis Time Series """
+
+ def __init__(
+ self,
+ block_type='generic',
+ d_model=128,
+ inner_layers=4,
+ outer_layers=4,
+ input_len=96,
+ pred_len=96,
+ input_vars=1,
+ output_vars=1,
+ forecast_task_type=ForecastTaskType.ENDOGENOUS, # TODO, support
others
+ ):
+ super(NBeats, self).__init__()
+ self.enc_in = input_vars
+ self.block = block_dict[block_type]
+ self.model = NBeatsUnivariable(
+ torch.nn.ModuleList(
+ [NBeatsBlock(input_size=input_len,
+ theta_size=input_len + pred_len,
+
basis_function=self.block(backcast_size=input_len, forecast_size=pred_len),
+ layers=inner_layers,
+ layer_size=d_model)
+ for _ in range(outer_layers)]
+ )
+ )
+
+ def forward(self, x, *args):
+ # x: [Batch, Input length, Channel]
+ res = []
+ for i in range(self.enc_in):
+ dec_out = self.model(x[:, :, i])
+ res.append(dec_out)
+ return torch.stack(res, dim=-1) # to [Batch, Output length, Channel]
+
+
+def _model_config(**kwargs):
+ return {
+ 'block_type': 'generic',
+ 'd_model': 128,
+ 'inner_layers': 4,
+ 'outer_layers': 4,
+ **kwargs
+ }
+
+
+"""
+Specific configs for NBeats variants
+"""
+support_model_configs = {
+ 'nbeats': _model_config(
+ block_type='generic'),
+}
+
+
+def nbeats(common_config: Dict, d_model=128, inner_layers=4, outer_layers=4,
**kwargs) -> Tuple[NBeats, Dict]:
+ config = _model_config()
+ config.update(**common_config)
+ if not d_model > 0:
+ raise BadConfigValueError('d_model', d_model,
+ 'Model dimension (d_model) of nbeats should
larger than 0')
+ if not inner_layers > 0:
+ raise BadConfigValueError('inner_layers', inner_layers,
+ 'Number of inner layers of nbeats should
larger than 0')
+ if not outer_layers > 0:
+ raise BadConfigValueError('outer_layers', outer_layers,
+ 'Number of outer layers of nbeats should
larger than 0')
+ config['d_model'] = d_model
+ config['inner_layers'] = inner_layers
+ config['outer_layers'] = outer_layers
+ return NBeats(**config), config
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index 3157006e578..ed1ec11beb8 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -16,7 +16,7 @@
# under the License.
#
import time
-from typing import Dict, List
+from typing import Dict, List, Tuple
import pandas as pd
from thrift.protocol import TBinaryProtocol, TCompactProtocol
@@ -33,10 +33,12 @@ from iotdb.thrift.confignode import IConfigNodeRPCService
from iotdb.thrift.confignode.ttypes import (TUpdateModelInfoReq,
TUpdateModelStateReq)
from iotdb.thrift.datanode import IMLNodeInternalRPCService
-from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq,
+from iotdb.thrift.datanode.ttypes import (TFetchMoreDataReq,
+ TFetchTimeseriesReq,
TRecordModelMetricsReq)
from iotdb.thrift.mlnode import IMLNodeRPCService
-from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TDeleteModelReq
+from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
+ TDeleteModelReq, TForecastReq)
class ClientManager(object):
@@ -51,13 +53,17 @@ class ClientManager(object):
def borrow_config_node_client(self):
return ConfigNodeClient(config_leader=self.__config_node_endpoint)
+ def borrow_mlnode_client(self):
+ return MLNodeClient(descriptor.get_config().get_mn_rpc_address(),
+ descriptor.get_config().get_mn_rpc_port())
+
class MLNodeClient(object):
def __init__(self, host, port):
self.__host = host
self.__port = port
- transport = TTransport.TBufferedTransport(
+ transport = TTransport.TFramedTransport(
TSocket.TSocket(self.__host, self.__port)
)
if not transport.isOpen():
@@ -88,9 +94,29 @@ class MLNodeClient(object):
except TTransport.TException as e:
raise e
- def create_forecast_task(self) -> None:
- # TODO
- pass
+ def create_forecast_task(self,
+ model_path: str,
+ ts_dataset: List,
+ column_name_list: List[str],
+ column_type_list: List[str],
+ column_name_index_map: Dict[str, int],
+ pred_length: int,
+ model_id: str
+ ) -> None:
+ req = TForecastReq(
+ modelPath=model_path,
+ tsDataset=ts_dataset,
+ columnNameList=column_name_list,
+ columnTypeList=column_type_list,
+ columnNameIndexMap=column_name_index_map,
+ predLength=pred_length,
+ modelId=model_id
+ )
+ try:
+ result = self.__client.forecast(req)
+ print(result)
+ except Exception as e:
+ raise e
def delete_model(self,
model_id: str,
@@ -128,7 +154,7 @@ class DataNodeClient(object):
query_expressions: List[str],
query_filter: str = None,
fetch_size: int = DEFAULT_FETCH_SIZE,
- timeout: int = DEFAULT_TIMEOUT) -> [int, bool,
pd.DataFrame]:
+ timeout: int = DEFAULT_TIMEOUT) -> pd.DataFrame:
req = TFetchTimeseriesReq(
queryExpressions=query_expressions,
queryFilter=query_filter,
@@ -141,7 +167,6 @@ class DataNodeClient(object):
if len(resp.tsDataset) == 0:
raise RuntimeError(f'No data fetched with query filter:
{query_filter}')
-
data = serde.convert_to_df(resp.columnNameList,
resp.columnTypeList,
resp.columnNameIndexMap,
@@ -149,17 +174,36 @@ class DataNodeClient(object):
if data.empty:
raise RuntimeError(
f'Fetched empty data with query expressions:
{query_expressions} and query filter: {query_filter}')
- return resp.queryId, resp.hasMoreData, data
except Exception as e:
logger.warn(
f'Fail to fetch data with query expressions:
{query_expressions} and query filter: {query_filter}')
raise e
+ query_id = resp.queryId
+ column_name_list = resp.columnNameList
+ column_type_list = resp.columnTypeList
+ column_name_index_map = resp.columnNameIndexMap
+ has_more_data = resp.hasMoreData
+ while has_more_data:
+ req = TFetchMoreDataReq(queryId=query_id, fetchSize=fetch_size)
+ try:
+ resp = self.__client.fetchMoreData(req)
+ verify_success(resp.status, "An error occurs when calling
fetch_more_data()")
+ data = data.append(serde.convert_to_df(column_name_list,
+ column_type_list,
+ column_name_index_map,
+ resp.tsDataset))
+ has_more_data = resp.hasMoreData
+ except Exception as e:
+ logger.warn(
+ f'Fail to fetch more data with query id: {query_id}')
+ raise e
+ return data
def fetch_window_batch(self,
query_expressions: list,
query_filter: str = None,
fetch_size: int = DEFAULT_FETCH_SIZE,
- timeout: int = DEFAULT_TIMEOUT) -> [int, bool,
List[pd.DataFrame]]:
+ timeout: int = DEFAULT_TIMEOUT) -> Tuple[int, bool,
List[pd.DataFrame]]:
pass
def record_model_metrics(self,
@@ -315,3 +359,18 @@ class ConfigNodeClient(object):
client_manager = ClientManager()
+
+if __name__ == '__main__':
+ client = client_manager.borrow_mlnode_client()
+ import pickle
+ f =
open('D:\\undergraduate\\DL\\iotdb\\mlnode\\iotdb\\mlnode\\test_tsdataset.pkl',
'rb')
+ ts_dataset = pickle.load(f)
+ client.create_forecast_task(
+
'D:\\undergraduate\\DL\\iotdb\\mlnode\\iotdb\\mlnode\\models\\Model_1\\tid_0.pt',
+ ts_dataset,
+ ['root.eg.etth1.s0'],
+ ['FLOAT'],
+ {'root.eg.etth1.s0': 0},
+ 192,
+ 'Model_2'
+ )
diff --git a/mlnode/iotdb/mlnode/config.py b/mlnode/iotdb/mlnode/config.py
index 0ccfdc2cbb0..c54aa07ff25 100644
--- a/mlnode/iotdb/mlnode/config.py
+++ b/mlnode/iotdb/mlnode/config.py
@@ -40,6 +40,15 @@ class MLNodeConfig(object):
# Cache number of model storage to avoid repeated loading
self.__mn_model_storage_cache_size = 30
+ # Maximum number of training model tasks, otherwise the task is pending
+ self.__mn_task_pool_size = 10
+
+ # Maximum number of trials to be explored in a tuning task
+ self.__mn_tuning_trial_num = 20
+
+ # Concurrency of trials in a tuning task
+ self.__mn_tuning_trial_concurrency = 4
+
# Target ConfigNode to be connected by MLNode
self.__mn_target_config_node: TEndPoint = TEndPoint("127.0.0.1", 10710)
@@ -70,6 +79,24 @@ class MLNodeConfig(object):
def set_mn_model_storage_cache_size(self, mn_model_storage_cache_size:
int) -> None:
self.__mn_model_storage_cache_size = mn_model_storage_cache_size
+ def get_mn_mn_task_pool_size(self) -> int:
+ return self.__mn_task_pool_size
+
+ def set_mn_task_pool_size(self, mn_task_pool_size: int) -> None:
+ self.__mn_task_pool_size = mn_task_pool_size
+
+ def get_mn_tuning_trial_num(self) -> int:
+ return self.__mn_tuning_trial_num
+
+ def set_mn_tuning_trial_num(self, mn_tuning_trial_num: int) -> None:
+ self.__mn_tuning_trial_num = mn_tuning_trial_num
+
+ def get_mn_tuning_trial_concurrency(self) -> int:
+ return self.__mn_tuning_trial_concurrency
+
+ def set_mn_tuning_trial_concurrency(self, mn_tuning_trial_concurrency:
int) -> None:
+ self.__mn_tuning_trial_concurrency = mn_tuning_trial_concurrency
+
def get_mn_target_config_node(self) -> TEndPoint:
return self.__mn_target_config_node
@@ -114,6 +141,15 @@ class MLNodeDescriptor(object):
if file_configs.mn_model_storage_cache_size is not None:
self.__config.set_mn_model_storage_cache_size(file_configs.mn_model_storage_cache_size)
+ if file_configs.mn_task_pool_size is not None:
+
self.__config.set_mn_task_pool_size(file_configs.mn_task_pool_size)
+
+ if file_configs.mn_tuning_trial_num is not None:
+
self.__config.set_mn_tuning_trial_num(file_configs.mn_tuning_trial_num)
+
+ if file_configs.mn_tuning_trial_concurrency is not None:
+
self.__config.set_mn_tuning_trial_concurrency(file_configs.mn_tuning_trial_concurrency)
+
if file_configs.mn_target_config_node is not None:
self.__config.set_mn_target_config_node(file_configs.mn_target_config_node)
diff --git a/mlnode/iotdb/mlnode/exception.py
b/mlnode/iotdb/mlnode/data_access/enums.py
similarity index 59%
copy from mlnode/iotdb/mlnode/exception.py
copy to mlnode/iotdb/mlnode/data_access/enums.py
index 6307909a9ac..a5d415fc0e2 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/data_access/enums.py
@@ -15,17 +15,32 @@
# specific language governing permissions and limitations
# under the License.
#
+from enum import Enum
-class _BaseError(Exception):
- """Base class for exceptions in this module."""
- pass
+class DatasetType(Enum):
+ TIMESERIES = "timeseries"
+ WINDOW = "window"
-class BadNodeUrlError(_BaseError):
- def __init__(self, node_url: str):
- self.message = "Bad node url: {}".format(node_url)
+ def __str__(self):
+ return self.value
+ def __hash__(self):
+ return hash(self.value)
-class ModelNotExistError(_BaseError):
- def __init__(self, file_path: str):
- self.message = "Model path: ({}) not exists".format(file_path)
+ def __eq__(self, other: str) -> bool:
+ return self.value == other
+
+
+class DataSourceType(Enum):
+ FILE = "file"
+ THRIFT = "thrift"
+
+ def __str__(self):
+ return self.value
+
+ def __hash__(self):
+ return hash(self.value)
+
+ def __eq__(self, other: str) -> bool:
+ return self.value == other
diff --git a/mlnode/iotdb/mlnode/data_access/factory.py
b/mlnode/iotdb/mlnode/data_access/factory.py
new file mode 100644
index 00000000000..ee8d57ffb45
--- /dev/null
+++ b/mlnode/iotdb/mlnode/data_access/factory.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from typing import Dict, Tuple
+
+from torch.utils.data import Dataset
+
+from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
+from iotdb.mlnode.data_access.offline.dataset import (TimeSeriesDataset,
+ WindowDataset)
+from iotdb.mlnode.data_access.offline.source import (FileDataSource,
+ ThriftDataSource)
+from iotdb.mlnode.exception import BadConfigValueError, MissingConfigError
+
+
+def _dataset_common_config(**kwargs):
+ return {
+ 'time_embed': 'h',
+ **kwargs
+ }
+
+
+_dataset_default_config_dict = {
+ DatasetType.TIMESERIES: _dataset_common_config(),
+ DatasetType.WINDOW: _dataset_common_config(
+ input_len=96,
+ pred_len=96,
+ )
+}
+
+
+def create_forecast_dataset(
+ source_type,
+ dataset_type,
+ **kwargs,
+) -> Tuple[Dataset, Dict]:
+ """
+ Factory method for all support dataset
+ currently implement two types of PyTorch dataset: WindowDataset,
TimeSeriesDataset
+ support two types of offline data source: FileDataSource and
ThriftDataSource
+ for specific dataset/datasource configs, see _dataset_config in
`dataset.py` and `source.py`
+
+ Args:
+ dataset_type: see data_access/enums for available choices
+ source_type: see data_access/enums for available cho ices
+ kwargs: for specific dataset configs, see returned `dataset_config`
with kwargs=None
+
+ Returns:
+ dataset: torch.nn.Module
+ dataset_config: dict of dataset configurations
+ """
+ if source_type == DataSourceType.FILE:
+ if 'filename' not in kwargs.keys():
+ raise MissingConfigError('filename')
+ datasource = FileDataSource(kwargs['filename'])
+ elif source_type == DataSourceType.THRIFT:
+ if 'query_expressions' not in kwargs.keys():
+ raise MissingConfigError('query_expressions')
+ if 'query_filter' not in kwargs.keys():
+ raise MissingConfigError('query_filter')
+ datasource = ThriftDataSource(kwargs['query_expressions'],
kwargs['query_filter'])
+ else:
+ raise BadConfigValueError('source_type', source_type, f"It should be
one of {list(DataSourceType)}")
+
+ if dataset_type not in list(DatasetType):
+ raise BadConfigValueError('dataset_type', dataset_type, f'It should be
one of {list(DatasetType)}')
+ dataset_config = _dataset_default_config_dict[dataset_type]
+
+ for k, v in kwargs.items():
+ if k in dataset_config.keys():
+ dataset_config[k] = v
+
+ if dataset_type == DatasetType.TIMESERIES:
+ dataset = TimeSeriesDataset(datasource, **dataset_config)
+ elif dataset_type == DatasetType.WINDOW:
+ dataset = WindowDataset(datasource, **dataset_config)
+ else:
+ raise BadConfigValueError('dataset_type', dataset_type, f'It should be
one of {list(DatasetType)}')
+
+ if 'input_vars' in kwargs.keys() and dataset.get_variable_num() !=
kwargs['input_vars']:
+ raise BadConfigValueError('input_vars', kwargs['input_vars'],
+ f'Variable number of fetched data should be
consistent with '
+ f'input_vars, but got:
{dataset.get_variable_num()}')
+
+ data_config = dataset_config.copy()
+ data_config['input_vars'] = dataset.get_variable_num()
+ data_config['output_vars'] = dataset.get_variable_num()
+ data_config['source_type'] = str(source_type)
+ data_config['dataset_type'] = str(dataset_type)
+
+ return dataset, data_config
diff --git a/mlnode/iotdb/mlnode/data_access/offline/dataset.py
b/mlnode/iotdb/mlnode/data_access/offline/dataset.py
new file mode 100644
index 00000000000..a1b11ccd861
--- /dev/null
+++ b/mlnode/iotdb/mlnode/data_access/offline/dataset.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from typing import Tuple
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from iotdb.mlnode.data_access.offline.source import DataSource
+from iotdb.mlnode.data_access.utils.timefeatures import time_features
+
+
+class TimeSeriesDataset(Dataset):
+ """
+ Build Row-by-Row dataset (with each element as multivariable time series at
+ the same time and correponding timestamp embedding)
+
+ Args:
+ data_source: the whole multivariate time series for a while
+ time_embed: embedding frequency, see `utils/timefeatures.py` for more
detail
+
+ Returns:
+ Random accessible dataset
+ """
+
+ def __init__(self, data_source: DataSource, time_embed: str = 'h'):
+ self.time_embed = time_embed
+ self.data = data_source.get_data()
+ self.data_stamp = time_features(data_source.get_timestamp(),
time_embed=self.time_embed).transpose(1, 0)
+ self.n_vars = self.data.shape[-1]
+
+ def get_variable_num(self) -> int:
+ return self.n_vars # number of series in data_source
+
+ def __getitem__(self, index) -> Tuple[np.ndarray, np.ndarray]:
+ seq = self.data[index]
+ seq_t = self.data_stamp[index]
+ return seq, seq_t
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+
+class WindowDataset(TimeSeriesDataset):
+ """
+ Build Windowed dataset (with each element as multivariable time series
+ with a sliding window and corresponding timestamps embedding),
+ the sliding step is one unit in give data source
+
+ Args:
+ data_source: the whole multivariate time series for a while
+ time_embed: embedding frequency, see `utils/timefeatures.py` for more
detail
+ input_len: input window size (unit) [1, 2, ... I]
+ pred_len: output window size (unit) right after the input window [I+1,
I+2, ... I+P]
+
+ Returns:
+ Random accessible dataset
+ """
+
+ def __init__(self,
+ data_source: DataSource = None,
+ input_len: int = 96,
+ pred_len: int = 96,
+ time_embed: str = 'h'):
+ self.input_len = input_len
+ self.pred_len = pred_len
+ super(WindowDataset, self).__init__(data_source, time_embed)
+ if input_len > self.data.shape[0]:
+ raise RuntimeError('input_len should not be larger than the number
of time series points')
+ if pred_len > self.data.shape[0]:
+ raise RuntimeError('pred_len should not be larger than the number
of time series points')
+
+ def __getitem__(self, index) -> Tuple[np.ndarray, np.ndarray, np.ndarray,
np.ndarray]:
+ s_begin = index
+ s_end = s_begin + self.input_len
+ r_begin = s_end
+ r_end = s_end + self.pred_len
+ seq_x = self.data[s_begin:s_end]
+ seq_y = self.data[r_begin:r_end]
+ seq_x_t = self.data_stamp[s_begin:s_end]
+ seq_y_t = self.data_stamp[r_begin:r_end]
+ return seq_x, seq_y, seq_x_t, seq_y_t
+
+ def __len__(self) -> int:
+ return len(self.data) - self.input_len - self.pred_len + 1
diff --git a/mlnode/iotdb/mlnode/data_access/offline/source.py
b/mlnode/iotdb/mlnode/data_access/offline/source.py
new file mode 100644
index 00000000000..4f29241913f
--- /dev/null
+++ b/mlnode/iotdb/mlnode/data_access/offline/source.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from typing import List
+
+import numpy as np
+import pandas as pd
+
+from iotdb.mlnode.client import client_manager
+
+
+class DataSource(object):
+ """
+ Pre-fetched in multi-variate time series in memory
+
+ Methods:
+ get_data: returns self.data, the time series value (Numpy.2DArray)
+ get_timestamp: returns self.timestamp, the aligned timestamp value
+ """
+
+ def __init__(self):
+ self.data = None
+ self.timestamp = None
+ self._read_data()
+
+ def _read_data(self):
+ raise NotImplementedError
+
+ def get_data(self) -> np.ndarray:
+ return self.data
+
+ def get_timestamp(self) -> np.ndarray:
+ return self.timestamp
+
+
+class FileDataSource(DataSource):
+ def __init__(self, filename: str = None):
+ self.filename = filename
+ super(FileDataSource, self).__init__()
+
+ def _read_data(self) -> None:
+ try:
+ raw_data = pd.read_csv(self.filename)
+ except Exception:
+ raise RuntimeError(f'Fail to load data with filename:
{self.filename}')
+ cols_data = raw_data.columns[1:]
+ self.data = raw_data[cols_data].values
+ self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values)
+
+
+class ThriftDataSource(DataSource):
+ def __init__(self, query_expressions: List = None, query_filter: str =
None):
+ self.query_expressions = query_expressions
+ self.query_filter = query_filter
+ super(ThriftDataSource, self).__init__()
+
+ def _read_data(self) -> None:
+ try:
+ data_client = client_manager.borrow_data_node_client()
+ except Exception:
+ raise RuntimeError('Fail to establish connection with DataNode')
+
+ raw_data = data_client.fetch_timeseries(self.query_expressions,
self.query_filter)
+
+ cols_data = raw_data.columns[1:]
+ self.data = raw_data[cols_data].values
+ self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values,
unit='ms', utc=True) \
+ .tz_convert('Asia/Shanghai') # for iotdb
diff --git a/mlnode/iotdb/mlnode/exception.py
b/mlnode/iotdb/mlnode/data_access/utils/__init__.py
similarity index 67%
copy from mlnode/iotdb/mlnode/exception.py
copy to mlnode/iotdb/mlnode/data_access/utils/__init__.py
index 6307909a9ac..2a1e720805f 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/data_access/utils/__init__.py
@@ -15,17 +15,3 @@
# specific language governing permissions and limitations
# under the License.
#
-
-class _BaseError(Exception):
- """Base class for exceptions in this module."""
- pass
-
-
-class BadNodeUrlError(_BaseError):
- def __init__(self, node_url: str):
- self.message = "Bad node url: {}".format(node_url)
-
-
-class ModelNotExistError(_BaseError):
- def __init__(self, file_path: str):
- self.message = "Model path: ({}) not exists".format(file_path)
diff --git a/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
b/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
new file mode 100644
index 00000000000..ecd6784ca4a
--- /dev/null
+++ b/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from typing import List
+
+import numpy as np
+import pandas as pd
+from pandas.tseries import offsets
+from pandas.tseries.frequencies import to_offset
+
+
+class TimeFeature:
+ def __init__(self):
+ pass
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ pass
+
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+
+
+class SecondOfMinute(TimeFeature):
+ """Minute of hour encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return index.second / 59.0 - 0.5
+
+
+class MinuteOfHour(TimeFeature):
+ """Minute of hour encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return index.minute / 59.0 - 0.5
+
+
+class HourOfDay(TimeFeature):
+ """Hour of day encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return index.hour / 23.0 - 0.5
+
+
+class DayOfWeek(TimeFeature):
+ """Hour of day encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return index.dayofweek / 6.0 - 0.5
+
+
+class DayOfMonth(TimeFeature):
+ """Day of month encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return (index.day - 1) / 30.0 - 0.5
+
+
+class DayOfYear(TimeFeature):
+ """Day of year encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return (index.dayofyear - 1) / 365.0 - 0.5
+
+
+class MonthOfYear(TimeFeature):
+ """Month of year encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return (index.month - 1) / 11.0 - 0.5
+
+
+class WeekOfYear(TimeFeature):
+ """Week of year encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return (index.isocalendar().week - 1) / 52.0 - 0.5
+
+
+def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
+ """
+ Embedding timestamp by given frequency string
+ Args:
+ freq_str: frequency string of the form [multiple][granularity] such as
'12H', '5min', '1D' etc.
+ Returns:
+ a list of time features that will be appropriate for the given
frequency string.
+ """
+
+ features_by_offsets = {
+ offsets.YearEnd: [],
+ offsets.QuarterEnd: [MonthOfYear],
+ offsets.MonthEnd: [MonthOfYear],
+ offsets.Week: [DayOfMonth, WeekOfYear],
+ offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
+ offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
+ offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
+ offsets.Minute: [
+ MinuteOfHour,
+ HourOfDay,
+ DayOfWeek,
+ DayOfMonth,
+ DayOfYear,
+ ],
+ offsets.Second: [
+ SecondOfMinute,
+ MinuteOfHour,
+ HourOfDay,
+ DayOfWeek,
+ DayOfMonth,
+ DayOfYear,
+ ],
+ }
+
+ try:
+ offset = to_offset(freq_str)
+
+ for offset_type, feature_classes in features_by_offsets.items():
+ if isinstance(offset, offset_type):
+ return [cls() for cls in feature_classes]
+ except ValueError:
+ supported_freq_msg = f'''
+ Unsupported time embedding frequency ({freq_str})
+ The following frequencies are supported (case-insensitive):
+ Y - yearly
+ alias: A
+ M - monthly
+ W - weekly
+ D - daily
+ B - business days
+ H - hourly
+ T - minutely
+ alias: min
+ S - secondly
+ '''
+ raise RuntimeError(supported_freq_msg)
+
+
+def time_features(dates, time_embed='h'):
+ return np.vstack([feat(dates) for feat in
time_features_from_frequency_str(time_embed)])
+
+
+def data_transform(data_raw: pd.DataFrame, freq='h'):
+ """
+ data: dataframe, column 0 is the time stamp
+ """
+ columns = data_raw.columns
+ data = data_raw[columns[1:]]
+ data_stamp = data_raw[columns[0]]
+ return data.values, data_stamp
+
+
+def timestamp_transform(timestamp_raw: pd.DataFrame, freq='h'):
+ """
+ """
+ timestamp = pd.to_datetime(timestamp_raw.values.squeeze(), unit='ms',
utc=True).tz_convert('Asia/Shanghai')
+ timestamp = time_features(timestamp, freq=freq)
+ timestamp = timestamp.transpose(1, 0)
+ return timestamp
diff --git a/mlnode/iotdb/mlnode/exception.py b/mlnode/iotdb/mlnode/exception.py
index 6307909a9ac..47edb95eb43 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/exception.py
@@ -18,7 +18,11 @@
class _BaseError(Exception):
"""Base class for exceptions in this module."""
- pass
+ def __init__(self):
+ self.message = None
+
+ def __str__(self) -> str:
+ return self.message
class BadNodeUrlError(_BaseError):
@@ -28,4 +32,19 @@ class BadNodeUrlError(_BaseError):
class ModelNotExistError(_BaseError):
def __init__(self, file_path: str):
- self.message = "Model path: ({}) not exists".format(file_path)
+ self.message = "Model path is not exists: {} ".format(file_path)
+
+
+class BadConfigValueError(_BaseError):
+ def __init__(self, config_name: str, config_value, hint: str = ''):
+ self.message = "Bad value ({0}) for config ({1}).
{2}".format(config_value, config_name, hint)
+
+
+class MissingConfigError(_BaseError):
+ def __init__(self, config_name: str):
+ self.message = "Missing config: {}".format(config_name)
+
+
+class WrongTypeConfigError(_BaseError):
+ def __init__(self, config_name: str, expected_type: str):
+ self.message = "Wrong type for config: {0}, expected:
{1}".format(config_name, expected_type)
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index 78021420d43..d4b4ca3b034 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -15,7 +15,14 @@
# specific language governing permissions and limitations
# under the License.
#
+
+from iotdb.mlnode.config import descriptor
from iotdb.mlnode.constant import TSStatusCode
+from iotdb.mlnode.data_access.factory import create_forecast_dataset
+from iotdb.mlnode.log import logger
+from iotdb.mlnode.parser import parse_forecast_request, parse_training_request
+from iotdb.mlnode.process.manager import TaskManager
+from iotdb.mlnode.serde import convert_to_binary
from iotdb.mlnode.storage import model_storage
from iotdb.mlnode.util import get_status
from iotdb.thrift.mlnode import IMLNodeRPCService
@@ -26,19 +33,53 @@ from iotdb.thrift.mlnode.ttypes import
(TCreateTrainingTaskReq,
class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
def __init__(self):
- pass
+ self.__task_manager =
TaskManager(pool_size=descriptor.get_config().get_mn_mn_task_pool_size())
def deleteModel(self, req: TDeleteModelReq):
try:
model_storage.delete_model(req.modelId)
return get_status(TSStatusCode.SUCCESS_STATUS)
except Exception as e:
+ logger.warn(e)
return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
def createTrainingTask(self, req: TCreateTrainingTaskReq):
- return get_status(TSStatusCode.SUCCESS_STATUS)
+ task = None
+ try:
+ # parse request, check required config and config type
+ data_config, model_config, task_config =
parse_training_request(req)
+ # create dataset & check data config legitimacy
+ dataset, data_config = create_forecast_dataset(**data_config)
+
+ model_config['input_vars'] = data_config['input_vars']
+ model_config['output_vars'] = data_config['output_vars']
+
+ # create task & check task config legitimacy
+ task = self.__task_manager.create_training_task(dataset,
data_config, model_config, task_config)
+
+ return get_status(TSStatusCode.SUCCESS_STATUS)
+ except Exception as e:
+ logger.warn(e)
+ return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
+ finally:
+ # submit task stage & check resource and decide pending/start
+ self.__task_manager.submit_training_task(task)
def forecast(self, req: TForecastReq):
- status = get_status(TSStatusCode.SUCCESS_STATUS)
- forecast_result = b'forecast result'
- return TForecastResp(status, forecast_result)
+ model_path, data, pred_length = parse_forecast_request(req)
+ model, model_configs = model_storage.load_model(model_path)
+ task_configs = {'pred_len': pred_length}
+ try:
+ task = self.__task_manager.create_forecast_task(
+ task_configs,
+ model_configs,
+ data,
+ model_path
+ )
+ # submit task stage & check resource and decide pending/start
+ forecast_result =
convert_to_binary(self.__task_manager.submit_forecast_task(task))
+ resp = TForecastResp(get_status(TSStatusCode.SUCCESS_STATUS),
forecast_result)
+ return resp
+ except Exception as e:
+ logger.warn(e)
+ return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
new file mode 100644
index 00000000000..34bb7eb65a9
--- /dev/null
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -0,0 +1,228 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import argparse
+import re
+from typing import Dict, List, Tuple
+
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
+from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
+from iotdb.mlnode.serde import convert_to_df
+from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TForecastReq
+
+
+class _ConfigParser(argparse.ArgumentParser):
+ """
+ A parser for parsing configs from configs: dict
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def parse_configs(self, configs) -> Dict:
+ """
+ Parse configs from a dict
+ Args:configs: a dict of all configs which contains all required
arguments
+ Returns: a dict of parsed configs
+ """
+ args = self.parse_dict(configs)
+ return vars(self.parse_known_args(args)[0])
+
+ @staticmethod
+ def parse_dict(config_dict: Dict) -> List:
+ """
+ Parse a dict of configs to a list of arguments
+ Args:config_dict: a dict of configs
+ Returns: a list of arguments which can be parsed by argparse
+ """
+ args = []
+ for k, v in config_dict.items():
+ if v is None:
+ continue
+ args.append("--{}".format(k))
+ if isinstance(v, str) and re.match(r'^\[(.*)]$', v):
+ v = eval(v)
+ v = [str(i) for i in v]
+ args.extend(v)
+ elif isinstance(v, list):
+ args.extend([str(i) for i in v])
+ elif isinstance(v, bool):
+ args.append(str(v).lower())
+ else:
+ args.append(v)
+ return args
+
+ def error(self, message: str):
+ """
+ Override the error method to raise exceptions instead of exiting
+ """
+ if message.startswith('the following arguments are required:'):
+ missing_arg = re.findall(r': --(\w+)', message)[0]
+ raise MissingConfigError(missing_arg)
+ elif re.match(r'argument --\w+: invalid \w+ value:', message):
+ argument = re.findall(r'argument --(\w+):', message)[0]
+ expected_type = re.findall(r'invalid (\w+) value:', message)[0]
+ raise WrongTypeConfigError(argument, expected_type)
+ else:
+ raise Exception(message)
+
+
+def str2bool(value):
+ if value.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif value.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+""" Argument description:
+ - query_expressions: query expressions
+ - query_filter: query filter
+ - source_type: source type
+ - filename: filename
+ - dataset_type: dataset type
+ - time_embed: freq for time features encoding
+ - input_len: input sequence length
+ - pred_len: prediction sequence length
+ - input_vars: number of input variables
+ - output_vars: number of output variables
+"""
+_data_config_parser = _ConfigParser()
+_data_config_parser.add_argument('--source_type',
+ type=DataSourceType,
+ default=DataSourceType.THRIFT,
+ choices=list(DataSourceType))
+_data_config_parser.add_argument('--dataset_type',
+ type=DatasetType,
+ default=DatasetType.WINDOW,
+ choices=list(DatasetType))
+_data_config_parser.add_argument('--filename', type=str, default='')
+_data_config_parser.add_argument('--query_expressions', type=str, nargs='*',
default=[])
+_data_config_parser.add_argument('--query_filter', type=str, default=None)
+_data_config_parser.add_argument('--time_embed', type=str, default='h')
+_data_config_parser.add_argument('--input_len', type=int, default=96)
+_data_config_parser.add_argument('--pred_len', type=int, default=96)
+
+""" Argument description:
+ - model_name: model name
+ - input_len: input sequence length
+ - pred_len: prediction sequence length
+ - input_vars: number of input variables
+ - output_vars: number of output variables
+ - task_type: task type, options:[M, S, MS];
+ M:multivariate predict multivariate,
+ S:univariate predict univariate,
+ MS:multivariate predict univariate'
+ - kernel_size: kernel size
+ - block_type: block type
+ - d_model: dimension of feature in model
+ - inner_layers: number of inner layers
+ - outer_layers: number of outer layers
+"""
+_model_config_parser = _ConfigParser()
+_model_config_parser.add_argument('--model_name', type=str, required=True)
+_model_config_parser.add_argument('--input_len', type=int, default=96)
+_model_config_parser.add_argument('--pred_len', type=int, default=96)
+_model_config_parser.add_argument('--forecast_task_type',
+ type=ForecastTaskType,
+ default=ForecastTaskType.ENDOGENOUS,
+ choices=list(ForecastTaskType))
+_model_config_parser.add_argument('--kernel_size', type=int, default=25)
+_model_config_parser.add_argument('--block_type', type=str, default='generic')
+_model_config_parser.add_argument('--d_model', type=int, default=128)
+_model_config_parser.add_argument('--inner_layers', type=int, default=4)
+_model_config_parser.add_argument('--outer_layers', type=int, default=4)
+
+""" Argument description:
+ - model_id: model id
+ - tuning: whether to tune hyperparameters
+ - task_type: task type, options:[M, S, MS]; M:multivariate predict
multivariate, S:univariate predict univariate,
+ MS:multivariate predict univariate'
+ - task_class: task class
+ - input_len: input sequence length
+ - pred_len: prediction sequence length
+ - input_vars: number of input variables
+ - output_vars: number of output variables
+ - learning_rate: learning rate
+ - batch_size: batch size
+ - num_workers: number of workers
+ - epochs: number of epochs
+ - use_gpu: whether to use gpu
+ - use_multi_gpu: whether to use multi-gpu
+ - devices: devices to use
+ - metric_names: metric to use
+"""
+_task_config_parser = _ConfigParser()
+_task_config_parser.add_argument('--task_class', type=str, required=True)
+_task_config_parser.add_argument('--model_id', type=str, required=True)
+_task_config_parser.add_argument('--tuning', type=str2bool, default=False)
+_task_config_parser.add_argument('--forecast_task_type',
+ type=ForecastTaskType,
+ default=ForecastTaskType.ENDOGENOUS,
+ choices=list(ForecastTaskType))
+_task_config_parser.add_argument('--input_len', type=int, default=96)
+_task_config_parser.add_argument('--pred_len', type=int, default=96)
+_task_config_parser.add_argument('--learning_rate', type=float, default=0.0001)
+_task_config_parser.add_argument('--batch_size', type=int, default=32)
+_task_config_parser.add_argument('--num_workers', type=int, default=0)
+_task_config_parser.add_argument('--epochs', type=int, default=10)
+_task_config_parser.add_argument('--use_gpu', type=str2bool, default=False)
+# _task_config_parser.add_argument('--gpu', type=int, default=0)
+# _task_config_parser.add_argument('--use_multi_gpu', type=str2bool,
default=False)
+# _task_config_parser.add_argument('--devices', type=int, nargs='+',
default=[0])
+_task_config_parser.add_argument('--metric_names', type=str, nargs='+',
default=['MSE'])
+
+
+def parse_training_request(req: TCreateTrainingTaskReq) -> Tuple[Dict, Dict,
Dict]:
+ """
+ Parse TCreateTrainingTaskReq with given yaml template
+ Args:
+ req: TCreateTrainingTaskReq
+ Returns:
+ data_config: configurations related to data
+ model_config: configurations related to model
+ task_config: configurations related to task
+ """
+ config = req.modelConfigs
+ config.update(model_name=config['model_type'])
+ config.update(task_class=config['model_task'])
+ config.update(model_id=req.modelId)
+ config.update(tuning=req.isAuto)
+ config.update(query_expressions=req.queryExpressions)
+ config.update(query_filter=req.queryFilter)
+
+ data_config = _data_config_parser.parse_configs(config)
+ model_config = _model_config_parser.parse_configs(config)
+ task_config = _task_config_parser.parse_configs(config)
+ return data_config, model_config, task_config
+
+
+def parse_forecast_request(req: TForecastReq):
+ model_path = req.modelPath
+ column_name_list = req.inputColumnNameList
+ column_type_list = req.inputTypeList
+ ts_dataset = req.inputData
+ pred_len = req.predictLength
+
+ data = convert_to_df(column_name_list, column_type_list, None,
[ts_dataset])
+ time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]]
+ full_data = (data, time_stamp)
+ return model_path, full_data, pred_len
diff --git a/mlnode/iotdb/mlnode/exception.py
b/mlnode/iotdb/mlnode/process/__init__.py
similarity index 67%
copy from mlnode/iotdb/mlnode/exception.py
copy to mlnode/iotdb/mlnode/process/__init__.py
index 6307909a9ac..2a1e720805f 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/process/__init__.py
@@ -15,17 +15,3 @@
# specific language governing permissions and limitations
# under the License.
#
-
-class _BaseError(Exception):
- """Base class for exceptions in this module."""
- pass
-
-
-class BadNodeUrlError(_BaseError):
- def __init__(self, node_url: str):
- self.message = "Bad node url: {}".format(node_url)
-
-
-class ModelNotExistError(_BaseError):
- def __init__(self, file_path: str):
- self.message = "Model path: ({}) not exists".format(file_path)
diff --git a/mlnode/iotdb/mlnode/process/manager.py
b/mlnode/iotdb/mlnode/process/manager.py
new file mode 100644
index 00000000000..5f40efd3f8f
--- /dev/null
+++ b/mlnode/iotdb/mlnode/process/manager.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import sys
+import signal
+import psutil
+
+import multiprocessing as mp
+from typing import Dict, Union
+
+import pandas as pd
+from torch.utils.data import Dataset
+from subprocess import call
+
+from iotdb.mlnode.log import logger
+from iotdb.mlnode.process.task import (ForecastingInferenceTask,
+ ForecastingSingleTrainingTask,
+ ForecastingTuningTrainingTask)
+
+
+class TaskManager(object):
+ def __init__(self, pool_size: int):
+ """
+ Args:
+ pool_size: specify the maximum process number of the process pool
+
+ __shared_resource_manager: a manager that manage resources shared
between processes
+ __pid_info: a map shared between processes, can be used to find the
pid with model_id and trial_id
+ __training_process_pool: a multiprocessing process pool
+ """
+ self.__shared_resource_manager = mp.Manager()
+ self.__pid_info = self.__shared_resource_manager.dict()
+ self.__training_process_pool = mp.Pool(pool_size)
+ self.__inference_process_pool = mp.Pool(pool_size)
+
+ def create_training_task(self,
+ dataset: Dataset,
+ data_configs: Dict,
+ model_configs: Dict,
+ task_configs: Dict):
+ """
+
+ Args:
+ dataset: a torch dataset to be used for training
+ data_configs: dict of data configurations
+ model_configs: dict of model configurations
+ task_configs: dict of task configurations
+
+ Returns:
+ task: a training task for forecasting, which can be submitted to
self.__training_process_pool
+ """
+ model_id = task_configs['model_id']
+ if task_configs['tuning']:
+ task = ForecastingTuningTrainingTask(
+ task_configs,
+ model_configs,
+ self.__pid_info,
+ data_configs,
+ dataset,
+ model_id,
+ )
+ else:
+ task = ForecastingSingleTrainingTask(
+ task_configs,
+ model_configs,
+ self.__pid_info,
+ data_configs,
+ dataset,
+ model_id,
+ )
+ return task
+
+ def submit_training_task(self, task: Union[ForecastingTuningTrainingTask,
ForecastingSingleTrainingTask]) -> None:
+ if task is not None:
+ self.__training_process_pool.apply_async(task, args=())
+ logger.info(f'Task: ({task.model_id}) - Training process submitted
successfully')
+
+ def create_forecast_task(self,
+ task_configs,
+ model_configs,
+ data,
+ model_path) -> ForecastingInferenceTask:
+ task = ForecastingInferenceTask(
+ task_configs,
+ model_configs,
+ self.__pid_info,
+ data,
+ model_path
+ )
+ return task
+
+ def submit_forecast_task(self, task: ForecastingInferenceTask) ->
pd.DataFrame:
+ read_pipe, send_pipe = mp.Pipe()
+ if task is not None:
+ self.__inference_process_pool.apply_async(task, args=(send_pipe,))
+ logger.info('Forecasting process submitted successfully')
+ return read_pipe.recv()
+
+ def kill_task(self, model_id):
+ """
+ Kill the process by pid, will check whether the pid is training or
inference process
+ """
+ pid = self.__pid_info[model_id]
+ if sys.platform == 'win32':
+ try:
+ process = psutil.Process(pid=pid)
+ process.send_signal(signal.CTRL_BREAK_EVENT)
+ except psutil.NoSuchProcess:
+ print(f'Tried to kill process (pid = {pid}), '
+ f'but the process does not exist.')
+ else:
+ cmds = ['kill', str(pid)]
+ call(cmds)
diff --git a/mlnode/iotdb/mlnode/process/task.py
b/mlnode/iotdb/mlnode/process/task.py
new file mode 100644
index 00000000000..afdef68f4ea
--- /dev/null
+++ b/mlnode/iotdb/mlnode/process/task.py
@@ -0,0 +1,294 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+from abc import abstractmethod
+from multiprocessing.connection import Connection
+from typing import Dict, Tuple
+
+import numpy as np
+import optuna
+import pandas as pd
+import torch
+from torch.utils.data import Dataset
+
+from iotdb.mlnode.algorithm.factory import create_forecast_model
+from iotdb.mlnode.client import client_manager
+from iotdb.mlnode.config import descriptor
+from iotdb.mlnode.log import logger
+from iotdb.mlnode.process.trial import ForecastingTrainingTrial
+from iotdb.mlnode.storage import model_storage
+from iotdb.thrift.common.ttypes import TrainingState
+
+
+class ForestingTrainingObjective:
+ """
+ A class which serve as a function, should accept trial as args
+ and return the optimization objective.
+ Optuna will try to minimize the objective.
+ """
+
+ def __init__(
+ self,
+ trial_configs: Dict,
+ model_configs: Dict,
+ dataset: Dataset,
+ # pid_info: Dict
+ ):
+ self.trial_configs = trial_configs
+ self.model_configs = model_configs
+ self.dataset = dataset
+ # self.pid_info = pid_info
+
+ def __call__(self, trial: optuna.Trial):
+ # TODO: decide which parameters to tune
+ trial_configs = self.trial_configs
+ trial_configs['learning_rate'] = trial.suggest_float("lr", 1e-7, 1e-1,
log=True)
+ trial_configs['trial_id'] = 'tid_' + str(trial._trial_id)
+ # TODO: check args
+ model, model_configs = create_forecast_model(**self.model_configs)
+ # self.pid_info[self.trial_configs['model_id']][trial._trial_id] =
os.getpid()
+ _trial = ForecastingTrainingTrial(trial_configs, model, model_configs,
self.dataset)
+ loss = _trial.start()
+ return loss
+
+
+class _BasicTask(object):
+ """
+ This class serve as a function, accepting configs and launch trials
+ according to the configs.
+ """
+
+ def __init__(
+ self,
+ task_configs: Dict,
+ model_configs: Dict,
+ pid_info: Dict
+ ):
+ """
+ Args:
+ task_configs:
+ model_configs:
+ pid_info:
+ """
+ self.pid_info = pid_info
+ self.task_configs = task_configs
+ self.model_configs = model_configs
+
+ @abstractmethod
+ def __call__(self):
+ raise NotImplementedError
+
+
+class _BasicTrainingTask(_BasicTask):
+ def __init__(
+ self,
+ task_configs: Dict,
+ model_configs: Dict,
+ pid_info: Dict,
+ data_configs: Dict,
+ dataset: Dataset,
+ ):
+ """
+ Args:
+ task_configs:
+ model_configs:
+ pid_info:
+ data_configs:
+ dataset:
+ """
+ super().__init__(task_configs, model_configs, pid_info)
+ self.data_configs = data_configs
+ self.dataset = dataset
+ self.confignode_client = client_manager.borrow_config_node_client()
+
+ @abstractmethod
+ def __call__(self):
+ raise NotImplementedError
+
+
+class _BasicInferenceTask(_BasicTask):
+ def __init__(
+ self,
+ task_configs: Dict,
+ model_configs: Dict,
+ pid_info: Dict,
+ data: Tuple,
+ model_path: str
+ ):
+ """
+ Args:
+ task_configs:
+ model_configs:
+ pid_info:
+ data:
+ model:
+ """
+ super().__init__(task_configs, model_configs, pid_info)
+ self.data = data
+ self.input_len = self.model_configs['input_len']
+ self.model_path = model_path
+
+ @abstractmethod
+ def __call__(self, pipe: Connection = None):
+ raise NotImplementedError
+
+ @abstractmethod
+ def data_align(self, *args):
+ raise NotImplementedError
+
+
+class ForecastingSingleTrainingTask(_BasicTrainingTask):
+ def __init__(
+ self,
+ task_configs: Dict,
+ model_configs: Dict,
+ pid_info: Dict,
+ data_configs: Dict,
+ dataset: Dataset,
+ model_id: str,
+ ):
+ """
+ Args:
+ task_configs: dict of task configurations
+ model_configs: dict of model configurations
+ pid_info: a map shared between processes, can be used to find the
pid with model_id and trial_id
+ data_configs: dict of data configurations
+ dataset: training dataset
+ """
+ super().__init__(task_configs, model_configs, pid_info, data_configs,
dataset)
+ self.model_id = model_id
+ self.default_trial_id = 'tid_0'
+ self.task_configs['trial_id'] = self.default_trial_id
+ model, model_configs = create_forecast_model(**model_configs)
+ self.trial = ForecastingTrainingTrial(task_configs, model,
model_configs, dataset)
+
+ def __call__(self):
+ try:
+ self.pid_info[self.model_id] = os.getpid()
+ self.trial.start()
+ self.confignode_client.update_model_state(self.model_id,
TrainingState.FINISHED, self.default_trial_id)
+ except Exception as e:
+ logger.warn(e)
+ self.confignode_client.update_model_state(self.model_id,
TrainingState.FAILED, self.default_trial_id)
+ raise e
+
+
+class ForecastingTuningTrainingTask(_BasicTrainingTask):
+ def __init__(
+ self,
+ task_configs: Dict,
+ model_configs: Dict,
+ pid_info: Dict,
+ data_configs: Dict,
+ dataset: Dataset,
+ model_id: str,
+ ):
+ """
+ Args:
+ task_configs: dict of task configurations
+ model_configs: dict of model configurations
+ pid_info: a map shared between processes, can be used to find the
pid with model_id and trial_id
+ data_configs: dict of data configurations
+ dataset: training dataset
+ """
+ super().__init__(task_configs, model_configs, pid_info, data_configs,
dataset)
+ self.model_id = model_id
+ self.study = optuna.create_study(direction='minimize')
+
+ def __call__(self):
+ self.pid_info[self.model_id] = os.getpid()
+ try:
+ self.study.optimize(ForestingTrainingObjective(
+ self.task_configs,
+ self.model_configs,
+ self.dataset),
+ n_trials=descriptor.get_config().get_mn_tuning_trial_num(),
+
n_jobs=descriptor.get_config().get_mn_tuning_trial_concurrency())
+ best_trial_id = 'tid_' + str(self.study.best_trial._trial_id)
+ self.confignode_client.update_model_state(self.model_id,
TrainingState.FINISHED, best_trial_id)
+ except Exception as e:
+ logger.warn(e)
+ self.confignode_client.update_model_state(self.model_id,
TrainingState.FAILED)
+ raise e
+
+
+class ForecastingInferenceTask(_BasicInferenceTask):
+ def __init__(
+ self,
+ task_configs: Dict,
+ model_configs: Dict,
+ pid_info: Dict,
+ data: Tuple,
+ model_path: str
+ ):
+ super().__init__(task_configs, model_configs, pid_info, data,
model_path)
+ self.pred_len = self.task_configs['pred_len']
+ self.model_pred_len = self.model_configs['pred_len']
+
+ def __call__(self, pipe: Connection = None):
+ self.model, _ = model_storage.load_model(self.model_path)
+ data, time_stamp = self.data
+ L, C = data.shape
+ time_stamp = pd.to_datetime(time_stamp.values[:, 0], unit='ms',
utc=True) \
+ .tz_convert('Asia/Shanghai') # for iotdb
+ data, time_stamp = self.data_align(data, time_stamp)
+ full_data, full_data_stamp = data, time_stamp
+ current_pred_len = 0
+ while current_pred_len < self.pred_len:
+ current_data = full_data[:, -self.input_len:, :]
+ current_data = torch.Tensor(current_data)
+ output_data = self.model(current_data).detach().numpy()
+ full_data = np.concatenate([full_data, output_data], axis=1)
+ current_pred_len += self.model_pred_len
+ full_data_stamp = self.generate_future_mark(full_data_stamp,
self.pred_len)
+ ret_data = pd.concat(
+ [pd.DataFrame(full_data_stamp.astype(np.int64)),
+ pd.DataFrame(full_data[0, -self.pred_len:,
:]).astype(np.double)], axis=1)
+ ret_data.columns = list(np.arange(0, C + 1))
+ pipe.send(ret_data)
+
+ def data_align(self, data: pd.DataFrame, data_stamp) -> Tuple[np.ndarray,
pd.DataFrame]:
+ """
+ data: L x C, DataFrame, suppose no batch dim
+ time_stamp: L x 1, DataFrame
+ """
+ data_stamp = pd.DataFrame(data_stamp)
+ assert len(data.shape) == 2, 'expect inference data to have two
dimensions'
+ assert len(data_stamp.shape) == 2, 'expect inference timestamps to be
shaped as [L, 1]'
+ time_deltas = data_stamp.diff().dropna()
+ mean_timedelta = time_deltas.mean()[0]
+ data = data.values
+ if data.shape[0] < self.input_len:
+ extra_len = self.input_len - data.shape[0]
+ data = np.concatenate([np.mean(data, axis=0,
keepdims=True).repeat(extra_len, axis=0), data], axis=0)
+ extrapolated_timestamp = pd.date_range(data_stamp[0][0] -
extra_len * mean_timedelta, periods=extra_len,
+ freq=mean_timedelta)
+ data_stamp = pd.concat([extrapolated_timestamp.to_frame(),
data_stamp])
+ else:
+ data = data[-self.input_len:, :]
+ data_stamp = data_stamp[-self.input_len:]
+ data = data[None, :] # add batch dim
+ return data, data_stamp
+
+ def generate_future_mark(self, data_stamp: pd.DataFrame, future_len: int)
-> pd.DatetimeIndex:
+ time_deltas = data_stamp.diff().dropna()
+ mean_timedelta = time_deltas.mean()[0]
+ extrapolated_timestamp = pd.date_range(data_stamp.values[0][0],
periods=future_len,
+ freq=mean_timedelta)
+ return extrapolated_timestamp[:, None]
diff --git a/mlnode/iotdb/mlnode/process/trial.py
b/mlnode/iotdb/mlnode/process/trial.py
new file mode 100644
index 00000000000..7fc34477686
--- /dev/null
+++ b/mlnode/iotdb/mlnode/process/trial.py
@@ -0,0 +1,263 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import time
+from abc import abstractmethod
+from typing import Dict, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader, Dataset
+
+from iotdb.mlnode.algorithm.metric import all_metrics, build_metrics
+from iotdb.mlnode.client import client_manager
+from iotdb.mlnode.log import logger
+from iotdb.mlnode.storage import model_storage
+from iotdb.thrift.common.ttypes import TrainingState
+
+
+def _parse_trial_config(**kwargs):
+ support_cfg = {
+ "batch_size": 32,
+ "learning_rate": 0.0001,
+ "epochs": 10,
+ "input_len": 96,
+ "pred_len": 96,
+ "num_workers": 0,
+ "use_gpu": False,
+ # "gpu": 0,
+ # "use_multi_gpu": False,
+ # "devices": [0],
+ "metric_names": ["MSE"],
+ "model_id": 'default',
+ "trial_id": 'default_trial'
+ }
+
+ trial_config = {}
+
+ for k, v in kwargs.items():
+ if k in support_cfg.keys():
+ if not isinstance(v, type(support_cfg[k])):
+ raise RuntimeError(
+ 'Trial config {} should have {} type, but got {}
instead'.format(k, type(support_cfg[k]).__name__,
+
type(v).__name__)
+ )
+ trial_config[k] = v
+
+ if trial_config['input_len'] <= 0:
+ raise RuntimeError(
+ 'Trial config input_len should be positive integer but got
{}'.format(trial_config['input_len'])
+ )
+
+ if trial_config['pred_len'] <= 0:
+ raise RuntimeError(
+ 'Trial config pred_len should be positive integer but got
{}'.format(trial_config['pred_len'])
+ )
+
+ for metric in trial_config['metric_names']:
+ if metric not in all_metrics:
+ raise RuntimeError(
+ f'Unknown metric type: ({metric}), which'
+ f' should be one of {all_metrics}'
+ )
+ return trial_config
+
+
+class BasicTrial(object):
+ def __init__(
+ self,
+ task_configs: Dict,
+ model: nn.Module,
+ model_configs: Dict,
+ dataset: Dataset
+ ):
+ self.trial_configs = task_configs
+ self.model_id = task_configs['model_id']
+ self.trial_id = task_configs['trial_id']
+ self.batch_size = task_configs['batch_size']
+ self.learning_rate = task_configs['learning_rate']
+ self.epochs = task_configs['epochs']
+ self.num_workers = task_configs['num_workers']
+ self.pred_len = task_configs['pred_len']
+ self.metric_names = task_configs['metric_names']
+ self.use_gpu = task_configs['use_gpu']
+ self.model = model
+ self.model_configs = model_configs
+
+ self.device = self.__acquire_device()
+ self.model = self.model.to(self.device)
+ self.dataset = dataset
+
+ def __acquire_device(self):
+ if self.use_gpu:
+ raise NotImplementedError
+ else:
+ device = torch.device('cpu')
+ return device
+
+ def _build_dataloader(self) -> DataLoader:
+ """
+ Returns:
+ training dataloader built with the dataset
+ """
+ return DataLoader(
+ self.dataset,
+ shuffle=True,
+ drop_last=True,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers
+ )
+
+ @abstractmethod
+ def start(self):
+ raise NotImplementedError
+
+
+class ForecastingTrainingTrial(BasicTrial):
+ def __init__(
+ self,
+ task_configs: Dict,
+ model: nn.Module,
+ model_configs: Dict,
+ dataset: Dataset,
+ ):
+ """
+ A training trial, accept all parameters needed and train a single
model.
+
+ Args:
+ trial_configs: dict of trial's configurations
+ model: torch.nn.Module
+ model_configs: dict of model's configurations
+ dataset: training dataset
+ **kwargs:
+ """
+ super(ForecastingTrainingTrial, self).__init__(task_configs, model,
model_configs, dataset)
+
+ self.dataloader = self._build_dataloader()
+ self.datanode_client = client_manager.borrow_data_node_client()
+ self.confignode_client = client_manager.borrow_config_node_client()
+ self.criterion = torch.nn.MSELoss()
+ self.optimizer = torch.optim.Adam(self.model.parameters(),
lr=self.learning_rate)
+ self.metrics_dict = build_metrics(self.metric_names)
+
+ def train(self, epoch: int) -> float:
+ self.model.train()
+ train_loss = []
+ epoch_time = time.time()
+
+ for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in
enumerate(self.dataloader):
+ self.optimizer.zero_grad()
+ batch_x = batch_x.float().to(self.device)
+ batch_y = batch_y.float().to(self.device)
+
+ batch_x_mark = batch_x_mark.float().to(self.device)
+ batch_y_mark = batch_y_mark.float().to(self.device)
+
+ # decoder input
+ dec_inp = torch.zeros_like(batch_y[:, -self.pred_len:, :]).float()
+ outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
+
+ outputs = outputs[:, -self.pred_len:]
+ batch_y = batch_y[:, -self.pred_len:]
+
+ loss = self.criterion(outputs, batch_y)
+ train_loss.append(loss.item())
+
+ if (i + 1) % 500 == 0:
+ logger.info('\titers: {0}, epoch: {1} | loss: {2:.7f}'
+ .format(i + 1, epoch + 1, loss.item()))
+
+ loss.backward()
+ self.optimizer.step()
+
+ train_loss = np.average(train_loss)
+ logger.info('Epoch: {0} cost time: {1} | Train Loss: {2:.7f}'
+ .format(epoch + 1, time.time() - epoch_time, train_loss))
+ return train_loss
+
+ def vali(self, epoch: int) -> Tuple[float, Dict]:
+ self.model.eval()
+ val_loss = []
+ metrics_value_dict = {name: [] for name in self.metric_names}
+ for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in
enumerate(self.dataloader):
+ batch_x = batch_x.float().to(self.device)
+ batch_y = batch_y.float().to(self.device)
+
+ batch_x_mark = batch_x_mark.float().to(self.device)
+ batch_y_mark = batch_y_mark.float().to(self.device)
+
+ # decoder input
+ dec_inp = torch.zeros_like(batch_y[:, -self.pred_len:, :]).float()
+ outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
+
+ outputs = outputs[:, -self.pred_len:]
+ batch_y = batch_y[:, -self.pred_len:]
+
+ loss = self.criterion(outputs, batch_y)
+ val_loss.append(loss.item())
+
+ for name in self.metric_names:
+ metric = self.metrics_dict[name]
+ value = metric(outputs.detach().cpu().numpy(),
batch_y.detach().cpu().numpy())
+ metrics_value_dict[name].append(value)
+
+ for name, value_list in metrics_value_dict.items():
+ metrics_value_dict[name] = np.average(value_list)
+
+ self.datanode_client.record_model_metrics(
+ model_id=self.model_id,
+ trial_id=self.trial_id,
+ metrics=list(metrics_value_dict.keys()),
+ values=list(metrics_value_dict.values())
+ )
+ val_loss = np.average(val_loss)
+ logger.info('Epoch: {0} Vali Loss: {1:.7f}'.format(epoch + 1,
val_loss))
+ return val_loss, metrics_value_dict
+
+ def start(self) -> float:
+ """
+ Start training with the specified parameters, save the best model and
report metrics to the db.
+ """
+ try:
+ self.confignode_client.update_model_state(self.model_id,
TrainingState.RUNNING)
+ best_loss = np.inf
+ best_metrics_dict = None
+ model_path = None
+ for epoch in range(self.epochs):
+ self.train(epoch)
+ val_loss, metrics_dict = self.vali(epoch)
+ if val_loss < best_loss:
+ best_loss = val_loss
+ best_metrics_dict = metrics_dict
+ model_path = model_storage.save_model(self.model,
+ self.model_configs,
+
model_id=self.model_id,
+
trial_id=self.trial_id)
+
+ logger.info(f'Trial: ({self.model_id}_{self.trial_id}) - Finished
with best model saved successfully')
+ model_info = {}
+ model_info.update(best_metrics_dict)
+ model_info.update(self.trial_configs)
+ model_info['model_path'] = model_path
+ self.confignode_client.update_model_info(self.model_id,
self.trial_id, model_info)
+ self.confignode_client.update_model_state(self.model_id,
TrainingState.FINISHED, self.trial_id)
+ return best_loss
+ except Exception as e:
+ logger.warn(e)
+ self.confignode_client.update_model_state(self.model_id,
TrainingState.FAILED)
+ raise e
diff --git a/mlnode/iotdb/mlnode/serde.py b/mlnode/iotdb/mlnode/serde.py
index 26860faf386..4b5e90c2554 100644
--- a/mlnode/iotdb/mlnode/serde.py
+++ b/mlnode/iotdb/mlnode/serde.py
@@ -15,15 +15,87 @@
# specific language governing permissions and limitations
# under the License.
#
+from enum import Enum
+
import numpy as np
import pandas as pd
-from iotdb.utils.IoTDBConstants import TSDataType
+
+class TSDataType(Enum):
+ BOOLEAN = 0
+ INT32 = 1
+ INT64 = 2
+ FLOAT = 3
+ DOUBLE = 4
+ TEXT = 5
+
+ # this method is implemented to avoid the issue reported by:
+ # https://bugs.python.org/issue30545
+ def __eq__(self, other) -> bool:
+ return self.value == other.value
+
+ def __hash__(self):
+ return self.value
+
+ def np_dtype(self):
+ return {
+ TSDataType.BOOLEAN: np.dtype(">?"),
+ TSDataType.FLOAT: np.dtype(">f4"),
+ TSDataType.DOUBLE: np.dtype(">f8"),
+ TSDataType.INT32: np.dtype(">i4"),
+ TSDataType.INT64: np.dtype(">i8"),
+ TSDataType.TEXT: np.dtype("str"),
+ }[self]
+
TIMESTAMP_STR = "Time"
START_INDEX = 2
+# convert dataFrame to tsBlock in binary
+def convert_to_binary(data_frame: pd.DataFrame):
+ data_shape = data_frame.shape
+ value_column_size = data_shape[1] - 1
+ position_count = data_shape[0]
+ keys = data_frame.keys()
+
+ # binary_res = [value_column_size.to_bytes(4, byteorder="big")]
+ binary = value_column_size.to_bytes(4, byteorder="big")
+
+ # all the tsDataType are double
+ for i in range(value_column_size):
+ binary += b'\x04'
+
+ # position count
+ binary += position_count.to_bytes(4, byteorder="big")
+
+ # column encoding
+ binary += b'\x02'
+ for i in range(value_column_size):
+ binary += b'\x02'
+
+ # write columns, the column in index 0 must be timeColumn
+ binary += bool.to_bytes(False, 1, byteorder="big")
+ for i in range(position_count):
+ value = data_frame[keys[0]][i]
+ if value.dtype.byteorder != '>':
+ value = value.byteswap()
+ binary += value.tobytes()
+
+ for i in range(value_column_size):
+ # the value can't be null
+ binary += bool.to_bytes(False, 1, byteorder="big")
+ col = data_frame[keys[i + 1]]
+ for j in range(position_count):
+ value = col[j]
+ if value.dtype.byteorder != '>':
+ value = value.byteswap()
+ binary += value.tobytes()
+
+ return binary
+
+
+# convert tsBlock in binary to dataFrame
def convert_to_df(name_list, type_list, name_index, binary_list):
column_name_list = [TIMESTAMP_STR]
column_type_list = [TSDataType.INT64]
@@ -135,7 +207,7 @@ def convert_to_df(name_list, type_list, name_index,
binary_list):
elif data_type == TSDataType.BOOLEAN:
tmp_array = np.full(total_length, np.nan, np.float32)
elif data_type == TSDataType.TEXT:
- tmp_array = np.full(total_length, None,
dtype=data_array.dtype)
+ tmp_array = np.full(total_length, np.nan,
dtype=data_array.dtype)
else:
raise Exception("Unsupported dataType in deserialization")
@@ -181,7 +253,7 @@ def convert_to_df(name_list, type_list, name_index,
binary_list):
return df
-# Serialized tsblock:
+# Serialized tsBlock:
#
+-------------+---------------+---------+------------+-----------+----------+
# | val col cnt | val col types | pos cnt | encodings | time col | val
col |
#
+-------------+---------------+---------+------------+-----------+----------+
@@ -199,10 +271,10 @@ def deserialize(buffer):
column_values = [None] * value_column_count
null_indicators = [None] * value_column_count
for i in range(value_column_count):
- column_value, nullIndicator, buffer = read_column(column_encodings[i +
1], buffer, data_types[i],
- position_count)
+ column_value, null_indicator, buffer = read_column(column_encodings[i
+ 1], buffer, data_types[i],
+ position_count)
column_values[i] = column_value
- null_indicators[i] = nullIndicator
+ null_indicators[i] = null_indicator
return time_column_values, column_values, null_indicators, position_count
@@ -230,11 +302,11 @@ def read_column_types(buffer, value_column_count):
data_types = []
for i in range(value_column_count):
res, buffer = read_byte_from_buffer(buffer)
- data_types.append(get_dataType(res))
+ data_types.append(get_data_type(res))
return data_types, buffer
-def get_dataType(value):
+def get_data_type(value):
if value == b'\x00':
return TSDataType.BOOLEAN
elif value == b'\x01':
@@ -247,8 +319,6 @@ def get_dataType(value):
return TSDataType.DOUBLE
elif value == b'\x05':
return TSDataType.TEXT
- elif value == b'\x06':
- return TSDataType.VECTOR
# Read ColumnEncodings
@@ -264,8 +334,8 @@ def read_column_encoding(buffer, size):
# Read Column
def deserialize_null_indicators(buffer, size):
- mayHaveNull, buffer = read_byte_from_buffer(buffer)
- if mayHaveNull != b'\x00':
+ may_have_null, buffer = read_byte_from_buffer(buffer)
+ if may_have_null != b'\x00':
return deserialize_from_boolean_array(buffer, size)
return None, buffer
@@ -278,8 +348,8 @@ def deserialize_null_indicators(buffer, size):
# +---------------+-----------------+-------------+
def read_time_column(buffer, size):
- nullIndicators, buffer = deserialize_null_indicators(buffer, size)
- if nullIndicators is None:
+ null_indicators, buffer = deserialize_null_indicators(buffer, size)
+ if null_indicators is None:
values, buffer = read_from_buffer(
buffer, size * 8
)
@@ -288,16 +358,16 @@ def read_time_column(buffer, size):
return values, buffer
-def read_INT64_column(buffer, data_type, position_count):
- nullIndicators, buffer = deserialize_null_indicators(buffer,
position_count)
- if nullIndicators is None:
+def read_int64_column(buffer, data_type, position_count):
+ null_indicators, buffer = deserialize_null_indicators(buffer,
position_count)
+ if null_indicators is None:
size = position_count
else:
- size = nullIndicators.count(False)
+ size = null_indicators.count(False)
if TSDataType.INT64 == data_type or TSDataType.DOUBLE == data_type:
values, buffer = read_from_buffer(buffer, size * 8)
- return values, nullIndicators, buffer
+ return values, null_indicators, buffer
else:
raise Exception("Invalid data type: " + data_type)
@@ -309,16 +379,16 @@ def read_INT64_column(buffer, data_type, position_count):
# | byte | list[byte] | list[int32] |
# +---------------+-----------------+-------------+
-def read_Int32_column(buffer, data_type, position_count):
- nullIndicators, buffer = deserialize_null_indicators(buffer,
position_count)
- if nullIndicators is None:
+def read_int32_column(buffer, data_type, position_count):
+ null_indicators, buffer = deserialize_null_indicators(buffer,
position_count)
+ if null_indicators is None:
size = position_count
else:
- size = nullIndicators.count(False)
+ size = null_indicators.count(False)
if TSDataType.INT32 == data_type or TSDataType.FLOAT == data_type:
values, buffer = read_from_buffer(buffer, size * 4)
- return values, nullIndicators, buffer
+ return values, null_indicators, buffer
else:
raise Exception("Invalid data type: " + data_type)
@@ -333,9 +403,9 @@ def read_Int32_column(buffer, data_type, position_count):
def read_byte_column(buffer, data_type, position_count):
if data_type != TSDataType.BOOLEAN:
raise Exception("Invalid data type: " + data_type)
- nullIndicators, buffer = deserialize_null_indicators(buffer,
position_count)
+ null_indicators, buffer = deserialize_null_indicators(buffer,
position_count)
res, buffer = deserialize_from_boolean_array(buffer, position_count)
- return res, nullIndicators, buffer
+ return res, null_indicators, buffer
def deserialize_from_boolean_array(buffer, size):
@@ -386,31 +456,31 @@ def deserialize_from_boolean_array(buffer, size):
def read_binary_column(buffer, data_type, position_count):
if data_type != TSDataType.TEXT:
raise Exception("Invalid data type: " + data_type)
- nullIndicators, buffer = deserialize_null_indicators(buffer,
position_count)
+ null_indicators, buffer = deserialize_null_indicators(buffer,
position_count)
- if nullIndicators is None:
+ if null_indicators is None:
size = position_count
else:
- size = nullIndicators.count(False)
+ size = null_indicators.count(False)
values = [None] * size
for i in range(size):
length, buffer = read_int_from_buffer(buffer)
res, buffer = read_from_buffer(buffer, length)
values[i] = res
- return values, nullIndicators, buffer
+ return values, null_indicators, buffer
def read_column(encoding, buffer, data_type, position_count):
if encoding == b'\x00':
return read_byte_column(buffer, data_type, position_count)
elif encoding == b'\x01':
- return read_Int32_column(buffer, data_type, position_count)
+ return read_int32_column(buffer, data_type, position_count)
elif encoding == b'\x02':
- return read_INT64_column(buffer, data_type, position_count)
+ return read_int64_column(buffer, data_type, position_count)
elif encoding == b'\x03':
return read_binary_column(buffer, data_type, position_count)
elif encoding == b'\x04':
- return read_runLength_column(buffer, data_type, position_count)
+ return read_run_length_column(buffer, data_type, position_count)
else:
raise Exception("Unsupported encoding: " + encoding)
@@ -422,11 +492,11 @@ def read_column(encoding, buffer, data_type,
position_count):
# | byte | list[byte] |
# +-----------+-------------------------+
-def read_runLength_column(buffer, data_type, position_count):
+def read_run_length_column(buffer, data_type, position_count):
encoding, buffer = read_byte_from_buffer(buffer)
- column, nullIndicators, buffer = read_column(encoding, buffer, data_type,
1)
+ column, null_indicators, buffer = read_column(encoding, buffer, data_type,
1)
- return repeat(column, data_type, position_count), nullIndicators *
position_count, buffer
+ return repeat(column, data_type, position_count), null_indicators *
position_count, buffer
def repeat(buffer, data_type, position_count):
diff --git a/mlnode/iotdb/mlnode/storage.py b/mlnode/iotdb/mlnode/storage.py
index 84d7dfd7ed1..4c038786174 100644
--- a/mlnode/iotdb/mlnode/storage.py
+++ b/mlnode/iotdb/mlnode/storage.py
@@ -19,6 +19,7 @@
import json
import os
import shutil
+import threading
from typing import Dict, Tuple
import torch
@@ -39,7 +40,7 @@ class ModelStorage(object):
except PermissionError as e:
logger.error(e)
raise e
-
+ self.lock = threading.RLock()
self.__model_cache =
lrucache(descriptor.get_config().get_mn_model_storage_cache_size())
def save_model(self,
@@ -56,19 +57,21 @@ class ModelStorage(object):
model_file_path = os.path.join(model_dir_path, f'{trial_id}.pt')
sample_input = [torch.randn(1, model_config['input_len'],
model_config['input_vars'])]
+ self.lock.acquire()
torch.jit.save(torch.jit.trace(model, sample_input),
model_file_path,
_extra_files={'model_config': json.dumps(model_config)})
+ self.lock.release()
return os.path.abspath(model_file_path)
- def load_model(self, model_id: str, trial_id: str) ->
Tuple[torch.jit.ScriptModule, Dict]:
+ def load_model(self, file_path: str) -> Tuple[torch.jit.ScriptModule,
Dict]:
"""
Returns:
jit_model: a ScriptModule contains model architecture and
parameters, which can be deployed cross-platform
model_config: a dict contains model attributes
"""
- file_path = os.path.join(self.__model_dir, f'{model_id}',
f'{trial_id}.pt')
- if model_id in self.__model_cache:
+ file_path = os.path.join(self.__model_dir, file_path)
+ if file_path in self.__model_cache:
return self.__model_cache[file_path]
else:
if not os.path.exists(file_path):
diff --git a/mlnode/resources/conf/iotdb-mlnode.toml
b/mlnode/resources/conf/iotdb-mlnode.toml
index a029509e643..e15a5515739 100644
--- a/mlnode/resources/conf/iotdb-mlnode.toml
+++ b/mlnode/resources/conf/iotdb-mlnode.toml
@@ -39,6 +39,18 @@ mn_model_storage_dir = "models"
# Datatype: int
mn_model_storage_cachesize = 30
+# Maximum number of training model tasks, otherwise the task is pending
+# Datatype: int
+mn_task_pool_size = 10
+
+# Maximum number of trials to be explored in a tuning task
+# Datatype: int
+mn_tuning_trial_num = 20
+
+# Concurrency of trials in a tuning task
+# Datatype: int
+mn_tuning_trial_concurrency = 4
+
####################
### Target Config Node
####################
diff --git a/mlnode/test/test_create_forecast_dataset.py
b/mlnode/test/test_create_forecast_dataset.py
new file mode 100644
index 00000000000..49e2d177e8f
--- /dev/null
+++ b/mlnode/test/test_create_forecast_dataset.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import os
+
+import requests
+
+from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
+from iotdb.mlnode.data_access.factory import create_forecast_dataset
+
+
+def test_create_dataset():
+ if not os.path.exists('sample_data.csv'):
+ response =
requests.get('https://cloud.tsinghua.edu.cn/f/9127c193e7254baeaed2/?dl=1')
+ with open('sample_data.csv', 'wb') as f:
+ f.write(response.content)
+ data, data_config = create_forecast_dataset(dataset_type='window',
+ input_len=192,
+ pred_len=96,
+ source_type='file',
+ filename='sample_data.csv')
+ assert data_config['dataset_type'] == str(DatasetType.WINDOW)
+ assert data_config['source_type'] == str(DataSourceType.FILE)
+ assert data_config['input_vars'] == data.get_variable_num()
+ assert data_config['output_vars'] == data.get_variable_num()
+
+ data_item = data[0]
+ x, y, x_enc, y_enc = data_item
+ assert x.shape[0] == 192
+ assert y.shape[0] == 96
+ assert x.shape[1] == data.get_variable_num()
+ assert y.shape[1] == data.get_variable_num()
+
+ data, data_config = create_forecast_dataset(dataset_type='window',
+ input_len=192,
+ pred_len=96,
+ source_type='file',
+ filename='sample_data.csv',
+ query_filter='0,-1',
+
query_expressions=['root.eg.etth1.*'])
+ # config about thrift source not belongs to file source
+ assert 'query_expression' not in data_config
+ assert 'query_filter' not in data_config
+
+
+def test_bad_config_dataset1():
+ try:
+ data, data_config =
create_forecast_dataset(dataset_type='dummy_dataset',
+ source_type='file')
+ except Exception as e:
+ print(e) # ('dataset_type', 'dummy_dataset')
+ try:
+ data, data_config = create_forecast_dataset(dataset_type='window',
+ source_type='dummy_source')
+ except Exception as e:
+ print(e) # ('source_type', 'dummy_source')
+
+
+def test_missing_config_dataset1():
+ try:
+ data, data_config = create_forecast_dataset(dataset_type='window',
+ source_type='file')
+ except Exception as e:
+ print(e) # (filename) is Missing
+
+
+def test_bad_config_dataset2():
+ try:
+ data, data_config = create_forecast_dataset(dataset_type='window',
+ source_type='file',
+ filename='sample_data.csv',
+ input_vars=1)
+ except Exception as e:
+ print(e) # ('input_vars', 1)
diff --git a/mlnode/test/test_create_forecast_model.py
b/mlnode/test/test_create_forecast_model.py
new file mode 100644
index 00000000000..ed439fd9c05
--- /dev/null
+++ b/mlnode/test/test_create_forecast_model.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import torch
+
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.algorithm.factory import create_forecast_model
+from iotdb.mlnode.exception import BadConfigValueError
+
+
+def test_create_forecast_model():
+ model, model_config = create_forecast_model(model_name='dlinear',
+ kernel_size=25, input_vars=8,
output_vars=8)
+ sample_input = torch.randn(1, model_config['input_len'],
model_config['input_vars'])
+ output = model(sample_input)
+ assert output.shape[1] == model_config['pred_len']
+ assert output.shape[2] == model_config['output_vars']
+ assert model_config['kernel_size'] == 25
+
+ model, model_config = create_forecast_model(model_name='nbeats',
+ kernel_size=25, d_model=64)
+ assert model_config['d_model'] == 64
+ assert 'kernel_size' not in model_config # config kernel_size not belongs
to nbeats model
+
+
+def test_bad_config_model1():
+ try:
+ model, models = create_forecast_model(model_name='dlinear_dummy',
+ kernel_size=25, input_vars=8,
output_vars=8)
+ except BadConfigValueError as e:
+ print(e) # BadConfigValueError: ('model_name', 'dlinear_dummy')
+
+
+def test_bad_config_model2():
+ try:
+ model, models = create_forecast_model(model_name='dlinear',
+ kernel_size=25, input_vars=0)
+ except BadConfigValueError as e:
+ print(e) # ('input_vars', 0)
+
+
+def test_bad_config_model3():
+ try:
+ model, models = create_forecast_model(model_name='dlinear',
+ kernel_size=-1, input_vars=8,
output_vars=8)
+ except BadConfigValueError as e:
+ print(e) # ('kernel_size', -1)
+
+
+def test_bad_config_model4():
+ try:
+ model, models = create_forecast_model(model_name='dlinear',
+ forecast_task_type='dummy_task')
+ except BadConfigValueError as e:
+ print(e) # ('forecast_task_type', 'dummy_task')
+
+
+def test_bad_config_model5():
+ try:
+ model, models = create_forecast_model(model_name='dlinear',
+
forecast_task_type=ForecastTaskType.ENDOGENOUS,
+ kernel_size=25, input_vars=8,
output_vars=1)
+ except BadConfigValueError as e:
+ print(e) # ('forecast_task_type', <ForecastTaskType.ENDOGENOUS:
'endogenous'>)
diff --git a/mlnode/test/test_model_storage.py
b/mlnode/test/test_model_storage.py
index 863e73b7164..9a6399b2ea2 100644
--- a/mlnode/test/test_model_storage.py
+++ b/mlnode/test/test_model_storage.py
@@ -63,7 +63,7 @@ def test_load_not_exist_model():
trial_id = 'dummy_trial'
model_id = 'dummy_model'
try:
- model_loaded, model_config_loaded =
model_storage.load_model(model_id=model_id, trial_id=trial_id)
+ model_storage.load_model(model_id=model_id, trial_id=trial_id)
except Exception as e:
assert e.message == ModelNotExistError(
os.path.join('.',
descriptor.get_config().get_mn_model_storage_dir(),
diff --git a/mlnode/test/test_parse_training_request.py
b/mlnode/test/test_parse_training_request.py
new file mode 100644
index 00000000000..2ea7f978da3
--- /dev/null
+++ b/mlnode/test/test_parse_training_request.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
+from iotdb.mlnode.parser import parse_training_request
+from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq
+
+
+def test_parse_training_request():
+ model_id = 'mid_etth1_dlinear_default'
+ is_auto = False
+ model_configs = {
+ 'task_class': 'forecast_training_task',
+ 'source_type': 'thrift',
+ 'dataset_type': 'window', # or use DatasetType.WINDOW,
+ 'filename': 'ETTh1.csv',
+ 'time_embed': 'h',
+ 'input_len': 96,
+ 'pred_len': 96,
+ 'model_name': 'dlinear',
+ 'input_vars': 7,
+ 'output_vars': 7,
+ 'forecast_type': 'endogenous',
+ 'kernel_size': 25,
+ 'learning_rate': 1e-3,
+ 'batch_size': 32,
+ 'num_workers': 0,
+ 'epochs': 10,
+ 'metric_names': ['MSE', 'MAE']
+ }
+ query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**',
'root.eg.etth1.**']
+ query_filter = '0,1501516800000'
+ req = TCreateTrainingTaskReq(
+ modelId=str(model_id),
+ isAuto=is_auto,
+ modelConfigs={k: str(v) for k, v in model_configs.items()},
+ queryExpressions=[str(query) for query in query_expressions],
+ queryFilter=str(query_filter),
+ )
+ data_config, model_config, task_config = parse_training_request(req)
+ for config in model_configs:
+ if config in data_config:
+ assert data_config[config] == model_configs[config]
+ if config in model_config:
+ assert model_config[config] == model_configs[config]
+ if config in task_config:
+ assert task_config[config] == model_configs[config]
+
+
+def test_missing_argument():
+ # missing model_name
+ model_id = 'mid_etth1_dlinear_default'
+ is_auto = False
+ model_configs = {
+ 'task_class': 'forecast_training_task',
+ 'source_type': 'thrift',
+ 'dataset_type': 'window',
+ 'filename': 'ETTh1.csv',
+ 'time_embed': 'h',
+ 'input_len': 96,
+ 'pred_len': 96,
+ 'input_vars': 7,
+ 'output_vars': 7,
+ 'forecast_type': 'endogenous',
+ 'kernel_size': 25,
+ 'learning_rate': 1e-3,
+ 'batch_size': 32,
+ 'num_workers': 0,
+ 'epochs': 10,
+ 'metric_names': ['MSE', 'MAE']
+ }
+ query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**',
'root.eg.etth1.**']
+ query_filter = '0,1501516800000'
+ req = TCreateTrainingTaskReq(
+ modelId=str(model_id),
+ isAuto=is_auto,
+ modelConfigs={k: str(v) for k, v in model_configs.items()},
+ queryExpressions=[str(query) for query in query_expressions],
+ queryFilter=str(query_filter),
+ )
+ try:
+ parse_training_request(req)
+ except Exception as e:
+ assert e.message ==
MissingConfigError(config_name='model_name').message
+
+
+def test_wrong_argument_type():
+ model_id = 'mid_etth1_dlinear_default'
+ is_auto = False
+ model_configs = {
+ 'task_class': 'forecast_training_task',
+ 'source_type': 'thrift',
+ 'dataset_type': 'window',
+ 'filename': 'ETTh1.csv',
+ 'time_embed': 'h',
+ 'input_len': 96.7,
+ 'pred_len': 96,
+ 'model_name': 'dlinear',
+ 'input_vars': 7,
+ 'output_vars': 7,
+ 'forecast_type': 'endogenous',
+ 'kernel_size': 25,
+ 'learning_rate': 1e-3,
+ 'batch_size': 32,
+ 'num_workers': 0,
+ 'epochs': 10,
+ 'metric_names': ['MSE', 'MAE']
+ }
+ query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**',
'root.eg.etth1.**']
+ query_filter = '0,1501516800000'
+ req = TCreateTrainingTaskReq(
+ modelId=str(model_id),
+ isAuto=is_auto,
+ modelConfigs={k: str(v) for k, v in model_configs.items()},
+ queryExpressions=[str(query) for query in query_expressions],
+ queryFilter=str(query_filter),
+ )
+ try:
+ data_config, model_config, task_config = parse_training_request(req)
+ except Exception as e:
+ assert e.message == WrongTypeConfigError(config_name='input_len',
expected_type='int').message
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index cd3e94e6da4..bf3773837b5 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -2141,33 +2141,6 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
return future;
}
- @Override
- public SettableFuture<ConfigTaskResult> setSpaceQuota(
- SetSpaceQuotaStatement setSpaceQuotaStatement) {
- SettableFuture<ConfigTaskResult> future = SettableFuture.create();
- TSStatus tsStatus = new TSStatus();
- TSetSpaceQuotaReq req = new TSetSpaceQuotaReq();
- req.setDatabase(setSpaceQuotaStatement.getPrefixPathList());
- TSpaceQuota spaceQuota = new TSpaceQuota();
- spaceQuota.setDeviceNum(setSpaceQuotaStatement.getDeviceNum());
- spaceQuota.setTimeserieNum(setSpaceQuotaStatement.getTimeSeriesNum());
- spaceQuota.setDiskSize(setSpaceQuotaStatement.getDiskSize());
- req.setSpaceLimit(spaceQuota);
- try (ConfigNodeClient client =
-
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
- // Send request to some API server
- tsStatus = client.setSpaceQuota(req);
- } catch (Exception e) {
- future.setException(e);
- }
- if (tsStatus.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS));
- } else {
- future.setException(new IoTDBException(tsStatus.message, tsStatus.code));
- }
- return future;
- }
-
@Override
public SettableFuture<ConfigTaskResult> createModel(CreateModelStatement
createModelStatement) {
createModelStatement.semanticCheck();
@@ -2224,6 +2197,89 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
return future;
}
+ @Override
+ public SettableFuture<ConfigTaskResult> dropModel(String modelId) {
+ SettableFuture<ConfigTaskResult> future = SettableFuture.create();
+ try (ConfigNodeClient client =
+
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
+ final TSStatus executionStatus = client.dropModel(new
TDropModelReq(modelId));
+ if (TSStatusCode.SUCCESS_STATUS.getStatusCode() !=
executionStatus.getCode()) {
+ LOGGER.warn("[{}] Failed to drop model {}.", executionStatus, modelId);
+ future.setException(new IoTDBException(executionStatus.message,
executionStatus.code));
+ } else {
+ future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS));
+ }
+ } catch (ClientManagerException | TException e) {
+ future.setException(e);
+ }
+ return future;
+ }
+
+ @Override
+ public SettableFuture<ConfigTaskResult> showModels() {
+ SettableFuture<ConfigTaskResult> future = SettableFuture.create();
+ try (ConfigNodeClient client =
+
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
+ TShowModelResp showModelResp = client.showModel(new TShowModelReq());
+ if (showModelResp.getStatus().getCode() !=
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ future.setException(
+ new IoTDBException(showModelResp.getStatus().message,
showModelResp.getStatus().code));
+ return future;
+ }
+ // convert model info list and buildTsBlock
+ ShowModelsTask.buildTsBlock(showModelResp.getModelInfoList(), future);
+ } catch (ClientManagerException | TException e) {
+ future.setException(e);
+ }
+ return future;
+ }
+
+ @Override
+ public SettableFuture<ConfigTaskResult> showTrails(String modelId) {
+ SettableFuture<ConfigTaskResult> future = SettableFuture.create();
+ try (ConfigNodeClient client =
+
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
+ TShowTrailResp showTrailResp = client.showTrail(new
TShowTrailReq(modelId));
+ if (showTrailResp.getStatus().getCode() !=
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ future.setException(
+ new IoTDBException(showTrailResp.getStatus().message,
showTrailResp.getStatus().code));
+ return future;
+ }
+ // convert trail info list and buildTsBlock
+ ShowTrailsTask.buildTsBlock(showTrailResp.getTrailInfoList(), future);
+ } catch (ClientManagerException | TException e) {
+ future.setException(e);
+ }
+ return future;
+ }
+
+ @Override
+ public SettableFuture<ConfigTaskResult> setSpaceQuota(
+ SetSpaceQuotaStatement setSpaceQuotaStatement) {
+ SettableFuture<ConfigTaskResult> future = SettableFuture.create();
+ TSStatus tsStatus = new TSStatus();
+ TSetSpaceQuotaReq req = new TSetSpaceQuotaReq();
+ req.setDatabase(setSpaceQuotaStatement.getPrefixPathList());
+ TSpaceQuota spaceQuota = new TSpaceQuota();
+ spaceQuota.setDeviceNum(setSpaceQuotaStatement.getDeviceNum());
+ spaceQuota.setTimeserieNum(setSpaceQuotaStatement.getTimeSeriesNum());
+ spaceQuota.setDiskSize(setSpaceQuotaStatement.getDiskSize());
+ req.setSpaceLimit(spaceQuota);
+ try (ConfigNodeClient client =
+
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
+ // Send request to some API server
+ tsStatus = client.setSpaceQuota(req);
+ } catch (Exception e) {
+ future.setException(e);
+ }
+ if (tsStatus.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS));
+ } else {
+ future.setException(new IoTDBException(tsStatus.message, tsStatus.code));
+ }
+ return future;
+ }
+
@Override
public SettableFuture<ConfigTaskResult> showSpaceQuota(
ShowSpaceQuotaStatement showSpaceQuotaStatement) {
@@ -2306,43 +2362,6 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
return throttleQuotaResp;
}
- @Override
- public SettableFuture<ConfigTaskResult> dropModel(String modelId) {
- SettableFuture<ConfigTaskResult> future = SettableFuture.create();
- try (ConfigNodeClient client =
-
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
- final TSStatus executionStatus = client.dropModel(new
TDropModelReq(modelId));
- if (TSStatusCode.SUCCESS_STATUS.getStatusCode() !=
executionStatus.getCode()) {
- LOGGER.warn("[{}] Failed to drop model {}.", executionStatus, modelId);
- future.setException(new IoTDBException(executionStatus.message,
executionStatus.code));
- } else {
- future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS));
- }
- } catch (ClientManagerException | TException e) {
- future.setException(e);
- }
- return future;
- }
-
- @Override
- public SettableFuture<ConfigTaskResult> showModels() {
- SettableFuture<ConfigTaskResult> future = SettableFuture.create();
- try (ConfigNodeClient client =
-
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
- TShowModelResp showModelResp = client.showModel(new TShowModelReq());
- if (showModelResp.getStatus().getCode() !=
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- future.setException(
- new IoTDBException(showModelResp.getStatus().message,
showModelResp.getStatus().code));
- return future;
- }
- // convert model info list and buildTsBlock
- ShowModelsTask.buildTsBlock(showModelResp.getModelInfoList(), future);
- } catch (ClientManagerException | TException e) {
- future.setException(e);
- }
- return future;
- }
-
@Override
public TSpaceQuotaResp getSpaceQuota() {
TSpaceQuotaResp spaceQuotaResp = new TSpaceQuotaResp();
@@ -2355,23 +2374,4 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
}
return spaceQuotaResp;
}
-
- @Override
- public SettableFuture<ConfigTaskResult> showTrails(String modelId) {
- SettableFuture<ConfigTaskResult> future = SettableFuture.create();
- try (ConfigNodeClient client =
-
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
- TShowTrailResp showTrailResp = client.showTrail(new
TShowTrailReq(modelId));
- if (showTrailResp.getStatus().getCode() !=
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- future.setException(
- new IoTDBException(showTrailResp.getStatus().message,
showTrailResp.getStatus().code));
- return future;
- }
- // convert trail info list and buildTsBlock
- ShowTrailsTask.buildTsBlock(showTrailResp.getTrailInfoList(), future);
- } catch (ClientManagerException | TException e) {
- future.setException(e);
- }
- return future;
- }
}