Abacn commented on code in PR #25965:
URL: https://github.com/apache/beam/pull/25965#discussion_r1245797583


##########
sdks/python/apache_beam/io/gcp/gcsio.py:
##########
@@ -295,162 +208,92 @@ def delete_batch(self, paths):
              argument, where exception is None if the operation succeeded or
              the relevant exception if the operation failed.
     """
-    if not paths:
-      return []
-
-    paths = iter(paths)
-    result_statuses = []
-    while True:
-      paths_chunk = list(islice(paths, MAX_BATCH_OPERATION_SIZE))
-      if not paths_chunk:
-        return result_statuses
-      batch_request = BatchApiRequest(
-          batch_url=GCS_BATCH_ENDPOINT,
-          retryable_codes=retry.SERVER_ERROR_OR_TIMEOUT_CODES,
-          response_encoding='utf-8')
-      for path in paths_chunk:
-        bucket, object_path = parse_gcs_path(path)
-        request = storage.StorageObjectsDeleteRequest(
-            bucket=bucket, object=object_path)
-        batch_request.Add(self.client.objects, 'Delete', request)
-      api_calls = batch_request.Execute(self.client._http)  # pylint: 
disable=protected-access
-      for i, api_call in enumerate(api_calls):
-        path = paths_chunk[i]
-        exception = None
-        if api_call.is_error:
-          exception = api_call.exception
-          # Return success when the file doesn't exist anymore for idempotency.
-          if isinstance(exception, HttpError) and exception.status_code == 404:
-            exception = None
-        result_statuses.append((path, exception))
-    return result_statuses
+    final_results = []
+    s = 0
+    while s < len(paths):
+      if (s + MAX_BATCH_OPERATION_SIZE) < len(paths):
+        current_paths = paths[s:s + MAX_BATCH_OPERATION_SIZE]
+      else:
+        current_paths = paths[s:]
+      current_batch = self.client.batch(raise_exception=False)
+      with current_batch:
+        for path in current_paths:
+          bucket_name, blob_name = parse_gcs_path(path)
+          bucket = self.client.get_bucket(bucket_name)
+          blob = storage.Blob(blob_name, bucket)
+          blob.delete()
+
+      current_responses = current_batch._responses
+      for resp in current_responses:
+        if resp[1] == NotFound:
+          final_results.append((resp[0], None))
+        else:
+          final_results.append(resp)
+
+      s += MAX_BATCH_OPERATION_SIZE
+
+    return final_results
 
   @retry.with_exponential_backoff(
       retry_filter=retry.retry_on_server_errors_and_timeout_filter)
-  def copy(
-      self,
-      src,
-      dest,
-      dest_kms_key_name=None,
-      max_bytes_rewritten_per_call=None):
+  def copy(self, src, dest):
     """Copies the given GCS object from src to dest.
 
     Args:
       src: GCS file path pattern in the form gs://<bucket>/<name>.
       dest: GCS file path pattern in the form gs://<bucket>/<name>.
-      dest_kms_key_name: Experimental. No backwards compatibility guarantees.
-        Encrypt dest with this Cloud KMS key. If None, will use dest bucket
-        encryption defaults.
-      max_bytes_rewritten_per_call: Experimental. No backwards compatibility
-        guarantees. Each rewrite API call will return after these many bytes.
-        Used for testing.
 
     Raises:
       TimeoutError: on timeout.
     """
-    src_bucket, src_path = parse_gcs_path(src)
-    dest_bucket, dest_path = parse_gcs_path(dest)
-    request = storage.StorageObjectsRewriteRequest(
-        sourceBucket=src_bucket,
-        sourceObject=src_path,
-        destinationBucket=dest_bucket,
-        destinationObject=dest_path,
-        destinationKmsKeyName=dest_kms_key_name,
-        maxBytesRewrittenPerCall=max_bytes_rewritten_per_call)
-    response = self.client.objects.Rewrite(request)
-    while not response.done:
-      _LOGGER.debug(
-          'Rewrite progress: %d of %d bytes, %s to %s',
-          response.totalBytesRewritten,
-          response.objectSize,
-          src,
-          dest)
-      request.rewriteToken = response.rewriteToken
-      response = self.client.objects.Rewrite(request)
-      if self._rewrite_cb is not None:
-        self._rewrite_cb(response)
-
-    _LOGGER.debug('Rewrite done: %s to %s', src, dest)
-
-  # We intentionally do not decorate this method with a retry, as retrying is
-  # handled in BatchApiRequest.Execute().
-  def copy_batch(
-      self,
-      src_dest_pairs,
-      dest_kms_key_name=None,
-      max_bytes_rewritten_per_call=None):
-    """Copies the given GCS object from src to dest.
+    src_bucket_name, src_blob_name = parse_gcs_path(src)
+    dest_bucket_name, dest_blob_name= parse_gcs_path(dest, 
object_optional=True)
+    src_bucket = self.get_bucket(src_bucket_name)
+    src_blob = src_bucket.get_blob(src_blob_name)
+    if not src_blob:
+      raise NotFound("Source %s not found", src)
+    dest_bucket = self.get_bucket(dest_bucket_name)
+    if not dest_blob_name:
+      dest_blob_name = None
+    src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_blob_name)
+
+  def copy_batch(self, src_dest_pairs):
+    """Copies the given GCS objects from src to dest.
 
     Args:
       src_dest_pairs: list of (src, dest) tuples of gs://<bucket>/<name> files
                       paths to copy from src to dest, not to exceed
                       MAX_BATCH_OPERATION_SIZE in length.
-      dest_kms_key_name: Experimental. No backwards compatibility guarantees.
-        Encrypt dest with this Cloud KMS key. If None, will use dest bucket
-        encryption defaults.
-      max_bytes_rewritten_per_call: Experimental. No backwards compatibility
-        guarantees. Each rewrite call will return after these many bytes. Used
-        primarily for testing.
 
     Returns: List of tuples of (src, dest, exception) in the same order as the
              src_dest_pairs argument, where exception is None if the operation
              succeeded or the relevant exception if the operation failed.
     """
-    if not src_dest_pairs:
-      return []
-    pair_to_request = {}
-    for pair in src_dest_pairs:
-      src_bucket, src_path = parse_gcs_path(pair[0])
-      dest_bucket, dest_path = parse_gcs_path(pair[1])
-      request = storage.StorageObjectsRewriteRequest(
-          sourceBucket=src_bucket,
-          sourceObject=src_path,
-          destinationBucket=dest_bucket,
-          destinationObject=dest_path,
-          destinationKmsKeyName=dest_kms_key_name,
-          maxBytesRewrittenPerCall=max_bytes_rewritten_per_call)
-      pair_to_request[pair] = request
-    pair_to_status = {}
-    while True:
-      pairs_in_batch = list(set(src_dest_pairs) - set(pair_to_status))
-      if not pairs_in_batch:
-        break
-      batch_request = BatchApiRequest(
-          batch_url=GCS_BATCH_ENDPOINT,
-          retryable_codes=retry.SERVER_ERROR_OR_TIMEOUT_CODES,
-          response_encoding='utf-8')
-      for pair in pairs_in_batch:
-        batch_request.Add(self.client.objects, 'Rewrite', 
pair_to_request[pair])
-      api_calls = batch_request.Execute(self.client._http)  # pylint: 
disable=protected-access
-      for pair, api_call in zip(pairs_in_batch, api_calls):
-        src, dest = pair
-        response = api_call.response
-        if self._rewrite_cb is not None:
-          self._rewrite_cb(response)
-        if api_call.is_error:
-          exception = api_call.exception
-          # Translate 404 to the appropriate not found exception.
-          if isinstance(exception, HttpError) and exception.status_code == 404:
-            exception = (
-                GcsIOError(errno.ENOENT, 'Source file not found: %s' % src))
-          pair_to_status[pair] = exception
-        elif not response.done:
-          _LOGGER.debug(
-              'Rewrite progress: %d of %d bytes, %s to %s',
-              response.totalBytesRewritten,
-              response.objectSize,
-              src,
-              dest)
-          pair_to_request[pair].rewriteToken = response.rewriteToken
-        else:
-          _LOGGER.debug('Rewrite done: %s to %s', src, dest)
-          pair_to_status[pair] = None
+    final_results = []
+    s = 0
+    while s < len(src_dest_pairs):
+      if (s + MAX_BATCH_OPERATION_SIZE) < len(src_dest_pairs):
+        current_pairs = src_dest_pairs[s:s + MAX_BATCH_OPERATION_SIZE]
+      else:
+        current_pairs = src_dest_pairs[s:]
+      current_batch = self.client.batch(raise_exception=False)
+      with current_batch:
+        for pair in current_pairs:
+          src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
+          dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
+          src_bucket = self.client.get_bucket(src_bucket_name)
+          src_blob = src_bucket.get_blob(src_blob_name)
+          dest_bucket = self.client.get_bucket(dest_bucket_name)
+
+          src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)
 
-    return [(pair[0], pair[1], pair_to_status[pair]) for pair in 
src_dest_pairs]
+      final_results += current_batch._responses

Review Comment:
   The return value is no longer "Returns: List of tuples of (src, dest, 
exception) " as the pydoc states. It is now [(headers, payload)] originated 
from here: 
https://github.com/googleapis/python-storage/blob/5b492d144216177714e95645467e01c7dbc82d19/google/cloud/storage/batch.py#L356
   
   We probably still need to keep the original return types



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to