This is an automated email from the ASF dual-hosted git repository. yongzao pushed a commit to branch cp-ain-to-206 in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 4da262eca7c52a137a2dd47445d8ee104cfa61d3 Author: Leo <[email protected]> AuthorDate: Wed Nov 19 09:23:19 2025 +0800 [AINode] Fix bug of sundial and forecast udf (#16768) (cherry picked from commit 2b47be756ad8703ce3673973260983f10c4f94e3) --- .../ainode/core/model/sundial/modeling_sundial.py | 7 +- .../ainode/core/model/timerxl/modeling_timer.py | 6 +- .../db/queryengine/plan/udf/UDTFForecast.java | 270 +++++++++++++++++++++ 3 files changed, 279 insertions(+), 4 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py index f66b4b13f22..6a2402952db 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py @@ -645,9 +645,10 @@ class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[ - :, -(input_ids.shape[1] // self.config.input_token_len) : - ] + token_num = ( + input_ids.shape[1] + self.config.input_token_len - 1 + ) // self.config.input_token_len + position_ids = position_ids[:, -token_num:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py index 4aed1696af7..8dfddd335b7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py @@ -613,7 +613,11 @@ class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin): if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + input_ids = input_ids[ + :, + -(attention_mask.shape[1] - past_length) + * self.config.input_token_len :, + ] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java new file mode 100644 index 00000000000..260410954d4 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.udf; + +import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.exception.IoTDBRuntimeException; +import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; +import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; +import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher; +import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; +import org.apache.iotdb.rpc.TSStatusCode; +import org.apache.iotdb.udf.api.UDTF; +import org.apache.iotdb.udf.api.access.Row; +import org.apache.iotdb.udf.api.collector.PointCollector; +import org.apache.iotdb.udf.api.customizer.config.UDTFConfigurations; +import org.apache.iotdb.udf.api.customizer.parameter.UDFParameters; +import org.apache.iotdb.udf.api.customizer.strategy.RowByRowAccessStrategy; +import org.apache.iotdb.udf.api.type.Type; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlock; +import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.read.common.block.column.TsBlockSerde; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class UDTFForecast implements UDTF { + private static final TsBlockSerde serde = new TsBlockSerde(); + private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance(); + private TEndPoint targetAINode = new TEndPoint("127.0.0.1", 10810); + private String model_id; + private int maxInputLength; + private int outputLength; + private long outputStartTime; + private long outputInterval; + private boolean keepInput; + Map<String, String> options; + List<Type> types; + private LinkedList<Row> inputRows; + private TsBlockBuilder inputTsBlockBuilder; + private final IModelFetcher modelFetcher = ModelFetcher.getInstance(); + + private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>(); + + static { + ALLOWED_INPUT_TYPES.add(Type.INT32); + ALLOWED_INPUT_TYPES.add(Type.INT64); + ALLOWED_INPUT_TYPES.add(Type.FLOAT); + ALLOWED_INPUT_TYPES.add(Type.DOUBLE); + } + + private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID"; + private static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH"; + private static final int DEFAULT_OUTPUT_LENGTH = 96; + private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME"; + public static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE; + private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL"; + public static final long DEFAULT_OUTPUT_INTERVAL = 0L; + private static final String KEEP_INPUT_PARAMETER_NAME = "PRESERVE_INPUT"; + private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE; + private static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS"; + private static final String DEFAULT_OPTIONS = ""; + + private void checkType() { + for (Type type : this.types) { + if (!ALLOWED_INPUT_TYPES.contains(type)) { + throw new IllegalArgumentException( + String.format( + "Input data type %s is not supported, only %s are allowed.", + type, ALLOWED_INPUT_TYPES)); + } + } + } + + @Override + public void beforeStart(UDFParameters parameters, UDTFConfigurations configurations) + throws Exception { + this.types = parameters.getDataTypes(); + checkType(); + configurations.setAccessStrategy(new RowByRowAccessStrategy()).setOutputDataType(Type.DOUBLE); + + this.model_id = parameters.getString(MODEL_ID_PARAMETER_NAME); + if (this.model_id == null || this.model_id.isEmpty()) { + throw new IllegalArgumentException( + "MODEL_ID parameter must be provided and cannot be empty."); + } + ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id); + this.targetAINode = descriptor.getTargetAINode(); + + this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL); + this.outputLength = + parameters.getIntOrDefault(OUTPUT_LENGTH_PARAMETER_NAME, DEFAULT_OUTPUT_LENGTH); + this.outputStartTime = + parameters.getLongOrDefault(OUTPUT_START_TIME, DEFAULT_OUTPUT_START_TIME); + this.keepInput = parameters.getBooleanOrDefault(KEEP_INPUT_PARAMETER_NAME, DEFAULT_KEEP_INPUT); + this.options = + Arrays.stream( + parameters.getStringOrDefault(OPTIONS_PARAMETER_NAME, DEFAULT_OPTIONS).split(",")) + .map(s -> s.split("=")) + .filter(arr -> arr.length == 2 && !arr[0].isEmpty()) // 防御性检查 + .collect( + Collectors.toMap( + arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2 // 如果 key 重复,保留后一个 + )); + this.inputRows = new LinkedList<>(); + List<TSDataType> tsDataTypeList = new ArrayList<>(this.types.size() - 1); + for (int i = 0; i < this.types.size(); i++) { + tsDataTypeList.add(TSDataType.DOUBLE); + } + this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList); + } + + private void setByType(Row row, PointCollector collector) throws IOException { + for (int i = 0; i < row.size(); i++) { + switch (this.types.get(i)) { + case INT32: + collector.putInt(row.getTime(), row.getInt(i)); + break; + case INT64: + collector.putLong(row.getTime(), row.getLong(i)); + break; + case FLOAT: + collector.putFloat(row.getTime(), row.getFloat(i)); + break; + case DOUBLE: + collector.putDouble(row.getTime(), row.getDouble(i)); + break; + default: + throw new IllegalArgumentException( + String.format("Unsupported data type %s", this.types.get(i + 1))); + } + } + } + + private void setByType(Row row, TsBlockBuilder tsBlockBuilder) throws IOException { + for (int i = 0; i < row.size(); i++) { + if (row.isNull(i)) { + tsBlockBuilder.getColumnBuilder(i).appendNull(); + continue; + } + switch (this.types.get(i)) { + case INT32: + tsBlockBuilder.getColumnBuilder(i).writeInt(row.getInt(i)); + break; + case INT64: + tsBlockBuilder.getColumnBuilder(i).writeLong(row.getLong(i)); + break; + case FLOAT: + tsBlockBuilder.getColumnBuilder(i).writeFloat(row.getFloat(i)); + break; + case DOUBLE: + tsBlockBuilder.getColumnBuilder(i).writeDouble(row.getDouble(i)); + break; + default: + throw new IllegalArgumentException( + String.format("Unsupported data type %s", this.types.get(i + 1))); + } + } + } + + @Override + public void transform(Row row, PointCollector collector) throws Exception { + if (this.keepInput) { + setByType(row, collector); + } + + if (maxInputLength != 0 && inputRows.size() >= maxInputLength) { + // If the input rows exceed the maximum length, remove the oldest row + inputRows.removeFirst(); + } + inputRows.add(row); + } + + private TsBlock forecast() throws Exception { + // Build the input data which will be sent to AINode + while (!inputRows.isEmpty()) { + Row row = inputRows.removeFirst(); + inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getTime()); + setByType(row, inputTsBlockBuilder); + inputTsBlockBuilder.declarePosition(); + } + + TsBlock inputData = inputTsBlockBuilder.build(); + + TForecastResp resp; + try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { + resp = client.forecast(model_id, inputData, outputLength, options); + } catch (Exception e) { + throw new IoTDBRuntimeException( + e.getMessage(), TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode()); + } + + if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + throw new IoTDBRuntimeException( + String.format( + "Forecast failed due to %d %s", + resp.getStatus().getCode(), resp.getStatus().getMessage()), + resp.getStatus().getCode()); + } + return serde.deserialize(ByteBuffer.wrap(resp.getForecastResult())); + } + + @Override + public void terminate(PointCollector collector) throws Exception { + long inputStartTime = inputRows.get(0).getTime(); + long inputEndTime = inputRows.get(inputRows.size() - 1).getTime(); + if (inputStartTime > inputEndTime) { + throw new IllegalArgumentException( + String.format( + "input end time should never less than start time, start time is %s, end time is %s", + inputStartTime, inputEndTime)); + } + long interval = this.outputInterval; + if (outputInterval <= 0) { + interval = (inputEndTime - inputStartTime) / (inputRows.size() - 1); + } + long outputTime = + (this.outputStartTime == Long.MIN_VALUE) ? inputEndTime + interval : this.outputStartTime; + long[] outputTimes = new long[this.outputLength]; + for (int i = 0; i < this.outputLength; i++) { + outputTimes[i] = outputTime + interval * i; + } + + TsBlock forecastResult = forecast(); + if (forecastResult.getPositionCount() != this.outputLength) { + throw new IllegalArgumentException( + String.format( + "The forecast result length %d does not match the expected output length %d", + forecastResult.getPositionCount(), this.outputLength)); + } + if (forecastResult.getValueColumnCount() != 1) { + throw new IllegalArgumentException( + String.format( + "The forecast result should have only one value column, but got %d", + forecastResult.getValueColumnCount())); + } + + for (int i = 0; i < forecastResult.getPositionCount(); i++) { + collector.putDouble(outputTimes[i], forecastResult.getValueColumns()[0].getDouble(i)); + } + } +}
