This is an automated email from the ASF dual-hosted git repository.
shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 74512d9d04d [YAML] Add a spec provider for transforms taking
specifiable arguments (#35187)
74512d9d04d is described below
commit 74512d9d04d810a0dc633928b699e4e35e57c98b
Author: Shunping Huang <[email protected]>
AuthorDate: Fri Jun 6 20:47:25 2025 -0400
[YAML] Add a spec provider for transforms taking specifiable arguments
(#35187)
* Add a test provider for specifiable and try it on AnomalyDetection.
Also add support on callable in spec.
* Minor renaming
* Fix lints.
---
sdks/python/apache_beam/yaml/yaml_provider.py | 2 +
sdks/python/apache_beam/yaml/yaml_specifiable.py | 59 +++++++++++
.../apache_beam/yaml/yaml_specifiable_test.py | 115 +++++++++++++++++++++
3 files changed, 176 insertions(+)
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py
b/sdks/python/apache_beam/yaml/yaml_provider.py
index 0609c4ef8df..d7c427228a3 100755
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -1475,6 +1475,7 @@ def standard_providers():
from apache_beam.yaml.yaml_mapping import create_mapping_providers
from apache_beam.yaml.yaml_join import create_join_providers
from apache_beam.yaml.yaml_io import io_providers
+ from apache_beam.yaml.yaml_specifiable import create_spec_providers
return merge_providers(
YamlProviders.create_builtin_provider(),
@@ -1483,6 +1484,7 @@ def standard_providers():
create_combine_providers(),
create_join_providers(),
io_providers(),
+ create_spec_providers(),
load_providers(yaml_utils.locate_data_file('standard_providers.yaml')))
diff --git a/sdks/python/apache_beam/yaml/yaml_specifiable.py
b/sdks/python/apache_beam/yaml/yaml_specifiable.py
new file mode 100644
index 00000000000..207bf608b4d
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_specifiable.py
@@ -0,0 +1,59 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.anomaly.specifiable import Spec
+from apache_beam.ml.anomaly.transforms import AnomalyDetection
+from apache_beam.ml.anomaly.transforms import Specifiable
+from apache_beam.utils import python_callable
+from apache_beam.yaml.yaml_provider import InlineProvider
+
+
+def maybe_make_specifiable(v):
+ if isinstance(v, dict):
+ if "type" in v and "config" in v:
+ return Specifiable.from_spec(
+ Spec(type=v["type"], config=maybe_make_specifiable(v["config"])))
+
+ if "callable" in v:
+ if "path" in v or "name" in v:
+ raise ValueError(
+ "Cannot specify 'callable' with 'path' and 'name' for function.")
+ else:
+ return python_callable.PythonCallableWithSource(v["callable"])
+
+ if "path" in v and "name" in v:
+ return python_callable.PythonCallableWithSource.load_from_script(
+ FileSystems.open(v["path"]).read().decode(), v["name"])
+
+ ret = {k: maybe_make_specifiable(v[k]) for k in v}
+ return ret
+ else:
+ return v
+
+
+class SpecProvider(InlineProvider):
+ def create_transform(self, type, args, yaml_create_transform):
+ return self._transform_factories[type](
+ **{
+ k: maybe_make_specifiable(v)
+ for k, v in args.items()
+ })
+
+
+def create_spec_providers():
+ return SpecProvider({"AnomalyDetection": AnomalyDetection})
diff --git a/sdks/python/apache_beam/yaml/yaml_specifiable_test.py
b/sdks/python/apache_beam/yaml/yaml_specifiable_test.py
new file mode 100644
index 00000000000..62b455c4980
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_specifiable_test.py
@@ -0,0 +1,115 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+from typing import Callable
+
+import apache_beam as beam
+from apache_beam.ml.anomaly.base import AnomalyDetector
+from apache_beam.ml.anomaly.specifiable import specifiable
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.utils import python_callable
+from apache_beam.yaml.yaml_transform import YamlTransform
+
+TEST_PROVIDERS = {
+ 'PyMap': lambda fn: beam.Map(python_callable.PythonCallableWithSource(fn)),
+}
+
+
+@specifiable
+class FakeDetector(AnomalyDetector): # pylint: disable=unused-variable
+ def __init__(self, fn: Callable):
+ super().__init__()
+ self._fn = fn
+
+ def learn_one(self, x: beam.Row) -> None:
+ pass
+
+ def score_one(self, x: beam.Row) -> float:
+ v = next(iter(x))
+ return self._fn(v)
+
+
+class YamlSpecifiableTransformTest(unittest.TestCase):
+ def test_specifiable_transform(self):
+ TRAIN_DATA = [
+ (0, beam.Row(x=1)),
+ (0, beam.Row(x=2)),
+ (0, beam.Row(x=2)),
+ (0, beam.Row(x=4)),
+ (0, beam.Row(x=9)),
+ ]
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ result = p | beam.Create(TRAIN_DATA) | YamlTransform(
+ '''
+ type: chain
+ transforms:
+ - type: AnomalyDetection
+ config:
+ detector:
+ type: 'ZScore'
+ config:
+ sub_stat_tracker:
+ type: 'IncSlidingMeanTracker'
+ config:
+ window_size: 5
+ stdev_tracker:
+ type: 'IncSlidingStdevTracker'
+ config:
+ window_size: 5
+ - type: PyMap
+ config:
+ fn: "lambda x: (x[1].predictions[0].label)"
+ ''',
+ providers=TEST_PROVIDERS)
+ assert_that(result, equal_to([-2, -2, 0, 1, 1]))
+
+ def test_specifiable_transform_with_callable(self):
+ TRAIN_DATA = [
+ (0, beam.Row(x=1)),
+ (0, beam.Row(x=2)),
+ (0, beam.Row(x=2)),
+ (0, beam.Row(x=4)),
+ (0, beam.Row(x=9)),
+ ]
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ result = p | beam.Create(TRAIN_DATA) | YamlTransform(
+ '''
+ type: chain
+ transforms:
+ - type: AnomalyDetection
+ config:
+ detector:
+ type: 'FakeDetector'
+ config:
+ fn:
+ callable: "lambda x: x * 10.0"
+ - type: PyMap
+ config:
+ fn: "lambda x: (x[1].predictions[0].score)"
+ ''',
+ providers=TEST_PROVIDERS)
+ assert_that(result, equal_to([10.0, 20.0, 20.0, 40.0, 90.0]))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()