This is an automated email from the ASF dual-hosted git repository.

ycycse pushed a commit to branch ycy/ainodeTimeColumnAppend
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit a92f15125c0e3ea9b6877c4b48aebec186b8af21
Author: YangCaiyin <[email protected]>
AuthorDate: Thu Mar 6 17:49:08 2025 +0800

    support output time column for model inference
---
 .../operator/process/ai/InferenceOperator.java     | 42 +++++++++++++++++++++-
 .../queryengine/plan/analyze/AnalyzeVisitor.java   |  3 +-
 .../plan/planner/LogicalPlanBuilder.java           |  1 +
 .../plan/planner/LogicalPlanVisitor.java           |  2 +-
 .../plan/planner/OperatorTreeGenerator.java        |  1 +
 .../plan/node/process/AI/InferenceNode.java        | 18 ++++++++--
 6 files changed, 62 insertions(+), 5 deletions(-)

diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
index 08cae2dc5b1..03ba37def17 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
@@ -40,6 +40,7 @@ import org.apache.tsfile.block.column.ColumnBuilder;
 import org.apache.tsfile.enums.TSDataType;
 import org.apache.tsfile.read.common.block.TsBlock;
 import org.apache.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.tsfile.read.common.block.column.TimeColumn;
 import org.apache.tsfile.read.common.block.column.TimeColumnBuilder;
 import org.apache.tsfile.read.common.block.column.TsBlockSerde;
 import org.apache.tsfile.utils.RamUsageEstimator;
@@ -81,6 +82,12 @@ public class InferenceOperator implements ProcessOperator {
   private final TsBlockSerde serde = new TsBlockSerde();
   private InferenceWindowType windowType = null;
 
+  private final boolean generateTimeColumn;
+  private long maxTimestamp;
+  private long minTimestamp;
+  private long interval;
+  private long currentRowIndex;
+
   public InferenceOperator(
       OperatorContext operatorContext,
       Operator child,
@@ -88,6 +95,7 @@ public class InferenceOperator implements ProcessOperator {
       ExecutorService modelInferenceExecutor,
       List<String> targetColumnNames,
       List<String> inputColumnNames,
+      boolean generateTimeColumn,
       long maxRetainedSize,
       long maxReturnSize) {
     this.operatorContext = operatorContext;
@@ -106,6 +114,14 @@ public class InferenceOperator implements ProcessOperator {
     if (modelInferenceDescriptor.getInferenceWindowParameter() != null) {
       windowType = 
modelInferenceDescriptor.getInferenceWindowParameter().getWindowType();
     }
+
+    if (generateTimeColumn) {
+      this.interval = 0;
+      this.minTimestamp = Long.MAX_VALUE;
+      this.maxTimestamp = Long.MIN_VALUE;
+      this.currentRowIndex = 0;
+    }
+    this.generateTimeColumn = generateTimeColumn;
   }
 
   @Override
@@ -140,6 +156,15 @@ public class InferenceOperator implements ProcessOperator {
     return !finished || (results != null && results.size() != resultIndex);
   }
 
+  private void fillTimeColumn(TsBlock tsBlock) {
+    TimeColumn timeColumn = (TimeColumn) tsBlock.getTimeColumn();
+    long[] time = timeColumn.getTimes();
+    for (int i = 0; i < time.length; i++) {
+      time[i] = maxTimestamp + interval * currentRowIndex;
+      currentRowIndex++;
+    }
+  }
+
   @Override
   public TsBlock next() throws Exception {
     if (inferenceExecutionFuture == null) {
@@ -156,6 +181,9 @@ public class InferenceOperator implements ProcessOperator {
 
       if (results != null && resultIndex != results.size()) {
         TsBlock tsBlock = serde.deserialize(results.get(resultIndex));
+        if (generateTimeColumn) {
+          fillTimeColumn(tsBlock);
+        }
         resultIndex++;
         return tsBlock;
       }
@@ -177,6 +205,9 @@ public class InferenceOperator implements ProcessOperator {
 
         finished = true;
         TsBlock resultTsBlock = 
serde.deserialize(inferenceResp.inferenceResult.get(0));
+        if (generateTimeColumn) {
+          fillTimeColumn(resultTsBlock);
+        }
         results = inferenceResp.inferenceResult;
         resultIndex++;
         return resultTsBlock;
@@ -194,7 +225,12 @@ public class InferenceOperator implements ProcessOperator {
     ColumnBuilder[] columnBuilders = 
inputTsBlockBuilder.getValueColumnBuilders();
     totalRow += inputTsBlock.getPositionCount();
     for (int i = 0; i < inputTsBlock.getPositionCount(); i++) {
-      timeColumnBuilder.writeLong(inputTsBlock.getTimeByIndex(i));
+      long timestamp = inputTsBlock.getTimeByIndex(i);
+      if (generateTimeColumn) {
+        minTimestamp = Math.min(minTimestamp, timestamp);
+        maxTimestamp = Math.max(maxTimestamp, timestamp);
+      }
+      timeColumnBuilder.writeLong(timestamp);
       for (int columnIndex = 0; columnIndex < 
inputTsBlock.getValueColumnCount(); columnIndex++) {
         columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex), 
i);
       }
@@ -259,6 +295,10 @@ public class InferenceOperator implements ProcessOperator {
 
   private void submitInferenceTask() {
 
+    if (generateTimeColumn) {
+      interval = (maxTimestamp - minTimestamp) / totalRow;
+    }
+
     TsBlock inputTsBlock = inputTsBlockBuilder.build();
 
     TsBlock finalInputTsBlock = preProcess(inputTsBlock);
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
index 7f51f95c362..8be0db3747e 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
@@ -1695,7 +1695,8 @@ public class AnalyzeVisitor extends 
StatementVisitor<Analysis, MPPQueryContext>
           .getModelInferenceDescriptor()
           .setOutputColumnNames(
               
columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList()));
-      analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, true));
+      boolean isIgnoreTimestamp = modelInformation.isBuiltIn();
+      analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, 
isIgnoreTimestamp));
       return;
     }
 
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java
index c68c0ef4d74..fe71feb4639 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java
@@ -1383,6 +1383,7 @@ public class LogicalPlanBuilder {
             context.getQueryId().genPlanNodeId(),
             root,
             analysis.getModelInferenceDescriptor(),
+            !analysis.getRespDatasetHeader().isIgnoreTimestamp(),
             analysis.getOutputExpressions().stream()
                 .map(expressionStringPair -> 
expressionStringPair.left.getExpressionString())
                 .collect(Collectors.toList()));
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java
index 376f582eedd..1a8a3cf487b 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java
@@ -234,7 +234,7 @@ public class LogicalPlanVisitor extends 
StatementVisitor<PlanNode, MPPQueryConte
     }
 
     if (queryStatement.hasModelInference()) {
-      planBuilder.planInference(analysis);
+      planBuilder = planBuilder.planInference(analysis);
     }
 
     // plan select into
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
index 75328790154..59f717e2860 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
@@ -2310,6 +2310,7 @@ public class OperatorTreeGenerator extends 
PlanVisitor<Operator, LocalExecutionP
         FragmentInstanceManager.getInstance().getModelInferenceExecutor(),
         node.getInputColumnNames(),
         node.getChild().getOutputColumnNames(),
+        node.isGenerateTimeColumn(),
         maxRetainedSize,
         maxReturnSize);
   }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java
index 95fe3437e78..09205c9eb56 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java
@@ -40,24 +40,29 @@ public class InferenceNode extends SingleChildProcessNode {
 
   // the column order in select item which reflects the real input order
   private final List<String> targetColumnNames;
+  private boolean generateTimeColumn = false;
 
   public InferenceNode(
       PlanNodeId id,
       PlanNode child,
       ModelInferenceDescriptor modelInferenceDescriptor,
+      boolean generateTimeColumn,
       List<String> targetColumnNames) {
     super(id, child);
     this.modelInferenceDescriptor = modelInferenceDescriptor;
     this.targetColumnNames = targetColumnNames;
+    this.generateTimeColumn = generateTimeColumn;
   }
 
   public InferenceNode(
       PlanNodeId id,
       ModelInferenceDescriptor modelInferenceDescriptor,
+      boolean generateTimeColumn,
       List<String> inputColumnNames) {
     super(id);
     this.modelInferenceDescriptor = modelInferenceDescriptor;
     this.targetColumnNames = inputColumnNames;
+    this.generateTimeColumn = generateTimeColumn;
   }
 
   public ModelInferenceDescriptor getModelInferenceDescriptor() {
@@ -68,6 +73,10 @@ public class InferenceNode extends SingleChildProcessNode {
     return targetColumnNames;
   }
 
+  public boolean isGenerateTimeColumn() {
+    return generateTimeColumn;
+  }
+
   @Override
   public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
     return visitor.visitInference(this, context);
@@ -75,7 +84,8 @@ public class InferenceNode extends SingleChildProcessNode {
 
   @Override
   public PlanNode clone() {
-    return new InferenceNode(getPlanNodeId(), child, modelInferenceDescriptor, 
targetColumnNames);
+    return new InferenceNode(
+        getPlanNodeId(), child, modelInferenceDescriptor, generateTimeColumn, 
targetColumnNames);
   }
 
   @Override
@@ -87,6 +97,7 @@ public class InferenceNode extends SingleChildProcessNode {
   protected void serializeAttributes(ByteBuffer byteBuffer) {
     PlanNodeType.INFERENCE.serialize(byteBuffer);
     modelInferenceDescriptor.serialize(byteBuffer);
+    ReadWriteIOUtils.write(generateTimeColumn, byteBuffer);
     ReadWriteIOUtils.writeStringList(targetColumnNames, byteBuffer);
   }
 
@@ -94,15 +105,18 @@ public class InferenceNode extends SingleChildProcessNode {
   protected void serializeAttributes(DataOutputStream stream) throws 
IOException {
     PlanNodeType.INFERENCE.serialize(stream);
     modelInferenceDescriptor.serialize(stream);
+    ReadWriteIOUtils.write(generateTimeColumn, stream);
     ReadWriteIOUtils.writeStringList(targetColumnNames, stream);
   }
 
   public static InferenceNode deserialize(ByteBuffer buffer) {
     ModelInferenceDescriptor modelInferenceDescriptor =
         ModelInferenceDescriptor.deserialize(buffer);
+    boolean generateTimeColumn = ReadWriteIOUtils.readBool(buffer);
     List<String> inputColumnNames = ReadWriteIOUtils.readStringList(buffer);
     PlanNodeId planNodeId = PlanNodeId.deserialize(buffer);
-    return new InferenceNode(planNodeId, modelInferenceDescriptor, 
inputColumnNames);
+    return new InferenceNode(
+        planNodeId, modelInferenceDescriptor, generateTimeColumn, 
inputColumnNames);
   }
 
   @Override

Reply via email to