This is an automated email from the ASF dual-hosted git repository.
riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 1de8454ddd2 [Python] Enrichment Transform with BigTable handler
(#30001)
1de8454ddd2 is described below
commit 1de8454ddd215b9659fd049bded6a4b4c484b12d
Author: Ritesh Ghorse <[email protected]>
AuthorDate: Thu Jan 18 21:22:11 2024 -0500
[Python] Enrichment Transform with BigTable handler (#30001)
* enrichment v1
* add documentation
* add doc comment
* rerun
* update docs, lint
* update docs, lint
* add generic type
* add generic type
* adjust doc path
* create test row
* use request type
* use request type
* change module name
* more tests
* remove non-functional params
* lint, doc
* change types for general use
* callable type
* dict type
* update signatures
* fix unit test
* bigtable with column family, ids, rrio-throttler
* update tests for row filter
* convert handler types from dict to Row
* update tests for bigtable
* ran pydocs
* ran pydocs
* mark postcommit
* remove _test file, fix import
* enable postcommit
* add more tests
* skip tests when dependencies are not installed
* add deleted imports from last commit
* add skip test condition
* fix import order, add TooManyRequests to try-catch
* make throttler, repeater non-optional
* add exception level and tests
* correct pydoc statement
* add throttle tests
* add bigtable improvements
* default app_profile_id
* add documentation, ignore None assignment
* add to changes.md
* change test structure that throws exception, skip http test for now
* drop postcommit trigger file
---
CHANGES.md | 2 +-
sdks/python/apache_beam/io/requestresponse.py | 413 +++++++++++++++++++++
...nseio_it_test.py => requestresponse_it_test.py} | 37 +-
sdks/python/apache_beam/io/requestresponse_test.py | 156 ++++++++
sdks/python/apache_beam/io/requestresponseio.py | 218 -----------
.../apache_beam/io/requestresponseio_test.py | 88 -----
sdks/python/apache_beam/transforms/enrichment.py | 137 +++++++
.../transforms/enrichment_handlers/__init__.py | 16 +
.../transforms/enrichment_handlers/bigtable.py | 151 ++++++++
.../enrichment_handlers/bigtable_it_test.py | 300 +++++++++++++++
.../apache_beam/transforms/enrichment_it_test.py | 162 ++++++++
.../apache_beam/transforms/enrichment_test.py | 41 ++
12 files changed, 1398 insertions(+), 323 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 81a519b07d7..dbad15f3dba 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -69,7 +69,7 @@
## New Features / Improvements
-* X feature added (Java/Python)
([#X](https://github.com/apache/beam/issues/X)).
+* [Enrichment Transform](https://s.apache.org/enrichment-transform) along with
GCP BigTable handler added to Python SDK
([#30001](https://github.com/apache/beam/pull/30001)).
## Breaking Changes
diff --git a/sdks/python/apache_beam/io/requestresponse.py
b/sdks/python/apache_beam/io/requestresponse.py
new file mode 100644
index 00000000000..63ec7061d3e
--- /dev/null
+++ b/sdks/python/apache_beam/io/requestresponse.py
@@ -0,0 +1,413 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""``PTransform`` for reading from and writing to Web APIs."""
+import abc
+import concurrent.futures
+import contextlib
+import logging
+import sys
+import time
+from typing import Generic
+from typing import Optional
+from typing import TypeVar
+
+from google.api_core.exceptions import TooManyRequests
+
+import apache_beam as beam
+from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
+from apache_beam.metrics import Metrics
+from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC
+from apache_beam.utils import retry
+
+RequestT = TypeVar('RequestT')
+ResponseT = TypeVar('ResponseT')
+
+DEFAULT_TIMEOUT_SECS = 30 # seconds
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class UserCodeExecutionException(Exception):
+ """Base class for errors related to calling Web APIs."""
+
+
+class UserCodeQuotaException(UserCodeExecutionException):
+ """Extends ``UserCodeExecutionException`` to signal specifically that
+ the Web API client encountered a Quota or API overuse related error.
+ """
+
+
+class UserCodeTimeoutException(UserCodeExecutionException):
+ """Extends ``UserCodeExecutionException`` to signal a user code timeout."""
+
+
+def retry_on_exception(exception: Exception):
+ """retry on exceptions caused by unavailability of the remote server."""
+ return isinstance(
+ exception,
+ (TooManyRequests, UserCodeTimeoutException, UserCodeQuotaException))
+
+
+class _MetricsCollector:
+ """A metrics collector that tracks RequestResponseIO related usage."""
+ def __init__(self, namespace: str):
+ """
+ Args:
+ namespace: Namespace for the metrics.
+ """
+ self.requests = Metrics.counter(namespace, 'requests')
+ self.responses = Metrics.counter(namespace, 'responses')
+ self.failures = Metrics.counter(namespace, 'failures')
+ self.throttled_requests = Metrics.counter(namespace, 'throttled_requests')
+ self.throttled_secs = Metrics.counter(
+ namespace, 'cumulativeThrottlingSeconds')
+ self.timeout_requests = Metrics.counter(namespace, 'requests_timed_out')
+ self.call_counter = Metrics.counter(namespace, 'call_invocations')
+ self.setup_counter = Metrics.counter(namespace, 'setup_counter')
+ self.teardown_counter = Metrics.counter(namespace, 'teardown_counter')
+ self.backoff_counter = Metrics.counter(namespace, 'backoff_counter')
+ self.sleeper_counter = Metrics.counter(namespace, 'sleeper_counter')
+ self.should_backoff_counter = Metrics.counter(
+ namespace, 'should_backoff_counter')
+
+
+class Caller(contextlib.AbstractContextManager,
+ abc.ABC,
+ Generic[RequestT, ResponseT]):
+ """Interface for user custom code intended for API calls.
+ For setup and teardown of clients when applicable, implement the
+ ``__enter__`` and ``__exit__`` methods respectively."""
+ @abc.abstractmethod
+ def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT:
+ """Calls a Web API with the ``RequestT`` and returns a
+ ``ResponseT``. ``RequestResponseIO`` expects implementations of the
+ ``__call__`` method to throw either a ``UserCodeExecutionException``,
+ ``UserCodeQuotaException``, or ``UserCodeTimeoutException``.
+ """
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return None
+
+
+class ShouldBackOff(abc.ABC):
+ """
+ ShouldBackOff provides mechanism to apply adaptive throttling.
+ """
+ pass
+
+
+class Repeater(abc.ABC):
+ """Repeater provides mechanism to repeat requests for a
+ configurable condition."""
+ @abc.abstractmethod
+ def repeat(
+ self,
+ caller: Caller[RequestT, ResponseT],
+ request: RequestT,
+ timeout: float,
+ metrics_collector: Optional[_MetricsCollector]) -> ResponseT:
+ """repeat method is called from the RequestResponseIO when
+ a repeater is enabled.
+
+ Args:
+ caller: :class:`apache_beam.io.requestresponse.Caller` object that calls
+ the API.
+ request: input request to repeat.
+ timeout: time to wait for the request to complete.
+ metrics_collector: (Optional) a
+ ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to
+ collect the metrics for RequestResponseIO.
+ """
+ pass
+
+
+def _execute_request(
+ caller: Caller[RequestT, ResponseT],
+ request: RequestT,
+ timeout: float,
+ metrics_collector: Optional[_MetricsCollector] = None) -> ResponseT:
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(caller, request)
+ try:
+ return future.result(timeout=timeout)
+ except TooManyRequests as e:
+ _LOGGER.info(
+ 'request could not be completed. got code %i from the service.',
+ e.code)
+ raise e
+ except concurrent.futures.TimeoutError:
+ if metrics_collector:
+ metrics_collector.timeout_requests.inc(1)
+ raise UserCodeTimeoutException(
+ f'Timeout {timeout} exceeded '
+ f'while completing request: {request}')
+ except RuntimeError:
+ if metrics_collector:
+ metrics_collector.failures.inc(1)
+ raise UserCodeExecutionException('could not complete request')
+
+
+class ExponentialBackOffRepeater(Repeater):
+ """Exponential BackOff Repeater uses exponential backoff retry strategy for
+ exceptions due to the remote service such as TooManyRequests (HTTP 429),
+ UserCodeTimeoutException, UserCodeQuotaException.
+
+ It utilizes the decorator
+ :func:`apache_beam.utils.retry.with_exponential_backoff`.
+ """
+ def __init__(self):
+ pass
+
+ @retry.with_exponential_backoff(
+ num_retries=2, retry_filter=retry_on_exception)
+ def repeat(
+ self,
+ caller: Caller[RequestT, ResponseT],
+ request: RequestT,
+ timeout: float,
+ metrics_collector: Optional[_MetricsCollector] = None) -> ResponseT:
+ """repeat method is called from the RequestResponseIO when
+ a repeater is enabled.
+
+ Args:
+ caller: :class:`apache_beam.io.requestresponse.Caller` object that
+ calls the API.
+ request: input request to repeat.
+ timeout: time to wait for the request to complete.
+ metrics_collector: (Optional) a
+ ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to
+ collect the metrics for RequestResponseIO.
+ """
+ return _execute_request(caller, request, timeout, metrics_collector)
+
+
+class NoOpsRepeater(Repeater):
+ """
+ NoOpsRepeater executes a request just once irrespective of any exception.
+ """
+ def repeat(
+ self,
+ caller: Caller[RequestT, ResponseT],
+ request: RequestT,
+ timeout: float,
+ metrics_collector: Optional[_MetricsCollector]) -> ResponseT:
+ return _execute_request(caller, request, timeout, metrics_collector)
+
+
+class CacheReader(abc.ABC):
+ """CacheReader provides mechanism to read from the cache."""
+ pass
+
+
+class CacheWriter(abc.ABC):
+ """CacheWriter provides mechanism to write to the cache."""
+ pass
+
+
+class PreCallThrottler(abc.ABC):
+ """PreCallThrottler provides a throttle mechanism before sending request."""
+ pass
+
+
+class DefaultThrottler(PreCallThrottler):
+ """Default throttler that uses
+ :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`
+
+ Args:
+ window_ms (int): length of history to consider, in ms, to set throttling.
+ bucket_ms (int): granularity of time buckets that we store data in, in ms.
+ overload_ratio (float): the target ratio between requests sent and
+ successful requests. This is "K" in the formula in
+ https://landing.google.com/sre/book/chapters/handling-overload.html.
+ delay_secs (int): minimum number of seconds to throttle a request.
+ """
+ def __init__(
+ self,
+ window_ms: int = 1,
+ bucket_ms: int = 1,
+ overload_ratio: float = 2,
+ delay_secs: int = 5):
+ self.throttler = AdaptiveThrottler(
+ window_ms=window_ms, bucket_ms=bucket_ms,
overload_ratio=overload_ratio)
+ self.delay_secs = delay_secs
+
+
+class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT],
+ beam.PCollection[ResponseT]]):
+ """A :class:`RequestResponseIO` transform to read and write to APIs.
+
+ Processes an input :class:`~apache_beam.pvalue.PCollection` of requests
+ by making a call to the API as defined in :class:`Caller`'s `__call__`
+ and returns a :class:`~apache_beam.pvalue.PCollection` of responses.
+ """
+ def __init__(
+ self,
+ caller: Caller[RequestT, ResponseT],
+ timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
+ should_backoff: Optional[ShouldBackOff] = None,
+ repeater: Repeater = ExponentialBackOffRepeater(),
+ cache_reader: Optional[CacheReader] = None,
+ cache_writer: Optional[CacheWriter] = None,
+ throttler: PreCallThrottler = DefaultThrottler(),
+ ):
+ """
+ Instantiates a RequestResponseIO transform.
+
+ Args:
+ caller (~apache_beam.io.requestresponse.Caller): an implementation of
+ `Caller` object that makes call to the API.
+ timeout (float): timeout value in seconds to wait for response from API.
+ should_backoff (~apache_beam.io.requestresponse.ShouldBackOff):
+ (Optional) provides methods for backoff.
+ repeater (~apache_beam.io.requestresponse.Repeater): provides method to
+ repeat failed requests to API due to service errors. Defaults to
+ :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to
+ repeat requests with exponential backoff.
+ cache_reader (~apache_beam.io.requestresponse.CacheReader): (Optional)
+ provides methods to read external cache.
+ cache_writer (~apache_beam.io.requestresponse.CacheWriter): (Optional)
+ provides methods to write to external cache.
+ throttler (~apache_beam.io.requestresponse.PreCallThrottler):
+ provides methods to pre-throttle a request. Defaults to
+ :class:`apache_beam.io.requestresponse.DefaultThrottler` for
+ client-side adaptive throttling using
+ :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`
+ """
+ self._caller = caller
+ self._timeout = timeout
+ self._should_backoff = should_backoff
+ if repeater:
+ self._repeater = repeater
+ else:
+ self._repeater = NoOpsRepeater()
+ self._cache_reader = cache_reader
+ self._cache_writer = cache_writer
+ self._throttler = throttler
+
+ def expand(
+ self,
+ requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
+ # TODO(riteshghorse): handle Cache and Throttle PTransforms when available.
+ if isinstance(self._throttler, DefaultThrottler):
+ return requests | _Call(
+ caller=self._caller,
+ timeout=self._timeout,
+ should_backoff=self._should_backoff,
+ repeater=self._repeater,
+ throttler=self._throttler)
+ else:
+ return requests | _Call(
+ caller=self._caller,
+ timeout=self._timeout,
+ should_backoff=self._should_backoff,
+ repeater=self._repeater)
+
+
+class _Call(beam.PTransform[beam.PCollection[RequestT],
+ beam.PCollection[ResponseT]]):
+ """(Internal-only) PTransform that invokes a remote function on each element
+ of the input PCollection.
+
+ This PTransform uses a `Caller` object to invoke the actual API calls,
+ and uses ``__enter__`` and ``__exit__`` to manage setup and teardown of
+ clients when applicable. Additionally, a timeout value is specified to
+ regulate the duration of each call, defaults to 30 seconds.
+
+ Args:
+ caller (:class:`apache_beam.io.requestresponse.Caller`): a callable
+ object that invokes API call.
+ timeout (float): timeout value in seconds to wait for response from API.
+ should_backoff (~apache_beam.io.requestresponse.ShouldBackOff):
+ (Optional) provides methods for backoff.
+ repeater (~apache_beam.io.requestresponse.Repeater): (Optional) provides
+ methods to repeat requests to API.
+ throttler (~apache_beam.io.requestresponse.PreCallThrottler):
+ (Optional) provides methods to pre-throttle a request.
+ """
+ def __init__(
+ self,
+ caller: Caller[RequestT, ResponseT],
+ timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
+ should_backoff: Optional[ShouldBackOff] = None,
+ repeater: Repeater = None,
+ throttler: PreCallThrottler = None,
+ ):
+ self._caller = caller
+ self._timeout = timeout
+ self._should_backoff = should_backoff
+ self._repeater = repeater
+ self._throttler = throttler
+
+ def expand(
+ self,
+ requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
+ return requests | beam.ParDo(
+ _CallDoFn(self._caller, self._timeout, self._repeater,
self._throttler))
+
+
+class _CallDoFn(beam.DoFn):
+ def setup(self):
+ self._caller.__enter__()
+ self._metrics_collector = _MetricsCollector(self._caller.__str__())
+ self._metrics_collector.setup_counter.inc(1)
+
+ def __init__(
+ self,
+ caller: Caller[RequestT, ResponseT],
+ timeout: float,
+ repeater: Repeater,
+ throttler: PreCallThrottler):
+ self._metrics_collector = None
+ self._caller = caller
+ self._timeout = timeout
+ self._repeater = repeater
+ self._throttler = throttler
+
+ def process(self, request: RequestT, *args, **kwargs):
+ self._metrics_collector.requests.inc(1)
+
+ is_throttled_request = False
+ if self._throttler:
+ while self._throttler.throttler.throttle_request(time.time() *
+ MSEC_TO_SEC):
+ _LOGGER.info(
+ "Delaying request for %d seconds" % self._throttler.delay_secs)
+ time.sleep(self._throttler.delay_secs)
+ self._metrics_collector.throttled_secs.inc(self._throttler.delay_secs)
+ is_throttled_request = True
+
+ if is_throttled_request:
+ self._metrics_collector.throttled_requests.inc(1)
+
+ try:
+ req_time = time.time()
+ response = self._repeater.repeat(
+ self._caller, request, self._timeout, self._metrics_collector)
+ self._metrics_collector.responses.inc(1)
+ self._throttler.throttler.successful_request(req_time * MSEC_TO_SEC)
+ yield response
+ except Exception as e:
+ raise e
+
+ def teardown(self):
+ self._metrics_collector.teardown_counter.inc(1)
+ self._caller.__exit__(*sys.exc_info())
diff --git a/sdks/python/apache_beam/io/requestresponseio_it_test.py
b/sdks/python/apache_beam/io/requestresponse_it_test.py
similarity index 86%
rename from sdks/python/apache_beam/io/requestresponseio_it_test.py
rename to sdks/python/apache_beam/io/requestresponse_it_test.py
index aae6b4e6ef2..396347c58d1 100644
--- a/sdks/python/apache_beam/io/requestresponseio_it_test.py
+++ b/sdks/python/apache_beam/io/requestresponse_it_test.py
@@ -16,6 +16,7 @@
#
import base64
import sys
+import typing
import unittest
from dataclasses import dataclass
from typing import Tuple
@@ -24,13 +25,18 @@ from typing import Union
import urllib3
import apache_beam as beam
-from apache_beam.io.requestresponseio import Caller
-from apache_beam.io.requestresponseio import RequestResponseIO
-from apache_beam.io.requestresponseio import UserCodeExecutionException
-from apache_beam.io.requestresponseio import UserCodeQuotaException
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.test_pipeline import TestPipeline
+# pylint: disable=ungrouped-imports
+try:
+ from apache_beam.io.requestresponse import Caller
+ from apache_beam.io.requestresponse import RequestResponseIO
+ from apache_beam.io.requestresponse import UserCodeExecutionException
+ from apache_beam.io.requestresponse import UserCodeQuotaException
+except ImportError:
+ raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
+
_HTTP_PATH = '/v1/echo'
_PAYLOAD = base64.b64encode(bytes('payload', 'utf-8'))
_HTTP_ENDPOINT_ADDRESS_FLAG = '--httpEndpointAddress'
@@ -61,28 +67,27 @@ class EchoITOptions(PipelineOptions):
help='The ID for an allocated quota that should exceed.')
-# TODO(riteshghorse,damondouglas) replace Echo(Request|Response) with proto
-# generated classes from .test-infra/mock-apis:
@dataclass
-class EchoRequest:
+class EchoResponse:
id: str
payload: bytes
-@dataclass
-class EchoResponse:
+# TODO(riteshghorse,damondouglas) replace Echo(Request|Response) with proto
+# generated classes from .test-infra/mock-apis:
+class Request(typing.NamedTuple):
id: str
payload: bytes
-class EchoHTTPCaller(Caller):
+class EchoHTTPCaller(Caller[Request, EchoResponse]):
"""Implements ``Caller`` to call the ``EchoServiceGrpc``'s HTTP handler.
The purpose of ``EchoHTTPCaller`` is to support integration tests.
"""
def __init__(self, url: str):
self.url = url + _HTTP_PATH
- def __call__(self, request: EchoRequest, *args, **kwargs) -> EchoResponse:
+ def __call__(self, request: Request, *args, **kwargs) -> EchoResponse:
"""Overrides ``Caller``'s call method invoking the
``EchoServiceGrpc``'s HTTP handler with an ``EchoRequest``, returning
either a successful ``EchoResponse`` or throwing either a
@@ -129,7 +134,7 @@ class EchoHTTPCallerTestIT(unittest.TestCase):
def setUp(self) -> None:
client, options = EchoHTTPCallerTestIT._get_client_and_options()
- req = EchoRequest(id=options.should_exceed_quota_id, payload=_PAYLOAD)
+ req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD)
try:
# The following is needed to exceed the API
client(req)
@@ -148,7 +153,7 @@ class EchoHTTPCallerTestIT(unittest.TestCase):
def test_given_valid_request_receives_response(self):
client, options = EchoHTTPCallerTestIT._get_client_and_options()
- req = EchoRequest(id=options.never_exceed_quota_id, payload=_PAYLOAD)
+ req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
response: EchoResponse = client(req)
@@ -158,20 +163,20 @@ class EchoHTTPCallerTestIT(unittest.TestCase):
def test_given_exceeded_quota_should_raise(self):
client, options = EchoHTTPCallerTestIT._get_client_and_options()
- req = EchoRequest(id=options.should_exceed_quota_id, payload=_PAYLOAD)
+ req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD)
self.assertRaises(UserCodeQuotaException, lambda: client(req))
def test_not_found_should_raise(self):
client, _ = EchoHTTPCallerTestIT._get_client_and_options()
- req = EchoRequest(id='i-dont-exist-quota-id', payload=_PAYLOAD)
+ req = Request(id='i-dont-exist-quota-id', payload=_PAYLOAD)
self.assertRaisesRegex(
UserCodeExecutionException, "Not Found", lambda: client(req))
def test_request_response_io(self):
client, options = EchoHTTPCallerTestIT._get_client_and_options()
- req = EchoRequest(id=options.never_exceed_quota_id, payload=_PAYLOAD)
+ req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
with TestPipeline(is_integration_test=True) as test_pipeline:
output = (
test_pipeline
diff --git a/sdks/python/apache_beam/io/requestresponse_test.py
b/sdks/python/apache_beam/io/requestresponse_test.py
new file mode 100644
index 00000000000..6d807c2a8eb
--- /dev/null
+++ b/sdks/python/apache_beam/io/requestresponse_test.py
@@ -0,0 +1,156 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import time
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=ungrouped-imports
+try:
+ from google.api_core.exceptions import TooManyRequests
+ from apache_beam.io.requestresponse import Caller, DefaultThrottler
+ from apache_beam.io.requestresponse import RequestResponseIO
+ from apache_beam.io.requestresponse import UserCodeExecutionException
+ from apache_beam.io.requestresponse import UserCodeTimeoutException
+ from apache_beam.io.requestresponse import retry_on_exception
+except ImportError:
+ raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
+
+
+class AckCaller(Caller[str, str]):
+ """AckCaller acknowledges the incoming request by returning a
+ request with ACK."""
+ def __enter__(self):
+ pass
+
+ def __call__(self, request: str):
+ return f"ACK: {request}"
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return None
+
+
+class CallerWithTimeout(AckCaller):
+ """CallerWithTimeout sleeps for 2 seconds before responding.
+ Used to test timeout in RequestResponseIO."""
+ def __call__(self, request: str, *args, **kwargs):
+ time.sleep(2)
+ return f"ACK: {request}"
+
+
+class CallerWithRuntimeError(AckCaller):
+ """CallerWithRuntimeError raises a `RuntimeError` for RequestResponseIO
+ to raise a UserCodeExecutionException."""
+ def __call__(self, request: str, *args, **kwargs):
+ if not request:
+ raise RuntimeError("Exception expected, not an error.")
+
+
+class CallerThatRetries(AckCaller):
+ def __init__(self):
+ self.count = -1
+
+ def __call__(self, request: str, *args, **kwargs):
+ try:
+ pass
+ except Exception as e:
+ raise e
+ finally:
+ self.count += 1
+ raise TooManyRequests('retries = %d' % self.count)
+
+
+class TestCaller(unittest.TestCase):
+ def test_valid_call(self):
+ caller = AckCaller()
+ with TestPipeline() as test_pipeline:
+ output = (
+ test_pipeline
+ | beam.Create(["sample_request"])
+ | RequestResponseIO(caller=caller))
+
+ self.assertIsNotNone(output)
+
+ def test_call_timeout(self):
+ caller = CallerWithTimeout()
+ with self.assertRaises(UserCodeTimeoutException):
+ with TestPipeline() as test_pipeline:
+ _ = (
+ test_pipeline
+ | beam.Create(["timeout_request"])
+ | RequestResponseIO(caller=caller, timeout=1))
+
+ def test_call_runtime_error(self):
+ caller = CallerWithRuntimeError()
+ with self.assertRaises(UserCodeExecutionException):
+ with TestPipeline() as test_pipeline:
+ _ = (
+ test_pipeline
+ | beam.Create([""])
+ | RequestResponseIO(caller=caller))
+
+ def test_retry_on_exception(self):
+ self.assertFalse(retry_on_exception(RuntimeError()))
+ self.assertTrue(retry_on_exception(TooManyRequests("HTTP 429")))
+
+ def test_caller_backoff_retry_strategy(self):
+ caller = CallerThatRetries()
+ with self.assertRaises(TooManyRequests) as cm:
+ with TestPipeline() as test_pipeline:
+ _ = (
+ test_pipeline
+ | beam.Create(["sample_request"])
+ | RequestResponseIO(caller=caller))
+ self.assertRegex(cm.exception.message, 'retries = 2')
+
+ def test_caller_no_retry_strategy(self):
+ caller = CallerThatRetries()
+ with self.assertRaises(TooManyRequests) as cm:
+ with TestPipeline() as test_pipeline:
+ _ = (
+ test_pipeline
+ | beam.Create(["sample_request"])
+ | RequestResponseIO(caller=caller, repeater=None))
+ self.assertRegex(cm.exception.message, 'retries = 0')
+
+ def test_default_throttler(self):
+ caller = CallerWithTimeout()
+ throttler = DefaultThrottler(
+ window_ms=10000, bucket_ms=5000, overload_ratio=1)
+ # manually override the number of received requests for testing.
+ throttler.throttler._all_requests.add(time.time() * 1000, 100)
+ test_pipeline = TestPipeline()
+ _ = (
+ test_pipeline
+ | beam.Create(['sample_request'])
+ | RequestResponseIO(caller=caller, throttler=throttler))
+ result = test_pipeline.run()
+ result.wait_until_finish()
+ metrics = result.metrics().query(
+ beam.metrics.MetricsFilter().with_name('throttled_requests'))
+ self.assertEqual(metrics['counters'][0].committed, 1)
+ metrics = result.metrics().query(
+ beam.metrics.MetricsFilter().with_name('cumulativeThrottlingSeconds'))
+ self.assertGreater(metrics['counters'][0].committed, 0)
+ metrics = result.metrics().query(
+ beam.metrics.MetricsFilter().with_name('responses'))
+ self.assertEqual(metrics['counters'][0].committed, 1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/io/requestresponseio.py
b/sdks/python/apache_beam/io/requestresponseio.py
deleted file mode 100644
index 0ec586e6401..00000000000
--- a/sdks/python/apache_beam/io/requestresponseio.py
+++ /dev/null
@@ -1,218 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""``PTransform`` for reading from and writing to Web APIs."""
-import abc
-import concurrent.futures
-import contextlib
-import logging
-import sys
-from typing import Generic
-from typing import Optional
-from typing import TypeVar
-
-import apache_beam as beam
-from apache_beam.pvalue import PCollection
-
-RequestT = TypeVar('RequestT')
-ResponseT = TypeVar('ResponseT')
-
-DEFAULT_TIMEOUT_SECS = 30 # seconds
-
-_LOGGER = logging.getLogger(__name__)
-
-
-class UserCodeExecutionException(Exception):
- """Base class for errors related to calling Web APIs."""
-
-
-class UserCodeQuotaException(UserCodeExecutionException):
- """Extends ``UserCodeExecutionException`` to signal specifically that
- the Web API client encountered a Quota or API overuse related error.
- """
-
-
-class UserCodeTimeoutException(UserCodeExecutionException):
- """Extends ``UserCodeExecutionException`` to signal a user code timeout."""
-
-
-class Caller(contextlib.AbstractContextManager, abc.ABC):
- """Interface for user custom code intended for API calls.
- For setup and teardown of clients when applicable, implement the
- ``__enter__`` and ``__exit__`` methods respectively."""
- @abc.abstractmethod
- def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT:
- """Calls a Web API with the ``RequestT`` and returns a
- ``ResponseT``. ``RequestResponseIO`` expects implementations of the
- ``__call__`` method to throw either a ``UserCodeExecutionException``,
- ``UserCodeQuotaException``, or ``UserCodeTimeoutException``.
- """
- pass
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- return None
-
-
-class ShouldBackOff(abc.ABC):
- """
- ShouldBackOff provides mechanism to apply adaptive throttling.
- """
- pass
-
-
-class Repeater(abc.ABC):
- """Repeater provides mechanism to repeat requests for a
- configurable condition."""
- pass
-
-
-class CacheReader(abc.ABC):
- """CacheReader provides mechanism to read from the cache."""
- pass
-
-
-class CacheWriter(abc.ABC):
- """CacheWriter provides mechanism to write to the cache."""
- pass
-
-
-class PreCallThrottler(abc.ABC):
- """PreCallThrottler provides a throttle mechanism before sending request."""
- pass
-
-
-class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT],
- beam.PCollection[ResponseT]]):
- """A :class:`RequestResponseIO` transform to read and write to APIs.
-
- Processes an input :class:`~apache_beam.pvalue.PCollection` of requests
- by making a call to the API as defined in :class:`Caller`'s `__call__`
- and returns a :class:`~apache_beam.pvalue.PCollection` of responses.
- """
- def __init__(
- self,
- caller: [Caller],
- timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
- should_backoff: Optional[ShouldBackOff] = None,
- repeater: Optional[Repeater] = None,
- cache_reader: Optional[CacheReader] = None,
- cache_writer: Optional[CacheWriter] = None,
- throttler: Optional[PreCallThrottler] = None,
- ):
- """
- Instantiates a RequestResponseIO transform.
-
- Args:
- caller (~apache_beam.io.requestresponseio.Caller): an implementation of
- `Caller` object that makes call to the API.
- timeout (float): timeout value in seconds to wait for response from API.
- should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff):
- (Optional) provides methods for backoff.
- repeater (~apache_beam.io.requestresponseio.Repeater): (Optional)
- provides methods to repeat requests to API.
- cache_reader (~apache_beam.io.requestresponseio.CacheReader): (Optional)
- provides methods to read external cache.
- cache_writer (~apache_beam.io.requestresponseio.CacheWriter): (Optional)
- provides methods to write to external cache.
- throttler (~apache_beam.io.requestresponseio.PreCallThrottler):
- (Optional) provides methods to pre-throttle a request.
- """
- self._caller = caller
- self._timeout = timeout
- self._should_backoff = should_backoff
- self._repeater = repeater
- self._cache_reader = cache_reader
- self._cache_writer = cache_writer
- self._throttler = throttler
-
- def expand(self, requests: PCollection[RequestT]) -> PCollection[ResponseT]:
- # TODO(riteshghorse): add Cache and Throttle PTransforms.
- return requests | _Call(
- caller=self._caller,
- timeout=self._timeout,
- should_backoff=self._should_backoff,
- repeater=self._repeater)
-
-
-class _Call(beam.PTransform[beam.PCollection[RequestT],
- beam.PCollection[ResponseT]]):
- """(Internal-only) PTransform that invokes a remote function on each element
- of the input PCollection.
-
- This PTransform uses a `Caller` object to invoke the actual API calls,
- and uses ``__enter__`` and ``__exit__`` to manage setup and teardown of
- clients when applicable. Additionally, a timeout value is specified to
- regulate the duration of each call, defaults to 30 seconds.
-
- Args:
- caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable
- object that invokes API call.
- timeout (float): timeout value in seconds to wait for response from API.
- """
- def __init__(
- self,
- caller: Caller,
- timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
- should_backoff: Optional[ShouldBackOff] = None,
- repeater: Optional[Repeater] = None,
- ):
- """Initialize the _Call transform.
- Args:
- caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable
- object that invokes API call.
- timeout (float): timeout value in seconds to wait for response from API.
- should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff):
- (Optional) provides methods for backoff.
- repeater (~apache_beam.io.requestresponseio.Repeater): (Optional)
provides
- methods to repeat requests to API.
- """
- self._caller = caller
- self._timeout = timeout
- self._should_backoff = should_backoff
- self._repeater = repeater
-
- def expand(
- self,
- requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
- return requests | beam.ParDo(_CallDoFn(self._caller, self._timeout))
-
-
-class _CallDoFn(beam.DoFn, Generic[RequestT, ResponseT]):
- def setup(self):
- self._caller.__enter__()
-
- def __init__(self, caller: Caller, timeout: float):
- self._caller = caller
- self._timeout = timeout
-
- def process(self, request, *args, **kwargs):
- with concurrent.futures.ThreadPoolExecutor() as executor:
- future = executor.submit(self._caller, request)
- try:
- yield future.result(timeout=self._timeout)
- except concurrent.futures.TimeoutError:
- raise UserCodeTimeoutException(
- f'Timeout {self._timeout} exceeded '
- f'while completing request: {request}')
- except RuntimeError:
- raise UserCodeExecutionException('could not complete request')
-
- def teardown(self):
- self._caller.__exit__(*sys.exc_info())
diff --git a/sdks/python/apache_beam/io/requestresponseio_test.py
b/sdks/python/apache_beam/io/requestresponseio_test.py
deleted file mode 100644
index 2828a357887..00000000000
--- a/sdks/python/apache_beam/io/requestresponseio_test.py
+++ /dev/null
@@ -1,88 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import time
-import unittest
-
-import apache_beam as beam
-from apache_beam.io.requestresponseio import Caller
-from apache_beam.io.requestresponseio import RequestResponseIO
-from apache_beam.io.requestresponseio import UserCodeExecutionException
-from apache_beam.io.requestresponseio import UserCodeTimeoutException
-from apache_beam.testing.test_pipeline import TestPipeline
-
-
-class AckCaller(Caller):
- """AckCaller acknowledges the incoming request by returning a
- request with ACK."""
- def __enter__(self):
- pass
-
- def __call__(self, request: str):
- return f"ACK: {request}"
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- return None
-
-
-class CallerWithTimeout(AckCaller):
- """CallerWithTimeout sleeps for 2 seconds before responding.
- Used to test timeout in RequestResponseIO."""
- def __call__(self, request: str, *args, **kwargs):
- time.sleep(2)
- return f"ACK: {request}"
-
-
-class CallerWithRuntimeError(AckCaller):
- """CallerWithRuntimeError raises a `RuntimeError` for RequestResponseIO
- to raise a UserCodeExecutionException."""
- def __call__(self, request: str, *args, **kwargs):
- if not request:
- raise RuntimeError("Exception expected, not an error.")
-
-
-class TestCaller(unittest.TestCase):
- def test_valid_call(self):
- caller = AckCaller()
- with TestPipeline() as test_pipeline:
- output = (
- test_pipeline
- | beam.Create(["sample_request"])
- | RequestResponseIO(caller=caller))
-
- self.assertIsNotNone(output)
-
- def test_call_timeout(self):
- caller = CallerWithTimeout()
- with self.assertRaises(UserCodeTimeoutException):
- with TestPipeline() as test_pipeline:
- _ = (
- test_pipeline
- | beam.Create(["timeout_request"])
- | RequestResponseIO(caller=caller, timeout=1))
-
- def test_call_runtime_error(self):
- caller = CallerWithRuntimeError()
- with self.assertRaises(UserCodeExecutionException):
- with TestPipeline() as test_pipeline:
- _ = (
- test_pipeline
- | beam.Create([""])
- | RequestResponseIO(caller=caller))
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/sdks/python/apache_beam/transforms/enrichment.py
b/sdks/python/apache_beam/transforms/enrichment.py
new file mode 100644
index 00000000000..a2f961be643
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import TypeVar
+
+import apache_beam as beam
+from apache_beam.io.requestresponse import DEFAULT_TIMEOUT_SECS
+from apache_beam.io.requestresponse import Caller
+from apache_beam.io.requestresponse import DefaultThrottler
+from apache_beam.io.requestresponse import ExponentialBackOffRepeater
+from apache_beam.io.requestresponse import PreCallThrottler
+from apache_beam.io.requestresponse import Repeater
+from apache_beam.io.requestresponse import RequestResponseIO
+
+__all__ = [
+ "EnrichmentSourceHandler",
+ "Enrichment",
+ "cross_join",
+]
+
+InputT = TypeVar('InputT')
+OutputT = TypeVar('OutputT')
+
+JoinFn = Callable[[Dict[str, Any], Dict[str, Any]], beam.Row]
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row:
+ """cross_join performs a cross join on two `dict` objects.
+
+ Joins the columns of the right row onto the left row.
+
+ Args:
+ left (Dict[str, Any]): input request dictionary.
+ right (Dict[str, Any]): response dictionary from the API.
+
+ Returns:
+ `beam.Row` containing the merged columns.
+ """
+ for k, v in right.items():
+ if k not in left:
+ # Don't override the values in left.
+ left[k] = v
+ elif left[k] != v:
+ _LOGGER.warning(
+ '%s exists in the input row as well the row fetched '
+ 'from API but have different values - %s and %s. Using the input '
+ 'value (%s) for the enriched row. You can override this behavior by '
+ 'passing a custom `join_fn` to Enrichment transform.' %
+ (k, left[k], v, left[k]))
+ return beam.Row(**left)
+
+
+class EnrichmentSourceHandler(Caller[InputT, OutputT]):
+ """Wrapper class for :class:`apache_beam.io.requestresponse.Caller`.
+
+ Ensure that the implementation of ``__call__`` method returns a tuple
+ of `beam.Row` objects.
+ """
+ pass
+
+
+class Enrichment(beam.PTransform[beam.PCollection[InputT],
+ beam.PCollection[OutputT]]):
+ """A :class:`apache_beam.transforms.enrichment.Enrichment` transform to
+ enrich elements in a PCollection.
+ **NOTE:** This transform and its implementation are under development and
+ do not provide backward compatibility guarantees.
+ Uses the :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler`
+ to enrich elements by joining the metadata from external source.
+
+ Processes an input :class:`~apache_beam.pvalue.PCollection` of `beam.Row` by
+ applying a :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler`
+ to each element and returning the enriched
+ :class:`~apache_beam.pvalue.PCollection`.
+
+ Args:
+ source_handler: Handles source lookup and metadata retrieval.
+ Implements the
+ :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler`
+ join_fn: A lambda function to join original element with lookup metadata.
+ Defaults to `CROSS_JOIN`.
+ timeout: (Optional) timeout for source requests. Defaults to 30 seconds.
+ repeater (~apache_beam.io.requestresponse.Repeater): provides method to
+ repeat failed requests to API due to service errors. Defaults to
+ :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to
+ repeat requests with exponential backoff.
+ throttler (~apache_beam.io.requestresponse.PreCallThrottler):
+ provides methods to pre-throttle a request. Defaults to
+ :class:`apache_beam.io.requestresponse.DefaultThrottler` for
+ client-side adaptive throttling using
+ :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`.
+ """
+ def __init__(
+ self,
+ source_handler: EnrichmentSourceHandler,
+ join_fn: JoinFn = cross_join,
+ timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
+ repeater: Repeater = ExponentialBackOffRepeater(),
+ throttler: PreCallThrottler = DefaultThrottler(),
+ ):
+ self._source_handler = source_handler
+ self._join_fn = join_fn
+ self._timeout = timeout
+ self._repeater = repeater
+ self._throttler = throttler
+
+ def expand(self,
+ input_row: beam.PCollection[InputT]) -> beam.PCollection[OutputT]:
+ fetched_data = input_row | RequestResponseIO(
+ caller=self._source_handler,
+ timeout=self._timeout,
+ repeater=self._repeater,
+ throttler=self._throttler)
+
+ # EnrichmentSourceHandler returns a tuple of (request,response).
+ return fetched_data | beam.Map(
+ lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict()))
diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/__init__.py
b/sdks/python/apache_beam/transforms/enrichment_handlers/__init__.py
new file mode 100644
index 00000000000..cce3acad34a
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_handlers/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py
b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py
new file mode 100644
index 00000000000..86ff2f3b8e7
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py
@@ -0,0 +1,151 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+from enum import Enum
+from typing import Any
+from typing import Dict
+from typing import Optional
+
+from google.api_core.exceptions import NotFound
+from google.cloud import bigtable
+from google.cloud.bigtable import Client
+from google.cloud.bigtable.row_filters import CellsColumnLimitFilter
+from google.cloud.bigtable.row_filters import RowFilter
+
+import apache_beam as beam
+from apache_beam.transforms.enrichment import EnrichmentSourceHandler
+
+__all__ = [
+ 'EnrichWithBigTable',
+ 'ExceptionLevel',
+]
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class ExceptionLevel(Enum):
+ """ExceptionLevel defines the exception level options to either
+ log a warning, or raise an exception, or do nothing when a BigTable query
+ returns an empty row.
+
+ Members:
+ - RAISE: Raise the exception.
+ - WARN: Log a warning for exception without raising it.
+ - QUIET: Neither log nor raise the exception.
+ """
+ RAISE = 0
+ WARN = 1
+ QUIET = 2
+
+
+class EnrichWithBigTable(EnrichmentSourceHandler[beam.Row, beam.Row]):
+ """EnrichWithBigTable is a handler for
+ :class:`apache_beam.transforms.enrichment.Enrichment` transform to interact
+ with GCP BigTable.
+
+ Args:
+ project_id (str): GCP project-id of the BigTable cluster.
+ instance_id (str): GCP instance-id of the BigTable cluster.
+ table_id (str): GCP table-id of the BigTable.
+ row_key (str): unique row-key field name from the input `beam.Row` object
+ to use as `row_key` for BigTable querying.
+ row_filter: a ``:class:`google.cloud.bigtable.row_filters.RowFilter``` to
+ filter data read with ``read_row()``.
+ Defaults to `CellsColumnLimitFilter(1)`.
+ app_profile_id (str): App profile ID to use for BigTable.
+ See https://cloud.google.com/bigtable/docs/app-profiles for more details.
+ encoding (str): encoding type to convert the string to bytes and vice-versa
+ from BigTable. Default is `utf-8`.
+ exception_level: a `enum.Enum` value from
+ ``apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel``
+ to set the level when an empty row is returned from the BigTable query.
+ Defaults to ``ExceptionLevel.WARN``.
+ """
+ def __init__(
+ self,
+ project_id: str,
+ instance_id: str,
+ table_id: str,
+ row_key: str,
+ row_filter: Optional[RowFilter] = CellsColumnLimitFilter(1),
+ app_profile_id: str = None, # type: ignore[assignment]
+ encoding: str = 'utf-8',
+ exception_level: ExceptionLevel = ExceptionLevel.WARN,
+ ):
+ self._project_id = project_id
+ self._instance_id = instance_id
+ self._table_id = table_id
+ self._row_key = row_key
+ self._row_filter = row_filter
+ self._app_profile_id = app_profile_id
+ self._encoding = encoding
+ self._exception_level = exception_level
+
+ def __enter__(self):
+ """connect to the Google BigTable cluster."""
+ self.client = Client(project=self._project_id)
+ self.instance = self.client.instance(self._instance_id)
+ self._table = bigtable.table.Table(
+ table_id=self._table_id,
+ instance=self.instance,
+ app_profile_id=self._app_profile_id)
+
+ def __call__(self, request: beam.Row, *args, **kwargs):
+ """
+ Reads a row from the GCP BigTable and returns
+ a `Tuple` of request and response.
+
+ Args:
+ request: the input `beam.Row` to enrich.
+ """
+ response_dict: Dict[str, Any] = {}
+ row_key_str: str = ""
+ try:
+ request_dict = request._asdict()
+ row_key_str = str(request_dict[self._row_key])
+ row_key = row_key_str.encode(self._encoding)
+ row = self._table.read_row(row_key, filter_=self._row_filter)
+ if row:
+ for cf_id, cf_v in row.cells.items():
+ response_dict[cf_id] = {}
+ for k, v in cf_v.items():
+ response_dict[cf_id][k.decode(self._encoding)] = \
+ v[0].value.decode(self._encoding)
+ elif self._exception_level == ExceptionLevel.WARN:
+ _LOGGER.warning(
+ 'no matching row found for row_key: %s '
+ 'with row_filter: %s' % (row_key_str, self._row_filter))
+ elif self._exception_level == ExceptionLevel.RAISE:
+ raise ValueError(
+ 'no matching row found for row_key: %s '
+ 'with row_filter=%s' % (row_key_str, self._row_filter))
+ except KeyError:
+ raise KeyError('row_key %s not found in input PCollection.' %
row_key_str)
+ except NotFound:
+ raise NotFound(
+ 'GCP BigTable cluster `%s:%s:%s` not found.' %
+ (self._project_id, self._instance_id, self._table_id))
+ except Exception as e:
+ raise e
+
+ return request, beam.Row(**response_dict)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Clean the instantiated BigTable client."""
+ self.client = None
+ self.instance = None
+ self._table = None
diff --git
a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py
b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py
new file mode 100644
index 00000000000..dd48c8e5ef4
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py
@@ -0,0 +1,300 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import unittest
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import BeamAssertException
+
+# pylint: disable=ungrouped-imports
+try:
+ from google.api_core.exceptions import NotFound
+ from google.cloud.bigtable import Client
+ from google.cloud.bigtable.row_filters import ColumnRangeFilter
+ from apache_beam.transforms.enrichment import Enrichment
+ from apache_beam.transforms.enrichment_handlers.bigtable import
EnrichWithBigTable
+ from apache_beam.transforms.enrichment_handlers.bigtable import
ExceptionLevel
+except ImportError:
+ raise unittest.SkipTest('GCP BigTable dependencies are not installed.')
+
+
+class ValidateResponse(beam.DoFn):
+ """ValidateResponse validates if a PCollection of `beam.Row`
+ has the required fields."""
+ def __init__(
+ self,
+ n_fields: int,
+ fields: List[str],
+ enriched_fields: Dict[str, List[str]]):
+ self.n_fields = n_fields
+ self._fields = fields
+ self._enriched_fields = enriched_fields
+
+ def process(self, element: beam.Row, *args, **kwargs):
+ element_dict = element.as_dict()
+ if len(element_dict.keys()) != self.n_fields:
+ raise BeamAssertException(
+ "Expected %d fields in enriched PCollection:" % self.n_fields)
+
+ for field in self._fields:
+ if field not in element_dict or element_dict[field] is None:
+ raise BeamAssertException(f"Expected a not None field: {field}")
+
+ for column_family, columns in self._enriched_fields.items():
+ if (len(element_dict[column_family]) != len(columns) or
+ not all(key in element_dict[column_family] for key in columns)):
+ raise BeamAssertException(
+ "Response from bigtable should contain a %s column_family with "
+ "%s keys." % (column_family, columns))
+
+
+class _Currency(NamedTuple):
+ s_id: int
+ id: str
+
+
+def create_rows(table):
+ product_id = 'product_id'
+ product_name = 'product_name'
+ product_stock = 'product_stock'
+
+ column_family_id = "product"
+ products = [
+ {
+ 'product_id': 1, 'product_name': 'pixel 5', 'product_stock': 2
+ },
+ {
+ 'product_id': 2, 'product_name': 'pixel 6', 'product_stock': 4
+ },
+ {
+ 'product_id': 3, 'product_name': 'pixel 7', 'product_stock': 20
+ },
+ {
+ 'product_id': 4, 'product_name': 'pixel 8', 'product_stock': 10
+ },
+ {
+ 'product_id': 5, 'product_name': 'iphone 11', 'product_stock': 3
+ },
+ {
+ 'product_id': 6, 'product_name': 'iphone 12', 'product_stock': 7
+ },
+ {
+ 'product_id': 7, 'product_name': 'iphone 13', 'product_stock': 8
+ },
+ {
+ 'product_id': 8, 'product_name': 'iphone 14', 'product_stock': 3
+ },
+ ]
+
+ for item in products:
+ row_key = str(item[product_id]).encode()
+ row = table.direct_row(row_key)
+ row.set_cell(
+ column_family_id,
+ product_id.encode(),
+ str(item[product_id]),
+ timestamp=datetime.datetime.utcnow())
+ row.set_cell(
+ column_family_id,
+ product_name.encode(),
+ item[product_name],
+ timestamp=datetime.datetime.utcnow())
+ row.set_cell(
+ column_family_id,
+ product_stock.encode(),
+ str(item[product_stock]),
+ timestamp=datetime.datetime.utcnow())
+ row.commit()
+
+
[email protected]_postcommit
+class TestBigTableEnrichment(unittest.TestCase):
+ def setUp(self):
+ self.project_id = 'apache-beam-testing'
+ self.instance_id = 'beam-test'
+ self.table_id = 'bigtable-enrichment-test'
+ self.req = [
+ beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1),
+ beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3),
+ beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2),
+ beam.Row(sale_id=7, customer_id=7, product_id=1, quantity=1),
+ ]
+ self.row_key = 'product_id'
+ self.column_family_id = 'product'
+ client = Client(project=self.project_id)
+ instance = client.instance(self.instance_id)
+ self.table = instance.table(self.table_id)
+ create_rows(self.table)
+
+ def tearDown(self) -> None:
+ self.table = None
+
+ def test_enrichment_with_bigtable(self):
+ expected_fields = [
+ 'sale_id', 'customer_id', 'product_id', 'quantity', 'product'
+ ]
+ expected_enriched_fields = {
+ 'product': ['product_id', 'product_name', 'product_stock'],
+ }
+ bigtable = EnrichWithBigTable(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ table_id=self.table_id,
+ row_key=self.row_key)
+ with TestPipeline(is_integration_test=True) as test_pipeline:
+ _ = (
+ test_pipeline
+ | "Create" >> beam.Create(self.req)
+ | "Enrich W/ BigTable" >> Enrichment(bigtable)
+ | "Validate Response" >> beam.ParDo(
+ ValidateResponse(
+ len(expected_fields),
+ expected_fields,
+ expected_enriched_fields)))
+
+ def test_enrichment_with_bigtable_row_filter(self):
+ expected_fields = [
+ 'sale_id', 'customer_id', 'product_id', 'quantity', 'product'
+ ]
+ expected_enriched_fields = {
+ 'product': ['product_name', 'product_stock'],
+ }
+ start_column = 'product_name'.encode()
+ column_filter = ColumnRangeFilter(self.column_family_id, start_column)
+ bigtable = EnrichWithBigTable(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ table_id=self.table_id,
+ row_key=self.row_key,
+ row_filter=column_filter)
+ with TestPipeline(is_integration_test=True) as test_pipeline:
+ _ = (
+ test_pipeline
+ | "Create" >> beam.Create(self.req)
+ | "Enrich W/ BigTable" >> Enrichment(bigtable)
+ | "Validate Response" >> beam.ParDo(
+ ValidateResponse(
+ len(expected_fields),
+ expected_fields,
+ expected_enriched_fields)))
+
+ def test_enrichment_with_bigtable_no_enrichment(self):
+ # row_key which is product_id=11 doesn't exist, so the enriched field
+ # won't be added. Hence, the response is same as the request.
+ expected_fields = ['sale_id', 'customer_id', 'product_id', 'quantity']
+ expected_enriched_fields = {}
+ bigtable = EnrichWithBigTable(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ table_id=self.table_id,
+ row_key=self.row_key)
+ req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)]
+ with TestPipeline(is_integration_test=True) as test_pipeline:
+ _ = (
+ test_pipeline
+ | "Create" >> beam.Create(req)
+ | "Enrich W/ BigTable" >> Enrichment(bigtable)
+ | "Validate Response" >> beam.ParDo(
+ ValidateResponse(
+ len(expected_fields),
+ expected_fields,
+ expected_enriched_fields)))
+
+ def test_enrichment_with_bigtable_bad_row_filter(self):
+ # in case of a bad column filter, that is, incorrect column_family_id and
+ # columns, no enrichment is done. If the column_family is correct but not
+ # column names then all columns in that column_family are returned.
+ start_column = 'car_name'.encode()
+ column_filter = ColumnRangeFilter('car_name', start_column)
+ bigtable = EnrichWithBigTable(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ table_id=self.table_id,
+ row_key=self.row_key,
+ row_filter=column_filter)
+ with self.assertRaises(NotFound):
+ test_pipeline = beam.Pipeline()
+ _ = (
+ test_pipeline
+ | "Create" >> beam.Create(self.req)
+ | "Enrich W/ BigTable" >> Enrichment(bigtable))
+ res = test_pipeline.run()
+ res.wait_until_finish()
+
+ def test_enrichment_with_bigtable_raises_key_error(self):
+ """raises a `KeyError` when the row_key doesn't exist in
+ the input PCollection."""
+ bigtable = EnrichWithBigTable(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ table_id=self.table_id,
+ row_key='car_name')
+ with self.assertRaises(KeyError):
+ test_pipeline = beam.Pipeline()
+ _ = (
+ test_pipeline
+ | "Create" >> beam.Create(self.req)
+ | "Enrich W/ BigTable" >> Enrichment(bigtable))
+ res = test_pipeline.run()
+ res.wait_until_finish()
+
+ def test_enrichment_with_bigtable_raises_not_found(self):
+ """raises a `NotFound` exception when the GCP BigTable Cluster
+ doesn't exist."""
+ bigtable = EnrichWithBigTable(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ table_id='invalid_table',
+ row_key=self.row_key)
+ with self.assertRaises(NotFound):
+ test_pipeline = beam.Pipeline()
+ _ = (
+ test_pipeline
+ | "Create" >> beam.Create(self.req)
+ | "Enrich W/ BigTable" >> Enrichment(bigtable))
+ res = test_pipeline.run()
+ res.wait_until_finish()
+
+ def test_enrichment_with_bigtable_exception_level(self):
+ """raises a `ValueError` exception when the GCP BigTable query returns
+ an empty row."""
+ bigtable = EnrichWithBigTable(
+ project_id=self.project_id,
+ instance_id=self.instance_id,
+ table_id=self.table_id,
+ row_key=self.row_key,
+ exception_level=ExceptionLevel.RAISE)
+ req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)]
+ with self.assertRaises(ValueError):
+ test_pipeline = beam.Pipeline()
+ _ = (
+ test_pipeline
+ | "Create" >> beam.Create(req)
+ | "Enrich W/ BigTable" >> Enrichment(bigtable))
+ res = test_pipeline.run()
+ res.wait_until_finish()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/transforms/enrichment_it_test.py
b/sdks/python/apache_beam/transforms/enrichment_it_test.py
new file mode 100644
index 00000000000..89842cb18be
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_it_test.py
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import time
+import unittest
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import pytest
+import urllib3
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import BeamAssertException
+
+# pylint: disable=ungrouped-imports
+try:
+ from apache_beam.io.requestresponse import UserCodeExecutionException
+ from apache_beam.io.requestresponse import UserCodeQuotaException
+ from apache_beam.io.requestresponse_it_test import _PAYLOAD
+ from apache_beam.io.requestresponse_it_test import EchoITOptions
+ from apache_beam.transforms.enrichment import Enrichment
+ from apache_beam.transforms.enrichment import EnrichmentSourceHandler
+except ImportError:
+ raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
+
+
+class Request(NamedTuple):
+ id: str
+ payload: bytes
+
+
+def _custom_join(left, right):
+ """custom_join returns the id and resp_payload along with a timestamp"""
+ right['timestamp'] = time.time()
+ return beam.Row(**right)
+
+
+class SampleHTTPEnrichment(EnrichmentSourceHandler[Request, beam.Row]):
+ """Implements ``EnrichmentSourceHandler`` to call the ``EchoServiceGrpc``'s
+ HTTP handler.
+ """
+ def __init__(self, url: str):
+ self.url = url + '/v1/echo' # append path to the mock API.
+
+ def __call__(self, request: Request, *args, **kwargs):
+ """Overrides ``Caller``'s call method invoking the
+ ``EchoServiceGrpc``'s HTTP handler with an `dict`, returning
+ either a successful ``Tuple[dict,dict]`` or throwing either a
+ ``UserCodeExecutionException``, ``UserCodeTimeoutException``,
+ or a ``UserCodeQuotaException``.
+ """
+ try:
+ resp = urllib3.request(
+ "POST",
+ self.url,
+ json={
+ "id": request.id, "payload": str(request.payload, 'utf-8')
+ },
+ retries=False)
+
+ if resp.status < 300:
+ resp_body = resp.json()
+ resp_id = resp_body['id']
+ payload = resp_body['payload']
+ return (
+ request, beam.Row(id=resp_id, resp_payload=bytes(payload,
'utf-8')))
+
+ if resp.status == 429: # Too Many Requests
+ raise UserCodeQuotaException(resp.reason)
+ elif resp.status != 200:
+ raise UserCodeExecutionException(resp.status, resp.reason, request)
+
+ except urllib3.exceptions.HTTPError as e:
+ raise UserCodeExecutionException(e)
+
+
+class ValidateFields(beam.DoFn):
+ """ValidateFields validates if a PCollection of `beam.Row`
+ has certain fields."""
+ def __init__(self, n_fields: int, fields: List[str]):
+ self.n_fields = n_fields
+ self._fields = fields
+
+ def process(self, element: beam.Row, *args, **kwargs):
+ element_dict = element.as_dict()
+ if len(element_dict.keys()) != self.n_fields:
+ raise BeamAssertException(
+ "Expected %d fields in enriched PCollection:"
+ " id, payload and resp_payload" % self.n_fields)
+
+ for field in self._fields:
+ if field not in element_dict or element_dict[field] is None:
+ raise BeamAssertException(f"Expected a not None field: {field}")
+
+
[email protected]_postcommit
+class TestEnrichment(unittest.TestCase):
+ options: Union[EchoITOptions, None] = None
+ client: Union[SampleHTTPEnrichment, None] = None
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls.options = EchoITOptions()
+ http_endpoint_address = cls.options.http_endpoint_address
+ if not http_endpoint_address or http_endpoint_address == '':
+ raise unittest.SkipTest('HTTP_ENDPOINT_ADDRESS is required.')
+ cls.client = SampleHTTPEnrichment(http_endpoint_address)
+
+ @classmethod
+ def _get_client_and_options(
+ cls) -> Tuple[SampleHTTPEnrichment, EchoITOptions]:
+ assert cls.options is not None
+ assert cls.client is not None
+ return cls.client, cls.options
+
+ def test_http_enrichment(self):
+ """Tests Enrichment Transform against the Mock-API HTTP endpoint
+ with the default cross join."""
+ client, options = TestEnrichment._get_client_and_options()
+ req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
+ fields = ['id', 'payload', 'resp_payload']
+ with TestPipeline(is_integration_test=True) as test_pipeline:
+ _ = (
+ test_pipeline
+ | 'Create PCollection' >> beam.Create([req])
+ | 'Enrichment Transform' >> Enrichment(client)
+ | 'Assert Fields' >> beam.ParDo(
+ ValidateFields(len(fields), fields=fields)))
+
+ def test_http_enrichment_custom_join(self):
+ """Tests Enrichment Transform against the Mock-API HTTP endpoint
+ with a custom join function."""
+ client, options = TestEnrichment._get_client_and_options()
+ req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
+ fields = ['id', 'resp_payload', 'timestamp']
+ with TestPipeline(is_integration_test=True) as test_pipeline:
+ _ = (
+ test_pipeline
+ | 'Create PCollection' >> beam.Create([req])
+ | 'Enrichment Transform' >> Enrichment(client, join_fn=_custom_join)
+ | 'Assert Fields' >> beam.ParDo(
+ ValidateFields(len(fields), fields=fields)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/transforms/enrichment_test.py
b/sdks/python/apache_beam/transforms/enrichment_test.py
new file mode 100644
index 00000000000..23b5f1828c1
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_test.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+
+# pylint: disable=ungrouped-imports
+try:
+ from apache_beam.transforms.enrichment import cross_join
+except ImportError:
+ raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
+
+
+class TestEnrichmentTransform(unittest.TestCase):
+ def test_cross_join(self):
+ left = {'id': 1, 'key': 'city'}
+ right = {'id': 1, 'value': 'durham'}
+ expected = beam.Row(id=1, key='city', value='durham')
+ output = cross_join(left, right)
+ self.assertEqual(expected, output)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()