This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new ccba945c67e [AINode] Simplify the CREATE MODEL SQL for model training
(#15840)
ccba945c67e is described below
commit ccba945c67ee853d7740578e7a4db3f7f3d4ef79
Author: Yongzao <[email protected]>
AuthorDate: Fri Jun 27 23:29:55 2025 +0800
[AINode] Simplify the CREATE MODEL SQL for model training (#15840)
---
.../org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 | 2 +-
.../iotdb/confignode/manager/ConfigManager.java | 1 -
.../plan/execution/config/TableConfigTaskVisitor.java | 6 +-----
.../plan/execution/config/TreeConfigTaskVisitor.java | 1 -
.../config/executor/ClusterConfigTaskExecutor.java | 4 +---
.../config/executor/IConfigTaskExecutor.java | 1 -
.../config/metadata/ai/CreateTrainingTask.java | 19 ++-----------------
.../iotdb/db/queryengine/plan/parser/ASTVisitor.java | 4 +---
.../plan/relational/sql/ast/CreateTraining.java | 14 ++------------
.../plan/relational/sql/parser/AstBuilder.java | 3 +--
.../metadata/model/CreateTrainingStatement.java | 14 ++------------
.../iotdb/db/relational/grammar/sql/RelationalSql.g4 | 2 +-
.../thrift-ainode/src/main/thrift/ainode.thrift | 3 +--
.../src/main/thrift/confignode.thrift | 7 +++----
14 files changed, 16 insertions(+), 65 deletions(-)
diff --git
a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
index 6a0c97ec321..01c7de75ede 100644
---
a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
+++
b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
@@ -699,7 +699,7 @@ dropSubscription
// ---- Create Model
createModel
: CREATE MODEL modelName=identifier uriClause
- | CREATE MODEL modelType=identifier modelId=identifier (WITH
HYPERPARAMETERS LR_BRACKET hparamPair (COMMA hparamPair)* RR_BRACKET)? (FROM
MODEL existingModelId=identifier)? ON DATASET LR_BRACKET trainingData RR_BRACKET
+ | CREATE MODEL modelId=identifier (WITH HYPERPARAMETERS LR_BRACKET
hparamPair (COMMA hparamPair)* RR_BRACKET)? FROM MODEL
existingModelId=identifier ON DATASET LR_BRACKET trainingData RR_BRACKET
;
trainingData
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index 82fd508c5c1..a95b945a73c 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -2660,7 +2660,6 @@ public class ConfigManager implements IManager {
TTrainingReq trainingReq = new TTrainingReq();
trainingReq.setModelId(req.getModelId());
- trainingReq.setModelType(req.getModelType());
if (req.isSetExistingModelId()) {
trainingReq.setExistingModelId(req.getExistingModelId());
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
index 10b75c424cc..b3b10606d3e 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
@@ -1359,11 +1359,7 @@ public class TableConfigTaskVisitor extends
AstVisitor<IConfigTask, MPPQueryCont
context.setQueryType(QueryType.WRITE);
return new CreateTrainingTask(
- node.getModelId(),
- node.getModelType(),
- node.getParameters(),
- node.getExistingModelId(),
- node.getTargetSql());
+ node.getModelId(), node.getParameters(), node.getExistingModelId(),
node.getTargetSql());
}
@Override
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
index 5b23baa41b3..6bd1d0392d5 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
@@ -808,7 +808,6 @@ public class TreeConfigTaskVisitor extends
StatementVisitor<IConfigTask, MPPQuer
}
return new CreateTrainingTask(
createTrainingStatement.getModelId(),
- createTrainingStatement.getModelType(),
createTrainingStatement.getParameters(),
createTrainingStatement.getTargetTimeRanges(),
createTrainingStatement.getExistingModelId(),
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index 6d57f1ac12b..a8bdf942703 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -3331,7 +3331,6 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
@Override
public SettableFuture<ConfigTaskResult> createTraining(
String modelId,
- String modelType,
boolean isTableModel,
Map<String, String> parameters,
List<List<Long>> timeRanges,
@@ -3341,7 +3340,7 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
final SettableFuture<ConfigTaskResult> future = SettableFuture.create();
try (final ConfigNodeClient client =
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
- final TCreateTrainingReq req = new TCreateTrainingReq(modelId,
modelType, isTableModel);
+ final TCreateTrainingReq req = new TCreateTrainingReq(modelId,
isTableModel, existingModelId);
if (isTableModel) {
TDataSchemaForTable dataSchemaForTable = new TDataSchemaForTable();
@@ -3354,7 +3353,6 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
}
req.setParameters(parameters);
req.setTimeRanges(timeRanges);
- req.setExistingModelId(existingModelId);
final TSStatus executionStatus = client.createTraining(req);
if (TSStatusCode.SUCCESS_STATUS.getStatusCode() !=
executionStatus.getCode()) {
future.setException(new IoTDBException(executionStatus.message,
executionStatus.code));
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
index cb49b444a52..11df7b6f689 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
@@ -424,7 +424,6 @@ public interface IConfigTaskExecutor {
SettableFuture<ConfigTaskResult> createTraining(
String modelId,
- String modelType,
boolean isTableModel,
Map<String, String> parameters,
List<List<Long>> timeRanges,
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
index 821c01e27a4..9c93c5b7577 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
@@ -31,7 +31,6 @@ import java.util.Map;
public class CreateTrainingTask implements IConfigTask {
private final String modelId;
- private final String modelType;
private final boolean isTableModel;
private final Map<String, String> parameters;
@@ -45,13 +44,8 @@ public class CreateTrainingTask implements IConfigTask {
// For table model
public CreateTrainingTask(
- String modelId,
- String modelType,
- Map<String, String> parameters,
- String existingModelId,
- String targetSql) {
+ String modelId, Map<String, String> parameters, String existingModelId,
String targetSql) {
this.modelId = modelId;
- this.modelType = modelType;
this.parameters = parameters;
this.existingModelId = existingModelId;
this.targetSql = targetSql;
@@ -61,13 +55,11 @@ public class CreateTrainingTask implements IConfigTask {
// For tree model
public CreateTrainingTask(
String modelId,
- String modelType,
Map<String, String> parameters,
List<List<Long>> timeRanges,
String existingModelId,
List<String> targetPaths) {
this.modelId = modelId;
- this.modelType = modelType;
this.parameters = parameters;
this.timeRanges = timeRanges;
this.existingModelId = existingModelId;
@@ -80,13 +72,6 @@ public class CreateTrainingTask implements IConfigTask {
public ListenableFuture<ConfigTaskResult> execute(IConfigTaskExecutor
configTaskExecutor)
throws InterruptedException {
return configTaskExecutor.createTraining(
- modelId,
- modelType,
- isTableModel,
- parameters,
- timeRanges,
- existingModelId,
- targetSql,
- targetPaths);
+ modelId, isTableModel, parameters, timeRanges, existingModelId,
targetSql, targetPaths);
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
index f838dadf283..ad84c41b262 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
@@ -1368,9 +1368,7 @@ public class ASTVisitor extends
IoTDBSqlParserBaseVisitor<Statement> {
public Statement visitCreateModel(IoTDBSqlParser.CreateModelContext ctx) {
if (ctx.modelName == null) {
String modelId = ctx.modelId.getText();
- String modelType = ctx.modelType.getText();
- CreateTrainingStatement createTrainingStatement =
- new CreateTrainingStatement(modelId, modelType);
+ CreateTrainingStatement createTrainingStatement = new
CreateTrainingStatement(modelId);
if (ctx.hparamPair() != null) {
Map<String, String> parameterList = new HashMap<>();
for (IoTDBSqlParser.HparamPairContext hparamPairContext :
ctx.hparamPair()) {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
index 1e621b7352e..eb8f63df9fb 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
@@ -26,16 +26,14 @@ import java.util.Objects;
public class CreateTraining extends Statement {
private final String modelId;
- private final String modelType;
private final String targetSql;
private Map<String, String> parameters;
private String existingModelId = null;
- public CreateTraining(String modelId, String modelType, String targetSql) {
+ public CreateTraining(String modelId, String targetSql) {
super(null);
this.modelId = modelId;
- this.modelType = modelType;
this.targetSql = targetSql;
}
@@ -56,10 +54,6 @@ public class CreateTraining extends Statement {
return modelId;
}
- public String getModelType() {
- return modelType;
- }
-
public Map<String, String> getParameters() {
return parameters;
}
@@ -79,7 +73,7 @@ public class CreateTraining extends Statement {
@Override
public int hashCode() {
- return Objects.hash(modelId, modelType, targetSql, existingModelId,
parameters);
+ return Objects.hash(modelId, targetSql, existingModelId, parameters);
}
@Override
@@ -89,7 +83,6 @@ public class CreateTraining extends Statement {
}
CreateTraining createTraining = (CreateTraining) obj;
return modelId.equals(createTraining.modelId)
- && modelType.equals(createTraining.modelType)
&& Objects.equals(existingModelId, createTraining.existingModelId)
&& Objects.equals(parameters, createTraining.parameters)
&& Objects.equals(targetSql, createTraining.targetSql);
@@ -101,9 +94,6 @@ public class CreateTraining extends Statement {
+ "modelId='"
+ modelId
+ '\''
- + ", modelType='"
- + modelType
- + '\''
+ ", parameters="
+ parameters
+ ", existingModelId='"
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
index 85085baea5e..26fd39070fd 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
@@ -3558,13 +3558,12 @@ public class AstBuilder extends
RelationalSqlBaseVisitor<Node> {
public Node
visitCreateModelStatement(RelationalSqlParser.CreateModelStatementContext ctx) {
String modelId = ctx.modelId.getText();
validateModelName(modelId);
- String modelType = ctx.modelType.getText();
if (ctx.targetData == null) {
throw new SemanticException("Target data in sql should be set in CREATE
MODEL");
}
String targetData = ((StringLiteral) visit(ctx.targetData)).getValue();
- CreateTraining createTraining = new CreateTraining(modelId, modelType,
targetData);
+ CreateTraining createTraining = new CreateTraining(modelId, targetData);
if (ctx.HYPERPARAMETERS() != null) {
Map<String, String> parameters = new HashMap<>();
for (RelationalSqlParser.HparamPairContext hparamPairContext :
ctx.hparamPair()) {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateTrainingStatement.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateTrainingStatement.java
index 6f1dd4735d9..0628c22ae2b 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateTrainingStatement.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateTrainingStatement.java
@@ -32,7 +32,6 @@ import java.util.Objects;
public class CreateTrainingStatement extends Statement implements
IConfigStatement {
private final String modelId;
- private final String modelType;
private Map<String, String> parameters;
private String existingModelId = null;
@@ -40,9 +39,8 @@ public class CreateTrainingStatement extends Statement
implements IConfigStateme
private List<PartialPath> targetPathPatterns;
private List<List<Long>> targetTimeRanges;
- public CreateTrainingStatement(String modelId, String modelType) {
+ public CreateTrainingStatement(String modelId) {
this.modelId = modelId;
- this.modelType = modelType;
}
public void setTargetPathPatterns(List<PartialPath> targetPathPatterns) {
@@ -65,10 +63,6 @@ public class CreateTrainingStatement extends Statement
implements IConfigStateme
return modelId;
}
- public String getModelType() {
- return modelType;
- }
-
public void setExistingModelId(String existingModelId) {
this.existingModelId = existingModelId;
}
@@ -87,7 +81,7 @@ public class CreateTrainingStatement extends Statement
implements IConfigStateme
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), modelId, modelType, existingModelId,
parameters);
+ return Objects.hash(super.hashCode(), modelId, existingModelId,
parameters);
}
@Override
@@ -97,7 +91,6 @@ public class CreateTrainingStatement extends Statement
implements IConfigStateme
}
CreateTrainingStatement target = (CreateTrainingStatement) obj;
return modelId.equals(target.modelId)
- && modelType.equals(target.modelType)
&& Objects.equals(existingModelId, target.existingModelId)
&& Objects.equals(parameters, target.parameters);
}
@@ -108,9 +101,6 @@ public class CreateTrainingStatement extends Statement
implements IConfigStateme
+ "modelId='"
+ modelId
+ '\''
- + ", modelType='"
- + modelType
- + '\''
+ ", parameters="
+ parameters
+ ", existingModelId='"
diff --git
a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
index 591592cfdbb..91bb547c27f 100644
---
a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
+++
b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
@@ -782,7 +782,7 @@ revokeGrantOpt
// ------------------------------------------- AI
---------------------------------------------------------
createModelStatement
- : CREATE MODEL modelType=identifier modelId=identifier (WITH
HYPERPARAMETERS '(' hparamPair (',' hparamPair)* ')')? (FROM MODEL
existingModelId=identifier)? ON DATASET '(' targetData=string ')'
+ : CREATE MODEL modelId=identifier (WITH HYPERPARAMETERS '(' hparamPair
(',' hparamPair)* ')')? FROM MODEL existingModelId=identifier ON DATASET '('
targetData=string ')'
;
hparamPair
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 52e665a532e..df6cf5daca3 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -80,10 +80,9 @@ struct IDataSchema {
struct TTrainingReq {
1: required string dbType
2: required string modelId
- 3: required string modelType
+ 3: required string existingModelId
4: optional list<IDataSchema> targetDataSchema;
5: optional map<string, string> parameters;
- 6: optional string existingModelId
}
struct TForecastReq {
diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
index 59effaccae8..3ce020c731b 100644
--- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
@@ -1098,13 +1098,12 @@ struct TDataSchemaForTree{
struct TCreateTrainingReq {
1: required string modelId
- 2: required string modelType
- 3: required bool isTableModel
+ 2: required bool isTableModel
+ 3: required string existingModelId
4: optional TDataSchemaForTable dataSchemaForTable
5: optional TDataSchemaForTree dataSchemaForTree
6: optional map<string, string> parameters
- 7: optional string existingModelId
- 8: optional list<list<i64>> timeRanges
+ 7: optional list<list<i64>> timeRanges
}
// ====================================================