This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 5c4c297e63d [AINode] Modify dataset module for table-model training
(#15816)
5c4c297e63d is described below
commit 5c4c297e63d794b50a946b9fc9a4e07acc6a49f6
Author: YangCaiyin <[email protected]>
AuthorDate: Thu Jun 26 20:26:46 2025 +0800
[AINode] Modify dataset module for table-model training (#15816)
---
iotdb-core/ainode/ainode/core/ingress/iotdb.py | 152 ++++++++++-----------
.../iotdb/confignode/manager/ConfigManager.java | 35 +----
.../execution/config/TableConfigTaskVisitor.java | 17 +--
.../execution/config/TreeConfigTaskVisitor.java | 1 -
.../config/executor/ClusterConfigTaskExecutor.java | 8 +-
.../config/executor/IConfigTaskExecutor.java | 4 +-
.../config/metadata/ai/CreateTrainingTask.java | 26 ++--
.../plan/relational/sql/ast/CreateTraining.java | 58 ++------
.../plan/relational/sql/parser/AstBuilder.java | 48 +------
.../db/relational/grammar/sql/RelationalSql.g4 | 24 +---
.../src/main/thrift/confignode.thrift | 10 +-
11 files changed, 112 insertions(+), 271 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/ingress/iotdb.py
b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
index 175e2eb2f65..399036bca1d 100644
--- a/iotdb-core/ainode/ainode/core/ingress/iotdb.py
+++ b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
@@ -216,32 +216,24 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
+ DEFAULT_TAG = "__DEFAULT_TAG__"
+
def __init__(
self,
- input_len: int,
- out_len: int,
+ model_id: str,
+ seq_len: int,
+ input_token_len: int,
+ output_token_len: int,
data_schema_list: list,
ip: str = "127.0.0.1",
port: int = 6667,
username: str = "root",
password: str = "root",
time_zone: str = "UTC+8",
- start_split: float = 0,
- end_split: float = 1,
+ use_rate: float = 1.0,
+ offset_rate: float = 0.0,
):
- super().__init__(ip, port, input_len, out_len)
- if end_split < start_split:
- raise ValueError("end_split must be greater than start_split")
-
- # database , table
- self.SELECT_SERIES_FORMAT_SQL = "select distinct item_id from %s"
- self.COUNT_SERIES_LENGTH_SQL = (
- "select count(value) from %s where item_id = '%s'"
- )
- self.FETCH_SERIES_SQL = (
- "select value from %s where item_id = '%s' offset %s limit %s"
- )
- self.SERIES_NAME = "%s.%s"
+ super().__init__(ip, port, seq_len, input_token_len, output_token_len)
table_session_config = TableSessionConfig(
node_urls=[f"{ip}:{port}"],
@@ -249,87 +241,95 @@ class
IoTDBTableModelDataset(BasicDatabaseForecastDataset):
password=password,
time_zone=time_zone,
)
-
self.session = TableSession(table_session_config)
- self.context_length = self.input_len + self.output_len
- self.token_num = self.context_length // self.input_len
- self._fetch_schema(data_schema_list)
+ self.use_rate = use_rate
+ self.offset_rate = offset_rate
- self.start_index = int(self.total_count * start_split)
- self.end_index = self.total_count * end_split
+ # used for caching data
+ self._fetch_schema(data_schema_list)
def _fetch_schema(self, data_schema_list: list):
- series_to_length = {}
- for data_schema in data_schema_list:
- series_list = []
- with self.session.execute_query_statement(
- self.SELECT_SERIES_FORMAT_SQL % data_schema
- ) as show_devices_result:
- while show_devices_result.has_next():
+ series_map = {}
+ for target_sql in data_schema_list:
+ target_sql = target_sql.schemaName
+ with self.session.execute_query_statement(target_sql) as
target_data:
+ while target_data.has_next():
+ cur_data = target_data.next()
+ # TODO: currently, we only support the following simple
table form
+ time_col, value_col, tag_col = -1, -1, -1
+ for i, field in enumerate(cur_data.get_fields()):
+ if field.get_data_type() == TSDataType.TIMESTAMP:
+ time_col = i
+ elif field.get_data_type() in (
+ TSDataType.INT32,
+ TSDataType.INT64,
+ TSDataType.FLOAT,
+ TSDataType.DOUBLE,
+ ):
+ value_col = i
+ elif field.get_data_type() == TSDataType.TEXT:
+ tag_col = i
+ if time_col == -1 or value_col == -1:
+ raise ValueError(
+ "The training cannot start due to invalid data
schema"
+ )
+ if tag_col == -1:
+ tag = self.DEFAULT_TAG
+ else:
+ tag = cur_data.get_fields()[tag_col].get_string_value()
+ if tag not in series_map:
+ series_map[tag] = []
+ series_list = series_map[tag]
series_list.append(
-
get_field_value(show_devices_result.next().get_fields()[0])
+ get_field_value(cur_data.get_fields()[value_col])
)
- for series in series_list:
- with self.session.execute_query_statement(
- self.COUNT_SERIES_LENGTH_SQL % (data_schema.schemaName,
series)
- ) as count_series_result:
- length =
get_field_value(count_series_result.next().get_fields()[0])
- series_to_length[
- self.SERIES_NAME % (data_schema.schemaName, series)
- ] = length
-
- sorted_series = sorted(series_to_length.items(), key=lambda x: x[1])
- sorted_series_with_prefix_sum = []
+ # TODO: Unify the following implementation
+ # structure: [(series_name, the number of windows of this series,
prefix sum of window number, window start offset, series_data), ...]
+ series_with_prefix_sum = []
window_sum = 0
- for seq_name, seq_length in sorted_series:
- window_count = seq_length - self.context_length + 1
- if window_count < 0:
+ for seq_name, seq_values in series_map.items():
+ # calculate and sum the number of training data windows for each
time series
+ window_count = len(seq_values) - self.seq_len -
self.output_token_len + 1
+ if window_count <= 1:
continue
- window_sum += window_count
- sorted_series_with_prefix_sum.append((seq_name, window_count,
window_sum))
+ use_window_count = int(window_count * self.use_rate)
+ window_sum += use_window_count
+ series_with_prefix_sum.append(
+ (
+ seq_name,
+ use_window_count,
+ window_sum,
+ int(window_count * self.offset_rate),
+ seq_values,
+ )
+ )
- self.total_count = window_sum
- self.sorted_series = sorted_series_with_prefix_sum
+ self.total_window_count = window_sum
+ self.series_with_prefix_sum = series_with_prefix_sum
def __getitem__(self, index):
window_index = index
-
+ # locate the series to be queried
series_index = 0
-
- while self.sorted_series[series_index][2] < window_index:
+ while self.series_with_prefix_sum[series_index][1] < window_index:
series_index += 1
-
+ # locate the window of this series to be queried
if series_index != 0:
- window_index -= self.sorted_series[series_index - 1][2]
-
- if window_index != 0:
- window_index -= 1
- series = self.sorted_series[series_index][0]
- schema = series.split(".")
-
- result = []
- sql = self.FETCH_SERIES_SQL % (
- schema[0:1],
- schema[2],
- window_index,
- self.context_length,
- )
- try:
- with self.session.execute_query_statement(sql) as query_result:
- while query_result.has_next():
-
result.append(get_field_value(query_result.next().get_fields()[0]))
- except Exception as e:
- logger.error("Executing sql: {} with exception: {}".format(sql, e))
+ window_index -= self.series_with_prefix_sum[series_index - 1][2]
+ window_index += self.series_with_prefix_sum[series_index][3]
+ result = self.series_with_prefix_sum[series_index][4][
+ window_index : window_index + self.seq_len + self.output_token_len
+ ]
result = torch.tensor(result)
return (
- result[0 : self.input_len],
- result[-self.output_len :],
+ result[0 : self.seq_len],
+ result[self.input_token_len : self.seq_len +
self.output_token_len],
np.ones(self.token_num, dtype=np.int32),
)
def __len__(self):
- return self.end_index - self.start_index
+ return self.total_window_count
def register_dataset(key: str, dataset: Dataset):
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index 0c8f87c12b5..a6462a72a2a 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -167,7 +167,6 @@ import
org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartResp;
import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionTableResp;
-import org.apache.iotdb.confignode.rpc.thrift.TDataSchemaForTable;
import org.apache.iotdb.confignode.rpc.thrift.TDatabaseSchema;
import org.apache.iotdb.confignode.rpc.thrift.TDeactivateSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TDeleteDatabasesReq;
@@ -248,7 +247,6 @@ import
org.apache.iotdb.confignode.rpc.thrift.TSpaceQuotaResp;
import org.apache.iotdb.confignode.rpc.thrift.TStartPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TStopPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TSubscribeReq;
-import org.apache.iotdb.confignode.rpc.thrift.TTableInfo;
import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp;
import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList;
import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq;
@@ -2641,10 +2639,6 @@ public class ConfigManager implements IManager {
private List<IDataSchema> fetchSchemaForTreeModel(TCreateTrainingReq req) {
List<IDataSchema> dataSchemaList = new ArrayList<>();
- if (req.useAllData) {
- dataSchemaList.add(new IDataSchema("root.**"));
- return dataSchemaList;
- }
for (int i = 0; i < req.getDataSchemaForTree().getPathSize(); i++) {
IDataSchema dataSchema = new
IDataSchema(req.getDataSchemaForTree().getPath().get(i));
dataSchema.setTimeRange(req.getTimeRanges().get(i));
@@ -2654,28 +2648,7 @@ public class ConfigManager implements IManager {
}
private List<IDataSchema> fetchSchemaForTableModel(TCreateTrainingReq req) {
- List<IDataSchema> dataSchemaList = new ArrayList<>();
- TDataSchemaForTable dataSchemaForTable = req.getDataSchemaForTable();
- if (req.useAllData || !dataSchemaForTable.getDatabaseList().isEmpty()) {
- List<String> databaseNameList = new ArrayList<>();
- if (req.useAllData) {
- TShowDatabaseResp resp = showDatabase(new TGetDatabaseReq());
- databaseNameList.addAll(resp.getDatabaseInfoMap().keySet());
- } else {
- databaseNameList.addAll(dataSchemaForTable.getDatabaseList());
- }
-
- for (String database : databaseNameList) {
- TShowTableResp resp = showTables(database, false);
- for (TTableInfo tableInfo : resp.getTableInfoList()) {
- dataSchemaList.add(new IDataSchema(database + DOT +
tableInfo.tableName));
- }
- }
- }
- for (String tableName : dataSchemaForTable.getTableList()) {
- dataSchemaList.add(new IDataSchema(tableName));
- }
- return dataSchemaList;
+ return Collections.singletonList(new
IDataSchema(req.getDataSchemaForTable().getTargetSql()));
}
public TSStatus createTraining(TCreateTrainingReq req) {
@@ -2687,11 +2660,11 @@ public class ConfigManager implements IManager {
TTrainingReq trainingReq = new TTrainingReq();
trainingReq.setModelId(req.getModelId());
- trainingReq.setModelType("sundial");
- if (req.existingModelId != null) {
+ trainingReq.setModelType(req.getModelType());
+ if (req.isSetExistingModelId()) {
trainingReq.setExistingModelId(req.getExistingModelId());
}
- if (!req.parameters.isEmpty()) {
+ if (req.isSetParameters() && !req.getParameters().isEmpty()) {
trainingReq.setParameters(req.getParameters());
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
index 95d9ecaa278..10b75c424cc 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
@@ -211,7 +211,6 @@ import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.Pair;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
@@ -1359,26 +1358,12 @@ public class TableConfigTaskVisitor extends
AstVisitor<IConfigTask, MPPQueryCont
protected IConfigTask visitCreateTraining(CreateTraining node,
MPPQueryContext context) {
context.setQueryType(QueryType.WRITE);
- String curDatabase = clientSession.getDatabaseName();
- List<String> tableList = new ArrayList<>();
- for (QualifiedName tableName : node.getTargetTables()) {
- List<String> parts = tableName.getParts();
- if (parts.size() == 1) {
- tableList.add(curDatabase + "." + parts.get(0));
- } else {
- tableList.add(parts.get(1) + "." + parts.get(0));
- }
- }
-
return new CreateTrainingTask(
node.getModelId(),
node.getModelType(),
node.getParameters(),
- node.isUseAllData(),
- node.getTargetTimeRanges(),
node.getExistingModelId(),
- tableList,
- node.getTargetDbs());
+ node.getTargetSql());
}
@Override
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
index 282839ec560..5b23baa41b3 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
@@ -810,7 +810,6 @@ public class TreeConfigTaskVisitor extends
StatementVisitor<IConfigTask, MPPQuer
createTrainingStatement.getModelId(),
createTrainingStatement.getModelType(),
createTrainingStatement.getParameters(),
- false,
createTrainingStatement.getTargetTimeRanges(),
createTrainingStatement.getExistingModelId(),
targetPathPatterns);
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index 39c2bda7298..f45c3ce187b 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -3334,11 +3334,9 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
String modelType,
boolean isTableModel,
Map<String, String> parameters,
- boolean useAllData,
List<List<Long>> timeRanges,
String existingModelId,
- @Nullable List<String> tableList,
- @Nullable List<String> databaseList,
+ @Nullable String targetSql,
@Nullable List<String> pathList) {
final SettableFuture<ConfigTaskResult> future = SettableFuture.create();
try (final ConfigNodeClient client =
@@ -3347,8 +3345,7 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
if (isTableModel) {
TDataSchemaForTable dataSchemaForTable = new TDataSchemaForTable();
- dataSchemaForTable.setTableList(tableList);
- dataSchemaForTable.setDatabaseList(databaseList);
+ dataSchemaForTable.setTargetSql(targetSql);
req.setDataSchemaForTable(dataSchemaForTable);
} else {
TDataSchemaForTree dataSchemaForTree = new TDataSchemaForTree();
@@ -3356,7 +3353,6 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
req.setDataSchemaForTree(dataSchemaForTree);
}
req.setParameters(parameters);
- req.setUseAllData(useAllData);
req.setTimeRanges(timeRanges);
req.setExistingModelId(existingModelId);
final TSStatus executionStatus = client.createTraining(req);
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
index 8120dca2b92..cb49b444a52 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
@@ -427,10 +427,8 @@ public interface IConfigTaskExecutor {
String modelType,
boolean isTableModel,
Map<String, String> parameters,
- boolean useAllData,
List<List<Long>> timeRanges,
String existingModelId,
- @Nullable List<String> tableList,
- @Nullable List<String> databaseList,
+ @Nullable String targetSql,
@Nullable List<String> pathList);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
index 91d3258dba1..821c01e27a4 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
@@ -34,49 +34,41 @@ public class CreateTrainingTask implements IConfigTask {
private final String modelType;
private final boolean isTableModel;
private final Map<String, String> parameters;
- private final boolean useAllData;
- private final List<List<Long>> timeRanges;
+
private final String existingModelId;
// Data schema for table model
- private List<String> targetTables;
- private List<String> targetDbs;
+ private String targetSql = null;
// Data schema for tree model
private List<String> targetPaths;
+ private List<List<Long>> timeRanges;
+ // For table model
public CreateTrainingTask(
String modelId,
String modelType,
Map<String, String> parameters,
- boolean useAllData,
- List<List<Long>> timeRanges,
String existingModelId,
- List<String> targetTables,
- List<String> targetDbs) {
+ String targetSql) {
this.modelId = modelId;
this.modelType = modelType;
this.parameters = parameters;
- this.useAllData = useAllData;
- this.timeRanges = timeRanges;
this.existingModelId = existingModelId;
-
+ this.targetSql = targetSql;
this.isTableModel = true;
- this.targetTables = targetTables;
- this.targetDbs = targetDbs;
}
+ // For tree model
public CreateTrainingTask(
String modelId,
String modelType,
Map<String, String> parameters,
- boolean useAllData,
List<List<Long>> timeRanges,
String existingModelId,
List<String> targetPaths) {
this.modelId = modelId;
this.modelType = modelType;
this.parameters = parameters;
- this.useAllData = useAllData;
this.timeRanges = timeRanges;
this.existingModelId = existingModelId;
@@ -92,11 +84,9 @@ public class CreateTrainingTask implements IConfigTask {
modelType,
isTableModel,
parameters,
- useAllData,
timeRanges,
existingModelId,
- targetTables,
- targetDbs,
+ targetSql,
targetPaths);
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
index 3c978ccb5c6..1e621b7352e 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
@@ -27,20 +27,16 @@ public class CreateTraining extends Statement {
private final String modelId;
private final String modelType;
+ private final String targetSql;
private Map<String, String> parameters;
private String existingModelId = null;
- private List<QualifiedName> targetTables;
- private List<String> targetDbs;
-
- private List<List<Long>> targetTimeRanges;
- private boolean useAllData = false;
-
- public CreateTraining(String modelId, String modelType) {
+ public CreateTraining(String modelId, String modelType, String targetSql) {
super(null);
this.modelId = modelId;
this.modelType = modelType;
+ this.targetSql = targetSql;
}
@Override
@@ -56,26 +52,6 @@ public class CreateTraining extends Statement {
this.existingModelId = existingModelId;
}
- public void setTargetDbs(List<String> targetDbs) {
- this.targetDbs = targetDbs;
- }
-
- public void setTargetTables(List<QualifiedName> targetTables) {
- this.targetTables = targetTables;
- }
-
- public void setUseAllData(boolean useAllData) {
- this.useAllData = useAllData;
- }
-
- public List<String> getTargetDbs() {
- return targetDbs;
- }
-
- public List<QualifiedName> getTargetTables() {
- return targetTables;
- }
-
public String getModelId() {
return modelId;
}
@@ -92,16 +68,8 @@ public class CreateTraining extends Statement {
return existingModelId;
}
- public boolean isUseAllData() {
- return useAllData;
- }
-
- public void setTargetTimeRanges(List<List<Long>> targetTimeRanges) {
- this.targetTimeRanges = targetTimeRanges;
- }
-
- public List<List<Long>> getTargetTimeRanges() {
- return targetTimeRanges;
+ public String getTargetSql() {
+ return targetSql;
}
@Override
@@ -111,8 +79,7 @@ public class CreateTraining extends Statement {
@Override
public int hashCode() {
- return Objects.hash(
- modelId, modelType, existingModelId, parameters, targetTimeRanges,
useAllData);
+ return Objects.hash(modelId, modelType, targetSql, existingModelId,
parameters);
}
@Override
@@ -125,8 +92,7 @@ public class CreateTraining extends Statement {
&& modelType.equals(createTraining.modelType)
&& Objects.equals(existingModelId, createTraining.existingModelId)
&& Objects.equals(parameters, createTraining.parameters)
- && Objects.equals(targetTimeRanges, createTraining.targetTimeRanges)
- && useAllData == createTraining.useAllData;
+ && Objects.equals(targetSql, createTraining.targetSql);
}
@Override
@@ -143,14 +109,8 @@ public class CreateTraining extends Statement {
+ ", existingModelId='"
+ existingModelId
+ '\''
- + ", targetTables="
- + targetTables
- + ", targetDbs="
- + targetDbs
- + ", targetTimeRanges="
- + targetTimeRanges
- + ", useAllData="
- + useAllData
+ + ", targetSql='"
+ + targetSql
+ '}';
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
index 268379e6b12..85085baea5e 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
@@ -3554,20 +3554,17 @@ public class AstBuilder extends
RelationalSqlBaseVisitor<Node> {
}
}
- private List<Long> parseTimePair(RelationalSqlParser.TimeRangeContext
timeRangeContext) {
- long currentTime = CommonDateTimeUtils.currentTime();
- List<Long> timeRange = new ArrayList<>();
- timeRange.add(parseTimeValue(timeRangeContext.timeValue(0), currentTime));
- timeRange.add(parseTimeValue(timeRangeContext.timeValue(1), currentTime));
- return timeRange;
- }
-
@Override
public Node
visitCreateModelStatement(RelationalSqlParser.CreateModelStatementContext ctx) {
String modelId = ctx.modelId.getText();
validateModelName(modelId);
String modelType = ctx.modelType.getText();
- CreateTraining createTraining = new CreateTraining(modelId, modelType);
+
+ if (ctx.targetData == null) {
+ throw new SemanticException("Target data in sql should be set in CREATE
MODEL");
+ }
+ String targetData = ((StringLiteral) visit(ctx.targetData)).getValue();
+ CreateTraining createTraining = new CreateTraining(modelId, modelType,
targetData);
if (ctx.HYPERPARAMETERS() != null) {
Map<String, String> parameters = new HashMap<>();
for (RelationalSqlParser.HparamPairContext hparamPairContext :
ctx.hparamPair()) {
@@ -3581,39 +3578,6 @@ public class AstBuilder extends
RelationalSqlBaseVisitor<Node> {
createTraining.setExistingModelId(ctx.existingModelId.getText());
}
- List<List<Long>> dbTimeRange = new ArrayList<>();
- List<List<Long>> tableTimeRange = new ArrayList<>();
- if (ctx.trainingData().ALL() != null) {
- createTraining.setUseAllData(true);
- } else {
- List<QualifiedName> targetTables = new ArrayList<>();
- List<String> targetDbs = new ArrayList<>();
- for (RelationalSqlParser.DataElementContext dataElementContext :
- ctx.trainingData().dataElement()) {
- if (dataElementContext.databaseElement() != null) {
- targetDbs.add(
- ((Identifier)
visit(dataElementContext.databaseElement().database)).getValue());
- if (dataElementContext.databaseElement().timeRange() != null) {
-
dbTimeRange.add(parseTimePair(dataElementContext.databaseElement().timeRange()));
- }
- } else {
-
targetTables.add(getQualifiedName(dataElementContext.tableElement().qualifiedName()));
- if (dataElementContext.tableElement().timeRange() != null) {
-
tableTimeRange.add(parseTimePair(dataElementContext.tableElement().timeRange()));
- }
- }
- }
-
- if (targetDbs.isEmpty() && targetTables.isEmpty()) {
- throw new IllegalArgumentException(
- "No training data is supported for model, please indicate database
or table");
- }
- createTraining.setTargetDbs(targetDbs);
- createTraining.setTargetTables(targetTables);
-
- dbTimeRange.addAll(tableTimeRange);
- createTraining.setTargetTimeRanges(dbTimeRange);
- }
return createTraining;
}
diff --git
a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
index 56b2687b064..591592cfdbb 100644
---
a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
+++
b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
@@ -782,29 +782,7 @@ revokeGrantOpt
// ------------------------------------------- AI
---------------------------------------------------------
createModelStatement
- : CREATE MODEL modelType=identifier modelId=identifier (WITH
HYPERPARAMETERS '(' hparamPair (',' hparamPair)* ')')? (FROM MODEL
existingModelId=identifier)? ON DATASET '(' trainingData ')'
- ;
-
-trainingData
- : ALL
- | dataElement(',' dataElement)*
- ;
-
-dataElement
- : databaseElement
- | tableElement
- ;
-
-databaseElement
- : DATABASE database=identifier ('(' timeRange ')')?
- ;
-
-tableElement
- : TABLE tableName=qualifiedName ('(' timeRange ')')?
- ;
-
-timeRange
- : '[' startTime=timeValue ',' endTime=timeValue ']'
+ : CREATE MODEL modelType=identifier modelId=identifier (WITH
HYPERPARAMETERS '(' hparamPair (',' hparamPair)* ')')? (FROM MODEL
existingModelId=identifier)? ON DATASET '(' targetData=string ')'
;
hparamPair
diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
index f767f35d67c..e5d599ab632 100644
--- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
@@ -1086,8 +1086,7 @@ struct TUpdateModelInfoReq {
}
struct TDataSchemaForTable{
- 1: required list<string> databaseList
- 2: required list<string> tableList
+ 1: required string targetSql
}
struct TDataSchemaForTree{
@@ -1100,10 +1099,9 @@ struct TCreateTrainingReq {
3: required bool isTableModel
4: optional TDataSchemaForTable dataSchemaForTable
5: optional TDataSchemaForTree dataSchemaForTree
- 6: optional bool useAllData
- 7: optional map<string, string> parameters
- 8: optional string existingModelId
- 9: optional list<list<i64>> timeRanges
+ 6: optional map<string, string> parameters
+ 7: optional string existingModelId
+ 8: optional list<list<i64>> timeRanges
}
// ====================================================