This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 00136b112bc Update KeyMhMapping to KeyModelMapping (#28209)
00136b112bc is described below

commit 00136b112bc49e85c84dfe52d7b65560310012fe
Author: Danny McCormick <[email protected]>
AuthorDate: Wed Aug 30 09:34:26 2023 -0400

    Update KeyMhMapping to KeyModelMapping (#28209)
---
 .../pytorch_model_per_key_image_segmentation.py    |  4 +-
 sdks/python/apache_beam/ml/inference/base.py       | 17 +++---
 sdks/python/apache_beam/ml/inference/base_test.py  | 62 +++++++++++-----------
 3 files changed, 43 insertions(+), 40 deletions(-)

diff --git 
a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py
 
b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py
index e09a348511b..f0b5462d533 100644
--- 
a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py
+++ 
b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py
@@ -33,7 +33,7 @@ import apache_beam as beam
 import torch
 from apache_beam.io.filesystems import FileSystems
 from apache_beam.ml.inference.base import KeyedModelHandler
-from apache_beam.ml.inference.base import KeyMhMapping
+from apache_beam.ml.inference.base import KeyModelMapping
 from apache_beam.ml.inference.base import PredictionResult
 from apache_beam.ml.inference.base import RunInference
 from apache_beam.ml.inference.pytorch_inference import 
PytorchModelHandlerTensor
@@ -257,7 +257,7 @@ def run(
   # Note that multiple keys can also point to a single model handler,
   # unlike this example.
   model_handler = KeyedModelHandler(
-      [KeyMhMapping(['v1'], mh1), KeyMhMapping(['v2'], mh2)])
+      [KeyModelMapping(['v1'], mh1), KeyModelMapping(['v2'], mh2)])
 
   pipeline = test_pipeline
   if not test_pipeline:
diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 10e8981d8bd..b5aa4f352fa 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -396,10 +396,10 @@ class _ModelManager:
 # Use a dataclass instead of named tuple because NamedTuples and generics don't
 # mix well across the board for all versions:
 # https://github.com/python/typing/issues/653
-class KeyMhMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):
+class KeyModelMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):
   """
-  Dataclass for mapping 1 or more keys to 1 model handler.
-  Given `KeyMhMapping(['key1', 'key2'], myMh)`, all examples with keys `key1`
+  Dataclass for mapping 1 or more keys to 1 model handler. Given
+  `KeyModelMapping(['key1', 'key2'], myMh)`, all examples with keys `key1`
   or `key2` will be run against the model defined by the `myMh` ModelHandler.
   """
   def __init__(
@@ -415,7 +415,8 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
   def __init__(
       self,
       unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
-                     List[KeyMhMapping[KeyT, ExampleT, PredictionT, ModelT]]]):
+                     List[KeyModelMapping[KeyT, ExampleT, PredictionT,
+                                          ModelT]]]):
     """A ModelHandler that takes keyed examples and returns keyed predictions.
 
     For example, if the original model is used with RunInference to take a
@@ -429,7 +430,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
 
         k1 = ['k1', 'k2', 'k3']
         k2 = ['k4', 'k5']
-        KeyedModelHandler([KeyMhMapping(k1, mh1), KeyMhMapping(k2, mh2)])
+        KeyedModelHandler([KeyModelMapping(k1, mh1), KeyModelMapping(k2, mh2)])
 
     Note that a single copy of each of these models may all be held in memory
     at the same time; be careful not to load too many large models or your
@@ -462,7 +463,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
 
     Args:
       unkeyed: Either (a) an implementation of ModelHandler that does not
-        require keys or (b) a list of KeyMhMappings mapping lists of keys to
+        require keys or (b) a list of KeyModelMappings mapping lists of keys to
         unkeyed ModelHandlers.
     """
     self._single_model = not isinstance(unkeyed, list)
@@ -479,8 +480,8 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
       return
 
     # To maintain an efficient representation, we will map all keys in a given
-    # KeyMhMapping to a single id (the first key in the KeyMhMapping list).
-    # We will then map that key to a ModelHandler. This will allow us to
+    # KeyModelMapping to a single id (the first key in the KeyModelMapping
+    # list). We will then map that key to a ModelHandler. This will allow us to
     # quickly look up the appropriate ModelHandler for any given key.
     self._id_to_mh_map: Dict[str, ModelHandler[ExampleT, PredictionT,
                                                ModelT]] = {}
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index 4b551ce5584..af6168c80af 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -277,11 +277,11 @@ class RunInferenceBaseTest(unittest.TestCase):
       expected[0] = (0, 200)
       pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
       mhs = [
-          base.KeyMhMapping([0],
-                            FakeModelHandler(
-                                state=200, multi_process_shared=True)),
-          base.KeyMhMapping([1, 2, 3],
-                            FakeModelHandler(multi_process_shared=True))
+          base.KeyModelMapping([0],
+                               FakeModelHandler(
+                                   state=200, multi_process_shared=True)),
+          base.KeyModelMapping([1, 2, 3],
+                               FakeModelHandler(multi_process_shared=True))
       ]
       actual = pcoll | base.RunInference(base.KeyedModelHandler(mhs))
       assert_that(actual, equal_to(expected), label='assert:inferences')
@@ -291,45 +291,45 @@ class RunInferenceBaseTest(unittest.TestCase):
       return int(example) * 2
 
     mhs = [
-        base.KeyMhMapping(
+        base.KeyModelMapping(
             [0],
             FakeModelHandler(
                 state=200,
                 multi_process_shared=True).with_preprocess_fn(mult_two)),
-        base.KeyMhMapping([1, 2, 3],
-                          FakeModelHandler(multi_process_shared=True))
+        base.KeyModelMapping([1, 2, 3],
+                             FakeModelHandler(multi_process_shared=True))
     ]
     with self.assertRaises(ValueError):
       base.KeyedModelHandler(mhs)
 
     mhs = [
-        base.KeyMhMapping(
+        base.KeyModelMapping(
             [0],
             FakeModelHandler(
                 state=200,
                 multi_process_shared=True).with_postprocess_fn(mult_two)),
-        base.KeyMhMapping([1, 2, 3],
-                          FakeModelHandler(multi_process_shared=True))
+        base.KeyModelMapping([1, 2, 3],
+                             FakeModelHandler(multi_process_shared=True))
     ]
     with self.assertRaises(ValueError):
       base.KeyedModelHandler(mhs)
 
     mhs = [
-        base.KeyMhMapping([0],
-                          FakeModelHandler(
-                              state=200, multi_process_shared=True)),
-        base.KeyMhMapping([0, 1, 2, 3],
-                          FakeModelHandler(multi_process_shared=True))
+        base.KeyModelMapping([0],
+                             FakeModelHandler(
+                                 state=200, multi_process_shared=True)),
+        base.KeyModelMapping([0, 1, 2, 3],
+                             FakeModelHandler(multi_process_shared=True))
     ]
     with self.assertRaises(ValueError):
       base.KeyedModelHandler(mhs)
 
     mhs = [
-        base.KeyMhMapping([],
-                          FakeModelHandler(
-                              state=200, multi_process_shared=True)),
-        base.KeyMhMapping([0, 1, 2, 3],
-                          FakeModelHandler(multi_process_shared=True))
+        base.KeyModelMapping([],
+                             FakeModelHandler(
+                                 state=200, multi_process_shared=True)),
+        base.KeyModelMapping([0, 1, 2, 3],
+                             FakeModelHandler(multi_process_shared=True))
     ]
     with self.assertRaises(ValueError):
       base.KeyedModelHandler(mhs)
@@ -343,8 +343,10 @@ class RunInferenceBaseTest(unittest.TestCase):
 
   def test_keyed_model_handler_multiple_models_get_num_bytes(self):
     mhs = [
-        base.KeyMhMapping(['key1'], 
FakeModelHandler(num_bytes_per_element=10)),
-        base.KeyMhMapping(['key2'], FakeModelHandler(num_bytes_per_element=20))
+        base.KeyModelMapping(['key1'],
+                             FakeModelHandler(num_bytes_per_element=10)),
+        base.KeyModelMapping(['key2'],
+                             FakeModelHandler(num_bytes_per_element=20))
     ]
     mh = base.KeyedModelHandler(mhs)
     batch = [('key1', 1), ('key2', 2), ('key1', 3)]
@@ -1010,12 +1012,12 @@ class RunInferenceBaseTest(unittest.TestCase):
     ]
 
     model_handler = base.KeyedModelHandler([
-        base.KeyMhMapping(['key1'],
-                          FakeModelHandlerReturnsPredictionResult(
-                              multi_process_shared=True, state=True)),
-        base.KeyMhMapping(['key2'],
-                          FakeModelHandlerReturnsPredictionResult(
-                              multi_process_shared=True, state=True))
+        base.KeyModelMapping(['key1'],
+                             FakeModelHandlerReturnsPredictionResult(
+                                 multi_process_shared=True, state=True)),
+        base.KeyModelMapping(['key2'],
+                             FakeModelHandlerReturnsPredictionResult(
+                                 multi_process_shared=True, state=True))
     ])
 
     class _EmitElement(beam.DoFn):
@@ -1114,7 +1116,7 @@ class RunInferenceBaseTest(unittest.TestCase):
     ]
 
     model_handler = base.KeyedModelHandler([
-        base.KeyMhMapping(
+        base.KeyModelMapping(
             ['key1', 'key2'],
             FakeModelHandlerReturnsPredictionResult(multi_process_shared=True))
     ])

Reply via email to