You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2023/02/28 21:35:47 UTC
[beam] 01/02: Fix tensorflowhub caching issue (#25661)
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch users/damccorm/release-cp
in repository https://gitbox.apache.org/repos/asf/beam.git
commit c7323e8b9a8244e509795f74ef9eca257d079095
Author: Danny McCormick <da...@google.com>
AuthorDate: Tue Feb 28 15:16:54 2023 -0500
Fix tensorflowhub caching issue (#25661)
* Fix tensorflowhub caching issue
* Comment
* lint
---
.../ml/inference/tensorflow_inference_it_test.py | 24 ++++++++++++++++++++++
1 file changed, 24 insertions(+)
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
index fb1a2964841..bdc0291dd1e 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference_it_test.py
@@ -20,6 +20,7 @@
import logging
import unittest
import uuid
+from pathlib import Path
import pytest
@@ -29,6 +30,7 @@ from apache_beam.testing.test_pipeline import TestPipeline
# pylint: disable=ungrouped-imports
try:
import tensorflow as tf
+ import tensorflow_hub as hub
from apache_beam.examples.inference import tensorflow_imagenet_segmentation
from apache_beam.examples.inference import tensorflow_mnist_classification
from apache_beam.examples.inference import tensorflow_mnist_with_weights
@@ -43,6 +45,27 @@ def process_outputs(filepath):
return lines
+def rmdir(directory):
+ directory = Path(directory)
+ for item in directory.iterdir():
+ if item.is_dir():
+ rmdir(item)
+ else:
+ item.unlink()
+ directory.rmdir()
+
+
+def clear_tf_hub_temp_dir(model_path):
+ # When loading from tensorflow hub using tfhub.resolve, the model is saved in
+ # a temporary directory. That file can be persisted between test runs, in
+ # which case tfhub.resolve will no-op. If the model is deleted and the file
+ # isn't, tfhub.resolve will still no-op and tf.keras.models.load_model will
+ # throw. To avoid this (and test more robustly) we delete the temporary
+ # directory entirely between runs.
+ local_path = hub.resolve(model_path)
+ rmdir(local_path)
+
+
@unittest.skipIf(
tf is None, 'Missing dependencies. '
'Test depends on tensorflow')
@@ -90,6 +113,7 @@ class TensorflowInference(unittest.TestCase):
output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
model_path = (
'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4')
+ clear_tf_hub_temp_dir(model_path)
extra_opts = {
'input': input_file,
'output': output_file,