This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 4d8710d43d5 [AnomalyDetection] Support functions and classes as init
arguments in specifiable. (#34273)
4d8710d43d5 is described below
commit 4d8710d43d5123d68a4383bb4bdbb7cbf58942f2
Author: Shunping Huang <[email protected]>
AuthorDate: Fri Mar 14 08:32:18 2025 -0400
[AnomalyDetection] Support functions and classes as init arguments in
specifiable. (#34273)
* Support functions and classes as init arguments in specifiable.
* Fix lints.
---
sdks/python/apache_beam/ml/anomaly/specifiable.py | 72 ++++++++++---
.../apache_beam/ml/anomaly/specifiable_test.py | 111 ++++++++++++++++++++-
sdks/python/apache_beam/ml/anomaly/transforms.py | 7 ++
3 files changed, 170 insertions(+), 20 deletions(-)
diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable.py
b/sdks/python/apache_beam/ml/anomaly/specifiable.py
index e73ef5513b6..2eeb1d0de76 100644
--- a/sdks/python/apache_beam/ml/anomaly/specifiable.py
+++ b/sdks/python/apache_beam/ml/anomaly/specifiable.py
@@ -25,12 +25,16 @@ import collections
import dataclasses
import inspect
import logging
+import os
from typing import Any
from typing import ClassVar
+from typing import Dict
from typing import List
+from typing import Optional
from typing import Protocol
from typing import Type
from typing import TypeVar
+from typing import Union
from typing import runtime_checkable
from typing_extensions import Self
@@ -65,9 +69,11 @@ def _class_to_subspace(cls: Type) -> str:
subspace list. This is usually called when registering a new specifiable
class.
"""
- for c in cls.mro():
- if c.__name__ in _ACCEPTED_SUBSPACES:
- return c.__name__
+ if hasattr(cls, "mro"):
+ # some classes do not have "mro", such as functions.
+ for c in cls.mro():
+ if c.__name__ in _ACCEPTED_SUBSPACES:
+ return c.__name__
return _FALLBACK_SUBSPACE
@@ -92,8 +98,10 @@ class Spec():
"""
#: A string indicating the concrete `Specifiable` class
type: str
- #: A dictionary of keyword arguments for the `__init__` method of the class.
- config: dict[str, Any] = dataclasses.field(default_factory=dict)
+ #: An optional dictionary of keyword arguments for the `__init__` method of
+ #: the class. If None, when we materialize this Spec, we only return the
+ #: class without instantiate any objects from it.
+ config: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)
@runtime_checkable
@@ -122,7 +130,7 @@ class Specifiable(Protocol):
return v
@classmethod
- def from_spec(cls, spec: Spec, _run_init: bool = True) -> Self:
+ def from_spec(cls, spec: Spec, _run_init: bool = True) -> Union[Self, type]:
"""Generate a `Specifiable` subclass object based on a spec.
Args:
@@ -137,9 +145,15 @@ class Specifiable(Protocol):
subspace = _spec_type_to_subspace(spec.type)
subclass: Type[Self] = _KNOWN_SPECIFIABLE[subspace].get(spec.type, None)
+
if subclass is None:
raise ValueError(f"Unknown spec type '{spec.type}' in {spec}")
+ if spec.config is None:
+ # when functions or classes are used as arguments, we won't try to
+ # create an instance.
+ return subclass
+
kwargs = {
k: Specifiable._from_spec_helper(v, _run_init)
for k,
@@ -158,6 +172,16 @@ class Specifiable(Protocol):
if isinstance(v, List):
return [Specifiable._to_spec_helper(e) for e in v]
+ if inspect.isfunction(v):
+ if not hasattr(v, "spec_type"):
+ _register(v, inject_spec_type=False)
+ return Spec(type=_get_default_spec_type(v), config=None)
+
+ if inspect.isclass(v):
+ if not hasattr(v, "spec_type"):
+ _register(v, inject_spec_type=False)
+ return Spec(type=_get_default_spec_type(v), config=None)
+
return v
def to_spec(self) -> Spec:
@@ -180,23 +204,40 @@ class Specifiable(Protocol):
pass
+def _get_default_spec_type(cls):
+ spec_type = cls.__name__
+ if inspect.isfunction(cls) and cls.__name__ == "<lambda>":
+ # for lambda functions, we need to include more information to distinguish
+ # among them
+ spec_type = '<lambda at %s:%s>' % (
+ os.path.basename(cls.__code__.co_filename),
cls.__code__.co_firstlineno)
+
+ return spec_type
+
+
# Register a `Specifiable` subclass in `KNOWN_SPECIFIABLE`
-def _register(cls, spec_type=None) -> None:
+def _register(cls, spec_type=None, inject_spec_type=True) -> None:
+ assert spec_type is None or inject_spec_type, \
+ "need to inject spec_type to class if spec_type is not None"
if spec_type is None:
- # By default, spec type is the class name. Users can override this with
- # other unique identifier.
- spec_type = cls.__name__
+ # Use default spec_type for a class if users do not specify one.
+ spec_type = _get_default_spec_type(cls)
subspace = _class_to_subspace(cls)
if spec_type in _KNOWN_SPECIFIABLE[subspace]:
- raise ValueError(
- f"{spec_type} is already registered for "
- f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. "
- "Please specify a different spec_type by @specifiable(spec_type=...).")
+ if cls is not _KNOWN_SPECIFIABLE[subspace][spec_type]:
+ # only raise exception if we register the same spec type with a different
+ # class
+ raise ValueError(
+ f"{spec_type} is already registered for "
+ f"specifiable class {_KNOWN_SPECIFIABLE[subspace][spec_type]}. "
+ "Please specify a different spec_type by
@specifiable(spec_type=...)."
+ )
else:
_KNOWN_SPECIFIABLE[subspace][spec_type] = cls
- cls.spec_type = spec_type
+ if inject_spec_type:
+ cls.spec_type = spec_type
# Keep a copy of arguments that are used to call the `__init__` method when the
@@ -331,6 +372,7 @@ def specifiable(
cls._to_spec_helper = staticmethod(Specifiable._to_spec_helper)
cls.from_spec = Specifiable.from_spec
cls._from_spec_helper = staticmethod(Specifiable._from_spec_helper)
+
return cls
# end of the function body of _wrapper
diff --git a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py
b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py
index a3133f32e99..4c1a7bdaf32 100644
--- a/sdks/python/apache_beam/ml/anomaly/specifiable_test.py
+++ b/sdks/python/apache_beam/ml/anomaly/specifiable_test.py
@@ -18,6 +18,7 @@
import copy
import dataclasses
import logging
+import os
import unittest
from typing import List
from typing import Optional
@@ -43,6 +44,9 @@ class TestSpecifiable(unittest.TestCase):
class A():
pass
+ class B():
+ pass
+
# class is not decorated and thus not registered
self.assertNotIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
@@ -53,8 +57,11 @@ class TestSpecifiable(unittest.TestCase):
self.assertIn("A", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A"], A)
- # an error is raised if the specified spec_type already exists.
- self.assertRaises(ValueError, specifiable, A)
+ # Re-registering spec_type with the same class is allowed
+ A = specifiable(A)
+
+ # Raise an error when re-registering spec_type with a different class
+ self.assertRaises(ValueError, specifiable(spec_type='A'), B)
# apply the decorator function to an existing class with a different
# spec_type
@@ -64,9 +71,6 @@ class TestSpecifiable(unittest.TestCase):
self.assertIn("A_DUP", _KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE])
self.assertEqual(_KNOWN_SPECIFIABLE[_FALLBACK_SUBSPACE]["A_DUP"], A)
- # an error is raised if the specified spec_type already exists.
- self.assertRaises(ValueError, specifiable(spec_type="A_DUP"), A)
-
def test_decorator_in_syntactic_sugar_form(self):
# call decorator without parameters
@specifiable
@@ -484,6 +488,103 @@ class TestNestedSpecifiable(unittest.TestCase):
self.assertEqual(Child_2.counter, 0)
+def my_normal_func(x, y):
+ return x + y
+
+
+@specifiable
+class Wrapper():
+ def __init__(self, func=None, cls=None, **kwargs):
+ self._func = func
+ if cls is not None:
+ self._cls = cls(**kwargs)
+
+ def run_func(self, x, y):
+ return self._func(x, y)
+
+ def run_func_in_class(self, x, y):
+ return self._cls.apply(x, y)
+
+
+class TestFunctionAsArgument(unittest.TestCase):
+ def setUp(self) -> None:
+ self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE)
+
+ def tearDown(self) -> None:
+ _KNOWN_SPECIFIABLE.clear()
+ _KNOWN_SPECIFIABLE.update(self.saved_specifiable)
+
+ def test_normal_function(self):
+ w = Wrapper(my_normal_func)
+
+ self.assertEqual(w.run_func(1, 2), 3)
+
+ w_spec = w.to_spec()
+ self.assertEqual(
+ w_spec,
+ Spec(
+ type='Wrapper',
+ config={'func': Spec(type="my_normal_func", config=None)}))
+
+ w_2 = Specifiable.from_spec(w_spec)
+ self.assertEqual(w_2.run_func(2, 3), 5)
+
+ def test_lambda_function(self):
+ my_lambda_func = lambda x, y: x - y
+
+ w = Wrapper(my_lambda_func)
+
+ self.assertEqual(w.run_func(3, 2), 1)
+
+ w_spec = w.to_spec()
+ self.assertEqual(
+ w_spec,
+ Spec(
+ type='Wrapper',
+ config={
+ 'func': Spec(
+ type=
+ f"<lambda at
{os.path.basename(__file__)}:{my_lambda_func.__code__.co_firstlineno}>", #
pylint: disable=line-too-long
+ config=None)
+ }
+ ))
+
+ w_2 = Specifiable.from_spec(w_spec)
+ self.assertEqual(w_2.run_func(5, 3), 2)
+
+
+class TestClassAsArgument(unittest.TestCase):
+ def setUp(self) -> None:
+ self.saved_specifiable = copy.deepcopy(_KNOWN_SPECIFIABLE)
+
+ def tearDown(self) -> None:
+ _KNOWN_SPECIFIABLE.clear()
+ _KNOWN_SPECIFIABLE.update(self.saved_specifiable)
+
+ def test_normal_class(self):
+ class InnerClass():
+ def __init__(self, multiplier):
+ self._multiplier = multiplier
+
+ def apply(self, x, y):
+ return x * y * self._multiplier
+
+ w = Wrapper(cls=InnerClass, multiplier=10)
+ self.assertEqual(w.run_func_in_class(2, 3), 60)
+
+ w_spec = w.to_spec()
+ self.assertEqual(
+ w_spec,
+ Spec(
+ type='Wrapper',
+ config={
+ 'cls': Spec(type='InnerClass', config=None), 'multiplier': 10
+ }))
+
+ w_2 = Specifiable.from_spec(w_spec)
+ self.assertEqual(w_2.run_func_in_class(5, 3), 150)
+
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
diff --git a/sdks/python/apache_beam/ml/anomaly/transforms.py
b/sdks/python/apache_beam/ml/anomaly/transforms.py
index 35cae18a722..08b656072ac 100644
--- a/sdks/python/apache_beam/ml/anomaly/transforms.py
+++ b/sdks/python/apache_beam/ml/anomaly/transforms.py
@@ -18,6 +18,7 @@
import dataclasses
import uuid
from typing import Callable
+from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Tuple
@@ -55,6 +56,8 @@ class _ScoreAndLearnDoFn(beam.DoFn):
def __init__(self, detector_spec: Spec):
self._detector_spec = detector_spec
+
+ assert isinstance(self._detector_spec.config, Dict)
self._detector_spec.config["_run_init"] = True
def score_and_learn(self, data):
@@ -172,8 +175,10 @@ class _StatelessThresholdDoFn(_BaseThresholdDoFn):
creation of a stateful `ThresholdFn`.
"""
def __init__(self, threshold_fn_spec: Spec):
+ assert isinstance(threshold_fn_spec.config, Dict)
threshold_fn_spec.config["_run_init"] = True
self._threshold_fn = Specifiable.from_spec(threshold_fn_spec)
+ assert isinstance(self._threshold_fn, ThresholdFn)
assert not self._threshold_fn.is_stateful, \
"This DoFn can only take stateless function as threshold_fn"
@@ -217,8 +222,10 @@ class _StatefulThresholdDoFn(_BaseThresholdDoFn):
THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker',
DillCoder())
def __init__(self, threshold_fn_spec: Spec):
+ assert isinstance(threshold_fn_spec.config, Dict)
threshold_fn_spec.config["_run_init"] = True
threshold_fn = Specifiable.from_spec(threshold_fn_spec)
+ assert isinstance(threshold_fn, ThresholdFn)
assert threshold_fn.is_stateful, \
"This DoFn can only take stateful function as threshold_fn"
self._threshold_fn_spec = threshold_fn_spec