This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 33d5b195fb0 Support Training Module of AINode
33d5b195fb0 is described below
commit 33d5b195fb014a9a8231c916f1ae8002bdbc6367
Author: YangCaiyin <[email protected]>
AuthorDate: Tue Mar 25 15:49:18 2025 +0800
Support Training Module of AINode
---
iotdb-client/client-py/resources/pyproject.toml | 2 +-
iotdb-core/ainode/.gitignore | 4 +-
iotdb-core/ainode/{iotdb => ainode}/__init__.py | 0
.../{iotdb/ainode => ainode/core}/__init__.py | 0
.../ainode/{iotdb/ainode => ainode/core}/client.py | 42 ++++--
.../ainode/{iotdb/ainode => ainode/core}/config.py | 22 +--
.../{iotdb/ainode => ainode/core}/constant.py | 0
.../{iotdb/ainode => ainode/core}/exception.py | 2 +-
.../{iotdb/ainode => ainode/core}/handler.py | 19 +--
.../ainode/{iotdb/ainode => ainode/core}/log.py | 4 +-
.../ainode => ainode/core}/manager/__init__.py | 0
.../core}/manager/cluster_manager.py | 4 +-
.../core}/manager/inference_manager.py | 14 +-
.../core}/manager/model_manager.py | 16 +--
.../ainode => ainode/core}/model/__init__.py | 0
.../core}/model/built_in_model_factory.py | 8 +-
.../ainode => ainode/core}/model/model_factory.py | 10 +-
.../ainode => ainode/core}/model/model_storage.py | 15 +-
.../ainode/{iotdb/ainode => ainode/core}/script.py | 16 +--
.../{iotdb/ainode => ainode/core}/service.py | 8 +-
.../{iotdb/ainode => ainode/core}/util/__init__.py | 0
.../ainode => ainode/core}/util/decorator.py | 0
.../{iotdb/ainode => ainode/core}/util/lock.py | 0
.../{iotdb/ainode => ainode/core}/util/serde.py | 6 +-
.../{iotdb/ainode => ainode/core}/util/status.py | 6 +-
iotdb-core/ainode/pom.xml | 8 +-
iotdb-core/ainode/pyproject.toml | 21 ++-
.../org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 | 13 ++
.../antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 | 4 +
.../request/write/model/UpdateModelInfoPlan.java | 4 +
.../iotdb/confignode/manager/ConfigManager.java | 112 +++++++++++++++
.../iotdb/confignode/manager/ModelManager.java | 51 ++++++-
.../thrift/ConfigNodeRPCServiceProcessor.java | 12 ++
.../iotdb/db/protocol/client/ConfigNodeClient.java | 14 ++
.../iotdb/db/queryengine/plan/Coordinator.java | 4 +
.../execution/config/TableConfigTaskVisitor.java | 36 +++++
.../execution/config/TreeConfigTaskVisitor.java | 30 +++-
.../config/executor/ClusterConfigTaskExecutor.java | 48 ++++++-
.../config/executor/IConfigTaskExecutor.java | 27 +++-
.../metadata/{model => ai}/CreateModelTask.java | 2 +-
.../config/metadata/ai/CreateTrainingTask.java | 108 ++++++++++++++
.../metadata/{model => ai}/DropModelTask.java | 2 +-
.../metadata/{model => ai}/ShowModelsTask.java | 2 +-
.../db/queryengine/plan/parser/ASTVisitor.java | 43 +++++-
.../plan/relational/sql/ast/AstVisitor.java | 8 ++
.../plan/relational/sql/ast/CreateTraining.java | 156 +++++++++++++++++++++
.../plan/relational/sql/ast/ShowModels.java | 74 ++++++++++
.../plan/relational/sql/parser/AstBuilder.java | 85 +++++++++++
.../plan/statement/StatementVisitor.java | 5 +
.../metadata/model/CreateTrainingStatement.java | 140 ++++++++++++++++++
.../iotdb/commons/client/ainode/AINodeClient.java | 13 ++
.../iotdb/commons/model/ModelInformation.java | 20 +++
.../apache/iotdb/commons/model/ModelStatus.java | 1 +
.../db/relational/grammar/sql/RelationalSql.g4 | 56 +++++++-
.../thrift-ainode/src/main/thrift/ainode.thrift | 16 +++
.../src/main/thrift/confignode.thrift | 35 +++++
56 files changed, 1227 insertions(+), 121 deletions(-)
diff --git a/iotdb-client/client-py/resources/pyproject.toml
b/iotdb-client/client-py/resources/pyproject.toml
index 09a94636e64..bd8ea77a3e2 100644
--- a/iotdb-client/client-py/resources/pyproject.toml
+++ b/iotdb-client/client-py/resources/pyproject.toml
@@ -42,7 +42,7 @@ dependencies = [
"thrift>=0.14.1",
"pandas>=1.0.0",
"numpy>=1.0.0",
- "sqlalchemy<1.5,>=1.4",
+ "sqlalchemy>=1.4",
"sqlalchemy-utils>=0.37.8"
]
diff --git a/iotdb-core/ainode/.gitignore b/iotdb-core/ainode/.gitignore
index b7ad350dc97..ddf402ed3b3 100644
--- a/iotdb-core/ainode/.gitignore
+++ b/iotdb-core/ainode/.gitignore
@@ -1,8 +1,8 @@
# generated by Thrift
-/iotdb/thrift/
+/ainode/thrift/
# generated by maven
-/iotdb/conf/
+/ainode/conf/
# .whl of ainode, generated by Poetry
/dist/
diff --git a/iotdb-core/ainode/iotdb/__init__.py
b/iotdb-core/ainode/ainode/__init__.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/__init__.py
rename to iotdb-core/ainode/ainode/__init__.py
diff --git a/iotdb-core/ainode/iotdb/ainode/__init__.py
b/iotdb-core/ainode/ainode/core/__init__.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/__init__.py
rename to iotdb-core/ainode/ainode/core/__init__.py
diff --git a/iotdb-core/ainode/iotdb/ainode/client.py
b/iotdb-core/ainode/ainode/core/client.py
similarity index 81%
rename from iotdb-core/ainode/iotdb/ainode/client.py
rename to iotdb-core/ainode/ainode/core/client.py
index e44cdf47676..2796d46ab93 100644
--- a/iotdb-core/ainode/iotdb/ainode/client.py
+++ b/iotdb-core/ainode/ainode/core/client.py
@@ -21,15 +21,16 @@ from thrift.Thrift import TException
from thrift.protocol import TCompactProtocol, TBinaryProtocol
from thrift.transport import TSocket, TTransport
-from iotdb.ainode.config import AINodeDescriptor
-from iotdb.ainode.constant import TSStatusCode
-from iotdb.ainode.log import Logger
-from iotdb.ainode.util.decorator import singleton
-from iotdb.ainode.util.status import verify_success
-from iotdb.thrift.common.ttypes import TEndPoint, TSStatus, TAINodeLocation,
TAINodeConfiguration
-from iotdb.thrift.confignode import IConfigNodeRPCService
-from iotdb.thrift.confignode.ttypes import (TAINodeRemoveReq, TNodeVersionInfo,
- TAINodeRegisterReq,
TAINodeRestartReq)
+from ainode.core.config import AINodeDescriptor
+from ainode.core.constant import TSStatusCode
+from ainode.core.log import Logger
+from ainode.core.util.decorator import singleton
+from ainode.core.util.status import verify_success
+from ainode.thrift.common.ttypes import TEndPoint, TSStatus, TAINodeLocation,
TAINodeConfiguration
+from ainode.thrift.confignode import IConfigNodeRPCService
+from ainode.thrift.confignode.ttypes import (TAINodeRemoveReq,
TNodeVersionInfo,
+ TAINodeRegisterReq,
TAINodeRestartReq)
+from ainode.thrift.confignode.ttypes import TUpdateModelInfoReq
logger = Logger()
@@ -202,3 +203,26 @@ class ConfigNodeClient(object):
self._config_leader = None
self._wait_and_reconnect()
raise TException(self._MSG_RECONNECTION_FAIL)
+
+ def update_model_info(self, model_id:str, model_status:int, attribute:str
= "", ainode_id=None, input_length=0, output_length=0) -> None:
+ if ainode_id is None:
+ ainode_id = []
+ for _ in range(0, self._RETRY_NUM):
+ try:
+ req = TUpdateModelInfoReq(
+ model_id, model_status, attribute
+ )
+ if ainode_id is not None:
+ req.aiNodeIds = ainode_id
+ req.inputLength = input_length
+ req.outputLength = output_length
+ status = self._client.updateModelInfo(req)
+ if not self._update_config_node_leader(status):
+ verify_success(status, "An error occurs when calling
update model info")
+ return status
+ except TTransport.TException:
+ logger.warning("Failed to connect to ConfigNode {} from AINode
when executing update model info",
+ self._config_leader)
+ self._config_leader = None
+ self._wait_and_reconnect()
+ raise TException(self._MSG_RECONNECTION_FAIL)
diff --git a/iotdb-core/ainode/iotdb/ainode/config.py
b/iotdb-core/ainode/ainode/core/config.py
similarity index 91%
rename from iotdb-core/ainode/iotdb/ainode/config.py
rename to iotdb-core/ainode/ainode/core/config.py
index af66bc48ccf..036dc20b9c2 100644
--- a/iotdb-core/ainode/iotdb/ainode/config.py
+++ b/iotdb-core/ainode/ainode/core/config.py
@@ -17,17 +17,17 @@
#
import os
-from iotdb.ainode.constant import (AINODE_CONF_DIRECTORY_NAME,
- AINODE_CONF_FILE_NAME,
- AINODE_MODELS_DIR, AINODE_LOG_DIR,
AINODE_SYSTEM_DIR, AINODE_INFERENCE_RPC_ADDRESS,
- AINODE_INFERENCE_RPC_PORT,
AINODE_THRIFT_COMPRESSION_ENABLED,
- AINODE_SYSTEM_FILE_NAME,
AINODE_CLUSTER_NAME, AINODE_VERSION_INFO, AINODE_BUILD_INFO,
- AINODE_CONF_GIT_FILE_NAME,
AINODE_CONF_POM_FILE_NAME, AINODE_ROOT_DIR,
- AINODE_ROOT_CONF_DIRECTORY_NAME)
-from iotdb.ainode.exception import BadNodeUrlError
-from iotdb.ainode.log import Logger
-from iotdb.ainode.util.decorator import singleton
-from iotdb.thrift.common.ttypes import TEndPoint
+from ainode.core.constant import (AINODE_CONF_DIRECTORY_NAME,
+ AINODE_CONF_FILE_NAME,
+ AINODE_MODELS_DIR, AINODE_LOG_DIR,
AINODE_SYSTEM_DIR, AINODE_INFERENCE_RPC_ADDRESS,
+ AINODE_INFERENCE_RPC_PORT,
AINODE_THRIFT_COMPRESSION_ENABLED,
+ AINODE_SYSTEM_FILE_NAME,
AINODE_CLUSTER_NAME, AINODE_VERSION_INFO, AINODE_BUILD_INFO,
+ AINODE_CONF_GIT_FILE_NAME,
AINODE_CONF_POM_FILE_NAME, AINODE_ROOT_DIR,
+ AINODE_ROOT_CONF_DIRECTORY_NAME)
+from ainode.core.exception import BadNodeUrlError
+from ainode.core.log import Logger
+from ainode.core.util.decorator import singleton
+from ainode.thrift.common.ttypes import TEndPoint
logger = Logger()
diff --git a/iotdb-core/ainode/iotdb/ainode/constant.py
b/iotdb-core/ainode/ainode/core/constant.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/constant.py
rename to iotdb-core/ainode/ainode/core/constant.py
diff --git a/iotdb-core/ainode/iotdb/ainode/exception.py
b/iotdb-core/ainode/ainode/core/exception.py
similarity index 98%
rename from iotdb-core/ainode/iotdb/ainode/exception.py
rename to iotdb-core/ainode/ainode/core/exception.py
index 56186ee2bef..a9b8c496d65 100644
--- a/iotdb-core/ainode/iotdb/ainode/exception.py
+++ b/iotdb-core/ainode/ainode/core/exception.py
@@ -17,7 +17,7 @@
#
import re
-from iotdb.ainode.constant import DEFAULT_MODEL_FILE_NAME,
DEFAULT_CONFIG_FILE_NAME
+from ainode.core.constant import DEFAULT_MODEL_FILE_NAME,
DEFAULT_CONFIG_FILE_NAME
class _BaseError(Exception):
diff --git a/iotdb-core/ainode/iotdb/ainode/handler.py
b/iotdb-core/ainode/ainode/core/handler.py
similarity index 69%
rename from iotdb-core/ainode/iotdb/ainode/handler.py
rename to iotdb-core/ainode/ainode/core/handler.py
index c27be605e51..7b94d209d08 100644
--- a/iotdb-core/ainode/iotdb/ainode/handler.py
+++ b/iotdb-core/ainode/ainode/core/handler.py
@@ -16,14 +16,14 @@
# under the License.
#
-from iotdb.ainode.manager.cluster_manager import ClusterManager
-from iotdb.ainode.manager.inference_manager import InferenceManager
-from iotdb.ainode.manager.model_manager import ModelManager
-from iotdb.thrift.ainode import IAINodeRPCService
-from iotdb.thrift.ainode.ttypes import (TDeleteModelReq, TRegisterModelReq,
- TAIHeartbeatReq, TInferenceReq,
TRegisterModelResp, TInferenceResp,
- TAIHeartbeatResp)
-from iotdb.thrift.common.ttypes import TSStatus
+from ainode.core.manager.cluster_manager import ClusterManager
+from ainode.core.manager.inference_manager import InferenceManager
+from ainode.core.manager.model_manager import ModelManager
+from ainode.thrift.ainode import IAINodeRPCService
+from ainode.thrift.ainode.ttypes import (TDeleteModelReq, TRegisterModelReq,
+ TAIHeartbeatReq, TInferenceReq,
TRegisterModelResp, TInferenceResp,
+ TAIHeartbeatResp, TTrainingReq)
+from ainode.thrift.common.ttypes import TSStatus
class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
@@ -41,3 +41,6 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
return ClusterManager.get_heart_beat(req)
+
+ def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
+ pass
diff --git a/iotdb-core/ainode/iotdb/ainode/log.py
b/iotdb-core/ainode/ainode/core/log.py
similarity index 97%
rename from iotdb-core/ainode/iotdb/ainode/log.py
rename to iotdb-core/ainode/ainode/core/log.py
index 4c796dbfcf6..4b2f412eaaf 100644
--- a/iotdb-core/ainode/iotdb/ainode/log.py
+++ b/iotdb-core/ainode/ainode/core/log.py
@@ -23,8 +23,8 @@ import random
import sys
import threading
-from iotdb.ainode.constant import STD_LEVEL, AINODE_LOG_FILE_NAMES,
AINODE_LOG_FILE_LEVELS
-from iotdb.ainode.util.decorator import singleton
+from ainode.core.constant import STD_LEVEL, AINODE_LOG_FILE_NAMES,
AINODE_LOG_FILE_LEVELS
+from ainode.core.util.decorator import singleton
class LoggerFilter(logging.Filter):
diff --git a/iotdb-core/ainode/iotdb/ainode/manager/__init__.py
b/iotdb-core/ainode/ainode/core/manager/__init__.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/manager/__init__.py
rename to iotdb-core/ainode/ainode/core/manager/__init__.py
diff --git a/iotdb-core/ainode/iotdb/ainode/manager/cluster_manager.py
b/iotdb-core/ainode/ainode/core/manager/cluster_manager.py
similarity index 93%
rename from iotdb-core/ainode/iotdb/ainode/manager/cluster_manager.py
rename to iotdb-core/ainode/ainode/core/manager/cluster_manager.py
index dff290a0b84..da7008b7762 100644
--- a/iotdb-core/ainode/iotdb/ainode/manager/cluster_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/cluster_manager.py
@@ -17,8 +17,8 @@
#
import psutil
-from iotdb.thrift.ainode.ttypes import TAIHeartbeatResp, TAIHeartbeatReq
-from iotdb.thrift.common.ttypes import TLoadSample
+from ainode.thrift.ainode.ttypes import TAIHeartbeatResp, TAIHeartbeatReq
+from ainode.thrift.common.ttypes import TLoadSample
class ClusterManager:
diff --git a/iotdb-core/ainode/iotdb/ainode/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
similarity index 95%
rename from iotdb-core/ainode/iotdb/ainode/manager/inference_manager.py
rename to iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 4c33ef6f918..ebfc6d41c9e 100644
--- a/iotdb-core/ainode/iotdb/ainode/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -18,13 +18,13 @@
import pandas as pd
from torch import tensor
-from iotdb.ainode.constant import TSStatusCode
-from iotdb.ainode.exception import InvalidWindowArgumentError,
InferenceModelInternalError, runtime_error_extractor
-from iotdb.ainode.log import Logger
-from iotdb.ainode.manager.model_manager import ModelManager
-from iotdb.ainode.util.serde import convert_to_binary, convert_to_df
-from iotdb.ainode.util.status import get_status
-from iotdb.thrift.ainode.ttypes import TInferenceReq, TInferenceResp
+from ainode.core.constant import TSStatusCode
+from ainode.core.exception import InvalidWindowArgumentError,
InferenceModelInternalError, runtime_error_extractor
+from ainode.core.log import Logger
+from ainode.core.manager.model_manager import ModelManager
+from ainode.core.util.serde import convert_to_binary, convert_to_df
+from ainode.core.util.status import get_status
+from ainode.thrift.ainode.ttypes import TInferenceReq, TInferenceResp
logger = Logger()
diff --git a/iotdb-core/ainode/iotdb/ainode/manager/model_manager.py
b/iotdb-core/ainode/ainode/core/manager/model_manager.py
similarity index 86%
rename from iotdb-core/ainode/iotdb/ainode/manager/model_manager.py
rename to iotdb-core/ainode/ainode/core/manager/model_manager.py
index 1ccdea95998..ead833a5839 100644
--- a/iotdb-core/ainode/iotdb/ainode/manager/model_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py
@@ -19,14 +19,14 @@ from typing import Callable
from yaml import YAMLError
-from iotdb.ainode.constant import TSStatusCode, BuiltInModelType
-from iotdb.ainode.exception import InvalidUriError, BadConfigValueError,
BuiltInModelNotSupportError
-from iotdb.ainode.log import Logger
-from iotdb.ainode.model.built_in_model_factory import fetch_built_in_model
-from iotdb.ainode.model.model_storage import ModelStorage
-from iotdb.ainode.util.status import get_status
-from iotdb.thrift.ainode.ttypes import TRegisterModelReq, TRegisterModelResp,
TDeleteModelReq
-from iotdb.thrift.common.ttypes import TSStatus
+from ainode.core.constant import TSStatusCode, BuiltInModelType
+from ainode.core.exception import InvalidUriError, BadConfigValueError,
BuiltInModelNotSupportError
+from ainode.core.log import Logger
+from ainode.core.model.built_in_model_factory import fetch_built_in_model
+from ainode.core.model.model_storage import ModelStorage
+from ainode.core.util.status import get_status
+from ainode.thrift.ainode.ttypes import TRegisterModelReq, TRegisterModelResp,
TDeleteModelReq
+from ainode.thrift.common.ttypes import TSStatus
logger = Logger()
diff --git a/iotdb-core/ainode/iotdb/ainode/model/__init__.py
b/iotdb-core/ainode/ainode/core/model/__init__.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/model/__init__.py
rename to iotdb-core/ainode/ainode/core/model/__init__.py
diff --git a/iotdb-core/ainode/iotdb/ainode/model/built_in_model_factory.py
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
similarity index 99%
rename from iotdb-core/ainode/iotdb/ainode/model/built_in_model_factory.py
rename to iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index 82443012176..4524272dfaf 100644
--- a/iotdb-core/ainode/iotdb/ainode/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -27,11 +27,11 @@ from sktime.forecasting.exp_smoothing import
ExponentialSmoothing
from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.trend import STLForecaster
-from iotdb.ainode.constant import AttributeName, BuiltInModelType
-from iotdb.ainode.exception import InferenceModelInternalError,
AttributeNotSupportError
-from iotdb.ainode.exception import WrongAttributeTypeError,
NumericalRangeException, StringRangeException, \
+from ainode.core.constant import AttributeName, BuiltInModelType
+from ainode.core.exception import InferenceModelInternalError,
AttributeNotSupportError
+from ainode.core.exception import WrongAttributeTypeError,
NumericalRangeException, StringRangeException, \
ListRangeException, BuiltInModelNotSupportError
-from iotdb.ainode.log import Logger
+from ainode.core.log import Logger
logger = Logger()
diff --git a/iotdb-core/ainode/iotdb/ainode/model/model_factory.py
b/iotdb-core/ainode/ainode/core/model/model_factory.py
similarity index 96%
rename from iotdb-core/ainode/iotdb/ainode/model/model_factory.py
rename to iotdb-core/ainode/ainode/core/model/model_factory.py
index a163cf1fa84..1700dd28eb6 100644
--- a/iotdb-core/ainode/iotdb/ainode/model/model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/model_factory.py
@@ -23,12 +23,12 @@ import yaml
from requests import Session
from requests.adapters import HTTPAdapter
-from iotdb.ainode.constant import DEFAULT_RECONNECT_TIMES,
DEFAULT_RECONNECT_TIMEOUT, DEFAULT_CHUNK_SIZE, \
+from ainode.core.constant import DEFAULT_RECONNECT_TIMES,
DEFAULT_RECONNECT_TIMEOUT, DEFAULT_CHUNK_SIZE, \
DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME
-from iotdb.ainode.exception import InvalidUriError, BadConfigValueError
-from iotdb.ainode.log import Logger
-from iotdb.ainode.util.serde import get_data_type_byte_from_str
-from iotdb.thrift.ainode.ttypes import TConfigs
+from ainode.core.exception import InvalidUriError, BadConfigValueError
+from ainode.core.log import Logger
+from ainode.core.util.serde import get_data_type_byte_from_str
+from ainode.thrift.ainode.ttypes import TConfigs
HTTP_PREFIX = "http://"
HTTPS_PREFIX = "https://"
diff --git a/iotdb-core/ainode/iotdb/ainode/model/model_storage.py
b/iotdb-core/ainode/ainode/core/model/model_storage.py
similarity index 92%
rename from iotdb-core/ainode/iotdb/ainode/model/model_storage.py
rename to iotdb-core/ainode/ainode/core/model/model_storage.py
index 9a1df5f9c20..43ebf6c06b7 100644
--- a/iotdb-core/ainode/iotdb/ainode/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -24,14 +24,13 @@ import torch
import torch._dynamo
from pylru import lrucache
-from iotdb.ainode.config import AINodeDescriptor
-from iotdb.ainode.constant import (DEFAULT_MODEL_FILE_NAME,
- DEFAULT_CONFIG_FILE_NAME)
-from iotdb.ainode.exception import ModelNotExistError
-from iotdb.ainode.log import Logger
-from iotdb.ainode.model.model_factory import fetch_model_by_uri
-from iotdb.ainode.util.lock import ModelLockPool
-
+from ainode.core.config import AINodeDescriptor
+from ainode.core.constant import (DEFAULT_MODEL_FILE_NAME,
+ DEFAULT_CONFIG_FILE_NAME)
+from ainode.core.exception import ModelNotExistError
+from ainode.core.log import Logger
+from ainode.core.model.model_factory import fetch_model_by_uri
+from ainode.core.util.lock import ModelLockPool
logger = Logger()
diff --git a/iotdb-core/ainode/iotdb/ainode/script.py
b/iotdb-core/ainode/ainode/core/script.py
similarity index 93%
rename from iotdb-core/ainode/iotdb/ainode/script.py
rename to iotdb-core/ainode/ainode/core/script.py
index e06a3fe77e0..b27bb6ab61b 100644
--- a/iotdb-core/ainode/iotdb/ainode/script.py
+++ b/iotdb-core/ainode/ainode/core/script.py
@@ -22,14 +22,14 @@ from datetime import datetime
import psutil
-from iotdb.ainode.client import ClientManager
-from iotdb.ainode.config import AINodeDescriptor
-from iotdb.ainode.constant import TSStatusCode, AINODE_SYSTEM_FILE_NAME
-from iotdb.ainode.exception import MissingConfigError
-from iotdb.ainode.log import Logger
-from iotdb.ainode.service import RPCService
-from iotdb.thrift.common.ttypes import TAINodeLocation, TEndPoint,
TAINodeConfiguration, TNodeResource
-from iotdb.thrift.confignode.ttypes import TNodeVersionInfo
+from ainode.core.client import ClientManager
+from ainode.core.config import AINodeDescriptor
+from ainode.core.constant import TSStatusCode, AINODE_SYSTEM_FILE_NAME
+from ainode.core.exception import MissingConfigError
+from ainode.core.log import Logger
+from ainode.core.service import RPCService
+from ainode.thrift.common.ttypes import TAINodeLocation, TEndPoint,
TAINodeConfiguration, TNodeResource
+from ainode.thrift.confignode.ttypes import TNodeVersionInfo
logger = Logger()
diff --git a/iotdb-core/ainode/iotdb/ainode/service.py
b/iotdb-core/ainode/ainode/core/service.py
similarity index 91%
rename from iotdb-core/ainode/iotdb/ainode/service.py
rename to iotdb-core/ainode/ainode/core/service.py
index 54954dd7d50..9532093d1da 100644
--- a/iotdb-core/ainode/iotdb/ainode/service.py
+++ b/iotdb-core/ainode/ainode/core/service.py
@@ -21,10 +21,10 @@ from thrift.protocol import TCompactProtocol,
TBinaryProtocol
from thrift.server import TServer
from thrift.transport import TSocket, TTransport
-from iotdb.ainode.config import AINodeDescriptor
-from iotdb.ainode.handler import AINodeRPCServiceHandler
-from iotdb.ainode.log import Logger
-from iotdb.thrift.ainode import IAINodeRPCService
+from ainode.core.config import AINodeDescriptor
+from ainode.core.handler import AINodeRPCServiceHandler
+from ainode.core.log import Logger
+from ainode.thrift.ainode import IAINodeRPCService
logger = Logger()
diff --git a/iotdb-core/ainode/iotdb/ainode/util/__init__.py
b/iotdb-core/ainode/ainode/core/util/__init__.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/util/__init__.py
rename to iotdb-core/ainode/ainode/core/util/__init__.py
diff --git a/iotdb-core/ainode/iotdb/ainode/util/decorator.py
b/iotdb-core/ainode/ainode/core/util/decorator.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/util/decorator.py
rename to iotdb-core/ainode/ainode/core/util/decorator.py
diff --git a/iotdb-core/ainode/iotdb/ainode/util/lock.py
b/iotdb-core/ainode/ainode/core/util/lock.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/util/lock.py
rename to iotdb-core/ainode/ainode/core/util/lock.py
diff --git a/iotdb-core/ainode/iotdb/ainode/util/serde.py
b/iotdb-core/ainode/ainode/core/util/serde.py
similarity index 98%
rename from iotdb-core/ainode/iotdb/ainode/util/serde.py
rename to iotdb-core/ainode/ainode/core/util/serde.py
index 4338dcdfefc..b9edccfd03e 100644
--- a/iotdb-core/ainode/iotdb/ainode/util/serde.py
+++ b/iotdb-core/ainode/ainode/core/util/serde.py
@@ -21,7 +21,7 @@ from enum import Enum
import numpy as np
import pandas as pd
-from iotdb.ainode.exception import BadConfigValueError
+from ainode.core.exception import BadConfigValueError
class TSDataType(Enum):
@@ -143,7 +143,7 @@ def convert_to_df(name_list, type_list, name_index,
binary_list):
time_column_values, np.dtype(np.longlong).newbyteorder(">")
)
if time_array.dtype.byteorder == ">":
- time_array = time_array.byteswap().newbyteorder("<")
+ time_array =
time_array.byteswap().view(time_array.dtype.newbyteorder("<"))
if result[TIMESTAMP_STR] is None:
result[TIMESTAMP_STR] = time_array
@@ -198,7 +198,7 @@ def convert_to_df(name_list, type_list, name_index,
binary_list):
raise RuntimeError("unsupported data type
{}.".format(data_type))
if data_array.dtype.byteorder == ">":
- data_array = data_array.byteswap().newbyteorder("<")
+ data_array =
data_array.byteswap().view(data_array.dtype.newbyteorder("<"))
null_indicator = null_indicators[location]
if len(data_array) < total_length or (data_type ==
TSDataType.BOOLEAN and null_indicator is not None):
diff --git a/iotdb-core/ainode/iotdb/ainode/util/status.py
b/iotdb-core/ainode/ainode/core/util/status.py
similarity index 90%
rename from iotdb-core/ainode/iotdb/ainode/util/status.py
rename to iotdb-core/ainode/ainode/core/util/status.py
index 1bcbef7a806..37368b0068b 100644
--- a/iotdb-core/ainode/iotdb/ainode/util/status.py
+++ b/iotdb-core/ainode/ainode/core/util/status.py
@@ -16,9 +16,9 @@
# under the License.
#
-from iotdb.ainode.constant import TSStatusCode
-from iotdb.ainode.log import Logger
-from iotdb.thrift.common.ttypes import TSStatus
+from ainode.core.constant import TSStatusCode
+from ainode.core.log import Logger
+from ainode.thrift.common.ttypes import TSStatus
def get_status(status_code: TSStatusCode, message: str = None) -> TSStatus:
diff --git a/iotdb-core/ainode/pom.xml b/iotdb-core/ainode/pom.xml
index ad6679c7701..ee3137cf664 100644
--- a/iotdb-core/ainode/pom.xml
+++ b/iotdb-core/ainode/pom.xml
@@ -67,7 +67,7 @@
<directory>dist</directory>
</fileset>
<fileset>
- <directory>iotdb</directory>
+ <directory>ainode</directory>
<includes>
<include>conf/</include>
<include>thrift/</include>
@@ -138,7 +138,7 @@
<goal>copy-resources</goal>
</goals>
<configuration>
-
<outputDirectory>${basedir}/iotdb/thrift/</outputDirectory>
+
<outputDirectory>${basedir}/ainode/thrift/</outputDirectory>
<resources>
<resource>
<directory>${basedir}/../../iotdb-protocol/thrift-commons/target/generated-sources-python/iotdb/thrift/</directory>
@@ -167,7 +167,7 @@
<goal>copy-resources</goal>
</goals>
<configuration>
-
<outputDirectory>${basedir}/iotdb/conf/</outputDirectory>
+
<outputDirectory>${basedir}/ainode/conf/</outputDirectory>
<resources>
<resource>
<directory>${basedir}/resources/</directory>
@@ -192,7 +192,7 @@
</goals>
<configuration>
<generateGitPropertiesFile>true</generateGitPropertiesFile>
-
<generateGitPropertiesFilename>${project.basedir}/iotdb/conf/git.properties</generateGitPropertiesFilename>
+
<generateGitPropertiesFilename>${project.basedir}/ainode/conf/git.properties</generateGitPropertiesFilename>
<includeOnlyProperties>
<includeOnlyProperty>^git.commit.id.abbrev$</includeOnlyProperty>
<includeOnlyProperty>^git.dirty$</includeOnlyProperty>
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index efbbdb1edaf..f5b083b8f69 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -34,26 +34,24 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
]
include = [
- {path = "iotdb/thrift/*", format = "wheel"},
- {path = "iotdb/thrift/common/*", format = "wheel"},
- {path = "iotdb/thrift/confignode/*", format = "wheel"},
- {path = "iotdb/thrift/datanode/*", format = "wheel"},
- {path = "iotdb/thrift/ainode/*", format = "wheel"},
- {path = "iotdb/conf/*", format = "wheel"},
+ {path = "ainode/thrift/*", format = "wheel"},
+ {path = "ainode/thrift/common/*", format = "wheel"},
+ {path = "ainode/thrift/confignode/*", format = "wheel"},
+ {path = "ainode/thrift/datanode/*", format = "wheel"},
+ {path = "ainode/thrift/ainode/*", format = "wheel"},
+ {path = "ainode/conf/*", format = "wheel"},
]
packages = [
- { include = "iotdb" }
+ { include = "ainode" }
]
[tool.poetry.dependencies]
python = ">=3.8, <3.13"
-
numpy = "^1.21.4"
pandas = "^1.3.5"
torch = ">=2.2.0"
pylru = "^1.2.1"
-
-thrift = "^0.13.0"
+thrift = ">=0.14.0"
dynaconf = "^3.1.11"
requests = "^2.31.0"
optuna = "^3.2.0"
@@ -61,6 +59,7 @@ psutil = "^5.9.5"
sktime = "^0.24.1"
pmdarima = "^2.0.4"
hmmlearn = "^0.3.0"
+apache-iotdb = "2.0.1b0"
[tool.poetry.scripts]
-ainode = "iotdb.ainode.script:main"
\ No newline at end of file
+ainode = "ainode.core.script:main"
\ No newline at end of file
diff --git
a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
index d7bc94cd50a..b4a7701267f 100644
---
a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
+++
b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
@@ -694,6 +694,19 @@ showSubscriptions
// ---- Create Model
createModel
: CREATE MODEL modelName=identifier uriClause
+ | CREATE MODEL modelType=identifier modelId=identifier (WITH
HYPERPARAMETERS LR_BRACKET hparamPair (COMMA hparamPair)* RR_BRACKET)? (FROM
MODEL existingModelId=identifier)? ON DATASET LR_BRACKET trainingData RR_BRACKET
+ ;
+
+trainingData
+ : dataElement(COMMA dataElement)*
+ ;
+
+dataElement
+ : pathPatternElement (LR_BRACKET timeRange RR_BRACKET)?
+ ;
+
+pathPatternElement
+ : PATH path=prefixPath
;
windowFunction
diff --git
a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
index 572e02b756e..e31d6718323 100644
--- a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
+++ b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
@@ -574,6 +574,10 @@ PASSWORD
: P A S S W O R D
;
+PATH
+ : P A T H
+ ;
+
PATHS
: P A T H S
;
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
index ca74d2daf69..ce7219e4281 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
@@ -72,6 +72,10 @@ public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
return nodeIds;
}
+ public void setNodeIds(List<Integer> nodeIds) {
+ this.nodeIds = nodeIds;
+ }
+
@Override
protected void serializeImpl(DataOutputStream stream) throws IOException {
stream.writeShort(getType().getPlanType());
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index 614bdafaefe..3867c5ecb7f 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -19,6 +19,8 @@
package org.apache.iotdb.confignode.manager;
+import org.apache.iotdb.ainode.rpc.thrift.IDataSchema;
+import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration;
import org.apache.iotdb.common.rpc.thrift.TAINodeLocation;
import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation;
@@ -37,6 +39,9 @@ import
org.apache.iotdb.common.rpc.thrift.TShowConfigurationResp;
import org.apache.iotdb.common.rpc.thrift.TTimePartitionSlot;
import org.apache.iotdb.commons.auth.AuthException;
import org.apache.iotdb.commons.auth.entity.PrivilegeUnion;
+import org.apache.iotdb.commons.client.ainode.AINodeClient;
+import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
+import org.apache.iotdb.commons.client.ainode.AINodeInfo;
import org.apache.iotdb.commons.cluster.NodeStatus;
import org.apache.iotdb.commons.cluster.NodeType;
import org.apache.iotdb.commons.conf.CommonConfig;
@@ -46,6 +51,7 @@ import org.apache.iotdb.commons.conf.IoTDBConstant;
import org.apache.iotdb.commons.conf.TrimProperties;
import org.apache.iotdb.commons.exception.IllegalPathException;
import org.apache.iotdb.commons.exception.MetadataException;
+import org.apache.iotdb.commons.model.ModelStatus;
import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.commons.path.PathPatternTree;
import org.apache.iotdb.commons.path.PathPatternUtil;
@@ -83,6 +89,7 @@ import
org.apache.iotdb.confignode.consensus.request.write.database.SetSchemaRep
import org.apache.iotdb.confignode.consensus.request.write.database.SetTTLPlan;
import
org.apache.iotdb.confignode.consensus.request.write.database.SetTimePartitionIntervalPlan;
import
org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan;
+import
org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
import
org.apache.iotdb.confignode.consensus.request.write.template.CreateSchemaTemplatePlan;
import
org.apache.iotdb.confignode.consensus.response.ainode.AINodeRegisterResp;
import org.apache.iotdb.confignode.consensus.response.auth.PermissionInfoResp;
@@ -151,11 +158,13 @@ import
org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq;
+import org.apache.iotdb.confignode.rpc.thrift.TCreateTrainingReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTriggerReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartResp;
import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionTableResp;
+import org.apache.iotdb.confignode.rpc.thrift.TDataSchemaForTable;
import org.apache.iotdb.confignode.rpc.thrift.TDatabaseSchema;
import org.apache.iotdb.confignode.rpc.thrift.TDeactivateSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TDeleteDatabasesReq;
@@ -233,10 +242,12 @@ import
org.apache.iotdb.confignode.rpc.thrift.TSpaceQuotaResp;
import org.apache.iotdb.confignode.rpc.thrift.TStartPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TStopPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TSubscribeReq;
+import org.apache.iotdb.confignode.rpc.thrift.TTableInfo;
import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp;
import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList;
import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq;
+import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq;
import org.apache.iotdb.consensus.common.DataSet;
import org.apache.iotdb.consensus.exception.ConsensusException;
import org.apache.iotdb.db.schemaengine.template.Template;
@@ -336,6 +347,8 @@ public class ConfigManager implements IManager {
private static final String DATABASE = "\tDatabase=";
+ private static final String DOT = ".";
+
public ConfigManager() throws IOException {
// Build the persistence module
ClusterInfo clusterInfo = new ClusterInfo();
@@ -2577,6 +2590,98 @@ public class ConfigManager implements IManager {
: status;
}
+ private List<IDataSchema> fetchSchemaForTreeModel(TCreateTrainingReq req) {
+ List<IDataSchema> dataSchemaList = new ArrayList<>();
+ if (req.useAllData) {
+ dataSchemaList.add(new IDataSchema("root.**"));
+ return dataSchemaList;
+ }
+ for (int i = 0; i < req.getDataSchemaForTree().getPathSize(); i++) {
+ IDataSchema dataSchema = new
IDataSchema(req.getDataSchemaForTree().getPath().get(i));
+ dataSchema.setTimeRange(req.getTimeRanges().get(i));
+ dataSchemaList.add(dataSchema);
+ }
+ return dataSchemaList;
+ }
+
+ private List<IDataSchema> fetchSchemaForTableModel(TCreateTrainingReq req) {
+ List<IDataSchema> dataSchemaList = new ArrayList<>();
+ TDataSchemaForTable dataSchemaForTable = req.getDataSchemaForTable();
+ if (req.useAllData || !dataSchemaForTable.getDatabaseList().isEmpty()) {
+ List<String> databaseNameList = new ArrayList<>();
+ if (req.useAllData) {
+ TShowDatabaseResp resp = showDatabase(new TGetDatabaseReq());
+ databaseNameList.addAll(resp.getDatabaseInfoMap().keySet());
+ } else {
+ databaseNameList.addAll(dataSchemaForTable.getDatabaseList());
+ }
+
+ for (String database : databaseNameList) {
+ TShowTableResp resp = showTables(database, false);
+ for (TTableInfo tableInfo : resp.getTableInfoList()) {
+ dataSchemaList.add(new IDataSchema(database + DOT +
tableInfo.tableName));
+ }
+ }
+ }
+ for (String tableName : dataSchemaForTable.getTableList()) {
+ dataSchemaList.add(new IDataSchema(dataSchemaForTable.curDatabase + DOT
+ tableName));
+ }
+ return dataSchemaList;
+ }
+
+ public TSStatus createTraining(TCreateTrainingReq req) {
+ TSStatus status = confirmLeader();
+ if (nodeManager.getRegisteredAINodes().isEmpty()) {
+ return new
TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode())
+ .setMessage("There is no available AINode! Try to start one.");
+ }
+
+ TTrainingReq trainingReq = new TTrainingReq();
+ trainingReq.setModelId(req.getModelId());
+ trainingReq.setModelType("timer_xl");
+ if (req.existingModelId != null) {
+ trainingReq.setExistingModelId(req.getExistingModelId());
+ }
+ if (!req.parameters.isEmpty()) {
+ trainingReq.setParameters(req.getParameters());
+ }
+
+ try {
+ status = getConsensusManager().write(new
CreateModelPlan(req.getModelId()));
+ if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ throw new MetadataException("Can't init model " + req.getModelId());
+ }
+
+ List<IDataSchema> dataSchema;
+ if (req.isTableModel) {
+ dataSchema = fetchSchemaForTableModel(req);
+ trainingReq.setDbType("iotdb.table");
+ } else {
+ dataSchema = fetchSchemaForTreeModel(req);
+ trainingReq.setDbType("iotdb.tree");
+ }
+ updateModelInfo(new TUpdateModelInfoReq(req.modelId,
ModelStatus.TRAINING.ordinal()));
+ trainingReq.setTargetDataSchema(dataSchema);
+
+ try (AINodeClient client =
+ AINodeClientManager.getInstance().borrowClient(AINodeInfo.endPoint))
{
+ status = client.createTrainingTask(trainingReq);
+ if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ throw new IllegalArgumentException(status.message);
+ }
+ }
+ } catch (final Exception e) {
+ status.setCode(TSStatusCode.CAN_NOT_CONNECT_CONFIGNODE.getStatusCode());
+ status.setMessage(e.getMessage());
+ try {
+ updateModelInfo(new TUpdateModelInfoReq(req.modelId,
ModelStatus.UNAVAILABLE.ordinal()));
+ } catch (Exception e2) {
+ LOGGER.error(e2.getMessage());
+ }
+ }
+ return status;
+ }
+
@Override
public TSStatus dropModel(TDropModelReq req) {
TSStatus status = confirmLeader();
@@ -2601,6 +2706,13 @@ public class ConfigManager implements IManager {
: new TGetModelInfoResp(status);
}
+ public TSStatus updateModelInfo(TUpdateModelInfoReq req) {
+ TSStatus status = confirmLeader();
+ return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
+ ? modelManager.updateModelInfo(req)
+ : status;
+ }
+
@Override
public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) {
TSStatus status = confirmLeader();
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
index 3c19dfdb14a..2ebc4d71dfc 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
@@ -20,9 +20,13 @@
package org.apache.iotdb.confignode.manager;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
+import org.apache.iotdb.commons.model.ModelInformation;
+import org.apache.iotdb.commons.model.ModelStatus;
import org.apache.iotdb.commons.model.ModelType;
import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
+import
org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
+import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp;
import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp;
import org.apache.iotdb.confignode.persistence.ModelInfo;
@@ -32,6 +36,7 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp;
+import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq;
import org.apache.iotdb.consensus.common.DataSet;
import org.apache.iotdb.consensus.exception.ConsensusException;
import org.apache.iotdb.rpc.TSStatusCode;
@@ -60,7 +65,18 @@ public class ModelManager {
return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode())
.setMessage(String.format("Model name %s already exists",
req.modelName));
}
- return configManager.getProcedureManager().createModel(req.modelName,
req.uri);
+ try {
+ if (req.uri.isEmpty()) {
+ return configManager.getConsensusManager().write(new
CreateModelPlan(req.modelName));
+ }
+ return configManager.getProcedureManager().createModel(req.modelName,
req.uri);
+ } catch (ConsensusException e) {
+ LOGGER.warn("Unexpected error happened while getting model: ", e);
+ // consensus layer related errors
+ TSStatus res = new
TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
+ res.setMessage(e.getMessage());
+ return res;
+ }
}
public TSStatus dropModel(TDropModelReq req) {
@@ -125,6 +141,39 @@ public class ModelManager {
}
}
+ // Currently this method is only used by built-in timer_xl
+ public TSStatus updateModelInfo(TUpdateModelInfoReq req) {
+ if (!modelInfo.contain(req.getModelId())) {
+ return new TSStatus(TSStatusCode.MODEL_NOT_FOUND_ERROR.getStatusCode())
+ .setMessage(String.format("Model %s doesn't exists",
req.getModelId()));
+ }
+ try {
+ ModelInformation modelInformation =
+ new ModelInformation(ModelType.USER_DEFINED, req.getModelId());
+
modelInformation.updateStatus(ModelStatus.values()[req.getModelStatus()]);
+ modelInformation.setAttribute(req.getAttributes());
+ modelInformation.setInputColumnSize(1);
+ if (req.isSetOutputLength()) {
+ modelInformation.setOutputLength(req.getOutputLength());
+ }
+ if (req.isSetInputLength()) {
+ modelInformation.setInputLength(req.getInputLength());
+ }
+ UpdateModelInfoPlan updateModelInfoPlan =
+ new UpdateModelInfoPlan(req.getModelId(), modelInformation);
+ if (req.isSetAiNodeIds()) {
+ updateModelInfoPlan.setNodeIds(req.getAiNodeIds());
+ }
+ return configManager.getConsensusManager().write(updateModelInfoPlan);
+ } catch (ConsensusException e) {
+ LOGGER.warn("Unexpected error happened while updating model info: ", e);
+ // consensus layer related errors
+ TSStatus res = new
TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
+ res.setMessage(e.getMessage());
+ return res;
+ }
+ }
+
public List<Integer> getModelDistributions(String modelName) {
return modelInfo.getNodeIds(modelName);
}
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
index 2137da48d0d..bba940b77e2 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
@@ -114,6 +114,7 @@ import
org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq;
+import org.apache.iotdb.confignode.rpc.thrift.TCreateTrainingReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTriggerReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeConfigurationResp;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq;
@@ -217,6 +218,7 @@ import
org.apache.iotdb.confignode.rpc.thrift.TTestOperation;
import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp;
import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq;
+import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq;
import org.apache.iotdb.confignode.service.ConfigNode;
import org.apache.iotdb.consensus.exception.ConsensusException;
import org.apache.iotdb.db.queryengine.plan.relational.type.AuthorRType;
@@ -1374,6 +1376,16 @@ public class ConfigNodeRPCServiceProcessor implements
IConfigNodeRPCService.Ifac
return configManager.getModelInfo(req);
}
+ @Override
+ public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException {
+ return configManager.updateModelInfo(req);
+ }
+
+ @Override
+ public TSStatus createTraining(TCreateTrainingReq req) throws TException {
+ return configManager.createTraining(req);
+ }
+
@Override
public TSStatus setSpaceQuota(final TSetSpaceQuotaReq req) throws TException
{
return configManager.setSpaceQuota(req);
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java
index c12a63c4122..53361b8022f 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java
@@ -73,6 +73,7 @@ import
org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq;
+import org.apache.iotdb.confignode.rpc.thrift.TCreateTrainingReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTriggerReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeConfigurationResp;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq;
@@ -176,6 +177,7 @@ import
org.apache.iotdb.confignode.rpc.thrift.TTestOperation;
import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp;
import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq;
+import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq;
import org.apache.iotdb.db.conf.IoTDBConfig;
import org.apache.iotdb.db.conf.IoTDBDescriptor;
import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory;
@@ -1285,6 +1287,18 @@ public class ConfigNodeClient implements
IConfigNodeRPCService.Iface, ThriftClie
() -> client.getModelInfo(req), resp ->
!updateConfigNodeLeader(resp.getStatus()));
}
+ @Override
+ public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException {
+ return executeRemoteCallWithRetry(
+ () -> client.updateModelInfo(req), status ->
!updateConfigNodeLeader(status));
+ }
+
+ @Override
+ public TSStatus createTraining(TCreateTrainingReq req) throws TException {
+ return executeRemoteCallWithRetry(
+ () -> client.createTraining(req), status ->
!updateConfigNodeLeader(status));
+ }
+
@Override
public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException {
return executeRemoteCallWithRetry(
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java
index 5b3e283ec7b..39f07c4c62d 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java
@@ -64,6 +64,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ClearCache;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateDB;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateFunction;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTable;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTraining;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DeleteDevice;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DescribeTable;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropColumn;
@@ -98,6 +99,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowCurrentUser;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowDB;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowDataNodes;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowFunctions;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowModels;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowRegions;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowTables;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowVariables;
@@ -443,6 +445,8 @@ public class Coordinator {
|| statement instanceof MigrateRegion
|| statement instanceof ReconstructRegion
|| statement instanceof ExtendRegion
+ || statement instanceof CreateTraining
+ || statement instanceof ShowModels
|| statement instanceof RemoveRegion) {
return new ConfigExecution(
queryContext,
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
index a90ed1e7b46..ce986ad8083 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
@@ -50,6 +50,8 @@ import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowFuncti
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowPipePluginsTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowRegionTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTrainingTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowModelsTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.region.ExtendRegionTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.region.MigrateRegionTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.region.ReconstructRegionTask;
@@ -119,6 +121,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreatePipe;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreatePipePlugin;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTable;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTopic;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTraining;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DataType;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DatabaseStatement;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DeleteDevice;
@@ -165,6 +168,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowCurrentUser;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowDB;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowDataNodes;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowFunctions;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowModels;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowPipePlugins;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowPipes;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowRegions;
@@ -198,6 +202,7 @@ import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.Pair;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
@@ -1214,4 +1219,35 @@ public class TableConfigTaskVisitor extends
AstVisitor<IConfigTask, MPPQueryCont
// corresponding tree-model variant and execute that.
return new RemoveRegionTask(removeRegion);
}
+
+ @Override
+ protected IConfigTask visitCreateTraining(CreateTraining node,
MPPQueryContext context) {
+ context.setQueryType(QueryType.WRITE);
+
+ String curDatabase = clientSession.getDatabaseName();
+ List<String> tableList = new ArrayList<>();
+ for (QualifiedName tableName : node.getTargetTables()) {
+ List<String> parts = tableName.getParts();
+ if (parts.size() == 1) {
+ tableList.add(curDatabase + "." + parts.get(0));
+ } else {
+ tableList.add(parts.get(1) + "." + parts.get(0));
+ }
+ }
+
+ return new CreateTrainingTask(
+ node.getModelId(),
+ node.getModelType(),
+ node.getParameters(),
+ node.isUseAllData(),
+ node.getTargetTimeRanges(),
+ node.getExistingModelId(),
+ node.getTargetDbs(),
+ tableList);
+ }
+
+ @Override
+ protected IConfigTask visitShowModels(ShowModels node, MPPQueryContext
context) {
+ return new ShowModelsTask(node.getModelId());
+ }
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
index fd8243377e6..c381d31857a 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
@@ -21,6 +21,7 @@ package org.apache.iotdb.db.queryengine.plan.execution.config;
import org.apache.iotdb.common.rpc.thrift.Model;
import org.apache.iotdb.commons.executable.ExecutableManager;
+import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.commons.pipe.config.constant.SystemConstant;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
@@ -58,9 +59,10 @@ import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowTTLTas
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowTriggersTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.UnSetTTLTask;
-import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.CreateModelTask;
-import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.DropModelTask;
-import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.ShowModelsTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateModelTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTrainingTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.DropModelTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowModelsTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.region.ExtendRegionTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.region.MigrateRegionTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.region.ReconstructRegionTask;
@@ -138,6 +140,7 @@ import
org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowTriggersState
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowVariablesStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.UnSetTTLStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateModelStatement;
+import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateTrainingStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.DropModelStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowAINodesStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowModelsStatement;
@@ -191,6 +194,9 @@ import
org.apache.iotdb.db.queryengine.plan.statement.sys.quota.ShowThrottleQuot
import org.apache.tsfile.exception.NotImplementedException;
+import java.util.ArrayList;
+import java.util.List;
+
import static
org.apache.iotdb.commons.executable.ExecutableManager.getUnTrustedUriErrorMsg;
import static
org.apache.iotdb.commons.executable.ExecutableManager.isUriTrusted;
@@ -758,4 +764,22 @@ public class TreeConfigTaskVisitor extends
StatementVisitor<IConfigTask, MPPQuer
ShowCurrentSqlDialectStatement node, MPPQueryContext context) {
return new
ShowCurrentSqlDialectTask(context.getSession().getSqlDialect().name());
}
+
+ @Override
+ public IConfigTask visitCreateTraining(
+ CreateTrainingStatement createTrainingStatement, MPPQueryContext
context) {
+ List<PartialPath> partialPathList =
createTrainingStatement.getTargetPathPatterns();
+ List<String> targetPathPatterns = new ArrayList<>();
+ for (PartialPath partialPath : partialPathList) {
+ targetPathPatterns.add(partialPath.getFullPath());
+ }
+ return new CreateTrainingTask(
+ createTrainingStatement.getModelId(),
+ createTrainingStatement.getModelType(),
+ createTrainingStatement.getParameters(),
+ false,
+ createTrainingStatement.getTargetTimeRanges(),
+ createTrainingStatement.getExistingModelId(),
+ targetPathPatterns);
+ }
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index 299b8d80420..3f8dc3f1f21 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -82,9 +82,12 @@ import
org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq;
+import org.apache.iotdb.confignode.rpc.thrift.TCreateTrainingReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTriggerReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRemoveReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRemoveResp;
+import org.apache.iotdb.confignode.rpc.thrift.TDataSchemaForTable;
+import org.apache.iotdb.confignode.rpc.thrift.TDataSchemaForTree;
import org.apache.iotdb.confignode.rpc.thrift.TDatabaseSchema;
import org.apache.iotdb.confignode.rpc.thrift.TDeactivateSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TDeleteDatabasesReq;
@@ -187,7 +190,7 @@ import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowRegion
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowTTLTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowTriggersTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask;
-import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.ShowModelsTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowModelsTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.region.ExtendRegionTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.region.MigrateRegionTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.region.ReconstructRegionTask;
@@ -3164,6 +3167,49 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
return future;
}
+ @Override
+ public SettableFuture<ConfigTaskResult> createTraining(
+ String modelId,
+ String modelType,
+ boolean isTableModel,
+ Map<String, String> parameters,
+ boolean useAllData,
+ List<List<Long>> timeRanges,
+ String existingModelId,
+ @Nullable List<String> tableList,
+ @Nullable List<String> databaseList,
+ @Nullable List<String> pathList) {
+ final SettableFuture<ConfigTaskResult> future = SettableFuture.create();
+ try (final ConfigNodeClient client =
+
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) {
+ final TCreateTrainingReq req = new TCreateTrainingReq(modelId,
modelType, isTableModel);
+
+ if (isTableModel) {
+ TDataSchemaForTable dataSchemaForTable = new TDataSchemaForTable();
+ dataSchemaForTable.setTableList(tableList);
+ dataSchemaForTable.setDatabaseList(databaseList);
+ req.setDataSchemaForTable(dataSchemaForTable);
+ } else {
+ TDataSchemaForTree dataSchemaForTree = new TDataSchemaForTree();
+ dataSchemaForTree.setPath(pathList);
+ req.setDataSchemaForTree(dataSchemaForTree);
+ }
+ req.setParameters(parameters);
+ req.setUseAllData(useAllData);
+ req.setTimeRanges(timeRanges);
+ req.setExistingModelId(existingModelId);
+ final TSStatus executionStatus = client.createTraining(req);
+ if (TSStatusCode.SUCCESS_STATUS.getStatusCode() !=
executionStatus.getCode()) {
+ future.setException(new IoTDBException(executionStatus.message,
executionStatus.code));
+ } else {
+ future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS));
+ }
+ } catch (final ClientManagerException | TException e) {
+ future.setException(e);
+ }
+ return future;
+ }
+
@Override
public SettableFuture<ConfigTaskResult> setSpaceQuota(
final SetSpaceQuotaStatement setSpaceQuotaStatement) {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
index f66a078b52a..6e844322a9e 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
@@ -288,13 +288,6 @@ public interface IConfigTaskExecutor {
TThrottleQuotaResp getThrottleQuota();
- SettableFuture<ConfigTaskResult> createModel(
- CreateModelStatement createModelStatement, MPPQueryContext context);
-
- SettableFuture<ConfigTaskResult> dropModel(String modelName);
-
- SettableFuture<ConfigTaskResult> showModels(String modelName);
-
TPipeTransferResp handleTransferConfigPlan(String clientId, TPipeTransferReq
req);
void handlePipeConfigClientExit(String clientId);
@@ -399,4 +392,24 @@ public interface IConfigTaskExecutor {
SettableFuture<ConfigTaskResult> showCurrentDatabase(@Nullable String
currentDatabase);
SettableFuture<ConfigTaskResult> showCurrentTimestamp();
+
+ // =============================== AI
=========================================
+ SettableFuture<ConfigTaskResult> createModel(
+ CreateModelStatement createModelStatement, MPPQueryContext context);
+
+ SettableFuture<ConfigTaskResult> dropModel(String modelName);
+
+ SettableFuture<ConfigTaskResult> showModels(String modelName);
+
+ SettableFuture<ConfigTaskResult> createTraining(
+ String modelId,
+ String modelType,
+ boolean isTableModel,
+ Map<String, String> parameters,
+ boolean useAllData,
+ List<List<Long>> timeRanges,
+ String existingModelId,
+ @Nullable List<String> tableList,
+ @Nullable List<String> databaseList,
+ @Nullable List<String> pathList);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/CreateModelTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateModelTask.java
similarity index 99%
rename from
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/CreateModelTask.java
rename to
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateModelTask.java
index 842b558fd8a..875455529b5 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/CreateModelTask.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateModelTask.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model;
+package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai;
import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult;
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
new file mode 100644
index 00000000000..84a6aa45f6d
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai;
+
+import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult;
+import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor;
+
+import com.google.common.util.concurrent.ListenableFuture;
+
+import java.util.List;
+import java.util.Map;
+
+public class CreateTrainingTask implements IConfigTask {
+
+ private final String modelId;
+ private final String modelType;
+ private final boolean isTableModel;
+ private final Map<String, String> parameters;
+ private final boolean useAllData;
+ private final List<List<Long>> timeRanges;
+ private final String existingModelId;
+
+ // Data schema for table model
+ private List<String> targetTables;
+ private List<String> targetDbs;
+ // Data schema for tree model
+ private List<String> targetPaths;
+
+ public CreateTrainingTask(
+ String modelId,
+ String modelType,
+ Map<String, String> parameters,
+ boolean useAllData,
+ List<List<Long>> timeRanges,
+ String existingModelId,
+ List<String> targetTables,
+ List<String> targetDbs) {
+ if (!modelType.equalsIgnoreCase("timer_xl")) {
+ throw new UnsupportedOperationException("Only TimerXL model is supported
now.");
+ }
+ this.modelId = modelId;
+ this.modelType = modelType;
+ this.parameters = parameters;
+ this.useAllData = useAllData;
+ this.timeRanges = timeRanges;
+ this.existingModelId = existingModelId;
+
+ this.isTableModel = true;
+ this.targetTables = targetTables;
+ this.targetDbs = targetDbs;
+ }
+
+ public CreateTrainingTask(
+ String modelId,
+ String modelType,
+ Map<String, String> parameters,
+ boolean useAllData,
+ List<List<Long>> timeRanges,
+ String existingModelId,
+ List<String> targetPaths) {
+ if (!modelType.equalsIgnoreCase("timer_xl")) {
+ throw new UnsupportedOperationException("Only TimerXL model is supported
now.");
+ }
+ this.modelId = modelId;
+ this.modelType = modelType;
+ this.parameters = parameters;
+ this.useAllData = useAllData;
+ this.timeRanges = timeRanges;
+ this.existingModelId = existingModelId;
+
+ this.isTableModel = false;
+ this.targetPaths = targetPaths;
+ }
+
+ @Override
+ public ListenableFuture<ConfigTaskResult> execute(IConfigTaskExecutor
configTaskExecutor)
+ throws InterruptedException {
+ return configTaskExecutor.createTraining(
+ modelId,
+ modelType,
+ isTableModel,
+ parameters,
+ useAllData,
+ timeRanges,
+ existingModelId,
+ targetTables,
+ targetDbs,
+ targetPaths);
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/DropModelTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/DropModelTask.java
similarity index 99%
rename from
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/DropModelTask.java
rename to
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/DropModelTask.java
index f8db88790d4..688e413cf85 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/DropModelTask.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/DropModelTask.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model;
+package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai;
import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult;
import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask;
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/ShowModelsTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowModelsTask.java
similarity index 99%
rename from
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/ShowModelsTask.java
rename to
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowModelsTask.java
index 83bafbccee3..120a0737a38 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/ShowModelsTask.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowModelsTask.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model;
+package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai;
import org.apache.iotdb.commons.model.ModelType;
import org.apache.iotdb.commons.schema.column.ColumnHeader;
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
index af5afd4c0c9..1ece2b8510e 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java
@@ -172,6 +172,7 @@ import
org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowTriggersState
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowVariablesStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.UnSetTTLStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateModelStatement;
+import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateTrainingStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.DropModelStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowAINodesStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowModelsStatement;
@@ -1344,9 +1345,49 @@ public class ASTVisitor extends
IoTDBSqlParserBaseVisitor<Statement> {
@Override
public Statement visitCreateModel(IoTDBSqlParser.CreateModelContext ctx) {
- CreateModelStatement createModelStatement = new CreateModelStatement();
+ if (ctx.modelName == null) {
+ String modelId = ctx.modelId.getText();
+ String modelType = ctx.modelType.getText();
+ CreateTrainingStatement createTrainingStatement =
+ new CreateTrainingStatement(modelId, modelType);
+ if (ctx.hparamPair() != null) {
+ Map<String, String> parameterList = new HashMap<>();
+ for (IoTDBSqlParser.HparamPairContext hparamPairContext :
ctx.hparamPair()) {
+ parameterList.put(
+ hparamPairContext.hparamKey.getText(),
hparamPairContext.hparamValue().getText());
+ }
+ createTrainingStatement.setParameters(parameterList);
+ }
+
+ if (ctx.existingModelId != null) {
+
createTrainingStatement.setExistingModelId(ctx.existingModelId.getText());
+ }
+
+ if (ctx.trainingData() == null) {
+ throw new UnsupportedOperationException("data should not be set for
model training");
+ }
+
+ List<PartialPath> targetPath = new ArrayList<>();
+ List<List<Long>> timeRanges = new ArrayList<>();
+ for (IoTDBSqlParser.DataElementContext dataElementContext :
+ ctx.trainingData().dataElement()) {
+ if (dataElementContext.timeRange() != null) {
+ long currentTime = CommonDateTimeUtils.currentTime();
+ long startTime =
parseTimeValue(dataElementContext.timeRange().timeValue(0), currentTime);
+ long endTime =
parseTimeValue(dataElementContext.timeRange().timeValue(1), currentTime);
+ timeRanges.add(Arrays.asList(startTime, endTime));
+ } else {
+ timeRanges.add(Collections.emptyList());
+ }
+
targetPath.add(parsePrefixPath(dataElementContext.pathPatternElement().prefixPath()));
+ }
+ createTrainingStatement.setTargetTimeRanges(timeRanges);
+ createTrainingStatement.setTargetPathPatterns(targetPath);
+ return createTrainingStatement;
+ }
String modelName = ctx.modelName.getText();
validateModelName(modelName);
+ CreateModelStatement createModelStatement = new CreateModelStatement();
createModelStatement.setModelName(parseIdentifier(modelName));
createModelStatement.setUri(parseAndValidateURI(ctx.uriClause()));
return createModelStatement;
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java
index ff076f3e1ed..7175853fb5b 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java
@@ -685,6 +685,14 @@ public abstract class AstVisitor<R, C> {
return visitStatement(node, context);
}
+ protected R visitCreateTraining(CreateTraining node, C context) {
+ return visitStatement(node, context);
+ }
+
+ protected R visitShowModels(ShowModels node, C context) {
+ return visitStatement(node, context);
+ }
+
public R visitTableArgument(TableFunctionTableArgument
tableFunctionTableArgument, C context) {
return visitNode(tableFunctionTableArgument, context);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
new file mode 100644
index 00000000000..3c978ccb5c6
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.sql.ast;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+public class CreateTraining extends Statement {
+
+ private final String modelId;
+ private final String modelType;
+
+ private Map<String, String> parameters;
+ private String existingModelId = null;
+
+ private List<QualifiedName> targetTables;
+ private List<String> targetDbs;
+
+ private List<List<Long>> targetTimeRanges;
+ private boolean useAllData = false;
+
+ public CreateTraining(String modelId, String modelType) {
+ super(null);
+ this.modelId = modelId;
+ this.modelType = modelType;
+ }
+
+ @Override
+ public <R, C> R accept(AstVisitor<R, C> visitor, C context) {
+ return visitor.visitCreateTraining(this, context);
+ }
+
+ public void setParameters(Map<String, String> parameters) {
+ this.parameters = parameters;
+ }
+
+ public void setExistingModelId(String existingModelId) {
+ this.existingModelId = existingModelId;
+ }
+
+ public void setTargetDbs(List<String> targetDbs) {
+ this.targetDbs = targetDbs;
+ }
+
+ public void setTargetTables(List<QualifiedName> targetTables) {
+ this.targetTables = targetTables;
+ }
+
+ public void setUseAllData(boolean useAllData) {
+ this.useAllData = useAllData;
+ }
+
+ public List<String> getTargetDbs() {
+ return targetDbs;
+ }
+
+ public List<QualifiedName> getTargetTables() {
+ return targetTables;
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ public String getModelType() {
+ return modelType;
+ }
+
+ public Map<String, String> getParameters() {
+ return parameters;
+ }
+
+ public String getExistingModelId() {
+ return existingModelId;
+ }
+
+ public boolean isUseAllData() {
+ return useAllData;
+ }
+
+ public void setTargetTimeRanges(List<List<Long>> targetTimeRanges) {
+ this.targetTimeRanges = targetTimeRanges;
+ }
+
+ public List<List<Long>> getTargetTimeRanges() {
+ return targetTimeRanges;
+ }
+
+ @Override
+ public List<? extends Node> getChildren() {
+ return null;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ modelId, modelType, existingModelId, parameters, targetTimeRanges,
useAllData);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof CreateTraining)) {
+ return false;
+ }
+ CreateTraining createTraining = (CreateTraining) obj;
+ return modelId.equals(createTraining.modelId)
+ && modelType.equals(createTraining.modelType)
+ && Objects.equals(existingModelId, createTraining.existingModelId)
+ && Objects.equals(parameters, createTraining.parameters)
+ && Objects.equals(targetTimeRanges, createTraining.targetTimeRanges)
+ && useAllData == createTraining.useAllData;
+ }
+
+ @Override
+ public String toString() {
+ return "CreateTraining{"
+ + "modelId='"
+ + modelId
+ + '\''
+ + ", modelType='"
+ + modelType
+ + '\''
+ + ", parameters="
+ + parameters
+ + ", existingModelId='"
+ + existingModelId
+ + '\''
+ + ", targetTables="
+ + targetTables
+ + ", targetDbs="
+ + targetDbs
+ + ", targetTimeRanges="
+ + targetTimeRanges
+ + ", useAllData="
+ + useAllData
+ + '}';
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ShowModels.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ShowModels.java
new file mode 100644
index 00000000000..6032a29a246
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ShowModels.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.sql.ast;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+import static com.google.common.base.MoreObjects.toStringHelper;
+
+public class ShowModels extends Statement {
+
+ private String modelId = null;
+
+ public ShowModels() {
+ super(null);
+ }
+
+ @Override
+ public List<? extends Node> getChildren() {
+ return ImmutableList.of();
+ }
+
+ public void setModelId(String modelId) {
+ this.modelId = modelId;
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ @Override
+ public <R, C> R accept(AstVisitor<R, C> visitor, C context) {
+ return visitor.visitShowModels(this, context);
+ }
+
+ @Override
+ public int hashCode() {
+ return 0;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof ShowModels)) {
+ return false;
+ }
+ if (modelId == null) {
+ return ((ShowModels) obj).getModelId() == null;
+ }
+ return modelId.equals(((ShowModels) obj).getModelId());
+ }
+
+ @Override
+ public String toString() {
+ return toStringHelper(this).toString();
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
index 3a4e247ce58..a1c4aba7ab4 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
@@ -60,6 +60,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreatePipe;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreatePipePlugin;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTable;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTopic;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTraining;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CurrentDatabase;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CurrentTime;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CurrentUser;
@@ -159,6 +160,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowDataNodes;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowDevice;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowFunctions;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowIndex;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowModels;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowPipePlugins;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowPipes;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowQueriesStatement;
@@ -2896,6 +2898,89 @@ public class AstBuilder extends
RelationalSqlBaseVisitor<Node> {
return super.visitIntervalField(ctx);
}
+ // ***************** AI *****************
+ public static void validateModelName(String modelName) {
+ if (modelName.length() < 2 || modelName.length() > 64) {
+ throw new SemanticException("Model name should be 2-64 characters");
+ } else if (modelName.startsWith("_")) {
+ throw new SemanticException("Model name should not start with '_'");
+ } else if (!modelName.matches("^[-\\w]*$")) {
+ throw new SemanticException("ModelName can only contain letters,
numbers, and underscores");
+ }
+ }
+
+ private List<Long> parseTimePair(RelationalSqlParser.TimeRangeContext
timeRangeContext) {
+ long currentTime = CommonDateTimeUtils.currentTime();
+ List<Long> timeRange = new ArrayList<>();
+ timeRange.add(parseTimeValue(timeRangeContext.timeValue(0), currentTime));
+ timeRange.add(parseTimeValue(timeRangeContext.timeValue(1), currentTime));
+ return timeRange;
+ }
+
+ @Override
+ public Node
visitCreateModelStatement(RelationalSqlParser.CreateModelStatementContext ctx) {
+ String modelId = ctx.modelId.getText();
+ validateModelName(modelId);
+ String modelType = ctx.modelType.getText();
+ CreateTraining createTraining = new CreateTraining(modelId, modelType);
+ if (ctx.HYPERPARAMETERS() != null) {
+ Map<String, String> parameters = new HashMap<>();
+ for (RelationalSqlParser.HparamPairContext hparamPairContext :
ctx.hparamPair()) {
+ parameters.put(
+ hparamPairContext.hparamKey.getText(),
hparamPairContext.hyparamValue.getText());
+ }
+ createTraining.setParameters(parameters);
+ }
+
+ if (ctx.existingModelId != null) {
+ createTraining.setExistingModelId(ctx.existingModelId.getText());
+ }
+
+ List<List<Long>> dbTimeRange = new ArrayList<>();
+ List<List<Long>> tableTimeRange = new ArrayList<>();
+ if (ctx.trainingData().ALL() != null) {
+ createTraining.setUseAllData(true);
+ } else {
+ List<QualifiedName> targetTables = new ArrayList<>();
+ List<String> targetDbs = new ArrayList<>();
+ for (RelationalSqlParser.DataElementContext dataElementContext :
+ ctx.trainingData().dataElement()) {
+ if (dataElementContext.databaseElement() != null) {
+ targetDbs.add(
+ ((Identifier)
visit(dataElementContext.databaseElement().database)).getValue());
+ if (dataElementContext.databaseElement().timeRange() != null) {
+
dbTimeRange.add(parseTimePair(dataElementContext.databaseElement().timeRange()));
+ }
+ } else {
+
targetTables.add(getQualifiedName(dataElementContext.tableElement().qualifiedName()));
+ if (dataElementContext.tableElement().timeRange() != null) {
+
tableTimeRange.add(parseTimePair(dataElementContext.tableElement().timeRange()));
+ }
+ }
+ }
+
+ if (targetDbs.isEmpty() && targetTables.isEmpty()) {
+ throw new IllegalArgumentException(
+ "No training data is supported for model, please indicate database
or table");
+ }
+ createTraining.setTargetDbs(targetDbs);
+ createTraining.setTargetTables(targetTables);
+
+ dbTimeRange.addAll(tableTimeRange);
+ createTraining.setTargetTimeRanges(dbTimeRange);
+ }
+ return createTraining;
+ }
+
+ @Override
+ public Node
visitShowModelsStatement(RelationalSqlParser.ShowModelsStatementContext ctx) {
+ ShowModels showModels = new ShowModels();
+ if (ctx.modelId != null) {
+ showModels.setModelId(ctx.modelId.getText());
+ }
+ return showModels;
+ }
+
// ***************** arguments *****************
@Override
public Node visitGenericType(RelationalSqlParser.GenericTypeContext ctx) {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/StatementVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/StatementVisitor.java
index 1ff54dff369..4ecfcb65079 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/StatementVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/StatementVisitor.java
@@ -76,6 +76,7 @@ import
org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowTriggersState
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowVariablesStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.UnSetTTLStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateModelStatement;
+import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateTrainingStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.DropModelStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowAINodesStatement;
import
org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowModelsStatement;
@@ -680,4 +681,8 @@ public abstract class StatementVisitor<R, C> {
public R visitShowCurrentUser(ShowCurrentUserStatement
showCurrentUserStatement, C context) {
return visitStatement(showCurrentUserStatement, context);
}
+
+ public R visitCreateTraining(CreateTrainingStatement
createTrainingStatement, C context) {
+ return visitStatement(createTrainingStatement, context);
+ }
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateTrainingStatement.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateTrainingStatement.java
new file mode 100644
index 00000000000..6f1dd4735d9
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateTrainingStatement.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.statement.metadata.model;
+
+import org.apache.iotdb.commons.path.PartialPath;
+import org.apache.iotdb.db.queryengine.plan.analyze.QueryType;
+import org.apache.iotdb.db.queryengine.plan.statement.IConfigStatement;
+import org.apache.iotdb.db.queryengine.plan.statement.Statement;
+import org.apache.iotdb.db.queryengine.plan.statement.StatementVisitor;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+public class CreateTrainingStatement extends Statement implements
IConfigStatement {
+
+ private final String modelId;
+ private final String modelType;
+
+ private Map<String, String> parameters;
+ private String existingModelId = null;
+
+ private List<PartialPath> targetPathPatterns;
+ private List<List<Long>> targetTimeRanges;
+
+ public CreateTrainingStatement(String modelId, String modelType) {
+ this.modelId = modelId;
+ this.modelType = modelType;
+ }
+
+ public void setTargetPathPatterns(List<PartialPath> targetPathPatterns) {
+ this.targetPathPatterns = targetPathPatterns;
+ }
+
+ public Map<String, String> getParameters() {
+ return parameters;
+ }
+
+ public String getExistingModelId() {
+ return existingModelId;
+ }
+
+ public List<PartialPath> getTargetPathPatterns() {
+ return targetPathPatterns;
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ public String getModelType() {
+ return modelType;
+ }
+
+ public void setExistingModelId(String existingModelId) {
+ this.existingModelId = existingModelId;
+ }
+
+ public void setTargetTimeRanges(List<List<Long>> targetTimeRanges) {
+ this.targetTimeRanges = targetTimeRanges;
+ }
+
+ public List<List<Long>> getTargetTimeRanges() {
+ return targetTimeRanges;
+ }
+
+ public void setParameters(Map<String, String> parameters) {
+ this.parameters = parameters;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), modelId, modelType, existingModelId,
parameters);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof CreateTrainingStatement)) {
+ return false;
+ }
+ CreateTrainingStatement target = (CreateTrainingStatement) obj;
+ return modelId.equals(target.modelId)
+ && modelType.equals(target.modelType)
+ && Objects.equals(existingModelId, target.existingModelId)
+ && Objects.equals(parameters, target.parameters);
+ }
+
+ @Override
+ public String toString() {
+ return "CreateTrainingStatement{"
+ + "modelId='"
+ + modelId
+ + '\''
+ + ", modelType='"
+ + modelType
+ + '\''
+ + ", parameters="
+ + parameters
+ + ", existingModelId='"
+ + existingModelId
+ + '\''
+ + ", targetPathPatterns="
+ + targetPathPatterns
+ + ", targetTimeRanges="
+ + targetTimeRanges
+ + '}';
+ }
+
+ @Override
+ public List<? extends PartialPath> getPaths() {
+ return targetPathPatterns;
+ }
+
+ @Override
+ public QueryType getQueryType() {
+ return QueryType.WRITE;
+ }
+
+ @Override
+ public <R, C> R accept(StatementVisitor<R, C> visitor, C context) {
+ return visitor.visitCreateTraining(this, context);
+ }
+}
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index cc93a8f32d5..1eca6e7f16a 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -26,6 +26,7 @@ import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp;
+import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
import org.apache.iotdb.ainode.rpc.thrift.TWindowParams;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
@@ -187,6 +188,18 @@ public class AINodeClient implements AutoCloseable,
ThriftClient {
}
}
+ public TSStatus createTrainingTask(TTrainingReq req) throws TException {
+ try {
+ return client.createTrainingTask(req);
+ } catch (TException e) {
+ logger.warn(
+ "Failed to connect to AINode from DataNode when executing {}: {}",
+ Thread.currentThread().getStackTrace()[1].getMethodName(),
+ e.getMessage());
+ throw new TException(MSG_CONNECTION_FAIL);
+ }
+ }
+
@Override
public void close() throws Exception {
Optional.ofNullable(transport).ifPresent(TTransport::close);
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index 0582989e00c..9e84c92a311 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -120,6 +120,14 @@ public class ModelInformation {
return modelName;
}
+ public void setInputLength(int length) {
+ inputShape[0] = length;
+ }
+
+ public void setOutputLength(int length) {
+ outputShape[0] = length;
+ }
+
// calculation modelType and outputColumn metadata for different built-in
models
public void setInputColumnSize(int size) {
inputShape[1] = size;
@@ -127,11 +135,16 @@ public class ModelInformation {
outputShape[1] = size;
} else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION) {
outputShape[1] = 1;
+ } else {
+ outputShape[1] = size;
}
if (modelType == ModelType.BUILT_IN_FORECAST) {
buildOutputDataTypeForBuiltInModel(TSDataType.DOUBLE, outputShape[1]);
} else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION) {
buildOutputDataTypeForBuiltInModel(TSDataType.INT32, outputShape[1]);
+ } else {
+ buildOutputDataTypeForBuiltInModel(TSDataType.FLOAT, outputShape[1]);
+ buildInputDataTypeForBuiltInModel(TSDataType.FLOAT, inputShape[1]);
}
}
@@ -146,6 +159,13 @@ public class ModelInformation {
}
}
+ private void buildInputDataTypeForBuiltInModel(TSDataType tsDataType, int
num) {
+ inputDataType = new TSDataType[num];
+ for (int i = 0; i < num; i++) {
+ inputDataType[i] = tsDataType;
+ }
+ }
+
public int[] getInputShape() {
return inputShape;
}
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelStatus.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelStatus.java
index 7aac33dac23..20f536d34a9 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelStatus.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelStatus.java
@@ -22,6 +22,7 @@ package org.apache.iotdb.commons.model;
public enum ModelStatus {
INACTIVE,
LOADING,
+ TRAINING,
ACTIVE,
DROPPING,
UNAVAILABLE
diff --git
a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
index 2b1f6b23b8d..1776a6ac88f 100644
---
a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
+++
b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
@@ -148,6 +148,10 @@ statement
| listUserStatement
| listRoleStatement
+ // AI
+ | createModelStatement
+ | showModelsStatement
+
// View, Trigger, pipe, CQ, Quota are not supported yet
;
@@ -702,6 +706,43 @@ revokeGrantOpt
: GRANT OPTION FOR
;
+// ------------------------------------------- AI
---------------------------------------------------------
+
+createModelStatement
+ : CREATE MODEL modelType=identifier modelId=identifier (WITH
HYPERPARAMETERS '(' hparamPair (',' hparamPair)* ')')? (FROM MODEL
existingModelId=identifier)? ON DATASET '(' trainingData ')'
+ ;
+
+trainingData
+ : ALL
+ | dataElement(',' dataElement)*
+ ;
+
+dataElement
+ : databaseElement
+ | tableElement
+ ;
+
+databaseElement
+ : DATABASE database=identifier ('(' timeRange ')')?
+ ;
+
+tableElement
+ : TABLE tableName=qualifiedName ('(' timeRange ')')?
+ ;
+
+timeRange
+ : '[' startTime=timeValue ',' endTime=timeValue ']'
+ ;
+
+hparamPair
+ : hparamKey=identifier '=' hyparamValue=primaryExpression
+ ;
+
+showModelsStatement
+ : SHOW MODELS
+ | SHOW MODELS modelId=identifier
+ ;
+
// ------------------------------------------- Query Statement
---------------------------------------------------------
queryStatement
: query
#statementDefault
@@ -1124,16 +1165,16 @@ nonReserved
: ABSENT | ADD | ADMIN | AFTER | ALL | ANALYZE | ANY | ARRAY | ASC | AT |
ATTRIBUTE | AUTHORIZATION
| BEGIN | BERNOULLI | BOTH
| CACHE | CALL | CALLED | CASCADE | CATALOG | CATALOGS | CHAR | CHARACTER
| CHARSET | CLEAR | CLUSTER | CLUSTERID | COLUMN | COLUMNS | COMMENT | COMMIT |
COMMITTED | CONDITION | CONDITIONAL | CONFIGNODES | CONFIGNODE | CONFIGURATION
| CONNECTOR | CONSTANT | COPARTITION | COUNT | CURRENT
- | DATA | DATABASE | DATABASES | DATANODE | DATANODES | DATE | DAY |
DECLARE | DEFAULT | DEFINE | DEFINER | DENY | DESC | DESCRIPTOR | DETAILS|
DETERMINISTIC | DEVICES | DISTRIBUTED | DO | DOUBLE
+ | DATA | DATABASE | DATABASES | DATANODE | DATANODES | DATASET | DATE |
DAY | DECLARE | DEFAULT | DEFINE | DEFINER | DENY | DESC | DESCRIPTOR |
DETAILS| DETERMINISTIC | DEVICES | DISTRIBUTED | DO | DOUBLE
| ELSEIF | EMPTY | ENCODING | ERROR | EXCLUDING | EXPLAIN | EXTRACTOR
| FETCH | FIELD | FILTER | FINAL | FIRST | FLUSH | FOLLOWING | FORMAT |
FUNCTION | FUNCTIONS
| GRACE | GRANT | GRANTED | GRANTS | GRAPHVIZ | GROUPS
- | HOUR
+ | HOUR | HYPERPARAMETERS
| INDEX | INDEXES | IF | IGNORE | IMMEDIATE | INCLUDING | INITIAL | INPUT
| INTERVAL | INVOKER | IO | ITERATE | ISOLATION
| JSON
| KEEP | KEY | KEYS | KILL
| LANGUAGE | LAST | LATERAL | LEADING | LEAVE | LEVEL | LIMIT | LINEAR |
LOAD | LOCAL | LOGICAL | LOOP
- | MANAGE_ROLE | MANAGE_USER | MAP | MATCH | MATCHED | MATCHES |
MATCH_RECOGNIZE | MATERIALIZED | MEASURES | METHOD | MERGE | MICROSECOND |
MIGRATE | MILLISECOND | MINUTE | MODIFY | MONTH
+ | MANAGE_ROLE | MANAGE_USER | MAP | MATCH | MATCHED | MATCHES |
MATCH_RECOGNIZE | MATERIALIZED | MEASURES | METHOD | MERGE | MICROSECOND |
MIGRATE | MILLISECOND | MINUTE | MODEL | MODELS | MODIFY | MONTH
| NANOSECOND | NESTED | NEXT | NFC | NFD | NFKC | NFKD | NO | NODEID |
NONE | NULLIF | NULLS
| OBJECT | OF | OFFSET | OMIT | ONE | ONLY | OPTION | ORDINALITY | OUTPUT
| OVER | OVERFLOW
| PARTITION | PARTITIONS | PASSING | PAST | PATH | PATTERN | PER | PERIOD
| PERMUTE | PIPE | PIPEPLUGIN | PIPEPLUGINS | PIPES | PLAN | POSITION |
PRECEDING | PRECISION | PRIVILEGES | PREVIOUS | PROCESSLIST | PROCESSOR |
PROPERTIES | PRUNE
@@ -1141,7 +1182,7 @@ nonReserved
| RANGE | READ | READONLY | RECONSTRUCT | REFRESH | REGION | REGIONID |
REGIONS | REMOVE | RENAME | REPAIR | REPEAT | REPEATABLE | REPLACE | RESET |
RESPECT | RESTRICT | RETURN | RETURNING | RETURNS | REVOKE | ROLE | ROLES |
ROLLBACK | ROW | ROWS | RUNNING
| SERIESSLOTID | SCALAR | SCHEMA | SCHEMAS | SECOND | SECURITY | SEEK |
SERIALIZABLE | SESSION | SET | SETS
| SHOW | SINK | SOME | SOURCE | START | STATS | STOP | SUBSCRIPTIONS |
SUBSET | SUBSTRING | SYSTEM
- | TABLES | TABLESAMPLE | TAG | TEXT | TEXT_STRING | TIES | TIME |
TIMEPARTITION | TIMESERIES | TIMESLOTID | TIMESTAMP | TO | TOPIC | TOPICS |
TRAILING | TRANSACTION | TRUNCATE | TRY_CAST | TYPE
+ | TABLES | TABLESAMPLE | TAG | TEXT | TEXT_STRING | TIES | TIME |
TIMEPARTITION | TIMER | TIMER_XL | TIMESERIES | TIMESLOTID | TIMESTAMP | TO |
TOPIC | TOPICS | TRAILING | TRANSACTION | TRUNCATE | TRY_CAST | TYPE
| UNBOUNDED | UNCOMMITTED | UNCONDITIONAL | UNIQUE | UNKNOWN | UNMATCHED |
UNTIL | UPDATE | URI | USE | USED | USER | UTF16 | UTF32 | UTF8
| VALIDATE | VALUE | VARIABLES | VARIATION | VERBOSE | VERSION | VIEW
| WEEK | WHILE | WINDOW | WITHIN | WITHOUT | WORK | WRAPPER | WRITE
@@ -1219,6 +1260,7 @@ DATABASE: 'DATABASE';
DATABASES: 'DATABASES';
DATANODE: 'DATANODE';
DATANODES: 'DATANODES';
+DATASET: 'DATASET';
DATE: 'DATE';
DATE_BIN: 'DATE_BIN';
DATE_BIN_GAPFILL: 'DATE_BIN_GAPFILL';
@@ -1282,6 +1324,7 @@ GROUPING: 'GROUPING';
GROUPS: 'GROUPS';
HAVING: 'HAVING';
HOUR: 'HOUR' | 'H';
+HYPERPARAMETERS: 'HYPERPARAMETERS';
INDEX: 'INDEX';
INDEXES: 'INDEXES';
IF: 'IF';
@@ -1289,6 +1332,7 @@ IGNORE: 'IGNORE';
IMMEDIATE: 'IMMEDIATE';
IN: 'IN';
INCLUDING: 'INCLUDING';
+INFERENCE: 'INFERENCE';
INITIAL: 'INITIAL';
INNER: 'INNER';
INPUT: 'INPUT';
@@ -1346,6 +1390,8 @@ MICROSECOND: 'US';
MIGRATE: 'MIGRATE';
MILLISECOND: 'MS';
MINUTE: 'MINUTE' | 'M';
+MODEL: 'MODEL';
+MODELS: 'MODELS';
MODIFY: 'MODIFY';
MONTH: 'MONTH' | 'MO';
NANOSECOND: 'NS';
@@ -1475,6 +1521,8 @@ TIME: 'TIME';
TIME_BOUND: 'TIME_BOUND';
TIME_COLUMN: 'TIME_COLUMN';
TIMEPARTITION: 'TIMEPARTITION';
+TIMER: 'TIMER';
+TIMER_XL: 'TIMER_XL';
TIMESERIES: 'TIMESERIES';
TIMESLOTID: 'TIMESLOTID';
TIMESTAMP: 'TIMESTAMP';
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index b3e8a67b8cc..9ac07b48dca 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -75,6 +75,20 @@ struct TInferenceResp {
2: required list<binary> inferenceResult
}
+struct IDataSchema {
+ 1: required string schemaName
+ 2: optional list<i64> timeRange
+}
+
+struct TTrainingReq {
+ 1: required string dbType
+ 2: required string modelId
+ 3: required string modelType
+ 4: optional list<IDataSchema> targetDataSchema;
+ 5: optional map<string, string> parameters;
+ 6: optional string existingModelId
+}
+
service IAINodeRPCService {
// -------------- For Config Node --------------
@@ -85,6 +99,8 @@ service IAINodeRPCService {
TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req)
+ common.TSStatus createTrainingTask(TTrainingReq req)
+
// -------------- For Data Node --------------
TInferenceResp inference(TInferenceReq req)
diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
index 42cee9388e5..6083ad513a6 100644
--- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
@@ -1035,6 +1035,37 @@ struct TGetModelInfoResp {
3: optional common.TEndPoint aiNodeAddress
}
+struct TUpdateModelInfoReq {
+ 1: required string modelId
+ 2: required i32 modelStatus
+ 3: optional string attributes
+ 4: optional list<i32> aiNodeIds
+ 5: optional i32 inputLength
+ 6: optional i32 outputLength
+}
+
+struct TDataSchemaForTable{
+ 1: required list<string> databaseList
+ 2: required list<string> tableList
+ 3: required string curDatabase
+}
+
+struct TDataSchemaForTree{
+ 1: required list<string> path
+}
+
+struct TCreateTrainingReq {
+ 1: required string modelId
+ 2: required string modelType
+ 3: required bool isTableModel
+ 4: optional TDataSchemaForTable dataSchemaForTable
+ 5: optional TDataSchemaForTree dataSchemaForTree
+ 6: optional bool useAllData
+ 7: optional map<string, string> parameters
+ 8: optional string existingModelId
+ 9: optional list<list<i64>> timeRanges
+}
+
// ====================================================
// Quota
// ====================================================
@@ -1899,6 +1930,10 @@ service IConfigNodeRPCService {
*/
TGetModelInfoResp getModelInfo(TGetModelInfoReq req)
+ common.TSStatus updateModelInfo(TUpdateModelInfoReq req)
+
+ common.TSStatus createTraining(TCreateTrainingReq req)
+
// ======================================================
// Quota
// ======================================================