You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ri...@apache.org on 2023/04/06 00:22:18 UTC
[beam] 01/01: Revert "[Python] Separate out notebooks for tensorflow with tfx and built-in model handler (#26105)"
This is an automated email from the ASF dual-hosted git repository.
riteshghorse pushed a commit to branch revert-26105-tf-notebook
in repository https://gitbox.apache.org/repos/asf/beam.git
commit 2234c8ebef629e90ec80ec37cacb20990e3f2a56
Author: Ritesh Ghorse <ri...@gmail.com>
AuthorDate: Wed Apr 5 20:22:10 2023 -0400
Revert "[Python] Separate out notebooks for tensorflow with tfx and built-in model handler (#26105)"
This reverts commit 126db67d8be2b4a91d8964e1a7c291041694e607.
---
.../beam-ml/run_inference_tensorflow.ipynb | 435 +++++++++++---
.../run_inference_tensorflow_with_tfx.ipynb | 657 ---------------------
2 files changed, 360 insertions(+), 732 deletions(-)
diff --git a/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb b/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb
index a541489a93c..034ab449a10 100644
--- a/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb
+++ b/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb
@@ -67,18 +67,21 @@
"source": [
"This notebook demonstrates the use of the [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform for [TensorFlow](https://www.tensorflow.org/).\n",
"\n",
- "Beam has built in support for 2 TensorFlow Model Handlers: [TFModelHandlerNumpy](https://github.com/apache/beam/blob/ca0787642a6b3804a742326147281c99ae8d08d2/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L91) and [TFModelHandlerTensor](https://github.com/apache/beam/blob/ca0787642a6b3804a742326147281c99ae8d08d2/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L184).\n",
- "TFModelHandlerNumpy can be used to run inference on models expecting a `numpy` array as an input while TFModelHandlerTensor can be used to run inference on models expecting a `tf.Tensor` as an input.\n",
+ "Beam has built in support for 2 Tensorflow Model Handlers: [TFModelHandlerNumpy](https://github.com/apache/beam/blob/ca0787642a6b3804a742326147281c99ae8d08d2/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L91) and [TFModelHandlerTensor](https://github.com/apache/beam/blob/ca0787642a6b3804a742326147281c99ae8d08d2/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L184).\n",
+ "TFModelHandlerNumpy can be used to run inference on models expecting a Numpy array as an input while TFModelHandlerTensor can be used to run inference on models expecting a Tensor as an input.\n",
"\n",
- "If your model needs input of type `tf.Example` see the [Apache Beam RunInference with `tfx-bsl` notebook](https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb).\n",
+ "Beam's Runinference transform also accepts a ModelHandler generated from [`tfx-bsl`](https://github.com/tensorflow/tfx-bsl) using `CreateModelHandler`.\n",
"\n",
"The Apache Beam RunInference transform is used to make predictions for\n",
- "a variety of machine learning models. For more information about the RunInference API, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation.\n",
+ "a variety of machine learning models. For more information about the RunInference API, see [Machine Learning](https://beam.apache.org/documentation/sdks/python-machine-learning) in the Apache Beam documentation.\n",
"\n",
"This notebook demonstrates the following steps:\n",
+ "- Import [`tfx-bsl`](https://github.com/tensorflow/tfx-bsl).\n",
"- Build a simple TensorFlow model.\n",
- "- Set up example data.\n",
- "- Run those examples with the built-in model handlers and get a prediction inside an Apache Beam pipeline."
+ "- Set up example data\n",
+ "- Run those examples with built-in model handlers and get a prediction inside an Apache Beam pipeline.\n",
+ "- Set up example data in TensorFlow protos.\n",
+ "- Run those examples with `tfx-bsl` model handler and get a prediction inside an Apache Beam pipeline."
],
"metadata": {
"id": "HrCtxslBGK8Z"
@@ -87,7 +90,19 @@
{
"cell_type": "markdown",
"source": [
- "To use RunInference with built-in Tensorflow model handler, install Apache Beam version 2.46 or later."
+ "## Before you begin\n",
+ "Complete the following setup steps.\n",
+ "\n",
+ "First, import `tfx-bsl`."
+ ],
+ "metadata": {
+ "id": "HrCtxslBGK8A"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "To use RunInference with built-in Tensorflow Model Handler, install Apache Beam version 2.46 or later. Creation of a ModelHandler is supported in `tfx-bsl` versions 1.10 and later."
],
"metadata": {
"id": "gVCtGOKTHMm4"
@@ -99,6 +114,7 @@
"id": "jBakpNZnAhqk"
},
"source": [
+ "!pip install tfx_bsl==1.10.0 --quiet\n",
"!pip install protobuf --quiet\n",
"!pip install apache_beam==2.46.0 --quiet"
],
@@ -124,7 +140,7 @@
"from google.colab import auth\n",
"auth.authenticate_user()"
],
- "execution_count": 2,
+ "execution_count": null,
"outputs": []
},
{
@@ -146,18 +162,23 @@
},
"source": [
"import argparse\n",
- "from typing import Dict, Text, Any, Tuple, List\n",
- "import numpy\n",
- "\n",
- "from google.protobuf import text_format\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
+ "from tensorflow_serving.apis import prediction_log_pb2\n",
+ "\n",
"import apache_beam as beam\n",
"from apache_beam.ml.inference.base import RunInference\n",
- "from apache_beam.ml.inference.base import KeyedModelHandler\n",
- "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy\n",
- "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor\n",
+ "import tfx_bsl\n",
+ "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n",
+ "from tfx_bsl.public import tfxio\n",
+ "from tfx_bsl.public.proto import model_spec_pb2\n",
+ "from tensorflow_metadata.proto.v0 import schema_pb2\n",
+ "\n",
+ "import numpy\n",
+ "\n",
+ "from typing import Dict, Text, Any, Tuple, List\n",
+ "\n",
"from apache_beam.options.pipeline_options import PipelineOptions\n",
"\n",
"project = \"PROJECT_ID\"\n",
@@ -165,7 +186,7 @@
"\n",
"save_model_dir_multiply = f'gs://{bucket}/tfx-inference/model/multiply_five/v1/'\n"
],
- "execution_count": 10,
+ "execution_count": null,
"outputs": []
},
{
@@ -196,7 +217,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "SH7iq3zeBBJ-",
- "outputId": "e15cab6b-1271-4b0b-bac3-ba76f8991077"
+ "outputId": "2fb860fe-4420-4266-a51b-80e0c296b0fa"
},
"source": [
"# Create training data that represents the 5 times multiplication table for the numbers 0 to 99.\n",
@@ -213,7 +234,7 @@
"model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n",
"model.summary()"
],
- "execution_count": 6,
+ "execution_count": null,
"outputs": [
{
"output_type": "stream",
@@ -254,7 +275,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "5XkIYXhJBFmS",
- "outputId": "724cad1b-58f6-4e97-f7ec-9526297a108e"
+ "outputId": "599f8a10-5923-44ae-f1b2-f6d86a06c2ad"
},
"source": [
"model.fit(x, y, epochs=500, verbose=0)\n",
@@ -265,18 +286,18 @@
"print('Test Examples ' + str(test_examples))\n",
"print('Predictions ' + str(predictions))"
],
- "execution_count": 7,
+ "execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
- "1/1 [==============================] - 0s 64ms/step\n",
+ "1/1 [==============================] - 0s 71ms/step\n",
"Test Examples [20, 40, 60, 90]\n",
- "Predictions [[ 51.815357]\n",
- " [101.63492 ]\n",
- " [151.45448 ]\n",
- " [226.18384 ]]\n"
+ "Predictions [[ 34.466846]\n",
+ " [ 66.937996]\n",
+ " [ 99.409134]\n",
+ " [148.11584 ]]\n"
]
}
]
@@ -300,7 +321,7 @@
"metadata": {
"id": "2JbE7WkGcAkK"
},
- "execution_count": 8,
+ "execution_count": null,
"outputs": []
},
{
@@ -316,6 +337,9 @@
{
"cell_type": "code",
"source": [
+ "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy\n",
+ "import apache_beam as beam\n",
+ "\n",
"class FormatOutput(beam.DoFn):\n",
" def process(self, element, *args, **kwargs):\n",
" yield \"example is {example} prediction is {prediction}\".format(example=element.example, prediction=element.inference)\n",
@@ -332,60 +356,294 @@
],
"metadata": {
"colab": {
- "base_uri": "https://localhost:8080/",
- "height": 124
+ "base_uri": "https://localhost:8080/"
},
"id": "St07XoibcQSb",
- "outputId": "028fb751-1f45-4c7b-da3f-5a3e31312798"
+ "outputId": "d373b1f9-00dc-4704-bf93-a27a74bf3673"
},
- "execution_count": 9,
+ "execution_count": null,
"outputs": [
{
"output_type": "stream",
- "name": "stderr",
+ "name": "stdout",
"text": [
- "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n"
+ "example is 20.0 prediction is [102.867615]\n",
+ "example is 40.0 prediction is [201.72066]\n",
+ "example is 60.0 prediction is [300.5737]\n",
+ "example is 90.0 prediction is [448.85324]\n"
]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## KeyedModelHandler with TensorFlow using TFModelHandlerNumpy\n",
+ "\n",
+ "By default, the `ModelHandler` does not expect a key.\n",
+ "\n",
+ "* If you know that keys are associated with your examples, wrap the model handler with `beam.KeyedModelHandler`.\n",
+ "* If you don't know whether keys are associated with your examples, use `beam.MaybeKeyedModelHandler`."
+ ],
+ "metadata": {
+ "id": "tRLArcjOcYuO"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from apache_beam.ml.inference.base import KeyedModelHandler\n",
+ "from google.protobuf import text_format\n",
+ "import tensorflow as tf\n",
+ "from typing import Tuple\n",
+ "\n",
+ "class FormatOutputKeyed(FormatOutput):\n",
+ " # To simplify, inherit from FormatOutput.\n",
+ " def process(self, tuple_in: Tuple):\n",
+ " key, element = tuple_in\n",
+ " output = super().process(element)\n",
+ " yield \"{} : {}\".format(key, [op for op in output])\n",
+ "\n",
+ "examples = numpy.array([(1,20), (2,40), (3,60), (4,90)], dtype=numpy.float32)\n",
+ "keyed_model_handler = KeyedModelHandler(TFModelHandlerNumpy(save_model_dir_multiply))\n",
+ "with beam.Pipeline() as p:\n",
+ " _ = (p | 'CreateExamples' >> beam.Create(examples)\n",
+ " | RunInference(keyed_model_handler)\n",
+ " | beam.ParDo(FormatOutputKeyed())\n",
+ " | beam.Map(print)\n",
+ " )"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "P6l9RwL2cAW3",
+ "outputId": "13cc318e-51b1-4b1e-ba8b-7e7cbe1956e8"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "1.0 : ['example is 20.0 prediction is [102.867615]']\n",
+ "2.0 : ['example is 40.0 prediction is [201.72066]']\n",
+ "3.0 : ['example is 60.0 prediction is [300.5737]']\n",
+ "4.0 : ['example is 90.0 prediction is [448.85324]']\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## RunInference with Tensorflow using tfx-bsl\n",
+ "In versions 1.10.0 and later of `tfx-bsl`, you can\n",
+ "create a TensorFlow `ModelHandler` for use with Apache Beam. For more information about the RunInference API, see [Machine Learning](https://beam.apache.org/documentation/sdks/python-machine-learning) in the Apache Beam documentation.\n",
+ "\n",
+ "### Populate the data in a TensorFlow proto\n",
+ "\n",
+ "Tensorflow data uses protos. If you are loading from a file, helpers exist for this step. Because this example uses generated data, this code populates a proto."
+ ],
+ "metadata": {
+ "id": "dEmleqiH3t71"
+ }
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "XvKc9kQilPjx"
+ },
+ "source": [
+ "# This example shows a proto that converts the samples and labels into\n",
+ "# tensors usable by TensorFlow.\n",
+ "\n",
+ "class ExampleProcessor:\n",
+ " def create_example_with_label(self, feature: numpy.float32,\n",
+ " label: numpy.float32)-> tf.train.Example:\n",
+ " return tf.train.Example(\n",
+ " features=tf.train.Features(\n",
+ " feature={'x': self.create_feature(feature),\n",
+ " 'y' : self.create_feature(label)\n",
+ " }))\n",
+ "\n",
+ " def create_example(self, feature: numpy.float32):\n",
+ " return tf.train.Example(\n",
+ " features=tf.train.Features(\n",
+ " feature={'x' : self.create_feature(feature)})\n",
+ " )\n",
+ "\n",
+ " def create_feature(self, element: numpy.float32):\n",
+ " return tf.train.Feature(float_list=tf.train.FloatList(value=[element]))\n",
+ "\n",
+ "# Create a labeled example file for the 5 times table.\n",
+ "\n",
+ "example_five_times_table = 'example_five_times_table.tfrecord'\n",
+ "\n",
+ "with tf.io.TFRecordWriter(example_five_times_table) as writer:\n",
+ " for i in zip(x, y):\n",
+ " example = ExampleProcessor().create_example_with_label(\n",
+ " feature=i[0], label=i[1])\n",
+ " writer.write(example.SerializeToString())\n",
+ "\n",
+ "# Create a file containing the values to predict.\n",
+ "\n",
+ "predict_values_five_times_table = 'predict_values_five_times_table.tfrecord'\n",
+ "\n",
+ "with tf.io.TFRecordWriter(predict_values_five_times_table) as writer:\n",
+ " for i in value_to_predict:\n",
+ " example = ExampleProcessor().create_example(feature=i)\n",
+ " writer.write(example.SerializeToString())"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Fit The Model\n",
+ "\n",
+ "This step builds a model. Because RunInference requires pretrained models, this segment builds a usable model."
+ ],
+ "metadata": {
+ "id": "G-sAu3cf31f3"
+ }
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "AnbrxXPKeAOQ",
+ "outputId": "011b2fd3-722e-43b3-c914-c131eaa48bbc"
+ },
+ "source": [
+ "RAW_DATA_TRAIN_SPEC = {\n",
+ "'x': tf.io.FixedLenFeature([], tf.float32),\n",
+ "'y': tf.io.FixedLenFeature([], tf.float32)\n",
+ "}\n",
+ "\n",
+ "dataset = tf.data.TFRecordDataset(example_five_times_table)\n",
+ "dataset = dataset.map(lambda e : tf.io.parse_example(e, RAW_DATA_TRAIN_SPEC))\n",
+ "dataset = dataset.map(lambda t : (t['x'], t['y']))\n",
+ "dataset = dataset.batch(100)\n",
+ "dataset = dataset.repeat()\n",
+ "\n",
+ "model.fit(dataset, epochs=5000, steps_per_epoch=1, verbose=0)"
+ ],
+ "execution_count": null,
+ "outputs": [
{
- "output_type": "display_data",
+ "output_type": "execute_result",
"data": {
- "application/javascript": [
- "\n",
- " if (typeof window.interactive_beam_jquery == 'undefined') {\n",
- " var jqueryScript = document.createElement('script');\n",
- " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
- " jqueryScript.type = 'text/javascript';\n",
- " jqueryScript.onload = function() {\n",
- " var datatableScript = document.createElement('script');\n",
- " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
- " datatableScript.type = 'text/javascript';\n",
- " datatableScript.onload = function() {\n",
- " window.interactive_beam_jquery = jQuery.noConflict(true);\n",
- " window.interactive_beam_jquery(document).ready(function($){\n",
- " \n",
- " });\n",
- " }\n",
- " document.head.appendChild(datatableScript);\n",
- " };\n",
- " document.head.appendChild(jqueryScript);\n",
- " } else {\n",
- " window.interactive_beam_jquery(document).ready(function($){\n",
- " \n",
- " });\n",
- " }"
+ "text/plain": [
+ "<keras.callbacks.History at 0x7fa02b413130>"
]
},
- "metadata": {}
+ "metadata": {},
+ "execution_count": 26
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Save the model\n",
+ "\n",
+ "This step shows how to save your model."
+ ],
+ "metadata": {
+ "id": "r4dpR6dQ4JwX"
+ }
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "fYvrIYO3qiJx"
+ },
+ "source": [
+ "RAW_DATA_PREDICT_SPEC = {\n",
+ "'x': tf.io.FixedLenFeature([], tf.float32),\n",
+ "}\n",
+ "\n",
+ "# tf.function compiles the function into a callable TF graph.\n",
+ "# RunInference relies on calling a TF graph as a model.\n",
+ "# Note: The input signature should be type tf.string as supported by\n",
+ "# tfx-bsl ModelHandlers.\n",
+ "@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string , name='examples')])\n",
+ "def serve_tf_examples_fn(serialized_tf_examples):\n",
+ " \"\"\"Returns the output to be used in the serving signature.\"\"\"\n",
+ " features = tf.io.parse_example(serialized_tf_examples, RAW_DATA_PREDICT_SPEC)\n",
+ " return model(features, training=False)\n",
+ "\n",
+ "signature = {'serving_default': serve_tf_examples_fn}\n",
+ "\n",
+ "# Signatures define the input and output types for a computation. The optional\n",
+ "# save signatures argument controls which methods in obj will be available to\n",
+ "# programs which consume SavedModels, for example, serving APIs.\n",
+ "# See https://www.tensorflow.org/api_docs/python/tf/saved_model/save\n",
+ "tf.keras.models.save_model(model, save_model_dir_multiply, signatures=signature)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run the Pipeline\n",
+ "Use the following code to run the pipeline.\n",
+ "\n",
+ "`FormatOutput` demonstrates how to extract values from the output protos.\n",
+ "\n",
+ "`CreateModelHandler` demonstrates the model handler that needs to be passed into the Apache Beam RunInference API."
+ ],
+ "metadata": {
+ "id": "P2UMmbNW4YQV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "PzjmXM_KvqHY",
+ "outputId": "67e61086-a83e-410c-c4e6-e6af77ad82bb"
+ },
+ "source": [
+ "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n",
+ "\n",
+ "class FormatOutput(beam.DoFn):\n",
+ " def process(self, element: prediction_log_pb2.PredictionLog):\n",
+ " predict_log = element.predict_log\n",
+ " input_value = tf.train.Example.FromString(predict_log.request.inputs['examples'].string_val[0])\n",
+ " input_float_value = input_value.features.feature['x'].float_list.value[0]\n",
+ " output_value = predict_log.response.outputs\n",
+ " output_float_value = output_value['output_0'].float_val[0]\n",
+ " yield (f\"example is {input_float_value:.2f} prediction is {output_float_value:.2f}\")\n",
+ "\n",
+ "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n",
+ "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n",
+ "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n",
+ "model_handler = CreateModelHandler(inference_spec_type)\n",
+ "with beam.Pipeline() as p:\n",
+ " _ = (p | tfexample_beam_record.RawRecordBeamSource()\n",
+ " | RunInference(model_handler)\n",
+ " | beam.ParDo(FormatOutput())\n",
+ " | beam.Map(print)\n",
+ " )"
+ ],
+ "execution_count": null,
+ "outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
- "example is 20.0 prediction is [51.815357]\n",
- "example is 40.0 prediction is [101.63492]\n",
- "example is 60.0 prediction is [151.45448]\n",
- "example is 90.0 prediction is [226.18384]\n"
+ "example is 20.00 prediction is 100.00\n",
+ "example is 40.00 prediction is 200.00\n",
+ "example is 60.00 prediction is 300.01\n",
+ "example is 90.00 prediction is 450.01\n"
]
}
]
@@ -393,29 +651,56 @@
{
"cell_type": "markdown",
"source": [
- "## KeyedModelHandler with TensorFlow using TFModelHandlerNumpy\n",
+ "## KeyedModelHandler with TensorFlow using tfx-bsl\n",
"\n",
"By default, the `ModelHandler` does not expect a key.\n",
"\n",
"* If you know that keys are associated with your examples, wrap the model handler with `beam.KeyedModelHandler`.\n",
- "* If you don't know whether keys are associated with your examples, use `beam.MaybeKeyedModelHandler`."
+ "* If you don't know whether keys are associated with your examples, use `beam.MaybeKeyedModelHandler`.\n",
+ "\n",
+ "In addition to demonstrating how to use a keyed model handler, this step demonstrates how to use `tfx-bsl` examples."
],
"metadata": {
- "id": "tRLArcjOcYuO"
+ "id": "IXikjkGdHm9n"
}
},
{
"cell_type": "code",
"source": [
+ "from apache_beam.ml.inference.base import KeyedModelHandler\n",
+ "from google.protobuf import text_format\n",
+ "import tensorflow as tf\n",
+ "\n",
"class FormatOutputKeyed(FormatOutput):\n",
" # To simplify, inherit from FormatOutput.\n",
" def process(self, tuple_in: Tuple):\n",
" key, element = tuple_in\n",
" output = super().process(element)\n",
- " yield \"{} : {}\".format(key, [op for op in output])\n",
+ " yield ' : '.join([key, next(output)])\n",
"\n",
- "examples = numpy.array([(1,20), (2,40), (3,60), (4,90)], dtype=numpy.float32)\n",
- "keyed_model_handler = KeyedModelHandler(TFModelHandlerNumpy(save_model_dir_multiply))\n",
+ "def make_example(num):\n",
+ " # Return keyed values in the form of (key num, example).\n",
+ " key = f'key {num}'\n",
+ " tf_proto = text_format.Parse(\n",
+ " \"\"\"\n",
+ " features {\n",
+ " feature {key: \"x\" value { float_list { value: %f }}}\n",
+ " }\n",
+ " \"\"\"% num, tf.train.Example())\n",
+ " return (key, tf_proto)\n",
+ "\n",
+ "# Make a list of examples of type tf.train.Example.\n",
+ "examples = [\n",
+ " make_example(5.0),\n",
+ " make_example(50.0),\n",
+ " make_example(40.0),\n",
+ " make_example(100.0)\n",
+ "]\n",
+ "\n",
+ "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n",
+ "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n",
+ "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n",
+ "keyed_model_handler = KeyedModelHandler(CreateModelHandler(inference_spec_type))\n",
"with beam.Pipeline() as p:\n",
" _ = (p | 'CreateExamples' >> beam.Create(examples)\n",
" | RunInference(keyed_model_handler)\n",
@@ -427,19 +712,19 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
- "id": "P6l9RwL2cAW3",
- "outputId": "03459fea-7d0a-4501-93cb-18bbad915d13"
+ "id": "KPtE3fmdJQry",
+ "outputId": "8729f479-e347-4243-b8de-757efd28dba7"
},
- "execution_count": 11,
+ "execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
- "1.0 : ['example is 20.0 prediction is [51.815357]']\n",
- "2.0 : ['example is 40.0 prediction is [101.63492]']\n",
- "3.0 : ['example is 60.0 prediction is [151.45448]']\n",
- "4.0 : ['example is 90.0 prediction is [226.18384]']\n"
+ "key 5.0 : example is 5.00 prediction is 25.00\n",
+ "key 50.0 : example is 50.00 prediction is 250.01\n",
+ "key 40.0 : example is 40.00 prediction is 200.00\n",
+ "key 100.0 : example is 100.00 prediction is 500.01\n"
]
}
]
diff --git a/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb b/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb
deleted file mode 100644
index fe9fc2b288a..00000000000
--- a/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb
+++ /dev/null
@@ -1,657 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": [],
- "collapsed_sections": [
- "X80jy3FqHjK4",
- "40qtP6zJuMXm",
- "YzvZWEv-1oiK",
- "rIwD_qEpX7Gu",
- "O_a0-4Gb19cy",
- "G-sAu3cf31f3",
- "r4dpR6dQ4JwX",
- "P2UMmbNW4YQV"
- ]
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "accelerator": "GPU"
- },
- "cells": [
- {
- "cell_type": "code",
- "source": [
- "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
- "\n",
- "# Licensed to the Apache Software Foundation (ASF) under one\n",
- "# or more contributor license agreements. See the NOTICE file\n",
- "# distributed with this work for additional information\n",
- "# regarding copyright ownership. The ASF licenses this file\n",
- "# to you under the Apache License, Version 2.0 (the\n",
- "# \"License\"); you may not use this file except in compliance\n",
- "# with the License. You may obtain a copy of the License at\n",
- "#\n",
- "# http://www.apache.org/licenses/LICENSE-2.0\n",
- "#\n",
- "# Unless required by applicable law or agreed to in writing,\n",
- "# software distributed under the License is distributed on an\n",
- "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
- "# KIND, either express or implied. See the License for the\n",
- "# specific language governing permissions and limitations\n",
- "# under the License"
- ],
- "metadata": {
- "cellView": "form",
- "id": "fFjof1NgAJwu"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "A8xNRyZMW1yK"
- },
- "source": [
- "# Apache Beam RunInference with `tfx-bsl`\n",
- "\n",
- "<table align=\"left\">\n",
- " <td>\n",
- " <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
- " </td>\n",
- " <td>\n",
- " <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_tensorflow_with_tfx.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
- " </td>\n",
- "</table>\n"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "This notebook demonstrates the use of the [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform for [TensorFlow](https://www.tensorflow.org/) using [`tfx-bsl`](https://github.com/tensorflow/tfx-bsl).\n",
- "\n",
- "Use this approach when the trained model requires a `tf.Example` input. If you have `numpy` or `tf.Tensor` inputs, see the [Apache Beam RunInference with TensorFlow notebook](https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_tensorflow.ipynb) instead which demonstrates the use built-in model handlers for TensorFlow in Beam SDK starting 2.46.0.\n",
- "\n",
- "The Apache Beam RunInference transform accepts a model handler generated from [`tfx-bsl`](https://github.com/tensorflow/tfx-bsl) by using `CreateModelHandler`.\n",
- "\n",
- "The Apache Beam RunInference transform is used to make predictions for\n",
- "a variety of machine learning models. For more information about the RunInference API, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation.\n",
- "\n",
- "This notebook demonstrates the following steps:\n",
- "- Import [`tfx-bsl`](https://github.com/tensorflow/tfx-bsl).\n",
- "- Build a simple TensorFlow model.\n",
- "- Set up example data.\n",
- "- Run those examples with the `tfx-bsl` model handler and get a prediction inside an Apache Beam pipeline."
- ],
- "metadata": {
- "id": "HrCtxslBGK8Z"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Before you begin\n",
- "Complete the following setup steps.\n",
- "\n",
- "First, import `tfx-bsl`."
- ],
- "metadata": {
- "id": "HrCtxslBGK8A"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "Creation of a ModelHandler is supported in `tfx-bsl` versions 1.10 and later."
- ],
- "metadata": {
- "id": "gVCtGOKTHMm4"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "jBakpNZnAhqk"
- },
- "source": [
- "!pip install tfx_bsl==1.10.0 --quiet\n",
- "!pip install protobuf --quiet\n",
- "!pip install apache_beam --quiet"
- ],
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Authenticate with Google Cloud\n",
- "This notebook relies on saving your model to Google Cloud. To use your Google Cloud account, authenticate this notebook."
- ],
- "metadata": {
- "id": "X80jy3FqHjK4"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "Kz9sccyGBqz3"
- },
- "source": [
- "from google.colab import auth\n",
- "auth.authenticate_user()"
- ],
- "execution_count": 2,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Import dependencies and set up your bucket\n",
- "Replace `PROJECT_ID` and `BUCKET_NAME` with the ID of your project and the name of your bucket.\n",
- "\n",
- "**Important**: If an error occurs, restart your runtime."
- ],
- "metadata": {
- "id": "40qtP6zJuMXm"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "eEle839_Akqx"
- },
- "source": [
- "import argparse\n",
- "\n",
- "import tensorflow as tf\n",
- "from tensorflow import keras\n",
- "from tensorflow_serving.apis import prediction_log_pb2\n",
- "\n",
- "import apache_beam as beam\n",
- "from apache_beam.ml.inference.base import RunInference\n",
- "import tfx_bsl\n",
- "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n",
- "from tfx_bsl.public import tfxio\n",
- "from tfx_bsl.public.proto import model_spec_pb2\n",
- "from tensorflow_metadata.proto.v0 import schema_pb2\n",
- "\n",
- "import numpy\n",
- "\n",
- "from typing import Dict, Text, Any, Tuple, List\n",
- "\n",
- "from apache_beam.options.pipeline_options import PipelineOptions\n",
- "\n",
- "project = \"PROJECT_ID\"\n",
- "bucket = \"BUCKET_NAME\"\n",
- "\n",
- "save_model_dir_multiply = f'gs://{bucket}/tfx-inference/model/multiply_five/v1/'\n"
- ],
- "execution_count": 12,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Create and test a simple model\n",
- "\n",
- "This step creates and tests a model that predicts the 5 times table."
- ],
- "metadata": {
- "id": "YzvZWEv-1oiK"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "rIwD_qEpX7Gu"
- },
- "source": [
- "### Create the model\n",
- "Create training data and build a linear regression model."
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "SH7iq3zeBBJ-",
- "outputId": "c5adb7ec-285b-401e-f9be-1e9b83c6d0ba"
- },
- "source": [
- "# Create training data that represents the 5 times multiplication table for the numbers 0 to 99.\n",
- "# x is the data and y is the labels.\n",
- "x = numpy.arange(0, 100) # Examples\n",
- "y = x * 5 # Labels\n",
- "\n",
- "# Build a simple linear regression model.\n",
- "# Note that the model has a shape of (1) for its input layer and expects a single int64 value.\n",
- "input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')\n",
- "output_layer= keras.layers.Dense(1)(input_layer)\n",
- "\n",
- "model = keras.Model(input_layer, output_layer)\n",
- "model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n",
- "model.summary()"
- ],
- "execution_count": 4,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Model: \"model\"\n",
- "_________________________________________________________________\n",
- " Layer (type) Output Shape Param # \n",
- "=================================================================\n",
- " x (InputLayer) [(None, 1)] 0 \n",
- " \n",
- " dense (Dense) (None, 1) 2 \n",
- " \n",
- "=================================================================\n",
- "Total params: 2\n",
- "Trainable params: 2\n",
- "Non-trainable params: 0\n",
- "_________________________________________________________________\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Test the model\n",
- "\n",
- "This step tests the model that you created."
- ],
- "metadata": {
- "id": "O_a0-4Gb19cy"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "5XkIYXhJBFmS",
- "outputId": "e3bb5079-5cb8-4fe4-eb8d-d3d13d5f9f0c"
- },
- "source": [
- "model.fit(x, y, epochs=500, verbose=0)\n",
- "test_examples =[20, 40, 60, 90]\n",
- "value_to_predict = numpy.array(test_examples, dtype=numpy.float32)\n",
- "predictions = model.predict(value_to_predict)\n",
- "\n",
- "print('Test Examples ' + str(test_examples))\n",
- "print('Predictions ' + str(predictions))"
- ],
- "execution_count": 6,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "1/1 [==============================] - 0s 94ms/step\n",
- "Test Examples [20, 40, 60, 90]\n",
- "Predictions [[ 9.201942]\n",
- " [16.40566 ]\n",
- " [23.609379]\n",
- " [34.41496 ]]\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## RunInference with Tensorflow using `tfx-bsl`\n",
- "In versions 1.10.0 and later of `tfx-bsl`, you can\n",
- "create a TensorFlow `ModelHandler` for use with Apache Beam. For more information about the RunInference API, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation.\n",
- "\n",
- "### Populate the data in a TensorFlow proto\n",
- "\n",
- "Tensorflow data uses protos. If you are loading from a file, helpers exist for this step. Because this example uses generated data, this code populates a proto."
- ],
- "metadata": {
- "id": "dEmleqiH3t71"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "XvKc9kQilPjx"
- },
- "source": [
- "# This example shows a proto that converts the samples and labels into\n",
- "# tensors usable by TensorFlow.\n",
- "\n",
- "class ExampleProcessor:\n",
- " def create_example_with_label(self, feature: numpy.float32,\n",
- " label: numpy.float32)-> tf.train.Example:\n",
- " return tf.train.Example(\n",
- " features=tf.train.Features(\n",
- " feature={'x': self.create_feature(feature),\n",
- " 'y' : self.create_feature(label)\n",
- " }))\n",
- "\n",
- " def create_example(self, feature: numpy.float32):\n",
- " return tf.train.Example(\n",
- " features=tf.train.Features(\n",
- " feature={'x' : self.create_feature(feature)})\n",
- " )\n",
- "\n",
- " def create_feature(self, element: numpy.float32):\n",
- " return tf.train.Feature(float_list=tf.train.FloatList(value=[element]))\n",
- "\n",
- "# Create a labeled example file for the 5 times table.\n",
- "\n",
- "example_five_times_table = 'example_five_times_table.tfrecord'\n",
- "\n",
- "with tf.io.TFRecordWriter(example_five_times_table) as writer:\n",
- " for i in zip(x, y):\n",
- " example = ExampleProcessor().create_example_with_label(\n",
- " feature=i[0], label=i[1])\n",
- " writer.write(example.SerializeToString())\n",
- "\n",
- "# Create a file containing the values to predict.\n",
- "\n",
- "predict_values_five_times_table = 'predict_values_five_times_table.tfrecord'\n",
- "\n",
- "with tf.io.TFRecordWriter(predict_values_five_times_table) as writer:\n",
- " for i in value_to_predict:\n",
- " example = ExampleProcessor().create_example(feature=i)\n",
- " writer.write(example.SerializeToString())"
- ],
- "execution_count": 7,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Fit The Model\n",
- "\n",
- "This step builds a model. Because RunInference requires pretrained models, this segment builds a usable model."
- ],
- "metadata": {
- "id": "G-sAu3cf31f3"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "AnbrxXPKeAOQ",
- "outputId": "42439aac-3a10-4e86-829f-44332aad6173"
- },
- "source": [
- "RAW_DATA_TRAIN_SPEC = {\n",
- "'x': tf.io.FixedLenFeature([], tf.float32),\n",
- "'y': tf.io.FixedLenFeature([], tf.float32)\n",
- "}\n",
- "\n",
- "dataset = tf.data.TFRecordDataset(example_five_times_table)\n",
- "dataset = dataset.map(lambda e : tf.io.parse_example(e, RAW_DATA_TRAIN_SPEC))\n",
- "dataset = dataset.map(lambda t : (t['x'], t['y']))\n",
- "dataset = dataset.batch(100)\n",
- "dataset = dataset.repeat()\n",
- "\n",
- "model.fit(dataset, epochs=5000, steps_per_epoch=1, verbose=0)"
- ],
- "execution_count": 8,
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "<keras.callbacks.History at 0x7f6960074c70>"
- ]
- },
- "metadata": {},
- "execution_count": 8
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Save the model\n",
- "\n",
- "This step shows how to save your model."
- ],
- "metadata": {
- "id": "r4dpR6dQ4JwX"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "fYvrIYO3qiJx"
- },
- "source": [
- "RAW_DATA_PREDICT_SPEC = {\n",
- "'x': tf.io.FixedLenFeature([], tf.float32),\n",
- "}\n",
- "\n",
- "# tf.function compiles the function into a callable TF graph.\n",
- "# RunInference relies on calling a TF graph as a model.\n",
- "# Note: The input signature should be type tf.string as supported by\n",
- "# tfx-bsl ModelHandlers.\n",
- "@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string , name='examples')])\n",
- "def serve_tf_examples_fn(serialized_tf_examples):\n",
- " \"\"\"Returns the output to be used in the serving signature.\"\"\"\n",
- " features = tf.io.parse_example(serialized_tf_examples, RAW_DATA_PREDICT_SPEC)\n",
- " return model(features, training=False)\n",
- "\n",
- "signature = {'serving_default': serve_tf_examples_fn}\n",
- "\n",
- "# Signatures define the input and output types for a computation. The optional\n",
- "# save signatures argument controls which methods in obj will be available to\n",
- "# programs which consume SavedModels, for example, serving APIs.\n",
- "# See https://www.tensorflow.org/api_docs/python/tf/saved_model/save\n",
- "tf.keras.models.save_model(model, save_model_dir_multiply, signatures=signature)"
- ],
- "execution_count": 9,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Run the Pipeline\n",
- "Use the following code to run the pipeline.\n",
- "\n",
- "`FormatOutput` demonstrates how to extract values from the output protos.\n",
- "\n",
- "`CreateModelHandler` demonstrates the model handler that needs to be passed into the Apache Beam RunInference API."
- ],
- "metadata": {
- "id": "P2UMmbNW4YQV"
- }
- },
- {
- "cell_type": "code",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 193
- },
- "id": "PzjmXM_KvqHY",
- "outputId": "0aa60bef-52a0-4ce2-d228-3fac977d59e0"
- },
- "source": [
- "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n",
- "\n",
- "class FormatOutput(beam.DoFn):\n",
- " def process(self, element: prediction_log_pb2.PredictionLog):\n",
- " predict_log = element.predict_log\n",
- " input_value = tf.train.Example.FromString(predict_log.request.inputs['examples'].string_val[0])\n",
- " input_float_value = input_value.features.feature['x'].float_list.value[0]\n",
- " output_value = predict_log.response.outputs\n",
- " output_float_value = output_value['output_0'].float_val[0]\n",
- " yield (f\"example is {input_float_value:.2f} prediction is {output_float_value:.2f}\")\n",
- "\n",
- "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n",
- "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n",
- "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n",
- "model_handler = CreateModelHandler(inference_spec_type)\n",
- "with beam.Pipeline() as p:\n",
- " _ = (p | tfexample_beam_record.RawRecordBeamSource()\n",
- " | RunInference(model_handler)\n",
- " | beam.ParDo(FormatOutput())\n",
- " | beam.Map(print)\n",
- " )"
- ],
- "execution_count": 10,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n"
- ]
- },
- {
- "output_type": "display_data",
- "data": {
- "application/javascript": [
- "\n",
- " if (typeof window.interactive_beam_jquery == 'undefined') {\n",
- " var jqueryScript = document.createElement('script');\n",
- " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
- " jqueryScript.type = 'text/javascript';\n",
- " jqueryScript.onload = function() {\n",
- " var datatableScript = document.createElement('script');\n",
- " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
- " datatableScript.type = 'text/javascript';\n",
- " datatableScript.onload = function() {\n",
- " window.interactive_beam_jquery = jQuery.noConflict(true);\n",
- " window.interactive_beam_jquery(document).ready(function($){\n",
- " \n",
- " });\n",
- " }\n",
- " document.head.appendChild(datatableScript);\n",
- " };\n",
- " document.head.appendChild(jqueryScript);\n",
- " } else {\n",
- " window.interactive_beam_jquery(document).ready(function($){\n",
- " \n",
- " });\n",
- " }"
- ]
- },
- "metadata": {}
- },
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "WARNING:tensorflow:From /usr/local/lib/python3.9/dist-packages/tfx_bsl/beam/run_inference.py:615: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.\n",
- "Instructions for updating:\n",
- "Use `tf.saved_model.load` instead.\n",
- "WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.\n"
- ]
- },
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "example is 20.00 prediction is 104.36\n",
- "example is 40.00 prediction is 202.62\n",
- "example is 60.00 prediction is 300.87\n",
- "example is 90.00 prediction is 448.26\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## KeyedModelHandler with TensorFlow using `tfx-bsl`\n",
- "\n",
- "By default, the `ModelHandler` does not expect a key.\n",
- "\n",
- "* If you know that keys are associated with your examples, wrap the model handler with `beam.KeyedModelHandler`.\n",
- "* If you don't know whether keys are associated with your examples, use `beam.MaybeKeyedModelHandler`.\n",
- "\n",
- "In addition to demonstrating how to use a keyed model handler, this step demonstrates how to use `tfx-bsl` examples."
- ],
- "metadata": {
- "id": "IXikjkGdHm9n"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "from apache_beam.ml.inference.base import KeyedModelHandler\n",
- "from google.protobuf import text_format\n",
- "import tensorflow as tf\n",
- "\n",
- "class FormatOutputKeyed(FormatOutput):\n",
- " # To simplify, inherit from FormatOutput.\n",
- " def process(self, tuple_in: Tuple):\n",
- " key, element = tuple_in\n",
- " output = super().process(element)\n",
- " yield ' : '.join([key, next(output)])\n",
- "\n",
- "def make_example(num):\n",
- " # Return keyed values in the form of (key num, example).\n",
- " key = f'key {num}'\n",
- " tf_proto = text_format.Parse(\n",
- " \"\"\"\n",
- " features {\n",
- " feature {key: \"x\" value { float_list { value: %f }}}\n",
- " }\n",
- " \"\"\"% num, tf.train.Example())\n",
- " return (key, tf_proto)\n",
- "\n",
- "# Make a list of examples of type tf.train.Example.\n",
- "examples = [\n",
- " make_example(5.0),\n",
- " make_example(50.0),\n",
- " make_example(40.0),\n",
- " make_example(100.0)\n",
- "]\n",
- "\n",
- "tfexample_beam_record = tfx_bsl.public.tfxio.TFExampleRecord(file_pattern=predict_values_five_times_table)\n",
- "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=save_model_dir_multiply)\n",
- "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n",
- "keyed_model_handler = KeyedModelHandler(CreateModelHandler(inference_spec_type))\n",
- "with beam.Pipeline() as p:\n",
- " _ = (p | 'CreateExamples' >> beam.Create(examples)\n",
- " | RunInference(keyed_model_handler)\n",
- " | beam.ParDo(FormatOutputKeyed())\n",
- " | beam.Map(print)\n",
- " )"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "KPtE3fmdJQry",
- "outputId": "c33558fc-fb12-4c20-b828-b5520721f279"
- },
- "execution_count": 11,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "key 5.0 : example is 5.00 prediction is 30.67\n",
- "key 50.0 : example is 50.00 prediction is 251.75\n",
- "key 40.0 : example is 40.00 prediction is 202.62\n",
- "key 100.0 : example is 100.00 prediction is 497.38\n"
- ]
- }
- ]
- }
- ]
-}
\ No newline at end of file