This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 03e701860ad [AINode] Add window_step options for dataset (#15857)
03e701860ad is described below

commit 03e701860ad303b732688faabc77a9a146518de4
Author: Yongzao <[email protected]>
AuthorDate: Wed Jul 2 19:54:30 2025 +0800

    [AINode] Add window_step options for dataset (#15857)
---
 iotdb-core/ainode/ainode/core/ingress/dataset.py |  2 ++
 iotdb-core/ainode/ainode/core/ingress/iotdb.py   | 40 ++++++++++++++++++------
 2 files changed, 32 insertions(+), 10 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/ingress/dataset.py 
b/iotdb-core/ainode/ainode/core/ingress/dataset.py
index 316c4235067..4e3b5293c16 100644
--- a/iotdb-core/ainode/ainode/core/ingress/dataset.py
+++ b/iotdb-core/ainode/ainode/core/ingress/dataset.py
@@ -33,6 +33,7 @@ class BasicDatabaseForecastDataset(BasicDatabaseDataset):
         seq_len: int,
         input_token_len: int,
         output_token_len: int,
+        window_step: int,
     ):
         super().__init__(ip, port)
         # The number of the time series data points of the model input
@@ -42,3 +43,4 @@ class BasicDatabaseForecastDataset(BasicDatabaseDataset):
         # The number of the time series data points of the model output
         self.output_token_len = output_token_len
         self.token_num = self.seq_len // self.input_token_len
+        self.window_step = window_step
diff --git a/iotdb-core/ainode/ainode/core/ingress/iotdb.py 
b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
index b9e844d9193..528c3cb7397 100644
--- a/iotdb-core/ainode/ainode/core/ingress/iotdb.py
+++ b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
@@ -59,6 +59,7 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
         seq_len: int,
         input_token_len: int,
         output_token_len: int,
+        window_step: int,
         data_schema_list: list,
         ip: str = 
AINodeDescriptor().get_config().get_ain_cluster_ingress_address(),
         port: int = 
AINodeDescriptor().get_config().get_ain_cluster_ingress_port(),
@@ -74,7 +75,9 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
         use_rate: float = 1.0,
         offset_rate: float = 0.0,
     ):
-        super().__init__(ip, port, seq_len, input_token_len, output_token_len)
+        super().__init__(
+            ip, port, seq_len, input_token_len, output_token_len, window_step
+        )
 
         self.SHOW_TIMESERIES = "show timeseries %s%s"
         self.COUNT_SERIES_SQL = "select count(%s) from %s%s"
@@ -139,7 +142,9 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
         window_sum = 0
         for seq_name, seq_value in sorted_series:
             # calculate and sum the number of training data windows for each 
time series
-            window_count = seq_value[1] - self.seq_len - self.output_token_len 
+ 1
+            window_count = (
+                seq_value[1] - self.seq_len - self.output_token_len + 1
+            ) // self.window_step
             if window_count <= 1:
                 continue
             use_window_count = int(window_count * self.use_rate)
@@ -176,14 +181,16 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
                 # try to get the training data window from cache first
                 series_data = torch.tensor(series_data)
                 result = series_data[
-                    window_index : window_index + self.seq_len + 
self.output_token_len
+                    window_index * self.window_step : window_index * 
self.window_step
+                    + self.seq_len
+                    + self.output_token_len
                 ]
                 return (
                     result[0 : self.seq_len],
                     result[self.input_token_len : self.seq_len + 
self.output_token_len],
                     np.ones(self.token_num, dtype=np.int32),
                 )
-        result = []
+        series_data = []
         sql = ""
         try:
             if self.cache_enable:
@@ -204,12 +211,18 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
                 )
             with self.session.execute_query_statement(sql) as query_result:
                 while query_result.has_next():
-                    
result.append(get_field_value(query_result.next().get_fields()[0]))
+                    series_data.append(
+                        get_field_value(query_result.next().get_fields()[0])
+                    )
         except Exception as e:
             logger.error("Executing sql: {} with exception: {}".format(sql, e))
         if self.cache_enable:
-            self.cache.put(cache_key, result)
-        result = torch.tensor(result)
+            self.cache.put(cache_key, series_data)
+        result = series_data[
+            window_index * self.window_step : window_index * self.window_step
+            + self.seq_len
+            + self.output_token_len
+        ]
         return (
             result[0 : self.seq_len],
             result[self.input_token_len : self.seq_len + 
self.output_token_len],
@@ -230,6 +243,7 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
         seq_len: int,
         input_token_len: int,
         output_token_len: int,
+        window_step: int,
         data_schema_list: list,
         ip: str = 
AINodeDescriptor().get_config().get_ain_cluster_ingress_address(),
         port: int = 
AINodeDescriptor().get_config().get_ain_cluster_ingress_port(),
@@ -245,7 +259,9 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
         use_rate: float = 1.0,
         offset_rate: float = 0.0,
     ):
-        super().__init__(ip, port, seq_len, input_token_len, output_token_len)
+        super().__init__(
+            ip, port, seq_len, input_token_len, output_token_len, window_step
+        )
 
         table_session_config = TableSessionConfig(
             node_urls=[f"{ip}:{port}"],
@@ -302,7 +318,9 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
         window_sum = 0
         for seq_name, seq_values in series_map.items():
             # calculate and sum the number of training data windows for each 
time series
-            window_count = len(seq_values) - self.seq_len - 
self.output_token_len + 1
+            window_count = (
+                len(seq_values) - self.seq_len - self.output_token_len + 1
+            ) // self.window_step
             if window_count <= 1:
                 continue
             use_window_count = int(window_count * self.use_rate)
@@ -331,7 +349,9 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
             window_index -= self.series_with_prefix_sum[series_index - 1][2]
         window_index += self.series_with_prefix_sum[series_index][3]
         result = self.series_with_prefix_sum[series_index][4][
-            window_index : window_index + self.seq_len + self.output_token_len
+            window_index * self.window_step : window_index * self.window_step
+            + self.seq_len
+            + self.output_token_len
         ]
         result = torch.tensor(result)
         return (

Reply via email to