This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 03e701860ad [AINode] Add window_step options for dataset (#15857)
03e701860ad is described below
commit 03e701860ad303b732688faabc77a9a146518de4
Author: Yongzao <[email protected]>
AuthorDate: Wed Jul 2 19:54:30 2025 +0800
[AINode] Add window_step options for dataset (#15857)
---
iotdb-core/ainode/ainode/core/ingress/dataset.py | 2 ++
iotdb-core/ainode/ainode/core/ingress/iotdb.py | 40 ++++++++++++++++++------
2 files changed, 32 insertions(+), 10 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/ingress/dataset.py
b/iotdb-core/ainode/ainode/core/ingress/dataset.py
index 316c4235067..4e3b5293c16 100644
--- a/iotdb-core/ainode/ainode/core/ingress/dataset.py
+++ b/iotdb-core/ainode/ainode/core/ingress/dataset.py
@@ -33,6 +33,7 @@ class BasicDatabaseForecastDataset(BasicDatabaseDataset):
seq_len: int,
input_token_len: int,
output_token_len: int,
+ window_step: int,
):
super().__init__(ip, port)
# The number of the time series data points of the model input
@@ -42,3 +43,4 @@ class BasicDatabaseForecastDataset(BasicDatabaseDataset):
# The number of the time series data points of the model output
self.output_token_len = output_token_len
self.token_num = self.seq_len // self.input_token_len
+ self.window_step = window_step
diff --git a/iotdb-core/ainode/ainode/core/ingress/iotdb.py
b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
index b9e844d9193..528c3cb7397 100644
--- a/iotdb-core/ainode/ainode/core/ingress/iotdb.py
+++ b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
@@ -59,6 +59,7 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
seq_len: int,
input_token_len: int,
output_token_len: int,
+ window_step: int,
data_schema_list: list,
ip: str =
AINodeDescriptor().get_config().get_ain_cluster_ingress_address(),
port: int =
AINodeDescriptor().get_config().get_ain_cluster_ingress_port(),
@@ -74,7 +75,9 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
use_rate: float = 1.0,
offset_rate: float = 0.0,
):
- super().__init__(ip, port, seq_len, input_token_len, output_token_len)
+ super().__init__(
+ ip, port, seq_len, input_token_len, output_token_len, window_step
+ )
self.SHOW_TIMESERIES = "show timeseries %s%s"
self.COUNT_SERIES_SQL = "select count(%s) from %s%s"
@@ -139,7 +142,9 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
window_sum = 0
for seq_name, seq_value in sorted_series:
# calculate and sum the number of training data windows for each
time series
- window_count = seq_value[1] - self.seq_len - self.output_token_len
+ 1
+ window_count = (
+ seq_value[1] - self.seq_len - self.output_token_len + 1
+ ) // self.window_step
if window_count <= 1:
continue
use_window_count = int(window_count * self.use_rate)
@@ -176,14 +181,16 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
# try to get the training data window from cache first
series_data = torch.tensor(series_data)
result = series_data[
- window_index : window_index + self.seq_len +
self.output_token_len
+ window_index * self.window_step : window_index *
self.window_step
+ + self.seq_len
+ + self.output_token_len
]
return (
result[0 : self.seq_len],
result[self.input_token_len : self.seq_len +
self.output_token_len],
np.ones(self.token_num, dtype=np.int32),
)
- result = []
+ series_data = []
sql = ""
try:
if self.cache_enable:
@@ -204,12 +211,18 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
)
with self.session.execute_query_statement(sql) as query_result:
while query_result.has_next():
-
result.append(get_field_value(query_result.next().get_fields()[0]))
+ series_data.append(
+ get_field_value(query_result.next().get_fields()[0])
+ )
except Exception as e:
logger.error("Executing sql: {} with exception: {}".format(sql, e))
if self.cache_enable:
- self.cache.put(cache_key, result)
- result = torch.tensor(result)
+ self.cache.put(cache_key, series_data)
+ result = series_data[
+ window_index * self.window_step : window_index * self.window_step
+ + self.seq_len
+ + self.output_token_len
+ ]
return (
result[0 : self.seq_len],
result[self.input_token_len : self.seq_len +
self.output_token_len],
@@ -230,6 +243,7 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
seq_len: int,
input_token_len: int,
output_token_len: int,
+ window_step: int,
data_schema_list: list,
ip: str =
AINodeDescriptor().get_config().get_ain_cluster_ingress_address(),
port: int =
AINodeDescriptor().get_config().get_ain_cluster_ingress_port(),
@@ -245,7 +259,9 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
use_rate: float = 1.0,
offset_rate: float = 0.0,
):
- super().__init__(ip, port, seq_len, input_token_len, output_token_len)
+ super().__init__(
+ ip, port, seq_len, input_token_len, output_token_len, window_step
+ )
table_session_config = TableSessionConfig(
node_urls=[f"{ip}:{port}"],
@@ -302,7 +318,9 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
window_sum = 0
for seq_name, seq_values in series_map.items():
# calculate and sum the number of training data windows for each
time series
- window_count = len(seq_values) - self.seq_len -
self.output_token_len + 1
+ window_count = (
+ len(seq_values) - self.seq_len - self.output_token_len + 1
+ ) // self.window_step
if window_count <= 1:
continue
use_window_count = int(window_count * self.use_rate)
@@ -331,7 +349,9 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
window_index -= self.series_with_prefix_sum[series_index - 1][2]
window_index += self.series_with_prefix_sum[series_index][3]
result = self.series_with_prefix_sum[series_index][4][
- window_index : window_index + self.seq_len + self.output_token_len
+ window_index * self.window_step : window_index * self.window_step
+ + self.seq_len
+ + self.output_token_len
]
result = torch.tensor(result)
return (