This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/release-cp in repository https://gitbox.apache.org/repos/asf/beam.git
commit c7323e8b9a8244e509795f74ef9eca257d079095 Author: Danny McCormick <[email protected]> AuthorDate: Tue Feb 28 15:16:54 2023 -0500 Fix tensorflowhub caching issue (#25661) * Fix tensorflowhub caching issue * Comment * lint --- .../ml/inference/tensorflow_inference_it_test.py | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py index fb1a2964841..bdc0291dd1e 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py @@ -20,6 +20,7 @@ import logging import unittest import uuid +from pathlib import Path import pytest @@ -29,6 +30,7 @@ from apache_beam.testing.test_pipeline import TestPipeline # pylint: disable=ungrouped-imports try: import tensorflow as tf + import tensorflow_hub as hub from apache_beam.examples.inference import tensorflow_imagenet_segmentation from apache_beam.examples.inference import tensorflow_mnist_classification from apache_beam.examples.inference import tensorflow_mnist_with_weights @@ -43,6 +45,27 @@ def process_outputs(filepath): return lines +def rmdir(directory): + directory = Path(directory) + for item in directory.iterdir(): + if item.is_dir(): + rmdir(item) + else: + item.unlink() + directory.rmdir() + + +def clear_tf_hub_temp_dir(model_path): + # When loading from tensorflow hub using tfhub.resolve, the model is saved in + # a temporary directory. That file can be persisted between test runs, in + # which case tfhub.resolve will no-op. If the model is deleted and the file + # isn't, tfhub.resolve will still no-op and tf.keras.models.load_model will + # throw. To avoid this (and test more robustly) we delete the temporary + # directory entirely between runs. + local_path = hub.resolve(model_path) + rmdir(local_path) + + @unittest.skipIf( tf is None, 'Missing dependencies. ' 'Test depends on tensorflow') @@ -90,6 +113,7 @@ class TensorflowInference(unittest.TestCase): output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt']) model_path = ( 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4') + clear_tf_hub_temp_dir(model_path) extra_opts = { 'input': input_file, 'output': output_file,
