[ 
https://issues.apache.org/jira/browse/BEAM-13982?focusedWorklogId=752918&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-752918
 ]

ASF GitHub Bot logged work on BEAM-13982:
-----------------------------------------

                Author: ASF GitHub Bot
            Created on: 05/Apr/22 14:15
            Start Date: 05/Apr/22 14:15
    Worklog Time Spent: 10m 
      Work Description: ryanthompson591 commented on code in PR #16970:
URL: https://github.com/apache/beam/pull/16970#discussion_r842839800


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -0,0 +1,252 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""An extensible run inference transform."""
+
+import logging
+import os
+import pickle
+import platform
+import sys
+import time
+from typing import Any
+from typing import Iterable
+from typing import Tuple
+
+import apache_beam as beam
+from apache_beam.utils import shared
+
+try:
+  # pylint: disable=g-import-not-at-top
+  import resource
+except ImportError:
+  resource = None
+
+_MILLISECOND_TO_MICROSECOND = 1000
+_MICROSECOND_TO_NANOSECOND = 1000
+_SECOND_TO_MICROSECOND = 1000000
+
+
+class InferenceRunner():
+  """Implements running inferences for a framework."""
+  def run_inference(self, batch: Any, model: Any) -> Iterable[Any]:
+    """Runs inferences on a batch of examples and returns an Iterable of 
Predictions."""
+    raise NotImplementedError(type(self))
+
+  def get_num_bytes(self, batch: Any) -> int:
+    """Returns the number of bytes of data for a batch."""
+    return len(pickle.dumps(batch))
+
+  def get_metrics_namespace(self) -> str:
+    """Returns a namespace for metrics collected by the RunInference 
transform."""
+    return 'RunInference'
+
+
+class ModelLoader():
+  """Has the ability to load an ML model."""
+  def load_model(self) -> Any:
+    """Loads and initializes a model for processing."""
+    raise NotImplementedError(type(self))
+
+  def get_inference_runner(self) -> InferenceRunner:
+    """Returns an implementation of InferenceRunner for this model."""
+    raise NotImplementedError(type(self))
+
+
+def _unbatch(maybe_keyed_batches: Tuple[Any, Any]):
+  keys, results = maybe_keyed_batches
+  if keys:
+    return zip(keys, results)
+  else:
+    return results
+
+
+class RunInference(beam.PTransform):
+  """An extensible transform for running inferences."""
+  def __init__(self, model_loader: ModelLoader, clock=None):
+    self._model_loader = model_loader
+    self._clock = clock
+
+  # TODO(BEAM-14208): Add batch_size back off in the case there
+  # are functional reasons large batch sizes cannot be handled.
+  def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
+    return (
+        pcoll
+        # TODO(BEAM-14044): Hook into the batching DoFn APIs.
+        | beam.BatchElements()
+        | beam.ParDo(
+            RunInferenceDoFn(shared.Shared(), self._model_loader, self._clock))

Review Comment:
   Ok, I modified this to just initialized shared in the constructor.
   
   I originally thought that was cleaner.





Issue Time Tracking
-------------------

    Worklog Id:     (was: 752918)
    Time Spent: 15h 10m  (was: 15h)

> Implement Generic RunInference Base class
> -----------------------------------------
>
>                 Key: BEAM-13982
>                 URL: https://issues.apache.org/jira/browse/BEAM-13982
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-py-core
>            Reporter: Andy Ye
>            Assignee: Ryan Thompson
>            Priority: P2
>              Labels: run-inference
>          Time Spent: 15h 10m
>  Remaining Estimate: 0h
>
> This base class will have
>  * Metrics
>  * Will call dependent framework-specific classes
>  * Unit tests



--
This message was sent by Atlassian Jira
(v8.20.1#820001)

Reply via email to