This is an automated email from the ASF dual-hosted git repository. jackietien pushed a commit to branch ai/forecast in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 273ecbceab71d8ce94bfd7f31be6a1b500571e1d Author: JackieTien97 <[email protected]> AuthorDate: Mon May 12 20:14:44 2025 +0800 Support built-in forecast table function for table model --- .../ScalarParameterSpecification.java | 2 +- .../function/TableBuiltinTableFunction.java | 9 +++-- .../function/tvf/ForecastTableFunction.java | 45 ++++++++++++++++++++-- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/table/specification/ScalarParameterSpecification.java b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/table/specification/ScalarParameterSpecification.java index 7333bd32601..84cc932cee5 100644 --- a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/table/specification/ScalarParameterSpecification.java +++ b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/table/specification/ScalarParameterSpecification.java @@ -63,7 +63,7 @@ public class ScalarParameterSpecification extends ParameterSpecification { private Type type; private boolean required = true; private Object defaultValue; - private List<Function<Object, String>> checkers = new ArrayList<>(); + private final List<Function<Object, String>> checkers = new ArrayList<>(); private Builder() {} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java index fda10eba8db..4a07f9a0c7b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java @@ -25,6 +25,7 @@ import org.apache.iotdb.commons.udf.builtin.relational.tvf.HOPTableFunction; import org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction; import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction; import org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction; +import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; import org.apache.iotdb.udf.api.relational.TableFunction; import java.util.Arrays; @@ -38,8 +39,8 @@ public enum TableBuiltinTableFunction { CUMULATE("cumulate"), SESSION("session"), VARIATION("variation"), - CAPACITY("capacity"); - // FORECAST("forecast"); + CAPACITY("capacity"), + FORECAST("forecast"); private final String functionName; @@ -79,8 +80,8 @@ public enum TableBuiltinTableFunction { return new VariationTableFunction(); case "capacity": return new CapacityTableFunction(); - // case "forecast": - // return new ForecastTableFunction(); + case "forecast": + return new ForecastTableFunction(); default: throw new UnsupportedOperationException("Unsupported table function: " + functionName); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index 38b2c8d49aa..9086e6f5edd 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -78,6 +78,8 @@ public class ForecastTableFunction implements TableFunction { String modelId; int maxInputLength; int outputLength; + long outputStartTime; + long outputInterval; boolean keepInput; Map<String, String> options; List<Type> types; @@ -90,6 +92,8 @@ public class ForecastTableFunction implements TableFunction { String modelId, Map<String, String> options, int outputLength, + long outputStartTime, + long outputInterval, TEndPoint targetAINode, List<Type> types) { this.keepInput = keepInput; @@ -97,6 +101,8 @@ public class ForecastTableFunction implements TableFunction { this.modelId = modelId; this.options = options; this.outputLength = outputLength; + this.outputStartTime = outputStartTime; + this.outputInterval = outputInterval; this.targetAINode = targetAINode; this.types = types; } @@ -110,6 +116,8 @@ public class ForecastTableFunction implements TableFunction { ReadWriteIOUtils.write(modelId, outputStream); ReadWriteIOUtils.write(maxInputLength, outputStream); ReadWriteIOUtils.write(outputLength, outputStream); + ReadWriteIOUtils.write(outputStartTime, outputStream); + ReadWriteIOUtils.write(outputInterval, outputStream); ReadWriteIOUtils.write(keepInput, outputStream); ReadWriteIOUtils.write(options, outputStream); ReadWriteIOUtils.write(types.size(), outputStream); @@ -134,6 +142,8 @@ public class ForecastTableFunction implements TableFunction { this.modelId = ReadWriteIOUtils.readString(buffer); this.maxInputLength = ReadWriteIOUtils.readInt(buffer); this.outputLength = ReadWriteIOUtils.readInt(buffer); + this.outputStartTime = ReadWriteIOUtils.readLong(buffer); + this.outputInterval = ReadWriteIOUtils.readLong(buffer); this.keepInput = ReadWriteIOUtils.readBoolean(buffer); this.options = ReadWriteIOUtils.readMap(buffer); int size = ReadWriteIOUtils.readInt(buffer); @@ -152,6 +162,10 @@ public class ForecastTableFunction implements TableFunction { private static final int DEFAULT_OUTPUT_LENGTH = 96; private static final String PREDICATED_COLUMNS_PARAMETER_NAME = "PREDICATED_COLUMNS"; private static final String DEFAULT_PREDICATED_COLUMNS = ""; + private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME"; + private static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE; + private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL"; + private static final long DEFAULT_OUTPUT_INTERVAL = 0L; private static final String TIMECOL_PARAMETER_NAME = "TIMECOL"; private static final String DEFAULT_TIME_COL = "time"; private static final String KEEP_INPUT_PARAMETER_NAME = "KEEP_INPUT"; @@ -184,6 +198,16 @@ public class ForecastTableFunction implements TableFunction { .type(Type.INT32) .defaultValue(DEFAULT_OUTPUT_LENGTH) .build(), + ScalarParameterSpecification.builder() + .name(OUTPUT_START_TIME) + .type(Type.TIMESTAMP) + .defaultValue(DEFAULT_OUTPUT_START_TIME) + .build(), + ScalarParameterSpecification.builder() + .name(OUTPUT_INTERVAL) + .type(Type.INT64) + .defaultValue(DEFAULT_OUTPUT_INTERVAL) + .build(), ScalarParameterSpecification.builder() .name(PREDICATED_COLUMNS_PARAMETER_NAME) .type(Type.STRING) @@ -307,6 +331,8 @@ public class ForecastTableFunction implements TableFunction { properColumnSchemaBuilder.addField(IS_INPUT_COLUMN_NAME, Type.BOOLEAN); } + long outputStartTime = (long) ((ScalarArgument) arguments.get(OUTPUT_START_TIME)).getValue(); + long outputInterval = (long) ((ScalarArgument) arguments.get(OUTPUT_INTERVAL)).getValue(); String options = (String) ((ScalarArgument) arguments.get(OPTIONS_PARAMETER_NAME)).getValue(); ForecastTableFunctionHandle functionHandle = @@ -316,6 +342,8 @@ public class ForecastTableFunction implements TableFunction { modelId, parseOptions(options), outputLength, + outputStartTime, + outputInterval, targetAINode, predicatedColumnTypes); @@ -389,6 +417,8 @@ public class ForecastTableFunction implements TableFunction { private final String modelId; private final int maxInputLength; private final int outputLength; + private final long outputStartTime; + private final long outputInterval; private final boolean keepInput; private final Map<String, String> options; private final LinkedList<Record> inputRecords; @@ -400,6 +430,8 @@ public class ForecastTableFunction implements TableFunction { this.modelId = functionHandle.modelId; this.maxInputLength = functionHandle.maxInputLength; this.outputLength = functionHandle.outputLength; + this.outputStartTime = functionHandle.outputStartTime; + this.outputInterval = functionHandle.outputInterval; this.keepInput = functionHandle.keepInput; this.options = functionHandle.options; this.inputRecords = new LinkedList<>(); @@ -467,11 +499,16 @@ public class ForecastTableFunction implements TableFunction { int columnSize = properColumnBuilders.size(); // time column - long startTime = inputRecords.getFirst().getLong(0); - long endTime = inputRecords.getLast().getLong(0); - long interval = (endTime - startTime) / inputRecords.size(); + long inputStartTime = inputRecords.getFirst().getLong(0); + long inputEndTime = inputRecords.getLast().getLong(0); + long interval = + outputInterval <= 0 + ? (inputEndTime - inputStartTime) / inputRecords.size() + : outputInterval; + long outputTime = + (outputStartTime == Long.MIN_VALUE) ? (inputEndTime + interval) : outputStartTime; for (int i = 0; i < outputLength; i++) { - properColumnBuilders.get(0).writeLong(endTime + interval * (i + 1)); + properColumnBuilders.get(0).writeLong(outputTime + interval * i); } // predicated columns
