riteshghorse commented on code in PR #30307:
URL: https://github.com/apache/beam/pull/30307#discussion_r1491177265
##########
sdks/python/apache_beam/io/requestresponse.py:
##########
@@ -411,3 +387,440 @@ def process(self, request: RequestT, *args, **kwargs):
def teardown(self):
self._metrics_collector.teardown_counter.inc(1)
self._caller.__exit__(*sys.exc_info())
+
+
+class Cache(abc.ABC):
+ """Base Cache class for
+ :class:`apache_beam.io.requestresponse.RequestResponseIO`.
+
+ For adding cache support to RequestResponseIO, implement this class.
+ """
+ @abc.abstractmethod
+ def get_read(self):
+ """get_read returns a PTransform that reads from the cache."""
+ pass
+
+ @abc.abstractmethod
+ def get_write(self):
+ """get_write returns a PTransform that writes to the cache."""
+ pass
+
+ @abc.abstractmethod
+ def has_request_coder(self) -> bool:
+ """returns `True` if the request coder is present."""
+ pass
+
+ @abc.abstractmethod
+ def set_request_coder(self, request_coder: coders.Coder):
+ """sets the request coder to use with Cache."""
+ pass
+
+ @abc.abstractmethod
+ def set_response_coder(self, response_coder: coders.Coder):
+ """sets the response coder to use with Cache."""
+ pass
+
+ @abc.abstractmethod
+ def set_source_caller(self, caller: Caller):
+ """(Internal-only) This method allows
+ :class:`apache_beam.io.requestresponse.RequestResponseIO` to pull
+ cache requests from respective callers."""
+ pass
+
+
+class _RedisMode(enum.Enum):
+ """
+ Mode of operation for redis cache when using
+ :class:`apache_beam.io.requestresponse.RedisCaller`.
+ """
+ READ = 0
+ WRITE = 1
+
+
+class RedisCaller(Caller):
+ """`RedisCaller` is an implementation of
+ :class:`apache_beam.io.requestresponse.Caller` for Redis client.
+
+ It provides the functionality for making requests to Redis server using
+ :class:`apache_beam.io.requestresponse.RequestResponseIO`.
+ """
+ def __init__(
+ self,
+ host: str,
+ port: int,
+ time_to_live: Union[int, timedelta],
+ *,
+ request_coder: Optional[coders.Coder],
+ response_coder: Optional[coders.Coder],
+ kwargs: Optional[Dict[str, Any]] = None,
+ source_caller: Optional[Caller] = None,
+ mode: _RedisMode,
+ ):
+ """
+ Args:
+ host (str): The hostname or IP address of the Redis server.
+ port (int): The port number of the Redis server.
+ time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
+ records stored in Redis. Provide an integer (in seconds) or a
+ `datetime.timedelta` object.
+ request_coder: (Optional[`coders.Coder`]) coder for requests stored
+ in Redis.
+ response_coder: (Optional[`coders.Coder`]) coder for decoding responses
+ received from Redis.
+ kwargs: Optional(Dict[str, Any]) additional keyword arguments that
+ are required to connect to your redis server. Same as `redis.Redis()`.
+ source_caller: (Optional[`Caller`]): The source caller using this Redis
+ cache in case of fetching the cache request to store in Redis.
+ mode: `_RedisMode` An enum type specifying the operational mode of
+ the `RedisCaller`.
+ """
+ self.host, self.port = host, port
+ self.time_to_live = time_to_live
+ self.request_coder = request_coder
+ self.response_coder = response_coder
+ self.kwargs = kwargs
+ self.source_caller = source_caller
+ self.mode = mode
+
+ def __enter__(self):
+ self.client = redis.Redis(self.host, self.port, **self.kwargs)
+
+ def __call__(self, element, *args, **kwargs):
+ if self.mode == _RedisMode.READ:
+ cache_request = self.source_caller.get_cache_request(element)
+ # check if the caller is a enrichment handler. EnrichmentHandler
+ # provides the request format for cache.
+ if cache_request:
+ encoded_request = self.request_coder.encode(cache_request)
+ else:
+ encoded_request = self.request_coder.encode(element)
+
+ encoded_response = self.client.get(encoded_request)
+ if not encoded_response:
+ # no cache entry present for this request.
+ return element, None
+
+ if self.response_coder is None:
+ try:
+ response_dict = json.loads(encoded_response.decode('utf-8'))
+ response = beam.Row(**response_dict)
+ except Exception:
+ _LOGGER.warning(
+ 'cannot decode response from redis cache for %s.' % element)
+ return element, None
+ else:
+ response = self.response_coder.decode(encoded_response)
+ return element, response
+ else:
+ cache_request = self.source_caller.get_cache_request(element[0])
+ if cache_request:
+ encoded_request = self.request_coder.encode(cache_request)
+ else:
+ encoded_request = self.request_coder.encode(element[0])
+ if self.response_coder is None:
+ try:
+ encoded_response = json.dumps(element[1]._asdict()).encode('utf-8')
+ except Exception as e:
+ _LOGGER.warning(
+ 'cannot encode response %s for %s to store in '
+ 'redis cache.' % (element[1], element[0]))
+ raise e
+ else:
+ encoded_response = self.response_coder.encode(element[1])
+ # Write to cache with TTL. Set nx to True to prevent overwriting for the
+ # same key.
+ self.client.set(
+ encoded_request, encoded_response, self.time_to_live, nx=True)
+ return element
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.client.close()
+
+
+class ReadFromRedis(beam.PTransform[beam.PCollection[RequestT],
+ beam.PCollection[ResponseT]]):
+ """ReadFromRedis is a `PTransform` that performs Redis cache read."""
+ def __init__(
+ self,
+ host: str,
+ port: int,
+ time_to_live: Union[int, timedelta],
+ *,
+ kwargs: Optional[Dict[str, Any]] = None,
+ request_coder: Optional[coders.Coder],
+ response_coder: Optional[coders.Coder],
+ source_caller: Optional[Caller[RequestT, ResponseT]] = None,
+ ):
+ """
+ Args:
+ host (str): The hostname or IP address of the Redis server.
+ port (int): The port number of the Redis server.
+ time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
+ records stored in Redis. Provide an integer (in seconds) or a
+ `datetime.timedelta` object.
+ kwargs: Optional(Dict[str, Any]) additional keyword arguments that
+ are required to connect to your redis server. Same as `redis.Redis()`.
+ request_coder: (Optional[`coders.Coder`]) coder for requests stored
+ in Redis.
+ response_coder: (Optional[`coders.Coder`]) coder for decoding responses
+ received from Redis.
+ source_caller: (Optional[`Caller`]): The source caller using this Redis
+ cache in case of fetching the cache request to store in Redis.
+ """
+ self.request_coder = request_coder
+ self.response_coder = response_coder
+ self.redis_caller = RedisCaller(
+ host,
+ port,
+ time_to_live,
+ request_coder=self.request_coder,
+ response_coder=self.response_coder,
+ kwargs=kwargs,
+ source_caller=source_caller,
+ mode=_RedisMode.READ)
+
+ def expand(
+ self,
+ requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
+ return requests | RequestResponseIO(self.redis_caller)
+
+
+class WriteToRedis(beam.PTransform[beam.PCollection[Tuple[RequestT,
ResponseT]],
+ beam.PCollection[ResponseT]]):
+ """WriteToRedis is a `PTransfrom` that performs write to Redis cache."""
+ def __init__(
+ self,
+ host: str,
+ port: int,
+ time_to_live: Union[int, timedelta],
+ *,
+ kwargs: Optional[Dict[str, Any]] = None,
+ request_coder: Optional[coders.Coder],
+ response_coder: Optional[coders.Coder],
+ source_caller: Optional[Caller[RequestT, ResponseT]] = None,
+ ):
+ """
+ Args:
+ host (str): The hostname or IP address of the Redis server.
+ port (int): The port number of the Redis server.
+ time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
+ records stored in Redis. Provide an integer (in seconds) or a
+ `datetime.timedelta` object.
+ kwargs: Optional(Dict[str, Any]) additional keyword arguments that
+ are required to connect to your redis server. Same as `redis.Redis()`.
+ request_coder: (Optional[`coders.Coder`]) coder for requests stored
+ in Redis.
+ response_coder: (Optional[`coders.Coder`]) coder for decoding responses
+ received from Redis.
+ source_caller: (Optional[`Caller`]): The source caller using this Redis
+ cache in case of fetching the cache request to store in Redis.
+ """
+ self.request_coder = request_coder
+ self.response_coder = response_coder
+ self.redis_caller = RedisCaller(
+ host,
+ port,
+ time_to_live,
+ request_coder=self.request_coder,
+ response_coder=self.response_coder,
+ kwargs=kwargs,
+ source_caller=source_caller,
+ mode=_RedisMode.WRITE)
+
+ def expand(
+ self, elements: beam.PCollection[Tuple[RequestT, ResponseT]]
+ ) -> beam.PCollection[ResponseT]:
+ return elements | RequestResponseIO(self.redis_caller)
+
+
+def ensure_coders_exist(request_coder):
+ """checks if the coder exists to encode the request for caching."""
+ if not request_coder:
+ _LOGGER.warning(
+ 'need request coder to be able to use'
+ 'Cache with RequestResponseIO.')
+
+
+class RedisCache(Cache):
+ """RedisCache to configure cache using Redis for
+ :class:`apache_beam.io.requestresponse.RequestResponseIO`."""
+ def __init__(
+ self,
+ host: str,
+ port: int,
+ time_to_live: Union[int, timedelta] = DEFAULT_TIME_TO_LIVE_SECS,
+ request_coder: Optional[coders.Coder] = None,
+ response_coder: Optional[coders.Coder] = None,
+ *,
+ kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ """
+ Args:
+ host (str): The hostname or IP address of the Redis server.
+ port (int): The port number of the Redis server.
+ time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
+ records stored in Redis. Provide an integer (in seconds) or a
+ `datetime.timedelta` object.
+ request_coder: (Optional[`coders.Coder`]) coder for encoding requests.
+ response_coder: (Optional[`coders.Coder`]) coder for decoding responses
+ received from Redis.
+ kwargs: Optional(Dict[str, Any]) additional keyword arguments that
+ are required to connect to your redis server. Same as `redis.Redis()`.
+ """
+ self._host = host
+ self._port = port
+ self._time_to_live = time_to_live
+ self._request_coder = request_coder
+ self._response_coder = response_coder
+ self._kwargs = kwargs if kwargs else {}
+ self._source_caller = None
+
+ def get_read(self):
+ """get_read returns a callback that returns a PTransform
+ for reading from the cache."""
+ ensure_coders_exist(self._request_coder)
+
+ def callback():
+ return ReadFromRedis(
+ self._host,
+ self._port,
+ time_to_live=self._time_to_live,
+ kwargs=self._kwargs,
+ request_coder=self._request_coder,
+ response_coder=self._response_coder,
+ source_caller=self._source_caller)
+
+ return callback
+
+ def get_write(self):
+ """get_write returns a callback that returns a PTransform
+ for writing to the cache."""
+ ensure_coders_exist(self._request_coder)
+
+ def callback():
+ return WriteToRedis(
+ self._host,
+ self._port,
+ time_to_live=self._time_to_live,
+ kwargs=self._kwargs,
+ request_coder=self._request_coder,
+ response_coder=self._response_coder,
+ source_caller=self._source_caller)
+
+ return callback
+
+ def has_request_coder(self) -> bool:
+ """returns True if the request coder exists."""
+ return self._request_coder is not None
+
+ def set_request_coder(self, request_coder: coders.Coder):
+ """sets the request coder to encode request for `RedisCache`."""
+ if request_coder and not self._request_coder:
+ self._request_coder = request_coder
+
+ def set_response_coder(self, response_coder: coders.Coder):
+ """sets the response coder to encode/decode response for `RedisCache`."""
+ if response_coder and not self._response_coder:
+ self._response_coder = response_coder
+
+ def set_source_caller(self, caller: Caller[RequestT, ResponseT]):
+ """sets the actual caller using the `RedisCache`."""
+ self._source_caller = caller
+
+
+class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT],
+ beam.PCollection[ResponseT]]):
+ """A :class:`RequestResponseIO` transform to read and write to APIs.
+
+ Processes an input :class:`~apache_beam.pvalue.PCollection` of requests
+ by making a call to the API as defined in :class:`Caller`'s `__call__`
+ and returns a :class:`~apache_beam.pvalue.PCollection` of responses.
+ """
+ def __init__(
+ self,
+ caller: Caller[RequestT, ResponseT],
+ timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
+ should_backoff: Optional[ShouldBackOff] = None,
+ repeater: Repeater = ExponentialBackOffRepeater(),
+ cache: Optional[Cache] = None,
+ throttler: PreCallThrottler = DefaultThrottler(),
+ ):
+ """
+ Instantiates a RequestResponseIO transform.
+
+ Args:
+ caller (~apache_beam.io.requestresponse.Caller): an implementation of
+ `Caller` object that makes call to the API.
+ timeout (float): timeout value in seconds to wait for response from API.
+ should_backoff (~apache_beam.io.requestresponse.ShouldBackOff):
+ (Optional) provides methods for backoff.
+ repeater (~apache_beam.io.requestresponse.Repeater): provides method to
+ repeat failed requests to API due to service errors. Defaults to
+ :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to
+ repeat requests with exponential backoff.
+ cache: (Optional) a :class:`apache_beam.io.requestresponse.Cache` object
+ to use the appropriate cache.
+ throttler (~apache_beam.io.requestresponse.PreCallThrottler):
+ provides methods to pre-throttle a request. Defaults to
+ :class:`apache_beam.io.requestresponse.DefaultThrottler` for
+ client-side adaptive throttling using
+ :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`
+ """
+ self._caller = caller
+ self._timeout = timeout
+ self._should_backoff = should_backoff
+ if repeater:
+ self._repeater = repeater
+ else:
+ self._repeater = NoOpsRepeater()
+ self._cache = cache
+ self._throttler = throttler
+
+ def expand(
+ self,
+ requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]:
+ # TODO(riteshghorse): handle Throttle PTransforms when available.
+
+ if self._cache:
+ self._cache.set_source_caller(caller=self._caller)
+
+ inputs = requests
+
+ if self._cache and self._cache.has_request_coder():
+ # read from cache.
+ cache_read_callback = self._cache.get_read()
+ outputs = inputs | cache_read_callback()
+ # filter responses that are None and send them to the Call transform
+ # to fetch a value from external service.
+ cached_responses = outputs | beam.ParDo(_FilterNoneCacheReadFn())
+ inputs = outputs | beam.ParDo(_FilterCacheRequestsFn())
Review Comment:
yeah that looks clean, I just followed the Java version. But I agree this
small detail won't matter. updated.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]