You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/09/06 15:12:23 UTC

[GitHub] [beam] liferoad commented on a diff in pull request #23035: Add one NER example to use a spaCy model with RunInference

liferoad commented on code in PR #23035:
URL: https://github.com/apache/beam/pull/23035#discussion_r963833358


##########
examples/notebooks/beam-ml/run_custom_inference.ipynb:
##########
@@ -0,0 +1,575 @@
+{
+  "cells": [
+    {
+      "cell_type": "code",
+      "execution_count": 1,
+      "id": "C1rAsD2L-hSO",
+      "metadata": {
+        "cellView": "form",
+        "id": "C1rAsD2L-hSO"
+      },
+      "outputs": [],
+      "source": [
+        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
+        "\n",
+        "# Licensed to the Apache Software Foundation (ASF) under one\n",
+        "# or more contributor license agreements. See the NOTICE file\n",
+        "# distributed with this work for additional information\n",
+        "# regarding copyright ownership. The ASF licenses this file\n",
+        "# to you under the Apache License, Version 2.0 (the\n",
+        "# \"License\"); you may not use this file except in compliance\n",
+        "# with the License. You may obtain a copy of the License at\n",
+        "#\n",
+        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing,\n",
+        "# software distributed under the License is distributed on an\n",
+        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+        "# KIND, either express or implied. See the License for the\n",
+        "# specific language governing permissions and limitations\n",
+        "# under the License.\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "b6f8f3af-744e-4eaa-8a30-6d03e8e4d21e",
+      "metadata": {
+        "id": "b6f8f3af-744e-4eaa-8a30-6d03e8e4d21e"
+      },
+      "source": [
+        "# Bring your own Machine Leanring (ML) model to Beam RunInference\n",
+        "\n",
+        "<button>\n",
+        "  <a href=\"https://beam.apache.org/documentation/sdks/python-machine-learning/\">\n",
+        "    <img src=\"https://beam.apache.org/images/favicon.ico\" alt=\"Open the docs\" height=\"16\"/>\n",
+        "    Beam RunInference\n",
+        "  </a>\n",
+        "</button>\n",
+        "\n",
+        "In this notebook, we walk through a simple example to show how to customize your own ML model handler using\n",
+        "[ModelHandler](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.ModelHandler).\n",
+        "\n",
+        "Named-Entity Recognition (NER) is one of the most common tasks for Natural Language Processing (NLP), \n",
+        "which locates and classifies named entities in unstructured text into pre-defined labels such as person name, organization, date, etc. \n",
+        "In this example, we illustrate how to use the popular spaCy package to load a ML model and apply it inside a Beam pipeline.\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "id": "299af9bb-b2fc-405c-96e7-ee0a6ae24bdd",
+      "metadata": {
+        "id": "299af9bb-b2fc-405c-96e7-ee0a6ae24bdd"
+      },
+      "source": [
+        "## Package Dependencies\n",
+        "\n",
+        "The RunInference library is available in Apache Beam version <b>2.40</b> or later.\n",
+        "\n",
+        "`spaCy` and `pandas` need to be installed. Here, a small NER model (`en_core_web_sm`) is also installed but any valid spaCy model could be used."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 2,
+      "id": "7f841596-f217-46d2-b64e-1952db4de4cb",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "7f841596-f217-46d2-b64e-1952db4de4cb",
+        "outputId": "da04ccb9-0801-47f6-ec9e-e87f0ca4569f"
+      },
+      "outputs": [],
+      "source": [
+        "# uncomment these to install the required packages\n",
+        "# %pip install spacy pandas\n",
+        "# %pip install \"apache-beam[gcp, dataframe, interactive]\"\n",
+        "# !python -m spacy download en_core_web_sm"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {},
+      "source": [
+        "## Let us play with spaCy first"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 3,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# create a spaCy language\n",
+        "\n",
+        "import spacy\n",
+        "\n",
+        "nlp = spacy.load(\"en_core_web_sm\")\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 4,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# some text strings for fun\n",
+        "text_strings = [\n",
+        "    \"The New York Times is an American daily newspaper based in New York City with a worldwide readership.\",\n",
+        "    \"It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.\"\n",
+        "]\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 5,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# check what entities spaCy can recognize\n",
+        "doc = nlp(text_strings[0])\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 6,
+      "metadata": {},
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "The New York Times 0 18 ORG\n",
+            "American 25 33 NORP\n",
+            "daily 34 39 DATE\n",
+            "New York City 59 72 GPE\n"
+          ]
+        }
+      ],
+      "source": [
+        "for ent in doc.ents:\n",
+        "    print(ent.text, ent.start_char, ent.end_char, ent.label_)\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 7,
+      "metadata": {},
+      "outputs": [
+        {
+          "data": {
+            "text/html": [
+              "<span class=\"tex2jax_ignore\"><div class=\"entities\" style=\"line-height: 2.5; direction: ltr\">\n",
+              "<mark class=\"entity\" style=\"background: #7aecec; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
+              "    The New York Times\n",
+              "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">ORG</span>\n",
+              "</mark>\n",
+              " is an \n",
+              "<mark class=\"entity\" style=\"background: #c887fb; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
+              "    American\n",
+              "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">NORP</span>\n",
+              "</mark>\n",
+              " \n",
+              "<mark class=\"entity\" style=\"background: #bfe1d9; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
+              "    daily\n",
+              "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">DATE</span>\n",
+              "</mark>\n",
+              " newspaper based in \n",
+              "<mark class=\"entity\" style=\"background: #feca74; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
+              "    New York City\n",
+              "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">GPE</span>\n",
+              "</mark>\n",
+              " with a worldwide readership.</div></span>"
+            ],
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        }
+      ],
+      "source": [
+        "# visualize the results\n",
+        "from spacy import displacy\n",
+        "displacy.render(doc, style=\"ent\")\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 8,
+      "metadata": {},
+      "outputs": [
+        {
+          "data": {
+            "text/html": [
+              "<span class=\"tex2jax_ignore\"><div class=\"entities\" style=\"line-height: 2.5; direction: ltr\">It was founded in \n",
+              "<mark class=\"entity\" style=\"background: #bfe1d9; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
+              "    1851\n",
+              "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">DATE</span>\n",
+              "</mark>\n",
+              " by \n",
+              "<mark class=\"entity\" style=\"background: #aa9cfc; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
+              "    Henry Jarvis\n",
+              "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">PERSON</span>\n",
+              "</mark>\n",
+              " \n",
+              "<mark class=\"entity\" style=\"background: #aa9cfc; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
+              "    Raymond\n",
+              "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">PERSON</span>\n",
+              "</mark>\n",
+              " and \n",
+              "<mark class=\"entity\" style=\"background: #aa9cfc; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
+              "    George Jones\n",
+              "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">PERSON</span>\n",
+              "</mark>\n",
+              ", and was initially published by \n",
+              "<mark class=\"entity\" style=\"background: #7aecec; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;\">\n",
+              "    Raymond, Jones &amp; Company\n",
+              "    <span style=\"font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; vertical-align: middle; margin-left: 0.5rem\">ORG</span>\n",
+              "</mark>\n",
+              ".</div></span>"
+            ],
+            "text/plain": [
+              "<IPython.core.display.HTML object>"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        }
+      ],
+      "source": [
+        "# another example\n",
+        "displacy.render(nlp(text_strings[1]), style=\"ent\")"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {},
+      "source": [
+        "## Now time to create our own `ModelHandler` to use spaCy for inference"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 9,
+      "metadata": {},
+      "outputs": [
+        {
+          "data": {
+            "application/javascript": "\n        if (typeof window.interactive_beam_jquery == 'undefined') {\n          var jqueryScript = document.createElement('script');\n          jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n          jqueryScript.type = 'text/javascript';\n          jqueryScript.onload = function() {\n            var datatableScript = document.createElement('script');\n            datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n            datatableScript.type = 'text/javascript';\n            datatableScript.onload = function() {\n              window.interactive_beam_jquery = jQuery.noConflict(true);\n              window.interactive_beam_jquery(document).ready(function($){\n                \n              });\n            }\n            document.head.appendChild(datatableScript);\n          };\n          document.head.appendChild(jqueryScript);\n        } else {\n          window.interactive
 _beam_jquery(document).ready(function($){\n            \n          });\n        }"
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "The New York Times is an American daily newspaper based in New York City with a worldwide readership.\n",
+            "It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.\n"
+          ]
+        }
+      ],
+      "source": [
+        "\n",
+        "import apache_beam as beam\n",
+        "from apache_beam.options.pipeline_options import PipelineOptions\n",
+        "\n",
+        "import warnings\n",
+        "warnings.filterwarnings(\"ignore\")\n",
+        "\n",
+        "\n",
+        "pipeline = beam.Pipeline()\n",
+        "\n",
+        "# only print the results to check\n",
+        "with pipeline as p:\n",
+        "    (p \n",
+        "    | \"CreateSentences\" >> beam.Create(text_strings)\n",
+        "    | beam.Map(print)\n",
+        "    )\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 10,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# Now define SpacyModelHandler to load the model and perform the inference\n",
+        "\n",
+        "from apache_beam.ml.inference.base import RunInference\n",
+        "from apache_beam.ml.inference.base import ModelHandler\n",
+        "from apache_beam.ml.inference.base import PredictionResult\n",
+        "from spacy import Language\n",
+        "from typing import Any\n",
+        "from typing import Dict\n",
+        "from typing import Iterable\n",
+        "from typing import Optional\n",
+        "from typing import Sequence\n",
+        "\n",
+        "class SpacyModelHandler(ModelHandler[str,\n",
+        "                                     PredictionResult,\n",
+        "                                     Language]):\n",
+        "    def __init__(\n",
+        "        self,\n",
+        "        model_name: str = \"en_core_web_sm\",\n",
+        "    ):\n",
+        "        \"\"\" Implementation of the ModelHandler interface for spaCy using text as input.\n",
+        "\n",
+        "        Example Usage::\n",
+        "\n",
+        "          pcoll | RunInference(SpacyModelHandler())\n",
+        "\n",
+        "        Args:\n",
+        "          model_name: The spaCy model name. Default is en_core_web_sm.\n",
+        "        \"\"\"\n",
+        "        self._model_name = model_name\n",
+        "\n",
+        "    def load_model(self) -> Language:\n",
+        "        \"\"\"Loads and initializes a model for processing.\"\"\"\n",
+        "        return spacy.load(self._model_name)\n",
+        "\n",
+        "    def run_inference(\n",
+        "        self,\n",
+        "        batch: Sequence[str],\n",
+        "        model: Language,\n",
+        "        inference_args: Optional[Dict[str, Any]] = None\n",
+        "    ) -> Iterable[PredictionResult]:\n",
+        "        \"\"\"Runs inferences on a batch of text strings.\n",
+        "\n",
+        "        Args:\n",
+        "          batch: A sequence of examples as text strings. \n",
+        "          model: A spaCy language model\n",
+        "          inference_args: Any additional arguments for an inference.\n",
+        "\n",
+        "        Returns:\n",
+        "          An Iterable of type PredictionResult.\n",
+        "        \"\"\"\n",
+        "        # loop each text string and use tuple to store the inference results\n",
+        "        predictions = []\n",
+        "        for one_text in batch:\n",
+        "            doc = model(one_text)\n",
+        "            predictions.append(\n",
+        "                [(ent.text, ent.start_char, ent.end_char, ent.label_) for ent in doc.ents])\n",
+        "        return [PredictionResult(x, y) for x, y in zip(batch, predictions)]\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 11,
+      "metadata": {},
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "The New York Times is an American daily newspaper based in New York City with a worldwide readership.\n",
+            "It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.\n",
+            "PredictionResult(example='The New York Times is an American daily newspaper based in New York City with a worldwide readership.', inference=[('The New York Times', 0, 18, 'ORG'), ('American', 25, 33, 'NORP'), ('daily', 34, 39, 'DATE'), ('New York City', 59, 72, 'GPE')])\n",
+            "PredictionResult(example='It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.', inference=[('1851', 18, 22, 'DATE'), ('Henry Jarvis', 26, 38, 'PERSON'), ('Raymond', 39, 46, 'PERSON'), ('George Jones', 51, 63, 'PERSON'), ('Raymond, Jones & Company', 96, 120, 'ORG')])\n"
+          ]
+        }
+      ],
+      "source": [
+        "# quick check to show the inference results are correct\n",
+        "with pipeline as p:\n",
+        "    (p \n",
+        "    | \"CreateSentences\" >> beam.Create(text_strings)\n",
+        "    | \"RunInferenceSpacy\" >> RunInference(SpacyModelHandler(\"en_core_web_sm\"))\n",
+        "    | beam.Map(print)\n",
+        "    )\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {},
+      "source": [
+        "## Use `KeyedModelHandler` to handle keyed data"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 12,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "# some text strings with keys, which are useful to distinguish examples\n",
+        "text_strings_with_keys = [\n",
+        "    (\"example_0\", \"The New York Times is an American daily newspaper based in New York City with a worldwide readership.\"),\n",
+        "    (\"example_1\", \"It was founded in 1851 by Henry Jarvis Raymond and George Jones, and was initially published by Raymond, Jones & Company.\")\n",
+        "]\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 13,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n",
+        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
+        "from apache_beam.dataframe.convert import to_dataframe\n",
+        "\n",
+        "pipeline = beam.Pipeline(InteractiveRunner())\n",
+        "\n",
+        "keyed_spacy_model_handler = KeyedModelHandler(SpacyModelHandler(\"en_core_web_sm\"))\n",
+        "\n",
+        "# quick check to show the inference results are correct\n",
+        "with pipeline as p:\n",
+        "    results = (p \n",
+        "    | \"CreateSentences\" >> beam.Create(text_strings_with_keys)\n",
+        "    | \"RunInferenceSpacy\" >> RunInference(keyed_spacy_model_handler)\n",
+        "    # Map to Row objects to generate a schema suitable for conversion\n",
+        "    # to a dataframe.\n",
+        "    | 'ToRows' >> beam.Map(lambda row: beam.Row(key=row[0], text=row[1][0], predictions=row[1][1]))\n",

Review Comment:
   The idea is to show another way to quickly debug the pipeline using DataFrame besides print. This is why I did it differently for this part.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org