This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new f16252d8105 [AINode] Support output time column for model inference
f16252d8105 is described below
commit f16252d8105cab29bbe28b5d8ddf3d3afda6d85b
Author: YangCaiyin <[email protected]>
AuthorDate: Thu Mar 13 09:19:45 2025 +0800
[AINode] Support output time column for model inference
---
.../org/apache/iotdb/ainode/it/AINodeBasicIT.java | 15 ++++----
.../operator/process/ai/InferenceOperator.java | 42 +++++++++++++++++++++-
.../queryengine/plan/analyze/AnalyzeVisitor.java | 3 +-
.../db/queryengine/plan/parser/ASTVisitor.java | 3 ++
.../plan/planner/LogicalPlanBuilder.java | 1 +
.../plan/planner/LogicalPlanVisitor.java | 2 +-
.../plan/planner/OperatorTreeGenerator.java | 1 +
.../plan/node/process/AI/InferenceNode.java | 18 ++++++++--
.../plan/statement/crud/QueryStatement.java | 9 +++++
9 files changed, 82 insertions(+), 12 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
index 3b54b5c4409..07d29c0d224 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
@@ -178,21 +178,22 @@ public class AINodeBasicIT {
@Test
public void callInferenceTest() {
- String sql = "CALL INFERENCE(identity, \"select s0,s1,s2 from
root.AI.data\")";
+ String sql =
+ "CALL INFERENCE(identity, \"select s0,s1,s2 from root.AI.data\",
generateTime=true)";
String sql2 = "CALL INFERENCE(identity, \"select s2,s0,s1 from
root.AI.data\")";
String sql3 =
- "CALL INFERENCE(_NaiveForecaster, \"select s0 from root.AI.data\",
predict_length=3)";
+ "CALL INFERENCE(_NaiveForecaster, \"select s0 from root.AI.data\",
predict_length=3, generateTime=true)";
try (Connection connection = EnvFactory.getEnv().getConnection();
Statement statement = connection.createStatement()) {
try (ResultSet resultSet = statement.executeQuery(sql)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "output0,output1,output2");
+ checkHeader(resultSetMetaData, "Time,output0,output1,output2");
int count = 0;
while (resultSet.next()) {
- float s0 = resultSet.getFloat(1);
- float s1 = resultSet.getFloat(2);
- float s2 = resultSet.getFloat(3);
+ float s0 = resultSet.getFloat(2);
+ float s1 = resultSet.getFloat(3);
+ float s2 = resultSet.getFloat(4);
assertEquals(s0, count + 1.0, 0.0001);
assertEquals(s1, count + 2.0, 0.0001);
@@ -221,7 +222,7 @@ public class AINodeBasicIT {
try (ResultSet resultSet = statement.executeQuery(sql3)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "output0,output1,output2");
+ checkHeader(resultSetMetaData, "Time,output0,output1,output2");
int count = 0;
while (resultSet.next()) {
count++;
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
index 08cae2dc5b1..9bdb57dc5bd 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
@@ -36,6 +36,7 @@ import org.apache.iotdb.rpc.TSStatusCode;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.TsBlock;
@@ -81,6 +82,12 @@ public class InferenceOperator implements ProcessOperator {
private final TsBlockSerde serde = new TsBlockSerde();
private InferenceWindowType windowType = null;
+ private final boolean generateTimeColumn;
+ private long maxTimestamp;
+ private long minTimestamp;
+ private long interval;
+ private long currentRowIndex;
+
public InferenceOperator(
OperatorContext operatorContext,
Operator child,
@@ -88,6 +95,7 @@ public class InferenceOperator implements ProcessOperator {
ExecutorService modelInferenceExecutor,
List<String> targetColumnNames,
List<String> inputColumnNames,
+ boolean generateTimeColumn,
long maxRetainedSize,
long maxReturnSize) {
this.operatorContext = operatorContext;
@@ -106,6 +114,14 @@ public class InferenceOperator implements ProcessOperator {
if (modelInferenceDescriptor.getInferenceWindowParameter() != null) {
windowType =
modelInferenceDescriptor.getInferenceWindowParameter().getWindowType();
}
+
+ if (generateTimeColumn) {
+ this.interval = 0;
+ this.minTimestamp = Long.MAX_VALUE;
+ this.maxTimestamp = Long.MIN_VALUE;
+ this.currentRowIndex = 0;
+ }
+ this.generateTimeColumn = generateTimeColumn;
}
@Override
@@ -140,6 +156,15 @@ public class InferenceOperator implements ProcessOperator {
return !finished || (results != null && results.size() != resultIndex);
}
+ private void fillTimeColumn(TsBlock tsBlock) {
+ Column timeColumn = tsBlock.getTimeColumn();
+ long[] time = timeColumn.getLongs();
+ for (int i = 0; i < time.length; i++) {
+ time[i] = maxTimestamp + interval * currentRowIndex;
+ currentRowIndex++;
+ }
+ }
+
@Override
public TsBlock next() throws Exception {
if (inferenceExecutionFuture == null) {
@@ -156,6 +181,9 @@ public class InferenceOperator implements ProcessOperator {
if (results != null && resultIndex != results.size()) {
TsBlock tsBlock = serde.deserialize(results.get(resultIndex));
+ if (generateTimeColumn) {
+ fillTimeColumn(tsBlock);
+ }
resultIndex++;
return tsBlock;
}
@@ -177,6 +205,9 @@ public class InferenceOperator implements ProcessOperator {
finished = true;
TsBlock resultTsBlock =
serde.deserialize(inferenceResp.inferenceResult.get(0));
+ if (generateTimeColumn) {
+ fillTimeColumn(resultTsBlock);
+ }
results = inferenceResp.inferenceResult;
resultIndex++;
return resultTsBlock;
@@ -194,7 +225,12 @@ public class InferenceOperator implements ProcessOperator {
ColumnBuilder[] columnBuilders =
inputTsBlockBuilder.getValueColumnBuilders();
totalRow += inputTsBlock.getPositionCount();
for (int i = 0; i < inputTsBlock.getPositionCount(); i++) {
- timeColumnBuilder.writeLong(inputTsBlock.getTimeByIndex(i));
+ long timestamp = inputTsBlock.getTimeByIndex(i);
+ if (generateTimeColumn) {
+ minTimestamp = Math.min(minTimestamp, timestamp);
+ maxTimestamp = Math.max(maxTimestamp, timestamp);
+ }
+ timeColumnBuilder.writeLong(timestamp);
for (int columnIndex = 0; columnIndex <
inputTsBlock.getValueColumnCount(); columnIndex++) {
columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex),
i);
}
@@ -259,6 +295,10 @@ public class InferenceOperator implements ProcessOperator {
private void submitInferenceTask() {
+ if (generateTimeColumn) {
+ interval = (maxTimestamp - minTimestamp) / totalRow;
+ }
+
TsBlock inputTsBlock = inputTsBlockBuilder.build();
TsBlock finalInputTsBlock = preProcess(inputTsBlock);
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
index 102cbd9732c..e0fa5c1c004 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
@@ -1695,7 +1695,8 @@ public class AnalyzeVisitor extends
StatementVisitor<Analysis, MPPQueryContext>
.getModelInferenceDescriptor()
.setOutputColumnNames(
columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList()));
- analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, true));
+ boolean isIgnoreTimestamp = !queryStatement.isGenerateTime();
+ analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders,
isIgnoreTimestamp));
return;
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
index 1ca5a693050..1ae9fc34961 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
@@ -4527,6 +4527,9 @@ public class ASTVisitor extends
IoTDBSqlParserBaseVisitor<Statement> {
"Window Function(e.g. HEAD, TAIL, COUNT) should be set in
value when key is 'WINDOW' in CALL INFERENCE");
}
parseWindowFunctionInInference(valueContext.windowFunction(),
statement);
+ } else if (paramKey.equalsIgnoreCase("GENERATETIME")) {
+ statement.setGenerateTime(
+
Boolean.parseBoolean(parseAttributeValue(valueContext.attributeValue())));
} else {
statement.addInferenceAttribute(
paramKey, parseAttributeValue(valueContext.attributeValue()));
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java
index c68c0ef4d74..fe71feb4639 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java
@@ -1383,6 +1383,7 @@ public class LogicalPlanBuilder {
context.getQueryId().genPlanNodeId(),
root,
analysis.getModelInferenceDescriptor(),
+ !analysis.getRespDatasetHeader().isIgnoreTimestamp(),
analysis.getOutputExpressions().stream()
.map(expressionStringPair ->
expressionStringPair.left.getExpressionString())
.collect(Collectors.toList()));
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java
index 376f582eedd..1a8a3cf487b 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java
@@ -234,7 +234,7 @@ public class LogicalPlanVisitor extends
StatementVisitor<PlanNode, MPPQueryConte
}
if (queryStatement.hasModelInference()) {
- planBuilder.planInference(analysis);
+ planBuilder = planBuilder.planInference(analysis);
}
// plan select into
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
index 75328790154..59f717e2860 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
@@ -2310,6 +2310,7 @@ public class OperatorTreeGenerator extends
PlanVisitor<Operator, LocalExecutionP
FragmentInstanceManager.getInstance().getModelInferenceExecutor(),
node.getInputColumnNames(),
node.getChild().getOutputColumnNames(),
+ node.isGenerateTimeColumn(),
maxRetainedSize,
maxReturnSize);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java
index 95fe3437e78..09205c9eb56 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java
@@ -40,24 +40,29 @@ public class InferenceNode extends SingleChildProcessNode {
// the column order in select item which reflects the real input order
private final List<String> targetColumnNames;
+ private boolean generateTimeColumn = false;
public InferenceNode(
PlanNodeId id,
PlanNode child,
ModelInferenceDescriptor modelInferenceDescriptor,
+ boolean generateTimeColumn,
List<String> targetColumnNames) {
super(id, child);
this.modelInferenceDescriptor = modelInferenceDescriptor;
this.targetColumnNames = targetColumnNames;
+ this.generateTimeColumn = generateTimeColumn;
}
public InferenceNode(
PlanNodeId id,
ModelInferenceDescriptor modelInferenceDescriptor,
+ boolean generateTimeColumn,
List<String> inputColumnNames) {
super(id);
this.modelInferenceDescriptor = modelInferenceDescriptor;
this.targetColumnNames = inputColumnNames;
+ this.generateTimeColumn = generateTimeColumn;
}
public ModelInferenceDescriptor getModelInferenceDescriptor() {
@@ -68,6 +73,10 @@ public class InferenceNode extends SingleChildProcessNode {
return targetColumnNames;
}
+ public boolean isGenerateTimeColumn() {
+ return generateTimeColumn;
+ }
+
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitInference(this, context);
@@ -75,7 +84,8 @@ public class InferenceNode extends SingleChildProcessNode {
@Override
public PlanNode clone() {
- return new InferenceNode(getPlanNodeId(), child, modelInferenceDescriptor,
targetColumnNames);
+ return new InferenceNode(
+ getPlanNodeId(), child, modelInferenceDescriptor, generateTimeColumn,
targetColumnNames);
}
@Override
@@ -87,6 +97,7 @@ public class InferenceNode extends SingleChildProcessNode {
protected void serializeAttributes(ByteBuffer byteBuffer) {
PlanNodeType.INFERENCE.serialize(byteBuffer);
modelInferenceDescriptor.serialize(byteBuffer);
+ ReadWriteIOUtils.write(generateTimeColumn, byteBuffer);
ReadWriteIOUtils.writeStringList(targetColumnNames, byteBuffer);
}
@@ -94,15 +105,18 @@ public class InferenceNode extends SingleChildProcessNode {
protected void serializeAttributes(DataOutputStream stream) throws
IOException {
PlanNodeType.INFERENCE.serialize(stream);
modelInferenceDescriptor.serialize(stream);
+ ReadWriteIOUtils.write(generateTimeColumn, stream);
ReadWriteIOUtils.writeStringList(targetColumnNames, stream);
}
public static InferenceNode deserialize(ByteBuffer buffer) {
ModelInferenceDescriptor modelInferenceDescriptor =
ModelInferenceDescriptor.deserialize(buffer);
+ boolean generateTimeColumn = ReadWriteIOUtils.readBool(buffer);
List<String> inputColumnNames = ReadWriteIOUtils.readStringList(buffer);
PlanNodeId planNodeId = PlanNodeId.deserialize(buffer);
- return new InferenceNode(planNodeId, modelInferenceDescriptor,
inputColumnNames);
+ return new InferenceNode(
+ planNodeId, modelInferenceDescriptor, generateTimeColumn,
inputColumnNames);
}
@Override
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/QueryStatement.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/QueryStatement.java
index cf408552f74..1993def2326 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/QueryStatement.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/QueryStatement.java
@@ -139,9 +139,18 @@ public class QueryStatement extends
AuthorityInformationStatement {
// [IoTDB-AI] used for model inference, which will be removed in the future
private String modelName;
private boolean hasModelInference = false;
+ private boolean generateTime = false;
private InferenceWindow inferenceWindow = null;
private Map<String, String> inferenceAttribute = null;
+ public void setGenerateTime(boolean generateTime) {
+ this.generateTime = generateTime;
+ }
+
+ public boolean isGenerateTime() {
+ return generateTime;
+ }
+
public void setModelName(String modelName) {
this.modelName = modelName;
}