This is an automated email from the ASF dual-hosted git repository. Caideyipi pushed a commit to branch patch-2094 in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit a50eea9cbe5493b3202d421b32381ca73e55dc99 Author: 陈荣钊 <[email protected]> AuthorDate: Sat Jun 6 09:48:42 2026 +0800 [TIMECHODB][TIMECHODB][AINode] Support NaN auto-adapt fill for AINode forecast --- .../core/inference/pipeline/basic_pipeline.py | 56 ++++++++- iotdb-core/ainode/test/test_forecast_auto_adapt.py | 135 +++++++++++++++++++++ .../timecho/ainode/core/ingress/data_fetcher.py | 4 +- .../function/tvf/TimechoForecastTableFunction.java | 48 +++++++- .../relational/analyzer/TableFunctionTest.java | 57 +++++++++ 5 files changed, 292 insertions(+), 8 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py index f4bf914a846..4c9f18644d0 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -30,6 +30,44 @@ from iotdb.ainode.core.model.model_loader import load_model BACKEND = DeviceManager() logger = Logger() +AUTO_ADAPT_FILL_VALUE_KEY = "auto_adapt_fill_value" +_AUTO_ADAPT_ZERO_VALUES = {"0", "0.0", "zero"} +_AUTO_ADAPT_NAN_VALUES = {"nan"} + + +def _parse_auto_adapt_fill_value(infer_kwargs: dict) -> float: + raw_fill_value = infer_kwargs.get(AUTO_ADAPT_FILL_VALUE_KEY) + if raw_fill_value is None: + raw_fill_value = next( + ( + value + for key, value in infer_kwargs.items() + if str(key).lower() == AUTO_ADAPT_FILL_VALUE_KEY + ), + "0", + ) + if raw_fill_value is None: + return 0.0 + normalized_fill_value = str(raw_fill_value).strip().lower() + if normalized_fill_value in _AUTO_ADAPT_ZERO_VALUES: + return 0.0 + if normalized_fill_value in _AUTO_ADAPT_NAN_VALUES: + return torch.nan + raise ValueError( + f"Unsupported {AUTO_ADAPT_FILL_VALUE_KEY}: {raw_fill_value}. " + "Expected one of ['0', 'NaN']." + ) + + +def _pad_1d_tensor( + tensor: torch.Tensor, left_pad_size: int, right_pad_size: int, fill_value: float +) -> torch.Tensor: + if left_pad_size == 0 and right_pad_size == 0: + return tensor + if torch.isnan(torch.tensor(fill_value)) and not torch.is_floating_point(tensor): + tensor = tensor.to(torch.float32) + return F.pad(tensor, (left_pad_size, right_pad_size), value=fill_value) + class BasicPipeline(ABC): def __init__(self, model_info: ModelInfo, **model_kwargs): @@ -95,6 +133,12 @@ class ForecastPipeline(BasicPipeline): if isinstance(inputs, list): output_length = infer_kwargs.get("output_length", 96) auto_adapt = infer_kwargs.get("auto_adapt", True) + if auto_adapt is None: + auto_adapt = True + auto_adapt_fill_value = _parse_auto_adapt_fill_value(infer_kwargs) + auto_adapt_fill_value_name = ( + "NaN" if torch.isnan(torch.tensor(auto_adapt_fill_value)) else "0" + ) for idx, input_dict in enumerate(inputs): # Check if the dictionary contains the expected keys if not isinstance(input_dict, dict): @@ -150,11 +194,11 @@ class ForecastPipeline(BasicPipeline): past_covariates[cov_key] = cov_value[-input_length:] else: logger.warning( - f"Past covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (< {input_length}), which will be padded with zeros at the beginning." + f"Past covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (< {input_length}), which will be padded with {auto_adapt_fill_value_name} at the beginning." ) pad_size = input_length - cov_value.shape[0] - past_covariates[cov_key] = F.pad( - cov_value, (pad_size, 0) + past_covariates[cov_key] = _pad_1d_tensor( + cov_value, pad_size, 0, auto_adapt_fill_value ) else: raise ValueError( @@ -205,11 +249,11 @@ class ForecastPipeline(BasicPipeline): ] else: logger.warning( - f"Future covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (< {output_length}), which will be padded with zeros at the end." + f"Future covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (< {output_length}), which will be padded with {auto_adapt_fill_value_name} at the end." ) pad_size = output_length - cov_value.shape[0] - future_covariates[cov_key] = F.pad( - cov_value, (0, pad_size) + future_covariates[cov_key] = _pad_1d_tensor( + cov_value, 0, pad_size, auto_adapt_fill_value ) else: raise ValueError( diff --git a/iotdb-core/ainode/test/test_forecast_auto_adapt.py b/iotdb-core/ainode/test/test_forecast_auto_adapt.py new file mode 100644 index 00000000000..7a0a1d04204 --- /dev/null +++ b/iotdb-core/ainode/test/test_forecast_auto_adapt.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys +import types + +import torch + +model_loader_stub = types.ModuleType("iotdb.ainode.core.model.model_loader") +model_loader_stub.load_model = lambda *args, **kwargs: None +sys.modules["iotdb.ainode.core.model.model_loader"] = model_loader_stub + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline + + +class _NoopForecastPipeline(ForecastPipeline): + def __init__(self): + pass + + def forecast(self, inputs, **infer_kwargs): + return inputs + + +def test_auto_adapt_padding_defaults_to_zero(): + inputs = [ + { + "targets": torch.tensor([[1.0, 2.0, 3.0]]), + "past_covariates": {"cov": torch.tensor([2.0])}, + "future_covariates": {"cov": torch.tensor([4.0])}, + } + ] + + processed = _NoopForecastPipeline().preprocess( + inputs, output_length=3, auto_adapt=True + ) + + assert torch.equal( + processed[0]["past_covariates"]["cov"], torch.tensor([0.0, 0.0, 2.0]) + ) + assert torch.equal( + processed[0]["future_covariates"]["cov"], torch.tensor([4.0, 0.0, 0.0]) + ) + + +def test_auto_adapt_padding_can_use_nan(): + inputs = [ + { + "targets": torch.tensor([[1.0, 2.0, 3.0]]), + "past_covariates": {"cov": torch.tensor([2])}, + "future_covariates": {"cov": torch.tensor([4])}, + } + ] + + processed = _NoopForecastPipeline().preprocess( + inputs, + output_length=3, + auto_adapt=True, + auto_adapt_fill_value="NaN", + ) + + assert torch.isnan(processed[0]["past_covariates"]["cov"][:2]).all() + assert processed[0]["past_covariates"]["cov"][2].item() == 2 + assert processed[0]["past_covariates"]["cov"].dtype == torch.float32 + assert processed[0]["future_covariates"]["cov"][0].item() == 4 + assert torch.isnan(processed[0]["future_covariates"]["cov"][1:]).all() + assert processed[0]["future_covariates"]["cov"].dtype == torch.float32 + + +def test_auto_adapt_fill_value_rejects_invalid_value(): + inputs = [{"targets": torch.tensor([[1.0, 2.0, 3.0]])}] + + try: + _NoopForecastPipeline().preprocess( + inputs, + output_length=3, + auto_adapt=True, + auto_adapt_fill_value="1", + ) + assert False + except ValueError as e: + assert "Unsupported auto_adapt_fill_value" in str(e) + + +def test_iotdb_data_fetcher_converts_covariates_to_float_tensors(): + session_stub = types.ModuleType("iotdb.Session") + session_stub.Session = object + sys.modules["iotdb.Session"] = session_stub + + table_session_stub = types.ModuleType("iotdb.table_session") + table_session_stub.TableSession = object + table_session_stub.TableSessionConfig = object + sys.modules["iotdb.table_session"] = table_session_stub + + field_stub = types.ModuleType("iotdb.utils.Field") + field_stub.Field = object + sys.modules["iotdb.utils.Field"] = field_stub + + constants_stub = types.ModuleType("iotdb.utils.IoTDBConstants") + constants_stub.TSDataType = types.SimpleNamespace( + INT32=1, INT64=2, FLOAT=3, DOUBLE=4, TIMESTAMP=5, TEXT=6 + ) + sys.modules["iotdb.utils.IoTDBConstants"] = constants_stub + + row_record_stub = types.ModuleType("iotdb.utils.RowRecord") + row_record_stub.RowRecord = object + sys.modules["iotdb.utils.RowRecord"] = row_record_stub + + from timecho.ainode.core.ingress.data_fetcher import IoTDBDataFetcher + + series_map = {("__DEFAULT_TAG__",): {"cov": [1, 2, 3]}} + timestamps_map = {("__DEFAULT_TAG__",): [3, 1, 2]} + + sorted_series_map, sorted_timestamps_map = IoTDBDataFetcher._sort_data_by_timestamp( + None, series_map, timestamps_map + ) + + covariates = sorted_series_map[("__DEFAULT_TAG__",)]["cov"] + assert covariates.dtype == torch.float32 + assert torch.equal(covariates, torch.tensor([2.0, 3.0, 1.0])) + assert sorted_timestamps_map[("__DEFAULT_TAG__",)] == [1, 2, 3] diff --git a/iotdb-core/ainode/timecho/ainode/core/ingress/data_fetcher.py b/iotdb-core/ainode/timecho/ainode/core/ingress/data_fetcher.py index b615161f536..8c768407574 100644 --- a/iotdb-core/ainode/timecho/ainode/core/ingress/data_fetcher.py +++ b/iotdb-core/ainode/timecho/ainode/core/ingress/data_fetcher.py @@ -118,6 +118,8 @@ class IoTDBDataFetcher: value_col_field: Field = cur_data.get_fields()[value_col] col_name = column_names[value_col] value = get_field_value(value_col_field) + if value is None: + value = torch.nan if col_name not in series_map[tag_values]: series_map[tag_values][col_name] = [] @@ -146,5 +148,5 @@ class IoTDBDataFetcher: timestamps_map[tag_values] = [timestamps[i] for i in sorted_indices] for col_name, values_list in cov_map.items(): sorted_values = [values_list[i] for i in sorted_indices] - cov_map[col_name] = torch.tensor(sorted_values) + cov_map[col_name] = torch.tensor(sorted_values, dtype=torch.float32) return series_map, timestamps_map diff --git a/iotdb-core/datanode/src/main/java/com/timecho/iotdb/db/queryengine/plan/relational/function/tvf/TimechoForecastTableFunction.java b/iotdb-core/datanode/src/main/java/com/timecho/iotdb/db/queryengine/plan/relational/function/tvf/TimechoForecastTableFunction.java index fd622630c06..18a10d69fe4 100644 --- a/iotdb-core/datanode/src/main/java/com/timecho/iotdb/db/queryengine/plan/relational/function/tvf/TimechoForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/com/timecho/iotdb/db/queryengine/plan/relational/function/tvf/TimechoForecastTableFunction.java @@ -54,6 +54,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -194,6 +195,10 @@ public class TimechoForecastTableFunction extends ForecastTableFunction { private static final String DEFAULT_FUTURE_COVS = ""; private static final String AUTO_ADAPT_PARAMETER_NAME = "AUTO_ADAPT"; private static final Boolean DEFAULT_AUTO_ADAPT = Boolean.TRUE; + private static final String AUTO_ADAPT_FILL_VALUE_PARAMETER_NAME = "AUTO_ADAPT_FILL_VALUE"; + private static final String AUTO_ADAPT_FILL_VALUE_OPTION_KEY = "auto_adapt_fill_value"; + private static final String DEFAULT_AUTO_ADAPT_FILL_VALUE = "0"; + private static final String NAN_AUTO_ADAPT_FILL_VALUE = "NaN"; @Override public List<ParameterSpecification> getArgumentsSpecifications() { @@ -243,6 +248,11 @@ public class TimechoForecastTableFunction extends ForecastTableFunction { .type(Type.BOOLEAN) .defaultValue(DEFAULT_AUTO_ADAPT) .build(), + ScalarParameterSpecification.builder() + .name(AUTO_ADAPT_FILL_VALUE_PARAMETER_NAME) + .type(Type.STRING) + .defaultValue(DEFAULT_AUTO_ADAPT_FILL_VALUE) + .build(), ScalarParameterSpecification.builder() .name(OPTIONS_PARAMETER_NAME) .type(Type.STRING) @@ -327,7 +337,11 @@ public class TimechoForecastTableFunction extends ForecastTableFunction { (String) ((ScalarArgument) arguments.get(HISTORY_COVS_PARAMETER_NAME)).getValue(); String futureCovs = (String) ((ScalarArgument) arguments.get(FUTURE_COVS_PARAMETER_NAME)).getValue(); + String autoAdaptFillValue = + (String) ((ScalarArgument) arguments.get(AUTO_ADAPT_FILL_VALUE_PARAMETER_NAME)).getValue(); String options = (String) ((ScalarArgument) arguments.get(OPTIONS_PARAMETER_NAME)).getValue(); + Map<String, String> parsedOptions = + parseOptionsWithAutoAdaptFillValue(options, autoAdaptFillValue); ForecastTableFunctionHandle functionHandle = new TimechoForecastTableFunctionHandle( @@ -335,7 +349,7 @@ public class TimechoForecastTableFunction extends ForecastTableFunction { keepInput, MAX_INPUT_LENGTH, modelId, - parseOptions(options), + parsedOptions, historyCovs, futureCovs, outputLength, @@ -356,6 +370,38 @@ public class TimechoForecastTableFunction extends ForecastTableFunction { return new TimechoForecastTableFunctionHandle(); } + private static Map<String, String> parseOptionsWithAutoAdaptFillValue( + String options, String autoAdaptFillValue) { + Map<String, String> parsedOptions = new HashMap<>(parseOptions(options)); + if (parsedOptions.containsKey(AUTO_ADAPT_FILL_VALUE_OPTION_KEY)) { + parsedOptions.put( + AUTO_ADAPT_FILL_VALUE_OPTION_KEY, + normalizeAutoAdaptFillValue(parsedOptions.get(AUTO_ADAPT_FILL_VALUE_OPTION_KEY))); + } + String normalizedAutoAdaptFillValue = normalizeAutoAdaptFillValue(autoAdaptFillValue); + if (!DEFAULT_AUTO_ADAPT_FILL_VALUE.equals(normalizedAutoAdaptFillValue)) { + parsedOptions.put(AUTO_ADAPT_FILL_VALUE_OPTION_KEY, normalizedAutoAdaptFillValue); + } + return parsedOptions; + } + + private static String normalizeAutoAdaptFillValue(String fillValue) { + if (fillValue == null) { + return DEFAULT_AUTO_ADAPT_FILL_VALUE; + } + String normalizedFillValue = fillValue.trim().toLowerCase(Locale.ENGLISH); + if (DEFAULT_AUTO_ADAPT_FILL_VALUE.equals(normalizedFillValue) + || "0.0".equals(normalizedFillValue) + || "zero".equals(normalizedFillValue)) { + return DEFAULT_AUTO_ADAPT_FILL_VALUE; + } + if ("nan".equals(normalizedFillValue)) { + return NAN_AUTO_ADAPT_FILL_VALUE; + } + throw new SemanticException( + String.format("%s should be either '0' or 'NaN'", AUTO_ADAPT_FILL_VALUE_PARAMETER_NAME)); + } + @Override public TableFunctionProcessorProvider getProcessorProvider( TableFunctionHandle tableFunctionHandle) { diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java index 9581083ba71..bb94508906f 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java @@ -475,6 +475,52 @@ public class TableFunctionTest { group(ImmutableList.of(sort("time_0", ASCENDING, FIRST)), 0, tableScan)))); } + @Test + public void testForecastFunctionWithAutoAdaptFillValue() { + PlanTester planTester = new PlanTester(); + + String sql = + "SELECT * FROM FORECAST(" + + "targets => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' AND tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), " + + "model_id => 'timer_xl', auto_adapt_fill_value => 'NaN')"; + LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql); + + PlanMatchPattern tableScan = + tableScan("testdb.table1", ImmutableMap.of("time_0", "time", "s3_1", "s3")); + Consumer<TableFunctionProcessorMatcher.Builder> tableFunctionMatcher = + builder -> + builder + .name("forecast") + .properOutputs("time", "s3") + .requiredSymbols("time_0", "s3_1") + .handle( + new TimechoForecastTableFunction.TimechoForecastTableFunctionHandle( + true, + false, + 2880, + "timer_xl", + ImmutableMap.of("auto_adapt_fill_value", "NaN"), + "", + "", + 96, + DEFAULT_OUTPUT_START_TIME, + DEFAULT_OUTPUT_INTERVAL, + Collections.singletonList(DOUBLE))); + + assertPlan( + logicalQueryPlan, + anyTree( + tableFunctionProcessor( + tableFunctionMatcher, + sort( + ImmutableList.of(sort("time_0", ASCENDING, FIRST)), + topK( + 1440, + ImmutableList.of(sort("time_0", DESCENDING, LAST)), + false, + tableScan))))); + } + @Test public void testForecastFunctionAbnormal() { // default order by time asc @@ -488,5 +534,16 @@ public class TableFunctionTest { } catch (SemanticException e) { assertEquals("TIMECOL should never be null or empty.", e.getMessage()); } + + sql = + "SELECT * FROM FORECAST(" + + "targets => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' AND tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), " + + "model_id => 'timer_xl', auto_adapt_fill_value => '1')"; + try { + analyzeSQL(sql, TEST_MATADATA, QUERY_CONTEXT); + fail(); + } catch (SemanticException e) { + assertEquals("AUTO_ADAPT_FILL_VALUE should be either '0' or 'NaN'", e.getMessage()); + } } }
