riteshghorse commented on code in PR #30307:
URL: https://github.com/apache/beam/pull/30307#discussion_r1491178369


##########
sdks/python/apache_beam/io/requestresponse.py:
##########
@@ -411,3 +387,440 @@ def process(self, request: RequestT, *args, **kwargs):
   def teardown(self):
     self._metrics_collector.teardown_counter.inc(1)
     self._caller.__exit__(*sys.exc_info())
+
+
+class Cache(abc.ABC):
+  """Base Cache class for
+  :class:`apache_beam.io.requestresponse.RequestResponseIO`.
+
+  For adding cache support to RequestResponseIO, implement this class.
+  """
+  @abc.abstractmethod
+  def get_read(self):
+    """get_read returns a PTransform that reads from the cache."""
+    pass
+
+  @abc.abstractmethod
+  def get_write(self):
+    """get_write returns a PTransform that writes to the cache."""
+    pass
+
+  @abc.abstractmethod
+  def has_request_coder(self) -> bool:
+    """returns `True` if the request coder is present."""
+    pass
+
+  @abc.abstractmethod
+  def set_request_coder(self, request_coder: coders.Coder):
+    """sets the request coder to use with Cache."""
+    pass
+
+  @abc.abstractmethod
+  def set_response_coder(self, response_coder: coders.Coder):
+    """sets the response coder to use with Cache."""
+    pass
+
+  @abc.abstractmethod
+  def set_source_caller(self, caller: Caller):
+    """(Internal-only) This method allows
+    :class:`apache_beam.io.requestresponse.RequestResponseIO` to pull
+    cache requests from respective callers."""
+    pass
+
+
+class _RedisMode(enum.Enum):
+  """
+  Mode of operation for redis cache when using
+  :class:`apache_beam.io.requestresponse.RedisCaller`.
+  """
+  READ = 0
+  WRITE = 1
+
+
+class RedisCaller(Caller):
+  """`RedisCaller` is an implementation of
+  :class:`apache_beam.io.requestresponse.Caller` for Redis client.
+
+  It provides the functionality for making requests to Redis server using
+  :class:`apache_beam.io.requestresponse.RequestResponseIO`.
+  """
+  def __init__(
+      self,
+      host: str,
+      port: int,
+      time_to_live: Union[int, timedelta],
+      *,
+      request_coder: Optional[coders.Coder],
+      response_coder: Optional[coders.Coder],
+      kwargs: Optional[Dict[str, Any]] = None,
+      source_caller: Optional[Caller] = None,
+      mode: _RedisMode,
+  ):
+    """
+    Args:
+      host (str): The hostname or IP address of the Redis server.
+      port (int): The port number of the Redis server.
+      time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
+        records stored in Redis. Provide an integer (in seconds) or a
+        `datetime.timedelta` object.
+      request_coder: (Optional[`coders.Coder`]) coder for requests stored
+        in Redis.
+      response_coder: (Optional[`coders.Coder`]) coder for decoding responses
+        received from Redis.
+      kwargs: Optional(Dict[str, Any]) additional keyword arguments that
+        are required to connect to your redis server. Same as `redis.Redis()`.
+      source_caller: (Optional[`Caller`]): The source caller using this Redis
+        cache in case of fetching the cache request to store in Redis.
+      mode: `_RedisMode` An enum type specifying the operational mode of
+        the `RedisCaller`.
+    """
+    self.host, self.port = host, port
+    self.time_to_live = time_to_live
+    self.request_coder = request_coder
+    self.response_coder = response_coder
+    self.kwargs = kwargs
+    self.source_caller = source_caller
+    self.mode = mode
+
+  def __enter__(self):
+    self.client = redis.Redis(self.host, self.port, **self.kwargs)
+
+  def __call__(self, element, *args, **kwargs):
+    if self.mode == _RedisMode.READ:
+      cache_request = self.source_caller.get_cache_request(element)
+      # check if the caller is a enrichment handler. EnrichmentHandler
+      # provides the request format for cache.
+      if cache_request:
+        encoded_request = self.request_coder.encode(cache_request)
+      else:
+        encoded_request = self.request_coder.encode(element)
+
+      encoded_response = self.client.get(encoded_request)
+      if not encoded_response:
+        # no cache entry present for this request.
+        return element, None
+
+      if self.response_coder is None:
+        try:
+          response_dict = json.loads(encoded_response.decode('utf-8'))
+          response = beam.Row(**response_dict)
+        except Exception:
+          _LOGGER.warning(
+              'cannot decode response from redis cache for %s.' % element)
+          return element, None
+      else:
+        response = self.response_coder.decode(encoded_response)
+      return element, response
+    else:
+      cache_request = self.source_caller.get_cache_request(element[0])
+      if cache_request:
+        encoded_request = self.request_coder.encode(cache_request)
+      else:
+        encoded_request = self.request_coder.encode(element[0])
+      if self.response_coder is None:
+        try:
+          encoded_response = json.dumps(element[1]._asdict()).encode('utf-8')
+        except Exception as e:
+          _LOGGER.warning(
+              'cannot encode response %s for %s to store in '
+              'redis cache.' % (element[1], element[0]))
+          raise e
+      else:
+        encoded_response = self.response_coder.encode(element[1])
+      # Write to cache with TTL. Set nx to True to prevent overwriting for the
+      # same key.
+      self.client.set(
+          encoded_request, encoded_response, self.time_to_live, nx=True)
+      return element
+
+  def __exit__(self, exc_type, exc_val, exc_tb):
+    self.client.close()
+
+
+class ReadFromRedis(beam.PTransform[beam.PCollection[RequestT],
+                                    beam.PCollection[ResponseT]]):
+  """ReadFromRedis is a `PTransform` that performs Redis cache read."""
+  def __init__(
+      self,
+      host: str,
+      port: int,
+      time_to_live: Union[int, timedelta],
+      *,
+      kwargs: Optional[Dict[str, Any]] = None,
+      request_coder: Optional[coders.Coder],
+      response_coder: Optional[coders.Coder],
+      source_caller: Optional[Caller[RequestT, ResponseT]] = None,
+  ):
+    """
+    Args:
+      host (str): The hostname or IP address of the Redis server.
+      port (int): The port number of the Redis server.
+      time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
+        records stored in Redis. Provide an integer (in seconds) or a
+        `datetime.timedelta` object.
+      kwargs: Optional(Dict[str, Any]) additional keyword arguments that
+        are required to connect to your redis server. Same as `redis.Redis()`.
+      request_coder: (Optional[`coders.Coder`]) coder for requests stored
+        in Redis.
+      response_coder: (Optional[`coders.Coder`]) coder for decoding responses
+        received from Redis.
+      source_caller: (Optional[`Caller`]): The source caller using this Redis
+        cache in case of fetching the cache request to store in Redis.
+    """
+    self.request_coder = request_coder
+    self.response_coder = response_coder
+    self.redis_caller = RedisCaller(
+        host,
+        port,
+        time_to_live,
+        request_coder=self.request_coder,
+        response_coder=self.response_coder,
+        kwargs=kwargs,
+        source_caller=source_caller,
+        mode=_RedisMode.READ)
+
+  def expand(
+      self,
+      requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
+    return requests | RequestResponseIO(self.redis_caller)
+
+
+class WriteToRedis(beam.PTransform[beam.PCollection[Tuple[RequestT, 
ResponseT]],
+                                   beam.PCollection[ResponseT]]):
+  """WriteToRedis is a `PTransfrom` that performs write to Redis cache."""
+  def __init__(
+      self,
+      host: str,
+      port: int,
+      time_to_live: Union[int, timedelta],
+      *,
+      kwargs: Optional[Dict[str, Any]] = None,
+      request_coder: Optional[coders.Coder],
+      response_coder: Optional[coders.Coder],
+      source_caller: Optional[Caller[RequestT, ResponseT]] = None,
+  ):
+    """
+    Args:
+      host (str): The hostname or IP address of the Redis server.
+      port (int): The port number of the Redis server.
+      time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
+        records stored in Redis. Provide an integer (in seconds) or a
+        `datetime.timedelta` object.
+      kwargs: Optional(Dict[str, Any]) additional keyword arguments that
+        are required to connect to your redis server. Same as `redis.Redis()`.
+      request_coder: (Optional[`coders.Coder`]) coder for requests stored
+        in Redis.
+      response_coder: (Optional[`coders.Coder`]) coder for decoding responses
+        received from Redis.
+      source_caller: (Optional[`Caller`]): The source caller using this Redis
+        cache in case of fetching the cache request to store in Redis.
+      """
+    self.request_coder = request_coder
+    self.response_coder = response_coder
+    self.redis_caller = RedisCaller(
+        host,
+        port,
+        time_to_live,
+        request_coder=self.request_coder,
+        response_coder=self.response_coder,
+        kwargs=kwargs,
+        source_caller=source_caller,
+        mode=_RedisMode.WRITE)
+
+  def expand(
+      self, elements: beam.PCollection[Tuple[RequestT, ResponseT]]
+  ) -> beam.PCollection[ResponseT]:
+    return elements | RequestResponseIO(self.redis_caller)
+
+
+def ensure_coders_exist(request_coder):
+  """checks if the coder exists to encode the request for caching."""
+  if not request_coder:
+    _LOGGER.warning(
+        'need request coder to be able to use'
+        'Cache with RequestResponseIO.')
+
+
+class RedisCache(Cache):
+  """RedisCache to configure cache using Redis for
+  :class:`apache_beam.io.requestresponse.RequestResponseIO`."""
+  def __init__(
+      self,
+      host: str,
+      port: int,
+      time_to_live: Union[int, timedelta] = DEFAULT_TIME_TO_LIVE_SECS,
+      request_coder: Optional[coders.Coder] = None,
+      response_coder: Optional[coders.Coder] = None,
+      *,
+      kwargs: Optional[Dict[str, Any]] = None,
+  ):
+    """
+    Args:
+      host (str): The hostname or IP address of the Redis server.
+      port (int): The port number of the Redis server.
+      time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
+        records stored in Redis. Provide an integer (in seconds) or a
+        `datetime.timedelta` object.
+      request_coder: (Optional[`coders.Coder`]) coder for encoding requests.
+      response_coder: (Optional[`coders.Coder`]) coder for decoding responses
+        received from Redis.
+      kwargs: Optional(Dict[str, Any]) additional keyword arguments that
+        are required to connect to your redis server. Same as `redis.Redis()`.
+    """
+    self._host = host
+    self._port = port
+    self._time_to_live = time_to_live
+    self._request_coder = request_coder
+    self._response_coder = response_coder
+    self._kwargs = kwargs if kwargs else {}
+    self._source_caller = None
+
+  def get_read(self):
+    """get_read returns a callback that returns a PTransform
+    for reading from the cache."""
+    ensure_coders_exist(self._request_coder)
+
+    def callback():

Review Comment:
   Good catch. Oops.. i may be experimenting and left it here. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to