You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by "damccorm (via GitHub)" <gi...@apache.org> on 2023/02/08 21:55:00 UTC

[GitHub] [beam] damccorm commented on a diff in pull request #25368: [Python] Added Tensorflow Model Handler

damccorm commented on code in PR #25368:
URL: https://github.com/apache/beam/pull/25368#discussion_r1100711310


##########
sdks/python/apache_beam/ml/inference/tensorflow_inference.py:
##########
@@ -0,0 +1,245 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import enum
+from typing import Any, Union
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+
+import sys
+import numpy
+import tensorflow as tf
+
+from apache_beam.ml.inference import utils
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+
+__all__ = [
+    'TFModelHandlerNumpy',
+    'TFModelHandlerTensor',
+]
+
+TensorInferenceFn = Callable[[
+    tf.Module,
+    Sequence[Union[numpy.ndarray, tf.Tensor]],
+    Optional[Dict[str, Any]],
+    Optional[str]
+],
+                             Iterable[PredictionResult]]
+
+
+class ModelType(enum.Enum):

Review Comment:
   Since right now we only accept one model type, do we gain anything from including this as a parameter for the ModelHandlers? We can always add it later, right?
   
   Or are you planning on adding a follow up PR with more types?



##########
sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py:
##########
@@ -0,0 +1,148 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import unittest
+
+import numpy
+import pytest
+
+try:
+  import tensorflow as tf
+  from apache_beam.ml.inference.sklearn_inference_test import _compare_prediction_result
+  from apache_beam.ml.inference.base import KeyedModelHandler, PredictionResult
+  from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy, TFModelHandlerTensor
+except ImportError:
+  raise unittest.SkipTest('Tensorflow dependencies are not installed')
+
+
+class FakeTFNumpyModel:
+  def predict(self, input: numpy.ndarray):
+    return numpy.multiply(input, 10)
+
+
+class FakeTFTensorModel:
+  def predict(self, input: tf.Tensor, add=False):
+    if add:
+      return tf.math.add(tf.math.multiply(input, 10), 10)
+    return tf.math.multiply(input, 10)
+
+
+def _compare_tensor_prediction_result(x, y):
+  return tf.math.equal(x.inference, y.inference)
+
+
+class TFRunInferenceTest(unittest.TestCase):
+  def test_predict_numpy(self):
+    fake_model = FakeTFNumpyModel()
+    inference_runner = TFModelHandlerNumpy(model_uri='unused')
+    batched_examples = [numpy.array([1]), numpy.array([10]), numpy.array([100])]
+    expected_predictions = [
+        PredictionResult(numpy.array([1]), 10),
+        PredictionResult(numpy.array([10]), 100),
+        PredictionResult(numpy.array([100]), 1000)
+    ]
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_prediction_result(actual, expected))
+
+  @pytest.mark.uses_tf

Review Comment:
   I believe you can just use this marker on the class so you don't need it on every function



##########
sdks/python/apache_beam/ml/inference/tensorflow_inference.py:
##########
@@ -0,0 +1,245 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import enum
+from typing import Any, Union
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+
+import sys
+import numpy
+import tensorflow as tf
+
+from apache_beam.ml.inference import utils
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+
+__all__ = [
+    'TFModelHandlerNumpy',
+    'TFModelHandlerTensor',
+]
+
+TensorInferenceFn = Callable[[
+    tf.Module,
+    Sequence[Union[numpy.ndarray, tf.Tensor]],
+    Optional[Dict[str, Any]],
+    Optional[str]
+],
+                             Iterable[PredictionResult]]
+
+
+class ModelType(enum.Enum):
+  """Defines how a model file should be loaded."""
+  SAVED_MODEL = 1
+
+
+def _load_model(model_uri, model_type):
+  if model_type == ModelType.SAVED_MODEL:
+    return tf.keras.models.load_model(model_uri)

Review Comment:
   From https://docs.google.com/document/d/1c2rWX7fA7UAl2qabzEXg5r7_zwosHfGo6mg1gz8VdSQ/edit?pli=1&disco=AAAAoHqF1dw I thought we were going to use `TFHub.resolve` - is there a reason we're switching back to load_model?



##########
sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py:
##########
@@ -0,0 +1,148 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import unittest
+
+import numpy
+import pytest
+
+try:
+  import tensorflow as tf
+  from apache_beam.ml.inference.sklearn_inference_test import _compare_prediction_result
+  from apache_beam.ml.inference.base import KeyedModelHandler, PredictionResult
+  from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy, TFModelHandlerTensor
+except ImportError:
+  raise unittest.SkipTest('Tensorflow dependencies are not installed')
+
+
+class FakeTFNumpyModel:
+  def predict(self, input: numpy.ndarray):
+    return numpy.multiply(input, 10)
+
+
+class FakeTFTensorModel:
+  def predict(self, input: tf.Tensor, add=False):
+    if add:
+      return tf.math.add(tf.math.multiply(input, 10), 10)
+    return tf.math.multiply(input, 10)
+
+
+def _compare_tensor_prediction_result(x, y):
+  return tf.math.equal(x.inference, y.inference)
+
+
+class TFRunInferenceTest(unittest.TestCase):
+  def test_predict_numpy(self):
+    fake_model = FakeTFNumpyModel()
+    inference_runner = TFModelHandlerNumpy(model_uri='unused')
+    batched_examples = [numpy.array([1]), numpy.array([10]), numpy.array([100])]
+    expected_predictions = [
+        PredictionResult(numpy.array([1]), 10),
+        PredictionResult(numpy.array([10]), 100),
+        PredictionResult(numpy.array([100]), 1000)
+    ]
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_prediction_result(actual, expected))
+
+  @pytest.mark.uses_tf

Review Comment:
   Technically that lumps in the numpy ones that don't use tensorflow, but its probably better to run them as a group anyways, otherwise they'll get double run I think



##########
sdks/python/apache_beam/ml/inference/tensorflow_inference_test.py:
##########
@@ -0,0 +1,148 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import unittest
+
+import numpy
+import pytest
+
+try:
+  import tensorflow as tf
+  from apache_beam.ml.inference.sklearn_inference_test import _compare_prediction_result
+  from apache_beam.ml.inference.base import KeyedModelHandler, PredictionResult
+  from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy, TFModelHandlerTensor
+except ImportError:
+  raise unittest.SkipTest('Tensorflow dependencies are not installed')
+
+
+class FakeTFNumpyModel:
+  def predict(self, input: numpy.ndarray):
+    return numpy.multiply(input, 10)
+
+
+class FakeTFTensorModel:
+  def predict(self, input: tf.Tensor, add=False):
+    if add:
+      return tf.math.add(tf.math.multiply(input, 10), 10)
+    return tf.math.multiply(input, 10)
+
+
+def _compare_tensor_prediction_result(x, y):
+  return tf.math.equal(x.inference, y.inference)
+
+
+class TFRunInferenceTest(unittest.TestCase):
+  def test_predict_numpy(self):
+    fake_model = FakeTFNumpyModel()
+    inference_runner = TFModelHandlerNumpy(model_uri='unused')
+    batched_examples = [numpy.array([1]), numpy.array([10]), numpy.array([100])]
+    expected_predictions = [
+        PredictionResult(numpy.array([1]), 10),
+        PredictionResult(numpy.array([10]), 100),
+        PredictionResult(numpy.array([100]), 1000)
+    ]
+    inferences = inference_runner.run_inference(batched_examples, fake_model)
+    for actual, expected in zip(inferences, expected_predictions):
+      self.assertTrue(_compare_prediction_result(actual, expected))
+
+  @pytest.mark.uses_tf

Review Comment:
   Actually, we definitely need it on the class, otherwise the numpy ones will get skipped when we fail to import tensorflow



##########
sdks/python/apache_beam/examples/inference/README.md:
##########
@@ -374,3 +383,51 @@ True Price 31000000.0, Predicted Price 25654277.256461
 ...
 ```
 
+## MNIST digit classification with Tensorflow
+[`tensorflow_mnist_classification.py`](./tensorflow_mnist_classification.py) contains an implementation for a RunInference pipeline that performs image classification on handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) database.
+
+The pipeline reads rows of pixels corresponding to a digit, performs basic preprocessing(converts the input shape to 28x28), passes the pixels to the trained Tensorflow model with RunInference, and then writes the predictions to a text file.
+
+### Dataset and model for language modeling
+
+To use this transform, you need a dataset and model for language modeling.
+
+1. Create a file named `INPUT.csv` that contains labels and pixels to feed into the model. Each row should have comma-separated elements. The first element is the label. All other elements are pixel values. The csv should not have column headers. The content of the file should be similar to the following example:

Review Comment:
   Can we add  an example csv in gcs that users can just download as an option here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org