This is an automated email from the ASF dual-hosted git repository. ycycse pushed a commit to branch ycy/ainodeTimeColumnAppend in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit a92f15125c0e3ea9b6877c4b48aebec186b8af21 Author: YangCaiyin <[email protected]> AuthorDate: Thu Mar 6 17:49:08 2025 +0800 support output time column for model inference --- .../operator/process/ai/InferenceOperator.java | 42 +++++++++++++++++++++- .../queryengine/plan/analyze/AnalyzeVisitor.java | 3 +- .../plan/planner/LogicalPlanBuilder.java | 1 + .../plan/planner/LogicalPlanVisitor.java | 2 +- .../plan/planner/OperatorTreeGenerator.java | 1 + .../plan/node/process/AI/InferenceNode.java | 18 ++++++++-- 6 files changed, 62 insertions(+), 5 deletions(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java index 08cae2dc5b1..03ba37def17 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java @@ -40,6 +40,7 @@ import org.apache.tsfile.block.column.ColumnBuilder; import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.read.common.block.TsBlock; import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.read.common.block.column.TimeColumn; import org.apache.tsfile.read.common.block.column.TimeColumnBuilder; import org.apache.tsfile.read.common.block.column.TsBlockSerde; import org.apache.tsfile.utils.RamUsageEstimator; @@ -81,6 +82,12 @@ public class InferenceOperator implements ProcessOperator { private final TsBlockSerde serde = new TsBlockSerde(); private InferenceWindowType windowType = null; + private final boolean generateTimeColumn; + private long maxTimestamp; + private long minTimestamp; + private long interval; + private long currentRowIndex; + public InferenceOperator( OperatorContext operatorContext, Operator child, @@ -88,6 +95,7 @@ public class InferenceOperator implements ProcessOperator { ExecutorService modelInferenceExecutor, List<String> targetColumnNames, List<String> inputColumnNames, + boolean generateTimeColumn, long maxRetainedSize, long maxReturnSize) { this.operatorContext = operatorContext; @@ -106,6 +114,14 @@ public class InferenceOperator implements ProcessOperator { if (modelInferenceDescriptor.getInferenceWindowParameter() != null) { windowType = modelInferenceDescriptor.getInferenceWindowParameter().getWindowType(); } + + if (generateTimeColumn) { + this.interval = 0; + this.minTimestamp = Long.MAX_VALUE; + this.maxTimestamp = Long.MIN_VALUE; + this.currentRowIndex = 0; + } + this.generateTimeColumn = generateTimeColumn; } @Override @@ -140,6 +156,15 @@ public class InferenceOperator implements ProcessOperator { return !finished || (results != null && results.size() != resultIndex); } + private void fillTimeColumn(TsBlock tsBlock) { + TimeColumn timeColumn = (TimeColumn) tsBlock.getTimeColumn(); + long[] time = timeColumn.getTimes(); + for (int i = 0; i < time.length; i++) { + time[i] = maxTimestamp + interval * currentRowIndex; + currentRowIndex++; + } + } + @Override public TsBlock next() throws Exception { if (inferenceExecutionFuture == null) { @@ -156,6 +181,9 @@ public class InferenceOperator implements ProcessOperator { if (results != null && resultIndex != results.size()) { TsBlock tsBlock = serde.deserialize(results.get(resultIndex)); + if (generateTimeColumn) { + fillTimeColumn(tsBlock); + } resultIndex++; return tsBlock; } @@ -177,6 +205,9 @@ public class InferenceOperator implements ProcessOperator { finished = true; TsBlock resultTsBlock = serde.deserialize(inferenceResp.inferenceResult.get(0)); + if (generateTimeColumn) { + fillTimeColumn(resultTsBlock); + } results = inferenceResp.inferenceResult; resultIndex++; return resultTsBlock; @@ -194,7 +225,12 @@ public class InferenceOperator implements ProcessOperator { ColumnBuilder[] columnBuilders = inputTsBlockBuilder.getValueColumnBuilders(); totalRow += inputTsBlock.getPositionCount(); for (int i = 0; i < inputTsBlock.getPositionCount(); i++) { - timeColumnBuilder.writeLong(inputTsBlock.getTimeByIndex(i)); + long timestamp = inputTsBlock.getTimeByIndex(i); + if (generateTimeColumn) { + minTimestamp = Math.min(minTimestamp, timestamp); + maxTimestamp = Math.max(maxTimestamp, timestamp); + } + timeColumnBuilder.writeLong(timestamp); for (int columnIndex = 0; columnIndex < inputTsBlock.getValueColumnCount(); columnIndex++) { columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex), i); } @@ -259,6 +295,10 @@ public class InferenceOperator implements ProcessOperator { private void submitInferenceTask() { + if (generateTimeColumn) { + interval = (maxTimestamp - minTimestamp) / totalRow; + } + TsBlock inputTsBlock = inputTsBlockBuilder.build(); TsBlock finalInputTsBlock = preProcess(inputTsBlock); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java index 7f51f95c362..8be0db3747e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java @@ -1695,7 +1695,8 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext> .getModelInferenceDescriptor() .setOutputColumnNames( columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList())); - analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, true)); + boolean isIgnoreTimestamp = modelInformation.isBuiltIn(); + analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); return; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java index c68c0ef4d74..fe71feb4639 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java @@ -1383,6 +1383,7 @@ public class LogicalPlanBuilder { context.getQueryId().genPlanNodeId(), root, analysis.getModelInferenceDescriptor(), + !analysis.getRespDatasetHeader().isIgnoreTimestamp(), analysis.getOutputExpressions().stream() .map(expressionStringPair -> expressionStringPair.left.getExpressionString()) .collect(Collectors.toList())); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java index 376f582eedd..1a8a3cf487b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java @@ -234,7 +234,7 @@ public class LogicalPlanVisitor extends StatementVisitor<PlanNode, MPPQueryConte } if (queryStatement.hasModelInference()) { - planBuilder.planInference(analysis); + planBuilder = planBuilder.planInference(analysis); } // plan select into diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java index 75328790154..59f717e2860 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java @@ -2310,6 +2310,7 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP FragmentInstanceManager.getInstance().getModelInferenceExecutor(), node.getInputColumnNames(), node.getChild().getOutputColumnNames(), + node.isGenerateTimeColumn(), maxRetainedSize, maxReturnSize); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java index 95fe3437e78..09205c9eb56 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java @@ -40,24 +40,29 @@ public class InferenceNode extends SingleChildProcessNode { // the column order in select item which reflects the real input order private final List<String> targetColumnNames; + private boolean generateTimeColumn = false; public InferenceNode( PlanNodeId id, PlanNode child, ModelInferenceDescriptor modelInferenceDescriptor, + boolean generateTimeColumn, List<String> targetColumnNames) { super(id, child); this.modelInferenceDescriptor = modelInferenceDescriptor; this.targetColumnNames = targetColumnNames; + this.generateTimeColumn = generateTimeColumn; } public InferenceNode( PlanNodeId id, ModelInferenceDescriptor modelInferenceDescriptor, + boolean generateTimeColumn, List<String> inputColumnNames) { super(id); this.modelInferenceDescriptor = modelInferenceDescriptor; this.targetColumnNames = inputColumnNames; + this.generateTimeColumn = generateTimeColumn; } public ModelInferenceDescriptor getModelInferenceDescriptor() { @@ -68,6 +73,10 @@ public class InferenceNode extends SingleChildProcessNode { return targetColumnNames; } + public boolean isGenerateTimeColumn() { + return generateTimeColumn; + } + @Override public <R, C> R accept(PlanVisitor<R, C> visitor, C context) { return visitor.visitInference(this, context); @@ -75,7 +84,8 @@ public class InferenceNode extends SingleChildProcessNode { @Override public PlanNode clone() { - return new InferenceNode(getPlanNodeId(), child, modelInferenceDescriptor, targetColumnNames); + return new InferenceNode( + getPlanNodeId(), child, modelInferenceDescriptor, generateTimeColumn, targetColumnNames); } @Override @@ -87,6 +97,7 @@ public class InferenceNode extends SingleChildProcessNode { protected void serializeAttributes(ByteBuffer byteBuffer) { PlanNodeType.INFERENCE.serialize(byteBuffer); modelInferenceDescriptor.serialize(byteBuffer); + ReadWriteIOUtils.write(generateTimeColumn, byteBuffer); ReadWriteIOUtils.writeStringList(targetColumnNames, byteBuffer); } @@ -94,15 +105,18 @@ public class InferenceNode extends SingleChildProcessNode { protected void serializeAttributes(DataOutputStream stream) throws IOException { PlanNodeType.INFERENCE.serialize(stream); modelInferenceDescriptor.serialize(stream); + ReadWriteIOUtils.write(generateTimeColumn, stream); ReadWriteIOUtils.writeStringList(targetColumnNames, stream); } public static InferenceNode deserialize(ByteBuffer buffer) { ModelInferenceDescriptor modelInferenceDescriptor = ModelInferenceDescriptor.deserialize(buffer); + boolean generateTimeColumn = ReadWriteIOUtils.readBool(buffer); List<String> inputColumnNames = ReadWriteIOUtils.readStringList(buffer); PlanNodeId planNodeId = PlanNodeId.deserialize(buffer); - return new InferenceNode(planNodeId, modelInferenceDescriptor, inputColumnNames); + return new InferenceNode( + planNodeId, modelInferenceDescriptor, generateTimeColumn, inputColumnNames); } @Override
