ycycse commented on code in PR #15756:
URL: https://github.com/apache/iotdb/pull/15756#discussion_r2154164562
##########
iotdb-core/ainode/ainode/core/manager/inference_manager.py:
##########
@@ -50,32 +50,57 @@ def infer(self, full_data, **kwargs):
pass
-# [IoTDB] full data deserialized from iotdb is composed of [timestampList,
valueList, length],
-# we only get valueList currently.
+# [IoTDB] Full data deserialized from IoTDB is composed of [timestampList,
valueList, length],
+# currently we only use valueList.
class TimerXLStrategy(InferenceStrategy):
- def infer(self, full_data, predict_length=96, **_):
+ def infer(self, full_data, predict_length=96, **kwargs):
data = full_data[1][0]
if data.dtype.byteorder not in ("=", "|"):
data = data.byteswap().newbyteorder()
seqs = torch.tensor(data).unsqueeze(0).float()
- # TODO: unify model inference input
- output = self.model.generate(seqs, max_new_tokens=predict_length,
revin=True)
- df = pd.DataFrame(output[0])
- return convert_to_binary(df)
+
+ # Inference parameters for IoTDB models
+ revin = kwargs.get("revin", True)
+ max_tokens = kwargs.get("max_new_tokens", predict_length)
+
+ logger.debug(
+ f"TimerXL inference: input_shape={seqs.shape},
predict_length={max_tokens}"
+ )
+
+ try:
+ output = self.model.generate(seqs, max_new_tokens=max_tokens,
revin=revin)
Review Comment:
Why don't you use `predict_length`? It make sure the interface is
user-friendly. If we want to change it, we should not keep the parameter.
##########
iotdb-core/ainode/ainode/core/manager/inference_manager.py:
##########
@@ -50,32 +50,57 @@ def infer(self, full_data, **kwargs):
pass
-# [IoTDB] full data deserialized from iotdb is composed of [timestampList,
valueList, length],
-# we only get valueList currently.
+# [IoTDB] Full data deserialized from IoTDB is composed of [timestampList,
valueList, length],
+# currently we only use valueList.
class TimerXLStrategy(InferenceStrategy):
- def infer(self, full_data, predict_length=96, **_):
+ def infer(self, full_data, predict_length=96, **kwargs):
data = full_data[1][0]
if data.dtype.byteorder not in ("=", "|"):
data = data.byteswap().newbyteorder()
seqs = torch.tensor(data).unsqueeze(0).float()
- # TODO: unify model inference input
- output = self.model.generate(seqs, max_new_tokens=predict_length,
revin=True)
- df = pd.DataFrame(output[0])
- return convert_to_binary(df)
+
+ # Inference parameters for IoTDB models
+ revin = kwargs.get("revin", True)
+ max_tokens = kwargs.get("max_new_tokens", predict_length)
+
+ logger.debug(
+ f"TimerXL inference: input_shape={seqs.shape},
predict_length={max_tokens}"
+ )
+
+ try:
+ output = self.model.generate(seqs, max_new_tokens=max_tokens,
revin=revin)
+ df = pd.DataFrame(output[0])
+ return convert_to_binary(df)
+ except Exception as e:
+ logger.error(f"TimerXL inference failed: {e}")
+ raise InferenceModelInternalError(f"TimerXL inference error:
{str(e)}")
class SundialStrategy(InferenceStrategy):
- def infer(self, full_data, predict_length=96, **_):
+ def infer(self, full_data, predict_length=96, **kwargs):
data = full_data[1][0]
if data.dtype.byteorder not in ("=", "|"):
data = data.byteswap().newbyteorder()
seqs = torch.tensor(data).unsqueeze(0).float()
- # TODO: unify model inference input
- output = self.model.generate(
- seqs, max_new_tokens=predict_length, num_samples=10, revin=True
+
+ # Inference parameters for IoTDB models
+ revin = kwargs.get("revin", True)
+ max_tokens = kwargs.get("max_new_tokens", predict_length)
+ num_samples = kwargs.get("num_samples", 10)
+
+ logger.debug(
+ f"Sundial inference: input_shape={seqs.shape},
predict_length={max_tokens}, num_samples={num_samples}"
)
- df = pd.DataFrame(output[0].mean(dim=0))
- return convert_to_binary(df)
+
+ try:
+ output = self.model.generate(
+ seqs, max_new_tokens=max_tokens, num_samples=num_samples,
revin=revin
Review Comment:
The same. Maybe we can directly use `predict_length`
##########
iotdb-core/ainode/ainode/core/manager/cluster_manager.py:
##########
@@ -15,32 +15,144 @@
# specific language governing permissions and limitations
# under the License.
#
+import threading
+import time
+
import psutil
+from ainode.core.config import AINodeDescriptor
+from ainode.core.log import Logger
from ainode.thrift.ainode.ttypes import TAIHeartbeatReq, TAIHeartbeatResp
from ainode.thrift.common.ttypes import TLoadSample
+logger = Logger()
+
class ClusterManager:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls):
+ if cls._instance is None:
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self):
+ if not self._initialized:
+ self._node_status = "STARTING"
+ self._start_time = time.time()
+ self._last_heartbeat = 0
+ self._heartbeat_count = 0
+ self._initialized = True
+
@staticmethod
def get_heart_beat(req: TAIHeartbeatReq) -> TAIHeartbeatResp:
+ """
+ Enhanced heartbeat response with additional node information
+ """
+ instance = ClusterManager()
+ instance._last_heartbeat = time.time()
+ instance._heartbeat_count += 1
+ instance._node_status = "RUNNING"
+
+ logger.debug(
+ f"Heartbeat request #{instance._heartbeat_count},
needSamplingLoad: {req.needSamplingLoad}"
+ )
+
if req.needSamplingLoad:
- cpu_percent = psutil.cpu_percent(interval=1)
- memory_percent = psutil.virtual_memory().percent
- disk_usage = psutil.disk_usage("/")
- disk_free = disk_usage.free
- load_sample = TLoadSample(
- cpuUsageRate=cpu_percent,
- memoryUsageRate=memory_percent,
- diskUsageRate=disk_usage.percent,
- freeDiskSpace=disk_free / 1024 / 1024 / 1024,
- )
- return TAIHeartbeatResp(
- heartbeatTimestamp=req.heartbeatTimestamp,
- status="Running",
- loadSample=load_sample,
- )
+ try:
+ # System load metrics
+ cpu_percent = psutil.cpu_percent(interval=1)
+ memory_info = psutil.virtual_memory()
+ memory_percent = memory_info.percent
+ disk_usage = psutil.disk_usage("/")
+ disk_free = disk_usage.free
+
+ load_sample = TLoadSample(
+ cpuUsageRate=cpu_percent,
+ memoryUsageRate=memory_percent,
+ diskUsageRate=disk_usage.percent,
+ freeDiskSpace=disk_free / 1024 / 1024 / 1024, # GB
+ )
+
+ logger.debug(
+ f"System load - CPU: {cpu_percent:.1f}%, "
+ f"Memory: {memory_percent:.1f}%, "
+ f"Disk Usage: {disk_usage.percent:.1f}%, "
+ f"Free Space: {disk_free / 1024 / 1024 / 1024:.1f}GB"
+ )
+
+ return TAIHeartbeatResp(
+ heartbeatTimestamp=req.heartbeatTimestamp,
+ status=instance._node_status,
+ loadSample=load_sample,
+ )
+ except Exception as e:
+ logger.error(f"Failed to retrieve system load metrics: {e}")
+ # Return basic heartbeat if system load cannot be retrieved
+ return TAIHeartbeatResp(
+ heartbeatTimestamp=req.heartbeatTimestamp,
+ status="RUNNING_WITH_ERROR",
+ )
else:
return TAIHeartbeatResp(
- heartbeatTimestamp=req.heartbeatTimestamp, status="Running"
+ heartbeatTimestamp=req.heartbeatTimestamp,
status=instance._node_status
)
+
+ def get_node_info(self) -> dict:
+ """Retrieve detailed node information"""
+ try:
+ config = AINodeDescriptor().get_config()
+ uptime = time.time() - self._start_time
+
+ return {
+ "node_id": config.get_ainode_id(),
+ "cluster_name": config.get_cluster_name(),
+ "status": self._node_status,
+ "uptime_seconds": uptime,
+ "heartbeat_count": self._heartbeat_count,
+ "last_heartbeat": self._last_heartbeat,
+ "rpc_address": config.get_ain_inference_rpc_address(),
+ "rpc_port": config.get_ain_inference_rpc_port(),
+ "version": config.get_version_info(),
+ "build": config.get_build_info(),
+ }
+ except Exception as e:
+ logger.error(f"Failed to retrieve node information: {e}")
+ return {"error": str(e)}
+
+ def set_node_status(self, status: str):
+ """Set the status of the current node"""
+ self._node_status = status
+ logger.info(f"Node status updated to: {status}")
+
+ def get_system_metrics(self) -> dict:
+ """Retrieve system-level metrics"""
+ try:
+ cpu_percent = psutil.cpu_percent(interval=1)
+ memory = psutil.virtual_memory()
+ disk = psutil.disk_usage("/")
+
+ return {
+ "cpu": {
+ "usage_percent": cpu_percent,
+ "count": psutil.cpu_count(),
+ },
+ "memory": {
+ "total_gb": memory.total / 1024 / 1024 / 1024,
+ "available_gb": memory.available / 1024 / 1024 / 1024,
+ "usage_percent": memory.percent,
+ },
+ "disk": {
+ "total_gb": disk.total / 1024 / 1024 / 1024,
+ "free_gb": disk.free / 1024 / 1024 / 1024,
+ "usage_percent": disk.percent,
+ },
+ "timestamp": time.time(),
+ }
+ except Exception as e:
+ logger.error(f"Failed to retrieve system metrics: {e}")
+ return {"error": str(e)}
Review Comment:
What do these methods do?
##########
iotdb-core/ainode/ainode/core/model/model_factory.py:
##########
@@ -41,6 +53,50 @@
logger = Logger()
+def _detect_model_format(base_path: str) -> tuple:
+ """
+ Detect model format: supports both IoTDB and legacy formats
+
+ Args:
+ base_path: Model directory path or network URI
+
+ Returns:
+ (format_type, config_file, weight_file): Format type and corresponding
file names
+ """
+ base_path = (
+ Path(base_path)
+ if not base_path.startswith(("http://", "https://"))
+ else base_path
+ )
+
+ # Check IoTDB format first (higher priority)
+ for config_file in IOTDB_CONFIG_FILES:
+ if isinstance(base_path, Path):
+ config_path = base_path / config_file
+ if config_path.exists():
+ # Look for corresponding weight file
+ for weight_file in WEIGHT_FORMAT_PRIORITY:
+ weight_path = base_path / weight_file
+ if weight_path.exists():
+ logger.info(
+ f"IoTDB format detected: {config_file} +
{weight_file}"
+ )
+ return "iotdb", config_file, weight_file
+ else:
+ # Skip detection for remote paths; will be handled during download
+ pass
+
+ # Check legacy format
+ if isinstance(base_path, Path):
+ legacy_config = base_path / DEFAULT_CONFIG_FILE_NAME
+ legacy_model = base_path / DEFAULT_MODEL_FILE_NAME
Review Comment:
The same
##########
iotdb-core/ainode/ainode/core/util/serde.py:
##########
@@ -155,3 +156,44 @@ def get_data_type_byte_from_str(value):
return TSDataType.DOUBLE.value
elif value == "text":
return TSDataType.TEXT.value
+
+
+def convert_iotdb_data_to_binary(
+ data_frame: pd.DataFrame, model_format: str = "legacy"
+):
+ """
+ Enhanced binary conversion with IoTDB model format support
+ """
+ try:
+ if model_format == "iotdb":
+ # Enhanced processing for IoTDB models
+ logger.debug(f"Converting IoTDB format data:
shape={data_frame.shape}")
+
+ # Use existing conversion logic
+ return convert_to_binary(data_frame)
+
+ except Exception as e:
+ logger.error(f"Binary conversion failed for format {model_format}:
{e}")
+ raise BadConfigValueError("data_conversion", model_format, str(e))
Review Comment:
What is this method for?
##########
iotdb-core/ainode/ainode/core/handler.py:
##########
@@ -17,6 +17,15 @@
#
from ainode.core.constant import TSStatusCode
+
+# only for test
+from ainode.core.exception import (
+ ConfigValidationError,
+ IoTDBModelError,
+ ModelFormatError,
+ ModelLoadingError,
+ WeightFileError,
+)
Review Comment:
Use UT/IT for testing.
##########
iotdb-core/ainode/ainode/core/manager/inference_manager.py:
##########
@@ -50,32 +50,57 @@ def infer(self, full_data, **kwargs):
pass
-# [IoTDB] full data deserialized from iotdb is composed of [timestampList,
valueList, length],
-# we only get valueList currently.
+# [IoTDB] Full data deserialized from IoTDB is composed of [timestampList,
valueList, length],
+# currently we only use valueList.
class TimerXLStrategy(InferenceStrategy):
- def infer(self, full_data, predict_length=96, **_):
+ def infer(self, full_data, predict_length=96, **kwargs):
data = full_data[1][0]
if data.dtype.byteorder not in ("=", "|"):
data = data.byteswap().newbyteorder()
seqs = torch.tensor(data).unsqueeze(0).float()
- # TODO: unify model inference input
- output = self.model.generate(seqs, max_new_tokens=predict_length,
revin=True)
- df = pd.DataFrame(output[0])
- return convert_to_binary(df)
+
+ # Inference parameters for IoTDB models
+ revin = kwargs.get("revin", True)
+ max_tokens = kwargs.get("max_new_tokens", predict_length)
Review Comment:
Discuss with @CRZbulabula
##########
iotdb-core/ainode/ainode/core/handler.py:
##########
@@ -43,20 +52,141 @@ def __init__(self):
self._model_manager = ModelManager()
self._inference_manager =
InferenceManager(model_manager=self._model_manager)
+ # def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
+ # return self._model_manager.register_model(req)
+
+ # def deleteModel(self, req: TDeleteModelReq) -> TSStatus:
+ # return self._model_manager.delete_model(req)
+
+ # def inference(self, req: TInferenceReq) -> TInferenceResp:
+ # return self._inference_manager.inference(req)
+
+ # def forecast(self, req: TForecastReq) -> TSStatus:
+ # return self._inference_manager.forecast(req)
+
+ # def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
+ # return ClusterManager.get_heart_beat(req)
+
+ # def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
+ # pass
+
+ # only for test
def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
- return self._model_manager.register_model(req)
+ """
+ register model
+ """
+ logger.info(f"Start register model: {req.modelId}, URI: {req.uri}")
+
+ try:
+ result = self._model_manager.register_model(req)
+ if result.status.code ==
TSStatusCode.SUCCESS_STATUS.get_status_code():
+ logger.info(f"Register model successfully: {req.modelId}")
+ else:
+ logger.warning(
+ f"Failed to register model: {req.modelId}, with status:
{result.status}"
+ )
+ return result
+
+ except (
+ ModelLoadingError,
+ ModelFormatError,
+ IoTDBModelError,
+ ConfigValidationError,
+ WeightFileError,
+ ) as e:
+ logger.error(
+ f"Failed to register model: error known: {req.modelId}, with
error: {e}"
+ )
+ from ainode.core.util.status import get_status
Review Comment:
Why put import here, will importing get_status take too much time?
##########
iotdb-core/ainode/ainode/core/manager/model_manager.py:
##########
@@ -112,6 +251,69 @@ def get_ckpt_path(self, model_id: str) -> str:
"""
return self.model_storage.get_ckpt_path(model_id)
+ def _validate_model_name(self, model_name: str) -> bool:
+ """
+ Validate the model name format
+
+ Args:
+ model_name: Name of the model
+
+ Returns:
+ Whether the name is valid
+ """
+ if not model_name:
+ return False
+
+ import re
+
+ pattern = r"^[a-zA-Z0-9_-]+$"
+
+ if re.match(pattern, model_name):
+ logger.debug(f"Model name validated: {model_name}")
+ return True
+ else:
+ logger.error(f"Illegal model name: {model_name}")
+ return False
+
+ def _update_model_status(self, model_id: str, status: str, message: str =
""):
+ """Update model status and notify ConfigNode"""
+ try:
+ with self._status_lock:
+ self._model_status_cache[model_id] = {
+ "status": status,
+ "message": message,
+ "timestamp": time.time(),
+ }
+
+ status_code_map = {"LOADING": 0, "ACTIVE": 1, "INACTIVE": 2,
"ERROR": 3}
Review Comment:
set this as a const parameters in costant.py or somewhere.
##########
iotdb-core/ainode/ainode/core/manager/cluster_manager.py:
##########
@@ -15,32 +15,144 @@
# specific language governing permissions and limitations
# under the License.
#
+import threading
+import time
+
import psutil
+from ainode.core.config import AINodeDescriptor
+from ainode.core.log import Logger
from ainode.thrift.ainode.ttypes import TAIHeartbeatReq, TAIHeartbeatResp
from ainode.thrift.common.ttypes import TLoadSample
+logger = Logger()
+
class ClusterManager:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls):
+ if cls._instance is None:
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
Review Comment:
use `@singleton` in util.py
##########
iotdb-core/ainode/ainode/core/constant.py:
##########
@@ -53,8 +53,11 @@
TRIAL_ID_PREFIX = "__trial_"
DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0"
-DEFAULT_MODEL_FILE_NAME = "model.safetensors"
-DEFAULT_CONFIG_FILE_NAME = "config.json"
+# DEFAULT_MODEL_FILE_NAME = "model.safetensors"
+# DEFAULT_CONFIG_FILE_NAME = "config.json"
Review Comment:
remove this.
##########
iotdb-core/ainode/ainode/core/model/config_parser.py:
##########
@@ -0,0 +1,317 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+from pathlib import Path
+from typing import Any, Dict, Union
+
+import yaml
+
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+def parse_config_file(config_path: Union[str, Path]) -> Dict[str, Any]:
+ """
+ Parse the configuration file, supporting JSON and YAML formats
+
+ Args:
+ config_path: Path to the configuration file
+
+ Returns:
+ Configuration dictionary
+ """
+ config_path = Path(config_path)
+
+ if not config_path.exists():
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
+
+ suffix = config_path.suffix.lower()
+
+ try:
+ with open(config_path, "r", encoding="utf-8") as f:
+ if suffix == ".json":
+ return json.load(f)
+ elif suffix in [".yaml", ".yml"]:
+ return yaml.safe_load(f)
+ else:
+ # Try JSON parsing
+ content = f.read()
+ try:
+ return json.loads(content)
+ except json.JSONDecodeError:
+ # Try YAML parsing
+ return yaml.safe_load(content)
+ except Exception as e:
+ logger.error(f"Failed to parse configuration file: {config_path},
Error: {e}")
+ raise
+
+
+def convert_iotdb_config_to_ainode_format(
+ iotdb_config: Dict[str, Any]
+) -> Dict[str, Any]:
+ """
+ Convert IoTDB configuration to AINode format (formerly thuTL_config)
+
+ Args:
+ iotdb_config: IoTDB configuration dictionary
+
+ Returns:
+ AINode format configuration dictionary
+ """
+ model_type = iotdb_config.get("model_type", "unknown")
+ input_length = iotdb_config.get("input_token_len", 96)
+ output_length = (
+ iotdb_config.get("output_token_lens", [96])[0]
+ if iotdb_config.get("output_token_lens")
+ else 96
+ )
Review Comment:
Discuss with @CRZbulabula
##########
iotdb-core/ainode/ainode/core/model/built_in_model_factory.py:
##########
@@ -528,8 +528,7 @@ def parse_attribute(
AINodeDescriptor().get_config().get_ain_models_dir(),
"weights",
"timerxl",
- "model.safetensors",
Review Comment:
remove.
##########
iotdb-core/ainode/ainode/core/constant.py:
##########
@@ -53,8 +53,11 @@
TRIAL_ID_PREFIX = "__trial_"
DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0"
-DEFAULT_MODEL_FILE_NAME = "model.safetensors"
-DEFAULT_CONFIG_FILE_NAME = "config.json"
+# DEFAULT_MODEL_FILE_NAME = "model.safetensors"
+# DEFAULT_CONFIG_FILE_NAME = "config.json"
+
+DEFAULT_MODEL_FILE_NAME = "model.pt" # change default file -> model.pt
+DEFAULT_CONFIG_FILE_NAME = "config.yaml" # change default config file ->
config.yaml
Review Comment:
There is no need to explain it.
##########
iotdb-core/ainode/ainode/core/handler.py:
##########
@@ -43,20 +52,141 @@ def __init__(self):
self._model_manager = ModelManager()
self._inference_manager =
InferenceManager(model_manager=self._model_manager)
+ # def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
+ # return self._model_manager.register_model(req)
+
+ # def deleteModel(self, req: TDeleteModelReq) -> TSStatus:
+ # return self._model_manager.delete_model(req)
+
+ # def inference(self, req: TInferenceReq) -> TInferenceResp:
+ # return self._inference_manager.inference(req)
+
+ # def forecast(self, req: TForecastReq) -> TSStatus:
+ # return self._inference_manager.forecast(req)
+
+ # def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
+ # return ClusterManager.get_heart_beat(req)
+
+ # def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
+ # pass
Review Comment:
Don't keep these codes if we don't need them anymore.
##########
iotdb-core/ainode/ainode/core/manager/model_manager.py:
##########
@@ -42,63 +50,194 @@
class ModelManager:
def __init__(self):
self.model_storage = ModelStorage()
+ self._model_status_cache = {} # Cache for model statuses
+ self._status_lock = threading.Lock()
Review Comment:
Use RWLock()
##########
iotdb-core/ainode/ainode/core/exception.py:
##########
@@ -75,10 +75,14 @@ def __init__(self, msg: str):
class InvalidUriError(_BaseError):
- def __init__(self, uri: str):
- self.message = "Invalid uri: {}, there are no {} or {} under this
uri.".format(
- uri, DEFAULT_MODEL_FILE_NAME, DEFAULT_CONFIG_FILE_NAME
- )
+ def __init__(self, uri: str, details: str = ""):
+ if details:
+ self.message = "Invalid uri: {}, {}".format(uri, details)
+ else:
+ # fix path error
+ self.message = "Invalid uri: {}, no valid model files found
(checked both IoTDB and legacy formats)".format(
Review Comment:
Set "no valid model files found (checked both IoTDB and legacy
formats)".format(uri)" as template and pass through `details` parameters.
##########
iotdb-core/ainode/ainode/core/handler.py:
##########
@@ -43,20 +52,141 @@ def __init__(self):
self._model_manager = ModelManager()
self._inference_manager =
InferenceManager(model_manager=self._model_manager)
+ # def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
+ # return self._model_manager.register_model(req)
+
+ # def deleteModel(self, req: TDeleteModelReq) -> TSStatus:
+ # return self._model_manager.delete_model(req)
+
+ # def inference(self, req: TInferenceReq) -> TInferenceResp:
+ # return self._inference_manager.inference(req)
+
+ # def forecast(self, req: TForecastReq) -> TSStatus:
+ # return self._inference_manager.forecast(req)
+
+ # def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
+ # return ClusterManager.get_heart_beat(req)
+
+ # def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
+ # pass
+
+ # only for test
def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
- return self._model_manager.register_model(req)
+ """
+ register model
+ """
+ logger.info(f"Start register model: {req.modelId}, URI: {req.uri}")
+
+ try:
+ result = self._model_manager.register_model(req)
+ if result.status.code ==
TSStatusCode.SUCCESS_STATUS.get_status_code():
+ logger.info(f"Register model successfully: {req.modelId}")
+ else:
+ logger.warning(
+ f"Failed to register model: {req.modelId}, with status:
{result.status}"
+ )
+ return result
+
+ except (
+ ModelLoadingError,
+ ModelFormatError,
+ IoTDBModelError,
+ ConfigValidationError,
+ WeightFileError,
+ ) as e:
+ logger.error(
+ f"Failed to register model: error known: {req.modelId}, with
error: {e}"
+ )
+ from ainode.core.util.status import get_status
+
+ return TRegisterModelResp(
+ get_status(TSStatusCode.INVALID_URI_ERROR, str(e))
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to register model: unknown error: {req.modelId}, with
error: {e}"
+ )
+ from ainode.core.util.status import get_status
+
+ return TRegisterModelResp(
+ get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
+ )
def deleteModel(self, req: TDeleteModelReq) -> TSStatus:
- return self._model_manager.delete_model(req)
+ """
+ Reinforce log data in delete models
+ """
+ logger.info(f"Start delete models: {req.modelId}")
+
+ try:
+ result = self._model_manager.delete_model(req)
+ if result.code == TSStatusCode.SUCCESS_STATUS.get_status_code():
+ logger.info(f"Delete models successfully: {req.modelId}")
+ else:
+ logger.warning(
+ f"Failed to delete models: {req.modelId}, with status:
{result}"
+ )
+ return result
+
+ except Exception as e:
+ logger.error(f"Failed to delete models: {req.modelId}, with error:
{e}")
+ from ainode.core.util.status import get_status
+
+ return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
def inference(self, req: TInferenceReq) -> TInferenceResp:
- return self._inference_manager.inference(req)
+ """
+ Perform inference with enhanced logging
+ """
+ logger.debug(f"Starting inference: model ID={req.modelId}")
+
+ try:
+ result = self._inference_manager.inference(req)
+ if result.status.code ==
TSStatusCode.SUCCESS_STATUS.get_status_code():
+ logger.debug(f"Inference succeeded: {req.modelId}")
+ else:
+ logger.warning(
+ f"Inference failed: {req.modelId}, Status: {result.status}"
+ )
+ return result
+
+ except Exception as e:
+ logger.error(f"Inference failed: {req.modelId}, Error: {e}")
+ from ainode.core.util.status import get_status
+
+ return TInferenceResp(
+ get_status(TSStatusCode.INFERENCE_INTERNAL_ERROR, str(e)), []
+ )
def forecast(self, req: TForecastReq) -> TSStatus:
- return self._inference_manager.forecast(req)
+ """
+ Perform forecasting with enhanced logging
+ """
+ logger.debug(f"Starting forecast: model ID={req.modelId}")
+
+ try:
+ result = self._inference_manager.forecast(req)
+ if result.status.code ==
TSStatusCode.SUCCESS_STATUS.get_status_code():
+ logger.debug(f"Forecast succeeded: {req.modelId}")
+ else:
+ logger.warning(
+ f"Forecast failed: {req.modelId}, Status: {result.status}"
+ )
+ return result
+
+ except Exception as e:
+ logger.error(f"Forecast failed: {req.modelId}, Error: {e}")
+ from ainode.core.util.status import get_status
+
+ return get_status(TSStatusCode.INFERENCE_INTERNAL_ERROR, str(e))
def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
return ClusterManager.get_heart_beat(req)
Review Comment:
Why catch exception outside just like other methods?
##########
iotdb-core/ainode/ainode/core/model/model_factory.py:
##########
@@ -41,6 +53,50 @@
logger = Logger()
+def _detect_model_format(base_path: str) -> tuple:
+ """
+ Detect model format: supports both IoTDB and legacy formats
+
+ Args:
+ base_path: Model directory path or network URI
+
+ Returns:
+ (format_type, config_file, weight_file): Format type and corresponding
file names
+ """
+ base_path = (
+ Path(base_path)
+ if not base_path.startswith(("http://", "https://"))
+ else base_path
+ )
+
+ # Check IoTDB format first (higher priority)
+ for config_file in IOTDB_CONFIG_FILES:
+ if isinstance(base_path, Path):
+ config_path = base_path / config_file
+ if config_path.exists():
+ # Look for corresponding weight file
+ for weight_file in WEIGHT_FORMAT_PRIORITY:
+ weight_path = base_path / weight_file
Review Comment:
```suggestion
weight_path = os.path.join(base_path, weight_file)
```
##########
iotdb-core/ainode/ainode/core/util/serde.py:
##########
@@ -155,3 +156,44 @@ def get_data_type_byte_from_str(value):
return TSDataType.DOUBLE.value
elif value == "text":
return TSDataType.TEXT.value
+
+
+def convert_iotdb_data_to_binary(
+ data_frame: pd.DataFrame, model_format: str = "legacy"
+):
+ """
+ Enhanced binary conversion with IoTDB model format support
+ """
+ try:
+ if model_format == "iotdb":
+ # Enhanced processing for IoTDB models
+ logger.debug(f"Converting IoTDB format data:
shape={data_frame.shape}")
+
+ # Use existing conversion logic
+ return convert_to_binary(data_frame)
+
+ except Exception as e:
+ logger.error(f"Binary conversion failed for format {model_format}:
{e}")
+ raise BadConfigValueError("data_conversion", model_format, str(e))
+
+
+def validate_data_types(data_frame: pd.DataFrame) -> bool:
+ """
+ Validate data types in DataFrame for IoTDB compatibility
+ """
+ try:
+ supported_types = ["bool", "int32", "int64", "float32", "float64",
"object"]
Review Comment:
```suggestion
supported_types = ["int32", "int64", "float32", "float64"]
```
##########
iotdb-core/ainode/ainode/core/handler.py:
##########
@@ -43,20 +52,141 @@ def __init__(self):
self._model_manager = ModelManager()
self._inference_manager =
InferenceManager(model_manager=self._model_manager)
+ # def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
+ # return self._model_manager.register_model(req)
+
+ # def deleteModel(self, req: TDeleteModelReq) -> TSStatus:
+ # return self._model_manager.delete_model(req)
+
+ # def inference(self, req: TInferenceReq) -> TInferenceResp:
+ # return self._inference_manager.inference(req)
+
+ # def forecast(self, req: TForecastReq) -> TSStatus:
+ # return self._inference_manager.forecast(req)
+
+ # def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
+ # return ClusterManager.get_heart_beat(req)
+
+ # def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
+ # pass
+
+ # only for test
def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
- return self._model_manager.register_model(req)
+ """
+ register model
+ """
+ logger.info(f"Start register model: {req.modelId}, URI: {req.uri}")
+
+ try:
+ result = self._model_manager.register_model(req)
+ if result.status.code ==
TSStatusCode.SUCCESS_STATUS.get_status_code():
+ logger.info(f"Register model successfully: {req.modelId}")
+ else:
+ logger.warning(
+ f"Failed to register model: {req.modelId}, with status:
{result.status}"
+ )
+ return result
+
+ except (
+ ModelLoadingError,
+ ModelFormatError,
+ IoTDBModelError,
+ ConfigValidationError,
+ WeightFileError,
+ ) as e:
+ logger.error(
+ f"Failed to register model: error known: {req.modelId}, with
error: {e}"
+ )
+ from ainode.core.util.status import get_status
+
+ return TRegisterModelResp(
+ get_status(TSStatusCode.INVALID_URI_ERROR, str(e))
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to register model: unknown error: {req.modelId}, with
error: {e}"
+ )
+ from ainode.core.util.status import get_status
+
+ return TRegisterModelResp(
+ get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
+ )
def deleteModel(self, req: TDeleteModelReq) -> TSStatus:
- return self._model_manager.delete_model(req)
+ """
+ Reinforce log data in delete models
+ """
+ logger.info(f"Start delete models: {req.modelId}")
+
+ try:
+ result = self._model_manager.delete_model(req)
+ if result.code == TSStatusCode.SUCCESS_STATUS.get_status_code():
+ logger.info(f"Delete models successfully: {req.modelId}")
+ else:
+ logger.warning(
+ f"Failed to delete models: {req.modelId}, with status:
{result}"
+ )
+ return result
+
+ except Exception as e:
+ logger.error(f"Failed to delete models: {req.modelId}, with error:
{e}")
+ from ainode.core.util.status import get_status
+
+ return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
def inference(self, req: TInferenceReq) -> TInferenceResp:
- return self._inference_manager.inference(req)
+ """
+ Perform inference with enhanced logging
+ """
+ logger.debug(f"Starting inference: model ID={req.modelId}")
Review Comment:
Why use debug? Take care of the log level.
##########
iotdb-core/ainode/ainode/core/model/model_factory.py:
##########
@@ -41,6 +53,50 @@
logger = Logger()
+def _detect_model_format(base_path: str) -> tuple:
+ """
+ Detect model format: supports both IoTDB and legacy formats
+
+ Args:
+ base_path: Model directory path or network URI
+
+ Returns:
+ (format_type, config_file, weight_file): Format type and corresponding
file names
+ """
+ base_path = (
+ Path(base_path)
+ if not base_path.startswith(("http://", "https://"))
+ else base_path
+ )
+
+ # Check IoTDB format first (higher priority)
+ for config_file in IOTDB_CONFIG_FILES:
+ if isinstance(base_path, Path):
+ config_path = base_path / config_file
Review Comment:
```suggestion
config_path = os.path.join(base_path, config_file)
```
##########
iotdb-core/ainode/ainode/core/model/model_factory.py:
##########
@@ -170,6 +354,32 @@ def _register_model_from_local(
return configs, attributes
+def _convert_iotdb_config_to_ainode_format(iotdb_config: dict) -> dict:
+ """
+ IoTDB -> AINode
+ """
+ model_type = iotdb_config.get("model_type", "unknown")
+ input_length = iotdb_config.get("input_token_len", 96)
+ output_length = iotdb_config.get("output_token_lens", [96])[0]
+
+ ainode_config = {
+ "configs": {
+ "input_shape": [input_length, 1],
+ "output_shape": [output_length, 1],
+ "input_type": ["float32"],
+ "output_type": ["float32"],
+ },
+ "attributes": {
+ "model_type": model_type,
+ "iotdb_model": True,
+ "original_config": iotdb_config,
+ },
+ }
+
+ logger.debug(f"转换IoTDB配置: {model_type} -> AINode格式")
Review Comment:
Use English
##########
iotdb-core/ainode/ainode/core/model/model_factory.py:
##########
@@ -90,39 +146,167 @@ def _download_file(url: str, storage_path: str) -> None:
logger.debug(f"download file from {url} to {storage_path} success")
+def _download_file_with_fallback(
+ base_uri: str, file_candidates: list, storage_path: str
+) -> str:
+ """
+ Try downloading files in priority order
+
+ Args:
+ base_uri: Base URI
+ file_candidates: List of candidate filenames
+ storage_path: Local storage path
+
+ Returns:
+ The successfully downloaded filename
+ """
+ base_uri = base_uri if base_uri.endswith("/") else base_uri + "/"
+
+ for filename in file_candidates:
+ try:
+ file_url = urljoin(base_uri, filename)
+ _download_file(file_url, storage_path)
+ logger.info(f"Successfully downloaded file: {filename}")
+ return filename
+ except Exception as e:
+ logger.debug(f"Failed to download file {filename}: {e}")
+ continue
+
+ raise InvalidUriError(
+ f"Unable to download any candidate file from {base_uri}:
{file_candidates}"
+ )
+
+
def _register_model_from_network(
uri: str, model_storage_path: str, config_storage_path: str
) -> [TConfigs, str]:
"""
- Args:
- uri: network dir path of model to register, where model.pt and
config.yaml are required,
- e.g. https://huggingface.co/user/modelname/resolve/main/
- model_storage_path: path to save model.pt
- config_storage_path: path to save config.yaml
- Returns:
- configs: TConfigs
- attributes: str
+ Register model from network with full integration of config_parser and
safetensor_loader
"""
- # concat uri to get complete url
uri = uri if uri.endswith("/") else uri + "/"
- target_model_path = urljoin(uri, DEFAULT_MODEL_FILE_NAME)
- target_config_path = urljoin(uri, DEFAULT_CONFIG_FILE_NAME)
- # download config file
- _download_file(target_config_path, config_storage_path)
+ # Try downloading configuration file (IoTDB format preferred)
+ try:
+ config_filename = _download_file_with_fallback(
+ uri, IOTDB_CONFIG_FILES + [DEFAULT_CONFIG_FILE_NAME],
config_storage_path
+ )
+ format_type = "iotdb" if config_filename in IOTDB_CONFIG_FILES else
"legacy"
+ except Exception as e:
+ logger.error(f"Failed to download config file: {e}")
+ raise InvalidUriError(uri)
- # read and parse config dict from config.yaml
- with open(config_storage_path, "r", encoding="utf-8") as file:
- config_dict = yaml.safe_load(file)
- configs, attributes = _parse_inference_config(config_dict)
+ # Parse configuration file using config_parser
+ try:
+ config_dict = parse_config_file(config_storage_path)
+
+ if format_type == "iotdb":
+ # Validate IoTDB configuration
+ if not validate_iotdb_config(config_dict):
+ raise BadConfigValueError(
+ "config_file",
+ config_storage_path,
+ "IoTDB configuration validation failed",
+ )
+
+ # Convert IoTDB config to AINode format
+ ainode_config = convert_iotdb_config_to_ainode_format(config_dict)
+ configs, attributes = _parse_inference_config(ainode_config)
+ else:
+ # Handle legacy format
+ configs, attributes = _parse_inference_config(config_dict)
+ except Exception as e:
+ logger.error(f"Failed to parse config file: {e}")
+ raise BadConfigValueError("config_file", config_storage_path, str(e))
+
+ # Download model weight file
+ try:
+ weight_candidates = (
+ WEIGHT_FORMAT_PRIORITY
+ if format_type == "iotdb"
+ else [DEFAULT_MODEL_FILE_NAME]
+ )
+ weight_filename = _download_file_with_fallback(
+ uri, weight_candidates, model_storage_path
+ )
+
+ # Validate downloaded weight file with safetensor_loader
+ if format_type == "iotdb":
+ try:
+ weights = load_weights_as_state_dict(model_storage_path)
+ logger.info(
+ f"Weight file validated successfully with {len(weights)}
parameters"
+ )
+ except Exception as e:
+ logger.error(f"Failed to validate downloaded weight file: {e}")
+ raise InvalidUriError(f"Corrupted weight file:
{weight_filename}")
+
+ except Exception as e:
+ logger.error(f"Failed to download model file: {e}")
Review Comment:
If exception happens, should we clear the folder?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]