CRZbulabula commented on code in PR #15978:
URL: https://github.com/apache/iotdb/pull/15978#discussion_r2230712936


##########
iotdb-core/ainode/ainode/core/model/model_info.py:
##########
@@ -36,6 +37,9 @@ class BuiltInModelType(Enum):
     TIMER_XL = "Timer-XL"
     # sundial
     SUNDIAL = "Timer-Sundial"
+    # transformer models
+    TIMESFM = "TimesFM"
+    CHRONOS = "Chronos"

Review Comment:
   Remove these codes



##########
iotdb-core/ainode/ainode/core/model/model_storage.py:
##########
@@ -310,24 +313,105 @@ def load_model(
             else:
                 # load the user-defined model
                 model_dir = os.path.join(self._model_dir, f"{model_id}")
-                model_path = os.path.join(model_dir, DEFAULT_MODEL_FILE_NAME)
-
-                if not os.path.exists(model_path):
-                    raise ModelNotExistError(model_path)
-                model = torch.jit.load(model_path)
-                if (
-                    isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
-                    or not acceleration
-                ):
+                if self._is_huggingface_format(model_dir):
+                    return self._load_huggingface_user_model(model_dir, 
inference_attrs)
+                else:
+                    model_path = os.path.join(model_dir, 
DEFAULT_MODEL_FILE_NAME)
+
+                    if not os.path.exists(model_path):
+                        raise ModelNotExistError(model_path)
+                    model = torch.jit.load(model_path)
+                    if (
+                        isinstance(model, 
torch._dynamo.eval_frame.OptimizedModule)
+                        or not acceleration
+                    ):
+                        return model
+
+                    try:
+                        model = torch.compile(model)
+                    except Exception as e:
+                        logger.warning(
+                            f"acceleration failed, fallback to normal mode: 
{str(e)}"
+                        )
                     return model
-
-                try:
-                    model = torch.compile(model)
-                except Exception as e:
-                    logger.warning(
-                        f"acceleration failed, fallback to normal mode: 
{str(e)}"
-                    )
-                return model
+                
+    def _is_huggingface_format(self, model_dir: str) -> bool:
+        config_path = os.path.join(model_dir, DEFAULT_CONFIG_FILE_NAME)
+        if os.path.exists(config_path):
+            try:
+                with open(config_path, "r", encoding="utf-8") as f:
+                    config_dict = yaml.safe_load(f)
+                model_type = config_dict.get("attributes", 
{}).get("model_type", "")
+                return model_type == "huggingface_transformers"

Review Comment:
   Seems none of the model will name its `model_type==hf_transformers`



##########
iotdb-core/ainode/ainode/core/model/model_storage.py:
##########
@@ -310,24 +313,105 @@ def load_model(
             else:
                 # load the user-defined model
                 model_dir = os.path.join(self._model_dir, f"{model_id}")
-                model_path = os.path.join(model_dir, DEFAULT_MODEL_FILE_NAME)
-
-                if not os.path.exists(model_path):
-                    raise ModelNotExistError(model_path)
-                model = torch.jit.load(model_path)
-                if (
-                    isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
-                    or not acceleration
-                ):
+                if self._is_huggingface_format(model_dir):
+                    return self._load_huggingface_user_model(model_dir, 
inference_attrs)
+                else:
+                    model_path = os.path.join(model_dir, 
DEFAULT_MODEL_FILE_NAME)
+
+                    if not os.path.exists(model_path):
+                        raise ModelNotExistError(model_path)
+                    model = torch.jit.load(model_path)
+                    if (
+                        isinstance(model, 
torch._dynamo.eval_frame.OptimizedModule)
+                        or not acceleration
+                    ):
+                        return model
+
+                    try:
+                        model = torch.compile(model)
+                    except Exception as e:
+                        logger.warning(
+                            f"acceleration failed, fallback to normal mode: 
{str(e)}"
+                        )
                     return model
-
-                try:
-                    model = torch.compile(model)
-                except Exception as e:
-                    logger.warning(
-                        f"acceleration failed, fallback to normal mode: 
{str(e)}"
-                    )
-                return model
+                
+    def _is_huggingface_format(self, model_dir: str) -> bool:

Review Comment:
   What's the difference between this interface and that in `model_info.py`? U 
need only one method for the same function.



##########
iotdb-core/ainode/ainode/core/model/model_info.py:
##########
@@ -54,6 +58,34 @@ def get_built_in_model_type(model_type: str) -> 
BuiltInModelType:
         raise ValueError(f"Invalid built-in model type: {model_type}")
     return BuiltInModelType(model_type)
 
+def get_model_loading_strategy(model_id_or_uri: str) -> str:

Review Comment:
   U can add a new Enum for classifying the file type, as the result type of 
this interface. BTW, I think `safetensors` and `pt` is enough



##########
iotdb-core/ainode/ainode/core/model/model_factory.py:
##########
@@ -281,11 +283,99 @@ def _parse_inference_config(config_dict):
 
 
 def fetch_model_by_uri(uri: str, model_storage_path: str, config_storage_path: 
str):
-    is_network_path, uri = _parse_uri(uri)
-
-    if is_network_path:
-        return _register_model_from_network(
-            uri, model_storage_path, config_storage_path
-        )
+    """
+    """
+    is_network_path, parsed_uri = _parse_uri(uri)
+    strategy = get_model_loading_strategy(uri)
+    
+    if strategy == "network_huggingface":
+        return _register_huggingface_model_from_network(parsed_uri, 
model_storage_path, config_storage_path)
+    elif strategy == "local_huggingface":
+        return _register_huggingface_model_from_local(parsed_uri, 
model_storage_path, config_storage_path)
+    elif strategy == "local_pytorch":
+        return _register_model_from_local(parsed_uri, model_storage_path, 
config_storage_path)
+    elif is_network_path:
+        return _register_model_from_network(parsed_uri, model_storage_path, 
config_storage_path)
     else:
-        return _register_model_from_local(uri, model_storage_path, 
config_storage_path)
+        return _register_model_from_local(parsed_uri, model_storage_path, 
config_storage_path)
+    
+def _register_huggingface_model_from_network(repo_id: str, model_storage_path: 
str, config_storage_path: str):
+    import tempfile
+    
+    temp_dir = tempfile.mkdtemp(prefix="hf_model_")
+    
+    try:
+        snapshot_download(
+            repo_id=repo_id,
+            local_dir=temp_dir,
+            local_dir_use_symlinks=False,
+        )
+        
+        return _process_huggingface_files(temp_dir, model_storage_path, 
config_storage_path)
+        
+    except Exception as e:
+        logger.error(f"Failed to download HuggingFace model {repo_id}: {e}")
+        raise InvalidUriError(repo_id)
+
+def _register_huggingface_model_from_local(local_path: str, 
model_storage_path: str, config_storage_path: str):
+    return _process_huggingface_files(local_path, model_storage_path, 
config_storage_path)
+
+def _process_huggingface_files(source_dir: str, model_storage_path: str, 
config_storage_path: str):
+    import glob
+    import json
+    
+    config_file = None
+    for config_name in ["config.json", "model_config.json"]:
+        config_path = os.path.join(source_dir, config_name)
+        if os.path.exists(config_path):
+            config_file = config_path
+            break
+    
+    if not config_file:
+        raise InvalidUriError(f"No config.json found in {source_dir}")
+    
+    safetensors_files = glob.glob(os.path.join(source_dir, "*.safetensors"))
+    if not safetensors_files:
+        raise InvalidUriError(f"No .safetensors files found in {source_dir}")
+    
+    with open(config_file, "r", encoding="utf-8") as f:
+        hf_config = json.load(f)
+    
+    ainode_config = _convert_hf_config_to_ainode(hf_config, source_dir)
+    
+    with open(config_storage_path, "w", encoding="utf-8") as f:
+        yaml.dump(ainode_config, f)
+    
+    with open(model_storage_path, "w") as f:
+        f.write(f"# HuggingFace model from: {source_dir}\n")
+        f.write(f"# Model files: {[os.path.basename(f) for f in 
safetensors_files]}\n")
+        f.write(f"# Source directory: {source_dir}\n")
+    
+    configs, attributes = _parse_inference_config(ainode_config)
+    return configs, attributes
+
+def _convert_hf_config_to_ainode(hf_config: dict, source_dir: str) -> dict:

Review Comment:
   Currently, we do not need to convert the config



##########
iotdb-core/ainode/ainode/core/model/model_factory.py:
##########
@@ -281,11 +283,99 @@ def _parse_inference_config(config_dict):
 
 
 def fetch_model_by_uri(uri: str, model_storage_path: str, config_storage_path: 
str):
-    is_network_path, uri = _parse_uri(uri)
-
-    if is_network_path:
-        return _register_model_from_network(
-            uri, model_storage_path, config_storage_path
-        )
+    """
+    """
+    is_network_path, parsed_uri = _parse_uri(uri)
+    strategy = get_model_loading_strategy(uri)
+    
+    if strategy == "network_huggingface":
+        return _register_huggingface_model_from_network(parsed_uri, 
model_storage_path, config_storage_path)
+    elif strategy == "local_huggingface":
+        return _register_huggingface_model_from_local(parsed_uri, 
model_storage_path, config_storage_path)
+    elif strategy == "local_pytorch":
+        return _register_model_from_local(parsed_uri, model_storage_path, 
config_storage_path)
+    elif is_network_path:
+        return _register_model_from_network(parsed_uri, model_storage_path, 
config_storage_path)
     else:
-        return _register_model_from_local(uri, model_storage_path, 
config_storage_path)
+        return _register_model_from_local(parsed_uri, model_storage_path, 
config_storage_path)
+    
+def _register_huggingface_model_from_network(repo_id: str, model_storage_path: 
str, config_storage_path: str):
+    import tempfile

Review Comment:
   import all libs at the top of this file



##########
iotdb-core/ainode/ainode/core/model/model_storage.py:
##########
@@ -310,24 +313,105 @@ def load_model(
             else:
                 # load the user-defined model
                 model_dir = os.path.join(self._model_dir, f"{model_id}")
-                model_path = os.path.join(model_dir, DEFAULT_MODEL_FILE_NAME)
-
-                if not os.path.exists(model_path):
-                    raise ModelNotExistError(model_path)
-                model = torch.jit.load(model_path)
-                if (
-                    isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
-                    or not acceleration
-                ):
+                if self._is_huggingface_format(model_dir):
+                    return self._load_huggingface_user_model(model_dir, 
inference_attrs)
+                else:
+                    model_path = os.path.join(model_dir, 
DEFAULT_MODEL_FILE_NAME)
+
+                    if not os.path.exists(model_path):
+                        raise ModelNotExistError(model_path)
+                    model = torch.jit.load(model_path)
+                    if (
+                        isinstance(model, 
torch._dynamo.eval_frame.OptimizedModule)
+                        or not acceleration
+                    ):
+                        return model
+
+                    try:
+                        model = torch.compile(model)
+                    except Exception as e:
+                        logger.warning(
+                            f"acceleration failed, fallback to normal mode: 
{str(e)}"
+                        )
                     return model
-
-                try:
-                    model = torch.compile(model)
-                except Exception as e:
-                    logger.warning(
-                        f"acceleration failed, fallback to normal mode: 
{str(e)}"
-                    )
-                return model
+                
+    def _is_huggingface_format(self, model_dir: str) -> bool:
+        config_path = os.path.join(model_dir, DEFAULT_CONFIG_FILE_NAME)
+        if os.path.exists(config_path):
+            try:
+                with open(config_path, "r", encoding="utf-8") as f:
+                    config_dict = yaml.safe_load(f)
+                model_type = config_dict.get("attributes", 
{}).get("model_type", "")
+                return model_type == "huggingface_transformers"
+            except:
+                return False
+        return False
+    
+    def _load_huggingface_user_model(self, model_dir: str, inference_attrs: 
Dict[str, str]) -> Callable:
+        config_path = os.path.join(model_dir, DEFAULT_CONFIG_FILE_NAME)
+        
+        with open(config_path, "r", encoding="utf-8") as f:
+            config_dict = yaml.safe_load(f)
+        
+        source_dir = config_dict.get("attributes", {}).get("source_dir", 
model_dir)
+        predict_length = int(inference_attrs.get("predict_length", 
+                                                config_dict.get("attributes", 
{}).get("predict_length", 96)))
+        
+        try:           
+            config = AutoConfig.from_pretrained(source_dir, 
trust_remote_code=True)
+            model = AutoModel.from_pretrained(source_dir, config=config, 
trust_remote_code=True)
+            
+            def inference(data):
+                return self._generic_transformers_inference(model, data, 
predict_length)
+            
+            return inference
+            
+        except Exception as e:
+            logger.error(f"Failed to load HuggingFace model: {e}")
+
+    def _generic_transformers_inference(self, model, data, predict_length: 
int):
+        try:
+            if isinstance(data, np.ndarray):
+                if len(data.shape) == 1:
+                    input_data = data.tolist()
+                else:
+                    input_data = [row.tolist() for row in data]
+            else:
+                input_data = data if isinstance(data, list) else [data]
+            
+            inference_methods = [
+                ('predict', lambda: model.predict(input_data)),
+                ('generate', lambda: model.generate(inputs=input_data, 
max_length=predict_length)),
+                ('forward', lambda: model(input_data)),
+                ('__call__', lambda: model(input_data)),
+            ]

Review Comment:
   Currently, we do not need to guarantee that all `USER-DEFINED` transformer 
models ara available for inferencing. We only support to download and load them 
into memory. Hence, remove this interface first.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@iotdb.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to