This is an automated email from the ASF dual-hosted git repository.

riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 1de8454ddd2 [Python] Enrichment Transform with BigTable handler 
(#30001)
1de8454ddd2 is described below

commit 1de8454ddd215b9659fd049bded6a4b4c484b12d
Author: Ritesh Ghorse <[email protected]>
AuthorDate: Thu Jan 18 21:22:11 2024 -0500

    [Python] Enrichment Transform with BigTable handler (#30001)
    
    * enrichment v1
    
    * add documentation
    
    * add doc comment
    
    * rerun
    
    * update docs, lint
    
    * update docs, lint
    
    * add generic type
    
    * add generic type
    
    * adjust doc path
    
    * create test row
    
    * use request type
    
    * use request type
    
    * change module name
    
    * more tests
    
    * remove non-functional params
    
    * lint, doc
    
    * change types for general use
    
    * callable type
    
    * dict type
    
    * update signatures
    
    * fix unit test
    
    * bigtable with column family, ids, rrio-throttler
    
    * update tests for row filter
    
    * convert handler types from dict to Row
    
    * update tests for bigtable
    
    * ran pydocs
    
    * ran pydocs
    
    * mark postcommit
    
    * remove _test file, fix import
    
    * enable postcommit
    
    * add more tests
    
    * skip tests when dependencies are not installed
    
    * add deleted imports from last commit
    
    * add skip test condition
    
    * fix import order, add TooManyRequests to try-catch
    
    * make throttler, repeater non-optional
    
    * add exception level and tests
    
    * correct pydoc statement
    
    * add throttle tests
    
    * add bigtable improvements
    
    * default app_profile_id
    
    * add documentation, ignore None assignment
    
    * add to changes.md
    
    * change test structure that throws exception, skip http test for now
    
    * drop postcommit trigger file
---
 CHANGES.md                                         |   2 +-
 sdks/python/apache_beam/io/requestresponse.py      | 413 +++++++++++++++++++++
 ...nseio_it_test.py => requestresponse_it_test.py} |  37 +-
 sdks/python/apache_beam/io/requestresponse_test.py | 156 ++++++++
 sdks/python/apache_beam/io/requestresponseio.py    | 218 -----------
 .../apache_beam/io/requestresponseio_test.py       |  88 -----
 sdks/python/apache_beam/transforms/enrichment.py   | 137 +++++++
 .../transforms/enrichment_handlers/__init__.py     |  16 +
 .../transforms/enrichment_handlers/bigtable.py     | 151 ++++++++
 .../enrichment_handlers/bigtable_it_test.py        | 300 +++++++++++++++
 .../apache_beam/transforms/enrichment_it_test.py   | 162 ++++++++
 .../apache_beam/transforms/enrichment_test.py      |  41 ++
 12 files changed, 1398 insertions(+), 323 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 81a519b07d7..dbad15f3dba 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -69,7 +69,7 @@
 
 ## New Features / Improvements
 
-* X feature added (Java/Python) 
([#X](https://github.com/apache/beam/issues/X)).
+* [Enrichment Transform](https://s.apache.org/enrichment-transform) along with 
GCP BigTable handler added to Python SDK 
([#30001](https://github.com/apache/beam/pull/30001)).
 
 ## Breaking Changes
 
diff --git a/sdks/python/apache_beam/io/requestresponse.py 
b/sdks/python/apache_beam/io/requestresponse.py
new file mode 100644
index 00000000000..63ec7061d3e
--- /dev/null
+++ b/sdks/python/apache_beam/io/requestresponse.py
@@ -0,0 +1,413 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""``PTransform`` for reading from and writing to Web APIs."""
+import abc
+import concurrent.futures
+import contextlib
+import logging
+import sys
+import time
+from typing import Generic
+from typing import Optional
+from typing import TypeVar
+
+from google.api_core.exceptions import TooManyRequests
+
+import apache_beam as beam
+from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler
+from apache_beam.metrics import Metrics
+from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC
+from apache_beam.utils import retry
+
+RequestT = TypeVar('RequestT')
+ResponseT = TypeVar('ResponseT')
+
+DEFAULT_TIMEOUT_SECS = 30  # seconds
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class UserCodeExecutionException(Exception):
+  """Base class for errors related to calling Web APIs."""
+
+
+class UserCodeQuotaException(UserCodeExecutionException):
+  """Extends ``UserCodeExecutionException`` to signal specifically that
+  the Web API client encountered a Quota or API overuse related error.
+  """
+
+
+class UserCodeTimeoutException(UserCodeExecutionException):
+  """Extends ``UserCodeExecutionException`` to signal a user code timeout."""
+
+
+def retry_on_exception(exception: Exception):
+  """retry on exceptions caused by unavailability of the remote server."""
+  return isinstance(
+      exception,
+      (TooManyRequests, UserCodeTimeoutException, UserCodeQuotaException))
+
+
+class _MetricsCollector:
+  """A metrics collector that tracks RequestResponseIO related usage."""
+  def __init__(self, namespace: str):
+    """
+    Args:
+      namespace: Namespace for the metrics.
+    """
+    self.requests = Metrics.counter(namespace, 'requests')
+    self.responses = Metrics.counter(namespace, 'responses')
+    self.failures = Metrics.counter(namespace, 'failures')
+    self.throttled_requests = Metrics.counter(namespace, 'throttled_requests')
+    self.throttled_secs = Metrics.counter(
+        namespace, 'cumulativeThrottlingSeconds')
+    self.timeout_requests = Metrics.counter(namespace, 'requests_timed_out')
+    self.call_counter = Metrics.counter(namespace, 'call_invocations')
+    self.setup_counter = Metrics.counter(namespace, 'setup_counter')
+    self.teardown_counter = Metrics.counter(namespace, 'teardown_counter')
+    self.backoff_counter = Metrics.counter(namespace, 'backoff_counter')
+    self.sleeper_counter = Metrics.counter(namespace, 'sleeper_counter')
+    self.should_backoff_counter = Metrics.counter(
+        namespace, 'should_backoff_counter')
+
+
+class Caller(contextlib.AbstractContextManager,
+             abc.ABC,
+             Generic[RequestT, ResponseT]):
+  """Interface for user custom code intended for API calls.
+  For setup and teardown of clients when applicable, implement the
+  ``__enter__`` and ``__exit__`` methods respectively."""
+  @abc.abstractmethod
+  def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT:
+    """Calls a Web API with the ``RequestT``  and returns a
+    ``ResponseT``. ``RequestResponseIO`` expects implementations of the
+    ``__call__`` method to throw either a ``UserCodeExecutionException``,
+    ``UserCodeQuotaException``, or ``UserCodeTimeoutException``.
+    """
+    pass
+
+  def __enter__(self):
+    return self
+
+  def __exit__(self, exc_type, exc_val, exc_tb):
+    return None
+
+
+class ShouldBackOff(abc.ABC):
+  """
+  ShouldBackOff provides mechanism to apply adaptive throttling.
+  """
+  pass
+
+
+class Repeater(abc.ABC):
+  """Repeater provides mechanism to repeat requests for a
+  configurable condition."""
+  @abc.abstractmethod
+  def repeat(
+      self,
+      caller: Caller[RequestT, ResponseT],
+      request: RequestT,
+      timeout: float,
+      metrics_collector: Optional[_MetricsCollector]) -> ResponseT:
+    """repeat method is called from the RequestResponseIO when
+    a repeater is enabled.
+
+    Args:
+      caller: :class:`apache_beam.io.requestresponse.Caller` object that calls
+        the API.
+      request: input request to repeat.
+      timeout: time to wait for the request to complete.
+      metrics_collector: (Optional) a
+        ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to
+        collect the metrics for RequestResponseIO.
+    """
+    pass
+
+
+def _execute_request(
+    caller: Caller[RequestT, ResponseT],
+    request: RequestT,
+    timeout: float,
+    metrics_collector: Optional[_MetricsCollector] = None) -> ResponseT:
+  with concurrent.futures.ThreadPoolExecutor() as executor:
+    future = executor.submit(caller, request)
+    try:
+      return future.result(timeout=timeout)
+    except TooManyRequests as e:
+      _LOGGER.info(
+          'request could not be completed. got code %i from the service.',
+          e.code)
+      raise e
+    except concurrent.futures.TimeoutError:
+      if metrics_collector:
+        metrics_collector.timeout_requests.inc(1)
+      raise UserCodeTimeoutException(
+          f'Timeout {timeout} exceeded '
+          f'while completing request: {request}')
+    except RuntimeError:
+      if metrics_collector:
+        metrics_collector.failures.inc(1)
+      raise UserCodeExecutionException('could not complete request')
+
+
+class ExponentialBackOffRepeater(Repeater):
+  """Exponential BackOff Repeater uses exponential backoff retry strategy for
+  exceptions due to the remote service such as TooManyRequests (HTTP 429),
+  UserCodeTimeoutException, UserCodeQuotaException.
+
+  It utilizes the decorator
+  :func:`apache_beam.utils.retry.with_exponential_backoff`.
+  """
+  def __init__(self):
+    pass
+
+  @retry.with_exponential_backoff(
+      num_retries=2, retry_filter=retry_on_exception)
+  def repeat(
+      self,
+      caller: Caller[RequestT, ResponseT],
+      request: RequestT,
+      timeout: float,
+      metrics_collector: Optional[_MetricsCollector] = None) -> ResponseT:
+    """repeat method is called from the RequestResponseIO when
+    a repeater is enabled.
+
+    Args:
+      caller: :class:`apache_beam.io.requestresponse.Caller` object that
+        calls the API.
+      request: input request to repeat.
+      timeout: time to wait for the request to complete.
+      metrics_collector: (Optional) a
+        ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to
+        collect the metrics for RequestResponseIO.
+    """
+    return _execute_request(caller, request, timeout, metrics_collector)
+
+
+class NoOpsRepeater(Repeater):
+  """
+  NoOpsRepeater executes a request just once irrespective of any exception.
+  """
+  def repeat(
+      self,
+      caller: Caller[RequestT, ResponseT],
+      request: RequestT,
+      timeout: float,
+      metrics_collector: Optional[_MetricsCollector]) -> ResponseT:
+    return _execute_request(caller, request, timeout, metrics_collector)
+
+
+class CacheReader(abc.ABC):
+  """CacheReader provides mechanism to read from the cache."""
+  pass
+
+
+class CacheWriter(abc.ABC):
+  """CacheWriter provides mechanism to write to the cache."""
+  pass
+
+
+class PreCallThrottler(abc.ABC):
+  """PreCallThrottler provides a throttle mechanism before sending request."""
+  pass
+
+
+class DefaultThrottler(PreCallThrottler):
+  """Default throttler that uses
+  :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`
+
+  Args:
+    window_ms (int): length of history to consider, in ms, to set throttling.
+    bucket_ms (int): granularity of time buckets that we store data in, in ms.
+    overload_ratio (float): the target ratio between requests sent and
+      successful requests. This is "K" in the formula in
+      https://landing.google.com/sre/book/chapters/handling-overload.html.
+    delay_secs (int): minimum number of seconds to throttle a request.
+  """
+  def __init__(
+      self,
+      window_ms: int = 1,
+      bucket_ms: int = 1,
+      overload_ratio: float = 2,
+      delay_secs: int = 5):
+    self.throttler = AdaptiveThrottler(
+        window_ms=window_ms, bucket_ms=bucket_ms, 
overload_ratio=overload_ratio)
+    self.delay_secs = delay_secs
+
+
+class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT],
+                                        beam.PCollection[ResponseT]]):
+  """A :class:`RequestResponseIO` transform to read and write to APIs.
+
+  Processes an input :class:`~apache_beam.pvalue.PCollection` of requests
+  by making a call to the API as defined in :class:`Caller`'s `__call__`
+  and returns a :class:`~apache_beam.pvalue.PCollection` of responses.
+  """
+  def __init__(
+      self,
+      caller: Caller[RequestT, ResponseT],
+      timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
+      should_backoff: Optional[ShouldBackOff] = None,
+      repeater: Repeater = ExponentialBackOffRepeater(),
+      cache_reader: Optional[CacheReader] = None,
+      cache_writer: Optional[CacheWriter] = None,
+      throttler: PreCallThrottler = DefaultThrottler(),
+  ):
+    """
+    Instantiates a RequestResponseIO transform.
+
+    Args:
+      caller (~apache_beam.io.requestresponse.Caller): an implementation of
+        `Caller` object that makes call to the API.
+      timeout (float): timeout value in seconds to wait for response from API.
+      should_backoff (~apache_beam.io.requestresponse.ShouldBackOff):
+        (Optional) provides methods for backoff.
+      repeater (~apache_beam.io.requestresponse.Repeater): provides method to
+        repeat failed requests to API due to service errors. Defaults to
+        :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to
+        repeat requests with exponential backoff.
+      cache_reader (~apache_beam.io.requestresponse.CacheReader): (Optional)
+        provides methods to read external cache.
+      cache_writer (~apache_beam.io.requestresponse.CacheWriter): (Optional)
+        provides methods to write to external cache.
+      throttler (~apache_beam.io.requestresponse.PreCallThrottler):
+        provides methods to pre-throttle a request. Defaults to
+        :class:`apache_beam.io.requestresponse.DefaultThrottler` for
+        client-side adaptive throttling using
+        :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`
+    """
+    self._caller = caller
+    self._timeout = timeout
+    self._should_backoff = should_backoff
+    if repeater:
+      self._repeater = repeater
+    else:
+      self._repeater = NoOpsRepeater()
+    self._cache_reader = cache_reader
+    self._cache_writer = cache_writer
+    self._throttler = throttler
+
+  def expand(
+      self,
+      requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
+    # TODO(riteshghorse): handle Cache and Throttle PTransforms when available.
+    if isinstance(self._throttler, DefaultThrottler):
+      return requests | _Call(
+          caller=self._caller,
+          timeout=self._timeout,
+          should_backoff=self._should_backoff,
+          repeater=self._repeater,
+          throttler=self._throttler)
+    else:
+      return requests | _Call(
+          caller=self._caller,
+          timeout=self._timeout,
+          should_backoff=self._should_backoff,
+          repeater=self._repeater)
+
+
+class _Call(beam.PTransform[beam.PCollection[RequestT],
+                            beam.PCollection[ResponseT]]):
+  """(Internal-only) PTransform that invokes a remote function on each element
+   of the input PCollection.
+
+  This PTransform uses a `Caller` object to invoke the actual API calls,
+  and uses ``__enter__`` and ``__exit__`` to manage setup and teardown of
+  clients when applicable. Additionally, a timeout value is specified to
+  regulate the duration of each call, defaults to 30 seconds.
+
+  Args:
+      caller (:class:`apache_beam.io.requestresponse.Caller`): a callable
+        object that invokes API call.
+      timeout (float): timeout value in seconds to wait for response from API.
+      should_backoff (~apache_beam.io.requestresponse.ShouldBackOff):
+        (Optional) provides methods for backoff.
+      repeater (~apache_beam.io.requestresponse.Repeater): (Optional) provides
+        methods to repeat requests to API.
+      throttler (~apache_beam.io.requestresponse.PreCallThrottler):
+        (Optional) provides methods to pre-throttle a request.
+  """
+  def __init__(
+      self,
+      caller: Caller[RequestT, ResponseT],
+      timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
+      should_backoff: Optional[ShouldBackOff] = None,
+      repeater: Repeater = None,
+      throttler: PreCallThrottler = None,
+  ):
+    self._caller = caller
+    self._timeout = timeout
+    self._should_backoff = should_backoff
+    self._repeater = repeater
+    self._throttler = throttler
+
+  def expand(
+      self,
+      requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
+    return requests | beam.ParDo(
+        _CallDoFn(self._caller, self._timeout, self._repeater, 
self._throttler))
+
+
+class _CallDoFn(beam.DoFn):
+  def setup(self):
+    self._caller.__enter__()
+    self._metrics_collector = _MetricsCollector(self._caller.__str__())
+    self._metrics_collector.setup_counter.inc(1)
+
+  def __init__(
+      self,
+      caller: Caller[RequestT, ResponseT],
+      timeout: float,
+      repeater: Repeater,
+      throttler: PreCallThrottler):
+    self._metrics_collector = None
+    self._caller = caller
+    self._timeout = timeout
+    self._repeater = repeater
+    self._throttler = throttler
+
+  def process(self, request: RequestT, *args, **kwargs):
+    self._metrics_collector.requests.inc(1)
+
+    is_throttled_request = False
+    if self._throttler:
+      while self._throttler.throttler.throttle_request(time.time() *
+                                                       MSEC_TO_SEC):
+        _LOGGER.info(
+            "Delaying request for %d seconds" % self._throttler.delay_secs)
+        time.sleep(self._throttler.delay_secs)
+        self._metrics_collector.throttled_secs.inc(self._throttler.delay_secs)
+        is_throttled_request = True
+
+    if is_throttled_request:
+      self._metrics_collector.throttled_requests.inc(1)
+
+    try:
+      req_time = time.time()
+      response = self._repeater.repeat(
+          self._caller, request, self._timeout, self._metrics_collector)
+      self._metrics_collector.responses.inc(1)
+      self._throttler.throttler.successful_request(req_time * MSEC_TO_SEC)
+      yield response
+    except Exception as e:
+      raise e
+
+  def teardown(self):
+    self._metrics_collector.teardown_counter.inc(1)
+    self._caller.__exit__(*sys.exc_info())
diff --git a/sdks/python/apache_beam/io/requestresponseio_it_test.py 
b/sdks/python/apache_beam/io/requestresponse_it_test.py
similarity index 86%
rename from sdks/python/apache_beam/io/requestresponseio_it_test.py
rename to sdks/python/apache_beam/io/requestresponse_it_test.py
index aae6b4e6ef2..396347c58d1 100644
--- a/sdks/python/apache_beam/io/requestresponseio_it_test.py
+++ b/sdks/python/apache_beam/io/requestresponse_it_test.py
@@ -16,6 +16,7 @@
 #
 import base64
 import sys
+import typing
 import unittest
 from dataclasses import dataclass
 from typing import Tuple
@@ -24,13 +25,18 @@ from typing import Union
 import urllib3
 
 import apache_beam as beam
-from apache_beam.io.requestresponseio import Caller
-from apache_beam.io.requestresponseio import RequestResponseIO
-from apache_beam.io.requestresponseio import UserCodeExecutionException
-from apache_beam.io.requestresponseio import UserCodeQuotaException
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.testing.test_pipeline import TestPipeline
 
+# pylint: disable=ungrouped-imports
+try:
+  from apache_beam.io.requestresponse import Caller
+  from apache_beam.io.requestresponse import RequestResponseIO
+  from apache_beam.io.requestresponse import UserCodeExecutionException
+  from apache_beam.io.requestresponse import UserCodeQuotaException
+except ImportError:
+  raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
+
 _HTTP_PATH = '/v1/echo'
 _PAYLOAD = base64.b64encode(bytes('payload', 'utf-8'))
 _HTTP_ENDPOINT_ADDRESS_FLAG = '--httpEndpointAddress'
@@ -61,28 +67,27 @@ class EchoITOptions(PipelineOptions):
         help='The ID for an allocated quota that should exceed.')
 
 
-# TODO(riteshghorse,damondouglas) replace Echo(Request|Response) with proto
-#   generated classes from .test-infra/mock-apis:
 @dataclass
-class EchoRequest:
+class EchoResponse:
   id: str
   payload: bytes
 
 
-@dataclass
-class EchoResponse:
+# TODO(riteshghorse,damondouglas) replace Echo(Request|Response) with proto
+#   generated classes from .test-infra/mock-apis:
+class Request(typing.NamedTuple):
   id: str
   payload: bytes
 
 
-class EchoHTTPCaller(Caller):
+class EchoHTTPCaller(Caller[Request, EchoResponse]):
   """Implements ``Caller`` to call the ``EchoServiceGrpc``'s HTTP handler.
     The purpose of ``EchoHTTPCaller`` is to support integration tests.
     """
   def __init__(self, url: str):
     self.url = url + _HTTP_PATH
 
-  def __call__(self, request: EchoRequest, *args, **kwargs) -> EchoResponse:
+  def __call__(self, request: Request, *args, **kwargs) -> EchoResponse:
     """Overrides ``Caller``'s call method invoking the
         ``EchoServiceGrpc``'s HTTP handler with an ``EchoRequest``, returning
         either a successful ``EchoResponse`` or throwing either a
@@ -129,7 +134,7 @@ class EchoHTTPCallerTestIT(unittest.TestCase):
   def setUp(self) -> None:
     client, options = EchoHTTPCallerTestIT._get_client_and_options()
 
-    req = EchoRequest(id=options.should_exceed_quota_id, payload=_PAYLOAD)
+    req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD)
     try:
       # The following is needed to exceed the API
       client(req)
@@ -148,7 +153,7 @@ class EchoHTTPCallerTestIT(unittest.TestCase):
   def test_given_valid_request_receives_response(self):
     client, options = EchoHTTPCallerTestIT._get_client_and_options()
 
-    req = EchoRequest(id=options.never_exceed_quota_id, payload=_PAYLOAD)
+    req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
 
     response: EchoResponse = client(req)
 
@@ -158,20 +163,20 @@ class EchoHTTPCallerTestIT(unittest.TestCase):
   def test_given_exceeded_quota_should_raise(self):
     client, options = EchoHTTPCallerTestIT._get_client_and_options()
 
-    req = EchoRequest(id=options.should_exceed_quota_id, payload=_PAYLOAD)
+    req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD)
 
     self.assertRaises(UserCodeQuotaException, lambda: client(req))
 
   def test_not_found_should_raise(self):
     client, _ = EchoHTTPCallerTestIT._get_client_and_options()
 
-    req = EchoRequest(id='i-dont-exist-quota-id', payload=_PAYLOAD)
+    req = Request(id='i-dont-exist-quota-id', payload=_PAYLOAD)
     self.assertRaisesRegex(
         UserCodeExecutionException, "Not Found", lambda: client(req))
 
   def test_request_response_io(self):
     client, options = EchoHTTPCallerTestIT._get_client_and_options()
-    req = EchoRequest(id=options.never_exceed_quota_id, payload=_PAYLOAD)
+    req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
     with TestPipeline(is_integration_test=True) as test_pipeline:
       output = (
           test_pipeline
diff --git a/sdks/python/apache_beam/io/requestresponse_test.py 
b/sdks/python/apache_beam/io/requestresponse_test.py
new file mode 100644
index 00000000000..6d807c2a8eb
--- /dev/null
+++ b/sdks/python/apache_beam/io/requestresponse_test.py
@@ -0,0 +1,156 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import time
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+
+# pylint: disable=ungrouped-imports
+try:
+  from google.api_core.exceptions import TooManyRequests
+  from apache_beam.io.requestresponse import Caller, DefaultThrottler
+  from apache_beam.io.requestresponse import RequestResponseIO
+  from apache_beam.io.requestresponse import UserCodeExecutionException
+  from apache_beam.io.requestresponse import UserCodeTimeoutException
+  from apache_beam.io.requestresponse import retry_on_exception
+except ImportError:
+  raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
+
+
+class AckCaller(Caller[str, str]):
+  """AckCaller acknowledges the incoming request by returning a
+  request with ACK."""
+  def __enter__(self):
+    pass
+
+  def __call__(self, request: str):
+    return f"ACK: {request}"
+
+  def __exit__(self, exc_type, exc_val, exc_tb):
+    return None
+
+
+class CallerWithTimeout(AckCaller):
+  """CallerWithTimeout sleeps for 2 seconds before responding.
+  Used to test timeout in RequestResponseIO."""
+  def __call__(self, request: str, *args, **kwargs):
+    time.sleep(2)
+    return f"ACK: {request}"
+
+
+class CallerWithRuntimeError(AckCaller):
+  """CallerWithRuntimeError raises a `RuntimeError` for RequestResponseIO
+  to raise a UserCodeExecutionException."""
+  def __call__(self, request: str, *args, **kwargs):
+    if not request:
+      raise RuntimeError("Exception expected, not an error.")
+
+
+class CallerThatRetries(AckCaller):
+  def __init__(self):
+    self.count = -1
+
+  def __call__(self, request: str, *args, **kwargs):
+    try:
+      pass
+    except Exception as e:
+      raise e
+    finally:
+      self.count += 1
+      raise TooManyRequests('retries = %d' % self.count)
+
+
+class TestCaller(unittest.TestCase):
+  def test_valid_call(self):
+    caller = AckCaller()
+    with TestPipeline() as test_pipeline:
+      output = (
+          test_pipeline
+          | beam.Create(["sample_request"])
+          | RequestResponseIO(caller=caller))
+
+    self.assertIsNotNone(output)
+
+  def test_call_timeout(self):
+    caller = CallerWithTimeout()
+    with self.assertRaises(UserCodeTimeoutException):
+      with TestPipeline() as test_pipeline:
+        _ = (
+            test_pipeline
+            | beam.Create(["timeout_request"])
+            | RequestResponseIO(caller=caller, timeout=1))
+
+  def test_call_runtime_error(self):
+    caller = CallerWithRuntimeError()
+    with self.assertRaises(UserCodeExecutionException):
+      with TestPipeline() as test_pipeline:
+        _ = (
+            test_pipeline
+            | beam.Create([""])
+            | RequestResponseIO(caller=caller))
+
+  def test_retry_on_exception(self):
+    self.assertFalse(retry_on_exception(RuntimeError()))
+    self.assertTrue(retry_on_exception(TooManyRequests("HTTP 429")))
+
+  def test_caller_backoff_retry_strategy(self):
+    caller = CallerThatRetries()
+    with self.assertRaises(TooManyRequests) as cm:
+      with TestPipeline() as test_pipeline:
+        _ = (
+            test_pipeline
+            | beam.Create(["sample_request"])
+            | RequestResponseIO(caller=caller))
+    self.assertRegex(cm.exception.message, 'retries = 2')
+
+  def test_caller_no_retry_strategy(self):
+    caller = CallerThatRetries()
+    with self.assertRaises(TooManyRequests) as cm:
+      with TestPipeline() as test_pipeline:
+        _ = (
+            test_pipeline
+            | beam.Create(["sample_request"])
+            | RequestResponseIO(caller=caller, repeater=None))
+    self.assertRegex(cm.exception.message, 'retries = 0')
+
+  def test_default_throttler(self):
+    caller = CallerWithTimeout()
+    throttler = DefaultThrottler(
+        window_ms=10000, bucket_ms=5000, overload_ratio=1)
+    # manually override the number of received requests for testing.
+    throttler.throttler._all_requests.add(time.time() * 1000, 100)
+    test_pipeline = TestPipeline()
+    _ = (
+        test_pipeline
+        | beam.Create(['sample_request'])
+        | RequestResponseIO(caller=caller, throttler=throttler))
+    result = test_pipeline.run()
+    result.wait_until_finish()
+    metrics = result.metrics().query(
+        beam.metrics.MetricsFilter().with_name('throttled_requests'))
+    self.assertEqual(metrics['counters'][0].committed, 1)
+    metrics = result.metrics().query(
+        beam.metrics.MetricsFilter().with_name('cumulativeThrottlingSeconds'))
+    self.assertGreater(metrics['counters'][0].committed, 0)
+    metrics = result.metrics().query(
+        beam.metrics.MetricsFilter().with_name('responses'))
+    self.assertEqual(metrics['counters'][0].committed, 1)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/io/requestresponseio.py 
b/sdks/python/apache_beam/io/requestresponseio.py
deleted file mode 100644
index 0ec586e6401..00000000000
--- a/sdks/python/apache_beam/io/requestresponseio.py
+++ /dev/null
@@ -1,218 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""``PTransform`` for reading from and writing to Web APIs."""
-import abc
-import concurrent.futures
-import contextlib
-import logging
-import sys
-from typing import Generic
-from typing import Optional
-from typing import TypeVar
-
-import apache_beam as beam
-from apache_beam.pvalue import PCollection
-
-RequestT = TypeVar('RequestT')
-ResponseT = TypeVar('ResponseT')
-
-DEFAULT_TIMEOUT_SECS = 30  # seconds
-
-_LOGGER = logging.getLogger(__name__)
-
-
-class UserCodeExecutionException(Exception):
-  """Base class for errors related to calling Web APIs."""
-
-
-class UserCodeQuotaException(UserCodeExecutionException):
-  """Extends ``UserCodeExecutionException`` to signal specifically that
-  the Web API client encountered a Quota or API overuse related error.
-  """
-
-
-class UserCodeTimeoutException(UserCodeExecutionException):
-  """Extends ``UserCodeExecutionException`` to signal a user code timeout."""
-
-
-class Caller(contextlib.AbstractContextManager, abc.ABC):
-  """Interface for user custom code intended for API calls.
-  For setup and teardown of clients when applicable, implement the
-  ``__enter__`` and ``__exit__`` methods respectively."""
-  @abc.abstractmethod
-  def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT:
-    """Calls a Web API with the ``RequestT``  and returns a
-    ``ResponseT``. ``RequestResponseIO`` expects implementations of the
-    ``__call__`` method to throw either a ``UserCodeExecutionException``,
-    ``UserCodeQuotaException``, or ``UserCodeTimeoutException``.
-    """
-    pass
-
-  def __enter__(self):
-    return self
-
-  def __exit__(self, exc_type, exc_val, exc_tb):
-    return None
-
-
-class ShouldBackOff(abc.ABC):
-  """
-  ShouldBackOff provides mechanism to apply adaptive throttling.
-  """
-  pass
-
-
-class Repeater(abc.ABC):
-  """Repeater provides mechanism to repeat requests for a
-  configurable condition."""
-  pass
-
-
-class CacheReader(abc.ABC):
-  """CacheReader provides mechanism to read from the cache."""
-  pass
-
-
-class CacheWriter(abc.ABC):
-  """CacheWriter provides mechanism to write to the cache."""
-  pass
-
-
-class PreCallThrottler(abc.ABC):
-  """PreCallThrottler provides a throttle mechanism before sending request."""
-  pass
-
-
-class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT],
-                                        beam.PCollection[ResponseT]]):
-  """A :class:`RequestResponseIO` transform to read and write to APIs.
-
-  Processes an input :class:`~apache_beam.pvalue.PCollection` of requests
-  by making a call to the API as defined in :class:`Caller`'s `__call__`
-  and returns a :class:`~apache_beam.pvalue.PCollection` of responses.
-  """
-  def __init__(
-      self,
-      caller: [Caller],
-      timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
-      should_backoff: Optional[ShouldBackOff] = None,
-      repeater: Optional[Repeater] = None,
-      cache_reader: Optional[CacheReader] = None,
-      cache_writer: Optional[CacheWriter] = None,
-      throttler: Optional[PreCallThrottler] = None,
-  ):
-    """
-    Instantiates a RequestResponseIO transform.
-
-    Args:
-      caller (~apache_beam.io.requestresponseio.Caller): an implementation of
-        `Caller` object that makes call to the API.
-      timeout (float): timeout value in seconds to wait for response from API.
-      should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff):
-        (Optional) provides methods for backoff.
-      repeater (~apache_beam.io.requestresponseio.Repeater): (Optional)
-        provides methods to repeat requests to API.
-      cache_reader (~apache_beam.io.requestresponseio.CacheReader): (Optional)
-        provides methods to read external cache.
-      cache_writer (~apache_beam.io.requestresponseio.CacheWriter): (Optional)
-        provides methods to write to external cache.
-      throttler (~apache_beam.io.requestresponseio.PreCallThrottler):
-        (Optional) provides methods to pre-throttle a request.
-    """
-    self._caller = caller
-    self._timeout = timeout
-    self._should_backoff = should_backoff
-    self._repeater = repeater
-    self._cache_reader = cache_reader
-    self._cache_writer = cache_writer
-    self._throttler = throttler
-
-  def expand(self, requests: PCollection[RequestT]) -> PCollection[ResponseT]:
-    # TODO(riteshghorse): add Cache and Throttle PTransforms.
-    return requests | _Call(
-        caller=self._caller,
-        timeout=self._timeout,
-        should_backoff=self._should_backoff,
-        repeater=self._repeater)
-
-
-class _Call(beam.PTransform[beam.PCollection[RequestT],
-                            beam.PCollection[ResponseT]]):
-  """(Internal-only) PTransform that invokes a remote function on each element
-   of the input PCollection.
-
-  This PTransform uses a `Caller` object to invoke the actual API calls,
-  and uses ``__enter__`` and ``__exit__`` to manage setup and teardown of
-  clients when applicable. Additionally, a timeout value is specified to
-  regulate the duration of each call, defaults to 30 seconds.
-
-  Args:
-      caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable
-        object that invokes API call.
-      timeout (float): timeout value in seconds to wait for response from API.
-  """
-  def __init__(
-      self,
-      caller: Caller,
-      timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
-      should_backoff: Optional[ShouldBackOff] = None,
-      repeater: Optional[Repeater] = None,
-  ):
-    """Initialize the _Call transform.
-    Args:
-      caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable
-        object that invokes API call.
-      timeout (float): timeout value in seconds to wait for response from API.
-      should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff):
-        (Optional) provides methods for backoff.
-      repeater (~apache_beam.io.requestresponseio.Repeater): (Optional) 
provides
-        methods to repeat requests to API.
-    """
-    self._caller = caller
-    self._timeout = timeout
-    self._should_backoff = should_backoff
-    self._repeater = repeater
-
-  def expand(
-      self,
-      requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
-    return requests | beam.ParDo(_CallDoFn(self._caller, self._timeout))
-
-
-class _CallDoFn(beam.DoFn, Generic[RequestT, ResponseT]):
-  def setup(self):
-    self._caller.__enter__()
-
-  def __init__(self, caller: Caller, timeout: float):
-    self._caller = caller
-    self._timeout = timeout
-
-  def process(self, request, *args, **kwargs):
-    with concurrent.futures.ThreadPoolExecutor() as executor:
-      future = executor.submit(self._caller, request)
-      try:
-        yield future.result(timeout=self._timeout)
-      except concurrent.futures.TimeoutError:
-        raise UserCodeTimeoutException(
-            f'Timeout {self._timeout} exceeded '
-            f'while completing request: {request}')
-      except RuntimeError:
-        raise UserCodeExecutionException('could not complete request')
-
-  def teardown(self):
-    self._caller.__exit__(*sys.exc_info())
diff --git a/sdks/python/apache_beam/io/requestresponseio_test.py 
b/sdks/python/apache_beam/io/requestresponseio_test.py
deleted file mode 100644
index 2828a357887..00000000000
--- a/sdks/python/apache_beam/io/requestresponseio_test.py
+++ /dev/null
@@ -1,88 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import time
-import unittest
-
-import apache_beam as beam
-from apache_beam.io.requestresponseio import Caller
-from apache_beam.io.requestresponseio import RequestResponseIO
-from apache_beam.io.requestresponseio import UserCodeExecutionException
-from apache_beam.io.requestresponseio import UserCodeTimeoutException
-from apache_beam.testing.test_pipeline import TestPipeline
-
-
-class AckCaller(Caller):
-  """AckCaller acknowledges the incoming request by returning a
-  request with ACK."""
-  def __enter__(self):
-    pass
-
-  def __call__(self, request: str):
-    return f"ACK: {request}"
-
-  def __exit__(self, exc_type, exc_val, exc_tb):
-    return None
-
-
-class CallerWithTimeout(AckCaller):
-  """CallerWithTimeout sleeps for 2 seconds before responding.
-  Used to test timeout in RequestResponseIO."""
-  def __call__(self, request: str, *args, **kwargs):
-    time.sleep(2)
-    return f"ACK: {request}"
-
-
-class CallerWithRuntimeError(AckCaller):
-  """CallerWithRuntimeError raises a `RuntimeError` for RequestResponseIO
-  to raise a UserCodeExecutionException."""
-  def __call__(self, request: str, *args, **kwargs):
-    if not request:
-      raise RuntimeError("Exception expected, not an error.")
-
-
-class TestCaller(unittest.TestCase):
-  def test_valid_call(self):
-    caller = AckCaller()
-    with TestPipeline() as test_pipeline:
-      output = (
-          test_pipeline
-          | beam.Create(["sample_request"])
-          | RequestResponseIO(caller=caller))
-
-    self.assertIsNotNone(output)
-
-  def test_call_timeout(self):
-    caller = CallerWithTimeout()
-    with self.assertRaises(UserCodeTimeoutException):
-      with TestPipeline() as test_pipeline:
-        _ = (
-            test_pipeline
-            | beam.Create(["timeout_request"])
-            | RequestResponseIO(caller=caller, timeout=1))
-
-  def test_call_runtime_error(self):
-    caller = CallerWithRuntimeError()
-    with self.assertRaises(UserCodeExecutionException):
-      with TestPipeline() as test_pipeline:
-        _ = (
-            test_pipeline
-            | beam.Create([""])
-            | RequestResponseIO(caller=caller))
-
-
-if __name__ == '__main__':
-  unittest.main()
diff --git a/sdks/python/apache_beam/transforms/enrichment.py 
b/sdks/python/apache_beam/transforms/enrichment.py
new file mode 100644
index 00000000000..a2f961be643
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Optional
+from typing import TypeVar
+
+import apache_beam as beam
+from apache_beam.io.requestresponse import DEFAULT_TIMEOUT_SECS
+from apache_beam.io.requestresponse import Caller
+from apache_beam.io.requestresponse import DefaultThrottler
+from apache_beam.io.requestresponse import ExponentialBackOffRepeater
+from apache_beam.io.requestresponse import PreCallThrottler
+from apache_beam.io.requestresponse import Repeater
+from apache_beam.io.requestresponse import RequestResponseIO
+
+__all__ = [
+    "EnrichmentSourceHandler",
+    "Enrichment",
+    "cross_join",
+]
+
+InputT = TypeVar('InputT')
+OutputT = TypeVar('OutputT')
+
+JoinFn = Callable[[Dict[str, Any], Dict[str, Any]], beam.Row]
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row:
+  """cross_join performs a cross join on two `dict` objects.
+
+    Joins the columns of the right row onto the left row.
+
+    Args:
+      left (Dict[str, Any]): input request dictionary.
+      right (Dict[str, Any]): response dictionary from the API.
+
+    Returns:
+      `beam.Row` containing the merged columns.
+  """
+  for k, v in right.items():
+    if k not in left:
+      # Don't override the values in left.
+      left[k] = v
+    elif left[k] != v:
+      _LOGGER.warning(
+          '%s exists in the input row as well the row fetched '
+          'from API but have different values - %s and %s. Using the input '
+          'value (%s) for the enriched row. You can override this behavior by '
+          'passing a custom `join_fn` to Enrichment transform.' %
+          (k, left[k], v, left[k]))
+  return beam.Row(**left)
+
+
+class EnrichmentSourceHandler(Caller[InputT, OutputT]):
+  """Wrapper class for :class:`apache_beam.io.requestresponse.Caller`.
+
+  Ensure that the implementation of ``__call__`` method returns a tuple
+  of `beam.Row`  objects.
+  """
+  pass
+
+
+class Enrichment(beam.PTransform[beam.PCollection[InputT],
+                                 beam.PCollection[OutputT]]):
+  """A :class:`apache_beam.transforms.enrichment.Enrichment` transform to
+  enrich elements in a PCollection.
+  **NOTE:** This transform and its implementation are under development and
+  do not provide backward compatibility guarantees.
+  Uses the :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler`
+  to enrich elements by joining the metadata from external source.
+
+  Processes an input :class:`~apache_beam.pvalue.PCollection` of `beam.Row` by
+  applying a :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler`
+  to each element and returning the enriched
+  :class:`~apache_beam.pvalue.PCollection`.
+
+  Args:
+    source_handler: Handles source lookup and metadata retrieval.
+      Implements the
+      :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler`
+    join_fn: A lambda function to join original element with lookup metadata.
+      Defaults to `CROSS_JOIN`.
+    timeout: (Optional) timeout for source requests. Defaults to 30 seconds.
+    repeater (~apache_beam.io.requestresponse.Repeater): provides method to
+      repeat failed requests to API due to service errors. Defaults to
+      :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to
+      repeat requests with exponential backoff.
+    throttler (~apache_beam.io.requestresponse.PreCallThrottler):
+      provides methods to pre-throttle a request. Defaults to
+      :class:`apache_beam.io.requestresponse.DefaultThrottler` for
+      client-side adaptive throttling using
+      :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`.
+  """
+  def __init__(
+      self,
+      source_handler: EnrichmentSourceHandler,
+      join_fn: JoinFn = cross_join,
+      timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
+      repeater: Repeater = ExponentialBackOffRepeater(),
+      throttler: PreCallThrottler = DefaultThrottler(),
+  ):
+    self._source_handler = source_handler
+    self._join_fn = join_fn
+    self._timeout = timeout
+    self._repeater = repeater
+    self._throttler = throttler
+
+  def expand(self,
+             input_row: beam.PCollection[InputT]) -> beam.PCollection[OutputT]:
+    fetched_data = input_row | RequestResponseIO(
+        caller=self._source_handler,
+        timeout=self._timeout,
+        repeater=self._repeater,
+        throttler=self._throttler)
+
+    # EnrichmentSourceHandler returns a tuple of (request,response).
+    return fetched_data | beam.Map(
+        lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict()))
diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/__init__.py 
b/sdks/python/apache_beam/transforms/enrichment_handlers/__init__.py
new file mode 100644
index 00000000000..cce3acad34a
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_handlers/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py 
b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py
new file mode 100644
index 00000000000..86ff2f3b8e7
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py
@@ -0,0 +1,151 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+from enum import Enum
+from typing import Any
+from typing import Dict
+from typing import Optional
+
+from google.api_core.exceptions import NotFound
+from google.cloud import bigtable
+from google.cloud.bigtable import Client
+from google.cloud.bigtable.row_filters import CellsColumnLimitFilter
+from google.cloud.bigtable.row_filters import RowFilter
+
+import apache_beam as beam
+from apache_beam.transforms.enrichment import EnrichmentSourceHandler
+
+__all__ = [
+    'EnrichWithBigTable',
+    'ExceptionLevel',
+]
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class ExceptionLevel(Enum):
+  """ExceptionLevel defines the exception level options to either
+  log a warning, or raise an exception, or do nothing when a BigTable query
+  returns an empty row.
+
+  Members:
+    - RAISE: Raise the exception.
+    - WARN: Log a warning for exception without raising it.
+    - QUIET: Neither log nor raise the exception.
+  """
+  RAISE = 0
+  WARN = 1
+  QUIET = 2
+
+
+class EnrichWithBigTable(EnrichmentSourceHandler[beam.Row, beam.Row]):
+  """EnrichWithBigTable is a handler for
+  :class:`apache_beam.transforms.enrichment.Enrichment` transform to interact
+  with GCP BigTable.
+
+  Args:
+    project_id (str): GCP project-id of the BigTable cluster.
+    instance_id (str): GCP instance-id of the BigTable cluster.
+    table_id (str): GCP table-id of the BigTable.
+    row_key (str): unique row-key field name from the input `beam.Row` object
+      to use as `row_key` for BigTable querying.
+    row_filter: a ``:class:`google.cloud.bigtable.row_filters.RowFilter``` to
+      filter data read with ``read_row()``.
+      Defaults to `CellsColumnLimitFilter(1)`.
+    app_profile_id (str): App profile ID to use for BigTable.
+      See https://cloud.google.com/bigtable/docs/app-profiles for more details.
+    encoding (str): encoding type to convert the string to bytes and vice-versa
+      from BigTable. Default is `utf-8`.
+    exception_level: a `enum.Enum` value from
+      ``apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel``
+      to set the level when an empty row is returned from the BigTable query.
+      Defaults to ``ExceptionLevel.WARN``.
+  """
+  def __init__(
+      self,
+      project_id: str,
+      instance_id: str,
+      table_id: str,
+      row_key: str,
+      row_filter: Optional[RowFilter] = CellsColumnLimitFilter(1),
+      app_profile_id: str = None,  # type: ignore[assignment]
+      encoding: str = 'utf-8',
+      exception_level: ExceptionLevel = ExceptionLevel.WARN,
+  ):
+    self._project_id = project_id
+    self._instance_id = instance_id
+    self._table_id = table_id
+    self._row_key = row_key
+    self._row_filter = row_filter
+    self._app_profile_id = app_profile_id
+    self._encoding = encoding
+    self._exception_level = exception_level
+
+  def __enter__(self):
+    """connect to the Google BigTable cluster."""
+    self.client = Client(project=self._project_id)
+    self.instance = self.client.instance(self._instance_id)
+    self._table = bigtable.table.Table(
+        table_id=self._table_id,
+        instance=self.instance,
+        app_profile_id=self._app_profile_id)
+
+  def __call__(self, request: beam.Row, *args, **kwargs):
+    """
+    Reads a row from the GCP BigTable and returns
+    a `Tuple` of request and response.
+
+    Args:
+    request: the input `beam.Row` to enrich.
+    """
+    response_dict: Dict[str, Any] = {}
+    row_key_str: str = ""
+    try:
+      request_dict = request._asdict()
+      row_key_str = str(request_dict[self._row_key])
+      row_key = row_key_str.encode(self._encoding)
+      row = self._table.read_row(row_key, filter_=self._row_filter)
+      if row:
+        for cf_id, cf_v in row.cells.items():
+          response_dict[cf_id] = {}
+          for k, v in cf_v.items():
+            response_dict[cf_id][k.decode(self._encoding)] = \
+              v[0].value.decode(self._encoding)
+      elif self._exception_level == ExceptionLevel.WARN:
+        _LOGGER.warning(
+            'no matching row found for row_key: %s '
+            'with row_filter: %s' % (row_key_str, self._row_filter))
+      elif self._exception_level == ExceptionLevel.RAISE:
+        raise ValueError(
+            'no matching row found for row_key: %s '
+            'with row_filter=%s' % (row_key_str, self._row_filter))
+    except KeyError:
+      raise KeyError('row_key %s not found in input PCollection.' % 
row_key_str)
+    except NotFound:
+      raise NotFound(
+          'GCP BigTable cluster `%s:%s:%s` not found.' %
+          (self._project_id, self._instance_id, self._table_id))
+    except Exception as e:
+      raise e
+
+    return request, beam.Row(**response_dict)
+
+  def __exit__(self, exc_type, exc_val, exc_tb):
+    """Clean the instantiated BigTable client."""
+    self.client = None
+    self.instance = None
+    self._table = None
diff --git 
a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py 
b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py
new file mode 100644
index 00000000000..dd48c8e5ef4
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py
@@ -0,0 +1,300 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import unittest
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import BeamAssertException
+
+# pylint: disable=ungrouped-imports
+try:
+  from google.api_core.exceptions import NotFound
+  from google.cloud.bigtable import Client
+  from google.cloud.bigtable.row_filters import ColumnRangeFilter
+  from apache_beam.transforms.enrichment import Enrichment
+  from apache_beam.transforms.enrichment_handlers.bigtable import 
EnrichWithBigTable
+  from apache_beam.transforms.enrichment_handlers.bigtable import 
ExceptionLevel
+except ImportError:
+  raise unittest.SkipTest('GCP BigTable dependencies are not installed.')
+
+
+class ValidateResponse(beam.DoFn):
+  """ValidateResponse validates if a PCollection of `beam.Row`
+  has the required fields."""
+  def __init__(
+      self,
+      n_fields: int,
+      fields: List[str],
+      enriched_fields: Dict[str, List[str]]):
+    self.n_fields = n_fields
+    self._fields = fields
+    self._enriched_fields = enriched_fields
+
+  def process(self, element: beam.Row, *args, **kwargs):
+    element_dict = element.as_dict()
+    if len(element_dict.keys()) != self.n_fields:
+      raise BeamAssertException(
+          "Expected %d fields in enriched PCollection:" % self.n_fields)
+
+    for field in self._fields:
+      if field not in element_dict or element_dict[field] is None:
+        raise BeamAssertException(f"Expected a not None field: {field}")
+
+    for column_family, columns in self._enriched_fields.items():
+      if (len(element_dict[column_family]) != len(columns) or
+          not all(key in element_dict[column_family] for key in columns)):
+        raise BeamAssertException(
+            "Response from bigtable should contain a %s column_family with "
+            "%s keys." % (column_family, columns))
+
+
+class _Currency(NamedTuple):
+  s_id: int
+  id: str
+
+
+def create_rows(table):
+  product_id = 'product_id'
+  product_name = 'product_name'
+  product_stock = 'product_stock'
+
+  column_family_id = "product"
+  products = [
+      {
+          'product_id': 1, 'product_name': 'pixel 5', 'product_stock': 2
+      },
+      {
+          'product_id': 2, 'product_name': 'pixel 6', 'product_stock': 4
+      },
+      {
+          'product_id': 3, 'product_name': 'pixel 7', 'product_stock': 20
+      },
+      {
+          'product_id': 4, 'product_name': 'pixel 8', 'product_stock': 10
+      },
+      {
+          'product_id': 5, 'product_name': 'iphone 11', 'product_stock': 3
+      },
+      {
+          'product_id': 6, 'product_name': 'iphone 12', 'product_stock': 7
+      },
+      {
+          'product_id': 7, 'product_name': 'iphone 13', 'product_stock': 8
+      },
+      {
+          'product_id': 8, 'product_name': 'iphone 14', 'product_stock': 3
+      },
+  ]
+
+  for item in products:
+    row_key = str(item[product_id]).encode()
+    row = table.direct_row(row_key)
+    row.set_cell(
+        column_family_id,
+        product_id.encode(),
+        str(item[product_id]),
+        timestamp=datetime.datetime.utcnow())
+    row.set_cell(
+        column_family_id,
+        product_name.encode(),
+        item[product_name],
+        timestamp=datetime.datetime.utcnow())
+    row.set_cell(
+        column_family_id,
+        product_stock.encode(),
+        str(item[product_stock]),
+        timestamp=datetime.datetime.utcnow())
+    row.commit()
+
+
[email protected]_postcommit
+class TestBigTableEnrichment(unittest.TestCase):
+  def setUp(self):
+    self.project_id = 'apache-beam-testing'
+    self.instance_id = 'beam-test'
+    self.table_id = 'bigtable-enrichment-test'
+    self.req = [
+        beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1),
+        beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3),
+        beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2),
+        beam.Row(sale_id=7, customer_id=7, product_id=1, quantity=1),
+    ]
+    self.row_key = 'product_id'
+    self.column_family_id = 'product'
+    client = Client(project=self.project_id)
+    instance = client.instance(self.instance_id)
+    self.table = instance.table(self.table_id)
+    create_rows(self.table)
+
+  def tearDown(self) -> None:
+    self.table = None
+
+  def test_enrichment_with_bigtable(self):
+    expected_fields = [
+        'sale_id', 'customer_id', 'product_id', 'quantity', 'product'
+    ]
+    expected_enriched_fields = {
+        'product': ['product_id', 'product_name', 'product_stock'],
+    }
+    bigtable = EnrichWithBigTable(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        table_id=self.table_id,
+        row_key=self.row_key)
+    with TestPipeline(is_integration_test=True) as test_pipeline:
+      _ = (
+          test_pipeline
+          | "Create" >> beam.Create(self.req)
+          | "Enrich W/ BigTable" >> Enrichment(bigtable)
+          | "Validate Response" >> beam.ParDo(
+              ValidateResponse(
+                  len(expected_fields),
+                  expected_fields,
+                  expected_enriched_fields)))
+
+  def test_enrichment_with_bigtable_row_filter(self):
+    expected_fields = [
+        'sale_id', 'customer_id', 'product_id', 'quantity', 'product'
+    ]
+    expected_enriched_fields = {
+        'product': ['product_name', 'product_stock'],
+    }
+    start_column = 'product_name'.encode()
+    column_filter = ColumnRangeFilter(self.column_family_id, start_column)
+    bigtable = EnrichWithBigTable(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        table_id=self.table_id,
+        row_key=self.row_key,
+        row_filter=column_filter)
+    with TestPipeline(is_integration_test=True) as test_pipeline:
+      _ = (
+          test_pipeline
+          | "Create" >> beam.Create(self.req)
+          | "Enrich W/ BigTable" >> Enrichment(bigtable)
+          | "Validate Response" >> beam.ParDo(
+              ValidateResponse(
+                  len(expected_fields),
+                  expected_fields,
+                  expected_enriched_fields)))
+
+  def test_enrichment_with_bigtable_no_enrichment(self):
+    # row_key which is product_id=11 doesn't exist, so the enriched field
+    # won't be added. Hence, the response is same as the request.
+    expected_fields = ['sale_id', 'customer_id', 'product_id', 'quantity']
+    expected_enriched_fields = {}
+    bigtable = EnrichWithBigTable(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        table_id=self.table_id,
+        row_key=self.row_key)
+    req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)]
+    with TestPipeline(is_integration_test=True) as test_pipeline:
+      _ = (
+          test_pipeline
+          | "Create" >> beam.Create(req)
+          | "Enrich W/ BigTable" >> Enrichment(bigtable)
+          | "Validate Response" >> beam.ParDo(
+              ValidateResponse(
+                  len(expected_fields),
+                  expected_fields,
+                  expected_enriched_fields)))
+
+  def test_enrichment_with_bigtable_bad_row_filter(self):
+    # in case of a bad column filter, that is, incorrect column_family_id and
+    # columns, no enrichment is done. If the column_family is correct but not
+    # column names then all columns in that column_family are returned.
+    start_column = 'car_name'.encode()
+    column_filter = ColumnRangeFilter('car_name', start_column)
+    bigtable = EnrichWithBigTable(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        table_id=self.table_id,
+        row_key=self.row_key,
+        row_filter=column_filter)
+    with self.assertRaises(NotFound):
+      test_pipeline = beam.Pipeline()
+      _ = (
+          test_pipeline
+          | "Create" >> beam.Create(self.req)
+          | "Enrich W/ BigTable" >> Enrichment(bigtable))
+      res = test_pipeline.run()
+      res.wait_until_finish()
+
+  def test_enrichment_with_bigtable_raises_key_error(self):
+    """raises a `KeyError` when the row_key doesn't exist in
+    the input PCollection."""
+    bigtable = EnrichWithBigTable(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        table_id=self.table_id,
+        row_key='car_name')
+    with self.assertRaises(KeyError):
+      test_pipeline = beam.Pipeline()
+      _ = (
+          test_pipeline
+          | "Create" >> beam.Create(self.req)
+          | "Enrich W/ BigTable" >> Enrichment(bigtable))
+      res = test_pipeline.run()
+      res.wait_until_finish()
+
+  def test_enrichment_with_bigtable_raises_not_found(self):
+    """raises a `NotFound` exception when the GCP BigTable Cluster
+    doesn't exist."""
+    bigtable = EnrichWithBigTable(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        table_id='invalid_table',
+        row_key=self.row_key)
+    with self.assertRaises(NotFound):
+      test_pipeline = beam.Pipeline()
+      _ = (
+          test_pipeline
+          | "Create" >> beam.Create(self.req)
+          | "Enrich W/ BigTable" >> Enrichment(bigtable))
+      res = test_pipeline.run()
+      res.wait_until_finish()
+
+  def test_enrichment_with_bigtable_exception_level(self):
+    """raises a `ValueError` exception when the GCP BigTable query returns
+    an empty row."""
+    bigtable = EnrichWithBigTable(
+        project_id=self.project_id,
+        instance_id=self.instance_id,
+        table_id=self.table_id,
+        row_key=self.row_key,
+        exception_level=ExceptionLevel.RAISE)
+    req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)]
+    with self.assertRaises(ValueError):
+      test_pipeline = beam.Pipeline()
+      _ = (
+          test_pipeline
+          | "Create" >> beam.Create(req)
+          | "Enrich W/ BigTable" >> Enrichment(bigtable))
+      res = test_pipeline.run()
+      res.wait_until_finish()
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/transforms/enrichment_it_test.py 
b/sdks/python/apache_beam/transforms/enrichment_it_test.py
new file mode 100644
index 00000000000..89842cb18be
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_it_test.py
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import time
+import unittest
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import pytest
+import urllib3
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import BeamAssertException
+
+# pylint: disable=ungrouped-imports
+try:
+  from apache_beam.io.requestresponse import UserCodeExecutionException
+  from apache_beam.io.requestresponse import UserCodeQuotaException
+  from apache_beam.io.requestresponse_it_test import _PAYLOAD
+  from apache_beam.io.requestresponse_it_test import EchoITOptions
+  from apache_beam.transforms.enrichment import Enrichment
+  from apache_beam.transforms.enrichment import EnrichmentSourceHandler
+except ImportError:
+  raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
+
+
+class Request(NamedTuple):
+  id: str
+  payload: bytes
+
+
+def _custom_join(left, right):
+  """custom_join returns the id and resp_payload along with a timestamp"""
+  right['timestamp'] = time.time()
+  return beam.Row(**right)
+
+
+class SampleHTTPEnrichment(EnrichmentSourceHandler[Request, beam.Row]):
+  """Implements ``EnrichmentSourceHandler`` to call the ``EchoServiceGrpc``'s
+  HTTP handler.
+  """
+  def __init__(self, url: str):
+    self.url = url + '/v1/echo'  # append path to the mock API.
+
+  def __call__(self, request: Request, *args, **kwargs):
+    """Overrides ``Caller``'s call method invoking the
+    ``EchoServiceGrpc``'s HTTP handler with an `dict`, returning
+    either a successful ``Tuple[dict,dict]`` or throwing either a
+    ``UserCodeExecutionException``, ``UserCodeTimeoutException``,
+    or a ``UserCodeQuotaException``.
+    """
+    try:
+      resp = urllib3.request(
+          "POST",
+          self.url,
+          json={
+              "id": request.id, "payload": str(request.payload, 'utf-8')
+          },
+          retries=False)
+
+      if resp.status < 300:
+        resp_body = resp.json()
+        resp_id = resp_body['id']
+        payload = resp_body['payload']
+        return (
+            request, beam.Row(id=resp_id, resp_payload=bytes(payload, 
'utf-8')))
+
+      if resp.status == 429:  # Too Many Requests
+        raise UserCodeQuotaException(resp.reason)
+      elif resp.status != 200:
+        raise UserCodeExecutionException(resp.status, resp.reason, request)
+
+    except urllib3.exceptions.HTTPError as e:
+      raise UserCodeExecutionException(e)
+
+
+class ValidateFields(beam.DoFn):
+  """ValidateFields validates if a PCollection of `beam.Row`
+  has certain fields."""
+  def __init__(self, n_fields: int, fields: List[str]):
+    self.n_fields = n_fields
+    self._fields = fields
+
+  def process(self, element: beam.Row, *args, **kwargs):
+    element_dict = element.as_dict()
+    if len(element_dict.keys()) != self.n_fields:
+      raise BeamAssertException(
+          "Expected %d fields in enriched PCollection:"
+          " id, payload and resp_payload" % self.n_fields)
+
+    for field in self._fields:
+      if field not in element_dict or element_dict[field] is None:
+        raise BeamAssertException(f"Expected a not None field: {field}")
+
+
[email protected]_postcommit
+class TestEnrichment(unittest.TestCase):
+  options: Union[EchoITOptions, None] = None
+  client: Union[SampleHTTPEnrichment, None] = None
+
+  @classmethod
+  def setUpClass(cls) -> None:
+    cls.options = EchoITOptions()
+    http_endpoint_address = cls.options.http_endpoint_address
+    if not http_endpoint_address or http_endpoint_address == '':
+      raise unittest.SkipTest('HTTP_ENDPOINT_ADDRESS is required.')
+    cls.client = SampleHTTPEnrichment(http_endpoint_address)
+
+  @classmethod
+  def _get_client_and_options(
+      cls) -> Tuple[SampleHTTPEnrichment, EchoITOptions]:
+    assert cls.options is not None
+    assert cls.client is not None
+    return cls.client, cls.options
+
+  def test_http_enrichment(self):
+    """Tests Enrichment Transform against the Mock-API HTTP endpoint
+    with the default cross join."""
+    client, options = TestEnrichment._get_client_and_options()
+    req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
+    fields = ['id', 'payload', 'resp_payload']
+    with TestPipeline(is_integration_test=True) as test_pipeline:
+      _ = (
+          test_pipeline
+          | 'Create PCollection' >> beam.Create([req])
+          | 'Enrichment Transform' >> Enrichment(client)
+          | 'Assert Fields' >> beam.ParDo(
+              ValidateFields(len(fields), fields=fields)))
+
+  def test_http_enrichment_custom_join(self):
+    """Tests Enrichment Transform against the Mock-API HTTP endpoint
+    with a custom join function."""
+    client, options = TestEnrichment._get_client_and_options()
+    req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
+    fields = ['id', 'resp_payload', 'timestamp']
+    with TestPipeline(is_integration_test=True) as test_pipeline:
+      _ = (
+          test_pipeline
+          | 'Create PCollection' >> beam.Create([req])
+          | 'Enrichment Transform' >> Enrichment(client, join_fn=_custom_join)
+          | 'Assert Fields' >> beam.ParDo(
+              ValidateFields(len(fields), fields=fields)))
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/transforms/enrichment_test.py 
b/sdks/python/apache_beam/transforms/enrichment_test.py
new file mode 100644
index 00000000000..23b5f1828c1
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/enrichment_test.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+
+# pylint: disable=ungrouped-imports
+try:
+  from apache_beam.transforms.enrichment import cross_join
+except ImportError:
+  raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
+
+
+class TestEnrichmentTransform(unittest.TestCase):
+  def test_cross_join(self):
+    left = {'id': 1, 'key': 'city'}
+    right = {'id': 1, 'value': 'durham'}
+    expected = beam.Row(id=1, key='city', value='durham')
+    output = cross_join(left, right)
+    self.assertEqual(expected, output)
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

Reply via email to