This is an automated email from the ASF dual-hosted git repository.

ycycse pushed a commit to branch ycy/dataLoader
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 133aeab8c0c0360ed87096b24299c6d2bba06afe
Author: YangCaiyin <[email protected]>
AuthorDate: Tue Jun 10 10:21:21 2025 +0800

    init dataset module
---
 iotdb-core/ainode/ainode/core/config.py            |   8 +
 .../ainode/ainode/core/dataProvider/__init__.py    |  17 ++
 .../ainode/ainode/core/dataProvider/dataset.py     |  62 +++++
 .../ainode/ainode/core/dataProvider/iotdb.py       | 307 +++++++++++++++++++++
 iotdb-core/ainode/ainode/core/util/cache.py        |  85 ++++++
 5 files changed, 479 insertions(+)

diff --git a/iotdb-core/ainode/ainode/core/config.py 
b/iotdb-core/ainode/ainode/core/config.py
index 62de76fcbb8..29d925ec086 100644
--- a/iotdb-core/ainode/ainode/core/config.py
+++ b/iotdb-core/ainode/ainode/core/config.py
@@ -52,6 +52,8 @@ class AINodeConfig(object):
         # log directory
         self._ain_logs_dir: str = AINODE_LOG_DIR
 
+        self._ain_data_cache_size = 50
+
         # Directory to save models
         self._ain_models_dir = AINODE_MODELS_DIR
 
@@ -94,6 +96,12 @@ class AINodeConfig(object):
     def set_build_info(self, build_info: str) -> None:
         self._build_info = build_info
 
+    def get_ain_data_storage_cache_size(self) -> int:
+        return self._ain_data_cache_size
+
+    def set_ain_data_cache_size(self, ain_data_cache_size: int) -> None:
+        self._ain_data_cache_size = ain_data_cache_size
+
     def set_version_info(self, version_info: str) -> None:
         self._version_info = version_info
 
diff --git a/iotdb-core/ainode/ainode/core/dataProvider/__init__.py 
b/iotdb-core/ainode/ainode/core/dataProvider/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/dataProvider/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git a/iotdb-core/ainode/ainode/core/dataProvider/dataset.py 
b/iotdb-core/ainode/ainode/core/dataProvider/dataset.py
new file mode 100644
index 00000000000..b5e8493cefc
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/dataProvider/dataset.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from torch.utils.data import Dataset
+
+from ainode.core.dataProvider.iotdb import IoTDBTableModelDataset, 
IoTDBTreeModelDataset
+from ainode.core.util.decorator import singleton
+
+
+class BasicDatabaseDataset(Dataset):
+    def __init__(self, ip: str, port: int):
+        self.ip = ip
+        self.port = port
+
+
+class BasicDatabaseForecastDataset(BasicDatabaseDataset):
+    def __init__(self, ip: str, port: int, input_len: int, output_len: int):
+        super().__init__(ip, port)
+        self.input_len = input_len
+        self.output_len = output_len
+
+
+def register_dataset(key: str, dataset: Dataset):
+    DatasetFactory().register(key, dataset)
+
+
+@singleton
+class DatasetFactory(object):
+
+    def __init__(self):
+        self.dataset_list = {
+            "iotdb.table": IoTDBTableModelDataset,
+            "iotdb.tree": IoTDBTreeModelDataset,
+        }
+
+    def register(self, key: str, dataset: Dataset):
+        if key not in self.dataset_list:
+            self.dataset_list[key] = dataset
+        else:
+            raise KeyError(f"Dataset {key} already exists")
+
+    def deregister(self, key: str):
+        del self.dataset_list[key]
+
+    def get_dataset(self, key: str):
+        if key not in self.dataset_list.keys():
+            raise KeyError(f"Dataset {key} does not exist")
+        return self.dataset_list[key]
diff --git a/iotdb-core/ainode/ainode/core/dataProvider/iotdb.py 
b/iotdb-core/ainode/ainode/core/dataProvider/iotdb.py
new file mode 100644
index 00000000000..6e43ec568d4
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/dataProvider/iotdb.py
@@ -0,0 +1,307 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import torch
+from iotdb.Session import Session
+from iotdb.table_session import TableSession, TableSessionConfig
+from iotdb.utils.Field import Field
+from iotdb.utils.IoTDBConstants import TSDataType
+from util.cache import MemoryLRUCache
+
+from ainode.core.config import AINodeDescriptor
+from ainode.core.dataProvider.dataset import BasicDatabaseForecastDataset
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+def get_field_value(field: Field):
+    data_type = field.get_data_type()
+    if data_type == TSDataType.INT32:
+        return field.get_int_value()
+    elif data_type == TSDataType.INT64:
+        return field.get_long_value()
+    elif data_type == TSDataType.FLOAT:
+        return field.get_float_value()
+    elif data_type == TSDataType.DOUBLE:
+        return field.get_double_value()
+    else:
+        return field.get_string_value()
+
+
+def _cache_enable() -> bool:
+    return AINodeDescriptor().get_config().get_ain_data_storage_cache_size() 
!= 0
+
+
+class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
+    cache = MemoryLRUCache()
+
+    def __init__(
+        self,
+        model_id: str,
+        input_len: int,
+        out_len: int,
+        schema_list: list,
+        ip: str = "127.0.0.1",
+        port: int = 6667,
+        username: str = "root",
+        password: str = "root",
+        time_zone: str = "UTC+8",
+        start_split: float = 0,
+        end_split: float = 1,
+    ):
+        super().__init__(ip, port, input_len, out_len)
+
+        self.SHOW_TIMESERIES = "show timeseries %s%s"
+        self.COUNT_SERIES_SQL = "select count(%s) from %s%s"
+        self.FETCH_SERIES_SQL = "select %s from %s%s"
+        self.FETCH_SERIES_RANGE_SQL = "select %s from %s offset %s limit %s%s"
+
+        self.TIME_CONDITION = " where time>%s and time<%s"
+
+        self.session = Session.init_from_node_urls(
+            node_urls=[f"{ip}:{port}"],
+            user=username,
+            password=password,
+            zone_id=time_zone,
+        )
+        self.session.open(False)
+        self.context_length = self.input_len + self.output_len
+        self._fetch_schema(schema_list)
+        self.start_idx = int(self.total_count * start_split)
+        self.end_idx = int(self.total_count * end_split)
+        self.cache_enable = _cache_enable()
+        self.cache_key_prefix = model_id + "_"
+
+    def _fetch_schema(self, schema_list: list):
+        series_to_length = {}
+        for schema in schema_list:
+            path_pattern = schema.schemaName
+            series_list = []
+
+            if schema.timeRange:
+                time_condition = self.TIME_CONDITION % (
+                    schema.timeRange[0],
+                    schema.timeRange[1],
+                )
+            else:
+                time_condition = ""
+
+            with self.session.execute_query_statement(
+                self.SHOW_TIMESERIES % (path_pattern, time_condition)
+            ) as show_timeseries_result:
+                while show_timeseries_result.has_next():
+                    series_list.append(
+                        
get_field_value(show_timeseries_result.next().get_fields()[0])
+                    )
+
+            for series in series_list:
+                split_series = series.split(".")
+                with self.session.execute_query_statement(
+                    self.COUNT_SERIES_SQL
+                    % (split_series[-1], ".".join(split_series[:-1]), 
time_condition)
+                ) as count_series_result:
+                    while count_series_result.has_next():
+                        length = get_field_value(
+                            count_series_result.next().get_fields()[0]
+                        )
+                        series_to_length[series] = (
+                            split_series,
+                            length,
+                            time_condition,
+                        )
+
+        sorted_series = sorted(series_to_length.items(), key=lambda x: x[1][1])
+        sorted_series_with_prefix_sum = []
+        window_sum = 0
+        for seq_name, seq_value in sorted_series:
+            window_count = seq_value[1] - self.context_length + 1
+            if window_count <= 0:
+                continue
+            window_sum += window_count
+            sorted_series_with_prefix_sum.append(
+                (seq_value[0], window_count, window_sum, seq_value[2])
+            )
+
+        self.total_count = window_sum
+        self.sorted_series = sorted_series_with_prefix_sum
+
+    def __getitem__(self, index):
+        window_index = index
+        series_index = 0
+        while self.sorted_series[series_index][2] < window_index:
+            series_index += 1
+
+        if series_index != 0:
+            window_index -= self.sorted_series[series_index - 1][2]
+
+        if window_index != 0:
+            window_index -= 1
+        series = self.sorted_series[series_index][0]
+        time_condition = self.sorted_series[series_index][3]
+        if self.cache_enable:
+            cache_key = self.cache_key_prefix + ".".join(series) + 
time_condition
+            series_data = self.cache.get(cache_key)
+            if series_data is not None:
+                series_data = torch.tensor(series_data)
+                result = series_data[window_index : window_index + 
self.context_length]
+                return result[0 : self.input_len].unsqueeze(-1), result[
+                    -self.output_len :
+                ].unsqueeze(-1)
+        result = []
+        try:
+            if self.cache_enable:
+                sql = self.FETCH_SERIES_SQL % (
+                    series[-1],
+                    ".".join(series[0:-1]),
+                    time_condition,
+                )
+            else:
+                sql = self.FETCH_SERIES_RANGE_SQL % (
+                    series[-1],
+                    ".".join(series[0:-1]),
+                    window_index,
+                    self.context_length,
+                    time_condition,
+                )
+            with self.session.execute_query_statement(sql) as query_result:
+                while query_result.has_next():
+                    
result.append(get_field_value(query_result.next().get_fields()[0]))
+        except Exception as e:
+            logger.error(e)
+        if self.cache_enable:
+            self.cache.put(cache_key, result)
+        result = torch.tensor(result)
+        return result[0 : self.input_len].unsqueeze(-1), result[
+            -self.output_len :
+        ].unsqueeze(-1)
+
+    def __len__(self):
+        return self.end_idx - self.start_idx
+
+
+class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
+
+    def __init__(
+        self,
+        input_len: int,
+        out_len: int,
+        data_schema_list: list,
+        ip: str = "127.0.0.1",
+        port: int = 6667,
+        username: str = "root",
+        password: str = "root",
+        time_zone: str = "UTC+8",
+        start_split: float = 0,
+        end_split: float = 1,
+    ):
+        super().__init__(ip, port, input_len, out_len)
+        if end_split < start_split:
+            raise ValueError("end_split must be greater than start_split")
+
+        # database , table
+        self.SELECT_SERIES_FORMAT_SQL = "select distinct item_id from %s"
+        self.COUNT_SERIES_LENGTH_SQL = (
+            "select count(value) from %s where item_id = '%s'"
+        )
+        self.FETCH_SERIES_SQL = (
+            "select value from %s where item_id = '%s' offset %s limit %s"
+        )
+        self.SERIES_NAME = "%s.%s"
+
+        table_session_config = TableSessionConfig(
+            node_urls=[f"{ip}:{port}"],
+            username=username,
+            password=password,
+            time_zone=time_zone,
+        )
+
+        self.session = TableSession(table_session_config)
+        self.context_length = self.input_len + self.output_len
+        self._fetch_schema(data_schema_list)
+
+        v = self.total_count * start_split
+        self.start_index = int(self.total_count * start_split)
+        self.end_index = self.total_count * end_split
+
+    def _fetch_schema(self, data_schema_list: list):
+        series_to_length = {}
+        for data_schema in data_schema_list:
+            series_list = []
+            with self.session.execute_query_statement(
+                self.SELECT_SERIES_FORMAT_SQL % data_schema
+            ) as show_devices_result:
+                while show_devices_result.has_next():
+                    series_list.append(
+                        
get_field_value(show_devices_result.next().get_fields()[0])
+                    )
+
+            for series in series_list:
+                with self.session.execute_query_statement(
+                    self.COUNT_SERIES_LENGTH_SQL % (data_schema.schemaName, 
series)
+                ) as count_series_result:
+                    length = 
get_field_value(count_series_result.next().get_fields()[0])
+                    series_to_length[
+                        self.SERIES_NAME % (data_schema.schemaName, series)
+                    ] = length
+
+        sorted_series = sorted(series_to_length.items(), key=lambda x: x[1])
+        sorted_series_with_prefix_sum = []
+        window_sum = 0
+        for seq_name, seq_length in sorted_series:
+            window_count = seq_length - self.context_length + 1
+            if window_count < 0:
+                continue
+            window_sum += window_count
+            sorted_series_with_prefix_sum.append((seq_name, window_count, 
window_sum))
+
+        self.total_count = window_sum
+        self.sorted_series = sorted_series_with_prefix_sum
+
+    def __getitem__(self, index):
+        window_index = index
+
+        series_index = 0
+
+        while self.sorted_series[series_index][2] < window_index:
+            series_index += 1
+
+        if series_index != 0:
+            window_index -= self.sorted_series[series_index - 1][2]
+
+        if window_index != 0:
+            window_index -= 1
+        series = self.sorted_series[series_index][0]
+        schema = series.split(".")
+
+        result = []
+        try:
+            with self.session.execute_query_statement(
+                self.FETCH_SERIES_SQL
+                % (schema[0:1], schema[2], window_index, self.context_length)
+            ) as query_result:
+                while query_result.has_next():
+                    
result.append(get_field_value(query_result.next().get_fields()[0]))
+        except Exception as e:
+            logger.error("Error happens when loading dataset str(e))")
+        result = torch.tensor(result)
+        return result[0 : self.input_len].unsqueeze(-1), result[
+            -self.output_len :
+        ].unsqueeze(-1)
+
+    def __len__(self):
+        return self.end_index - self.start_index
diff --git a/iotdb-core/ainode/ainode/core/util/cache.py 
b/iotdb-core/ainode/ainode/core/util/cache.py
new file mode 100644
index 00000000000..fb5a572a04d
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/util/cache.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import sys
+from collections import OrderedDict
+
+from ainode.core.config import AINodeDescriptor
+from ainode.core.util.decorator import singleton
+
+
+def _estimate_size(obj):
+    if isinstance(obj, str):
+        return len(obj) + 49
+    elif isinstance(obj, int):
+        return 28
+    elif isinstance(obj, list):
+        return 64 + sum(_estimate_size(x) for x in obj)
+    elif isinstance(obj, dict):
+        return 280 + sum(_estimate_size(k) + _estimate_size(v) for k, v in 
obj.items())
+    else:
+        return sys.getsizeof(obj)
+
+
+def _get_item_memory(key, value) -> int:
+    return _estimate_size(key) + _estimate_size(value)
+
+
+@singleton
+class MemoryLRUCache:
+    def __init__(self):
+        self.cache = OrderedDict()
+        self.max_memory_bytes = (
+            AINodeDescriptor().get_config().get_ain_data_storage_cache_size()
+            * 1024
+            * 1024
+        )
+        self.current_memory = 0
+
+    def get(self, key):
+        if key not in self.cache:
+            return None
+        value = self.cache[key]
+        self.cache.move_to_end(key)
+        return value
+
+    def put(self, key, value):
+        item_memory = _get_item_memory(key, value)
+
+        if key in self.cache:
+            old_value = self.cache[key]
+            old_memory = _get_item_memory(key, old_value)
+            self.current_memory -= old_memory
+            self.current_memory += item_memory
+            self._evict_if_needed()
+            self.cache[key] = value
+            self.cache.move_to_end(key)
+        else:
+            self.current_memory += item_memory
+            self._evict_if_needed()
+            self.cache[key] = value
+
+    def _evict_if_needed(self):
+        while self.current_memory > self.max_memory_bytes:
+            if not self.cache:
+                break
+            key, value = self.cache.popitem(last=False)
+            removed_memory = _get_item_memory(key, value)
+            self.current_memory -= removed_memory
+
+    def get_current_memory_mb(self) -> float:
+        return self.current_memory / (1024 * 1024)

Reply via email to