You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ri...@apache.org on 2024/02/14 23:50:55 UTC

(beam) branch master updated: [Docs] Add example notebook for enrichment transform with bigtable (#30315)

This is an automated email from the ASF dual-hosted git repository.

riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6377cbe20cc [Docs] Add example notebook for enrichment transform with bigtable (#30315)
6377cbe20cc is described below

commit 6377cbe20cc387d7986709014efc4cedfb2099f3
Author: Ritesh Ghorse <ri...@gmail.com>
AuthorDate: Wed Feb 14 18:50:47 2024 -0500

    [Docs] Add example notebook for enrichment transform with bigtable (#30315)
    
    * bigtable enrichment notebook
    
    * Update examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    * Update examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    * Update examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    * Update examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    * Update examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    * Update examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    * Update examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    * Update examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    * Update examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    * changes from the review
    
    * Apply suggestions from code review
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
    
    ---------
    
    Co-authored-by: Rebecca Szper <98...@users.noreply.github.com>
---
 .../beam-ml/bigtable_enrichment_transform.ipynb    | 854 +++++++++++++++++++++
 1 file changed, 854 insertions(+)

diff --git a/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
new file mode 100644
index 00000000000..455b5cccd5d
--- /dev/null
+++ b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
@@ -0,0 +1,854 @@
+{
+  "cells": [
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "fFjof1NgAJwu",
+        "cellView": "form"
+      },
+      "outputs": [],
+      "source": [
+        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
+        "\n",
+        "# Licensed to the Apache Software Foundation (ASF) under one\n",
+        "# or more contributor license agreements. See the NOTICE file\n",
+        "# distributed with this work for additional information\n",
+        "# regarding copyright ownership. The ASF licenses this file\n",
+        "# to you under the Apache License, Version 2.0 (the\n",
+        "# \"License\"); you may not use this file except in compliance\n",
+        "# with the License. You may obtain a copy of the License at\n",
+        "#\n",
+        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing,\n",
+        "# software distributed under the License is distributed on an\n",
+        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+        "# KIND, either express or implied. See the License for the\n",
+        "# specific language governing permissions and limitations\n",
+        "# under the License"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "A8xNRyZMW1yK"
+      },
+      "source": [
+        "# Use Apache Beam and Bigtable to enrich data\n",
+        "\n",
+        "<table align=\"left\">\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
+        "  </td>\n",
+        "  <td>\n",
+        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
+        "  </td>\n",
+        "</table>\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "HrCtxslBGK8Z"
+      },
+      "source": [
+        "This notebook shows how to enrich data by using the Apache Beam [enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) with [Bigtable](https://cloud.google.com/bigtable/docs/overview). The enrichment transform is a turnkey transform in Apache Beam that lets you enrich data by using a key-value lookup. This transform has the following features:\n",
+        "\n",
+        "- The transform has a built-in Apache Beam handler that interacts with Bigtable to get data to use in the enrichment.\n",
+        "- The enrichment transform uses client-side throttling to manage rate limiting the requests. The requests are exponentially backed off with a default retry strategy. You can configure rate limiting to suit your use case."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "This notebook demonstrates the following ecommerce use case:\n",
+        "\n",
+        "A stream of online transaction from [Pub/Sub](https://cloud.google.com/pubsub/docs/guides) contains the following fields: `sale_id`, `product_id`, `customer_id`, `quantity`, and `price`. Additional customer demographic data is stored in a separate Bigtable cluster. The demographic data is used to enrich the event stream from Pub/Sub. Then, the enriched data is used to predict the next product to recommended to a customer."
+      ],
+      "metadata": {
+        "id": "ltn5zrBiGS9C"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "gVCtGOKTHMm4"
+      },
+      "source": [
+        "## Before you begin\n",
+        "Set up your environment and download dependencies."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "YDHPlMjZRuY0"
+      },
+      "source": [
+        "### Install Apache Beam\n",
+        "To use the enrichment transform with the built-in Bigtable handler, install the Apache Beam SDK version 2.54.0 or later."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "jBakpNZnAhqk"
+      },
+      "outputs": [],
+      "source": [
+        "!pip install torch\n",
+        "!pip install apache_beam[interactive,gcp]==2.54.0 --quiet"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import datetime\n",
+        "import json\n",
+        "import math\n",
+        "\n",
+        "from typing import Any\n",
+        "from typing import Dict\n",
+        "\n",
+        "import torch\n",
+        "from google.cloud import pubsub_v1\n",
+        "from google.cloud.bigtable import Client\n",
+        "from google.cloud.bigtable import column_family\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "import apache_beam.runners.interactive.interactive_beam as ib\n",
+        "from apache_beam.ml.inference.base import RunInference\n",
+        "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n",
+        "from apache_beam.options import pipeline_options\n",
+        "from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n",
+        "from apache_beam.transforms.enrichment import Enrichment\n",
+        "from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler"
+      ],
+      "metadata": {
+        "id": "SiJii48A2Rnb"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "X80jy3FqHjK4"
+      },
+      "source": [
+        "### Authenticate with Google Cloud\n",
+        "This notebook reads data from Pub/Sub and Bigtable. To use your Google Cloud account, authenticate this notebook."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "Kz9sccyGBqz3"
+      },
+      "outputs": [],
+      "source": [
+        "from google.colab import auth\n",
+        "auth.authenticate_user()"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Replace `<PROJECT_ID>`, `<INSTANCE_ID>`, and `<TABLE_ID>` with the appropriate values for your setup. These fields are used with Bigtable."
+      ],
+      "metadata": {
+        "id": "nAmGgUMt48o9"
+      }
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 1,
+      "metadata": {
+        "id": "wEXucyi2liij"
+      },
+      "outputs": [],
+      "source": [
+        "PROJECT_ID = \"<PROJECT_ID>\"\n",
+        "INSTANCE_ID = \"<INSTANCE_ID>\"\n",
+        "TABLE_ID = \"<TABLE_ID>\""
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Train the model\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "RpqZFfFfA_Dt"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Create sample data by using the format `[product_id, quantity, price, customer_id, customer_location, recommend_product_id]`."
+      ],
+      "metadata": {
+        "id": "8cUpV7mkB_xE"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "data = [\n",
+        "    [3, 5, 127, 9, 'China', 7], [1, 6, 167, 5, 'Peru', 4], [5, 4, 91, 2, 'USA', 8], [7, 2, 52, 1, 'India', 4], [1, 8, 118, 3, 'UK', 8], [4, 6, 132, 8, 'Mexico', 2],\n",
+        "    [6, 3, 154, 6, 'Brazil', 3], [4, 7, 163, 1, 'India', 7], [5, 2, 80, 4, 'Egypt', 9], [9, 4, 107, 7, 'Bangladesh', 1], [2, 9, 192, 8, 'Mexico', 4], [4, 5, 116, 5, 'Peru', 8],\n",
+        "    [8, 1, 195, 1, 'India', 7], [8, 6, 153, 5, 'Peru', 1], [5, 3, 120, 6, 'Brazil', 2], [2, 7, 187, 7, 'Bangladesh', 4], [1, 8, 103, 6, 'Brazil', 8], [2, 9, 181, 1, 'India', 8],\n",
+        "    [6, 5, 166, 3, 'UK', 5], [3, 4, 115, 8, 'Mexico', 1], [4, 7, 170, 4, 'Egypt', 2], [9, 3, 141, 7, 'Bangladesh', 3], [9, 3, 157, 1, 'India', 2], [7, 6, 128, 9, 'China', 1],\n",
+        "    [1, 8, 102, 3, 'UK', 4], [5, 2, 107, 4, 'Egypt', 6], [6, 5, 164, 8, 'Mexico', 9], [4, 7, 188, 5, 'Peru', 1], [8, 1, 184, 1, 'India', 2], [8, 6, 198, 2, 'USA', 5],\n",
+        "    [5, 3, 105, 6, 'Brazil', 7], [2, 7, 162, 7, 'Bangladesh', 7], [1, 8, 133, 9, 'China', 3], [2, 9, 173, 1, 'India', 7], [6, 5, 183, 5, 'Peru', 8], [3, 4, 191, 3, 'UK', 6],\n",
+        "    [4, 7, 123, 2, 'USA', 5], [9, 3, 159, 8, 'Mexico', 2], [9, 3, 146, 4, 'Egypt', 8], [7, 6, 194, 1, 'India', 8], [3, 5, 112, 6, 'Brazil', 1], [4, 6, 101, 7, 'Bangladesh', 2],\n",
+        "    [8, 1, 192, 4, 'Egypt', 4], [7, 2, 196, 5, 'Peru', 6], [9, 4, 124, 9, 'China', 7], [3, 4, 129, 5, 'Peru', 6], [6, 3, 151, 8, 'Mexico', 9], [5, 7, 114, 7, 'Bangladesh', 4],\n",
+        "    [4, 7, 175, 6, 'Brazil', 5], [1, 8, 121, 1, 'India', 2], [4, 6, 187, 2, 'USA', 5], [6, 5, 144, 9, 'China', 9], [9, 4, 103, 5, 'Peru', 3], [5, 3, 84, 3, 'UK', 1],\n",
+        "    [3, 5, 193, 2, 'USA', 4], [4, 7, 135, 1, 'India', 1], [7, 6, 148, 8, 'Mexico', 8], [1, 6, 160, 5, 'Peru', 7], [8, 6, 155, 6, 'Brazil', 9], [5, 7, 183, 7, 'Bangladesh', 2],\n",
+        "    [2, 9, 125, 4, 'Egypt', 4], [6, 3, 111, 9, 'China', 9], [5, 2, 132, 3, 'UK', 3], [4, 5, 104, 7, 'Bangladesh', 7], [2, 7, 177, 8, 'Mexico', 7]]"
+      ],
+      "metadata": {
+        "id": "TpxDHGObBEsj"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "countries_to_id = {'India': 1, 'USA': 2, 'UK': 3, 'Egypt': 4, 'Peru': 5,\n",
+        "                   'Brazil': 6, 'Bangladesh': 7, 'Mexico': 8, 'China': 9}"
+      ],
+      "metadata": {
+        "id": "bQt1cB4-CSBd"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Preprocess the data:\n",
+        "\n",
+        "1.   Convert the lists to tensors.\n",
+        "2.   Separate the features from the expected prediction."
+      ],
+      "metadata": {
+        "id": "Y0Duet4nCdN1"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "X = [torch.tensor(item[:4]+[countries_to_id[item[4]]], dtype=torch.float) for item in data]\n",
+        "Y = [torch.tensor(item[-1], dtype=torch.float) for item in data]"
+      ],
+      "metadata": {
+        "id": "7TT1O7sBCaZN"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Define a simple model that has five input features and predicts a single value."
+      ],
+      "metadata": {
+        "id": "q6wB_ZsXDjjd"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def build_model(n_inputs, n_outputs):\n",
+        "  \"\"\"build_model builds and returns a model that takes\n",
+        "  `n_inputs` features and predicts `n_outputs` value\"\"\"\n",
+        "  return torch.nn.Sequential(\n",
+        "      torch.nn.Linear(n_inputs, 8),\n",
+        "      torch.nn.ReLU(),\n",
+        "      torch.nn.Linear(8, 16),\n",
+        "      torch.nn.ReLU(),\n",
+        "      torch.nn.Linear(16, n_outputs))"
+      ],
+      "metadata": {
+        "id": "nphNfhUnESES"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Train the model."
+      ],
+      "metadata": {
+        "id": "_sBSzDllEmCz"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "model = build_model(n_inputs=5, n_outputs=1)\n",
+        "\n",
+        "loss_fn = torch.nn.MSELoss()\n",
+        "optimizer = torch.optim.Adam(model.parameters())\n",
+        "\n",
+        "for epoch in range(1000):\n",
+        "  print(f'Epoch {epoch}: ---')\n",
+        "  optimizer.zero_grad()\n",
+        "  for i in range(len(X)):\n",
+        "    pred = model(X[i])\n",
+        "    loss = loss_fn(pred, Y[i])\n",
+        "    loss.backward()\n",
+        "  optimizer.step()"
+      ],
+      "metadata": {
+        "id": "CaYrplaPDayp"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Save the model to the `STATE_DICT_PATH` variable."
+      ],
+      "metadata": {
+        "id": "_rJYv8fFFPYb"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "STATE_DICT_PATH = './model.pth'\n",
+        "\n",
+        "torch.save(model.state_dict(), STATE_DICT_PATH)"
+      ],
+      "metadata": {
+        "id": "W4t260o9FURP"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Set up the Bigtable table\n",
+        "\n",
+        "Create a sample Bigtable table for this notebook."
+      ],
+      "metadata": {
+        "id": "ouMQZ4sC4zuO"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Connect to the Bigtable instance. If you don't have admin access, then drop `admin=True`.\n",
+        "client = Client(project=PROJECT_ID, admin=True)\n",
+        "instance = client.instance(INSTANCE_ID)\n",
+        "\n",
+        "# Create a column family.\n",
+        "column_family_id = 'demograph'\n",
+        "max_versions_rule = column_family.MaxVersionsGCRule(2)\n",
+        "column_families = {column_family_id: max_versions_rule}\n",
+        "\n",
+        "# Create a table.\n",
+        "table = instance.table(TABLE_ID)\n",
+        "\n",
+        "# You need admin access to use `.exists()`. If you don't have the admin access, then\n",
+        "# comment out the if-else block.\n",
+        "if not table.exists():\n",
+        "  table.create(column_families=column_families)\n",
+        "else:\n",
+        "  print(\"Table %s already exists in %s:%s\" % (TABLE_ID, PROJECT_ID, INSTANCE_ID))"
+      ],
+      "metadata": {
+        "id": "E7Y4ipuL5kFD"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Add rows to the table for the enrichment example."
+      ],
+      "metadata": {
+        "id": "eQLkSg3p7WAm"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Define column names for the table.\n",
+        "customer_id = 'customer_id'\n",
+        "customer_name = 'customer_name'\n",
+        "customer_location = 'customer_location'\n",
+        "\n",
+        "# The following data is sample data to insert into Bigtable.\n",
+        "customers = [\n",
+        "  {\n",
+        "    'customer_id': 1, 'customer_name': 'Sam', 'customer_location': 'India'\n",
+        "  },\n",
+        "  {\n",
+        "    'customer_id': 2, 'customer_name': 'John', 'customer_location': 'USA'\n",
+        "  },\n",
+        "  {\n",
+        "    'customer_id': 3, 'customer_name': 'Travis', 'customer_location': 'UK'\n",
+        "  },\n",
+        "]\n",
+        "\n",
+        "for customer in customers:\n",
+        "  row_key = str(customer[customer_id]).encode()\n",
+        "  row = table.direct_row(row_key)\n",
+        "  row.set_cell(\n",
+        "    column_family_id,\n",
+        "    customer_id.encode(),\n",
+        "    str(customer[customer_id]),\n",
+        "    timestamp=datetime.datetime.utcnow())\n",
+        "  row.set_cell(\n",
+        "    column_family_id,\n",
+        "    customer_name.encode(),\n",
+        "    customer[customer_name],\n",
+        "    timestamp=datetime.datetime.utcnow())\n",
+        "  row.set_cell(\n",
+        "    column_family_id,\n",
+        "    customer_location.encode(),\n",
+        "    customer[customer_location],\n",
+        "    timestamp=datetime.datetime.utcnow())\n",
+        "  row.commit()\n",
+        "  print('Inserted row for key: %s' % customer[customer_id])"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "LI6oYkZ97Vtu",
+        "outputId": "c72b28b5-8692-40f5-f8da-85622437d8f7"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Inserted row for key: 1\n",
+            "Inserted row for key: 2\n",
+            "Inserted row for key: 3\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "### Publish messages to Pub/Sub\n",
+        "\n",
+        "Use the Pub/Sub Python client to publish messages.\n"
+      ],
+      "metadata": {
+        "id": "pHODouJDwc60"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Replace <TOPIC_NAME> with the name of your Pub/Sub topic.\n",
+        "TOPIC = \"<TOPIC_NAME>\"\n",
+        "\n",
+        "# Replace <SUBSCRIPTION_PATH> with the subscription for your topic.\n",
+        "SUBSCRIPTION = \"<SUBSCRIPTION_PATH>\"\n"
+      ],
+      "metadata": {
+        "id": "QKCuwDioxw-f"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "messages = [\n",
+        "    {'sale_id': i, 'customer_id': i, 'product_id': i, 'quantity': i, 'price': i*100}\n",
+        "    for i in range(1,4)\n",
+        "  ]\n",
+        "\n",
+        "publisher = pubsub_v1.PublisherClient()\n",
+        "topic_name = publisher.topic_path(PROJECT_ID, TOPIC)\n",
+        "\n",
+        "for message in messages:\n",
+        "  data = json.dumps(message).encode('utf-8')\n",
+        "  publish_future = publisher.publish(topic_name, data)"
+      ],
+      "metadata": {
+        "id": "MaCJwaPexPKZ"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Use the Bigtable enrichment handler\n",
+        "\n",
+        "The [`BigTableEnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler) is a built-in handler included in the Apache Beam SDK versions 2.54.0 and later."
+      ],
+      "metadata": {
+        "id": "zPSFEMm02omi"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "To establish a client for the Bigtable enrichment handler, replace `<PROJECT_ID>`, `<INSTANCE_ID>`, and `<TABLE_ID>` with the appropriate values for those fields. The `row_key` variable is the field name from the input row that contains the row key to use when querying Bigtable.\n",
+        "\n",
+        "To convert a `string` type to a `byte` type or a `byte` type to a `string` type from Bigtable, you can configure additional options, such as [`app_profile_id`](https://cloud.google.com/bigtable/docs/app-profiles), [`row_filter`](https://cloud.google.com/python/docs/reference/bigtable/latest/row-filters), and [`encoding`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.BigT [...]
+        "\n",
+        "The default `encoding` type is `utf-8`.\n",
+        "\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "K41xhvmA5yQk"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "row_key = 'customer_id'"
+      ],
+      "metadata": {
+        "id": "3dB26jhI45gd"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "bigtable_handler = BigTableEnrichmentHandler(project_id=PROJECT_ID,\n",
+        "                                             instance_id=INSTANCE_ID,\n",
+        "                                             table_id=TABLE_ID,\n",
+        "                                             row_key=row_key)"
+      ],
+      "metadata": {
+        "id": "cr1j_DHK4gA4"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "The `BigTableEnrichmentHandler` returns the latest value from the table without its associated timestamp for the `row_key` that you provide. If you want to fetch the `timestamp` associated with the `row_key` value, then pass `include_timestamp=True` to the handler.\n",
+        "\n",
+        "**Note:** When exceptions occur, by default, the logging severity is set to warning ([`ExceptionLevel.WARN`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel.WARN)).  To configure the severity to raise exceptions, set `exception_level` to [`ExceptionLevel.RAISE`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtabl [...]
+      ],
+      "metadata": {
+        "id": "yFMcaf8i7TbI"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Use the enrichment transform\n",
+        "\n",
+        "To use the [enrichment transform](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.Enrichment), the [`EnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.EnrichmentSourceHandler) parameter is required. You can also use a configuration parameter to specify a `lambda` for a join function, a timeout, a throttler, and a repeat [...]
+        "\n",
+        "\n",
+        "*   `join_fn`: A lambda function that takes dictionaries as input and returns an enriched row (`Callable[[Dict[str, Any], Dict[str, Any]], beam.Row]`). The enriched row specifies how to join the data fetched from the API. Defaults to a [cross-join](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.cross_join).\n",
+        "*   `timeout`: The number of seconds to wait for the request to be completed by the API before timing out. Defaults to 30 seconds.\n",
+        "*   `throttler`: Specifies the throttling mechanism. The only supported option is default client-side adaptive throttling.\n",
+        "*   `repeater`: Specifies the retry strategy when errors like `TooManyRequests` and `TimeoutException` occur. Defaults to [`ExponentialBackOffRepeater`](https://beam.apache.org/releases/pydoc/current/apache_beam.io.requestresponse.html#apache_beam.io.requestresponse.ExponentialBackOffRepeater).\n"
+      ],
+      "metadata": {
+        "id": "-Lvo8O2V-0Ey"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "The following example demonstrates the code needed to add this transform to your pipeline.\n",
+        "\n",
+        "\n",
+        "```\n",
+        "with beam.Pipeline() as p:\n",
+        "  output = (p\n",
+        "            ...\n",
+        "            | \"Enrich with BigTable\" >> Enrichment(bigtable_handler, timeout=10)\n",
+        "            | \"RunInference\" >> RunInference(model_handler)\n",
+        "            ...\n",
+        "            )\n",
+        "```\n",
+        "\n",
+        "\n",
+        "\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "xJTCfSmiV1kv"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "To make a prediction, use the following fields: `product_id`, `quantity`, `price`, `customer_id`, and `customer_location`. Retrieve the value of the `customer_location` field from Bigtable.\n",
+        "\n",
+        "Because the enrichment transform performs a [`cross_join`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.cross_join) by default, design the custom join to enrich the input data. This design ensures that the join includes only the specified fields."
+      ],
+      "metadata": {
+        "id": "F-xjiP_pHWZr"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def custom_join(left: Dict[str, Any], right: Dict[str, Any]):\n",
+        "  enriched = {}\n",
+        "  enriched['product_id'] = left['product_id']\n",
+        "  enriched['quantity'] = left['quantity']\n",
+        "  enriched['price'] = left['price']\n",
+        "  enriched['customer_id'] = left['customer_id']\n",
+        "  enriched['customer_location'] = right['demograph']['customer_location']\n",
+        "  return beam.Row(**enriched)"
+      ],
+      "metadata": {
+        "id": "8LnCnEPNIPtg"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "## Use the `PyTorchModelHandlerTensor` interface to run inference\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "CX9Cqybu6scV"
+      }
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Because the enrichment transform outputs data in the format `beam.Row`, to make it compatible with the [`PyTorchModelHandlerTensor`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.pytorch_inference.html#apache_beam.ml.inference.pytorch_inference.PytorchModelHandlerTensor) interface,  convert it to `torch.tensor`. Additionally, the enriched field `customer_location` is a `string` type, but the model requires a `float` type. Convert the `customer_location` [...]
+      ],
+      "metadata": {
+        "id": "zy5Jl7_gLklX"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "def convert_row_to_tensor(element: beam.Row):\n",
+        "  row_dict = element._asdict()\n",
+        "  row_dict['customer_location'] = countries_to_id[row_dict['customer_location']]\n",
+        "  return torch.tensor(list(row_dict.values()), dtype=torch.float)"
+      ],
+      "metadata": {
+        "id": "KBKoB06nL4LF"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Initialize the model handler with the preprocessing function."
+      ],
+      "metadata": {
+        "id": "-tGHyB_vL3rJ"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "model_handler = PytorchModelHandlerTensor(state_dict_path=STATE_DICT_PATH,\n",
+        "                                          model_class=build_model,\n",
+        "                                          model_params={'n_inputs':5, 'n_outputs':1}\n",
+        "                                          ).with_preprocess_fn(convert_row_to_tensor)"
+      ],
+      "metadata": {
+        "id": "VqUUEwcU-r2e"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Define a `DoFn` to format the output."
+      ],
+      "metadata": {
+        "id": "vNHI4gVgNec2"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class PostProcessor(beam.DoFn):\n",
+        "  def process(self, element, *args, **kwargs):\n",
+        "    print('Customer %d who bought product %d is recommended to buy product %d' % (element.example[3], element.example[0], math.ceil(element.inference[0])))"
+      ],
+      "metadata": {
+        "id": "rkN-_Yf4Nlwy"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "0a1zerXycQ0z"
+      },
+      "source": [
+        "## Run the pipeline\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Configure the pipeline to run in streaming mode."
+      ],
+      "metadata": {
+        "id": "WrwY0_gV_IDK"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "options = pipeline_options.PipelineOptions()\n",
+        "options.view_as(pipeline_options.StandardOptions).streaming = True # Streaming mode is set True"
+      ],
+      "metadata": {
+        "id": "t0425sYBsYtB"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Pub/Sub sends the data in bytes. Convert the data to `beam.Row` objects by using a `DoFn`."
+      ],
+      "metadata": {
+        "id": "DBNijQDY_dRe"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class DecodeBytes(beam.DoFn):\n",
+        "  \"\"\"\n",
+        "  The DecodeBytes `DoFn` converts the data read from Pub/Sub to `beam.Row`.\n",
+        "  First, decode the encoded string. Convert the output to\n",
+        "  a `dict` with `json.loads()`, which is used to create a `beam.Row`.\n",
+        "  \"\"\"\n",
+        "  def process(self, element, *args, **kwargs):\n",
+        "    element_dict = json.loads(element.decode('utf-8'))\n",
+        "    yield beam.Row(**element_dict)"
+      ],
+      "metadata": {
+        "id": "sRw9iL8pKP5O"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "Use the following code to run the pipeline.\n",
+        "\n",
+        "**Note:** Because this pipeline is a streaming pipeline, you need to manually stop the cell. If you don't stop the cell, the pipeline continues to run."
+      ],
+      "metadata": {
+        "id": "xofUJym-_GuB"
+      }
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "St07XoibcQSb",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 1000
+        },
+        "outputId": "34e0a603-fb77-455c-e40b-d15b672edeb2"
+      },
+      "outputs": [
+        {
+          "output_type": "display_data",
+          "data": {
+            "application/javascript": [
+              "\n",
+              "        if (typeof window.interactive_beam_jquery == 'undefined') {\n",
+              "          var jqueryScript = document.createElement('script');\n",
+              "          jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
+              "          jqueryScript.type = 'text/javascript';\n",
+              "          jqueryScript.onload = function() {\n",
+              "            var datatableScript = document.createElement('script');\n",
+              "            datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
+              "            datatableScript.type = 'text/javascript';\n",
+              "            datatableScript.onload = function() {\n",
+              "              window.interactive_beam_jquery = jQuery.noConflict(true);\n",
+              "              window.interactive_beam_jquery(document).ready(function($){\n",
+              "                \n",
+              "              });\n",
+              "            }\n",
+              "            document.head.appendChild(datatableScript);\n",
+              "          };\n",
+              "          document.head.appendChild(jqueryScript);\n",
+              "        } else {\n",
+              "          window.interactive_beam_jquery(document).ready(function($){\n",
+              "            \n",
+              "          });\n",
+              "        }"
+            ]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Customer 1 who bought product 1 is recommended to buy product 3\n",
+            "Customer 2 who bought product 2 is recommended to buy product 5\n",
+            "Customer 3 who bought product 3 is recommended to buy product 7\n"
+          ]
+        }
+      ],
+      "source": [
+        "with beam.Pipeline(options=options) as p:\n",
+        "  _ = (p\n",
+        "       | \"Read from Pub/Sub\" >> beam.io.ReadFromPubSub(subscription=SUBSCRIPTION)\n",
+        "       | \"ConvertToRow\" >> beam.ParDo(DecodeBytes())\n",
+        "       | \"Enrichment\" >> Enrichment(bigtable_handler, join_fn=custom_join)\n",
+        "       | \"RunInference\" >> RunInference(model_handler)\n",
+        "       | \"Format Output\" >> beam.ParDo(PostProcessor())\n",
+        "       )\n"
+      ]
+    }
+  ],
+  "metadata": {
+    "colab": {
+      "provenance": [],
+      "toc_visible": true
+    },
+    "kernelspec": {
+      "display_name": "Python 3",
+      "name": "python3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 0
+}
\ No newline at end of file