This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new ea6a2d01b19 Support built-in forecast table function for table model
ea6a2d01b19 is described below
commit ea6a2d01b199ade15fa84c4ad6beeb67de8fcfad
Author: Jackie Tien <[email protected]>
AuthorDate: Tue May 13 12:11:25 2025 +0800
Support built-in forecast table function for table model
---
.../ScalarParameterSpecification.java | 2 +-
.../function/TableBuiltinTableFunction.java | 9 +++--
.../function/tvf/ForecastTableFunction.java | 45 ++++++++++++++++++++--
3 files changed, 47 insertions(+), 9 deletions(-)
diff --git
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/table/specification/ScalarParameterSpecification.java
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/table/specification/ScalarParameterSpecification.java
index 7333bd32601..84cc932cee5 100644
---
a/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/table/specification/ScalarParameterSpecification.java
+++
b/iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/table/specification/ScalarParameterSpecification.java
@@ -63,7 +63,7 @@ public class ScalarParameterSpecification extends
ParameterSpecification {
private Type type;
private boolean required = true;
private Object defaultValue;
- private List<Function<Object, String>> checkers = new ArrayList<>();
+ private final List<Function<Object, String>> checkers = new ArrayList<>();
private Builder() {}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
index fda10eba8db..4a07f9a0c7b 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
@@ -25,6 +25,7 @@ import
org.apache.iotdb.commons.udf.builtin.relational.tvf.HOPTableFunction;
import
org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
import
org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction;
+import
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
import org.apache.iotdb.udf.api.relational.TableFunction;
import java.util.Arrays;
@@ -38,8 +39,8 @@ public enum TableBuiltinTableFunction {
CUMULATE("cumulate"),
SESSION("session"),
VARIATION("variation"),
- CAPACITY("capacity");
- // FORECAST("forecast");
+ CAPACITY("capacity"),
+ FORECAST("forecast");
private final String functionName;
@@ -79,8 +80,8 @@ public enum TableBuiltinTableFunction {
return new VariationTableFunction();
case "capacity":
return new CapacityTableFunction();
- // case "forecast":
- // return new ForecastTableFunction();
+ case "forecast":
+ return new ForecastTableFunction();
default:
throw new UnsupportedOperationException("Unsupported table function: "
+ functionName);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
index 38b2c8d49aa..9086e6f5edd 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -78,6 +78,8 @@ public class ForecastTableFunction implements TableFunction {
String modelId;
int maxInputLength;
int outputLength;
+ long outputStartTime;
+ long outputInterval;
boolean keepInput;
Map<String, String> options;
List<Type> types;
@@ -90,6 +92,8 @@ public class ForecastTableFunction implements TableFunction {
String modelId,
Map<String, String> options,
int outputLength,
+ long outputStartTime,
+ long outputInterval,
TEndPoint targetAINode,
List<Type> types) {
this.keepInput = keepInput;
@@ -97,6 +101,8 @@ public class ForecastTableFunction implements TableFunction {
this.modelId = modelId;
this.options = options;
this.outputLength = outputLength;
+ this.outputStartTime = outputStartTime;
+ this.outputInterval = outputInterval;
this.targetAINode = targetAINode;
this.types = types;
}
@@ -110,6 +116,8 @@ public class ForecastTableFunction implements TableFunction
{
ReadWriteIOUtils.write(modelId, outputStream);
ReadWriteIOUtils.write(maxInputLength, outputStream);
ReadWriteIOUtils.write(outputLength, outputStream);
+ ReadWriteIOUtils.write(outputStartTime, outputStream);
+ ReadWriteIOUtils.write(outputInterval, outputStream);
ReadWriteIOUtils.write(keepInput, outputStream);
ReadWriteIOUtils.write(options, outputStream);
ReadWriteIOUtils.write(types.size(), outputStream);
@@ -134,6 +142,8 @@ public class ForecastTableFunction implements TableFunction
{
this.modelId = ReadWriteIOUtils.readString(buffer);
this.maxInputLength = ReadWriteIOUtils.readInt(buffer);
this.outputLength = ReadWriteIOUtils.readInt(buffer);
+ this.outputStartTime = ReadWriteIOUtils.readLong(buffer);
+ this.outputInterval = ReadWriteIOUtils.readLong(buffer);
this.keepInput = ReadWriteIOUtils.readBoolean(buffer);
this.options = ReadWriteIOUtils.readMap(buffer);
int size = ReadWriteIOUtils.readInt(buffer);
@@ -152,6 +162,10 @@ public class ForecastTableFunction implements
TableFunction {
private static final int DEFAULT_OUTPUT_LENGTH = 96;
private static final String PREDICATED_COLUMNS_PARAMETER_NAME =
"PREDICATED_COLUMNS";
private static final String DEFAULT_PREDICATED_COLUMNS = "";
+ private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
+ private static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
+ private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
+ private static final long DEFAULT_OUTPUT_INTERVAL = 0L;
private static final String TIMECOL_PARAMETER_NAME = "TIMECOL";
private static final String DEFAULT_TIME_COL = "time";
private static final String KEEP_INPUT_PARAMETER_NAME = "KEEP_INPUT";
@@ -184,6 +198,16 @@ public class ForecastTableFunction implements
TableFunction {
.type(Type.INT32)
.defaultValue(DEFAULT_OUTPUT_LENGTH)
.build(),
+ ScalarParameterSpecification.builder()
+ .name(OUTPUT_START_TIME)
+ .type(Type.TIMESTAMP)
+ .defaultValue(DEFAULT_OUTPUT_START_TIME)
+ .build(),
+ ScalarParameterSpecification.builder()
+ .name(OUTPUT_INTERVAL)
+ .type(Type.INT64)
+ .defaultValue(DEFAULT_OUTPUT_INTERVAL)
+ .build(),
ScalarParameterSpecification.builder()
.name(PREDICATED_COLUMNS_PARAMETER_NAME)
.type(Type.STRING)
@@ -307,6 +331,8 @@ public class ForecastTableFunction implements TableFunction
{
properColumnSchemaBuilder.addField(IS_INPUT_COLUMN_NAME, Type.BOOLEAN);
}
+ long outputStartTime = (long) ((ScalarArgument)
arguments.get(OUTPUT_START_TIME)).getValue();
+ long outputInterval = (long) ((ScalarArgument)
arguments.get(OUTPUT_INTERVAL)).getValue();
String options = (String) ((ScalarArgument)
arguments.get(OPTIONS_PARAMETER_NAME)).getValue();
ForecastTableFunctionHandle functionHandle =
@@ -316,6 +342,8 @@ public class ForecastTableFunction implements TableFunction
{
modelId,
parseOptions(options),
outputLength,
+ outputStartTime,
+ outputInterval,
targetAINode,
predicatedColumnTypes);
@@ -389,6 +417,8 @@ public class ForecastTableFunction implements TableFunction
{
private final String modelId;
private final int maxInputLength;
private final int outputLength;
+ private final long outputStartTime;
+ private final long outputInterval;
private final boolean keepInput;
private final Map<String, String> options;
private final LinkedList<Record> inputRecords;
@@ -400,6 +430,8 @@ public class ForecastTableFunction implements TableFunction
{
this.modelId = functionHandle.modelId;
this.maxInputLength = functionHandle.maxInputLength;
this.outputLength = functionHandle.outputLength;
+ this.outputStartTime = functionHandle.outputStartTime;
+ this.outputInterval = functionHandle.outputInterval;
this.keepInput = functionHandle.keepInput;
this.options = functionHandle.options;
this.inputRecords = new LinkedList<>();
@@ -467,11 +499,16 @@ public class ForecastTableFunction implements
TableFunction {
int columnSize = properColumnBuilders.size();
// time column
- long startTime = inputRecords.getFirst().getLong(0);
- long endTime = inputRecords.getLast().getLong(0);
- long interval = (endTime - startTime) / inputRecords.size();
+ long inputStartTime = inputRecords.getFirst().getLong(0);
+ long inputEndTime = inputRecords.getLast().getLong(0);
+ long interval =
+ outputInterval <= 0
+ ? (inputEndTime - inputStartTime) / inputRecords.size()
+ : outputInterval;
+ long outputTime =
+ (outputStartTime == Long.MIN_VALUE) ? (inputEndTime + interval) :
outputStartTime;
for (int i = 0; i < outputLength; i++) {
- properColumnBuilders.get(0).writeLong(endTime + interval * (i + 1));
+ properColumnBuilders.get(0).writeLong(outputTime + interval * i);
}
// predicated columns