This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/tfhub-test in repository https://gitbox.apache.org/repos/asf/beam.git
commit 167ace851fc250b980576f7812ae55519a609450 Author: Danny McCormick <[email protected]> AuthorDate: Tue Feb 28 09:49:25 2023 -0500 Fix tensorflowhub caching issue --- .../ml/inference/tensorflow_inference_it_test.py | 23 ++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py index fb1a2964841..4e044082ac0 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py @@ -25,10 +25,12 @@ import pytest from apache_beam.io.filesystems import FileSystems from apache_beam.testing.test_pipeline import TestPipeline +from pathlib import Path # pylint: disable=ungrouped-imports try: import tensorflow as tf + import tensorflow_hub as hub from apache_beam.examples.inference import tensorflow_imagenet_segmentation from apache_beam.examples.inference import tensorflow_mnist_classification from apache_beam.examples.inference import tensorflow_mnist_with_weights @@ -42,6 +44,26 @@ def process_outputs(filepath): lines = [l.decode('utf-8').strip('\n') for l in lines] return lines +def rmdir(directory): + directory = Path(directory) + for item in directory.iterdir(): + if item.is_dir(): + rmdir(item) + else: + item.unlink() + directory.rmdir() + +def clear_tf_hub_temp_dir(model_path): + # When loading a tensorflow hub using tfhub.resolve, the model is saved in a + # temporary directory. That file can be persisted between test runs, in which + # case tfhub.resolve will no-op. If the model is deleted and the file isn't, + # tfhub.resolve will still no-op and tf.keras.models.load_model will throw. + # To avoid this (and test more robustly) we delete the temporary directory + # entirely between runs. + local_path = hub.resolve(model_path) + rmdir(local_path) + + @unittest.skipIf( tf is None, 'Missing dependencies. ' @@ -90,6 +112,7 @@ class TensorflowInference(unittest.TestCase): output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) model_path = ( 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4') + clear_tf_hub_temp_dir(model_path) extra_opts = { 'input': input_file, 'output': output_file,
