You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/09/13 14:49:41 UTC

[GitHub] [beam] damccorm commented on a diff in pull request #23173: Add a tensorflow example to the run_inference_basic notebook

damccorm commented on code in PR #23173:
URL: https://github.com/apache/beam/pull/23173#discussion_r969724448


##########
examples/notebooks/beam-ml/run_inference_basic.ipynb:
##########
@@ -1362,6 +1308,408 @@
         "      | beam.Map(print)\n",
         "  )\n"
       ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# TensorFlow - Examples\n",
+        "\n",
+        "[TFX-BSL](https://github.com/tensorflow/tfx-bsl) offers an API that returns a ModelHandler that integrates with beam and works with TensorFlow models.\n",
+        "\n",
+        "It also works with TensorFlow models hosted on VertexAI."
+      ],
+      "metadata": {
+        "id": "F_vhQJTgcc_y"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# The TF ModelHandler is impelemented in tfx_bsl 1.10 and later.\n",
+        "!pip install tfx_bsl==1.10.0 --quiet\n",
+        "#NOTE: You may need to restart your runtime."
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "a2FoAA-feV5L",
+        "outputId": "638e06bd-a057-40ec-d263-4fb87568ef4f"
+      },
+      "execution_count": 1,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "\u001b[K     |████████████████████████████████| 21.6 MB 2.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 10.9 MB 29.1 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 578.0 MB 11 kB/s \n",
+            "\u001b[K     |████████████████████████████████| 2.4 MB 40.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 270 kB 60.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 508 kB 35.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 47 kB 3.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 151 kB 54.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 62 kB 1.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 124 kB 64.1 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 180 kB 38.9 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 83 kB 1.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 183 kB 52.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 255 kB 66.0 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 435 kB 52.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 173 kB 59.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 267 kB 55.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 267 kB 63.3 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 148 kB 57.6 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 235 kB 57.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 1.0 MB 51.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 119 kB 62.0 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 119 kB 53.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 118 kB 45.3 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 62.3 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 58.1 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 56.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 66.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 36.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 265 kB 56.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 148 kB 55.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 147 kB 65.9 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 4.6 MB 35.9 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 5.9 MB 43.6 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 1.1 MB 50.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 438 kB 49.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 1.7 MB 48.0 MB/s \n",
+            "\u001b[?25h  Building wheel for dill (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+            "  Building wheel for google-apitools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+            "  Building wheel for docopt (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Authenticate\n",
+        "\n",
+        "Authorize yourself so that you can save and load models."
+      ],
+      "metadata": {
+        "id": "8RHZxqB1HLlr"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from google.colab import auth\n",
+        "auth.authenticate_user()"
+      ],
+      "metadata": {
+        "id": "qDdGKYOoevNb"
+      },
+      "execution_count": 2,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import argparse\n",
+        "\n",
+        "import tensorflow as tf\n",
+        "from tensorflow import keras\n",
+        "from tensorflow_serving.apis import prediction_log_pb2\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "from apache_beam.ml.inference.base import RunInference\n",
+        "import tfx_bsl\n",
+        "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n",
+        "from tfx_bsl.public import tfxio\n",
+        "from tfx_bsl.public.proto import model_spec_pb2\n",
+        "from tensorflow_metadata.proto.v0 import schema_pb2\n",
+        "\n",
+        "import numpy\n",
+        "\n",
+        "from typing import Dict, Text, Any, Tuple, List\n",
+        "\n",
+        "from apache_beam.options.pipeline_options import PipelineOptions\n",
+        "\n",
+        "project = \"<Your Project>\"\n",
+        "bucket = \"<Your Bucket>\"\n",
+        "\n",
+        "save_model_dir_multiply = f'gs://{bucket}/tfx-inference/model/multiply_five/v1/'\n",
+        "\n",
+        "# NOTE: If you get a ContextualVersionConflict error, restart your runtime\n",
+        "# environment (under the menu Runtime->Restart Runtime)\n"
+      ],
+      "metadata": {
+        "id": "zSashIqde3Ar"
+      },
+      "execution_count": 2,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Create and Test a Simple Model\n",
+        "\n",
+        "This creates a model that predicts the 5 times table."
+      ],
+      "metadata": {
+        "id": "ohPytRVDHc94"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Create training data which represents the 5 times multiplication table for 0 to 99. x is the data and y the labels. \n",
+        "x = numpy.arange(0, 100)\n",
+        "y = x * 5\n",
+        "\n",
+        "\n",
+        "# Build a simple linear regression model.\n",
+        "# Note the model has a shape of (1) for its input layer, it will expect a single int64 value.\n",
+        "\n",
+        "input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')\n",
+        "output_layer= keras.layers.Dense(1)(input_layer)\n",
+        "\n",
+        "model = keras.Model(input_layer, output_layer)\n",
+        "model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n",
+        "\n",
+        "# Test the model\n",
+        "model.fit(x, y, epochs=4000, verbose=0)\n",
+        "value_to_predict = numpy.array([20, 40, 60, 90], dtype=numpy.float32)\n",
+        "model.predict(value_to_predict)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "RmzuFB1ymqdP",
+        "outputId": "bec9a1d1-3d90-4b5f-9819-d20a7644adce"
+      },
+      "execution_count": 7,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "1/1 [==============================] - 0s 56ms/step\n"
+          ]
+        },
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "array([[100.00245],\n",
+              "       [200.0051 ],\n",
+              "       [300.00775],\n",
+              "       [450.01172]], dtype=float32)"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 7
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Populate the Data into a TensorFlow Proto\n",

Review Comment:
   I would cut this section and the `Fit the Model` section - right now this TensorFlow section spends 4 sections on training a TF model and only one on actually using RunInference, I think that's probably backwards from what we want. I'd assume people know how to train a model (or point them to docs that can teach them)



##########
examples/notebooks/beam-ml/run_inference_basic.ipynb:
##########
@@ -1362,6 +1308,408 @@
         "      | beam.Map(print)\n",
         "  )\n"
       ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# TensorFlow - Examples\n",

Review Comment:
   Nit: Right now we have all 3 frameworks labeled differently (`# PyTorch`, `# Sklearn implementation of RunInference API.`, `# TensorFlow - Examples`). I don't really care which we choose, but can we consolidate on a pattern?



##########
examples/notebooks/beam-ml/run_inference_basic.ipynb:
##########
@@ -1362,6 +1308,408 @@
         "      | beam.Map(print)\n",
         "  )\n"
       ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# TensorFlow - Examples\n",
+        "\n",
+        "[TFX-BSL](https://github.com/tensorflow/tfx-bsl) offers an API that returns a ModelHandler that integrates with beam and works with TensorFlow models.\n",
+        "\n",
+        "It also works with TensorFlow models hosted on VertexAI."
+      ],
+      "metadata": {
+        "id": "F_vhQJTgcc_y"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# The TF ModelHandler is impelemented in tfx_bsl 1.10 and later.\n",
+        "!pip install tfx_bsl==1.10.0 --quiet\n",
+        "#NOTE: You may need to restart your runtime."
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "a2FoAA-feV5L",
+        "outputId": "638e06bd-a057-40ec-d263-4fb87568ef4f"
+      },
+      "execution_count": 1,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "\u001b[K     |████████████████████████████████| 21.6 MB 2.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 10.9 MB 29.1 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 578.0 MB 11 kB/s \n",
+            "\u001b[K     |████████████████████████████████| 2.4 MB 40.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 270 kB 60.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 508 kB 35.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 47 kB 3.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 151 kB 54.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 62 kB 1.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 124 kB 64.1 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 180 kB 38.9 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 83 kB 1.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 183 kB 52.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 255 kB 66.0 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 435 kB 52.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 173 kB 59.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 267 kB 55.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 267 kB 63.3 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 148 kB 57.6 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 235 kB 57.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 1.0 MB 51.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 119 kB 62.0 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 119 kB 53.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 118 kB 45.3 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 62.3 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 58.1 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 56.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 66.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 36.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 265 kB 56.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 148 kB 55.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 147 kB 65.9 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 4.6 MB 35.9 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 5.9 MB 43.6 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 1.1 MB 50.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 438 kB 49.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 1.7 MB 48.0 MB/s \n",
+            "\u001b[?25h  Building wheel for dill (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+            "  Building wheel for google-apitools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+            "  Building wheel for docopt (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Authenticate\n",
+        "\n",
+        "Authorize yourself so that you can save and load models."
+      ],
+      "metadata": {
+        "id": "8RHZxqB1HLlr"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from google.colab import auth\n",
+        "auth.authenticate_user()"
+      ],
+      "metadata": {
+        "id": "qDdGKYOoevNb"
+      },
+      "execution_count": 2,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import argparse\n",
+        "\n",
+        "import tensorflow as tf\n",
+        "from tensorflow import keras\n",
+        "from tensorflow_serving.apis import prediction_log_pb2\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "from apache_beam.ml.inference.base import RunInference\n",
+        "import tfx_bsl\n",
+        "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n",
+        "from tfx_bsl.public import tfxio\n",
+        "from tfx_bsl.public.proto import model_spec_pb2\n",
+        "from tensorflow_metadata.proto.v0 import schema_pb2\n",
+        "\n",
+        "import numpy\n",
+        "\n",
+        "from typing import Dict, Text, Any, Tuple, List\n",
+        "\n",
+        "from apache_beam.options.pipeline_options import PipelineOptions\n",
+        "\n",
+        "project = \"<Your Project>\"\n",
+        "bucket = \"<Your Bucket>\"\n",
+        "\n",
+        "save_model_dir_multiply = f'gs://{bucket}/tfx-inference/model/multiply_five/v1/'\n",
+        "\n",
+        "# NOTE: If you get a ContextualVersionConflict error, restart your runtime\n",
+        "# environment (under the menu Runtime->Restart Runtime)\n"
+      ],
+      "metadata": {
+        "id": "zSashIqde3Ar"
+      },
+      "execution_count": 2,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Create and Test a Simple Model\n",
+        "\n",
+        "This creates a model that predicts the 5 times table."
+      ],
+      "metadata": {
+        "id": "ohPytRVDHc94"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Create training data which represents the 5 times multiplication table for 0 to 99. x is the data and y the labels. \n",
+        "x = numpy.arange(0, 100)\n",
+        "y = x * 5\n",
+        "\n",
+        "\n",
+        "# Build a simple linear regression model.\n",
+        "# Note the model has a shape of (1) for its input layer, it will expect a single int64 value.\n",
+        "\n",
+        "input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')\n",
+        "output_layer= keras.layers.Dense(1)(input_layer)\n",
+        "\n",
+        "model = keras.Model(input_layer, output_layer)\n",
+        "model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n",
+        "\n",
+        "# Test the model\n",
+        "model.fit(x, y, epochs=4000, verbose=0)\n",
+        "value_to_predict = numpy.array([20, 40, 60, 90], dtype=numpy.float32)\n",
+        "model.predict(value_to_predict)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "RmzuFB1ymqdP",
+        "outputId": "bec9a1d1-3d90-4b5f-9819-d20a7644adce"
+      },
+      "execution_count": 7,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "1/1 [==============================] - 0s 56ms/step\n"
+          ]
+        },
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "array([[100.00245],\n",
+              "       [200.0051 ],\n",
+              "       [300.00775],\n",
+              "       [450.01172]], dtype=float32)"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 7
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Populate the Data into a TensorFlow Proto\n",
+        "\n",
+        "Tensorflow data uses protos. If you are loading from a file there are helpers for this. Since we are using generated data, this code populates a proto."
+      ],
+      "metadata": {
+        "id": "y3Z68x9qI3Jt"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# This is an example of a proto that converts the samples and labels into\n",
+        "# tensors usable by tensorflow.\n",
+        "class ExampleProcessor:\n",
+        "    def create_example_with_label(self, feature: numpy.float32,\n",
+        "                             label: numpy.float32)-> tf.train.Example:\n",
+        "        return tf.train.Example(\n",
+        "            features=tf.train.Features(\n",
+        "                  feature={'x': self.create_feature(feature),\n",
+        "                           'y' : self.create_feature(label)\n",
+        "                  }))\n",
+        "\n",
+        "    def create_example(self, feature: numpy.float32):\n",
+        "        return tf.train.Example(\n",
+        "            features=tf.train.Features(\n",
+        "                  feature={'x' : self.create_feature(feature)})\n",
+        "            )\n",
+        "\n",
+        "    def create_feature(self, element: numpy.float32):\n",
+        "        return tf.train.Feature(float_list=tf.train.FloatList(value=[element]))\n",
+        "\n",
+        "# Create our labeled example file for 5 times table\n",
+        "\n",
+        "example_five_times_table = 'example_five_times_table.tfrecord'\n",
+        "\n",
+        "with tf.io.TFRecordWriter(example_five_times_table) as writer:\n",
+        "  for i in zip(x, y):\n",
+        "    example = ExampleProcessor().create_example_with_label(\n",
+        "        feature=i[0], label=i[1])\n",
+        "    writer.write(example.SerializeToString())\n",
+        "\n",
+        "# Create a file containing the values to predict\n",
+        "\n",
+        "predict_values_five_times_table = 'predict_values_five_times_table.tfrecord'\n",
+        "\n",
+        "with tf.io.TFRecordWriter(predict_values_five_times_table) as writer:\n",
+        "  for i in value_to_predict:\n",
+        "    example = ExampleProcessor().create_example(feature=i)\n",
+        "    writer.write(example.SerializeToString())\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "t2lQIfeonS0Q"
+      },
+      "execution_count": 8,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Fit The Model\n",
+        "\n",
+        "This example builds a model. However, normal RunInference usage is for pretrained models."
+      ],
+      "metadata": {
+        "id": "qN5RobQ2JN3O"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Build the training dataset.\n",
+        "RAW_DATA_TRAIN_SPEC = {\n",
+        "'x': tf.io.FixedLenFeature([], tf.float32),\n",
+        "'y': tf.io.FixedLenFeature([], tf.float32)\n",
+        "}\n",
+        "\n",
+        "RAW_DATA_PREDICT_SPEC = {\n",
+        "'x': tf.io.FixedLenFeature([], tf.float32),\n",
+        "}\n",
+        "\n",
+        "dataset = tf.data.TFRecordDataset(example_five_times_table)\n",
+        "dataset = dataset.map(lambda e : tf.io.parse_example(e, RAW_DATA_TRAIN_SPEC))\n",
+        "dataset = dataset.map(lambda t : (t['x'], t['y']))\n",
+        "dataset = dataset.batch(100)\n",
+        "dataset = dataset.repeat()\n",
+        "\n",
+        "# Fit the model\n",
+        "model.fit(dataset, epochs=500, steps_per_epoch=1, verbose=0)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "03OVIGDtn7WI",
+        "outputId": "b1eef9a4-b5db-42e3-8f8a-fd3da71f8c38"
+      },
+      "execution_count": 13,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "<keras.callbacks.History at 0x7f5a3f8eebd0>"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 13
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Save the Model"
+      ],
+      "metadata": {
+        "id": "KtrqJ6i2JmAS"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string , name='examples')])\n",
+        "def serve_tf_examples_fn(serialized_tf_examples):\n",
+        "  \"\"\"Returns the output to be used in the serving signature.\"\"\"\n",
+        "  features = tf.io.parse_example(serialized_tf_examples, RAW_DATA_PREDICT_SPEC)\n",
+        "  return model(features, training=False)\n",
+        "\n",
+        "signature = {'serving_default': serve_tf_examples_fn}\n",
+        "\n",
+        "tf.keras.models.save_model(model, save_model_dir_multiply, signatures=signature)"
+      ],
+      "metadata": {
+        "id": "pscBzhsakGnr"
+      },
+      "execution_count": 17,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Run the Pipeline\n",
+        "\n",
+        "PrintNicely demonstrates how to extract values from the output protos.\n",
+        "\n",
+        "CreateModelHandler demonstrates the model handler that needs to be passed into beams RunInference API."

Review Comment:
   ```suggestion
           "CreateModelHandler demonstrates the model handler that needs to be passed into Beam's RunInference API."
   ```



##########
examples/notebooks/beam-ml/run_inference_basic.ipynb:
##########
@@ -1362,6 +1308,408 @@
         "      | beam.Map(print)\n",
         "  )\n"
       ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# TensorFlow - Examples\n",
+        "\n",
+        "[TFX-BSL](https://github.com/tensorflow/tfx-bsl) offers an API that returns a ModelHandler that integrates with beam and works with TensorFlow models.\n",
+        "\n",
+        "It also works with TensorFlow models hosted on VertexAI."
+      ],
+      "metadata": {
+        "id": "F_vhQJTgcc_y"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# The TF ModelHandler is impelemented in tfx_bsl 1.10 and later.\n",
+        "!pip install tfx_bsl==1.10.0 --quiet\n",
+        "#NOTE: You may need to restart your runtime."
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "a2FoAA-feV5L",
+        "outputId": "638e06bd-a057-40ec-d263-4fb87568ef4f"
+      },
+      "execution_count": 1,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "\u001b[K     |████████████████████████████████| 21.6 MB 2.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 10.9 MB 29.1 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 578.0 MB 11 kB/s \n",
+            "\u001b[K     |████████████████████████████████| 2.4 MB 40.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 270 kB 60.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 508 kB 35.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 47 kB 3.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 151 kB 54.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 62 kB 1.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 124 kB 64.1 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 180 kB 38.9 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 83 kB 1.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 183 kB 52.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 255 kB 66.0 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 435 kB 52.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 173 kB 59.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 267 kB 55.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 267 kB 63.3 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 148 kB 57.6 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 235 kB 57.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 1.0 MB 51.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 119 kB 62.0 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 119 kB 53.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 118 kB 45.3 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 62.3 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 58.1 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 56.7 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 66.2 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 234 kB 36.8 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 265 kB 56.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 148 kB 55.4 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 147 kB 65.9 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 4.6 MB 35.9 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 5.9 MB 43.6 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 1.1 MB 50.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 438 kB 49.5 MB/s \n",
+            "\u001b[K     |████████████████████████████████| 1.7 MB 48.0 MB/s \n",
+            "\u001b[?25h  Building wheel for dill (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+            "  Building wheel for google-apitools (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+            "  Building wheel for docopt (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Authenticate\n",
+        "\n",
+        "Authorize yourself so that you can save and load models."
+      ],
+      "metadata": {
+        "id": "8RHZxqB1HLlr"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from google.colab import auth\n",
+        "auth.authenticate_user()"
+      ],
+      "metadata": {
+        "id": "qDdGKYOoevNb"
+      },
+      "execution_count": 2,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import argparse\n",
+        "\n",
+        "import tensorflow as tf\n",
+        "from tensorflow import keras\n",
+        "from tensorflow_serving.apis import prediction_log_pb2\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "from apache_beam.ml.inference.base import RunInference\n",
+        "import tfx_bsl\n",
+        "from tfx_bsl.public.beam.run_inference import CreateModelHandler\n",
+        "from tfx_bsl.public import tfxio\n",
+        "from tfx_bsl.public.proto import model_spec_pb2\n",
+        "from tensorflow_metadata.proto.v0 import schema_pb2\n",
+        "\n",
+        "import numpy\n",
+        "\n",
+        "from typing import Dict, Text, Any, Tuple, List\n",
+        "\n",
+        "from apache_beam.options.pipeline_options import PipelineOptions\n",
+        "\n",
+        "project = \"<Your Project>\"\n",
+        "bucket = \"<Your Bucket>\"\n",
+        "\n",
+        "save_model_dir_multiply = f'gs://{bucket}/tfx-inference/model/multiply_five/v1/'\n",
+        "\n",
+        "# NOTE: If you get a ContextualVersionConflict error, restart your runtime\n",
+        "# environment (under the menu Runtime->Restart Runtime)\n"
+      ],
+      "metadata": {
+        "id": "zSashIqde3Ar"
+      },
+      "execution_count": 2,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Create and Test a Simple Model\n",
+        "\n",
+        "This creates a model that predicts the 5 times table."
+      ],
+      "metadata": {
+        "id": "ohPytRVDHc94"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Create training data which represents the 5 times multiplication table for 0 to 99. x is the data and y the labels. \n",
+        "x = numpy.arange(0, 100)\n",
+        "y = x * 5\n",
+        "\n",
+        "\n",
+        "# Build a simple linear regression model.\n",
+        "# Note the model has a shape of (1) for its input layer, it will expect a single int64 value.\n",
+        "\n",
+        "input_layer = keras.layers.Input(shape=(1), dtype=tf.float32, name='x')\n",
+        "output_layer= keras.layers.Dense(1)(input_layer)\n",
+        "\n",
+        "model = keras.Model(input_layer, output_layer)\n",
+        "model.compile(optimizer=tf.optimizers.Adam(), loss='mean_absolute_error')\n",
+        "\n",
+        "# Test the model\n",
+        "model.fit(x, y, epochs=4000, verbose=0)\n",
+        "value_to_predict = numpy.array([20, 40, 60, 90], dtype=numpy.float32)\n",
+        "model.predict(value_to_predict)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "RmzuFB1ymqdP",
+        "outputId": "bec9a1d1-3d90-4b5f-9819-d20a7644adce"
+      },
+      "execution_count": 7,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "1/1 [==============================] - 0s 56ms/step\n"
+          ]
+        },
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "array([[100.00245],\n",
+              "       [200.0051 ],\n",
+              "       [300.00775],\n",
+              "       [450.01172]], dtype=float32)"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 7
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Populate the Data into a TensorFlow Proto\n",
+        "\n",
+        "Tensorflow data uses protos. If you are loading from a file there are helpers for this. Since we are using generated data, this code populates a proto."
+      ],
+      "metadata": {
+        "id": "y3Z68x9qI3Jt"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# This is an example of a proto that converts the samples and labels into\n",
+        "# tensors usable by tensorflow.\n",
+        "class ExampleProcessor:\n",
+        "    def create_example_with_label(self, feature: numpy.float32,\n",
+        "                             label: numpy.float32)-> tf.train.Example:\n",
+        "        return tf.train.Example(\n",
+        "            features=tf.train.Features(\n",
+        "                  feature={'x': self.create_feature(feature),\n",
+        "                           'y' : self.create_feature(label)\n",
+        "                  }))\n",
+        "\n",
+        "    def create_example(self, feature: numpy.float32):\n",
+        "        return tf.train.Example(\n",
+        "            features=tf.train.Features(\n",
+        "                  feature={'x' : self.create_feature(feature)})\n",
+        "            )\n",
+        "\n",
+        "    def create_feature(self, element: numpy.float32):\n",
+        "        return tf.train.Feature(float_list=tf.train.FloatList(value=[element]))\n",
+        "\n",
+        "# Create our labeled example file for 5 times table\n",
+        "\n",
+        "example_five_times_table = 'example_five_times_table.tfrecord'\n",
+        "\n",
+        "with tf.io.TFRecordWriter(example_five_times_table) as writer:\n",
+        "  for i in zip(x, y):\n",
+        "    example = ExampleProcessor().create_example_with_label(\n",
+        "        feature=i[0], label=i[1])\n",
+        "    writer.write(example.SerializeToString())\n",
+        "\n",
+        "# Create a file containing the values to predict\n",
+        "\n",
+        "predict_values_five_times_table = 'predict_values_five_times_table.tfrecord'\n",
+        "\n",
+        "with tf.io.TFRecordWriter(predict_values_five_times_table) as writer:\n",
+        "  for i in value_to_predict:\n",
+        "    example = ExampleProcessor().create_example(feature=i)\n",
+        "    writer.write(example.SerializeToString())\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "t2lQIfeonS0Q"
+      },
+      "execution_count": 8,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Fit The Model\n",
+        "\n",
+        "This example builds a model. However, normal RunInference usage is for pretrained models."
+      ],
+      "metadata": {
+        "id": "qN5RobQ2JN3O"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Build the training dataset.\n",
+        "RAW_DATA_TRAIN_SPEC = {\n",
+        "'x': tf.io.FixedLenFeature([], tf.float32),\n",
+        "'y': tf.io.FixedLenFeature([], tf.float32)\n",
+        "}\n",
+        "\n",
+        "RAW_DATA_PREDICT_SPEC = {\n",
+        "'x': tf.io.FixedLenFeature([], tf.float32),\n",
+        "}\n",
+        "\n",
+        "dataset = tf.data.TFRecordDataset(example_five_times_table)\n",
+        "dataset = dataset.map(lambda e : tf.io.parse_example(e, RAW_DATA_TRAIN_SPEC))\n",
+        "dataset = dataset.map(lambda t : (t['x'], t['y']))\n",
+        "dataset = dataset.batch(100)\n",
+        "dataset = dataset.repeat()\n",
+        "\n",
+        "# Fit the model\n",
+        "model.fit(dataset, epochs=500, steps_per_epoch=1, verbose=0)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "03OVIGDtn7WI",
+        "outputId": "b1eef9a4-b5db-42e3-8f8a-fd3da71f8c38"
+      },
+      "execution_count": 13,
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "<keras.callbacks.History at 0x7f5a3f8eebd0>"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 13
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Save the Model"
+      ],
+      "metadata": {
+        "id": "KtrqJ6i2JmAS"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string , name='examples')])\n",
+        "def serve_tf_examples_fn(serialized_tf_examples):\n",
+        "  \"\"\"Returns the output to be used in the serving signature.\"\"\"\n",
+        "  features = tf.io.parse_example(serialized_tf_examples, RAW_DATA_PREDICT_SPEC)\n",
+        "  return model(features, training=False)\n",
+        "\n",
+        "signature = {'serving_default': serve_tf_examples_fn}\n",
+        "\n",
+        "tf.keras.models.save_model(model, save_model_dir_multiply, signatures=signature)"
+      ],
+      "metadata": {
+        "id": "pscBzhsakGnr"
+      },
+      "execution_count": 17,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Run the Pipeline\n",

Review Comment:
   I'd love to see another section or 2 on common patterns (this could mirror PyTorch a little bit, though it doesn't have to cover the same ground - e.g. an A/B example could be neat, though it doesn't need to be that specifically.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org