lukecwik commented on a change in pull request #11203: [BEAM-9577] Define and 
implement dependency-aware artifact staging service.
URL: https://github.com/apache/beam/pull/11203#discussion_r400603582
 
 

 ##########
 File path: sdks/python/apache_beam/runners/portability/artifact_service.py
 ##########
 @@ -263,3 +279,205 @@ def _open(self, path, mode='r'):
       return filesystems.FileSystems.create(path)
     else:
       return filesystems.FileSystems.open(path)
+
+
+# The dependency-aware artifact staging and retrieval services.
+
+
+def _queue_iter(queue, end_token):
+  while True:
+    item = queue.get()
+    if item is end_token:
+      break
+    yield item
+
+
+class ArtifactRetrievalService(
+    beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer):
+
+  _DEFAULT_CHUNK_SIZE = 2 << 20
+
+  def __init__(
+      self,
+      file_reader,  # type: Callable[[str], BinaryIO],
+      chunk_size=None,
+  ):
+    self._file_reader = file_reader
+    self._chunk_size = chunk_size or self._DEFAULT_CHUNK_SIZE
+
+  def ResolveArtifact(self, request, context=None):
+    return beam_artifact_api_pb2.ResolveArtifactResponse(
+        replacements=request.artifacts)
+
+  def GetArtifact(self, request, context=None):
+    if request.artifact.type_urn == common_urns.artifact_types.FILE.urn:
+      payload = proto_utils.parse_Bytes(
+          request.artifact.type_payload,
+          beam_runner_api_pb2.ArtifactFilePayload)
+      read_handle = self._file_reader(payload.path)
+    elif request.artifact.type_urn == common_urns.artifact_types.URL.urn:
+      payload = proto_utils.parse_Bytes(
+          request.artifact.type_payload, 
beam_runner_api_pb2.ArtifactUrlPayload)
+      # TODO(Py3): Remove the unneeded contextlib wrapper.
+      read_handle = contextlib.closing(urlopen(payload.path))
+    elif request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn:
+      payload = proto_utils.parse_Bytes(
+          request.artifact.type_payload,
+          beam_runner_api_pb2.EmbeddedFilePayload)
+      read_handle = BytesIO(payload.data)
+    else:
+      raise NotImplementedError(request.artifact.type_urn)
+
+    with read_handle as fin:
+      while True:
+        chunk = fin.read(self._chunk_size)
+        if not chunk:
+          break
+        yield beam_artifact_api_pb2.GetArtifactResponse(data=chunk)
+
+
+class ArtifactStagingService(
+    beam_artifact_api_pb2_grpc.ArtifactStagingServiceServicer):
+  def __init__(
+      self,
+      file_writer,  # type: Callable[[str, Optional[str]], Tuple[BinaryIO, 
str]]
+    ):
+    self._lock = threading.Lock()
+    self._jobs_to_stage = {}
+    self._file_writer = file_writer
+
+  def register_job(self, staging_token, dependencies):
+    self._jobs_to_stage[staging_token] = list(dependencies), threading.Event()
+
+  def resolved_deps(self, staging_token, timeout=None):
+    dependencies_list, event = self._jobs_to_stage[staging_token]
+    try:
+      if not event.wait(timeout):
+        raise concurrent.futures.TimeoutError()
+      return dependencies_list
+    finally:
+      del self._jobs_to_stage[staging_token]
+
+  def ReverseArtifactRetrievalService(self, responses, context=None):
+    staging_token = next(responses).staging_token
+    dependencies, event = self._jobs_to_stage[staging_token]
+
+    requests = queue.Queue()
+
+    class FakeRetrievalService(object):
+      def ResolveArtifacts(self, request):
+        requests.put(
+            beam_artifact_api_pb2.ArtifactRequestWrapper(
+                resolve_artifact=request))
+        return next(responses).resolve_artifact_response
+
+      def GetArtifact(self, request):
+        requests.put(
+            beam_artifact_api_pb2.ArtifactRequestWrapper(get_artifact=request))
+        while True:
+          response = next(responses)
+          yield response.get_artifact_response
+          if response.is_last:
+            break
+
+    def resolve():
+      file_deps = resolve_as_files(
+          FakeRetrievalService(),
+          lambda name: self._file_writer(os.path.join(staging_token, name)),
+          dependencies)
+      dependencies[:] = file_deps
+      requests.put(None)
+      event.set()
+
+    t = threading.Thread(target=resolve)
+    t.daemon = True
+    t.start()
+
+    return _queue_iter(requests, None)
+
+
+def resolve_as_files(retrieval_service, file_writer, dependencies):
+  """Translates a set of dependencies into file-based dependencies."""
+  # Resolve until nothing changes.  This ensures that they can be fetched.
+  resolution = retrieval_service.ResolveArtifacts(
+      beam_artifact_api_pb2.ResolveArtifactRequest(
+          artifacts=dependencies,
+          # Anything fetchable will do.
+          # TODO(robertwb): Take advantage of shared filesystems, urls.
+          preferred_urns=[],
+      ))
+  if resolution.error:
+    raise RuntimeError(resolution)
+  dependencies = resolution.replacements
+
+  # Fetch each of the dependencies, using file_writer to store them as
+  # file-based artifacts.
+  # TODO(robertwb): Consider parallelizing the actual writes.
+  for dep in dependencies:
+    if dep.role_urn == common_urns.artifact_roles.STAGING_TO.urn:
+      base_name = os.path.basename(
+          proto_utils.parse_Bytes(
+              dep.role_payload,
+              beam_runner_api_pb2.ArtifactStagingToRolePayload).staged_name)
+    else:
+      base_name = None
+    unique_name = '-'.join(
+        filter(
+            None,
+            [hashlib.sha256(dep.SerializeToString()).hexdigest(), base_name]))
 
 Review comment:
   Its typically pretty easy to figure out the name of a path component so no 
need.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to