This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new b6c13d7e3e8 [AINode] Forecast table function version2 (#16922)
b6c13d7e3e8 is described below
commit b6c13d7e3e87581ea43ba82e03a5ede133cdd452
Author: Yongzao <[email protected]>
AuthorDate: Thu Dec 18 09:41:24 2025 +0800
[AINode] Forecast table function version2 (#16922)
---------
Co-authored-by: Liu Zhengyun <[email protected]>
---
.../ainode/it/AINodeConcurrentForecastIT.java | 2 +-
.../apache/iotdb/ainode/it/AINodeForecastIT.java | 100 +++++-
.../ainode/core/model/sundial/modeling_sundial.py | 13 +-
.../ainode/core/model/timer_xl/modeling_timer.py | 6 +-
.../plan/planner/TableOperatorGenerator.java | 25 +-
.../function/TableBuiltinTableFunction.java | 6 +-
.../function/tvf/ClassifyTableFunction.java | 383 +++++++++++++++++++++
.../function/tvf/ForecastTableFunction.java | 263 ++++----------
.../relational/utils/ResultColumnAppender.java | 145 ++++++++
.../db/queryengine/plan/udf/UDTFForecast.java | 2 +
.../relational/analyzer/TableFunctionTest.java | 4 +-
11 files changed, 723 insertions(+), 226 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
index 844ec1d8223..fe19f991e57 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
@@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT {
private static final Logger LOGGER =
LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
- "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM
root.AI) ORDER BY time, output_length=>%d)";
+ "SELECT * FROM FORECAST(model_id=>'%s', targets=>(SELECT time,s FROM
root.AI) ORDER BY time, output_length=>%d)";
@BeforeClass
public static void setUp() throws Exception {
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
index a06656d4ada..bb0de13ed49 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
@@ -38,13 +38,21 @@ import java.sql.SQLException;
import java.sql.Statement;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
public class AINodeForecastIT {
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
- "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM
db.AI) ORDER BY time)";
+ "SELECT * FROM FORECAST("
+ + "model_id=>'%s', "
+ + "targets=>(SELECT time, s%d FROM db.AI WHERE time<%d ORDER BY time
DESC LIMIT %d) ORDER BY time, "
+ + "output_start_time=>%d, "
+ + "output_length=>%d, "
+ + "output_interval=>%d, "
+ + "timecol=>'%s'"
+ + ")";
@BeforeClass
public static void setUp() throws Exception {
@@ -55,7 +63,7 @@ public class AINodeForecastIT {
statement.execute("CREATE DATABASE db");
statement.execute(
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32
FIELD, s3 INT64 FIELD)");
- for (int i = 0; i < 2880; i++) {
+ for (int i = 0; i < 5760; i++) {
statement.execute(
String.format(
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
@@ -81,18 +89,100 @@ public class AINodeForecastIT {
public void forecastTableFunctionTest(
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws
SQLException {
- // Invoke call inference for specified models, there should exist result.
+ // Invoke forecast table function for specified models, there should exist
result.
for (int i = 0; i < 4; i++) {
String forecastTableFunctionSQL =
- String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
modelInfo.getModelId(), i);
+ String.format(
+ FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+ modelInfo.getModelId(),
+ i,
+ 5760,
+ 2880,
+ 5760,
+ 96,
+ 1,
+ "time");
try (ResultSet resultSet =
statement.executeQuery(forecastTableFunctionSQL)) {
int count = 0;
while (resultSet.next()) {
count++;
}
- // Ensure the call inference return results
+ // Ensure the forecast sentence return results
Assert.assertTrue(count > 0);
}
}
}
+
+ @Test
+ public void forecastTableFunctionErrorTest() throws SQLException {
+ for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values())
{
+ try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ forecastTableFunctionErrorTest(statement, modelInfo);
+ }
+ }
+ }
+
+ public void forecastTableFunctionErrorTest(
+ Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws
SQLException {
+ // OUTPUT_START_TIME error
+ String invalidOutputStartTimeSQL =
+ String.format(
+ FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+ modelInfo.getModelId(),
+ 0,
+ 5760,
+ 2880,
+ 5759,
+ 96,
+ 1,
+ "time");
+ errorTest(
+ statement,
+ invalidOutputStartTimeSQL,
+ "701: The OUTPUT_START_TIME should be greater than the maximum
timestamp of target time series. Expected greater than [5759] but found
[5759].");
+
+ // OUTPUT_LENGTH error
+ String invalidOutputLengthSQL =
+ String.format(
+ FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+ modelInfo.getModelId(),
+ 0,
+ 5760,
+ 2880,
+ 5760,
+ 0,
+ 1,
+ "time");
+ errorTest(statement, invalidOutputLengthSQL, "701: OUTPUT_LENGTH should be
greater than 0");
+
+ // OUTPUT_INTERVAL error
+ String invalidOutputIntervalSQL =
+ String.format(
+ FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+ modelInfo.getModelId(),
+ 0,
+ 5760,
+ 2880,
+ 5760,
+ 96,
+ -1,
+ "time");
+ errorTest(statement, invalidOutputIntervalSQL, "701: OUTPUT_INTERVAL
should be greater than 0");
+
+ // TIMECOL error
+ String invalidTimecolSQL2 =
+ String.format(
+ FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+ modelInfo.getModelId(),
+ 0,
+ 5760,
+ 2880,
+ 5760,
+ 96,
+ 1,
+ "s0");
+ errorTest(
+ statement, invalidTimecolSQL2, "701: The type of the column [s0] is
not as expected.");
+ }
}
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
index dc1de32506e..4e892a6817b 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
@@ -610,7 +610,11 @@ class SundialForPrediction(SundialPreTrainedModel,
TSGenerationMixin):
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
- input_ids = input_ids[:, -(attention_mask.shape[1] -
past_length) :]
+ input_ids = input_ids[
+ :,
+ -(attention_mask.shape[1] - past_length)
+ * self.config.input_token_len :,
+ ]
# 2 - If the past_length is smaller than input_ids', then
input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] //
self.config.input_token_len):
@@ -623,9 +627,10 @@ class SundialForPrediction(SundialPreTrainedModel,
TSGenerationMixin):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[
- :, -(input_ids.shape[1] // self.config.input_token_len) :
- ]
+ token_num = (
+ input_ids.shape[1] + self.config.input_token_len - 1
+ ) // self.config.input_token_len
+ position_ids = position_ids[:, -token_num:]
# if `inputs_embeds` are passed, we only want to use them in the 1st
generation step
if inputs_embeds is not None and past_key_values is None:
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
index fc9d7b41388..1722b27d715 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
@@ -603,7 +603,11 @@ class TimerForPrediction(TimerPreTrainedModel,
TSGenerationMixin):
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
- input_ids = input_ids[:, -(attention_mask.shape[1] -
past_length) :]
+ input_ids = input_ids[
+ :,
+ -(attention_mask.shape[1] - past_length)
+ * self.config.input_token_len :,
+ ]
# 2 - If the past_length is smaller than input_ids', then
input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] //
self.config.input_token_len):
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
index 6c6ae5cf327..a2bd0dc4815 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
@@ -2013,15 +2013,22 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
@Override
public Operator visitGroup(GroupNode node, LocalExecutionPlanContext
context) {
- StreamSortNode streamSortNode =
- new StreamSortNode(
- node.getPlanNodeId(),
- node.getChild(),
- node.getOrderingScheme(),
- false,
- false,
- node.getPartitionKeyCount() - 1);
- return visitStreamSort(streamSortNode, context);
+ if (node.getPartitionKeyCount() == 0) {
+ SortNode sortNode =
+ new SortNode(
+ node.getPlanNodeId(), node.getChild(), node.getOrderingScheme(),
false, false);
+ return visitSort(sortNode, context);
+ } else {
+ StreamSortNode streamSortNode =
+ new StreamSortNode(
+ node.getPlanNodeId(),
+ node.getChild(),
+ node.getOrderingScheme(),
+ false,
+ false,
+ node.getPartitionKeyCount() - 1);
+ return visitStreamSort(streamSortNode, context);
+ }
}
@Override
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
index a5a57cebadb..61b96809f84 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
@@ -25,6 +25,7 @@ import
org.apache.iotdb.commons.udf.builtin.relational.tvf.HOPTableFunction;
import
org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
import
org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction;
+import
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ClassifyTableFunction;
import
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
import
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.PatternMatchTableFunction;
import org.apache.iotdb.udf.api.relational.TableFunction;
@@ -42,7 +43,8 @@ public enum TableBuiltinTableFunction {
VARIATION("variation"),
CAPACITY("capacity"),
FORECAST("forecast"),
- PATTERN_MATCH("pattern_match");
+ PATTERN_MATCH("pattern_match"),
+ CLASSIFY("classify");
private final String functionName;
@@ -86,6 +88,8 @@ public enum TableBuiltinTableFunction {
return new CapacityTableFunction();
case "forecast":
return new ForecastTableFunction();
+ case "classify":
+ return new ClassifyTableFunction();
default:
throw new UnsupportedOperationException("Unsupported table function: "
+ functionName);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
new file mode 100644
index 00000000000..670e019a4b6
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
@@ -0,0 +1,383 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.function.tvf;
+
+import org.apache.iotdb.ainode.rpc.thrift.TForecastReq;
+import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
+import org.apache.iotdb.commons.client.IClientManager;
+import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
+import org.apache.iotdb.db.exception.sql.SemanticException;
+import org.apache.iotdb.db.protocol.client.an.AINodeClient;
+import org.apache.iotdb.db.protocol.client.an.AINodeClientManager;
+import
org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.udf.api.exception.UDFException;
+import org.apache.iotdb.udf.api.relational.TableFunction;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis;
+import org.apache.iotdb.udf.api.relational.table.TableFunctionHandle;
+import
org.apache.iotdb.udf.api.relational.table.TableFunctionProcessorProvider;
+import org.apache.iotdb.udf.api.relational.table.argument.Argument;
+import org.apache.iotdb.udf.api.relational.table.argument.DescribedSchema;
+import org.apache.iotdb.udf.api.relational.table.argument.ScalarArgument;
+import org.apache.iotdb.udf.api.relational.table.argument.TableArgument;
+import
org.apache.iotdb.udf.api.relational.table.processor.TableFunctionDataProcessor;
+import
org.apache.iotdb.udf.api.relational.table.specification.ParameterSpecification;
+import
org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSpecification;
+import
org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification;
+import org.apache.iotdb.udf.api.type.Type;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.TsBlock;
+import org.apache.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.tsfile.read.common.block.column.TsBlockSerde;
+import org.apache.tsfile.utils.PublicBAOS;
+import org.apache.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static
org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex;
+import static
org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender;
+import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
+
+public class ClassifyTableFunction implements TableFunction {
+
+ public static class ClassifyTableFunctionHandle implements
TableFunctionHandle {
+ String modelId;
+ int maxInputLength;
+ List<Type> inputColumnTypes;
+
+ public ClassifyTableFunctionHandle() {}
+
+ public ClassifyTableFunctionHandle(
+ String modelId, int maxInputLength, List<Type> inputColumnTypes) {
+ this.modelId = modelId;
+ this.maxInputLength = maxInputLength;
+ this.inputColumnTypes = inputColumnTypes;
+ }
+
+ @Override
+ public byte[] serialize() {
+ try (PublicBAOS publicBAOS = new PublicBAOS();
+ DataOutputStream outputStream = new DataOutputStream(publicBAOS)) {
+ ReadWriteIOUtils.write(modelId, outputStream);
+ ReadWriteIOUtils.write(maxInputLength, outputStream);
+ ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream);
+ for (Type type : inputColumnTypes) {
+ ReadWriteIOUtils.write(type.getType(), outputStream);
+ }
+ outputStream.flush();
+ return publicBAOS.toByteArray();
+ } catch (IOException e) {
+ throw new IoTDBRuntimeException(
+ String.format(
+ "Error occurred while serializing ForecastTableFunctionHandle:
%s", e.getMessage()),
+ TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
+ }
+ }
+
+ @Override
+ public void deserialize(byte[] bytes) {
+ ByteBuffer buffer = ByteBuffer.wrap(bytes);
+ this.modelId = ReadWriteIOUtils.readString(buffer);
+ this.maxInputLength = ReadWriteIOUtils.readInt(buffer);
+ int size = ReadWriteIOUtils.readInt(buffer);
+ this.inputColumnTypes = new ArrayList<>(size);
+ for (int i = 0; i < size; i++) {
+
inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readString(buffer)));
+ }
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ ClassifyTableFunctionHandle that = (ClassifyTableFunctionHandle) o;
+ return maxInputLength == that.maxInputLength
+ && Objects.equals(modelId, that.modelId)
+ && Objects.equals(inputColumnTypes, that.inputColumnTypes);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(modelId, maxInputLength, inputColumnTypes);
+ }
+ }
+
+ private static final String INPUT_PARAMETER_NAME = "INPUT";
+ private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
+ public static final String TIMECOL_PARAMETER_NAME = "TIMECOL";
+ private static final String DEFAULT_TIME_COL = "time";
+ private static final String DEFAULT_OUTPUT_COLUMN_NAME = "category";
+ private static final int MAX_INPUT_LENGTH = 2880;
+
+ private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>();
+
+ static {
+ ALLOWED_INPUT_TYPES.add(Type.INT32);
+ ALLOWED_INPUT_TYPES.add(Type.INT64);
+ ALLOWED_INPUT_TYPES.add(Type.FLOAT);
+ ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
+ }
+
+ @Override
+ public List<ParameterSpecification> getArgumentsSpecifications() {
+ return Arrays.asList(
+
TableParameterSpecification.builder().name(INPUT_PARAMETER_NAME).setSemantics().build(),
+ ScalarParameterSpecification.builder()
+ .name(MODEL_ID_PARAMETER_NAME)
+ .type(Type.STRING)
+ .build(),
+ ScalarParameterSpecification.builder()
+ .name(TIMECOL_PARAMETER_NAME)
+ .type(Type.STRING)
+ .defaultValue(DEFAULT_TIME_COL)
+ .build());
+ }
+
+ @Override
+ public TableFunctionAnalysis analyze(Map<String, Argument> arguments) throws
UDFException {
+ TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME);
+ String modelId = (String) ((ScalarArgument)
arguments.get(MODEL_ID_PARAMETER_NAME)).getValue();
+ // modelId should never be null or empty
+ if (modelId == null || modelId.isEmpty()) {
+ throw new SemanticException(
+ String.format("%s should never be null or empty",
MODEL_ID_PARAMETER_NAME));
+ }
+
+ String timeColumn =
+ ((String) ((ScalarArgument)
arguments.get(TIMECOL_PARAMETER_NAME)).getValue())
+ .toLowerCase(Locale.ENGLISH);
+
+ if (timeColumn.isEmpty()) {
+ throw new SemanticException(
+ String.format("%s should never be null or empty.",
TIMECOL_PARAMETER_NAME));
+ }
+
+ // predicated columns should never contain partition by columns and time
column
+ Set<String> excludedColumns =
+ input.getPartitionBy().stream()
+ .map(s -> s.toLowerCase(Locale.ENGLISH))
+ .collect(Collectors.toSet());
+ excludedColumns.add(timeColumn);
+ int timeColumnIndex = findColumnIndex(input, timeColumn,
Collections.singleton(Type.TIMESTAMP));
+
+ // List of required column indexes
+ List<Integer> requiredIndexList = new ArrayList<>();
+ requiredIndexList.add(timeColumnIndex);
+ DescribedSchema.Builder properColumnSchemaBuilder =
+ new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);
+
+ List<Type> inputColumnTypes = new ArrayList<>();
+ List<Optional<String>> allInputColumnsName = input.getFieldNames();
+ List<Type> allInputColumnsType = input.getFieldTypes();
+
+ for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
+ Optional<String> fieldName = allInputColumnsName.get(i);
+ // All input value columns are required for model forecasting
+ if (!fieldName.isPresent()
+ ||
!excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
+ Type columnType = allInputColumnsType.get(i);
+ checkType(columnType, fieldName.orElse(""));
+ inputColumnTypes.add(columnType);
+ requiredIndexList.add(i);
+ }
+ }
+ properColumnSchemaBuilder.addField(DEFAULT_OUTPUT_COLUMN_NAME, Type.INT32);
+
+ ClassifyTableFunctionHandle functionHandle =
+ new ClassifyTableFunctionHandle(modelId, MAX_INPUT_LENGTH,
inputColumnTypes);
+
+ // outputColumnSchema
+ return TableFunctionAnalysis.builder()
+ .properColumnSchema(properColumnSchemaBuilder.build())
+ .handle(functionHandle)
+ .requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList)
+ .build();
+ }
+
+ // only allow for INT32, INT64, FLOAT, DOUBLE
+ private void checkType(Type type, String columnName) {
+ if (!ALLOWED_INPUT_TYPES.contains(type)) {
+ throw new SemanticException(
+ String.format(
+ "The type of the column [%s] is [%s], only INT32, INT64, FLOAT,
DOUBLE is allowed",
+ columnName, type));
+ }
+ }
+
+ @Override
+ public TableFunctionHandle createTableFunctionHandle() {
+ return new ClassifyTableFunctionHandle();
+ }
+
+ @Override
+ public TableFunctionProcessorProvider getProcessorProvider(
+ TableFunctionHandle tableFunctionHandle) {
+ return new TableFunctionProcessorProvider() {
+ @Override
+ public TableFunctionDataProcessor getDataProcessor() {
+ return new ClassifyDataProcessor((ClassifyTableFunctionHandle)
tableFunctionHandle);
+ }
+ };
+ }
+
+ private static class ClassifyDataProcessor implements
TableFunctionDataProcessor {
+
+ private static final TsBlockSerde SERDE = new TsBlockSerde();
+ private static final IClientManager<Integer, AINodeClient> CLIENT_MANAGER =
+ AINodeClientManager.getInstance();
+
+ private final String modelId;
+ private final int maxInputLength;
+ private final LinkedList<Record> inputRecords;
+ private final TsBlockBuilder inputTsBlockBuilder;
+ private final List<ResultColumnAppender> inputColumnAppenderList;
+ private final List<ResultColumnAppender> resultColumnAppenderList;
+
+ public ClassifyDataProcessor(ClassifyTableFunctionHandle functionHandle) {
+ this.modelId = functionHandle.modelId;
+ this.maxInputLength = functionHandle.maxInputLength;
+ this.inputRecords = new LinkedList<>();
+ List<TSDataType> inputTsDataTypeList =
+ new ArrayList<>(functionHandle.inputColumnTypes.size());
+ this.inputColumnAppenderList = new
ArrayList<>(functionHandle.inputColumnTypes.size());
+ for (Type type : functionHandle.inputColumnTypes) {
+ // AINode currently only accept double input
+ inputTsDataTypeList.add(TSDataType.DOUBLE);
+ inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE));
+ }
+ this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
+ this.resultColumnAppenderList = new ArrayList<>(1);
+
this.resultColumnAppenderList.add(createResultColumnAppender(Type.INT32));
+ }
+
+ @Override
+ public void process(
+ Record input,
+ List<ColumnBuilder> properColumnBuilders,
+ ColumnBuilder passThroughIndexBuilder) {
+ // only keep at most maxInputLength rows
+ if (maxInputLength != 0 && inputRecords.size() == maxInputLength) {
+ inputRecords.removeFirst();
+ }
+ inputRecords.add(input);
+ }
+
+ @Override
+ public void finish(
+ List<ColumnBuilder> properColumnBuilders, ColumnBuilder
passThroughIndexBuilder) {
+
+ // time column
+ long inputStartTime = inputRecords.getFirst().getLong(0);
+ long inputEndTime = inputRecords.getLast().getLong(0);
+ if (inputEndTime < inputStartTime) {
+ throw new SemanticException(
+ String.format(
+ "input end time should never less than start time, start time
is %s, end time is %s",
+ inputStartTime, inputEndTime));
+ }
+ int outputLength = inputRecords.size();
+ for (Record inputRecord : inputRecords) {
+ properColumnBuilders.get(0).writeLong(inputRecord.getLong(0));
+ }
+
+ // predicated columns
+ TsBlock predicatedResult = classify();
+ if (predicatedResult.getPositionCount() != outputLength) {
+ throw new IoTDBRuntimeException(
+ String.format(
+ "Model %s output length is %s, doesn't equal to specified %s",
+ modelId, predicatedResult.getPositionCount(), outputLength),
+ TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
+ }
+
+ for (int columnIndex = 1, size = predicatedResult.getValueColumnCount();
+ columnIndex <= size;
+ columnIndex++) {
+ Column column = predicatedResult.getColumn(columnIndex - 1);
+ ColumnBuilder builder = properColumnBuilders.get(columnIndex);
+ ResultColumnAppender appender =
resultColumnAppenderList.get(columnIndex - 1);
+ for (int row = 0; row < outputLength; row++) {
+ if (column.isNull(row)) {
+ builder.appendNull();
+ } else {
+ // convert double to real type
+ appender.writeDouble(column.getDouble(row), builder);
+ }
+ }
+ }
+ }
+
+ private TsBlock classify() {
+ int outputLength = inputRecords.size();
+ // construct inputTSBlock for AINode
+ while (!inputRecords.isEmpty()) {
+ Record row = inputRecords.removeFirst();
+ inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0));
+ for (int i = 1, size = row.size(); i < size; i++) {
+ // we set null input to 0.0
+ if (row.isNull(i)) {
+ inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(0.0);
+ } else {
+ // need to transform other types to DOUBLE
+ inputTsBlockBuilder
+ .getColumnBuilder(i - 1)
+ .writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row,
i));
+ }
+ }
+ inputTsBlockBuilder.declarePosition();
+ }
+ TsBlock inputData = inputTsBlockBuilder.build();
+
+ TForecastResp resp;
+ try (AINodeClient client =
+
CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) {
+ resp = client.forecast(new TForecastReq(modelId,
SERDE.serialize(inputData), outputLength));
+ } catch (Exception e) {
+ throw new IoTDBRuntimeException(e.getMessage(),
CAN_NOT_CONNECT_AINODE.getStatusCode());
+ }
+
+ if (resp.getStatus().getCode() !=
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ String message =
+ String.format(
+ "Error occurred while executing classify:[%s]",
resp.getStatus().getMessage());
+ throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
+ }
+ return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+ }
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
index 08f7ec6c833..950d8c464e8 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -26,7 +26,7 @@ import
org.apache.iotdb.commons.exception.IoTDBRuntimeException;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.protocol.client.an.AINodeClient;
import org.apache.iotdb.db.protocol.client.an.AINodeClientManager;
-import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
+import
org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.iotdb.udf.api.relational.TableFunction;
import org.apache.iotdb.udf.api.relational.access.Record;
@@ -58,6 +58,7 @@ import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
+import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
@@ -70,12 +71,11 @@ import java.util.Set;
import java.util.stream.Collectors;
import static
org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex;
+import static
org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender;
import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
public class ForecastTableFunction implements TableFunction {
- private static final TsBlockSerde SERDE = new TsBlockSerde();
-
public static class ForecastTableFunctionHandle implements
TableFunctionHandle {
String modelId;
int maxInputLength;
@@ -84,7 +84,7 @@ public class ForecastTableFunction implements TableFunction {
long outputInterval;
boolean keepInput;
Map<String, String> options;
- List<Type> types;
+ List<Type> targetColumntypes;
public ForecastTableFunctionHandle() {}
@@ -96,7 +96,7 @@ public class ForecastTableFunction implements TableFunction {
int outputLength,
long outputStartTime,
long outputInterval,
- List<Type> types) {
+ List<Type> targetColumntypes) {
this.keepInput = keepInput;
this.maxInputLength = maxInputLength;
this.modelId = modelId;
@@ -104,7 +104,7 @@ public class ForecastTableFunction implements TableFunction
{
this.outputLength = outputLength;
this.outputStartTime = outputStartTime;
this.outputInterval = outputInterval;
- this.types = types;
+ this.targetColumntypes = targetColumntypes;
}
@Override
@@ -118,8 +118,8 @@ public class ForecastTableFunction implements TableFunction
{
ReadWriteIOUtils.write(outputInterval, outputStream);
ReadWriteIOUtils.write(keepInput, outputStream);
ReadWriteIOUtils.write(options, outputStream);
- ReadWriteIOUtils.write(types.size(), outputStream);
- for (Type type : types) {
+ ReadWriteIOUtils.write(targetColumntypes.size(), outputStream);
+ for (Type type : targetColumntypes) {
ReadWriteIOUtils.write(type.getType(), outputStream);
}
outputStream.flush();
@@ -143,9 +143,9 @@ public class ForecastTableFunction implements TableFunction
{
this.keepInput = ReadWriteIOUtils.readBoolean(buffer);
this.options = ReadWriteIOUtils.readMap(buffer);
int size = ReadWriteIOUtils.readInt(buffer);
- this.types = new ArrayList<>(size);
+ this.targetColumntypes = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
- types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
+ targetColumntypes.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
}
}
@@ -165,7 +165,7 @@ public class ForecastTableFunction implements TableFunction
{
&& keepInput == that.keepInput
&& Objects.equals(modelId, that.modelId)
&& Objects.equals(options, that.options)
- && Objects.equals(types, that.types);
+ && Objects.equals(targetColumntypes, that.targetColumntypes);
}
@Override
@@ -178,16 +178,14 @@ public class ForecastTableFunction implements
TableFunction {
outputInterval,
keepInput,
options,
- types);
+ targetColumntypes);
}
}
- private static final String INPUT_PARAMETER_NAME = "INPUT";
+ private static final String TARGETS_PARAMETER_NAME = "TARGETS";
private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
private static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH";
private static final int DEFAULT_OUTPUT_LENGTH = 96;
- private static final String PREDICATED_COLUMNS_PARAMETER_NAME =
"PREDICATED_COLUMNS";
- private static final String DEFAULT_PREDICATED_COLUMNS = "";
private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
public static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
@@ -212,20 +210,10 @@ public class ForecastTableFunction implements
TableFunction {
ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
}
- // need to set before analyze method is called
- // should only be used in fe scope, never be used in
TableFunctionProcessorProvider
- // The reason we don't directly set modelFetcher=ModelFetcher.getInstance()
is that we need to
- // mock IModelFetcher in UT
- private IModelFetcher modelFetcher = null;
-
- public void setModelFetcher(IModelFetcher modelFetcher) {
- this.modelFetcher = modelFetcher;
- }
-
@Override
public List<ParameterSpecification> getArgumentsSpecifications() {
return Arrays.asList(
-
TableParameterSpecification.builder().name(INPUT_PARAMETER_NAME).setSemantics().build(),
+
TableParameterSpecification.builder().name(TARGETS_PARAMETER_NAME).setSemantics().build(),
ScalarParameterSpecification.builder()
.name(MODEL_ID_PARAMETER_NAME)
.type(Type.STRING)
@@ -245,11 +233,6 @@ public class ForecastTableFunction implements
TableFunction {
.type(Type.INT64)
.defaultValue(DEFAULT_OUTPUT_INTERVAL)
.build(),
- ScalarParameterSpecification.builder()
- .name(PREDICATED_COLUMNS_PARAMETER_NAME)
- .type(Type.STRING)
- .defaultValue(DEFAULT_PREDICATED_COLUMNS)
- .build(),
ScalarParameterSpecification.builder()
.name(TIMECOL_PARAMETER_NAME)
.type(Type.STRING)
@@ -269,7 +252,7 @@ public class ForecastTableFunction implements TableFunction
{
@Override
public TableFunctionAnalysis analyze(Map<String, Argument> arguments) {
- TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME);
+ TableArgument targets = (TableArgument)
arguments.get(TARGETS_PARAMETER_NAME);
String modelId = (String) ((ScalarArgument)
arguments.get(MODEL_ID_PARAMETER_NAME)).getValue();
// modelId should never be null or empty
if (modelId == null || modelId.isEmpty()) {
@@ -284,82 +267,58 @@ public class ForecastTableFunction implements
TableFunction {
String.format("%s should be greater than 0",
OUTPUT_LENGTH_PARAMETER_NAME));
}
- String predicatedColumns =
- (String) ((ScalarArgument)
arguments.get(PREDICATED_COLUMNS_PARAMETER_NAME)).getValue();
-
String timeColumn =
((String) ((ScalarArgument)
arguments.get(TIMECOL_PARAMETER_NAME)).getValue())
.toLowerCase(Locale.ENGLISH);
-
if (timeColumn.isEmpty()) {
throw new SemanticException(
String.format("%s should never be null or empty.",
TIMECOL_PARAMETER_NAME));
}
+ long outputInterval = (long) ((ScalarArgument)
arguments.get(OUTPUT_INTERVAL)).getValue();
+ if (outputInterval < 0) {
+ throw new SemanticException(String.format("%s should be greater than 0",
OUTPUT_INTERVAL));
+ }
+
// predicated columns should never contain partition by columns and time
column
Set<String> excludedColumns =
- input.getPartitionBy().stream()
+ targets.getPartitionBy().stream()
.map(s -> s.toLowerCase(Locale.ENGLISH))
.collect(Collectors.toSet());
excludedColumns.add(timeColumn);
- int timeColumnIndex = findColumnIndex(input, timeColumn,
Collections.singleton(Type.TIMESTAMP));
+ int timeColumnIndex =
+ findColumnIndex(targets, timeColumn,
Collections.singleton(Type.TIMESTAMP));
+ // List of required column indexes
List<Integer> requiredIndexList = new ArrayList<>();
requiredIndexList.add(timeColumnIndex);
DescribedSchema.Builder properColumnSchemaBuilder =
new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);
- List<Type> predicatedColumnTypes = new ArrayList<>();
- List<Optional<String>> allInputColumnsName = input.getFieldNames();
- List<Type> allInputColumnsType = input.getFieldTypes();
- if (predicatedColumns.isEmpty()) {
- // predicated columns by default include all columns from input table
except for timecol and
- // partition by columns
- for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
- Optional<String> fieldName = allInputColumnsName.get(i);
- if (!fieldName.isPresent()
- ||
!excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
- Type columnType = allInputColumnsType.get(i);
- predicatedColumnTypes.add(columnType);
- checkType(columnType, fieldName.orElse(""));
- requiredIndexList.add(i);
- properColumnSchemaBuilder.addField(fieldName, columnType);
- }
- }
- } else {
- String[] predictedColumnsArray = predicatedColumns.split(";");
- Map<String, Integer> inputColumnIndexMap = new HashMap<>();
- for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
- Optional<String> fieldName = allInputColumnsName.get(i);
- if (!fieldName.isPresent()) {
- continue;
- }
- inputColumnIndexMap.put(fieldName.get().toLowerCase(Locale.ENGLISH),
i);
+ List<Type> targetColumnTypes = new ArrayList<>();
+ List<Optional<String>> allInputColumnsName = targets.getFieldNames();
+ List<Type> allInputColumnsType = targets.getFieldTypes();
+
+ // predicated columns = all input columns except timecol / partition by
columns
+ for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
+ Optional<String> fieldName = allInputColumnsName.get(i);
+ if (!fieldName.isPresent()
+ ||
excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
+ continue;
}
- Set<Integer> requiredIndexSet = new
HashSet<>(predictedColumnsArray.length);
- // columns need to be predicated
- for (String outputColumn : predictedColumnsArray) {
- String lowerCaseOutputColumn =
outputColumn.toLowerCase(Locale.ENGLISH);
- if (excludedColumns.contains(lowerCaseOutputColumn)) {
- throw new SemanticException(
- String.format("%s is in partition by clause or is time column",
outputColumn));
- }
- Integer inputColumnIndex =
inputColumnIndexMap.get(lowerCaseOutputColumn);
- if (inputColumnIndex == null) {
- throw new SemanticException(
- String.format("Column %s don't exist in input", outputColumn));
- }
- if (!requiredIndexSet.add(inputColumnIndex)) {
- throw new SemanticException(String.format("Duplicate column %s",
outputColumn));
- }
+ Type columnType = allInputColumnsType.get(i);
+ targetColumnTypes.add(columnType);
+ checkType(columnType, fieldName.get());
+ requiredIndexList.add(i);
+ properColumnSchemaBuilder.addField(fieldName.get(), columnType);
+ }
- Type columnType = allInputColumnsType.get(inputColumnIndex);
- predicatedColumnTypes.add(columnType);
- checkType(columnType, outputColumn);
- requiredIndexList.add(inputColumnIndex);
- properColumnSchemaBuilder.addField(outputColumn, columnType);
- }
+ if (targetColumnTypes.size() > 1) {
+ throw new SemanticException(
+ String.format(
+ "%s should not contain more than one target column, found [%s]
target columns.",
+ TARGETS_PARAMETER_NAME, targetColumnTypes.size()));
}
boolean keepInput =
@@ -369,7 +328,6 @@ public class ForecastTableFunction implements TableFunction
{
}
long outputStartTime = (long) ((ScalarArgument)
arguments.get(OUTPUT_START_TIME)).getValue();
- long outputInterval = (long) ((ScalarArgument)
arguments.get(OUTPUT_INTERVAL)).getValue();
String options = (String) ((ScalarArgument)
arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
ForecastTableFunctionHandle functionHandle =
@@ -381,13 +339,13 @@ public class ForecastTableFunction implements
TableFunction {
outputLength,
outputStartTime,
outputInterval,
- predicatedColumnTypes);
+ targetColumnTypes);
// outputColumnSchema
return TableFunctionAnalysis.builder()
.properColumnSchema(properColumnSchemaBuilder.build())
.handle(functionHandle)
- .requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList)
+ .requiredColumns(TARGETS_PARAMETER_NAME, requiredIndexList)
.build();
}
@@ -465,9 +423,9 @@ public class ForecastTableFunction implements TableFunction
{
this.keepInput = functionHandle.keepInput;
this.options = functionHandle.options;
this.inputRecords = new LinkedList<>();
- this.resultColumnAppenderList = new
ArrayList<>(functionHandle.types.size());
- List<TSDataType> tsDataTypeList = new
ArrayList<>(functionHandle.types.size());
- for (Type type : functionHandle.types) {
+ this.resultColumnAppenderList = new
ArrayList<>(functionHandle.targetColumntypes.size());
+ List<TSDataType> tsDataTypeList = new
ArrayList<>(functionHandle.targetColumntypes.size());
+ for (Type type : functionHandle.targetColumntypes) {
resultColumnAppenderList.add(createResultColumnAppender(type));
// ainode currently only accept double input
tsDataTypeList.add(TSDataType.DOUBLE);
@@ -475,21 +433,6 @@ public class ForecastTableFunction implements
TableFunction {
this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList);
}
- private static ResultColumnAppender createResultColumnAppender(Type type) {
- switch (type) {
- case INT32:
- return new Int32Appender();
- case INT64:
- return new Int64Appender();
- case FLOAT:
- return new FloatAppender();
- case DOUBLE:
- return new DoubleAppender();
- default:
- throw new IllegalArgumentException("Unsupported column type: " +
type);
- }
- }
-
@Override
public void process(
Record input,
@@ -528,6 +471,9 @@ public class ForecastTableFunction implements TableFunction
{
int columnSize = properColumnBuilders.size();
+ // sort inputRecords in ascending order by timestamp
+ inputRecords.sort(Comparator.comparingLong(record -> record.getLong(0)));
+
// time column
long inputStartTime = inputRecords.getFirst().getLong(0);
long inputEndTime = inputRecords.getLast().getLong(0);
@@ -546,6 +492,12 @@ public class ForecastTableFunction implements
TableFunction {
}
long outputTime =
(outputStartTime == Long.MIN_VALUE) ? (inputEndTime + interval) :
outputStartTime;
+ if (outputTime <= inputEndTime) {
+ throw new SemanticException(
+ String.format(
+ "The %s should be greater than the maximum timestamp of target
time series. Expected greater than [%s] but found [%s].",
+ OUTPUT_START_TIME, inputEndTime, outputTime));
+ }
for (int i = 0; i < outputLength; i++) {
properColumnBuilders.get(0).writeLong(outputTime + interval * i);
}
@@ -585,6 +537,7 @@ public class ForecastTableFunction implements TableFunction
{
}
private TsBlock forecast() {
+ // construct inputTSBlock for AINode
while (!inputRecords.isEmpty()) {
Record row = inputRecords.removeFirst();
inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0));
@@ -632,100 +585,4 @@ public class ForecastTableFunction implements
TableFunction {
return res;
}
}
-
- private interface ResultColumnAppender {
- void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder);
-
- double getDouble(Record row, int columnIndex);
-
- void writeDouble(double value, ColumnBuilder columnBuilder);
- }
-
- private static class Int32Appender implements ResultColumnAppender {
-
- @Override
- public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
- if (row.isNull(columnIndex)) {
- properColumnBuilder.appendNull();
- } else {
- properColumnBuilder.writeInt(row.getInt(columnIndex));
- }
- }
-
- @Override
- public double getDouble(Record row, int columnIndex) {
- return row.getInt(columnIndex);
- }
-
- @Override
- public void writeDouble(double value, ColumnBuilder columnBuilder) {
- columnBuilder.writeInt((int) value);
- }
- }
-
- private static class Int64Appender implements ResultColumnAppender {
-
- @Override
- public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
- if (row.isNull(columnIndex)) {
- properColumnBuilder.appendNull();
- } else {
- properColumnBuilder.writeLong(row.getLong(columnIndex));
- }
- }
-
- @Override
- public double getDouble(Record row, int columnIndex) {
- return row.getLong(columnIndex);
- }
-
- @Override
- public void writeDouble(double value, ColumnBuilder columnBuilder) {
- columnBuilder.writeLong((long) value);
- }
- }
-
- private static class FloatAppender implements ResultColumnAppender {
-
- @Override
- public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
- if (row.isNull(columnIndex)) {
- properColumnBuilder.appendNull();
- } else {
- properColumnBuilder.writeFloat(row.getFloat(columnIndex));
- }
- }
-
- @Override
- public double getDouble(Record row, int columnIndex) {
- return row.getFloat(columnIndex);
- }
-
- @Override
- public void writeDouble(double value, ColumnBuilder columnBuilder) {
- columnBuilder.writeFloat((float) value);
- }
- }
-
- private static class DoubleAppender implements ResultColumnAppender {
-
- @Override
- public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
- if (row.isNull(columnIndex)) {
- properColumnBuilder.appendNull();
- } else {
- properColumnBuilder.writeDouble(row.getDouble(columnIndex));
- }
- }
-
- @Override
- public double getDouble(Record row, int columnIndex) {
- return row.getDouble(columnIndex);
- }
-
- @Override
- public void writeDouble(double value, ColumnBuilder columnBuilder) {
- columnBuilder.writeDouble(value);
- }
- }
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/ResultColumnAppender.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/ResultColumnAppender.java
new file mode 100644
index 00000000000..7748f39382d
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/ResultColumnAppender.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.utils;
+
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.type.Type;
+
+import org.apache.tsfile.block.column.ColumnBuilder;
+
+public interface ResultColumnAppender {
+
+ void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder);
+
+ double getDouble(Record row, int columnIndex);
+
+ void writeDouble(double value, ColumnBuilder columnBuilder);
+
+ /**
+ * Static factory method to return the appropriate ResultColumnAppender
instance based on the
+ * Type.
+ */
+ static ResultColumnAppender createResultColumnAppender(Type type) {
+ switch (type) {
+ case INT32:
+ return new Int32Appender();
+ case INT64:
+ return new Int64Appender();
+ case FLOAT:
+ return new FloatAppender();
+ case DOUBLE:
+ return new DoubleAppender();
+ default:
+ throw new IllegalArgumentException("Unsupported column type: " + type);
+ }
+ }
+
+ /** INT32 Appender */
+ class Int32Appender implements ResultColumnAppender {
+
+ @Override
+ public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
+ if (row.isNull(columnIndex)) {
+ properColumnBuilder.appendNull();
+ } else {
+ properColumnBuilder.writeInt(row.getInt(columnIndex));
+ }
+ }
+
+ @Override
+ public double getDouble(Record row, int columnIndex) {
+ return row.getInt(columnIndex);
+ }
+
+ @Override
+ public void writeDouble(double value, ColumnBuilder columnBuilder) {
+ columnBuilder.writeInt((int) value);
+ }
+ }
+
+ /** INT64 Appender */
+ class Int64Appender implements ResultColumnAppender {
+
+ @Override
+ public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
+ if (row.isNull(columnIndex)) {
+ properColumnBuilder.appendNull();
+ } else {
+ properColumnBuilder.writeLong(row.getLong(columnIndex));
+ }
+ }
+
+ @Override
+ public double getDouble(Record row, int columnIndex) {
+ return row.getLong(columnIndex);
+ }
+
+ @Override
+ public void writeDouble(double value, ColumnBuilder columnBuilder) {
+ columnBuilder.writeLong((long) value);
+ }
+ }
+
+ /** FLOAT Appender */
+ class FloatAppender implements ResultColumnAppender {
+
+ @Override
+ public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
+ if (row.isNull(columnIndex)) {
+ properColumnBuilder.appendNull();
+ } else {
+ properColumnBuilder.writeFloat(row.getFloat(columnIndex));
+ }
+ }
+
+ @Override
+ public double getDouble(Record row, int columnIndex) {
+ return row.getFloat(columnIndex);
+ }
+
+ @Override
+ public void writeDouble(double value, ColumnBuilder columnBuilder) {
+ columnBuilder.writeFloat((float) value);
+ }
+ }
+
+ /** DOUBLE Appender */
+ class DoubleAppender implements ResultColumnAppender {
+
+ @Override
+ public void append(Record row, int columnIndex, ColumnBuilder
properColumnBuilder) {
+ if (row.isNull(columnIndex)) {
+ properColumnBuilder.appendNull();
+ } else {
+ properColumnBuilder.writeDouble(row.getDouble(columnIndex));
+ }
+ }
+
+ @Override
+ public double getDouble(Record row, int columnIndex) {
+ return row.getDouble(columnIndex);
+ }
+
+ @Override
+ public void writeDouble(double value, ColumnBuilder columnBuilder) {
+ columnBuilder.writeDouble(value);
+ }
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
index 09f00f8ed67..ebecf79f5b7 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
@@ -85,6 +85,7 @@ public class UDTFForecast implements UDTF {
private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
private static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS";
private static final String DEFAULT_OPTIONS = "";
+ private static final int MAX_INPUT_LENGTH = 2880;
private void checkType() {
for (Type type : this.types) {
@@ -109,6 +110,7 @@ public class UDTFForecast implements UDTF {
throw new IllegalArgumentException(
"MODEL_ID parameter must be provided and cannot be empty.");
}
+ this.maxInputLength = MAX_INPUT_LENGTH;
this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL,
DEFAULT_OUTPUT_INTERVAL);
this.outputLength =
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
index 7bbfe150ade..344c69cfc1a 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
@@ -356,7 +356,7 @@ public class TableFunctionTest {
String sql =
"SELECT * FROM FORECAST("
- + "input => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' AND
tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
+ + "targets => (SELECT time,s3 FROM table1 WHERE tag1='shanghai'
AND tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
+ "model_id => 'timer_xl')";
LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
@@ -416,7 +416,7 @@ public class TableFunctionTest {
String sql =
"SELECT * FROM FORECAST("
- + "input => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' AND
tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
+ + "targets => (SELECT time,s3 FROM table1 WHERE tag1='shanghai'
AND tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
+ "model_id => 'timer_xl', timecol=>'TiME')";
LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);