This is an automated email from the ASF dual-hosted git repository. hui pushed a commit to branch lmh/forecastTest in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 4c8e850b422b9b826fbb18ee42af6aaa80f25184 Author: Minghui Liu <[email protected]> AuthorDate: Mon May 22 10:40:42 2023 +0800 fix mlnode --- mlnode/iotdb/mlnode/algorithm/metric.py | 2 +- .../mlnode/algorithm/models/forecast/dlinear.py | 2 +- mlnode/iotdb/mlnode/client.py | 9 +++---- mlnode/iotdb/mlnode/data_access/offline/source.py | 4 +-- mlnode/iotdb/mlnode/handler.py | 6 ++--- mlnode/iotdb/mlnode/parser.py | 4 +-- mlnode/iotdb/mlnode/process/manager.py | 13 +++++----- mlnode/iotdb/mlnode/process/task.py | 30 ++++++++-------------- mlnode/iotdb/mlnode/process/trial.py | 4 +-- mlnode/iotdb/mlnode/storage.py | 3 +-- mlnode/test/test_create_forecast_dataset.py | 1 + mlnode/test/test_create_forecast_model.py | 1 + mlnode/test/test_model_storage.py | 1 + mlnode/test/test_serde.py | 3 ++- 14 files changed, 38 insertions(+), 45 deletions(-) diff --git a/mlnode/iotdb/mlnode/algorithm/metric.py b/mlnode/iotdb/mlnode/algorithm/metric.py index c32580743ba..5ffc5f24a5c 100644 --- a/mlnode/iotdb/mlnode/algorithm/metric.py +++ b/mlnode/iotdb/mlnode/algorithm/metric.py @@ -16,7 +16,7 @@ # under the License. # from abc import abstractmethod -from typing import List, Dict +from typing import Dict, List import numpy as np diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py index a2ae149134f..35ba728fa78 100644 --- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py +++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py @@ -17,10 +17,10 @@ # import math +from typing import Dict, Tuple import torch import torch.nn as nn -from typing import Dict, Tuple from iotdb.mlnode.algorithm.enums import ForecastTaskType from iotdb.mlnode.exception import BadConfigValueError diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py index bb442cd7b8d..a5bdf109e8f 100644 --- a/mlnode/iotdb/mlnode/client.py +++ b/mlnode/iotdb/mlnode/client.py @@ -33,13 +33,12 @@ from iotdb.thrift.confignode import IConfigNodeRPCService from iotdb.thrift.confignode.ttypes import (TUpdateModelInfoReq, TUpdateModelStateReq) from iotdb.thrift.datanode import IMLNodeInternalRPCService -from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq, - TRecordModelMetricsReq, - TFetchMoreDataReq) +from iotdb.thrift.datanode.ttypes import (TFetchMoreDataReq, + TFetchTimeseriesReq, + TRecordModelMetricsReq) from iotdb.thrift.mlnode import IMLNodeRPCService from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq, - TDeleteModelReq, - TForecastReq) + TDeleteModelReq, TForecastReq) class ClientManager(object): diff --git a/mlnode/iotdb/mlnode/data_access/offline/source.py b/mlnode/iotdb/mlnode/data_access/offline/source.py index 418c9c7afc0..4f29241913f 100644 --- a/mlnode/iotdb/mlnode/data_access/offline/source.py +++ b/mlnode/iotdb/mlnode/data_access/offline/source.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. # +from typing import List + import numpy as np import pandas as pd -from typing import List - from iotdb.mlnode.client import client_manager diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py index a349a99d613..d4b4ca3b034 100644 --- a/mlnode/iotdb/mlnode/handler.py +++ b/mlnode/iotdb/mlnode/handler.py @@ -16,19 +16,19 @@ # under the License. # +from iotdb.mlnode.config import descriptor from iotdb.mlnode.constant import TSStatusCode from iotdb.mlnode.data_access.factory import create_forecast_dataset from iotdb.mlnode.log import logger -from iotdb.mlnode.parser import parse_training_request, parse_forecast_request +from iotdb.mlnode.parser import parse_forecast_request, parse_training_request from iotdb.mlnode.process.manager import TaskManager +from iotdb.mlnode.serde import convert_to_binary from iotdb.mlnode.storage import model_storage from iotdb.mlnode.util import get_status -from iotdb.mlnode.config import descriptor from iotdb.thrift.mlnode import IMLNodeRPCService from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq, TDeleteModelReq, TForecastReq, TForecastResp) -from iotdb.mlnode.serde import convert_to_binary class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface): diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py index 71ffbe7c372..34bb7eb65a9 100644 --- a/mlnode/iotdb/mlnode/parser.py +++ b/mlnode/iotdb/mlnode/parser.py @@ -24,8 +24,8 @@ from typing import Dict, List, Tuple from iotdb.mlnode.algorithm.enums import ForecastTaskType from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError -from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TForecastReq from iotdb.mlnode.serde import convert_to_df +from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TForecastReq class _ConfigParser(argparse.ArgumentParser): @@ -222,7 +222,7 @@ def parse_forecast_request(req: TForecastReq): ts_dataset = req.inputData pred_len = req.predictLength - data = convert_to_df(column_name_list, column_type_list, None, ts_dataset) + data = convert_to_df(column_name_list, column_type_list, None, [ts_dataset]) time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]] full_data = (data, time_stamp) return model_path, full_data, pred_len diff --git a/mlnode/iotdb/mlnode/process/manager.py b/mlnode/iotdb/mlnode/process/manager.py index ee7df83acae..5957dadbed2 100644 --- a/mlnode/iotdb/mlnode/process/manager.py +++ b/mlnode/iotdb/mlnode/process/manager.py @@ -17,13 +17,15 @@ # import multiprocessing as mp -import pandas as pd - from typing import Dict, Union + +import pandas as pd from torch.utils.data import Dataset + from iotdb.mlnode.log import logger -from iotdb.mlnode.process.task import ForecastingSingleTrainingTask, ForecastingTuningTrainingTask, \ - ForecastingInferenceTask +from iotdb.mlnode.process.task import (ForecastingInferenceTask, + ForecastingSingleTrainingTask, + ForecastingTuningTrainingTask) class TaskManager(object): @@ -102,6 +104,5 @@ class TaskManager(object): read_pipe, send_pipe = mp.Pipe() if task is not None: self.__training_process_pool.apply_async(task, args=(send_pipe,)) - logger.info(f'Forecasting process submitted successfully') - # task(send_pipe) + logger.info('Forecasting process submitted successfully') return read_pipe.recv() diff --git a/mlnode/iotdb/mlnode/process/task.py b/mlnode/iotdb/mlnode/process/task.py index ff0ee4c5a18..85611a948fb 100644 --- a/mlnode/iotdb/mlnode/process/task.py +++ b/mlnode/iotdb/mlnode/process/task.py @@ -17,25 +17,23 @@ # import os -import pandas as pd -import numpy as np - from abc import abstractmethod +from multiprocessing.connection import Connection from typing import Dict, Tuple +import numpy as np import optuna +import pandas as pd import torch -from torch import nn from torch.utils.data import Dataset -from multiprocessing.connection import Connection -from iotdb.mlnode.log import logger -from iotdb.mlnode.process.trial import ForecastingTrainingTrial from iotdb.mlnode.algorithm.factory import create_forecast_model from iotdb.mlnode.client import client_manager from iotdb.mlnode.config import descriptor -from iotdb.thrift.common.ttypes import TrainingState +from iotdb.mlnode.log import logger +from iotdb.mlnode.process.trial import ForecastingTrainingTrial from iotdb.mlnode.storage import model_storage +from iotdb.thrift.common.ttypes import TrainingState class ForestingTrainingObjective: @@ -229,7 +227,7 @@ class ForecastingInferenceTask(_BasicInferenceTask): task_configs: Dict, model_configs: Dict, pid_info: Dict, - data:Tuple, + data: Tuple, model_path: str ): super().__init__(task_configs, model_configs, pid_info, data, model_path) @@ -247,17 +245,14 @@ class ForecastingInferenceTask(_BasicInferenceTask): current_pred_len = 0 while current_pred_len < self.pred_len: current_data = full_data[:, -self.input_len:, :] - # current_data_stamp = timefeatures.time_features(full_data_stamp.iloc[-self.input_len:, :])[None, :] # batch current_data = torch.Tensor(current_data) output_data = self.model(current_data).detach().numpy() full_data = np.concatenate([full_data, output_data], axis=1) - # full_data_stamp = pd.concat([full_data_stamp, self.generate_future_mark(full_data_stamp, self.pred_len)]) current_pred_len += self.model_pred_len full_data_stamp = self.generate_future_mark(full_data_stamp, self.pred_len) - # ret_data = np.concatenate([full_data_stamp, full_data[0, -self.pred_len:, :]], axis=1) - # ret_data = ret_data[- - ret_data = pd.concat([pd.DataFrame(full_data_stamp.astype(np.int64)), pd.DataFrame(full_data[0, -self.pred_len:, :])], axis=1) - # ret_data = pd.DataFrame(ret_data) + ret_data = pd.concat( + [pd.DataFrame(full_data_stamp.astype(np.int64)), + pd.DataFrame(full_data[0, -self.pred_len:, :]).astype(np.double)], axis=1) ret_data.columns = list(np.arange(0, C + 1)) pipe.send(ret_data) @@ -284,12 +279,9 @@ class ForecastingInferenceTask(_BasicInferenceTask): data = data[None, :] # add batch dim return data, data_stamp - def generate_future_mark(self, data_stamp:pd.DataFrame, future_len: int) -> pd.DatetimeIndex: + def generate_future_mark(self, data_stamp: pd.DataFrame, future_len: int) -> pd.DatetimeIndex: time_deltas = data_stamp.diff().dropna() mean_timedelta = time_deltas.mean()[0] extrapolated_timestamp = pd.date_range(data_stamp.values[0][0], periods=future_len, freq=mean_timedelta) return extrapolated_timestamp[:, None] - -if __name__ == '__main__': - pass diff --git a/mlnode/iotdb/mlnode/process/trial.py b/mlnode/iotdb/mlnode/process/trial.py index 973fb31f6f0..7fc34477686 100644 --- a/mlnode/iotdb/mlnode/process/trial.py +++ b/mlnode/iotdb/mlnode/process/trial.py @@ -22,12 +22,10 @@ from typing import Dict, Tuple import numpy as np import torch import torch.nn as nn -from torch.nn.modules import loss -from torch.optim import Optimizer from torch.utils.data import DataLoader, Dataset from iotdb.mlnode.algorithm.metric import all_metrics, build_metrics -from iotdb.mlnode.client import client_manager, DataNodeClient, ConfigNodeClient +from iotdb.mlnode.client import client_manager from iotdb.mlnode.log import logger from iotdb.mlnode.storage import model_storage from iotdb.thrift.common.ttypes import TrainingState diff --git a/mlnode/iotdb/mlnode/storage.py b/mlnode/iotdb/mlnode/storage.py index db3b081a2f1..4c038786174 100644 --- a/mlnode/iotdb/mlnode/storage.py +++ b/mlnode/iotdb/mlnode/storage.py @@ -19,12 +19,11 @@ import json import os import shutil +import threading from typing import Dict, Tuple import torch import torch.nn as nn -import threading - from pylru import lrucache from iotdb.mlnode.config import descriptor diff --git a/mlnode/test/test_create_forecast_dataset.py b/mlnode/test/test_create_forecast_dataset.py index c9e506dfefb..49e2d177e8f 100644 --- a/mlnode/test/test_create_forecast_dataset.py +++ b/mlnode/test/test_create_forecast_dataset.py @@ -18,6 +18,7 @@ import os import requests + from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType from iotdb.mlnode.data_access.factory import create_forecast_dataset diff --git a/mlnode/test/test_create_forecast_model.py b/mlnode/test/test_create_forecast_model.py index a100d01c1ca..ed439fd9c05 100644 --- a/mlnode/test/test_create_forecast_model.py +++ b/mlnode/test/test_create_forecast_model.py @@ -16,6 +16,7 @@ # under the License. # import torch + from iotdb.mlnode.algorithm.enums import ForecastTaskType from iotdb.mlnode.algorithm.factory import create_forecast_model from iotdb.mlnode.exception import BadConfigValueError diff --git a/mlnode/test/test_model_storage.py b/mlnode/test/test_model_storage.py index 90d6bebdc1d..9a6399b2ea2 100644 --- a/mlnode/test/test_model_storage.py +++ b/mlnode/test/test_model_storage.py @@ -20,6 +20,7 @@ import os import time import torch.nn as nn + from iotdb.mlnode.config import descriptor from iotdb.mlnode.exception import ModelNotExistError from iotdb.mlnode.storage import model_storage diff --git a/mlnode/test/test_serde.py b/mlnode/test/test_serde.py index 3454f41ddec..c05083be417 100644 --- a/mlnode/test/test_serde.py +++ b/mlnode/test/test_serde.py @@ -18,9 +18,10 @@ import numpy as np import pandas as pd -from iotdb.mlnode.serde import convert_to_df from pandas.testing import assert_frame_equal +from iotdb.mlnode.serde import convert_to_df + device_id = "root.wt1" ts_path_lst = [
