ryanthompson591 commented on a change in pull request #16917:
URL: https://github.com/apache/beam/pull/16917#discussion_r822868000



##########
File path: sdks/python/apache_beam/ml/inference/api.py
##########
@@ -15,38 +15,29 @@
 # limitations under the License.
 #
 
+import abc
 from dataclasses import dataclass
-from enum import Enum
 from typing import Tuple
 from typing import TypeVar
 from typing import Union
 
 import apache_beam as beam
 
 
-class PyTorchDevice(Enum):
-  CPU = 1
-  GPU = 2
-
-
-class SklearnSerializationType(Enum):
-  PICKLE = 1
-  JOBLIB = 2
-
-
-@dataclass
 class BaseModelSpec:
-  model_uri: str
-
-
-@dataclass
-class PyTorchModelSpec(BaseModelSpec):
-  device: PyTorchDevice
-
+  """
+  Model factory that returns ModelLoader and
+  InferenceRunner objects to be used
+  """
+  @abc.abstractmethod
+  def get_model_loader(self):
+    "Returns ModelLoader object"

Review comment:
       I think this is good for now.  
   
   Just a note, but I don't think you need to do anything: Going forward it 
seems to me that there will be an implementation that uses a service (instead 
of loading model) a that maybe returns none for get_model_loader.  I haven't 
decided if this should return None by default or if a NotImplementedError is 
fine.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to