This is an automated email from the ASF dual-hosted git repository. yongzao pushed a commit to branch refactor-forecast in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 671b0c22608abfb50b6fa99b831e741a05b5c4ed Author: Yongzao <[email protected]> AuthorDate: Sat Nov 1 17:02:11 2025 +0800 finish --- iotdb-core/ainode/.gitignore | 1 - .../function/tvf/ForecastTableFunction.java | 88 ++++++++++++++-------- .../db/queryengine/plan/udf/UDTFForecast.java | 7 +- .../relational/analyzer/TableFunctionTest.java | 2 + 4 files changed, 60 insertions(+), 38 deletions(-) diff --git a/iotdb-core/ainode/.gitignore b/iotdb-core/ainode/.gitignore index 8cc2098c3fd..80221b44dc5 100644 --- a/iotdb-core/ainode/.gitignore +++ b/iotdb-core/ainode/.gitignore @@ -5,7 +5,6 @@ /iotdb/thrift/ /iotdb/tsfile/ /iotdb/utils/ -/iotdb/__init__.py /iotdb/Session.py /iotdb/SessionPool.py /iotdb/table_session.py diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index 5521062a24e..af09dbf4584 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -84,7 +84,8 @@ public class ForecastTableFunction implements TableFunction { long outputInterval; boolean keepInput; Map<String, String> options; - List<Type> types; + List<Type> inputColumnTypes; + List<Type> predicatedColumnTypes; public ForecastTableFunctionHandle() {} @@ -97,7 +98,8 @@ public class ForecastTableFunction implements TableFunction { long outputStartTime, long outputInterval, TEndPoint targetAINode, - List<Type> types) { + List<Type> inputColumnTypes, + List<Type> predicatedColumnTypes) { this.keepInput = keepInput; this.maxInputLength = maxInputLength; this.modelId = modelId; @@ -106,7 +108,8 @@ public class ForecastTableFunction implements TableFunction { this.outputStartTime = outputStartTime; this.outputInterval = outputInterval; this.targetAINode = targetAINode; - this.types = types; + this.inputColumnTypes = inputColumnTypes; + this.predicatedColumnTypes = predicatedColumnTypes; } @Override @@ -122,8 +125,12 @@ public class ForecastTableFunction implements TableFunction { ReadWriteIOUtils.write(outputInterval, outputStream); ReadWriteIOUtils.write(keepInput, outputStream); ReadWriteIOUtils.write(options, outputStream); - ReadWriteIOUtils.write(types.size(), outputStream); - for (Type type : types) { + ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream); + for (Type type : inputColumnTypes) { + ReadWriteIOUtils.write(type.getType(), outputStream); + } + ReadWriteIOUtils.write(predicatedColumnTypes.size(), outputStream); + for (Type type : predicatedColumnTypes) { ReadWriteIOUtils.write(type.getType(), outputStream); } outputStream.flush(); @@ -149,9 +156,14 @@ public class ForecastTableFunction implements TableFunction { this.keepInput = ReadWriteIOUtils.readBoolean(buffer); this.options = ReadWriteIOUtils.readMap(buffer); int size = ReadWriteIOUtils.readInt(buffer); - this.types = new ArrayList<>(size); + this.inputColumnTypes = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readString(buffer))); + } + size = ReadWriteIOUtils.readInt(buffer); + this.predicatedColumnTypes = new ArrayList<>(size); for (int i = 0; i < size; i++) { - types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer))); + predicatedColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer))); } } @@ -172,7 +184,8 @@ public class ForecastTableFunction implements TableFunction { && Objects.equals(targetAINode, that.targetAINode) && Objects.equals(modelId, that.modelId) && Objects.equals(options, that.options) - && Objects.equals(types, that.types); + && Objects.equals(inputColumnTypes, that.inputColumnTypes) + && Objects.equals(predicatedColumnTypes, that.predicatedColumnTypes); } @Override @@ -186,7 +199,8 @@ public class ForecastTableFunction implements TableFunction { outputInterval, keepInput, options, - types); + inputColumnTypes, + predicatedColumnTypes); } } @@ -319,12 +333,22 @@ public class ForecastTableFunction implements TableFunction { DescribedSchema.Builder properColumnSchemaBuilder = new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP); + List<Type> inputColumnTypes = new ArrayList<>(); List<Type> predicatedColumnTypes = new ArrayList<>(); List<Optional<String>> allInputColumnsName = input.getFieldNames(); List<Type> allInputColumnsType = input.getFieldTypes(); + for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { + Optional<String> fieldName = allInputColumnsName.get(i); + // All input value columns are required for model forecasting + if (!fieldName.isPresent() + || !excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) { + inputColumnTypes.add(allInputColumnsType.get(i)); + requiredIndexList.add(i); + } + } if (predicatedColumns.isEmpty()) { - // predicated columns by default include all columns from input table except for timecol and - // partition by columns + // predicated columns by default include all columns from input table except for + // timecol and partition by columns for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { Optional<String> fieldName = allInputColumnsName.get(i); if (!fieldName.isPresent() @@ -332,7 +356,6 @@ public class ForecastTableFunction implements TableFunction { Type columnType = allInputColumnsType.get(i); predicatedColumnTypes.add(columnType); checkType(columnType, fieldName.orElse("")); - requiredIndexList.add(i); properColumnSchemaBuilder.addField(fieldName, columnType); } } @@ -347,7 +370,7 @@ public class ForecastTableFunction implements TableFunction { inputColumnIndexMap.put(fieldName.get().toLowerCase(Locale.ENGLISH), i); } - Set<Integer> requiredIndexSet = new HashSet<>(predictedColumnsArray.length); + Set<Integer> predicatedIndexSet = new HashSet<>(predictedColumnsArray.length); // columns need to be predicated for (String outputColumn : predictedColumnsArray) { String lowerCaseOutputColumn = outputColumn.toLowerCase(Locale.ENGLISH); @@ -360,14 +383,13 @@ public class ForecastTableFunction implements TableFunction { throw new SemanticException( String.format("Column %s don't exist in input", outputColumn)); } - if (!requiredIndexSet.add(inputColumnIndex)) { + if (!predicatedIndexSet.add(inputColumnIndex)) { throw new SemanticException(String.format("Duplicate column %s", outputColumn)); } Type columnType = allInputColumnsType.get(inputColumnIndex); predicatedColumnTypes.add(columnType); checkType(columnType, outputColumn); - requiredIndexList.add(inputColumnIndex); properColumnSchemaBuilder.addField(outputColumn, columnType); } } @@ -392,6 +414,7 @@ public class ForecastTableFunction implements TableFunction { outputStartTime, outputInterval, targetAINode, + inputColumnTypes, predicatedColumnTypes); // outputColumnSchema @@ -469,8 +492,9 @@ public class ForecastTableFunction implements TableFunction { private final boolean keepInput; private final Map<String, String> options; private final LinkedList<Record> inputRecords; - private final List<ResultColumnAppender> resultColumnAppenderList; private final TsBlockBuilder inputTsBlockBuilder; + private final List<ResultColumnAppender> inputColumnAppenderList; + private final List<ResultColumnAppender> resultColumnAppenderList; public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) { this.targetAINode = functionHandle.targetAINode; @@ -482,14 +506,22 @@ public class ForecastTableFunction implements TableFunction { this.keepInput = functionHandle.keepInput; this.options = functionHandle.options; this.inputRecords = new LinkedList<>(); - this.resultColumnAppenderList = new ArrayList<>(functionHandle.types.size()); - List<TSDataType> tsDataTypeList = new ArrayList<>(functionHandle.types.size()); - for (Type type : functionHandle.types) { + List<TSDataType> inputTsDataTypeList = + new ArrayList<>(functionHandle.inputColumnTypes.size()); + for (Type type : functionHandle.inputColumnTypes) { + // AINode currently only accept double input + inputTsDataTypeList.add(TSDataType.DOUBLE); + } + this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList); + this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size()); + for (Type type : functionHandle.inputColumnTypes) { + // AINode currently only accept double input + inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE)); + } + this.resultColumnAppenderList = new ArrayList<>(functionHandle.predicatedColumnTypes.size()); + for (Type type : functionHandle.predicatedColumnTypes) { resultColumnAppenderList.add(createResultColumnAppender(type)); - // ainode currently only accept double input - tsDataTypeList.add(TSDataType.DOUBLE); } - this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList); } private static ResultColumnAppender createResultColumnAppender(Type type) { @@ -613,7 +645,7 @@ public class ForecastTableFunction implements TableFunction { // need to transform other types to DOUBLE inputTsBlockBuilder .getColumnBuilder(i - 1) - .writeDouble(resultColumnAppenderList.get(i - 1).getDouble(row, i)); + .writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row, i)); } } inputTsBlockBuilder.declarePosition(); @@ -634,15 +666,7 @@ public class ForecastTableFunction implements TableFunction { throw new IoTDBRuntimeException(message, resp.getStatus().getCode()); } - TsBlock res = SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult())); - if (res.getValueColumnCount() != inputData.getValueColumnCount()) { - throw new IoTDBRuntimeException( - String.format( - "Model %s output %s columns, doesn't equal to specified %s", - modelId, res.getValueColumnCount(), inputData.getValueColumnCount()), - TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode()); - } - return res; + return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult())); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java index 22c2bce7b5e..f1e387c0ff6 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -128,11 +128,8 @@ public class UDTFForecast implements UDTF { Arrays.stream( parameters.getStringOrDefault(OPTIONS_PARAMETER_NAME, DEFAULT_OPTIONS).split(",")) .map(s -> s.split("=")) - .filter(arr -> arr.length == 2 && !arr[0].isEmpty()) // 防御性检查 - .collect( - Collectors.toMap( - arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2 // 如果 key 重复,保留后一个 - )); + .filter(arr -> arr.length == 2 && !arr[0].isEmpty()) + .collect(Collectors.toMap(arr -> arr[0].trim(), arr -> arr[1].trim(), (v1, v2) -> v2)); this.inputRows = new LinkedList<>(); List<TSDataType> tsDataTypeList = new ArrayList<>(this.types.size() - 1); for (int i = 0; i < this.types.size(); i++) { diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java index e56b48936b9..0e8ffb484f1 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java @@ -379,6 +379,7 @@ public class TableFunctionTest { DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, new TEndPoint("127.0.0.1", 10810), + Collections.singletonList(DOUBLE), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan @@ -440,6 +441,7 @@ public class TableFunctionTest { DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, new TEndPoint("127.0.0.1", 10810), + Collections.singletonList(DOUBLE), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan
