This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 5a83067d05a [AINode] Support hubmixin models and modify pipeline
(#17334)
5a83067d05a is described below
commit 5a83067d05a511f8e1d8e2a0a480546d0c4e829b
Author: Leo <[email protected]>
AuthorDate: Tue Mar 24 21:45:56 2026 +0800
[AINode] Support hubmixin models and modify pipeline (#17334)
---
.../ainode/it/AINodeInstanceManagementIT.java | 70 ++++---
.../iotdb/ainode/it/AINodeModelManageIT.java | 59 ++++--
.../apache/iotdb/ainode/utils/AINodeTestUtils.java | 32 +++-
iotdb-core/ainode/iotdb/ainode/core/exception.py | 6 -
.../core/inference/pipeline/basic_pipeline.py | 209 +++++++++++++++++++--
.../iotdb/ainode/core/manager/inference_manager.py | 65 +++----
.../iotdb/ainode/core/manager/model_manager.py | 1 -
.../core/model/chronos2/pipeline_chronos2.py | 6 +-
.../iotdb/ainode/core/model/model_constants.py | 31 ++-
.../ainode/iotdb/ainode/core/model/model_info.py | 12 +-
.../ainode/iotdb/ainode/core/model/model_loader.py | 120 ++++--------
.../iotdb/ainode/core/model/model_storage.py | 106 +++++++++--
.../ainode/core/model/moirai2/pipeline_moirai2.py | 6 +-
.../ainode/core/model/sktime/pipeline_sktime.py | 7 +-
.../ainode/core/model/sundial/pipeline_sundial.py | 6 +-
.../ainode/core/model/timer_xl/pipeline_timer.py | 6 +-
iotdb-core/ainode/iotdb/ainode/core/model/utils.py | 118 +++++++++++-
.../ainode/iotdb/ainode/core/util/decorator.py | 8 +-
iotdb-core/ainode/iotdb/ainode/core/util/serde.py | 15 +-
.../function/tvf/ForecastTableFunction.java | 50 +----
.../function/tvf/TableFunctionUtils.java | 74 ++++++++
21 files changed, 709 insertions(+), 298 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
index f8aa27ce688..8356a055311 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
@@ -20,12 +20,16 @@
package org.apache.iotdb.ainode.it;
import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.AIClusterIT;
import org.apache.iotdb.itbase.env.BaseEnv;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
import java.sql.Connection;
import java.sql.ResultSet;
@@ -41,9 +45,13 @@ import static
org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpeci
import static
org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
+@RunWith(IoTDBTestRunner.class)
+@Category({AIClusterIT.class})
public class AINodeInstanceManagementIT {
- private static final Set<String> TARGET_DEVICES = new
HashSet<>(Arrays.asList("cpu", "0", "1"));
+ private static final String TARGET_DEVICES_STR = "0,1";
+ private static final Set<String> TARGET_DEVICES =
+ new HashSet<>(Arrays.asList(TARGET_DEVICES_STR.split(",")));
@BeforeClass
public static void setUp() throws Exception {
@@ -76,53 +84,57 @@ public class AINodeInstanceManagementIT {
// Ensure resources
try (ResultSet resultSet = statement.executeQuery("SHOW AI_DEVICES")) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "DeviceID");
+ checkHeader(resultSetMetaData, "DeviceId,DeviceType");
final Set<String> resultDevices = new HashSet<>();
while (resultSet.next()) {
- resultDevices.add(resultSet.getString("DeviceID"));
+ resultDevices.add(resultSet.getString("DeviceId"));
}
- Assert.assertEquals(TARGET_DEVICES, resultDevices);
+ Set<String> expected = new HashSet<>(TARGET_DEVICES);
+ expected.add("cpu");
+ Assert.assertEquals(expected, resultDevices);
}
// Load sundial to each device
- statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'",
TARGET_DEVICES));
- checkModelOnSpecifiedDevice(statement, "sundial",
TARGET_DEVICES.toString());
+ statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'",
TARGET_DEVICES_STR));
+ checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
// Unload sundial from each device
- statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'",
TARGET_DEVICES));
- checkModelNotOnSpecifiedDevice(statement, "sundial",
TARGET_DEVICES.toString());
+ statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'",
TARGET_DEVICES_STR));
+ checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
// Load timer_xl to each device
- statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'",
TARGET_DEVICES));
- checkModelOnSpecifiedDevice(statement, "timer_xl",
TARGET_DEVICES.toString());
+ statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'",
TARGET_DEVICES_STR));
+ checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR);
// Unload timer_xl from each device
- statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'",
TARGET_DEVICES));
- checkModelNotOnSpecifiedDevice(statement, "timer_xl",
TARGET_DEVICES.toString());
+ statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'",
TARGET_DEVICES_STR));
+ checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR);
}
private static final int LOOP_CNT = 10;
- @Test
+ // @Test
public void repeatLoadAndUnloadTest() throws SQLException,
InterruptedException {
try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
for (int i = 0; i < LOOP_CNT; i++) {
- statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\"");
- checkModelOnSpecifiedDevice(statement, "sundial",
TARGET_DEVICES.toString());
- statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\"");
- checkModelNotOnSpecifiedDevice(statement, "sundial",
TARGET_DEVICES.toString());
+ statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'",
TARGET_DEVICES_STR));
+ checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
+ statement.execute(
+ String.format("UNLOAD MODEL sundial FROM DEVICES '%s'",
TARGET_DEVICES_STR));
+ checkModelNotOnSpecifiedDevice(statement, "sundial",
TARGET_DEVICES_STR);
}
}
}
- @Test
+ // @Test
public void concurrentLoadAndUnloadTest() throws SQLException,
InterruptedException {
try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
for (int i = 0; i < LOOP_CNT; i++) {
- statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\"");
- statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\"");
+ statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'",
TARGET_DEVICES_STR));
+ statement.execute(
+ String.format("UNLOAD MODEL sundial FROM DEVICES '%s'",
TARGET_DEVICES_STR));
}
- checkModelNotOnSpecifiedDevice(statement, "sundial",
TARGET_DEVICES.toString());
+ checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
}
}
@@ -145,23 +157,23 @@ public class AINodeInstanceManagementIT {
private void failTest(Statement statement) {
errorTest(
statement,
- "LOAD MODEL unknown TO DEVICES \"cpu,0,1\"",
- "1505: Cannot load model [unknown], because it is neither a built-in
nor a fine-tuned model. You can use 'SHOW MODELS' to retrieve the available
models.");
+ "LOAD MODEL unknown TO DEVICES 'cpu,0,1'",
+ "1504: Model [unknown] is not registered yet. You can use 'SHOW
MODELS' to retrieve the available models.");
errorTest(
statement,
- "LOAD MODEL sundial TO DEVICES \"unknown\"",
- "1507: Device ID [unknown] is not available. You can use 'SHOW
AI_DEVICES' to retrieve the available devices.");
+ "LOAD MODEL sundial TO DEVICES '999'",
+ "1508: AIDevice ID [999] is not available. You can use 'SHOW
AI_DEVICES' to retrieve the available devices.");
errorTest(
statement,
- "UNLOAD MODEL sundial FROM DEVICES \"unknown\"",
- "1507: Device ID [unknown] is not available. You can use 'SHOW
AI_DEVICES' to retrieve the available devices.");
+ "UNLOAD MODEL sundial FROM DEVICES '999'",
+ "1508: AIDevice ID [999] is not available. You can use 'SHOW
AI_DEVICES' to retrieve the available devices.");
errorTest(
statement,
- "LOAD MODEL sundial TO DEVICES \"0,0\"",
+ "LOAD MODEL sundial TO DEVICES '0,0'",
"1509: Device ID list contains duplicate entries.");
errorTest(
statement,
- "UNLOAD MODEL sundial FROM DEVICES \"0,0\"",
+ "UNLOAD MODEL sundial FROM DEVICES '0,0'",
"1510: Device ID list contains duplicate entries.");
}
}
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
index 8ece0ba7523..9e78c8b025c 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
@@ -71,15 +71,22 @@ public class AINodeModelManageIT {
public void userDefinedModelManagementTestInTree() throws SQLException,
InterruptedException {
try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
- registerUserDefinedModel(statement);
- callInferenceTest(
- statement, new FakeModelInfo("user_chronos", "custom_t5",
"user_defined", "active"));
- dropUserDefinedModel(statement);
+ // Test transformers model (chronos2) in tree.
+ AINodeTestUtils.FakeModelInfo modelInfo =
+ new FakeModelInfo("user_chronos", "custom_t5", "user_defined",
"active");
+ registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
+ callInferenceTest(statement, modelInfo);
+ dropUserDefinedModel(statement, modelInfo.getModelId());
errorTest(
statement,
"create model origin_chronos using uri
\"file:///data/chronos2_origin\"",
"1505: 't5' is already used by a Transformers config, pick another
name.");
statement.execute("drop model origin_chronos");
+
+ // Test PytorchModelHubMixin model (mantis) in tree.
+ modelInfo = new FakeModelInfo("user_mantis", "custom_mantis",
"user_defined", "active");
+ registerUserDefinedModel(statement, modelInfo, "file:///data/mantis");
+ dropUserDefinedModel(statement, modelInfo.getModelId());
}
}
@@ -87,23 +94,35 @@ public class AINodeModelManageIT {
public void userDefinedModelManagementTestInTable() throws SQLException,
InterruptedException {
try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
- registerUserDefinedModel(statement);
- forecastTableFunctionTest(
- statement, new FakeModelInfo("user_chronos", "custom_t5",
"user_defined", "active"));
- dropUserDefinedModel(statement);
+ // Test transformers model (chronos2) in table.
+ AINodeTestUtils.FakeModelInfo modelInfo =
+ new FakeModelInfo("user_chronos", "custom_t5", "user_defined",
"active");
+ registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
+ forecastTableFunctionTest(statement, modelInfo);
+ dropUserDefinedModel(statement, modelInfo.getModelId());
errorTest(
statement,
"create model origin_chronos using uri
\"file:///data/chronos2_origin\"",
"1505: 't5' is already used by a Transformers config, pick another
name.");
statement.execute("drop model origin_chronos");
+
+ // Test PytorchModelHubMixin model (mantis) in table.
+ modelInfo = new FakeModelInfo("user_mantis", "custom_mantis",
"user_defined", "active");
+ registerUserDefinedModel(statement, modelInfo, "file:///data/mantis");
+ dropUserDefinedModel(statement, modelInfo.getModelId());
}
}
- private void registerUserDefinedModel(Statement statement)
+ public static void registerUserDefinedModel(
+ Statement statement, AINodeTestUtils.FakeModelInfo modelInfo, String uri)
throws SQLException, InterruptedException {
+ String modelId = modelInfo.getModelId();
+ String modelType = modelInfo.getModelType();
+ String category = modelInfo.getCategory();
+ final String CREATE_MODEL_TEMPLATE = "create model %s using uri \"%s\"";
final String alterConfigSQL = "set configuration
\"trusted_uri_pattern\"='.*'";
- final String registerSql = "create model user_chronos using uri
\"file:///data/chronos2\"";
- final String showSql = "SHOW MODELS user_chronos";
+ final String registerSql = String.format(CREATE_MODEL_TEMPLATE, modelId,
uri);
+ final String showSql = String.format("SHOW MODELS %s", modelId);
statement.execute(alterConfigSQL);
statement.execute(registerSql);
boolean loading = true;
@@ -112,13 +131,13 @@ public class AINodeModelManageIT {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
while (resultSet.next()) {
- String modelId = resultSet.getString(1);
- String modelType = resultSet.getString(2);
- String category = resultSet.getString(3);
+ String resultModelId = resultSet.getString(1);
+ String resultModelType = resultSet.getString(2);
+ String resultCategory = resultSet.getString(3);
String state = resultSet.getString(4);
- assertEquals("user_chronos", modelId);
- assertEquals("custom_t5", modelType);
- assertEquals("user_defined", category);
+ assertEquals(modelId, resultModelId);
+ assertEquals(modelType, resultModelType);
+ assertEquals(category, resultCategory);
if (state.equals("active")) {
loading = false;
} else if (state.equals("loading")) {
@@ -136,9 +155,9 @@ public class AINodeModelManageIT {
assertFalse(loading);
}
- private void dropUserDefinedModel(Statement statement) throws SQLException {
- final String showSql = "SHOW MODELS user_chronos";
- final String dropSql = "DROP MODEL user_chronos";
+ public static void dropUserDefinedModel(Statement statement, String modelId)
throws SQLException {
+ final String showSql = String.format("SHOW MODELS %s", modelId);
+ final String dropSql = String.format("DROP MODEL %s", modelId);
statement.execute(dropSql);
try (ResultSet resultSet = statement.executeQuery(showSql)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index 5a4dce53666..e41d3d4e0f9 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -51,10 +51,10 @@ public class AINodeTestUtils {
public static final Map<String, FakeModelInfo> BUILTIN_LTSM_MAP =
Stream.of(
- new AbstractMap.SimpleEntry<>(
- "sundial", new FakeModelInfo("sundial", "sundial",
"builtin", "active")),
new AbstractMap.SimpleEntry<>(
"timer_xl", new FakeModelInfo("timer_xl", "timer",
"builtin", "active")),
+ new AbstractMap.SimpleEntry<>(
+ "sundial", new FakeModelInfo("sundial", "sundial",
"builtin", "active")),
new AbstractMap.SimpleEntry<>(
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin",
"active")),
new AbstractMap.SimpleEntry<>(
@@ -171,7 +171,7 @@ public class AINodeTestUtils {
LOGGER.info("Model {} found in device {}, count {}", loadedModelId,
deviceId, count);
if (loadedModelId.equals(modelId) &&
targetDevices.contains(deviceId) && count > 0) {
foundDevices.add(deviceId);
- LOGGER.info("Model {} is loaded to device {}", modelId, device);
+ LOGGER.info("Model {} is loaded to device {}", modelId, deviceId);
}
}
if (foundDevices.containsAll(targetDevices)) {
@@ -252,6 +252,32 @@ public class AINodeTestUtils {
}
}
+ /** Prepare db.AI2(s0 FLOAT,...) with 2880 rows of data in table. */
+ public static void prepareDataInTable2() throws SQLException {
+ try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ statement.execute("CREATE DATABASE db");
+ statement.execute(
+ "CREATE TABLE db.AI2 (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32
FIELD, s3 INT64 FIELD, s4 FLOAT FIELD, s5 DOUBLE FIELD, s6 INT32 FIELD, s7
INT64 FIELD, s8 FLOAT FIELD, s9 DOUBLE FIELD)");
+ for (int i = 0; i < 2880; i++) {
+ statement.execute(
+ String.format(
+ "INSERT INTO db.AI2(time,s0,s1,s2,s3,s4,s5,s6,s7,s8,s9)
VALUES(%d,%f,%f,%d,%d,%f,%f,%d,%d,%f,%f)",
+ i,
+ (float) i,
+ (double) i,
+ i,
+ i,
+ (float) (i * 2),
+ (double) (i * 2),
+ i * 2,
+ i * 2,
+ (float) (i * 3),
+ (double) (i * 3)));
+ }
+ }
+ }
+
public static class FakeModelInfo {
private final String modelId;
diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py
b/iotdb-core/ainode/iotdb/ainode/core/exception.py
index b007ee58c48..b76baa3e2e2 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/exception.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py
@@ -15,12 +15,6 @@
# specific language governing permissions and limitations
# under the License.
#
-import re
-
-from iotdb.ainode.core.model.model_constants import (
- MODEL_CONFIG_FILE_IN_YAML,
- MODEL_WEIGHTS_FILE_IN_PT,
-)
class _BaseException(Exception):
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
index 5d0026522a1..f4bf914a846 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
@@ -56,17 +56,23 @@ class ForecastPipeline(BasicPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, **model_kwargs)
- def preprocess(
+ # ========================= Preprocess =========================
+ def preprocess(self, inputs, **infer_kwargs):
+ inputs = self._base_preprocess(inputs, **infer_kwargs)
+ return self._preprocess(inputs, **infer_kwargs)
+
+ def _base_preprocess(
self,
- inputs: list[dict[str, dict[str, torch.Tensor] | torch.Tensor]],
+ inputs,
**infer_kwargs,
- ):
+ ) -> list[dict[str, dict[str, torch.Tensor] | torch.Tensor]]:
"""
- Preprocess the input data before passing it to the model for
inference, validating the shape and type of the input data.
+ The common preprocess logic for all forecast pipelines,
+ validating the shape and type of the input data.
Args:
inputs (list[dict]):
- The input data, a list of dictionaries, where each dictionary
contains:
+ The input data, expected a list of dictionaries, where each
dictionary contains:
- 'targets': A tensor (1D or 2D) of shape (input_length,)
or (target_count, input_length).
- 'past_covariates': A dictionary of tensors (optional),
where each tensor has shape (input_length,).
- 'future_covariates': A dictionary of tensors (optional),
where each tensor has shape (input_length,).
@@ -79,7 +85,11 @@ class ForecastPipeline(BasicPipeline):
ValueError: If the input format is incorrect (e.g., missing keys,
invalid tensor shapes).
Returns:
- The preprocessed inputs, validated and ready for model inference.
+ list[dict]:
+ The validated input data, a list of dictionaries, where each
dictionary contains:
+ - 'targets': A tensor (1D or 2D) of shape (input_length,)
or (target_count, input_length).
+ - 'past_covariates': A dictionary of tensors (optional),
where each tensor has shape (input_length,).
+ - 'future_covariates': A dictionary of tensors (optional),
where each tensor has shape (input_length,).
"""
if isinstance(inputs, list):
@@ -211,10 +221,34 @@ class ForecastPipeline(BasicPipeline):
)
return inputs
+ def _preprocess(
+ self,
+ inputs: list[dict[str, dict[str, torch.Tensor] | torch.Tensor]],
+ **infer_kwargs,
+ ):
+ """
+ Optional hook for subclasses to implement custom preprocessing logic.
+ This method is called after the base validation in `_base_preprocess`,
so the inputs
+ are unified when this method is invoked.
+
+ Args:
+ inputs (list[dict]): The validated input data, a list of
dictionaries, where each dictionary contains:
+ - 'targets': A tensor of shape (input_length,) or
(target_count, input_length).
+ - 'past_covariates' (optional): A dictionary of 1-D tensors,
each of shape (input_length,).
+ - 'future_covariates' (optional): A dictionary of 1-D tensors,
each of shape (output_length,),
+ whose keys are guaranteed to be a subset of
'past_covariates'.
+ **infer_kwargs: Additional keyword arguments passed through from
the pipeline.
+
+ Returns:
+ inputs: The modified inputs ready for model inference.
+ """
+ return inputs
+
+ # ========================== Forecast ==========================
@abstractmethod
def forecast(self, inputs, **infer_kwargs):
"""
- Perform forecasting on the given inputs.
+ Perform forecasting on the given inputs, which must be implemented by
the subclasses.
Parameters:
inputs: The input data used for making predictions. The type and
structure
@@ -225,13 +259,35 @@ class ForecastPipeline(BasicPipeline):
Returns:
The forecasted output, which will depend on the specific model's
implementation.
"""
- pass
+ raise NotImplementedError("forecast not implemented")
+
+ # ========================= Postprocess ========================
+ def postprocess(self, outputs, **infer_kwargs):
+ outputs = self._postprocess(outputs, **infer_kwargs)
+ return self._base_postprocess(outputs, **infer_kwargs)
+
+ def _postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]:
+ """
+ Optional hook for subclasses to implement custom postprocessing logic.
+ This method is called before the base validation in
`_base_postprocess`, so the outputs
+ must conform to the expected format when this method returns.
+
+ Args:
+ outputs: The raw model outputs.
+ **infer_kwargs: Additional keyword arguments passed through from
the pipeline.
- def postprocess(
+ Returns:
+ list[torch.Tensor]: The modified outputs, which must be a list of
2-D tensors
+ with shape (target_count, output_length), as this will be
validated by `_base_postprocess`.
+ """
+ return outputs
+
+ def _base_postprocess(
self, outputs: list[torch.Tensor], **infer_kwargs
) -> list[torch.Tensor]:
"""
- Postprocess the model outputs after inference, validating the shape of
the output data and ensures it matches the expected dimensions.
+ The common postprocess logic for all forecast pipelines.
+ validating the shape of the output data and ensures it matches the
expected dimensions.
Args:
outputs:
@@ -262,14 +318,114 @@ class ClassificationPipeline(BasicPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, **model_kwargs)
- def preprocess(self, inputs, **kwargs):
+ # ========================= Preprocess =========================
+ def preprocess(self, inputs, **infer_kwargs):
+ inputs = self._base_preprocess(inputs, **infer_kwargs)
+ return self._preprocess(inputs, **infer_kwargs)
+
+ def _base_preprocess(self, inputs, **infer_kwargs) -> torch.Tensor:
+ """
+ The common preprocess logic for all classification pipelines,
+ validating and preprocess the inputs.
+
+ Args:
+ inputs: The input data, expected to be a 3D-tensor.
+ **infer_kwargs: Additional inference parameters.
+
+ Returns:
+ torch.Tensor:
+ The preprocessed inputs, which will be a 3D-tensor with shape
(batch_size, variable_count, sequence_length).
+
+ Raises:
+ ValueError: If the input format is incorrect.
+ """
+ if isinstance(inputs, torch.Tensor) and inputs.ndim == 3:
+ return inputs
+ else:
+ raise ValueError(
+ f"The inputs should be a 3D-tensor, but got {type(inputs)}
with shape {tuple(inputs.shape)}."
+ )
+
+ def _preprocess(self, inputs: torch.Tensor, **infer_kwargs):
+ """
+ Optional hook for subclasses to implement custom preprocessing logic.
+ This method is called after the base validation in `_base_preprocess`,
so the inputs
+ are unified when this method is invoked.
+
+ Args:
+ inputs (torch.Tensor): The validated input data, a 3D tensor.
+ **infer_kwargs: Additional keyword arguments passed through from
the pipeline.
+
+ Returns:
+ torch.Tensor: The modified inputs ready for model inference.
+ """
return inputs
+ # ========================== Classify ==========================
@abstractmethod
- def classify(self, inputs, **kwargs):
- pass
+ def classify(self, inputs, **infer_kwargs):
+ """
+ Perform classification on the given inputs, which must be implemented
by the subclasses.
+
+ Parameters:
+ inputs: The input data used for making classification. The type
and structure
+ depend on the specific implementation of the model.
+ **infer_kwargs: Additional inference parameters.
+
+ Returns:
+ The classified result, which will depend on the specific model's
implementation.
+ """
+ raise NotImplementedError("classify not implemented")
+
+ # ========================= Postprocess ========================
+ def postprocess(self, outputs, **infer_kwargs):
+ outputs = self._postprocess(outputs, **infer_kwargs)
+ return self._base_postprocess(outputs, **infer_kwargs)
+
+ def _postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]:
+ """
+ Optional hook for subclasses to implement custom postprocessing logic.
+ This method is called before the base validation in
`_base_postprocess`, so the outputs
+ must conform to the expected format when this method returns.
+
+ Args:
+ outputs: The raw model outputs.
+ **infer_kwargs: Additional keyword arguments passed through from
the pipeline.
+
+ Returns:
+ list[torch.Tensor]: The modified outputs, which must be a list of
tensors,
+ as this will be validated by `_base_postprocess`.
+
+ Raises:
+ ValueError: If the output format is incorrect.
+ """
+ return outputs
- def postprocess(self, outputs, **kwargs):
+ def _base_postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]:
+ """
+ The common postprocess logic for all classification pipelines,
+ validating the shape of the output data.
+
+ Args:
+ outputs (list[torch.Tensor]):
+ The output from the model.
+ **infer_kwargs:
+ Additional keyword arguments.
+
+ Returns:
+ list[torch.Tensor]:
+ The postprocessed outputs.
+
+ Raises:
+ ValueError:
+ If the output format is incorrect.
+ """
+ if not isinstance(outputs, list) or any(
+ not isinstance(output, torch.Tensor) for output in outputs
+ ):
+ raise ValueError(
+ f"The outputs should be a list of tensors, but got
{type(outputs)}."
+ )
return outputs
@@ -277,12 +433,29 @@ class ChatPipeline(BasicPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, **model_kwargs)
- def preprocess(self, inputs, **kwargs):
+ # ========================= Preprocess =========================
+ def preprocess(self, inputs, **infer_kwargs):
+ inputs = self._base_preprocess(inputs, **infer_kwargs)
+ return self._preprocess(inputs, **infer_kwargs)
+
+ def _base_preprocess(self, inputs, **infer_kwargs):
return inputs
+ def _preprocess(self, inputs, **infer_kwargs):
+ return inputs
+
+ # ========================== Chat ==========================
@abstractmethod
- def chat(self, inputs, **kwargs):
- pass
+ def chat(self, inputs, **infer_kwargs):
+ raise NotImplementedError("chat not implemented")
+
+ # ========================= Postprocess ========================
+ def postprocess(self, outputs, **infer_kwargs):
+ outputs = self._postprocess(outputs, **infer_kwargs)
+ return self._base_postprocess(outputs, **infer_kwargs)
+
+ def _postprocess(self, outputs, **infer_kwargs):
+ return outputs
- def postprocess(self, outputs, **kwargs):
+ def _base_postprocess(self, outputs, **infer_kwargs):
return outputs
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index 07ca8a63bce..8dcf03627dd 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -178,30 +178,17 @@ class InferenceManager:
def _do_inference_and_construct_resp(
self,
model_id: str,
- model_inputs_list: list[dict[str, torch.Tensor | dict[str,
torch.Tensor]]],
- output_length: int,
+ model_inputs,
inference_attrs: dict,
- **kwargs,
) -> list[bytes]:
- auto_adapt = kwargs.get("auto_adapt", True)
- if (
- output_length
- >
AINodeDescriptor().get_config().get_ain_inference_max_output_length()
- ):
- raise NumericalRangeException(
- "output_length",
- output_length,
- 1,
-
AINodeDescriptor().get_config().get_ain_inference_max_output_length(),
- )
if self._pool_controller.has_running_pools(model_id):
+ # Only forecast task can use pool
+ output_length = int(inference_attrs.get("output_length", 96))
infer_req = InferenceRequest(
req_id=generate_req_id(),
model_id=model_id,
- inputs=torch.stack(
- [data["targets"] for data in model_inputs_list], dim=0
- ),
+ inputs=torch.stack([data["targets"] for data in model_inputs],
dim=0),
output_length=output_length,
)
outputs = self._process_request(infer_req)
@@ -210,23 +197,17 @@ class InferenceManager:
inference_pipeline = load_pipeline(
model_info, device=self._backend.torch_device("cpu")
)
- inputs = inference_pipeline.preprocess(
- model_inputs_list,
- output_length=output_length,
- auto_adapt=auto_adapt,
- )
+ inputs = inference_pipeline.preprocess(model_inputs,
**inference_attrs)
if isinstance(inference_pipeline, ForecastPipeline):
- outputs = inference_pipeline.forecast(
- inputs, output_length=output_length, **inference_attrs
- )
+ outputs = inference_pipeline.forecast(inputs,
**inference_attrs)
elif isinstance(inference_pipeline, ClassificationPipeline):
- outputs = inference_pipeline.classify(inputs)
+ outputs = inference_pipeline.classify(inputs,
**inference_attrs)
elif isinstance(inference_pipeline, ChatPipeline):
- outputs = inference_pipeline.chat(inputs)
+ outputs = inference_pipeline.chat(inputs, **inference_attrs)
else:
outputs = None
logger.error("[Inference] Unsupported pipeline type.")
- outputs = inference_pipeline.postprocess(outputs)
+ outputs = inference_pipeline.postprocess(outputs,
**inference_attrs)
# convert tensor into tsblock for the output in each batch
resp_list = []
@@ -235,7 +216,7 @@ class InferenceManager:
resp_list.append(resp)
return resp_list
- def _run(
+ def _run_forecast(
self,
req,
data_getter,
@@ -249,14 +230,26 @@ class InferenceManager:
inputs = convert_tsblock_to_tensor(raw)
inference_attrs = extract_attrs(req)
- output_length = int(inference_attrs.pop("output_length", 96))
+ output_length = int(inference_attrs.get("output_length", 96))
+ if (
+ output_length
+ >
AINodeDescriptor().get_config().get_ain_inference_max_output_length()
+ ):
+ raise NumericalRangeException(
+ "output_length",
+ output_length,
+ 1,
+ AINodeDescriptor()
+ .get_config()
+ .get_ain_inference_max_output_length(),
+ )
- model_inputs_list: list[
- dict[str, torch.Tensor | dict[str, torch.Tensor]]
- ] = [{"targets": inputs[0]}]
+ model_inputs: list[dict[str, torch.Tensor | dict[str,
torch.Tensor]]] = [
+ {"targets": inputs[0]}
+ ]
resp_list = self._do_inference_and_construct_resp(
- model_id, model_inputs_list, output_length, inference_attrs
+ model_id, model_inputs, inference_attrs
)
return resp_cls(
@@ -271,7 +264,7 @@ class InferenceManager:
return resp_cls(status, empty)
def forecast(self, req: TForecastReq):
- return self._run(
+ return self._run_forecast(
req,
data_getter=lambda r: r.inputData,
extract_attrs=lambda r: {
@@ -283,7 +276,7 @@ class InferenceManager:
)
def inference(self, req: TInferenceReq):
- return self._run(
+ return self._run_forecast(
req,
data_getter=lambda r: r.dataset,
extract_attrs=lambda r: {
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
index ff4226e734f..e1d67873b53 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
@@ -62,7 +62,6 @@ class ModelManager:
get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
)
except Exception as e:
- # Catch-all for other exceptions (mainly from transformers
implementation)
return TRegisterModelResp(
get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
)
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
index b28f8f35a66..01ff78ba48d 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
@@ -36,7 +36,7 @@ class Chronos2Pipeline(ForecastPipeline):
def __init__(self, model_info, **model_kwargs):
super().__init__(model_info, **model_kwargs)
- def preprocess(self, inputs, **infer_kwargs):
+ def _preprocess(self, inputs, **infer_kwargs):
"""
Preprocess input data of chronos2.
@@ -62,7 +62,6 @@ class Chronos2Pipeline(ForecastPipeline):
- 'future_covariates' (optional): dict of str to torch.Tensor
Unchanged future covariates.
"""
- super().preprocess(inputs, **infer_kwargs)
for item in inputs:
item["target"] = item.pop("targets")
return inputs
@@ -449,7 +448,7 @@ class Chronos2Pipeline(ForecastPipeline):
return prediction
- def postprocess(
+ def _postprocess(
self, outputs: list[torch.Tensor], **infer_kwargs
) -> list[torch.Tensor]:
"""
@@ -472,5 +471,4 @@ class Chronos2Pipeline(ForecastPipeline):
# If 0.5 quantile is not provided,
# get the mean of all quantiles
outputs_list.append(output.mean(dim=1))
- super().postprocess(outputs_list, **infer_kwargs)
return outputs_list
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
index 9f1801b5073..a9495f16714 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
@@ -17,11 +17,32 @@
#
from enum import Enum
-# Model file constants
-MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors"
-MODEL_CONFIG_FILE_IN_JSON = "config.json"
-MODEL_WEIGHTS_FILE_IN_PT = "model.pt"
-MODEL_CONFIG_FILE_IN_YAML = "config.yaml"
+# ==================== File Name Constants ====================
+#
+# All file names used for model persistence are defined here.
+# Never hard-code these strings elsewhere – always import from
+# this module.
+
+# -- Config files --
+CONFIG_JSON = "config.json"
+CONFIG_YAML = "config.yaml"
+
+# -- Full model weights --
+MODEL_SAFETENSORS = "model.safetensors"
+MODEL_PT = "model.pt"
+MODEL_BIN = "pytorch_model.bin" # legacy HuggingFace format
+
+# -- Ordered tuples for detection / searching --
+MODEL_WEIGHT_FILES = (MODEL_SAFETENSORS, MODEL_PT, MODEL_BIN)
+
+# -- Backward-compatible aliases (deprecated, will be removed) --
+MODEL_WEIGHTS_FILE_IN_SAFETENSORS = MODEL_SAFETENSORS
+MODEL_CONFIG_FILE_IN_JSON = CONFIG_JSON
+MODEL_WEIGHTS_FILE_IN_PT = MODEL_PT
+MODEL_CONFIG_FILE_IN_YAML = CONFIG_YAML
+
+
+# ==================== Enumerations ====================
class ModelCategory(Enum):
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
index f253fb1e56f..da752cbd784 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
@@ -31,6 +31,7 @@ class ModelInfo:
pipeline_cls: str = "",
repo_id: str = "",
auto_map: Optional[Dict] = None,
+ hub_mixin_cls: Optional[str] = None,
transformers_registered: bool = False,
):
self.model_id = model_id
@@ -39,16 +40,16 @@ class ModelInfo:
self.state = state
self.pipeline_cls = pipeline_cls
self.repo_id = repo_id
- self.auto_map = auto_map # If exists, indicates it's a Transformers
model
- self.transformers_registered = (
- transformers_registered # Internal flag: whether registered to
Transformers
- )
+ self.auto_map = auto_map
+ self.hub_mixin_cls = hub_mixin_cls
+ self.transformers_registered = transformers_registered
def __repr__(self):
return (
f"ModelInfo(model_id={self.model_id},
model_type={self.model_type}, "
f"category={self.category.value}, state={self.state.value}, "
- f"has_auto_map={self.auto_map is not None})"
+ f"has_auto_map={self.auto_map is not None}), "
+ f"has_hub_mix_in_cls={self.hub_mixin_cls is not None})"
)
@@ -144,6 +145,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
"AutoConfig": "config.Chronos2CoreConfig",
"AutoModelForCausalLM": "model.Chronos2Model",
},
+ transformers_registered=True,
),
"moirai2": ModelInfo(
model_id="moirai2",
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
index 1da07cb9fef..2476a05856d 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
@@ -17,28 +17,21 @@
#
import os
-from pathlib import Path
from typing import Any
import torch
-from transformers import (
- AutoConfig,
- AutoModelForCausalLM,
- AutoModelForNextSentencePrediction,
- AutoModelForSeq2SeqLM,
- AutoModelForSequenceClassification,
- AutoModelForTimeSeriesPrediction,
- AutoModelForTokenClassification,
-)
-from iotdb.ainode.core.config import AINodeDescriptor
from iotdb.ainode.core.exception import ModelNotExistException
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.manager.device_manager import DeviceManager
-from iotdb.ainode.core.model.model_constants import ModelCategory
+from iotdb.ainode.core.model.model_constants import MODEL_PT
from iotdb.ainode.core.model.model_info import ModelInfo
from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model
-from iotdb.ainode.core.model.utils import import_class_from_path,
temporary_sys_path
+from iotdb.ainode.core.model.utils import (
+ get_model_and_config_by_auto_class,
+ get_model_and_config_by_native_code,
+ get_model_path,
+)
logger = Logger()
BACKEND = DeviceManager()
@@ -46,83 +39,64 @@ BACKEND = DeviceManager()
def load_model(model_info: ModelInfo, **model_kwargs) -> Any:
if model_info.auto_map is not None:
- model = load_model_from_transformers(model_info, **model_kwargs)
+ model = load_transformers_model(model_info, **model_kwargs)
+ elif model_info.hub_mixin_cls is not None:
+ model = _load_hub_mixin_model(model_info, **model_kwargs)
else:
if model_info.model_type == "sktime":
model = create_sktime_model(model_info.model_id)
else:
- model = load_model_from_pt(model_info, **model_kwargs)
+ model = _load_torchscript_model(model_info, **model_kwargs)
logger.info(
- f"Model {model_info.model_id} loaded to device {model.device if
model_info.model_type != 'sktime' else 'cpu'} successfully."
+ f"Model {model_info.model_id} loaded to device
{next(model.parameters()).device if model_info.model_type != 'sktime' else
'cpu'} successfully."
)
return model
-def load_model_from_transformers(model_info: ModelInfo, **model_kwargs):
+def load_transformers_model(model_info: ModelInfo, **model_kwargs):
device_map = model_kwargs.get("device_map", "cpu")
+ trust_remote_code = model_kwargs.get("trust_remote_code", True)
train_from_scratch = model_kwargs.get("train_from_scratch", False)
- model_path = os.path.join(
- os.getcwd(),
- AINodeDescriptor().get_config().get_ain_models_dir(),
- model_info.category.value,
- model_info.model_id,
- )
+ model_path = get_model_path(model_info)
- config_str = model_info.auto_map.get("AutoConfig", "")
- model_str = model_info.auto_map.get("AutoModelForCausalLM", "")
+ model_class, config_instance =
get_model_and_config_by_native_code(model_info)
+ if model_class is None:
+ model_class, config_instance =
get_model_and_config_by_auto_class(model_path)
- if model_info.category == ModelCategory.BUILTIN:
- module_name = (
- AINodeDescriptor().get_config().get_ain_models_builtin_dir()
- + "."
- + model_info.model_id
+ # ---- Load base model ----
+ if train_from_scratch:
+ model = model_class.from_config(
+ config_instance, trust_remote_code=trust_remote_code
)
- config_cls = import_class_from_path(module_name, config_str)
- model_cls = import_class_from_path(module_name, model_str)
- elif model_str and config_str:
- module_parent = str(Path(model_path).parent.absolute())
- with temporary_sys_path(module_parent):
- config_cls = import_class_from_path(model_info.model_id,
config_str)
- model_cls = import_class_from_path(model_info.model_id, model_str)
else:
- config_cls = AutoConfig.from_pretrained(model_path)
- if type(config_cls) in
AutoModelForTimeSeriesPrediction._model_mapping.keys():
- model_cls = AutoModelForTimeSeriesPrediction
- elif (
- type(config_cls) in
AutoModelForNextSentencePrediction._model_mapping.keys()
- ):
- model_cls = AutoModelForNextSentencePrediction
- elif type(config_cls) in AutoModelForSeq2SeqLM._model_mapping.keys():
- model_cls = AutoModelForSeq2SeqLM
- elif (
- type(config_cls) in
AutoModelForSequenceClassification._model_mapping.keys()
- ):
- model_cls = AutoModelForSequenceClassification
- elif type(config_cls) in
AutoModelForTokenClassification._model_mapping.keys():
- model_cls = AutoModelForTokenClassification
- else:
- model_cls = AutoModelForCausalLM
+ model = model_class.from_pretrained(
+ model_path,
+ config=config_instance,
+ trust_remote_code=trust_remote_code,
+ )
- if train_from_scratch:
- model = model_cls.from_config(config_cls)
- else:
- model = model_cls.from_pretrained(model_path)
+ return BACKEND.move_model(model, device_map)
+
+def _load_hub_mixin_model(model_info: ModelInfo, **model_kwargs):
+ device_map = model_kwargs.get("device_map", "cpu")
+ model_path = get_model_path(model_info)
+ model_class, _ = get_model_and_config_by_native_code(model_info)
+ if model_class is None:
+ logger.error(f"Model class not found for '{model_info.model_id}'")
+ raise ModelNotExistException(model_info.model_id)
+ # Load model
+ model = model_class.from_pretrained(model_path)
return BACKEND.move_model(model, device_map)
-def load_model_from_pt(model_info: ModelInfo, **kwargs):
+def _load_torchscript_model(model_info: ModelInfo, **kwargs):
device_map = kwargs.get("device_map", "cpu")
acceleration = kwargs.get("acceleration", False)
- model_path = os.path.join(
- os.getcwd(),
- AINodeDescriptor().get_config().get_ain_models_dir(),
- model_info.category.value,
- model_info.model_id,
- )
- model_file = os.path.join(model_path, "model.pt")
+ model_path = get_model_path(model_info)
+ model_file = os.path.join(model_path, MODEL_PT)
if not os.path.exists(model_file):
logger.error(f"Model file not found at {model_file}.")
raise ModelNotExistException(model_file)
@@ -134,17 +108,3 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs):
except Exception as e:
logger.warning(f"acceleration failed, fallback to normal mode:
{str(e)}")
return BACKEND.move_model(model, device_map)
-
-
-def load_model_for_efficient_inference():
- # TODO: An efficient model loading method for inference based on
model_arguments
- pass
-
-
-def load_model_for_powerful_finetune():
- # TODO: An powerful model loading method for finetune based on
model_arguments
- pass
-
-
-def unload_model():
- pass
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
index 2cfb07fb56a..b7799df67a5 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
@@ -21,10 +21,15 @@ import json
import os
import shutil
from pathlib import Path
-from typing import Dict, List, Optional
-
-from huggingface_hub import hf_hub_download
-from transformers import AutoConfig, AutoModelForCausalLM
+from typing import Dict, Optional
+
+from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ PretrainedConfig,
+ PreTrainedModel,
+)
from iotdb.ainode.core.config import AINodeDescriptor
from iotdb.ainode.core.constant import TSStatusCode
@@ -35,8 +40,8 @@ from iotdb.ainode.core.exception import (
)
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.model.model_constants import (
- MODEL_CONFIG_FILE_IN_JSON,
- MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
+ CONFIG_JSON,
+ MODEL_SAFETENSORS,
ModelCategory,
ModelStates,
UriType,
@@ -147,13 +152,13 @@ class ModelStorage:
def _download_model_if_necessary() -> bool:
"""Returns: True if the model is existed or downloaded
successfully, False otherwise."""
repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id
- weights_path = os.path.join(model_dir,
MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
- config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON)
+ weights_path = os.path.join(model_dir, MODEL_SAFETENSORS)
+ config_path = os.path.join(model_dir, CONFIG_JSON)
if not os.path.exists(weights_path):
try:
hf_hub_download(
repo_id=repo_id,
- filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
+ filename=MODEL_SAFETENSORS,
local_dir=model_dir,
)
except Exception as e:
@@ -165,7 +170,7 @@ class ModelStorage:
try:
hf_hub_download(
repo_id=repo_id,
- filename=MODEL_CONFIG_FILE_IN_JSON,
+ filename=CONFIG_JSON,
local_dir=model_dir,
)
except Exception as e:
@@ -191,7 +196,7 @@ class ModelStorage:
self._models_dir,
ModelCategory.BUILTIN.value,
model_id,
- MODEL_CONFIG_FILE_IN_JSON,
+ CONFIG_JSON,
)
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
@@ -218,15 +223,17 @@ class ModelStorage:
def _process_user_defined_model_directory(self, model_dir: str, model_id:
str):
"""Handling the discovery logic for a user-defined model directory."""
- config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON)
+ config_path = os.path.join(model_dir, CONFIG_JSON)
model_type = ""
auto_map = {}
pipeline_cls = ""
+ hub_mixin_cls = ""
if os.path.exists(config_path):
config = load_model_config_in_json(config_path)
model_type = config.get("model_type", "")
auto_map = config.get("auto_map", None)
pipeline_cls = config.get("pipeline_cls", "")
+ hub_mixin_cls = config.get("hub_mixin_cls", "")
model_info = ModelInfo(
model_id=model_id,
model_type=model_type,
@@ -234,6 +241,7 @@ class ModelStorage:
state=ModelStates.ACTIVE,
pipeline_cls=pipeline_cls,
auto_map=auto_map,
+ hub_mixin_cls=hub_mixin_cls,
transformers_registered=False, # Lazy registration
)
with self._lock_pool.get_lock(model_id).write_lock():
@@ -284,6 +292,7 @@ class ModelStorage:
model_type = config.get("model_type", "")
auto_map = config.get("auto_map")
pipeline_cls = config.get("pipeline_cls", "")
+ hub_mixin_cls = config.get("hub_mixin_cls", "")
with self._lock_pool.get_lock(model_id).write_lock():
model_info = ModelInfo(
@@ -293,6 +302,7 @@ class ModelStorage:
state=ModelStates.ACTIVE,
pipeline_cls=pipeline_cls,
auto_map=auto_map,
+ hub_mixin_cls=hub_mixin_cls,
transformers_registered=False, # Register later
)
self._models[ModelCategory.USER_DEFINED.value][model_id] =
model_info
@@ -308,6 +318,17 @@ class ModelStorage:
f"Failed to register Transformers model {model_id},
because {e}"
)
raise e
+ elif hub_mixin_cls:
+ # PyTorchModelHubMixin model: immediately register
+ try:
+ if self._register_hub_mixin_model(model_info):
+ model_info.transformers_registered = True
+ except Exception as e:
+ model_info.state = ModelStates.INACTIVE
+ logger.error(
+ f"Failed to register HubMixin model {model_id},
because {e}"
+ )
+ raise e
else:
# Other type models: only log
self._register_other_model(model_info)
@@ -321,6 +342,7 @@ class ModelStorage:
True if registration is successful
Raises:
Exception: Transformers internal exception if registration fails
+ ValueError: If class is invalid
"""
auto_map = model_info.auto_map
if not auto_map:
@@ -338,6 +360,14 @@ class ModelStorage:
config_class = import_class_from_path(
model_info.model_id, auto_config_path
)
+ # Validate config_class is a subclass of PretrainedConfig
+ if not (
+ isinstance(config_class, type)
+ and issubclass(config_class, PretrainedConfig)
+ ):
+ raise ValueError(
+ f"AutoConfig class '{auto_config_path}' must be a
subclass of PretrainedConfig"
+ )
AutoConfig.register(model_info.model_type, config_class)
logger.info(
f"Registered AutoConfig: {model_info.model_type} ->
{auto_config_path}"
@@ -346,6 +376,14 @@ class ModelStorage:
model_class = import_class_from_path(
model_info.model_id, auto_model_path
)
+ # Validate model_class is a subclass of PreTrainedModel
+ if not (
+ isinstance(model_class, type)
+ and issubclass(model_class, PreTrainedModel)
+ ):
+ raise ValueError(
+ f"AutoModelForCausalLM class '{auto_model_path}' must
be a subclass of PreTrainedModel"
+ )
AutoModelForCausalLM.register(config_class, model_class)
logger.info(
f"Registered AutoModelForCausalLM: {config_class.__name__}
-> {auto_model_path}"
@@ -357,6 +395,48 @@ class ModelStorage:
)
raise e
+ def _register_hub_mixin_model(self, model_info: ModelInfo) -> bool:
+ """
+ Register PyTorchModelHubMixin model (internal method).
+ For now, just validate the class.
+
+ Returns:
+ True if registration is successful
+ Raises:
+ ValueError: If class is invalid
+ Exception: For other errors
+ """
+ hub_mixin_cls = model_info.hub_mixin_cls
+ if not hub_mixin_cls:
+ return False
+
+ try:
+ model_path = os.path.join(
+ self._models_dir, model_info.category.value,
model_info.model_id
+ )
+ module_parent = str(Path(model_path).parent.absolute())
+ with temporary_sys_path(module_parent):
+ model_class = import_class_from_path(model_info.model_id,
hub_mixin_cls)
+
+ # Validate that the class inherits from PyTorchModelHubMixin
+ if not issubclass(model_class, PyTorchModelHubMixin):
+ raise ValueError(
+ f"Class '{model_class}' does not inherit from "
+ "PyTorchModelHubMixin."
+ )
+
+ logger.info(
+ f"Registered PyTorchModelHubMixin model: "
+ f"{model_info.model_id} -> {hub_mixin_cls}"
+ )
+ return True
+
+ except Exception as e:
+ logger.warning(
+ f"Failed to register PyTorchModelHubMixin model
{model_info.model_id}: {e}."
+ )
+ raise e
+
def _register_other_model(self, model_info: ModelInfo):
"""Register other type models (non-Transformers models)"""
logger.info(
@@ -526,7 +606,7 @@ class ModelStorage:
return self._models[category.value].get(model_id)
else:
# Category not specified, need to traverse all dictionaries, use
global lock
- with self._lock_pool.get_lock("").read_lock():
+ with self._lock_pool.get_lock(model_id).read_lock():
for category_dict in self._models.values():
if model_id in category_dict:
return category_dict[model_id]
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py
b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py
index fe2fb632362..666c3063df3 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/moirai2/pipeline_moirai2.py
@@ -30,7 +30,7 @@ class Moirai2Pipeline(ForecastPipeline):
def __init__(self, model_info, **model_kwargs):
super().__init__(model_info, **model_kwargs)
- def preprocess(self, inputs, **infer_kwargs):
+ def _preprocess(self, inputs, **infer_kwargs):
"""
Preprocess input data for moirai2.
@@ -48,7 +48,6 @@ class Moirai2Pipeline(ForecastPipeline):
list of dict
Processed inputs compatible with moirai2 format (time, features).
"""
- super().preprocess(inputs, **infer_kwargs)
# Moirai2.predict() expects past_target in (time, features) format
processed_inputs = []
for item in inputs:
@@ -141,7 +140,7 @@ class Moirai2Pipeline(ForecastPipeline):
f"Model must be an instance of Moirai2ForPrediction, got
{type(self.model)}"
)
- def postprocess(
+ def _postprocess(
self, outputs: list[torch.Tensor], **infer_kwargs
) -> list[torch.Tensor]:
"""
@@ -165,5 +164,4 @@ class Moirai2Pipeline(ForecastPipeline):
else:
# If no quantiles, get the mean
outputs_list.append(output.mean(dim=1))
- super().postprocess(outputs_list, **infer_kwargs)
return outputs_list
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
index 12b2668543e..a528ce0ffc1 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
@@ -33,7 +33,7 @@ class SktimePipeline(ForecastPipeline):
model_kwargs.pop("device", None) # sktime models run on CPU
super().__init__(model_info, **model_kwargs)
- def preprocess(
+ def _preprocess(
self,
inputs: list[dict[str, dict[str, torch.Tensor] | torch.Tensor]],
**infer_kwargs,
@@ -49,8 +49,6 @@ class SktimePipeline(ForecastPipeline):
"""
model_id = self.model_info.model_id
- inputs = super().preprocess(inputs, **infer_kwargs)
-
# Here, we assume element in list has same history_length,
# otherwise, the model cannot proceed
if inputs[0].get("past_covariates", None) or inputs[0].get(
@@ -96,7 +94,7 @@ class SktimePipeline(ForecastPipeline):
return outputs
- def postprocess(self, outputs: np.ndarray, **infer_kwargs) ->
list[torch.Tensor]:
+ def _postprocess(self, outputs: np.ndarray, **infer_kwargs) ->
list[torch.Tensor]:
"""
Postprocess the model's outputs.
@@ -111,5 +109,4 @@ class SktimePipeline(ForecastPipeline):
# Transform outputs into a 2D-tensor: [batch_size, output_length]
outputs = torch.from_numpy(outputs).float()
outputs = [outputs[i].unsqueeze(0) for i in range(outputs.size(0))]
- outputs = super().postprocess(outputs, **infer_kwargs)
return outputs
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
index 8aa9b175169..8e4ffefe316 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
@@ -30,7 +30,7 @@ class SundialPipeline(ForecastPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, **model_kwargs)
- def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor:
+ def _preprocess(self, inputs, **infer_kwargs) -> torch.Tensor:
"""
Preprocess the input data by converting it to a 2D tensor (Sundial
only supports 2D inputs).
@@ -48,7 +48,6 @@ class SundialPipeline(ForecastPipeline):
(i.e., when inputs.shape[1] != 1).
"""
model_id = self.model_info.model_id
- inputs = super().preprocess(inputs, **infer_kwargs)
# Here, we assume element in list has same history_length,
# otherwise, the model cannot proceed
if inputs[0].get("past_covariates", None) or inputs[0].get(
@@ -93,7 +92,7 @@ class SundialPipeline(ForecastPipeline):
)
return outputs
- def postprocess(self, outputs: torch.Tensor, **infer_kwargs) ->
list[torch.Tensor]:
+ def _postprocess(self, outputs: torch.Tensor, **infer_kwargs) ->
list[torch.Tensor]:
"""
Postprocess the model's output by averaging across the num_samples
dimension and
expanding the dimensions to match the expected shape.
@@ -107,5 +106,4 @@ class SundialPipeline(ForecastPipeline):
"""
outputs = outputs.mean(dim=1).unsqueeze(1)
outputs = [outputs[i] for i in range(outputs.size(0))]
- outputs = super().postprocess(outputs, **infer_kwargs)
return outputs
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
index 213e6102c8b..3b7957259d6 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
@@ -30,7 +30,7 @@ class TimerPipeline(ForecastPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, **model_kwargs)
- def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor:
+ def _preprocess(self, inputs, **infer_kwargs) -> torch.Tensor:
"""
Preprocess the input data by converting it to a 2D tensor (Timer-XL
only supports 2D inputs).
@@ -48,7 +48,6 @@ class TimerPipeline(ForecastPipeline):
(i.e., when inputs.shape[1] != 1).
"""
model_id = self.model_info.model_id
- inputs = super().preprocess(inputs, **infer_kwargs)
# Here, we assume element in list has same history_length,
# otherwise, the model cannot proceed
if inputs[0].get("past_covariates", None) or inputs[0].get(
@@ -86,7 +85,7 @@ class TimerPipeline(ForecastPipeline):
outputs = self.model.generate(inputs, max_new_tokens=output_length,
revin=revin)
return outputs
- def postprocess(self, outputs: torch.Tensor, **infer_kwargs) ->
list[torch.Tensor]:
+ def _postprocess(self, outputs: torch.Tensor, **infer_kwargs) ->
list[torch.Tensor]:
"""
Postprocess the model's output by expanding its dimensions to match
the expected shape.
@@ -98,5 +97,4 @@ class TimerPipeline(ForecastPipeline):
list of torch.Tensor: A list of 2D tensors with shape
[target_count(1), output_length].
"""
outputs = [outputs[i].unsqueeze(0) for i in range(outputs.size(0))]
- outputs = super().postprocess(outputs, **infer_kwargs)
return outputs
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
index 815232c52b0..815f1076101 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
@@ -23,17 +23,31 @@ import shutil
import sys
from contextlib import contextmanager
from pathlib import Path
-from typing import Dict, Tuple
-
-from huggingface_hub import snapshot_download
+from typing import Any, Dict, Optional, Tuple, Type
+
+from huggingface_hub import PyTorchModelHubMixin, snapshot_download
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoModelForNextSentencePrediction,
+ AutoModelForSeq2SeqLM,
+ AutoModelForSequenceClassification,
+ AutoModelForTimeSeriesPrediction,
+ AutoModelForTokenClassification,
+ PretrainedConfig,
+ PreTrainedModel,
+)
+from iotdb.ainode.core.config import AINodeDescriptor
from iotdb.ainode.core.exception import InvalidModelUriException
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.model.model_constants import (
- MODEL_CONFIG_FILE_IN_JSON,
- MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
+ CONFIG_JSON,
+ MODEL_SAFETENSORS,
+ ModelCategory,
UriType,
)
+from iotdb.ainode.core.model.model_info import ModelInfo
logger = Logger()
@@ -69,11 +83,95 @@ def load_model_config_in_json(config_path: str) -> Dict:
return json.load(f)
+def get_model_path(model_info: ModelInfo) -> str:
+ return os.path.join(
+ os.getcwd(),
+ AINodeDescriptor().get_config().get_ain_models_dir(),
+ model_info.category.value,
+ model_info.model_id,
+ )
+
+
+def get_model_and_config_by_native_code(
+ model_info: ModelInfo,
+) -> Tuple[
+ Optional[Type[PreTrainedModel | PyTorchModelHubMixin]],
Optional[PretrainedConfig]
+]:
+ """
+ Return model_class and config_instance (optionally) from the model's
native code.
+ """
+
+ # Try to get model str and config str.
+ config_str = None
+ if model_info.auto_map:
+ config_str = model_info.auto_map.get("AutoConfig", "")
+ model_str = model_info.auto_map.get("AutoModelForCausalLM", "")
+ if not config_str or not model_str:
+ return None, None
+ elif model_info.hub_mixin_cls:
+ model_str = model_info.hub_mixin_cls
+ else:
+ return None, None
+
+ model_path = get_model_path(model_info)
+
+ # Try to import model and config class.
+ config_class, config_instance = None, None
+ model_class = None
+ if model_info.category == ModelCategory.BUILTIN:
+ module_name = (
+ AINodeDescriptor().get_config().get_ain_models_builtin_dir()
+ + "."
+ + model_info.model_id
+ )
+ if config_str:
+ # For Transformer models
+ config_class = import_class_from_path(module_name, config_str)
+ config_instance = config_class.from_pretrained(model_path)
+ model_class = import_class_from_path(module_name, model_str)
+ else:
+ module_parent = str(Path(model_path).parent.absolute())
+ with temporary_sys_path(module_parent):
+ if config_str:
+ # For Transformer models
+ config_class = import_class_from_path(model_info.model_id,
config_str)
+ config_instance = config_class.from_pretrained(model_path)
+ model_class = import_class_from_path(model_info.model_id,
model_str)
+
+ return model_class, config_instance
+
+
+def get_model_and_config_by_auto_class(model_path: str) -> Tuple[type, Any]:
+ """Return model_class and config_instance from Huggingface Transformers's
AutoClass."""
+ config_instance = AutoConfig.from_pretrained(model_path)
+
+ if type(config_instance) in
AutoModelForTimeSeriesPrediction._model_mapping.keys():
+ model_class = AutoModelForTimeSeriesPrediction
+ elif (
+ type(config_instance)
+ in AutoModelForNextSentencePrediction._model_mapping.keys()
+ ):
+ model_class = AutoModelForNextSentencePrediction
+ elif type(config_instance) in AutoModelForSeq2SeqLM._model_mapping.keys():
+ model_class = AutoModelForSeq2SeqLM
+ elif (
+ type(config_instance)
+ in AutoModelForSequenceClassification._model_mapping.keys()
+ ):
+ model_class = AutoModelForSequenceClassification
+ elif type(config_instance) in
AutoModelForTokenClassification._model_mapping.keys():
+ model_class = AutoModelForTokenClassification
+ else:
+ model_class = AutoModelForCausalLM
+
+ return model_class, config_instance
+
+
def validate_model_files(model_dir: str) -> Tuple[str, str]:
"""Validate model files exist, return config and weights file paths"""
- config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON)
- weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
+ config_path = os.path.join(model_dir, CONFIG_JSON)
+ weights_path = os.path.join(model_dir, MODEL_SAFETENSORS)
if not os.path.exists(config_path):
raise InvalidModelUriException(
@@ -116,9 +214,9 @@ def _fetch_model_from_local(source_path: str, storage_path:
str):
if not source_dir.is_dir():
raise InvalidModelUriException(f"Source path is not a directory:
{source_path}")
storage_dir = Path(storage_path)
- for file in source_dir.iterdir():
- if file.is_file():
- shutil.copy2(file, storage_dir / file.name)
+ if storage_dir.exists():
+ shutil.rmtree(storage_dir)
+ shutil.copytree(source_dir, storage_dir)
def _fetch_model_from_hf_repo(repo_id: str, storage_path: str):
diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/decorator.py
b/iotdb-core/ainode/iotdb/ainode/core/util/decorator.py
index 5a84c3d6bb2..2abe08ca9dc 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/util/decorator.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/util/decorator.py
@@ -15,15 +15,21 @@
# specific language governing permissions and limitations
# under the License.
#
+import threading
from functools import wraps
def singleton(cls):
+ """Thread-safe singleton decorator."""
instances = {}
+ lock = threading.Lock()
+ @wraps(cls)
def get_instance(*args, **kwargs):
if cls not in instances:
- instances[cls] = cls(*args, **kwargs)
+ with lock:
+ if cls not in instances:
+ instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
index a61032ba26f..f03d323f1b8 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
@@ -75,10 +75,17 @@ def convert_tsblock_to_tensor(tsblock_data: bytes):
# Convert DataFrame to TsBlock in binary, input shouldn't contain time column.
# Maybe contain multiple value columns.
def convert_tensor_to_tsblock(data_tensor: torch.Tensor):
- data_frame = pd.DataFrame(data_tensor).T
- data_shape = data_frame.shape
- value_column_size = data_shape[1]
- position_count = data_shape[0]
+ # Ensure the tensor is 2D with size [target_count, sequence_length]
+ if data_tensor.dim() == 0:
+ data_tensor = data_tensor.unsqueeze(0).unsqueeze(0)
+ elif data_tensor.dim() == 1:
+ data_tensor = data_tensor.unsqueeze(0)
+
+ # Transpose the tensor to [sequence_length, target_count]
+ data_frame = pd.DataFrame(data_tensor.cpu()).T
+ # sequence_length, target_count
+ position_count, value_column_size = data_frame.shape[0],
data_frame.shape[1]
+
keys = data_frame.keys()
binary = value_column_size.to_bytes(4, byteorder="big")
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
index 1a0fdd4bb96..f406c16c085 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -62,8 +62,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
-import java.util.HashMap;
-import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
@@ -74,6 +72,8 @@ import java.util.Set;
import java.util.stream.Collectors;
import static
org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex;
+import static
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.TableFunctionUtils.checkType;
+import static
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.TableFunctionUtils.parseOptions;
import static
org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender;
public class ForecastTableFunction implements TableFunction {
@@ -201,17 +201,6 @@ public class ForecastTableFunction implements
TableFunction {
protected static final String DEFAULT_OPTIONS = "";
protected static final int MAX_INPUT_LENGTH = 2880;
- private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s";
-
- private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>();
-
- static {
- ALLOWED_INPUT_TYPES.add(Type.INT32);
- ALLOWED_INPUT_TYPES.add(Type.INT64);
- ALLOWED_INPUT_TYPES.add(Type.FLOAT);
- ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
- }
-
@Override
public List<ParameterSpecification> getArgumentsSpecifications() {
return Arrays.asList(
@@ -367,38 +356,6 @@ public class ForecastTableFunction implements
TableFunction {
};
}
- // only allow for INT32, INT64, FLOAT, DOUBLE
- public void checkType(Type type, String columnName) {
- if (!ALLOWED_INPUT_TYPES.contains(type)) {
- throw new SemanticException(
- String.format(
- "The type of the column [%s] is [%s], only INT32, INT64, FLOAT,
DOUBLE is allowed",
- columnName, type));
- }
- }
-
- public static Map<String, String> parseOptions(String options) {
- if (options.isEmpty()) {
- return Collections.emptyMap();
- }
- String[] optionArray = options.split(",");
- if (optionArray.length == 0) {
- throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT,
options));
- }
-
- Map<String, String> optionsMap = new HashMap<>(optionArray.length);
- for (String option : optionArray) {
- int index = option.indexOf('=');
- if (index == -1 || index == option.length() - 1) {
- throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT,
option));
- }
- String key = option.substring(0, index).trim();
- String value = option.substring(index + 1).trim();
- optionsMap.put(key, value);
- }
- return optionsMap;
- }
-
protected static class ForecastDataProcessor implements
TableFunctionDataProcessor {
protected static final TsBlockSerde SERDE = new TsBlockSerde();
@@ -474,7 +431,7 @@ public class ForecastTableFunction implements TableFunction
{
int columnSize = properColumnBuilders.size();
// sort inputRecords in ascending order by timestamp
- inputRecords.sort(Comparator.comparingLong(record -> record.getLong(0)));
+ inputRecords.sort(Comparator.comparingLong(r -> r.getLong(0)));
// time column
long inputStartTime = inputRecords.getFirst().getLong(0);
@@ -514,6 +471,7 @@ public class ForecastTableFunction implements TableFunction
{
TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
}
+ // construct result column
for (int columnIndex = 1, size = predicatedResult.getValueColumnCount();
columnIndex <= size;
columnIndex++) {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/TableFunctionUtils.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/TableFunctionUtils.java
new file mode 100644
index 00000000000..499f133c437
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/TableFunctionUtils.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.function.tvf;
+
+import org.apache.iotdb.db.exception.sql.SemanticException;
+import org.apache.iotdb.udf.api.type.Type;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+public class TableFunctionUtils {
+ private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s";
+
+ public static Map<String, String> parseOptions(String options) {
+ if (options.isEmpty()) {
+ return Collections.emptyMap();
+ }
+ String[] optionArray = options.split(",");
+ if (optionArray.length == 0) {
+ throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT,
options));
+ }
+
+ Map<String, String> optionsMap = new HashMap<>(optionArray.length);
+ for (String option : optionArray) {
+ int index = option.indexOf('=');
+ if (index == -1 || index == option.length() - 1) {
+ throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT,
option));
+ }
+ String key = option.substring(0, index).trim();
+ String value = option.substring(index + 1).trim();
+ optionsMap.put(key, value);
+ }
+ return optionsMap;
+ }
+
+ private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>();
+
+ static {
+ ALLOWED_INPUT_TYPES.add(Type.INT32);
+ ALLOWED_INPUT_TYPES.add(Type.INT64);
+ ALLOWED_INPUT_TYPES.add(Type.FLOAT);
+ ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
+ }
+
+ // only allow for INT32, INT64, FLOAT, DOUBLE
+ public static void checkType(Type type, String columnName) {
+ if (!ALLOWED_INPUT_TYPES.contains(type)) {
+ throw new SemanticException(
+ String.format(
+ "The type of the column [%s] is [%s], only INT32, INT64, FLOAT,
DOUBLE is allowed",
+ columnName, type));
+ }
+ }
+}