CRZbulabula commented on code in PR #16794:
URL: https://github.com/apache/iotdb/pull/16794#discussion_r2559661924


##########
iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py:
##########
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import pandas as pd
+import torch
+
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import 
ForecastPipeline
+from iotdb.ainode.core.util.serde import convert_to_binary
+
+
+class SundialPipeline(ForecastPipeline):
+    def __init__(self, model_id, **infer_kwargs):
+        super().__init__(model_id, infer_kwargs=infer_kwargs)
+
+    def _preprocess(self, inputs):
+        return super()._preprocess(inputs)
+
+    def infer(self, inputs, **infer_kwargs):
+        predict_length = infer_kwargs.get("predict_length", 96)
+        num_samples = infer_kwargs.get("num_samples", 10)
+        revin = infer_kwargs.get("revin", True)
+
+        input_ids = self._preprocess(inputs)
+        output = self.model.generate(
+            input_ids,
+            max_new_tokens=predict_length,
+            num_samples=num_samples,
+            revin=revin,
+        )
+        return self._postprocess(output)
+
+    def _postprocess(self, output: torch.Tensor):
+        return output.mean(dim=1)

Review Comment:
   Maybe we should discuss this implementation tmr.



##########
iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py:
##########
@@ -47,24 +47,10 @@
 logger = Logger()
 
 
-def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus:
-    """
-    Ensure that the device IDs in the provided list are available.
-    """
-    available_devices = get_available_devices()
-    for device_id in device_id_list:
-        if device_id not in available_devices:
-            return TSStatus(
-                code=TSStatusCode.INVALID_URI_ERROR.value,
-                message=f"Device ID [{device_id}] is not available. You can 
use 'SHOW AI_DEVICES' to retrieve the available devices.",
-            )
-    return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)

Review Comment:
   Why u remove my common function and repeat its codes inside each interface 
>_<



##########
iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py:
##########
@@ -255,30 +247,20 @@ def _expand_pools_on_device(self, model_id: str, 
device_id: str, count: int):
         """
 
         def _expand_pool_on_device(*_):
-            result_queue = mp.Queue()
+            request_queue = mp.Queue()
             pool_id = self._new_pool_id.get_and_increment()
-            model_info = self._model_manager.get_model_info(model_id)
-            model_type = model_info.model_type
-            if model_type == BuiltInModelType.SUNDIAL.value:
-                config = SundialConfig()
-            elif model_type == BuiltInModelType.TIMER_XL.value:
-                config = TimerConfig()
-            else:
-                raise InferenceModelInternalError(
-                    f"Unsupported model type {model_type} for loading model 
{model_id}"
-                )
+            model_info = MODEL_MANAGER.get_model_info(model_id)
             pool = InferenceRequestPool(
                 pool_id=pool_id,
                 model_info=model_info,
                 device=device_id,
-                config=config,
-                request_queue=result_queue,
+                request_queue=request_queue,
                 result_queue=self._result_queue,
                 ready_event=mp.Event(),
             )
             pool.start()
-            self._register_pool(model_id, device_id, pool_id, pool, 
result_queue)
-            if not pool.ready_event.wait(timeout=300):
+            self._register_pool(model_id, device_id, pool_id, pool, 
request_queue)
+            if not pool.ready_event.wait(timeout=30):

Review Comment:
   This should set to 300 since it is very slow for expansion in CI env.



##########
iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py:
##########
@@ -40,130 +38,84 @@
 logger = Logger()
 
 
-@singleton
 class ModelManager:
     def __init__(self):
-        self.model_storage = ModelStorage()
-
-    def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp:
-        logger.info(f"register model {req.modelId} from {req.uri}")
+        self.models_dir = os.path.join(
+            os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir()
+        )
+        self.storage = ModelStorage(models_dir=self.models_dir)
+        self.loader = ModelLoader(storage=self.storage)
+
+        # Automatically discover all models
+        self._models = self.storage.discover_all()

Review Comment:
   No need to maintain here cause we already have them all in ModelStorage.



##########
iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py:
##########
@@ -123,13 +121,14 @@ def showAIDevices(self) -> TShowAIDevicesResp:
             deviceIdList=get_available_devices(),
         )
 
+    def inference(self, req: TInferenceReq) -> TInferenceResp:
+        return self._inference_manager.inference(req)
+
+    def forecast(self, req: TForecastReq) -> TSStatus:
+        return self._inference_manager.forecast(req)
+
+    def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
+        return ClusterManager.get_heart_beat(req)
+
     def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
         pass
-
-    def _ensure_model_is_built_in_or_fine_tuned(self, model_id: str) -> 
TSStatus:
-        if not self._model_manager.is_built_in_or_fine_tuned(model_id):
-            return TSStatus(
-                code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value,
-                message=f"Model [{model_id}] is not a built-in or fine-tuned 
model. You can use 'SHOW MODELS' to retrieve the available models.",
-            )
-        return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)

Review Comment:
   Shouldn't we judge whether the specified model is available previously and 
uniformly?



##########
iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py:
##########
@@ -86,8 +82,8 @@ def __init__(
         self._batcher = BasicBatcher()
         self._stop_event = mp.Event()
 
-        self._model = None
-        self._model_manager = None
+        # self._inference_pipeline = get_pipeline(self.model_info.model_id, 
self.device)

Review Comment:
   Remove this.



##########
iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py:
##########
@@ -193,11 +156,8 @@ def run(self):
         self._logger = Logger(
             INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
         )
-        self._model_manager = ModelManager()
         self._request_scheduler.device = self.device
-        self._model = self._model_manager.load_model(self.model_info.model_id, 
{}).to(
-            self.device
-        )
+        self._inference_pipeline = get_pipeline(self.model_info.model_id, 
self.device)

Review Comment:
   Remind to fix the warning throw by PyCharm



##########
iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py:
##########
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from abc import ABC
+
+import torch
+
+from iotdb.ainode.core.exception import InferenceModelInternalError
+from iotdb.ainode.core.manager.model_manager import get_model_manager
+
+
+class BasicPipeline(ABC):
+    def __init__(self, model_id, **infer_kwargs):
+        self.model_id = model_id
+        self.device = infer_kwargs.get("device", "cpu")
+        # self.model = get_model_manager().load_model(model_id).to(self.device)

Review Comment:
   Fix this or you cannot pass our CI haha



##########
iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py:
##########
@@ -40,130 +38,84 @@
 logger = Logger()
 
 
-@singleton

Review Comment:
   Why u remove the original singleton claim here?



##########
iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py:
##########
@@ -40,16 +40,14 @@
     ScaleActionType,
 )
 from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.manager.model_manager import ModelManager
-from iotdb.ainode.core.model.model_enums import BuiltInModelType
-from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig
-from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig
+from iotdb.ainode.core.manager.model_manager import get_model_manager
 from iotdb.ainode.core.util.atmoic_int import AtomicInt
 from iotdb.ainode.core.util.batch_executor import BatchExecutor
 from iotdb.ainode.core.util.decorator import synchronized
 from iotdb.ainode.core.util.thread_name import ThreadName
 
 logger = Logger()
+MODEL_MANAGER = get_model_manager()

Review Comment:
   I just remove the global singleton in this file... >_< Plz rebase first.



##########
iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py:
##########
@@ -0,0 +1,29 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from iotdb.ainode.core.inference.pipeline.sundial_pipeline import 
SundialPipeline
+from iotdb.ainode.core.inference.pipeline.timerxl_pipeline import 
TimerxlPipeline
+
+
+def get_pipeline(model_id, device):
+    if model_id == "timerxl":
+        return TimerxlPipeline(model_id, device=device)
+    elif model_id == "sundial":
+        return SundialPipeline(model_id, device=device)

Review Comment:
   Under current implementation, to integrate a new model, we should add 
another if-else branch here? Is it toooooo much?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to