This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new f91bb68b2e9 Add some simple annotations to Python transforms. (#28191)
f91bb68b2e9 is described below
commit f91bb68b2e9df3788914898c03fdf42445f912d5
Author: Robert Bradshaw <[email protected]>
AuthorDate: Wed Aug 30 17:23:33 2023 -0700
Add some simple annotations to Python transforms. (#28191)
---
sdks/python/apache_beam/ml/inference/base.py | 9 +++++++++
sdks/python/apache_beam/transforms/ptransform.py | 5 ++++-
sdks/python/apache_beam/yaml/yaml_transform_scope_test.py | 10 +++++++---
3 files changed, 20 insertions(+), 4 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index b5aa4f352fa..0964fc46a95 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -952,6 +952,15 @@ class
RunInference(beam.PTransform[beam.PCollection[ExampleT],
# allow us to effectively disambiguate in multi-model settings.
self._model_tag = uuid.uuid4().hex
+ def annotations(self):
+ return {
+ 'model_handler': str(self._model_handler),
+ 'model_handler_type': (
+ f'{self._model_handler.__class__.__module__}'
+ f'.{self._model_handler.__class__.__qualname__}'),
+ **super().annotations()
+ }
+
def _get_model_metadata_pcoll(self, pipeline):
# avoid circular imports.
# pylint: disable=wrong-import-position
diff --git a/sdks/python/apache_beam/transforms/ptransform.py
b/sdks/python/apache_beam/transforms/ptransform.py
index fd86ff1f934..c7eaa152ae0 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -371,7 +371,10 @@ class PTransform(WithTypeHints, HasDisplayData,
Generic[InputT, OutputT]):
return self.__class__.__name__
def annotations(self) -> Dict[str, Union[bytes, str, message.Message]]:
- return {}
+ return {
+ 'python_type': #
+ f'{self.__class__.__module__}.{self.__class__.__qualname__}'
+ }
def default_type_hints(self):
fn_type_hints = IOTypeHints.from_callable(self.expand)
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
index ead5d5d66d2..a22e4f851a1 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
@@ -61,7 +61,7 @@ class ScopeTest(unittest.TestCase):
- type: PyMap
name: Square
input: Create
- config:
+ config:
fn: "lambda x: x*x"
'''
@@ -123,7 +123,11 @@ class ScopeTest(unittest.TestCase):
self.assertIsInstance(result, beam.transforms.ParDo)
self.assertEqual(result.label, 'Map(lambda x: x*x)')
- result_annotations = {**result.annotations()}
+ result_annotations = {
+ key: value
+ for (key, value) in result.annotations().items()
+ if key.startswith('yaml')
+ }
target_annotations = {
'yaml_type': 'PyMap',
'yaml_args': '{"fn": "lambda x: x*x"}',
@@ -146,7 +150,7 @@ class LightweightScopeTest(unittest.TestCase):
fn: "lambda x: x * x * x"
- type: Filter
name: FilterOutBigNumbers
- input: PyMap
+ input: PyMap
keep: "lambda x: x<100"
'''
return yaml.load(pipeline_yaml, Loader=SafeLineLoader)