CRZbulabula commented on code in PR #16907:
URL: https://github.com/apache/iotdb/pull/16907#discussion_r2622626382


##########
iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py:
##########
@@ -29,59 +29,61 @@ def __init__(self, model_info, **model_kwargs):
         self.device = model_kwargs.get("device", "cpu")
         self.model = load_model(model_info, device_map=self.device, 
**model_kwargs)
 
-    def _preprocess(self, inputs):
+    @abstractmethod
+    def preprocess(self, inputs):
         """
         Preprocess the input before inference, including shape validation and 
value transformation.
         """
-        return inputs
+        raise NotImplementedError("preprocess not implemented")
 
-    def _postprocess(self, output: torch.Tensor):
+    @abstractmethod
+    def postprocess(self, output: torch.Tensor):
         """
         Post-process the outputs after the entire inference task.
         """
-        return output
+        raise NotImplementedError("postprocess not implemented")
 
 
 class ForecastPipeline(BasicPipeline):
     def __init__(self, model_info, **model_kwargs):
         super().__init__(model_info, model_kwargs=model_kwargs)
 
-    def _preprocess(self, inputs):
+    def preprocess(self, inputs):
         return inputs
 
     @abstractmethod
     def forecast(self, inputs, **infer_kwargs):
         pass
 
-    def _postprocess(self, output: torch.Tensor):
+    def postprocess(self, output: torch.Tensor):

Review Comment:
   And the shape of `inputs` should be 3-d.



##########
iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py:
##########
@@ -29,59 +29,61 @@ def __init__(self, model_info, **model_kwargs):
         self.device = model_kwargs.get("device", "cpu")
         self.model = load_model(model_info, device_map=self.device, 
**model_kwargs)
 
-    def _preprocess(self, inputs):
+    @abstractmethod
+    def preprocess(self, inputs):
         """
         Preprocess the input before inference, including shape validation and 
value transformation.
         """
-        return inputs
+        raise NotImplementedError("preprocess not implemented")
 
-    def _postprocess(self, output: torch.Tensor):
+    @abstractmethod
+    def postprocess(self, output: torch.Tensor):
         """
         Post-process the outputs after the entire inference task.
         """
-        return output
+        raise NotImplementedError("postprocess not implemented")
 
 
 class ForecastPipeline(BasicPipeline):
     def __init__(self, model_info, **model_kwargs):
         super().__init__(model_info, model_kwargs=model_kwargs)
 
-    def _preprocess(self, inputs):
+    def preprocess(self, inputs):
         return inputs
 
     @abstractmethod
     def forecast(self, inputs, **infer_kwargs):
         pass
 
-    def _postprocess(self, output: torch.Tensor):
+    def postprocess(self, output: torch.Tensor):

Review Comment:
   U should add some annotations to notify the shape of `inputs` and `output` 
for forecasting task.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to