This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 00136b112bc Update KeyMhMapping to KeyModelMapping (#28209)
00136b112bc is described below
commit 00136b112bc49e85c84dfe52d7b65560310012fe
Author: Danny McCormick <[email protected]>
AuthorDate: Wed Aug 30 09:34:26 2023 -0400
Update KeyMhMapping to KeyModelMapping (#28209)
---
.../pytorch_model_per_key_image_segmentation.py | 4 +-
sdks/python/apache_beam/ml/inference/base.py | 17 +++---
sdks/python/apache_beam/ml/inference/base_test.py | 62 +++++++++++-----------
3 files changed, 43 insertions(+), 40 deletions(-)
diff --git
a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py
b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py
index e09a348511b..f0b5462d533 100644
---
a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py
+++
b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py
@@ -33,7 +33,7 @@ import apache_beam as beam
import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import KeyedModelHandler
-from apache_beam.ml.inference.base import KeyMhMapping
+from apache_beam.ml.inference.base import KeyModelMapping
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import
PytorchModelHandlerTensor
@@ -257,7 +257,7 @@ def run(
# Note that multiple keys can also point to a single model handler,
# unlike this example.
model_handler = KeyedModelHandler(
- [KeyMhMapping(['v1'], mh1), KeyMhMapping(['v2'], mh2)])
+ [KeyModelMapping(['v1'], mh1), KeyModelMapping(['v2'], mh2)])
pipeline = test_pipeline
if not test_pipeline:
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index 10e8981d8bd..b5aa4f352fa 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -396,10 +396,10 @@ class _ModelManager:
# Use a dataclass instead of named tuple because NamedTuples and generics don't
# mix well across the board for all versions:
# https://github.com/python/typing/issues/653
-class KeyMhMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):
+class KeyModelMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):
"""
- Dataclass for mapping 1 or more keys to 1 model handler.
- Given `KeyMhMapping(['key1', 'key2'], myMh)`, all examples with keys `key1`
+ Dataclass for mapping 1 or more keys to 1 model handler. Given
+ `KeyModelMapping(['key1', 'key2'], myMh)`, all examples with keys `key1`
or `key2` will be run against the model defined by the `myMh` ModelHandler.
"""
def __init__(
@@ -415,7 +415,8 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
def __init__(
self,
unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
- List[KeyMhMapping[KeyT, ExampleT, PredictionT, ModelT]]]):
+ List[KeyModelMapping[KeyT, ExampleT, PredictionT,
+ ModelT]]]):
"""A ModelHandler that takes keyed examples and returns keyed predictions.
For example, if the original model is used with RunInference to take a
@@ -429,7 +430,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
k1 = ['k1', 'k2', 'k3']
k2 = ['k4', 'k5']
- KeyedModelHandler([KeyMhMapping(k1, mh1), KeyMhMapping(k2, mh2)])
+ KeyedModelHandler([KeyModelMapping(k1, mh1), KeyModelMapping(k2, mh2)])
Note that a single copy of each of these models may all be held in memory
at the same time; be careful not to load too many large models or your
@@ -462,7 +463,7 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
Args:
unkeyed: Either (a) an implementation of ModelHandler that does not
- require keys or (b) a list of KeyMhMappings mapping lists of keys to
+ require keys or (b) a list of KeyModelMappings mapping lists of keys to
unkeyed ModelHandlers.
"""
self._single_model = not isinstance(unkeyed, list)
@@ -479,8 +480,8 @@ class KeyedModelHandler(Generic[KeyT, ExampleT,
PredictionT, ModelT],
return
# To maintain an efficient representation, we will map all keys in a given
- # KeyMhMapping to a single id (the first key in the KeyMhMapping list).
- # We will then map that key to a ModelHandler. This will allow us to
+ # KeyModelMapping to a single id (the first key in the KeyModelMapping
+ # list). We will then map that key to a ModelHandler. This will allow us to
# quickly look up the appropriate ModelHandler for any given key.
self._id_to_mh_map: Dict[str, ModelHandler[ExampleT, PredictionT,
ModelT]] = {}
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py
b/sdks/python/apache_beam/ml/inference/base_test.py
index 4b551ce5584..af6168c80af 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -277,11 +277,11 @@ class RunInferenceBaseTest(unittest.TestCase):
expected[0] = (0, 200)
pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
mhs = [
- base.KeyMhMapping([0],
- FakeModelHandler(
- state=200, multi_process_shared=True)),
- base.KeyMhMapping([1, 2, 3],
- FakeModelHandler(multi_process_shared=True))
+ base.KeyModelMapping([0],
+ FakeModelHandler(
+ state=200, multi_process_shared=True)),
+ base.KeyModelMapping([1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
]
actual = pcoll | base.RunInference(base.KeyedModelHandler(mhs))
assert_that(actual, equal_to(expected), label='assert:inferences')
@@ -291,45 +291,45 @@ class RunInferenceBaseTest(unittest.TestCase):
return int(example) * 2
mhs = [
- base.KeyMhMapping(
+ base.KeyModelMapping(
[0],
FakeModelHandler(
state=200,
multi_process_shared=True).with_preprocess_fn(mult_two)),
- base.KeyMhMapping([1, 2, 3],
- FakeModelHandler(multi_process_shared=True))
+ base.KeyModelMapping([1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
]
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)
mhs = [
- base.KeyMhMapping(
+ base.KeyModelMapping(
[0],
FakeModelHandler(
state=200,
multi_process_shared=True).with_postprocess_fn(mult_two)),
- base.KeyMhMapping([1, 2, 3],
- FakeModelHandler(multi_process_shared=True))
+ base.KeyModelMapping([1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
]
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)
mhs = [
- base.KeyMhMapping([0],
- FakeModelHandler(
- state=200, multi_process_shared=True)),
- base.KeyMhMapping([0, 1, 2, 3],
- FakeModelHandler(multi_process_shared=True))
+ base.KeyModelMapping([0],
+ FakeModelHandler(
+ state=200, multi_process_shared=True)),
+ base.KeyModelMapping([0, 1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
]
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)
mhs = [
- base.KeyMhMapping([],
- FakeModelHandler(
- state=200, multi_process_shared=True)),
- base.KeyMhMapping([0, 1, 2, 3],
- FakeModelHandler(multi_process_shared=True))
+ base.KeyModelMapping([],
+ FakeModelHandler(
+ state=200, multi_process_shared=True)),
+ base.KeyModelMapping([0, 1, 2, 3],
+ FakeModelHandler(multi_process_shared=True))
]
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)
@@ -343,8 +343,10 @@ class RunInferenceBaseTest(unittest.TestCase):
def test_keyed_model_handler_multiple_models_get_num_bytes(self):
mhs = [
- base.KeyMhMapping(['key1'],
FakeModelHandler(num_bytes_per_element=10)),
- base.KeyMhMapping(['key2'], FakeModelHandler(num_bytes_per_element=20))
+ base.KeyModelMapping(['key1'],
+ FakeModelHandler(num_bytes_per_element=10)),
+ base.KeyModelMapping(['key2'],
+ FakeModelHandler(num_bytes_per_element=20))
]
mh = base.KeyedModelHandler(mhs)
batch = [('key1', 1), ('key2', 2), ('key1', 3)]
@@ -1010,12 +1012,12 @@ class RunInferenceBaseTest(unittest.TestCase):
]
model_handler = base.KeyedModelHandler([
- base.KeyMhMapping(['key1'],
- FakeModelHandlerReturnsPredictionResult(
- multi_process_shared=True, state=True)),
- base.KeyMhMapping(['key2'],
- FakeModelHandlerReturnsPredictionResult(
- multi_process_shared=True, state=True))
+ base.KeyModelMapping(['key1'],
+ FakeModelHandlerReturnsPredictionResult(
+ multi_process_shared=True, state=True)),
+ base.KeyModelMapping(['key2'],
+ FakeModelHandlerReturnsPredictionResult(
+ multi_process_shared=True, state=True))
])
class _EmitElement(beam.DoFn):
@@ -1114,7 +1116,7 @@ class RunInferenceBaseTest(unittest.TestCase):
]
model_handler = base.KeyedModelHandler([
- base.KeyMhMapping(
+ base.KeyModelMapping(
['key1', 'key2'],
FakeModelHandlerReturnsPredictionResult(multi_process_shared=True))
])