yeandy commented on code in PR #17368:
URL: https://github.com/apache/beam/pull/17368#discussion_r851304959
##########
sdks/python/setup.py:
##########
@@ -169,6 +170,7 @@ def get_version():
'pytest>=4.4.0,<5.0',
'pytest-xdist>=1.29.0,<2',
'pytest-timeout>=1.3.3,<2',
+ 'scikit-learn>=0.24.2',
Review Comment:
How does this affect `scikit-learn` in the
[base_image_requirements_manual.txt](https://github.com/apache/beam/blob/master/sdks/python/container/base_image_requirements_manual.txt#L42)?
##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+ PICKLE = 1
+ JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+ def run_inference(self, batch: List[numpy.array],
+ model: Any) -> Iterable[numpy.array]:
+ # vectorize data for better performance
+ vectorized_batch = numpy.stack(batch, axis=0)
+ predictions = model.predict(vectorized_batch)
+ return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+ def get_num_bytes(self, batch: List[numpy.array]) -> int:
+ """Returns the number of bytes of data for a batch."""
+ return sum(sys.getsizeof(element) for element in batch)
+
+
+class SKLearnModelLoader(base.ModelLoader):
+ def __init__(
+ self,
+ serialization: SerializationType = SerializationType.PICKLE,
Review Comment:
```suggestion
serialization_type: SerializationType = SerializationType.PICKLE,
```
Would it be more clear to have `serialization_type`?
##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+ PICKLE = 1
+ JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+ def run_inference(self, batch: List[numpy.array],
+ model: Any) -> Iterable[numpy.array]:
+ # vectorize data for better performance
+ vectorized_batch = numpy.stack(batch, axis=0)
+ predictions = model.predict(vectorized_batch)
+ return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+ def get_num_bytes(self, batch: List[numpy.array]) -> int:
+ """Returns the number of bytes of data for a batch."""
+ return sum(sys.getsizeof(element) for element in batch)
Review Comment:
Does `sys.getsizeof(element)` return the size in bytes of all features in
the `element` numpy array? i.e. if each `element` has 4 numeric features, with
each feature as 4 bytes, then it will return 16? And then if we have 2
`elements` in the `batch`, then we will return 32?
##########
sdks/python/setup.py:
##########
@@ -159,6 +159,7 @@ def get_version():
REQUIRED_TEST_PACKAGES = [
'freezegun>=0.3.12',
+ 'joblib>=1.1.0',
Review Comment:
Should we have `joblib` be in the `REQUIRED_PACKAGES`? technically it's
being used in the regular `sklearn_loader.py` file.
##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+ PICKLE = 1
+ JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+ def run_inference(self, batch: List[numpy.array],
+ model: Any) -> Iterable[numpy.array]:
+ # vectorize data for better performance
+ vectorized_batch = numpy.stack(batch, axis=0)
+ predictions = model.predict(vectorized_batch)
+ return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+ def get_num_bytes(self, batch: List[numpy.array]) -> int:
+ """Returns the number of bytes of data for a batch."""
+ return sum(sys.getsizeof(element) for element in batch)
+
+
+class SKLearnModelLoader(base.ModelLoader):
+ def __init__(
+ self,
+ serialization: SerializationType = SerializationType.PICKLE,
+ model_uri: str = ''):
+ self._serialization = serialization
+ self._model_uri = model_uri
Review Comment:
Should we try to be as consistent as possible across frameworks? For
example, for Pytorch, I use `state_dict_path`, but I could change it to be
`state_dict_uri`.
##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+ PICKLE = 1
+ JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
Review Comment:
Silly question, but for the sake of consistency and ease of use, what should
the naming convention be for the different frameworks? For example, technically
`PyTorch` has capital P and T, but in my implementation, I use `Pytorch` for
simplicity. (I can change it though)
And for `Scikit-learn`, it's often abbreviated as `sklearn` or `Sklearn`,
without the capital K or L. Should we change it to `SklearnInferenceRunner`?
##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+ PICKLE = 1
+ JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+ def run_inference(self, batch: List[numpy.array],
+ model: Any) -> Iterable[numpy.array]:
+ # vectorize data for better performance
+ vectorized_batch = numpy.stack(batch, axis=0)
+ predictions = model.predict(vectorized_batch)
+ return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]
Review Comment:
Nice!
##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+import pickle
+import sys
+from dataclasses import dataclass
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import joblib
+import numpy
+
+import apache_beam.ml.inference.api as api
+import apache_beam.ml.inference.base as base
+import sklearn_loader
+from apache_beam.io.filesystems import FileSystems
+
+
+class SerializationType(enum.Enum):
+ PICKLE = 1
+ JOBLIB = 2
+
+
+class SKLearnInferenceRunner(base.InferenceRunner):
+ def run_inference(self, batch: List[numpy.array],
+ model: Any) -> Iterable[numpy.array]:
+ # vectorize data for better performance
+ vectorized_batch = numpy.stack(batch, axis=0)
+ predictions = model.predict(vectorized_batch)
+ return [api.PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+ def get_num_bytes(self, batch: List[numpy.array]) -> int:
+ """Returns the number of bytes of data for a batch."""
+ return sum(sys.getsizeof(element) for element in batch)
+
+
+class SKLearnModelLoader(base.ModelLoader):
+ def __init__(
+ self,
+ serialization: SerializationType = SerializationType.PICKLE,
+ model_uri: str = ''):
+ self._serialization = serialization
+ self._model_uri = model_uri
+ self._inference_runner = SKLearnInferenceRunner()
+
+ def load_model(self):
+ """Loads and initializes a model for processing."""
+ file = FileSystems.open(self._model_uri, 'rb')
+ if self._serialization == SerializationType.PICKLE:
+ return pickle.load(file)
+ elif self._serialization == SerializationType.JOBLIB:
+ return joblib.load(file)
+ raise ValueError('No supported serialization type.')
Review Comment:
Assuming the user is only picking from the `SerializationType` enums (and
using type checking), we will never hit this case, right?
Can we add in the error message the value of `self._serialization`? And also
add a test for this?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]