[
https://issues.apache.org/jira/browse/BEAM-13982?focusedWorklogId=752918&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-752918
]
ASF GitHub Bot logged work on BEAM-13982:
-----------------------------------------
Author: ASF GitHub Bot
Created on: 05/Apr/22 14:15
Start Date: 05/Apr/22 14:15
Worklog Time Spent: 10m
Work Description: ryanthompson591 commented on code in PR #16970:
URL: https://github.com/apache/beam/pull/16970#discussion_r842839800
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -0,0 +1,252 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""An extensible run inference transform."""
+
+import logging
+import os
+import pickle
+import platform
+import sys
+import time
+from typing import Any
+from typing import Iterable
+from typing import Tuple
+
+import apache_beam as beam
+from apache_beam.utils import shared
+
+try:
+ # pylint: disable=g-import-not-at-top
+ import resource
+except ImportError:
+ resource = None
+
+_MILLISECOND_TO_MICROSECOND = 1000
+_MICROSECOND_TO_NANOSECOND = 1000
+_SECOND_TO_MICROSECOND = 1000000
+
+
+class InferenceRunner():
+ """Implements running inferences for a framework."""
+ def run_inference(self, batch: Any, model: Any) -> Iterable[Any]:
+ """Runs inferences on a batch of examples and returns an Iterable of
Predictions."""
+ raise NotImplementedError(type(self))
+
+ def get_num_bytes(self, batch: Any) -> int:
+ """Returns the number of bytes of data for a batch."""
+ return len(pickle.dumps(batch))
+
+ def get_metrics_namespace(self) -> str:
+ """Returns a namespace for metrics collected by the RunInference
transform."""
+ return 'RunInference'
+
+
+class ModelLoader():
+ """Has the ability to load an ML model."""
+ def load_model(self) -> Any:
+ """Loads and initializes a model for processing."""
+ raise NotImplementedError(type(self))
+
+ def get_inference_runner(self) -> InferenceRunner:
+ """Returns an implementation of InferenceRunner for this model."""
+ raise NotImplementedError(type(self))
+
+
+def _unbatch(maybe_keyed_batches: Tuple[Any, Any]):
+ keys, results = maybe_keyed_batches
+ if keys:
+ return zip(keys, results)
+ else:
+ return results
+
+
+class RunInference(beam.PTransform):
+ """An extensible transform for running inferences."""
+ def __init__(self, model_loader: ModelLoader, clock=None):
+ self._model_loader = model_loader
+ self._clock = clock
+
+ # TODO(BEAM-14208): Add batch_size back off in the case there
+ # are functional reasons large batch sizes cannot be handled.
+ def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
+ return (
+ pcoll
+ # TODO(BEAM-14044): Hook into the batching DoFn APIs.
+ | beam.BatchElements()
+ | beam.ParDo(
+ RunInferenceDoFn(shared.Shared(), self._model_loader, self._clock))
Review Comment:
Ok, I modified this to just initialized shared in the constructor.
I originally thought that was cleaner.
Issue Time Tracking
-------------------
Worklog Id: (was: 752918)
Time Spent: 15h 10m (was: 15h)
> Implement Generic RunInference Base class
> -----------------------------------------
>
> Key: BEAM-13982
> URL: https://issues.apache.org/jira/browse/BEAM-13982
> Project: Beam
> Issue Type: Sub-task
> Components: sdk-py-core
> Reporter: Andy Ye
> Assignee: Ryan Thompson
> Priority: P2
> Labels: run-inference
> Time Spent: 15h 10m
> Remaining Estimate: 0h
>
> This base class will have
> * Metrics
> * Will call dependent framework-specific classes
> * Unit tests
--
This message was sent by Atlassian Jira
(v8.20.1#820001)