yeandy commented on code in PR #17368:
URL: https://github.com/apache/beam/pull/17368#discussion_r851304959


##########
sdks/python/setup.py:
##########
@@ -169,6 +170,7 @@ def get_version():
     'pytest>=4.4.0,<5.0',
     'pytest-xdist>=1.29.0,<2',
     'pytest-timeout>=1.3.3,<2',
+    'scikit-learn>=0.24.2',

Review Comment:
   How does this affect `scikit-learn` in the 
[base_image_requirements_manual.txt](https://github.com/apache/beam/blob/master/sdks/python/container/base_image_requirements_manual.txt#L42)?



##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+  PICKLE = 1
+  JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+  def run_inference(self, batch: List[numpy.array],
+                    model: Any) -> Iterable[numpy.array]:
+    # vectorize data for better performance
+    vectorized_batch = numpy.stack(batch, axis=0)
+    predictions = model.predict(vectorized_batch)
+    return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+  def get_num_bytes(self, batch: List[numpy.array]) -> int:
+    """Returns the number of bytes of data for a batch."""
+    return sum(sys.getsizeof(element) for element in batch)
+
+
+class SKLearnModelLoader(base.ModelLoader):
+  def __init__(
+      self,
+      serialization: SerializationType = SerializationType.PICKLE,

Review Comment:
   ```suggestion
         serialization_type: SerializationType = SerializationType.PICKLE,
   ```
   Would it be more clear to have `serialization_type`? 



##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+  PICKLE = 1
+  JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+  def run_inference(self, batch: List[numpy.array],
+                    model: Any) -> Iterable[numpy.array]:
+    # vectorize data for better performance
+    vectorized_batch = numpy.stack(batch, axis=0)
+    predictions = model.predict(vectorized_batch)
+    return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+  def get_num_bytes(self, batch: List[numpy.array]) -> int:
+    """Returns the number of bytes of data for a batch."""
+    return sum(sys.getsizeof(element) for element in batch)

Review Comment:
   Does `sys.getsizeof(element)` return the size in bytes of all features in 
the `element` numpy array? i.e. if each `element` has 4 numeric features, with 
each feature as 4 bytes, then it will return 16? And then if we have 2 
`elements` in the `batch`, then we will return 32?



##########
sdks/python/setup.py:
##########
@@ -159,6 +159,7 @@ def get_version():
 
 REQUIRED_TEST_PACKAGES = [
     'freezegun>=0.3.12',
+    'joblib>=1.1.0',

Review Comment:
   Should we have `joblib` be in the `REQUIRED_PACKAGES`? technically it's 
being used in the regular `sklearn_loader.py` file.



##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+  PICKLE = 1
+  JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+  def run_inference(self, batch: List[numpy.array],
+                    model: Any) -> Iterable[numpy.array]:
+    # vectorize data for better performance
+    vectorized_batch = numpy.stack(batch, axis=0)
+    predictions = model.predict(vectorized_batch)
+    return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+  def get_num_bytes(self, batch: List[numpy.array]) -> int:
+    """Returns the number of bytes of data for a batch."""
+    return sum(sys.getsizeof(element) for element in batch)
+
+
+class SKLearnModelLoader(base.ModelLoader):
+  def __init__(
+      self,
+      serialization: SerializationType = SerializationType.PICKLE,
+      model_uri: str = ''):
+    self._serialization = serialization
+    self._model_uri = model_uri

Review Comment:
   Should we try to be as consistent as possible across frameworks? For 
example, for Pytorch, I use `state_dict_path`, but I could change it to be 
`state_dict_uri`. 



##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+  PICKLE = 1
+  JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):

Review Comment:
   Silly question, but for the sake of consistency and ease of use, what should 
the naming convention be for the different frameworks? For example, technically 
`PyTorch` has capital P and T, but in my implementation, I use `Pytorch` for 
simplicity. (I can change it though)
   
   And for `Scikit-learn`, it's often abbreviated as `sklearn` or `Sklearn`, 
without the capital K or L. Should we change it to `SklearnInferenceRunner`?
   



##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+  PICKLE = 1
+  JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+  def run_inference(self, batch: List[numpy.array],
+                    model: Any) -> Iterable[numpy.array]:
+    # vectorize data for better performance
+    vectorized_batch = numpy.stack(batch, axis=0)
+    predictions = model.predict(vectorized_batch)
+    return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]

Review Comment:
   Nice!



##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+  PICKLE = 1
+  JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+  def run_inference(self, batch: List[numpy.array],
+                    model: Any) -> Iterable[numpy.array]:
+    # vectorize data for better performance
+    vectorized_batch = numpy.stack(batch, axis=0)
+    predictions = model.predict(vectorized_batch)
+    return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+  def get_num_bytes(self, batch: List[numpy.array]) -> int:
+    """Returns the number of bytes of data for a batch."""
+    return sum(sys.getsizeof(element) for element in batch)
+
+
+class SKLearnModelLoader(base.ModelLoader):
+  def __init__(
+      self,
+      serialization: SerializationType = SerializationType.PICKLE,
+      model_uri: str = ''):
+    self._serialization = serialization
+    self._model_uri = model_uri
+    self._inference_runner = SKLearnInferenceRunner()
+
+  def load_model(self):
+    """Loads and initializes a model for processing."""
+    file = FileSystems.open(self._model_uri, 'rb')
+    if self._serialization == SerializationType.PICKLE:
+      return pickle.load(file)
+    elif self._serialization == SerializationType.JOBLIB:
+      return joblib.load(file)
+    raise ValueError('No supported serialization type.')

Review Comment:
   Assuming the user is only picking from the `SerializationType` enums (and 
using type checking), we will never hit this case, right?
   
   Can we add in the error message the value of `self._serialization`? And also 
add a test for this?
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to