You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/02/28 14:50:03 UTC

[beam] branch users/damccorm/tfhub-test created (now 167ace851fc)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a change to branch users/damccorm/tfhub-test
in repository https://gitbox.apache.org/repos/asf/beam.git


      at 167ace851fc Fix tensorflowhub caching issue

This branch includes the following new commits:

     new 167ace851fc Fix tensorflowhub caching issue

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[beam] 01/01: Fix tensorflowhub caching issue

Posted by da...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/tfhub-test
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 167ace851fc250b980576f7812ae55519a609450
Author: Danny McCormick <da...@google.com>
AuthorDate: Tue Feb 28 09:49:25 2023 -0500

    Fix tensorflowhub caching issue
---
 .../ml/inference/tensorflow_inference_it_test.py   | 23 ++++++++++++++++++++++
 1 file changed, 23 insertions(+)

diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
index fb1a2964841..4e044082ac0 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
@@ -25,10 +25,12 @@ import pytest
 
 from apache_beam.io.filesystems import FileSystems
 from apache_beam.testing.test_pipeline import TestPipeline
+from pathlib import Path
 
 # pylint: disable=ungrouped-imports
 try:
   import tensorflow as tf
+  import tensorflow_hub as hub
   from apache_beam.examples.inference import tensorflow_imagenet_segmentation
   from apache_beam.examples.inference import tensorflow_mnist_classification
   from apache_beam.examples.inference import tensorflow_mnist_with_weights
@@ -42,6 +44,26 @@ def process_outputs(filepath):
   lines = [l.decode('utf-8').strip('\n') for l in lines]
   return lines
 
+def rmdir(directory):
+  directory = Path(directory)
+  for item in directory.iterdir():
+    if item.is_dir():
+      rmdir(item)
+    else:
+      item.unlink()
+  directory.rmdir()
+
+def clear_tf_hub_temp_dir(model_path):
+  # When loading a tensorflow hub using tfhub.resolve, the model is saved in a
+  # temporary directory. That file can be persisted between test runs, in which
+  # case tfhub.resolve will no-op. If the model is deleted and the file isn't,
+  # tfhub.resolve will still no-op and tf.keras.models.load_model will throw.
+  # To avoid this (and test more robustly) we delete the temporary directory
+  # entirely between runs.
+  local_path = hub.resolve(model_path)
+  rmdir(local_path)
+
+
 
 @unittest.skipIf(
     tf is None, 'Missing dependencies. '
@@ -90,6 +112,7 @@ class TensorflowInference(unittest.TestCase):
     output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
     model_path = (
         'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4')
+    clear_tf_hub_temp_dir(model_path)
     extra_opts = {
         'input': input_file,
         'output': output_file,