This is an automated email from the ASF dual-hosted git repository.

Caideyipi pushed a commit to branch hotfix/2.0.9.4-sjzt
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit a50eea9cbe5493b3202d421b32381ca73e55dc99
Author: 陈荣钊 <[email protected]>
AuthorDate: Sat Jun 6 09:48:42 2026 +0800

    [TIMECHODB][TIMECHODB][AINode] Support NaN auto-adapt fill for AINode 
forecast
---
 .../core/inference/pipeline/basic_pipeline.py      |  56 ++++++++-
 iotdb-core/ainode/test/test_forecast_auto_adapt.py | 135 +++++++++++++++++++++
 .../timecho/ainode/core/ingress/data_fetcher.py    |   4 +-
 .../function/tvf/TimechoForecastTableFunction.java |  48 +++++++-
 .../relational/analyzer/TableFunctionTest.java     |  57 +++++++++
 5 files changed, 292 insertions(+), 8 deletions(-)

diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
index f4bf914a846..4c9f18644d0 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
@@ -30,6 +30,44 @@ from iotdb.ainode.core.model.model_loader import load_model
 BACKEND = DeviceManager()
 logger = Logger()
 
+AUTO_ADAPT_FILL_VALUE_KEY = "auto_adapt_fill_value"
+_AUTO_ADAPT_ZERO_VALUES = {"0", "0.0", "zero"}
+_AUTO_ADAPT_NAN_VALUES = {"nan"}
+
+
+def _parse_auto_adapt_fill_value(infer_kwargs: dict) -> float:
+    raw_fill_value = infer_kwargs.get(AUTO_ADAPT_FILL_VALUE_KEY)
+    if raw_fill_value is None:
+        raw_fill_value = next(
+            (
+                value
+                for key, value in infer_kwargs.items()
+                if str(key).lower() == AUTO_ADAPT_FILL_VALUE_KEY
+            ),
+            "0",
+        )
+    if raw_fill_value is None:
+        return 0.0
+    normalized_fill_value = str(raw_fill_value).strip().lower()
+    if normalized_fill_value in _AUTO_ADAPT_ZERO_VALUES:
+        return 0.0
+    if normalized_fill_value in _AUTO_ADAPT_NAN_VALUES:
+        return torch.nan
+    raise ValueError(
+        f"Unsupported {AUTO_ADAPT_FILL_VALUE_KEY}: {raw_fill_value}. "
+        "Expected one of ['0', 'NaN']."
+    )
+
+
+def _pad_1d_tensor(
+    tensor: torch.Tensor, left_pad_size: int, right_pad_size: int, fill_value: 
float
+) -> torch.Tensor:
+    if left_pad_size == 0 and right_pad_size == 0:
+        return tensor
+    if torch.isnan(torch.tensor(fill_value)) and not 
torch.is_floating_point(tensor):
+        tensor = tensor.to(torch.float32)
+    return F.pad(tensor, (left_pad_size, right_pad_size), value=fill_value)
+
 
 class BasicPipeline(ABC):
     def __init__(self, model_info: ModelInfo, **model_kwargs):
@@ -95,6 +133,12 @@ class ForecastPipeline(BasicPipeline):
         if isinstance(inputs, list):
             output_length = infer_kwargs.get("output_length", 96)
             auto_adapt = infer_kwargs.get("auto_adapt", True)
+            if auto_adapt is None:
+                auto_adapt = True
+            auto_adapt_fill_value = _parse_auto_adapt_fill_value(infer_kwargs)
+            auto_adapt_fill_value_name = (
+                "NaN" if torch.isnan(torch.tensor(auto_adapt_fill_value)) else 
"0"
+            )
             for idx, input_dict in enumerate(inputs):
                 # Check if the dictionary contains the expected keys
                 if not isinstance(input_dict, dict):
@@ -150,11 +194,11 @@ class ForecastPipeline(BasicPipeline):
                                 past_covariates[cov_key] = 
cov_value[-input_length:]
                             else:
                                 logger.warning(
-                                    f"Past covariate {cov_key} at index {idx} 
has length {cov_value.shape[0]} (< {input_length}), which will be padded with 
zeros at the beginning."
+                                    f"Past covariate {cov_key} at index {idx} 
has length {cov_value.shape[0]} (< {input_length}), which will be padded with 
{auto_adapt_fill_value_name} at the beginning."
                                 )
                                 pad_size = input_length - cov_value.shape[0]
-                                past_covariates[cov_key] = F.pad(
-                                    cov_value, (pad_size, 0)
+                                past_covariates[cov_key] = _pad_1d_tensor(
+                                    cov_value, pad_size, 0, 
auto_adapt_fill_value
                                 )
                         else:
                             raise ValueError(
@@ -205,11 +249,11 @@ class ForecastPipeline(BasicPipeline):
                                     ]
                                 else:
                                     logger.warning(
-                                        f"Future covariate {cov_key} at index 
{idx} has length {cov_value.shape[0]} (< {output_length}), which will be padded 
with zeros at the end."
+                                        f"Future covariate {cov_key} at index 
{idx} has length {cov_value.shape[0]} (< {output_length}), which will be padded 
with {auto_adapt_fill_value_name} at the end."
                                     )
                                     pad_size = output_length - 
cov_value.shape[0]
-                                    future_covariates[cov_key] = F.pad(
-                                        cov_value, (0, pad_size)
+                                    future_covariates[cov_key] = 
_pad_1d_tensor(
+                                        cov_value, 0, pad_size, 
auto_adapt_fill_value
                                     )
                             else:
                                 raise ValueError(
diff --git a/iotdb-core/ainode/test/test_forecast_auto_adapt.py 
b/iotdb-core/ainode/test/test_forecast_auto_adapt.py
new file mode 100644
index 00000000000..7a0a1d04204
--- /dev/null
+++ b/iotdb-core/ainode/test/test_forecast_auto_adapt.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import sys
+import types
+
+import torch
+
+model_loader_stub = types.ModuleType("iotdb.ainode.core.model.model_loader")
+model_loader_stub.load_model = lambda *args, **kwargs: None
+sys.modules["iotdb.ainode.core.model.model_loader"] = model_loader_stub
+
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import 
ForecastPipeline
+
+
+class _NoopForecastPipeline(ForecastPipeline):
+    def __init__(self):
+        pass
+
+    def forecast(self, inputs, **infer_kwargs):
+        return inputs
+
+
+def test_auto_adapt_padding_defaults_to_zero():
+    inputs = [
+        {
+            "targets": torch.tensor([[1.0, 2.0, 3.0]]),
+            "past_covariates": {"cov": torch.tensor([2.0])},
+            "future_covariates": {"cov": torch.tensor([4.0])},
+        }
+    ]
+
+    processed = _NoopForecastPipeline().preprocess(
+        inputs, output_length=3, auto_adapt=True
+    )
+
+    assert torch.equal(
+        processed[0]["past_covariates"]["cov"], torch.tensor([0.0, 0.0, 2.0])
+    )
+    assert torch.equal(
+        processed[0]["future_covariates"]["cov"], torch.tensor([4.0, 0.0, 0.0])
+    )
+
+
+def test_auto_adapt_padding_can_use_nan():
+    inputs = [
+        {
+            "targets": torch.tensor([[1.0, 2.0, 3.0]]),
+            "past_covariates": {"cov": torch.tensor([2])},
+            "future_covariates": {"cov": torch.tensor([4])},
+        }
+    ]
+
+    processed = _NoopForecastPipeline().preprocess(
+        inputs,
+        output_length=3,
+        auto_adapt=True,
+        auto_adapt_fill_value="NaN",
+    )
+
+    assert torch.isnan(processed[0]["past_covariates"]["cov"][:2]).all()
+    assert processed[0]["past_covariates"]["cov"][2].item() == 2
+    assert processed[0]["past_covariates"]["cov"].dtype == torch.float32
+    assert processed[0]["future_covariates"]["cov"][0].item() == 4
+    assert torch.isnan(processed[0]["future_covariates"]["cov"][1:]).all()
+    assert processed[0]["future_covariates"]["cov"].dtype == torch.float32
+
+
+def test_auto_adapt_fill_value_rejects_invalid_value():
+    inputs = [{"targets": torch.tensor([[1.0, 2.0, 3.0]])}]
+
+    try:
+        _NoopForecastPipeline().preprocess(
+            inputs,
+            output_length=3,
+            auto_adapt=True,
+            auto_adapt_fill_value="1",
+        )
+        assert False
+    except ValueError as e:
+        assert "Unsupported auto_adapt_fill_value" in str(e)
+
+
+def test_iotdb_data_fetcher_converts_covariates_to_float_tensors():
+    session_stub = types.ModuleType("iotdb.Session")
+    session_stub.Session = object
+    sys.modules["iotdb.Session"] = session_stub
+
+    table_session_stub = types.ModuleType("iotdb.table_session")
+    table_session_stub.TableSession = object
+    table_session_stub.TableSessionConfig = object
+    sys.modules["iotdb.table_session"] = table_session_stub
+
+    field_stub = types.ModuleType("iotdb.utils.Field")
+    field_stub.Field = object
+    sys.modules["iotdb.utils.Field"] = field_stub
+
+    constants_stub = types.ModuleType("iotdb.utils.IoTDBConstants")
+    constants_stub.TSDataType = types.SimpleNamespace(
+        INT32=1, INT64=2, FLOAT=3, DOUBLE=4, TIMESTAMP=5, TEXT=6
+    )
+    sys.modules["iotdb.utils.IoTDBConstants"] = constants_stub
+
+    row_record_stub = types.ModuleType("iotdb.utils.RowRecord")
+    row_record_stub.RowRecord = object
+    sys.modules["iotdb.utils.RowRecord"] = row_record_stub
+
+    from timecho.ainode.core.ingress.data_fetcher import IoTDBDataFetcher
+
+    series_map = {("__DEFAULT_TAG__",): {"cov": [1, 2, 3]}}
+    timestamps_map = {("__DEFAULT_TAG__",): [3, 1, 2]}
+
+    sorted_series_map, sorted_timestamps_map = 
IoTDBDataFetcher._sort_data_by_timestamp(
+        None, series_map, timestamps_map
+    )
+
+    covariates = sorted_series_map[("__DEFAULT_TAG__",)]["cov"]
+    assert covariates.dtype == torch.float32
+    assert torch.equal(covariates, torch.tensor([2.0, 3.0, 1.0]))
+    assert sorted_timestamps_map[("__DEFAULT_TAG__",)] == [1, 2, 3]
diff --git a/iotdb-core/ainode/timecho/ainode/core/ingress/data_fetcher.py 
b/iotdb-core/ainode/timecho/ainode/core/ingress/data_fetcher.py
index b615161f536..8c768407574 100644
--- a/iotdb-core/ainode/timecho/ainode/core/ingress/data_fetcher.py
+++ b/iotdb-core/ainode/timecho/ainode/core/ingress/data_fetcher.py
@@ -118,6 +118,8 @@ class IoTDBDataFetcher:
                     value_col_field: Field = cur_data.get_fields()[value_col]
                     col_name = column_names[value_col]
                     value = get_field_value(value_col_field)
+                    if value is None:
+                        value = torch.nan
 
                     if col_name not in series_map[tag_values]:
                         series_map[tag_values][col_name] = []
@@ -146,5 +148,5 @@ class IoTDBDataFetcher:
             timestamps_map[tag_values] = [timestamps[i] for i in 
sorted_indices]
             for col_name, values_list in cov_map.items():
                 sorted_values = [values_list[i] for i in sorted_indices]
-                cov_map[col_name] = torch.tensor(sorted_values)
+                cov_map[col_name] = torch.tensor(sorted_values, 
dtype=torch.float32)
         return series_map, timestamps_map
diff --git 
a/iotdb-core/datanode/src/main/java/com/timecho/iotdb/db/queryengine/plan/relational/function/tvf/TimechoForecastTableFunction.java
 
b/iotdb-core/datanode/src/main/java/com/timecho/iotdb/db/queryengine/plan/relational/function/tvf/TimechoForecastTableFunction.java
index fd622630c06..18a10d69fe4 100644
--- 
a/iotdb-core/datanode/src/main/java/com/timecho/iotdb/db/queryengine/plan/relational/function/tvf/TimechoForecastTableFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/com/timecho/iotdb/db/queryengine/plan/relational/function/tvf/TimechoForecastTableFunction.java
@@ -54,6 +54,7 @@ import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -194,6 +195,10 @@ public class TimechoForecastTableFunction extends 
ForecastTableFunction {
   private static final String DEFAULT_FUTURE_COVS = "";
   private static final String AUTO_ADAPT_PARAMETER_NAME = "AUTO_ADAPT";
   private static final Boolean DEFAULT_AUTO_ADAPT = Boolean.TRUE;
+  private static final String AUTO_ADAPT_FILL_VALUE_PARAMETER_NAME = 
"AUTO_ADAPT_FILL_VALUE";
+  private static final String AUTO_ADAPT_FILL_VALUE_OPTION_KEY = 
"auto_adapt_fill_value";
+  private static final String DEFAULT_AUTO_ADAPT_FILL_VALUE = "0";
+  private static final String NAN_AUTO_ADAPT_FILL_VALUE = "NaN";
 
   @Override
   public List<ParameterSpecification> getArgumentsSpecifications() {
@@ -243,6 +248,11 @@ public class TimechoForecastTableFunction extends 
ForecastTableFunction {
             .type(Type.BOOLEAN)
             .defaultValue(DEFAULT_AUTO_ADAPT)
             .build(),
+        ScalarParameterSpecification.builder()
+            .name(AUTO_ADAPT_FILL_VALUE_PARAMETER_NAME)
+            .type(Type.STRING)
+            .defaultValue(DEFAULT_AUTO_ADAPT_FILL_VALUE)
+            .build(),
         ScalarParameterSpecification.builder()
             .name(OPTIONS_PARAMETER_NAME)
             .type(Type.STRING)
@@ -327,7 +337,11 @@ public class TimechoForecastTableFunction extends 
ForecastTableFunction {
         (String) ((ScalarArgument) 
arguments.get(HISTORY_COVS_PARAMETER_NAME)).getValue();
     String futureCovs =
         (String) ((ScalarArgument) 
arguments.get(FUTURE_COVS_PARAMETER_NAME)).getValue();
+    String autoAdaptFillValue =
+        (String) ((ScalarArgument) 
arguments.get(AUTO_ADAPT_FILL_VALUE_PARAMETER_NAME)).getValue();
     String options = (String) ((ScalarArgument) 
arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
+    Map<String, String> parsedOptions =
+        parseOptionsWithAutoAdaptFillValue(options, autoAdaptFillValue);
 
     ForecastTableFunctionHandle functionHandle =
         new TimechoForecastTableFunctionHandle(
@@ -335,7 +349,7 @@ public class TimechoForecastTableFunction extends 
ForecastTableFunction {
             keepInput,
             MAX_INPUT_LENGTH,
             modelId,
-            parseOptions(options),
+            parsedOptions,
             historyCovs,
             futureCovs,
             outputLength,
@@ -356,6 +370,38 @@ public class TimechoForecastTableFunction extends 
ForecastTableFunction {
     return new TimechoForecastTableFunctionHandle();
   }
 
+  private static Map<String, String> parseOptionsWithAutoAdaptFillValue(
+      String options, String autoAdaptFillValue) {
+    Map<String, String> parsedOptions = new HashMap<>(parseOptions(options));
+    if (parsedOptions.containsKey(AUTO_ADAPT_FILL_VALUE_OPTION_KEY)) {
+      parsedOptions.put(
+          AUTO_ADAPT_FILL_VALUE_OPTION_KEY,
+          
normalizeAutoAdaptFillValue(parsedOptions.get(AUTO_ADAPT_FILL_VALUE_OPTION_KEY)));
+    }
+    String normalizedAutoAdaptFillValue = 
normalizeAutoAdaptFillValue(autoAdaptFillValue);
+    if (!DEFAULT_AUTO_ADAPT_FILL_VALUE.equals(normalizedAutoAdaptFillValue)) {
+      parsedOptions.put(AUTO_ADAPT_FILL_VALUE_OPTION_KEY, 
normalizedAutoAdaptFillValue);
+    }
+    return parsedOptions;
+  }
+
+  private static String normalizeAutoAdaptFillValue(String fillValue) {
+    if (fillValue == null) {
+      return DEFAULT_AUTO_ADAPT_FILL_VALUE;
+    }
+    String normalizedFillValue = fillValue.trim().toLowerCase(Locale.ENGLISH);
+    if (DEFAULT_AUTO_ADAPT_FILL_VALUE.equals(normalizedFillValue)
+        || "0.0".equals(normalizedFillValue)
+        || "zero".equals(normalizedFillValue)) {
+      return DEFAULT_AUTO_ADAPT_FILL_VALUE;
+    }
+    if ("nan".equals(normalizedFillValue)) {
+      return NAN_AUTO_ADAPT_FILL_VALUE;
+    }
+    throw new SemanticException(
+        String.format("%s should be either '0' or 'NaN'", 
AUTO_ADAPT_FILL_VALUE_PARAMETER_NAME));
+  }
+
   @Override
   public TableFunctionProcessorProvider getProcessorProvider(
       TableFunctionHandle tableFunctionHandle) {
diff --git 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
index 9581083ba71..bb94508906f 100644
--- 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
+++ 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
@@ -475,6 +475,52 @@ public class TableFunctionTest {
                 group(ImmutableList.of(sort("time_0", ASCENDING, FIRST)), 0, 
tableScan))));
   }
 
+  @Test
+  public void testForecastFunctionWithAutoAdaptFillValue() {
+    PlanTester planTester = new PlanTester();
+
+    String sql =
+        "SELECT * FROM FORECAST("
+            + "targets => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' 
AND tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
+            + "model_id => 'timer_xl', auto_adapt_fill_value => 'NaN')";
+    LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+
+    PlanMatchPattern tableScan =
+        tableScan("testdb.table1", ImmutableMap.of("time_0", "time", "s3_1", 
"s3"));
+    Consumer<TableFunctionProcessorMatcher.Builder> tableFunctionMatcher =
+        builder ->
+            builder
+                .name("forecast")
+                .properOutputs("time", "s3")
+                .requiredSymbols("time_0", "s3_1")
+                .handle(
+                    new 
TimechoForecastTableFunction.TimechoForecastTableFunctionHandle(
+                        true,
+                        false,
+                        2880,
+                        "timer_xl",
+                        ImmutableMap.of("auto_adapt_fill_value", "NaN"),
+                        "",
+                        "",
+                        96,
+                        DEFAULT_OUTPUT_START_TIME,
+                        DEFAULT_OUTPUT_INTERVAL,
+                        Collections.singletonList(DOUBLE)));
+
+    assertPlan(
+        logicalQueryPlan,
+        anyTree(
+            tableFunctionProcessor(
+                tableFunctionMatcher,
+                sort(
+                    ImmutableList.of(sort("time_0", ASCENDING, FIRST)),
+                    topK(
+                        1440,
+                        ImmutableList.of(sort("time_0", DESCENDING, LAST)),
+                        false,
+                        tableScan)))));
+  }
+
   @Test
   public void testForecastFunctionAbnormal() {
     // default order by time asc
@@ -488,5 +534,16 @@ public class TableFunctionTest {
     } catch (SemanticException e) {
       assertEquals("TIMECOL should never be null or empty.", e.getMessage());
     }
+
+    sql =
+        "SELECT * FROM FORECAST("
+            + "targets => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' 
AND tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
+            + "model_id => 'timer_xl', auto_adapt_fill_value => '1')";
+    try {
+      analyzeSQL(sql, TEST_MATADATA, QUERY_CONTEXT);
+      fail();
+    } catch (SemanticException e) {
+      assertEquals("AUTO_ADAPT_FILL_VALUE should be either '0' or 'NaN'", 
e.getMessage());
+    }
   }
 }

Reply via email to