This is an automated email from the ASF dual-hosted git repository.

shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 74512d9d04d [YAML] Add a spec provider for transforms taking 
specifiable arguments (#35187)
74512d9d04d is described below

commit 74512d9d04d810a0dc633928b699e4e35e57c98b
Author: Shunping Huang <[email protected]>
AuthorDate: Fri Jun 6 20:47:25 2025 -0400

    [YAML] Add a spec provider for transforms taking specifiable arguments 
(#35187)
    
    * Add a test provider for specifiable and try it on AnomalyDetection.
    
    Also add support on callable in spec.
    
    * Minor renaming
    
    * Fix lints.
---
 sdks/python/apache_beam/yaml/yaml_provider.py      |   2 +
 sdks/python/apache_beam/yaml/yaml_specifiable.py   |  59 +++++++++++
 .../apache_beam/yaml/yaml_specifiable_test.py      | 115 +++++++++++++++++++++
 3 files changed, 176 insertions(+)

diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py 
b/sdks/python/apache_beam/yaml/yaml_provider.py
index 0609c4ef8df..d7c427228a3 100755
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -1475,6 +1475,7 @@ def standard_providers():
   from apache_beam.yaml.yaml_mapping import create_mapping_providers
   from apache_beam.yaml.yaml_join import create_join_providers
   from apache_beam.yaml.yaml_io import io_providers
+  from apache_beam.yaml.yaml_specifiable import create_spec_providers
 
   return merge_providers(
       YamlProviders.create_builtin_provider(),
@@ -1483,6 +1484,7 @@ def standard_providers():
       create_combine_providers(),
       create_join_providers(),
       io_providers(),
+      create_spec_providers(),
       load_providers(yaml_utils.locate_data_file('standard_providers.yaml')))
 
 
diff --git a/sdks/python/apache_beam/yaml/yaml_specifiable.py 
b/sdks/python/apache_beam/yaml/yaml_specifiable.py
new file mode 100644
index 00000000000..207bf608b4d
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_specifiable.py
@@ -0,0 +1,59 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.anomaly.specifiable import Spec
+from apache_beam.ml.anomaly.transforms import AnomalyDetection
+from apache_beam.ml.anomaly.transforms import Specifiable
+from apache_beam.utils import python_callable
+from apache_beam.yaml.yaml_provider import InlineProvider
+
+
+def maybe_make_specifiable(v):
+  if isinstance(v, dict):
+    if "type" in v and "config" in v:
+      return Specifiable.from_spec(
+          Spec(type=v["type"], config=maybe_make_specifiable(v["config"])))
+
+    if "callable" in v:
+      if "path" in v or "name" in v:
+        raise ValueError(
+            "Cannot specify 'callable' with 'path' and 'name' for function.")
+      else:
+        return python_callable.PythonCallableWithSource(v["callable"])
+
+    if "path" in v and "name" in v:
+      return python_callable.PythonCallableWithSource.load_from_script(
+          FileSystems.open(v["path"]).read().decode(), v["name"])
+
+    ret = {k: maybe_make_specifiable(v[k]) for k in v}
+    return ret
+  else:
+    return v
+
+
+class SpecProvider(InlineProvider):
+  def create_transform(self, type, args, yaml_create_transform):
+    return self._transform_factories[type](
+        **{
+            k: maybe_make_specifiable(v)
+            for k, v in args.items()
+        })
+
+
+def create_spec_providers():
+  return SpecProvider({"AnomalyDetection": AnomalyDetection})
diff --git a/sdks/python/apache_beam/yaml/yaml_specifiable_test.py 
b/sdks/python/apache_beam/yaml/yaml_specifiable_test.py
new file mode 100644
index 00000000000..62b455c4980
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_specifiable_test.py
@@ -0,0 +1,115 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+from typing import Callable
+
+import apache_beam as beam
+from apache_beam.ml.anomaly.base import AnomalyDetector
+from apache_beam.ml.anomaly.specifiable import specifiable
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.utils import python_callable
+from apache_beam.yaml.yaml_transform import YamlTransform
+
+TEST_PROVIDERS = {
+    'PyMap': lambda fn: beam.Map(python_callable.PythonCallableWithSource(fn)),
+}
+
+
+@specifiable
+class FakeDetector(AnomalyDetector):  # pylint: disable=unused-variable
+  def __init__(self, fn: Callable):
+    super().__init__()
+    self._fn = fn
+
+  def learn_one(self, x: beam.Row) -> None:
+    pass
+
+  def score_one(self, x: beam.Row) -> float:
+    v = next(iter(x))
+    return self._fn(v)
+
+
+class YamlSpecifiableTransformTest(unittest.TestCase):
+  def test_specifiable_transform(self):
+    TRAIN_DATA = [
+        (0, beam.Row(x=1)),
+        (0, beam.Row(x=2)),
+        (0, beam.Row(x=2)),
+        (0, beam.Row(x=4)),
+        (0, beam.Row(x=9)),
+    ]
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      result = p | beam.Create(TRAIN_DATA) | YamlTransform(
+          '''
+          type: chain
+          transforms:
+            - type: AnomalyDetection
+              config:
+                detector:
+                  type: 'ZScore'
+                  config:
+                    sub_stat_tracker:
+                      type: 'IncSlidingMeanTracker'
+                      config:
+                        window_size: 5
+                    stdev_tracker:
+                      type: 'IncSlidingStdevTracker'
+                      config:
+                        window_size: 5
+            - type: PyMap
+              config:
+                  fn: "lambda x: (x[1].predictions[0].label)"
+          ''',
+          providers=TEST_PROVIDERS)
+      assert_that(result, equal_to([-2, -2, 0, 1, 1]))
+
+  def test_specifiable_transform_with_callable(self):
+    TRAIN_DATA = [
+        (0, beam.Row(x=1)),
+        (0, beam.Row(x=2)),
+        (0, beam.Row(x=2)),
+        (0, beam.Row(x=4)),
+        (0, beam.Row(x=9)),
+    ]
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      result = p | beam.Create(TRAIN_DATA) | YamlTransform(
+          '''
+          type: chain
+          transforms:
+            - type: AnomalyDetection
+              config:
+                detector:
+                  type: 'FakeDetector'
+                  config:
+                    fn:
+                      callable: "lambda x: x * 10.0"
+            - type: PyMap
+              config:
+                  fn: "lambda x: (x[1].predictions[0].score)"
+          ''',
+          providers=TEST_PROVIDERS)
+      assert_that(result, equal_to([10.0, 20.0, 20.0, 40.0, 90.0]))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

Reply via email to