This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4d8710d43d5 [AnomalyDetection] Support functions and classes as init 
arguments in specifiable. (#34273)
4d8710d43d5 is described below

commit 4d8710d43d5123d68a4383bb4bdbb7cbf58942f2
Author: Shunping Huang <[email protected]>
AuthorDate: Fri Mar 14 08:32:18 2025 -0400

    [AnomalyDetection] Support functions and classes as init arguments in 
specifiable. (#34273)
    
    * Support functions and classes as init arguments in specifiable.
    
    * Fix lints.
---
 sdks/python/apache_beam/ml/anomaly/specifiable.py  |  72 ++++++++++---
 .../apache_beam/ml/anomaly/specifiable_test.py     | 111 ++++++++++++++++++++-
 sdks/python/apache_beam/ml/anomaly/transforms.py   |   7 ++
 3 files changed, 170 insertions(+), 20 deletions(-)

diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py 
b/sdks/python/apache_beam/ml/anomaly/specifiable.py
index e73ef5513b6..2eeb1d0de76 100644
--- a/sdks/python/apache_beam/ml/anomaly/specifiable.py
+++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py
@@ -25,12 +25,16 @@ import collections
 import dataclasses
 import inspect
 import logging
+import os
 from typing import Any
 from typing import ClassVar
+from typing import Dict
 from typing import List
+from typing import Optional
 from typing import Protocol
 from typing import Type
 from typing import TypeVar
+from typing import Union
 from typing import runtime_checkable
 
 from typing_extensions import Self
@@ -65,9 +69,11 @@ def _class_to_subspace(cls: Type) -> str:
   subspace list. This is usually called when registering a new specifiable
   class.
   """
-  for c in cls.mro():
-    if c.__name__ in _ACCEPTED_SUBSPACES:
-      return c.__name__
+  if hasattr(cls, "mro"):
+    # some classes do not have "mro", such as functions.
+    for c in cls.mro():
+      if c.__name__ in _ACCEPTED_SUBSPACES:
+        return c.__name__
 
   return _FALLBACK_SUBSPACE
 
@@ -92,8 +98,10 @@ class Spec():
   """
   #: A string indicating the concrete `Specifiable` class
   type: str
-  #: A dictionary of keyword arguments for the `__init__` method of the class.
-  config: dict[str, Any] = dataclasses.field(default_factory=dict)
+  #: An optional dictionary of keyword arguments for the `__init__` method of
+  #: the class. If None, when we materialize this Spec, we only return the
+  #: class without instantiate any objects from it.
+  config: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)
 
 
 @runtime_checkable
@@ -122,7 +130,7 @@ class Specifiable(Protocol):
     return v
 
   @classmethod
-  def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self:
+  def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]:
     """Generate a `Specifiable` subclass object based on a spec.
 
     Args:
@@ -137,9 +145,15 @@ class Specifiable(Protocol):
 
     subspace = _spec_type_to_subspace(spec.type)
     subclass: Type[Self] = _KNOWN_SPECIFIABLE[subspace].get(spec.type, None)
+
     if subclass is None:
       raise ValueError(f"Unknown spec type '{spec.type}' in {spec}")
 
+    if spec.config is None:
+      # when functions or classes are used as arguments, we won't try to
+      # create an instance.
+      return subclass
+
     kwargs = {
         k: Specifiable._from_spec_helper(v, _run_init)
         for k,
@@ -158,6 +172,16 @@ class Specifiable(Protocol):
     if isinstance(v, List):
       return [Specifiable._to_spec_helper(e) for e in v]
 
+    if inspect.isfunction(v):
+      if not hasattr(v, "spec_type"):
+        _register(v, inject_spec_type=False)
+      return Spec(type=_get_default_spec_type(v), config=None)
+
+    if inspect.isclass(v):
+      if not hasattr(v, "spec_type"):
+        _register(v, inject_spec_type=False)
+      return Spec(type=_get_default_spec_type(v), config=None)
+
     return v
 
   def to_spec(self) -> Spec:
@@ -180,23 +204,40 @@ class Specifiable(Protocol):
     pass
 
 
+def _get_default_spec_type(cls):
+  spec_type = cls.__name__
+  if inspect.isfunction(cls) and cls.__name__ == "<lambda>":
+    # for lambda functions, we need to include more information to distinguish
+    # among them
+    spec_type = '<lambda at %s:%s>' % (
+        os.path.basename(cls.__code__.co_filename), 
cls.__code__.co_firstlineno)
+
+  return spec_type
+
+
 # Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE`
-def _register(cls, spec_type=None) -> None:
+def _register(cls, spec_type=None, inject_spec_type=True) -> None:
+  assert spec_type is None or inject_spec_type, \
+      "need to inject spec_type to class if spec_type is not None"
   if spec_type is None:
-    # By default, spec type is the class name. Users can override this with
-    # other unique identifier.
-    spec_type = cls.__name__
+    # Use default spec_type for a class if users do not specify one.
+    spec_type = _get_default_spec_type(cls)
 
   subspace = _class_to_subspace(cls)
   if spec_type in _KNOWN_SPECIFIABLE[subspace]:
-    raise ValueError(
-        f"{spec_type} is already registered for "
-        f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. "
-        "Please specify a different spec_type by @specifiable(spec_type=...).")
+    if cls is not _KNOWN_SPECIFIABLE[subspace][spec_type]:
+      # only raise exception if we register the same spec type with a different
+      # class
+      raise ValueError(
+          f"{spec_type} is already registered for "
+          f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. "
+          "Please specify a different spec_type by 
@specifiable(spec_type=...)."
+      )
   else:
     _KNOWN_SPECIFIABLE[subspace][spec_type] = cls
 
-  cls.spec_type = spec_type
+  if inject_spec_type:
+    cls.spec_type = spec_type
 
 
 # Keep a copy of arguments that are used to call the `__init__` method when the
@@ -331,6 +372,7 @@ def specifiable(
     cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper)
     cls.from_spec = Specifiable.from_spec
     cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper)
+
     return cls
     # end of the function body of _wrapper
 
diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py 
b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py
index a3133f32e99..4c1a7bdaf32 100644
--- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py
+++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py
@@ -18,6 +18,7 @@
 import copy
 import dataclasses
 import logging
+import os
 import unittest
 from typing import List
 from typing import Optional
@@ -43,6 +44,9 @@ class TestSpecifiable(unittest.TestCase):
     class A():
       pass
 
+    class B():
+      pass
+
     # class is not decorated and thus not registered
     self.assertNotIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
 
@@ -53,8 +57,11 @@ class TestSpecifiable(unittest.TestCase):
     self.assertIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
     self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A"], A)
 
-    # an error is raised if the specified spec_type already exists.
-    self.assertRaises(ValueError, specifiable, A)
+    # Re-registering spec_type with the same class is allowed
+    A = specifiable(A)
+
+    # Raise an error when re-registering spec_type with a different class
+    self.assertRaises(ValueError, specifiable(spec_type='A'), B)
 
     # apply the decorator function to an existing class with a different
     # spec_type
@@ -64,9 +71,6 @@ class TestSpecifiable(unittest.TestCase):
     self.assertIn("A_DUP", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
     self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A_DUP"], A)
 
-    # an error is raised if the specified spec_type already exists.
-    self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A)
-
   def test_decorator_in_syntactic_sugar_form(self):
     # call decorator without parameters
     @specifiable
@@ -484,6 +488,103 @@ class TestNestedSpecifiable(unittest.TestCase):
     self.assertEqual(Child_2.counter, 0)
 
 
+def my_normal_func(x, y):
+  return x + y
+
+
+@specifiable
+class Wrapper():
+  def __init__(self, func=None, cls=None, **kwargs):
+    self._func = func
+    if cls is not None:
+      self._cls = cls(**kwargs)
+
+  def run_func(self, x, y):
+    return self._func(x, y)
+
+  def run_func_in_class(self, x, y):
+    return self._cls.apply(x, y)
+
+
+class TestFunctionAsArgument(unittest.TestCase):
+  def setUp(self) -> None:
+    self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE)
+
+  def tearDown(self) -> None:
+    _KNOWN_SPECIFIABLE.clear()
+    _KNOWN_SPECIFIABLE.update(self.saved_specifiable)
+
+  def test_normal_function(self):
+    w = Wrapper(my_normal_func)
+
+    self.assertEqual(w.run_func(1, 2), 3)
+
+    w_spec = w.to_spec()
+    self.assertEqual(
+        w_spec,
+        Spec(
+            type='Wrapper',
+            config={'func': Spec(type="my_normal_func", config=None)}))
+
+    w_2 = Specifiable.from_spec(w_spec)
+    self.assertEqual(w_2.run_func(2, 3), 5)
+
+  def test_lambda_function(self):
+    my_lambda_func = lambda x, y: x - y
+
+    w = Wrapper(my_lambda_func)
+
+    self.assertEqual(w.run_func(3, 2), 1)
+
+    w_spec = w.to_spec()
+    self.assertEqual(
+        w_spec,
+        Spec(
+            type='Wrapper',
+            config={
+                'func': Spec(
+                    type=
+                    f"<lambda at 
{os.path.basename(__file__)}:{my_lambda_func.__code__.co_firstlineno}>",  # 
pylint: disable=line-too-long
+                    config=None)
+            }
+        ))
+
+    w_2 = Specifiable.from_spec(w_spec)
+    self.assertEqual(w_2.run_func(5, 3), 2)
+
+
+class TestClassAsArgument(unittest.TestCase):
+  def setUp(self) -> None:
+    self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE)
+
+  def tearDown(self) -> None:
+    _KNOWN_SPECIFIABLE.clear()
+    _KNOWN_SPECIFIABLE.update(self.saved_specifiable)
+
+  def test_normal_class(self):
+    class InnerClass():
+      def __init__(self, multiplier):
+        self._multiplier = multiplier
+
+      def apply(self, x, y):
+        return x * y * self._multiplier
+
+    w = Wrapper(cls=InnerClass, multiplier=10)
+    self.assertEqual(w.run_func_in_class(2, 3), 60)
+
+    w_spec = w.to_spec()
+    self.assertEqual(
+        w_spec,
+        Spec(
+            type='Wrapper',
+            config={
+                'cls': Spec(type='InnerClass', config=None), 'multiplier': 10
+            }))
+
+    w_2 = Specifiable.from_spec(w_spec)
+    self.assertEqual(w_2.run_func_in_class(5, 3), 150)
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()
diff --git a/sdks/python/apache_beam/ml/anomaly/transforms.py 
b/sdks/python/apache_beam/ml/anomaly/transforms.py
index 35cae18a722..08b656072ac 100644
--- a/sdks/python/apache_beam/ml/anomaly/transforms.py
+++ b/sdks/python/apache_beam/ml/anomaly/transforms.py
@@ -18,6 +18,7 @@
 import dataclasses
 import uuid
 from typing import Callable
+from typing import Dict
 from typing import Iterable
 from typing import Optional
 from typing import Tuple
@@ -55,6 +56,8 @@ class _ScoreAndLearnDoFn(beam.DoFn):
 
   def __init__(self, detector_spec: Spec):
     self._detector_spec = detector_spec
+
+    assert isinstance(self._detector_spec.config, Dict)
     self._detector_spec.config["_run_init"] = True
 
   def score_and_learn(self, data):
@@ -172,8 +175,10 @@ class _StatelessThresholdDoFn(_BaseThresholdDoFn):
       creation of a stateful `ThresholdFn`.
   """
   def __init__(self, threshold_fn_spec: Spec):
+    assert isinstance(threshold_fn_spec.config, Dict)
     threshold_fn_spec.config["_run_init"] = True
     self._threshold_fn = Specifiable.from_spec(threshold_fn_spec)
+    assert isinstance(self._threshold_fn, ThresholdFn)
     assert not self._threshold_fn.is_stateful, \
       "This DoFn can only take stateless function as threshold_fn"
 
@@ -217,8 +222,10 @@ class _StatefulThresholdDoFn(_BaseThresholdDoFn):
   THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', 
DillCoder())
 
   def __init__(self, threshold_fn_spec: Spec):
+    assert isinstance(threshold_fn_spec.config, Dict)
     threshold_fn_spec.config["_run_init"] = True
     threshold_fn = Specifiable.from_spec(threshold_fn_spec)
+    assert isinstance(threshold_fn, ThresholdFn)
     assert threshold_fn.is_stateful, \
       "This DoFn can only take stateful function as threshold_fn"
     self._threshold_fn_spec = threshold_fn_spec

Reply via email to