ycycse commented on code in PR #9338:
URL: https://github.com/apache/iotdb/pull/9338#discussion_r1141030512


##########
mlnode/iotdb/mlnode/algorithm/model_factory.py:
##########
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+from iotdb.mlnode.algorithm.models.DLinear import *
+from iotdb.mlnode.algorithm.models.NBeats import *
+
+
+support_forecasting_model = []
+support_forecasting_model.extend(DLinear.support_model_names)
+support_forecasting_model.extend(NBeats.support_model_names)
+
+"""
+Common configs for all forecasting model with default values
+"""
+
+
+def _common_cfg(**kwargs):
+    return {
+        'input_len': 96,
+        'pred_len': 96,
+        'input_vars': 1,
+        'output_vars': 1,
+        **kwargs
+    }
+
+
+"""
+Common forecasting task configs
+"""
+support_common_cfgs = {
+    # univariate forecasting
+    's': _common_cfg(
+        input_vars=1,
+        output_vars=1),
+
+    # univariate forecasting with observable exogenous variables
+    'ms': _common_cfg(
+        output_vars=1),
+
+    # multivariate forecasting, current support this only
+    'm': _common_cfg(),
+}
+
+
+def is_model(model_name: str) -> bool:
+    """
+    Check if a model name exists
+    """
+    return model_name in support_forecasting_model
+
+
+def list_model():
+    """
+    List support forecasting model
+    """
+    return support_forecasting_model
+
+
+def create_forecast_model(
+        model_name: str,
+        task_type: str = 'm',
+        input_len: int = 96,
+        pred_len: int = 96,
+        input_vars: int = 1,
+        output_vars: int = 1,
+        **kwargs,
+):
+    """
+    Factory method for all support forecasting models
+    the given arguments is common configs shared by all forecasting models
+    for specific model configs, see __model_cfg in 
`algorithm/models/MODELNAME.py`
+
+    Args:
+        model_name: see available models by `list_model`
+        task_type: 'm' for multivariate forecasting, currently only support 'm'
+                   'ms' for covariate forecasting,
+                   's' for univariate forecasting
+        input_len: time length of model input
+        pred_len: time length of model output
+        input_vars: number of input series
+        output_vars: number of output series
+        kwargs: for specific model configs, see returned `model_config` with 
kwargs=None
+
+    Returns:
+        model: torch.nn.Module
+        model_config: dict of model configurations
+    """
+    if not is_model(model_name):
+        raise RuntimeError(f'Unknown forecasting model: ({model_name}),'

Review Comment:
   RuntimeError is to much, maybe add a new exception type in the future.



##########
mlnode/iotdb/mlnode/algorithm/models/NBeats.py:
##########
@@ -0,0 +1,266 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import torch
+import torch.nn as nn
+from typing import Tuple
+
+__all__ = ['NBeats', 'nbeats']  # , 'nbeats_s', 'nbeats_t']
+
+"""
+Specific configs for NBeats with default values
+"""
+
+
+def _model_cfg(**kwargs):
+    return {
+        'block_type': 'g',
+        'd_model': 128,
+        'inner_layers': 4,
+        'outer_layers': 4,
+        # 'harmonics': 4,
+        # 'degree_of_polynomial': 3,
+        **kwargs
+    }
+
+
+"""
+Specific configs for NBeats variants
+"""
+support_model_cfgs = {
+    'nbeats': _model_cfg(
+        block_type='g'),
+    # 'nbeats_s': _model_cfg(
+    #     harmonics=4,
+    #     block_type='s'),
+    # 'nbeats_t': _model_cfg(
+    #     degree_of_polynomial=3,
+    #     block_type='t')
+}
+
+
+class GenericBasis(nn.Module):
+    """
+    Generic basis function.
+    """
+
+    def __init__(self, backcast_size: int, forecast_size: int):
+        super().__init__()
+        self.backcast_size = backcast_size
+        self.forecast_size = forecast_size
+
+    def forward(self, theta: torch.Tensor):
+        return theta[:, :self.backcast_size], theta[:, -self.forecast_size:]
+
+
+# class TrendBasis(nn.Module):
+#     """
+#     Trend basis function.
+#     """
+#     def __init__(self, degree_of_polynomial: int, backcast_size: int, 
forecast_size):
+#         super().__init__()
+#         polynomial_size = degree_of_polynomial + 1

Review Comment:
   Are these codes useful?



##########
mlnode/iotdb/mlnode/algorithm/models/NBeats.py:
##########
@@ -0,0 +1,266 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import torch
+import torch.nn as nn
+from typing import Tuple
+
+__all__ = ['NBeats', 'nbeats']  # , 'nbeats_s', 'nbeats_t']
+
+"""
+Specific configs for NBeats with default values
+"""
+
+
+def _model_cfg(**kwargs):
+    return {
+        'block_type': 'g',
+        'd_model': 128,
+        'inner_layers': 4,
+        'outer_layers': 4,
+        # 'harmonics': 4,
+        # 'degree_of_polynomial': 3,
+        **kwargs
+    }
+
+
+"""
+Specific configs for NBeats variants
+"""
+support_model_cfgs = {
+    'nbeats': _model_cfg(
+        block_type='g'),
+    # 'nbeats_s': _model_cfg(
+    #     harmonics=4,
+    #     block_type='s'),
+    # 'nbeats_t': _model_cfg(
+    #     degree_of_polynomial=3,
+    #     block_type='t')
+}
+
+
+class GenericBasis(nn.Module):
+    """
+    Generic basis function.
+    """
+
+    def __init__(self, backcast_size: int, forecast_size: int):
+        super().__init__()
+        self.backcast_size = backcast_size
+        self.forecast_size = forecast_size
+
+    def forward(self, theta: torch.Tensor):
+        return theta[:, :self.backcast_size], theta[:, -self.forecast_size:]
+
+
+# class TrendBasis(nn.Module):
+#     """
+#     Trend basis function.
+#     """
+#     def __init__(self, degree_of_polynomial: int, backcast_size: int, 
forecast_size):
+#         super().__init__()
+#         polynomial_size = degree_of_polynomial + 1
+#         self.backcast_basis = nn.Parameter(
+#             torch.tensor(np.concatenate([np.power(np.arange(backcast_size, 
dtype=float) / backcast_size, i)[None, :]
+#                                     for i in range(polynomial_size)]), 
dtype=torch.float32), requires_grad=False)
+#         self.forecast_basis = nn.Parameter(
+#             torch.tensor(np.concatenate([np.power(np.arange(forecast_size, 
dtype=float) / forecast_size, i)[None, :]
+#                                     for i in range(polynomial_size)]), 
dtype=torch.float32), requires_grad=False)
+
+#     def forward(self, theta):
+#         cut_point = self.forecast_basis.shape[0]
+#         backcast = torch.einsum('bp,pt->bt', theta[:, cut_point:], 
self.backcast_basis)
+#         forecast = torch.einsum('bp,pt->bt', theta[:, :cut_point], 
self.forecast_basis)
+#         return backcast, forecast
+
+
+# class SeasonalityBasis(nn.Module):
+#     """
+#     Seasonality basis function.
+#     """
+#     def __init__(self, harmonics: int, backcast_size: int, forecast_size: 
int):
+#         super().__init__()
+#         frequency = np.append(np.zeros(1, dtype=float),
+#                                         np.arange(harmonics, harmonics / 2 * 
forecast_size,
+#                                                     dtype=float) / 
harmonics)[None, :]
+#         backcast_grid = -2 * np.pi * (
+#                 np.arange(backcast_size, dtype=float)[:, None] / 
forecast_size) * frequency
+#         forecast_grid = 2 * np.pi * (
+#                 np.arange(forecast_size, dtype=float)[:, None] / 
forecast_size) * frequency
+
+#         backcast_cos_template = 
torch.tensor(np.transpose(np.cos(backcast_grid)), dtype=torch.float32)
+#         backcast_sin_template = 
torch.tensor(np.transpose(np.sin(backcast_grid)), dtype=torch.float32)
+#         backcast_template = torch.cat([backcast_cos_template, 
backcast_sin_template], dim=0)
+
+#         forecast_cos_template = 
torch.tensor(np.transpose(np.cos(forecast_grid)), dtype=torch.float32)
+#         forecast_sin_template = 
torch.tensor(np.transpose(np.sin(forecast_grid)), dtype=torch.float32)
+#         forecast_template = torch.cat([forecast_cos_template, 
forecast_sin_template], dim=0)
+
+#         self.backcast_basis = nn.Parameter(backcast_template, 
requires_grad=False)
+#         self.forecast_basis = nn.Parameter(forecast_template, 
requires_grad=False)
+
+#     def forward(self, theta):
+#         cut_point = self.forecast_basis.shape[0]
+#         backcast = torch.einsum('bp,pt->bt', theta[:, cut_point:], 
self.backcast_basis)
+#         forecast = torch.einsum('bp,pt->bt', theta[:, :cut_point], 
self.forecast_basis)
+#         return backcast, forecast
+
+
+class NBeatsBlock(nn.Module):
+    """
+    N-BEATS block which takes a basis function as an argument
+    """
+
+    def __init__(self,
+                 input_size,
+                 theta_size: int,
+                 basis_function: nn.Module,
+                 layers: int,
+                 layer_size: int):
+        """
+        N-BEATS block
+
+        Args:
+            input_size: input sample size
+            theta_size:  number of parameters for the basis function
+            basis_function: basis function which takes the parameters and 
produces backcast and forecast
+            layers: number of layers
+            layer_size: layer size
+        """
+        super().__init__()
+        self.layers = nn.ModuleList([nn.Linear(in_features=input_size, 
out_features=layer_size)] +
+                                    [nn.Linear(in_features=layer_size, 
out_features=layer_size)
+                                     for _ in range(layers - 1)])
+        self.basis_parameters = nn.Linear(in_features=layer_size, 
out_features=theta_size)
+        self.basis_function = basis_function
+
+    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        block_input = x
+        for layer in self.layers:
+            block_input = torch.relu(layer(block_input))
+        basis_parameters = self.basis_parameters(block_input)
+        return self.basis_function(basis_parameters)
+
+
+class NBeatsUnivar(nn.Module):
+    """
+    N-Beats Model.
+    """
+
+    def __init__(self, blocks: nn.ModuleList):
+        super().__init__()
+        self.blocks = blocks
+
+    def forward(self, x):
+        residuals = x
+        forecast = None
+        for _, block in enumerate(self.blocks):
+            backcast, block_forecast = block(residuals)
+            residuals = (residuals - backcast)
+            if forecast is None:
+                forecast = block_forecast
+            else:
+                forecast += block_forecast
+        return forecast
+
+
+block_dict = {
+    'g': GenericBasis,
+    # 't': TrendBasis,
+    # 's': SeasonalityBasis,

Review Comment:
   If these will be supported in the future. There is no need to annotate the 
codes, you can throw an UnsupportException instead.



##########
mlnode/iotdb/mlnode/algorithm/model_factory.py:
##########
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+from iotdb.mlnode.algorithm.models.DLinear import *
+from iotdb.mlnode.algorithm.models.NBeats import *
+
+
+support_forecasting_model = []
+support_forecasting_model.extend(DLinear.support_model_names)
+support_forecasting_model.extend(NBeats.support_model_names)
+
+"""
+Common configs for all forecasting model with default values
+"""
+
+
+def _common_cfg(**kwargs):
+    return {
+        'input_len': 96,
+        'pred_len': 96,
+        'input_vars': 1,
+        'output_vars': 1,
+        **kwargs
+    }
+
+
+"""
+Common forecasting task configs
+"""
+support_common_cfgs = {
+    # univariate forecasting
+    's': _common_cfg(
+        input_vars=1,
+        output_vars=1),
+
+    # univariate forecasting with observable exogenous variables
+    'ms': _common_cfg(
+        output_vars=1),
+
+    # multivariate forecasting, current support this only
+    'm': _common_cfg(),
+}
+
+
+def is_model(model_name: str) -> bool:
+    """
+    Check if a model name exists
+    """
+    return model_name in support_forecasting_model
+
+
+def list_model():
+    """
+    List support forecasting model
+    """
+    return support_forecasting_model
+
+
+def create_forecast_model(
+        model_name: str,
+        task_type: str = 'm',
+        input_len: int = 96,
+        pred_len: int = 96,
+        input_vars: int = 1,
+        output_vars: int = 1,
+        **kwargs,
+):
+    """
+    Factory method for all support forecasting models
+    the given arguments is common configs shared by all forecasting models
+    for specific model configs, see __model_cfg in 
`algorithm/models/MODELNAME.py`
+
+    Args:
+        model_name: see available models by `list_model`
+        task_type: 'm' for multivariate forecasting, currently only support 'm'
+                   'ms' for covariate forecasting,
+                   's' for univariate forecasting
+        input_len: time length of model input
+        pred_len: time length of model output
+        input_vars: number of input series
+        output_vars: number of output series
+        kwargs: for specific model configs, see returned `model_config` with 
kwargs=None
+
+    Returns:
+        model: torch.nn.Module
+        model_config: dict of model configurations
+    """
+    if not is_model(model_name):
+        raise RuntimeError(f'Unknown forecasting model: ({model_name}),'
+                           f' which should be one of {list_model()}')
+    if task_type not in support_common_cfgs.keys():
+        raise RuntimeError(f'Unknown forecasting task type: ({task_type})'
+                           f' which should be one of 
{support_common_cfgs.keys()}')
+
+    common_cfg = support_common_cfgs[task_type]
+    common_cfg['input_len'] = input_len
+    common_cfg['pred_len'] = pred_len
+    common_cfg['input_vars'] = input_vars
+    common_cfg['output_vars'] = output_vars
+    assert input_len > 0 and pred_len > 0, 'Length of input/output series 
should be positive'
+    assert input_vars > 0 and output_vars > 0, 'Number of input/output series 
should be positive'

Review Comment:
   Use `Exception` if possible. `assert` is usually used when debugging.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to