[
https://issues.apache.org/jira/browse/BEAM-13983?focusedWorklogId=761859&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-761859
]
ASF GitHub Bot logged work on BEAM-13983:
-----------------------------------------
Author: ASF GitHub Bot
Created on: 25/Apr/22 15:40
Start Date: 25/Apr/22 15:40
Worklog Time Spent: 10m
Work Description: ryanthompson591 commented on code in PR #17368:
URL: https://github.com/apache/beam/pull/17368#discussion_r857769301
##########
sdks/python/apache_beam/ml/inference/sklearn_loader.py:
##########
@@ -0,0 +1,78 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import enum
+import pickle
+import sys
+from typing import Any
+from typing import Iterable
+from typing import List
+
+import numpy
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.api import PredictionResult
+from apache_beam.ml.inference.base import InferenceRunner
+from apache_beam.ml.inference.base import ModelLoader
+
+try:
+ import joblib
+except ImportError:
+ # joblib is an optional dependency.
+ pass
+
+
+class ModelFileType(enum.Enum):
+ PICKLE = 1
+ JOBLIB = 2
+
+
+class SklearnInferenceRunner(InferenceRunner):
+ def run_inference(self, batch: List[numpy.array],
+ model: Any) -> Iterable[numpy.array]:
+ # vectorize data for better performance
+ vectorized_batch = numpy.stack(batch, axis=0)
+ predictions = model.predict(vectorized_batch)
+ return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+ def get_num_bytes(self, batch: List[numpy.array]) -> int:
+ """Returns the number of bytes of data for a batch."""
+ return sum(sys.getsizeof(element) for element in batch)
+
+
+class SklearnModelLoader(ModelLoader):
+ def __init__(
+ self,
+ model_file_type: ModelFileType = ModelFileType.PICKLE,
+ model_uri: str = ''):
+ self._model_file_type = model_file_type
+ self._model_uri = model_uri
+ self._inference_runner = SklearnInferenceRunner()
+
+ def load_model(self):
+ """Loads and initializes a model for processing."""
+ file = FileSystems.open(self._model_uri, 'rb')
+ if self._model_file_type == ModelFileType.PICKLE:
+ return pickle.load(file)
+ elif self._model_file_type == ModelFileType.JOBLIB:
+ if not joblib:
+ raise ImportError('Joblib not available in SklearnModelLoader.')
Review Comment:
Are you saying that we should not crash here? In my opinion, if this
requirement is missing the pipeline should fail and crash since this whole
transform will simply not work. This is consistent with when other transforms
fail.
I made this error message a little more consistent, and pointed to docs.
Other import errors:
https://github.com/apache/beam/blob/3f2e3c7c9eccb9d40370cbc70e9a451a4b5573f5/sdks/python/apache_beam/ml/gcp/visionml.py#L41
https://github.com/apache/beam/blob/3f2e3c7c9eccb9d40370cbc70e9a451a4b5573f5/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py#L298
Issue Time Tracking
-------------------
Worklog Id: (was: 761859)
Time Spent: 4h 50m (was: 4h 40m)
> Implement RunInference for Scikit-learn
> ---------------------------------------
>
> Key: BEAM-13983
> URL: https://issues.apache.org/jira/browse/BEAM-13983
> Project: Beam
> Issue Type: Sub-task
> Components: sdk-py-core
> Reporter: Andy Ye
> Assignee: Ryan Thompson
> Priority: P2
> Labels: run-inference
> Time Spent: 4h 50m
> Remaining Estimate: 0h
>
> Implement RunInference for Scikit-learn as described in the design doc
> [https://s.apache.org/inference-sklearn-pytorch]
> There will be a sklearn_impl.py file that contains SklearnModelLoader and
> SlkearnInferenceRunner classes.
--
This message was sent by Atlassian Jira
(v8.20.7#820007)