This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 128d5989121 [BEAM-14218] Add resource location hints to base inference
runner. (#17448)
128d5989121 is described below
commit 128d5989121155549e9bef823529153a3de6520e
Author: Ryan Thompson <[email protected]>
AuthorDate: Fri May 27 16:55:12 2022 -0400
[BEAM-14218] Add resource location hints to base inference runner. (#17448)
* added docs and location hints
* simplified model hints to come from the loader
* yapf
* added optional clock
* moved clock
* fixed typo'
* keep clock in right space
* fixed bad hints
---
sdks/python/apache_beam/ml/inference/base.py | 25 ++++++++++++++++++++-----
1 file changed, 20 insertions(+), 5 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index 49753c4e7a3..7c7ef50e234 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -39,6 +39,7 @@ from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
+from typing import Optional
from typing import TypeVar
import apache_beam as beam
@@ -57,7 +58,7 @@ _SECOND_TO_MICROSECOND = 1_000_000
T = TypeVar('T')
-class InferenceRunner():
+class InferenceRunner:
"""Implements running inferences for a framework."""
def run_inference(self, batch: List[Any], model: Any) -> Iterable[Any]:
"""Runs inferences on a batch of examples and
@@ -83,25 +84,38 @@ class ModelLoader(Generic[T]):
"""Returns an implementation of InferenceRunner for this model."""
raise NotImplementedError(type(self))
+ def get_resource_hints(self) -> dict:
+ """Returns resource hints for the transform."""
+ return {}
+
def batch_elements_kwargs(self) -> Mapping[str, Any]:
"""Returns kwargs suitable for beam.BatchElements."""
return {}
class RunInference(beam.PTransform):
- """An extensible transform for running inferences."""
- def __init__(self, model_loader: ModelLoader, clock=None):
+ """An extensible transform for running inferences.
+ Args:
+ model_loader: An implementation of ModelLoader.
+ clock: A clock implementing get_current_time_in_microseconds.
+ """
+ def __init__(
+ self, model_loader: ModelLoader, clock: Optional["_Clock"] = None):
self._model_loader = model_loader
self._clock = clock
# TODO(BEAM-14208): Add batch_size back off in the case there
# are functional reasons large batch sizes cannot be handled.
def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
+ resource_hints = self._model_loader.get_resource_hints()
return (
pcoll
# TODO(BEAM-14044): Hook into the batching DoFn APIs.
| beam.BatchElements(**self._model_loader.batch_elements_kwargs())
- | beam.ParDo(_RunInferenceDoFn(self._model_loader, self._clock)))
+ | (
+ beam.ParDo(_RunInferenceDoFn(
+ self._model_loader,
+ self._clock)).with_resource_hints(**resource_hints)))
class _MetricsCollector:
@@ -155,7 +169,8 @@ class _MetricsCollector:
class _RunInferenceDoFn(beam.DoFn):
"""A DoFn implementation generic to frameworks."""
- def __init__(self, model_loader: ModelLoader, clock=None):
+ def __init__(
+ self, model_loader: ModelLoader, clock: Optional["_Clock"] = None):
self._model_loader = model_loader
self._inference_runner = model_loader.get_inference_runner()
self._shared_model_handle = shared.Shared()