This is an automated email from the ASF dual-hosted git repository. anandinguva pushed a commit to branch fix-path in repository https://gitbox.apache.org/repos/asf/beam.git
commit 0a0d329e3ce4adafdcb9dac50597fc6fe0aa1348 Author: Anand Inguva <[email protected]> AuthorDate: Fri Feb 2 16:20:25 2024 -0500 Update artifacts fetcher to download artifacts --- sdks/python/apache_beam/ml/transforms/utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/utils.py b/sdks/python/apache_beam/ml/transforms/utils.py index fadf611b0e6..96f0f7f2f57 100644 --- a/sdks/python/apache_beam/ml/transforms/utils.py +++ b/sdks/python/apache_beam/ml/transforms/utils.py @@ -18,18 +18,28 @@ __all__ = ['ArtifactsFetcher'] import os +import tempfile import typing import tensorflow_transform as tft +from apache_beam.io.filesystems import FileSystems from apache_beam.ml.transforms import base -class ArtifactsFetcher(): +class ArtifactsFetcher: """ Utility class used to fetch artifacts from the artifact_location passed to the TFTProcessHandlers in MLTransform. + + This is intended to be used for testing purposes only. """ def __init__(self, artifact_location): + tempdir = tempfile.mkdtemp() + self._artifact_location = tempdir + # TODO: Can we use FileSystems.match() here with a * glob pattern? + # using match, does it output files and directories path? + FileSystems.copy(artifact_location, tempdir) + assert os.listdir(tempdir), f"No files found in {artifact_location}" files = os.listdir(artifact_location) files.remove(base._ATTRIBUTE_FILE_NAME) # TODO: https://github.com/apache/beam/issues/29356 @@ -43,9 +53,7 @@ class ArtifactsFetcher(): self._artifact_location = os.path.join(artifact_location, files[0]) self.transform_output = tft.TFTransformOutput(self._artifact_location) - def get_vocab_list( - self, - vocab_filename: str = 'compute_and_apply_vocab') -> typing.List[bytes]: + def get_vocab_list(self, vocab_filename: str) -> typing.List[bytes]: """ Returns list of vocabulary terms created during MLTransform. """ @@ -57,8 +65,7 @@ class ArtifactsFetcher(): vocab_filename)) from e return [x.decode('utf-8') for x in vocab_list] - def get_vocab_filepath( - self, vocab_filename: str = 'compute_and_apply_vocab') -> str: + def get_vocab_filepath(self, vocab_filename: str) -> str: """ Return the path to the vocabulary file created during MLTransform. """
