This is an automated email from the ASF dual-hosted git repository. hui pushed a commit to branch mlnode/test in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit c67e2a616f6e1b67c76b2f60120642dea4bf1397 Author: Minghui Liu <[email protected]> AuthorDate: Tue Apr 4 17:08:46 2023 +0800 fix fetch_timeseries --- mlnode/iotdb/mlnode/client.py | 29 +++++++++++++++++++---- mlnode/iotdb/mlnode/data_access/offline/source.py | 22 +++-------------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py index 1560b0507a..bf46846fe1 100644 --- a/mlnode/iotdb/mlnode/client.py +++ b/mlnode/iotdb/mlnode/client.py @@ -16,7 +16,8 @@ # under the License. # import time - +import pandas as pd +from iotdb.mlnode import serde from thrift.protocol import TBinaryProtocol, TCompactProtocol from thrift.Thrift import TException from thrift.transport import TSocket, TTransport @@ -126,7 +127,7 @@ class DataNodeClient(object): query_expressions: list, query_filter: str = None, fetch_size: int = DEFAULT_FETCH_SIZE, - timeout: int = DEFAULT_TIMEOUT) -> TFetchTimeseriesResp: + timeout: int = DEFAULT_TIMEOUT) -> [int, bool, pd.DataFrame]: req = TFetchTimeseriesReq( queryExpressions=query_expressions, queryFilter=query_filter, @@ -136,10 +137,30 @@ class DataNodeClient(object): try: resp = self.__client.fetchTimeseries(req) verify_success(resp.status, "An error occurs when calling fetch_timeseries()") - return resp - except TTransport.TException as e: + + if len(resp.tsDataset) == 0: + raise RuntimeError(f'No data fetched with query filter: {query_filter}') + + data = serde.convert_to_df(resp.columnNameList, + resp.columnTypeList, + resp.columnNameIndexMap, + resp.tsDataset) + if data.empty: + raise RuntimeError( + f'Fetched empty data with query expressions: {query_expressions} and query filter: {query_filter}') + return resp.queryId, resp.hasMoreData, data + except Exception as e: + logger.warn( + f'Fail to fetch data with query expressions: {query_expressions} and query filter: {query_filter}') raise e + def fetch_window_batch(self, + query_expressions: list, + query_filter: str = None, + fetch_size: int = DEFAULT_FETCH_SIZE, + timeout: int = DEFAULT_TIMEOUT) -> [int, bool, list[pd.DataFrame]]: + pass + def record_model_metrics(self, model_id: str, trial_id: str, diff --git a/mlnode/iotdb/mlnode/data_access/offline/source.py b/mlnode/iotdb/mlnode/data_access/offline/source.py index 0422bb373d..05ef96d16a 100644 --- a/mlnode/iotdb/mlnode/data_access/offline/source.py +++ b/mlnode/iotdb/mlnode/data_access/offline/source.py @@ -72,25 +72,9 @@ class ThriftDataSource(DataSource): except Exception: raise RuntimeError('Fail to establish connection with DataNode') - try: - res = data_client.fetch_timeseries( - query_expressions=self.query_expressions, - query_filter=self.query_filter, - ) - except Exception: - raise RuntimeError(f'Fail to fetch data with query expressions: {self.query_expressions}' - f' and query filter: {self.query_filter}') - - if len(res.tsDataset) == 0: - raise RuntimeError(f'No data fetched with query filter: {self.query_filter}') - - raw_data = serde.convert_to_df(res.columnNameList, - res.columnTypeList, - res.columnNameIndexMap, - res.tsDataset) - if raw_data.empty: - raise RuntimeError(f'Fetched empty data with query expressions: ' - f'{self.query_expressions} and query filter: {self.query_filter}') + query_id, has_more_data, raw_data = data_client.fetch_timeseries(self.query_expressions, self.query_filter) + # TODO: consider has_more_data + cols_data = raw_data.columns[1:] self.data = raw_data[cols_data].values self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values, unit='ms', utc=True) \
