This is an automated email from the ASF dual-hosted git repository.

anandinguva pushed a commit to branch fix-path
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 0a0d329e3ce4adafdcb9dac50597fc6fe0aa1348
Author: Anand Inguva <[email protected]>
AuthorDate: Fri Feb 2 16:20:25 2024 -0500

    Update artifacts fetcher to download artifacts
---
 sdks/python/apache_beam/ml/transforms/utils.py | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/sdks/python/apache_beam/ml/transforms/utils.py 
b/sdks/python/apache_beam/ml/transforms/utils.py
index fadf611b0e6..96f0f7f2f57 100644
--- a/sdks/python/apache_beam/ml/transforms/utils.py
+++ b/sdks/python/apache_beam/ml/transforms/utils.py
@@ -18,18 +18,28 @@
 __all__ = ['ArtifactsFetcher']
 
 import os
+import tempfile
 import typing
 
 import tensorflow_transform as tft
+from apache_beam.io.filesystems import FileSystems
 from apache_beam.ml.transforms import base
 
 
-class ArtifactsFetcher():
+class ArtifactsFetcher:
   """
   Utility class used to fetch artifacts from the artifact_location passed
   to the TFTProcessHandlers in MLTransform.
+
+  This is intended to be used for testing purposes only.
   """
   def __init__(self, artifact_location):
+    tempdir = tempfile.mkdtemp()
+    self._artifact_location = tempdir
+    # TODO: Can we use FileSystems.match() here with a * glob pattern?
+    # using match, does it output files and directories path?
+    FileSystems.copy(artifact_location, tempdir)
+    assert os.listdir(tempdir), f"No files found in {artifact_location}"
     files = os.listdir(artifact_location)
     files.remove(base._ATTRIBUTE_FILE_NAME)
     # TODO: https://github.com/apache/beam/issues/29356
@@ -43,9 +53,7 @@ class ArtifactsFetcher():
     self._artifact_location = os.path.join(artifact_location, files[0])
     self.transform_output = tft.TFTransformOutput(self._artifact_location)
 
-  def get_vocab_list(
-      self,
-      vocab_filename: str = 'compute_and_apply_vocab') -> typing.List[bytes]:
+  def get_vocab_list(self, vocab_filename: str) -> typing.List[bytes]:
     """
     Returns list of vocabulary terms created during MLTransform.
     """
@@ -57,8 +65,7 @@ class ArtifactsFetcher():
               vocab_filename)) from e
     return [x.decode('utf-8') for x in vocab_list]
 
-  def get_vocab_filepath(
-      self, vocab_filename: str = 'compute_and_apply_vocab') -> str:
+  def get_vocab_filepath(self, vocab_filename: str) -> str:
     """
     Return the path to the vocabulary file created during MLTransform.
     """

Reply via email to