You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/02/28 14:50:03 UTC
[beam] branch users/damccorm/tfhub-test created (now 167ace851fc)
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a change to branch users/damccorm/tfhub-test
in repository https://gitbox.apache.org/repos/asf/beam.git
at 167ace851fc Fix tensorflowhub caching issue
This branch includes the following new commits:
new 167ace851fc Fix tensorflowhub caching issue
The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails. The revisions
listed as "add" were already present in the repository and have only
been added to this reference.
[beam] 01/01: Fix tensorflowhub caching issue
Posted by da...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch users/damccorm/tfhub-test
in repository https://gitbox.apache.org/repos/asf/beam.git
commit 167ace851fc250b980576f7812ae55519a609450
Author: Danny McCormick <da...@google.com>
AuthorDate: Tue Feb 28 09:49:25 2023 -0500
Fix tensorflowhub caching issue
---
.../ml/inference/tensorflow_inference_it_test.py | 23 ++++++++++++++++++++++
1 file changed, 23 insertions(+)
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
index fb1a2964841..4e044082ac0 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
@@ -25,10 +25,12 @@ import pytest
from apache_beam.io.filesystems import FileSystems
from apache_beam.testing.test_pipeline import TestPipeline
+from pathlib import Path
# pylint: disable=ungrouped-imports
try:
import tensorflow as tf
+ import tensorflow_hub as hub
from apache_beam.examples.inference import tensorflow_imagenet_segmentation
from apache_beam.examples.inference import tensorflow_mnist_classification
from apache_beam.examples.inference import tensorflow_mnist_with_weights
@@ -42,6 +44,26 @@ def process_outputs(filepath):
lines = [l.decode('utf-8').strip('\n') for l in lines]
return lines
+def rmdir(directory):
+ directory = Path(directory)
+ for item in directory.iterdir():
+ if item.is_dir():
+ rmdir(item)
+ else:
+ item.unlink()
+ directory.rmdir()
+
+def clear_tf_hub_temp_dir(model_path):
+ # When loading a tensorflow hub using tfhub.resolve, the model is saved in a
+ # temporary directory. That file can be persisted between test runs, in which
+ # case tfhub.resolve will no-op. If the model is deleted and the file isn't,
+ # tfhub.resolve will still no-op and tf.keras.models.load_model will throw.
+ # To avoid this (and test more robustly) we delete the temporary directory
+ # entirely between runs.
+ local_path = hub.resolve(model_path)
+ rmdir(local_path)
+
+
@unittest.skipIf(
tf is None, 'Missing dependencies. '
@@ -90,6 +112,7 @@ class TensorflowInference(unittest.TestCase):
output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
model_path = (
'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4')
+ clear_tf_hub_temp_dir(model_path)
extra_opts = {
'input': input_file,
'output': output_file,