You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2023/01/14 23:12:11 UTC

[GitHub] [beam] chamikaramj commented on a diff in pull request #24656: Cross Language RunInference

chamikaramj commented on code in PR #24656:
URL: https://github.com/apache/beam/pull/24656#discussion_r1070270963


##########
website/www/site/content/en/documentation/ml/multi-language-inference.md:
##########
@@ -0,0 +1,159 @@
+---
+title: "Cross Language RunInference  "
+---
+<!--
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+# Cross Language RunInference
+
+This Cross Language RunInference example shows how to use the [RunInference](https://beam.apache.org/documentation/ml/overview/#runinference)
+Transform in a multi-language pipeline. The pipeline is in Java and reads the input data from
+GCS. With the help of a [PythonExternalTransform](https://beam.apache.org/documentation/programming-guide/#1312-creating-cross-language-python-transforms)
+a composite python transform is called that does the preprocessing, postprocessing and inference.
+Lastly, the data is written back to GCS in the Java pipeline.
+
+## NLP model and dataset
+A `bert-base-uncased` model is used to make inference, which is an open-source model
+available on [HuggingFace](https://huggingface.co/bert-base-uncased). This BERT-model will be
+used to predict the last word of a sentence, based on the context of the sentence.
+
+Next to this we also use an [IMDB movie reviews](https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews?select=IMDB+Dataset.csv) dataset, which is  an open-source dataset that is available on Kaggle.  A sample of the data after preprocessing is shown below:
+
+| **Text** 	|   **Last Word** 	|
+|---	|:---	|
+|<img width=700/>|<img width=100/>|
+| One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be [MASK] 	| hooked 	|
+| A wonderful little [MASK] 	| production 	|
+| So im not a big fan of Boll's work but then again not many [MASK] 	| are 	|
+| This a fantastic movie of three prisoners who become [MASK] 	| famous 	|
+| Some films just simply should not be [MASK] 	| remade 	|
+| The Karen Carpenter Story shows a little more about singer Karen Carpenter's complex [MASK] 	| life 	|
+
+The full code used in this example can be found on GitHub [here](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference/multi_language_inference).
+
+
+## Multi-language RunInference pipeline
+### Cross-Language Python transform
+Next to making inference on the data, we also need to perform preprocessing and postprocessing on the data. This way the pipeline gives clean output that is easily interpreted.  In order to do these three tasks, one single composite custom Ptransform is written, with a unit DoFn or PTransform for each of the tasks as shown below:
+
+```python
+def expand(self, pcoll):
+    return (
+    pcoll
+    | 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer))
+    | 'Inference' >> RunInference(KeyedModelHandler(self._model_handler))
+    | 'Postprocess' >> beam.ParDo(self.Postprocess(
+        self._tokenizer)).with_input_types(typing.Iterable[str])
+    )
+```
+
+First, the preprocessing is done. In which the raw textual data is cleaned and tokenized for the BERT-model. All these steps are executed in the `Preprocess` DoFn. The `Preprocess` DoFn takes a single element as input and returns list with the original text and the tokenized text.
+
+The preprocessed data is then used to make inference. This is done in the [`RunInference`](https://beam.apache.org/documentation/ml/overview/#runinference) PTransform, which is already available in the Apache Beam SDK. The `RunInference` PTransform requires one parameter, a modelhandler. In this example the `KeyedModelHandler` is used, because the `Preprocess` Dofn also output the original sentence. Ofcourse, this is personal preference and can be changed to the needs of the end-user. This modelhandler is defined it this initialization function of the composite Ptransform. This section is shown below:
+
+```python
+def __init__(self, model, model_path):
+    self._model = model
+    logging.info(f"Downloading {self._model} model from GCS.")
+    self._model_config = BertConfig.from_pretrained(self._model)
+    self._tokenizer = BertTokenizer.from_pretrained(self._model)
+    self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper(
+        state_dict_path=(model_path),
+        model_class=BertForMaskedLM,
+        model_params={'config': self._model_config},
+        device='cuda:0')
+```
+We can see that the `PytorchModelHandlerKeyedTensorWrapper` is used. This is a wrapper around the `PytorchModelHandlerKeyedTensor` modelhandler. The `PytorchModelHandlerKeyedTensor` modelhandler is used to make inference on a PyTorch model. The `PytorchModelHandlerKeyedTensorWrapper` is used to limit the batch size to 1. This is done because the tokenized strings generated from BertTokenizer may have different lengths, which doesn't work with torch.stack() in current RunInference implementation since stack() requires tensors to be the same size. Restricting max_batch_size to 1 means there is only 1 example per batch in the run_inference() call. The definition of the wrapper is shown below:
+
+```python
+class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor):
+
+    def batch_elements_kwargs(self):
+      return {'max_batch_size': 1}
+```
+
+Next to the definition of the modelhandler, the ModelConfig and ModelTokenizer are loaded in the initialization function. The ModelConfig is used to define the model architecture and the ModelTokenizer is used to tokenize the input data. This is done with the following two parameters:
+- `model`: The name of the model that is used for inference. In this example it is `bert-base-uncased`.
+- `model_path`: The path to the state_dict of the model that is used for inference. In this example it is a path to a GCS bucket, where the state_dict is stored.
+
+Both these parameters specified in the Java PipelineOptions.
+
+Finally the predictions of the model are postprocessed. This is done in the `Postprocess` DoFn. The `Postprocess` DoFn returns the original text, the last word of the sentence and the predicted word.
+
+### Set up the expansion service

Review Comment:
   Add a note "This step is not needed for Beam 2.44.0 and later".



##########
sdks/python/apache_beam/examples/inference/multi_language_inference/README.md:
##########
@@ -0,0 +1,58 @@
+<!--
+    Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+    Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+-->
+## Setting up the Expansion service
+Because we can not add local packages in Beam 2.43 we must create our own expansion service.

Review Comment:
   Support for using local packages with the default expansion service is available in 2.44. After that nether a custom expansion service nor a custom container should be needed. It should be possible to just develop additional classes as a local Python package and specify it via "withExtraPackages" option.
   https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java#L141
   
   I think we should add a note like "Skip this for Beam 2.44.0 and later". 



##########
sdks/python/apache_beam/examples/inference/multi_language_inference/last_word_prediction/src/main/java/org/MultiLangRunInference.java:
##########
@@ -0,0 +1,97 @@
+package org;
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.runners.core.construction.External;
+import org.apache.beam.sdk.extensions.python.PythonExternalTransform;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.options.Validation.Required;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.util.ByteStringOutputStream;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.PBegin;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class MultiLangRunInference {
+    public interface MultiLangueageOptions extends PipelineOptions {
+
+        @Description("Path to an input file that contains labels and pixels to feed into the model")
+        @Required
+        String getInputFile();
+
+        void setInputFile(String value);
+
+        @Description("Path to a stored model.")
+        @Required
+        String getModelPath();
+
+        void setModelPath(String value);
+
+        @Description("Path to an input file that contains labels and pixels to feed into the model")
+        @Required
+        String getOutputFile();
+
+        void setOutputFile(String value);
+
+        @Description("Name of the model on HuggingFace.")
+        @Required
+        String getModelName();
+
+        void setModelName(String value);
+
+        @Description("Port number of the expansion service.")
+        @Required
+        String getPort();
+
+        void setPort(String value);
+    }
+
+    public static void main(String[] args) {
+
+        MultiLangueageOptions options = PipelineOptionsFactory.fromArgs(args).withValidation()
+                .as(MultiLangueageOptions.class);
+        
+        Pipeline p = Pipeline.create(options);
+        PCollection<String> input = p.apply("Read Input", TextIO.read().from(options.getInputFile()));
+    
+        input.apply("Predict", PythonExternalTransform.<PCollection<String>, PCollection<String>>from(
+            "expansion_service.run_inference_expansion.RunInferenceTransform", "localhost:" + options.getPort())
+            .withKwarg("model",  options.getModelName())
+            .withKwarg("model_path", options.getModelPath()))

Review Comment:
   Let's also provide simplified instructions once Beam 2.44.0 is finalized (setup a package and use "withExtraPackages" here).



##########
sdks/python/apache_beam/examples/inference/multi_language_inference/last_word_prediction/src/main/java/org/MultiLangRunInference.java:
##########
@@ -0,0 +1,97 @@
+package org;
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.runners.core.construction.External;
+import org.apache.beam.sdk.extensions.python.PythonExternalTransform;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.options.Validation.Required;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.util.ByteStringOutputStream;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.PBegin;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class MultiLangRunInference {
+    public interface MultiLangueageOptions extends PipelineOptions {
+
+        @Description("Path to an input file that contains labels and pixels to feed into the model")
+        @Required
+        String getInputFile();
+
+        void setInputFile(String value);
+
+        @Description("Path to a stored model.")
+        @Required
+        String getModelPath();
+
+        void setModelPath(String value);
+
+        @Description("Path to an input file that contains labels and pixels to feed into the model")
+        @Required
+        String getOutputFile();
+
+        void setOutputFile(String value);
+
+        @Description("Name of the model on HuggingFace.")
+        @Required
+        String getModelName();
+
+        void setModelName(String value);
+
+        @Description("Port number of the expansion service.")
+        @Required
+        String getPort();
+
+        void setPort(String value);
+    }
+
+    public static void main(String[] args) {
+
+        MultiLangueageOptions options = PipelineOptionsFactory.fromArgs(args).withValidation()
+                .as(MultiLangueageOptions.class);
+        
+        Pipeline p = Pipeline.create(options);
+        PCollection<String> input = p.apply("Read Input", TextIO.read().from(options.getInputFile()));
+    
+        input.apply("Predict", PythonExternalTransform.<PCollection<String>, PCollection<String>>from(
+            "expansion_service.run_inference_expansion.RunInferenceTransform", "localhost:" + options.getPort())

Review Comment:
   Ditto regarding updating the package name.



##########
sdks/python/apache_beam/examples/inference/multi_language/expansion_service_package/expansion_service/run_inference_expansion.py:
##########
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# pytype: skip-file
+
+import argparse
+import logging
+import signal
+import sys
+import typing
+
+import apache_beam as beam
+from apache_beam.coders import RowCoder
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
+from apache_beam.pipeline import PipelineOptions
+from apache_beam.transforms import ptransform
+from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
+from transformers import BertConfig
+from transformers import BertForMaskedLM
+from transformers import BertTokenizer
+
+# This script provides a run inference transform with pre and post processing.
+# The model used is a BertLM, base uncased model.
+_LOGGER = logging.getLogger(__name__)
+
+# This URN will be used to register a transform that runs inference on a BERT model.
+TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert"

Review Comment:
   This URN (and references to it) should not be needed anymore.



##########
sdks/python/apache_beam/examples/inference/multi_language_inference/expansion_service/run_inference_expansion.py:
##########
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+import signal
+import sys
+import typing
+
+import apache_beam as beam
+from apache_beam.coders import RowCoder
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
+from apache_beam.pipeline import PipelineOptions
+from apache_beam.transforms import ptransform
+from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
+from transformers import BertConfig
+from transformers import BertForMaskedLM
+from transformers import BertTokenizer
+
+# This script provides a run inference transform with pre and post processing.
+# The model used is a BertLM, base uncased model.
+_LOGGER = logging.getLogger(__name__)
+
+# This URN will be used to register a transform that runs inference on a BERT model.
+TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert"
+
+
+@ptransform.PTransform.register_urn(TEST_RUN_BERT_URN, None)
+class RunInferenceTransform(ptransform.PTransform):
+  class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor):
+    """Wrapper to PytorchModelHandler to limit batch size to 1.
+        The tokenized strings generated from BertTokenizer may have different
+        lengths, which doesn't work with torch.stack() in current RunInference
+        implementation since stack() requires tensors to be the same size.
+        Restricting max_batch_size to 1 means there is only 1 example per
+        `batch` in the run_inference() call.
+        """
+    def batch_elements_kwargs(self):
+      return {'max_batch_size': 1}
+
+  class Preprocess(beam.DoFn):
+    def __init__(self, tokenizer):
+      self._tokenizer = tokenizer
+      logging.info('Starting Preprocess.')
+
+    def process(self, text: str):
+      import torch
+      # remove unusable tokens marks.
+      removable_tokens = ['"', '*', '<br />', "'", "(", ")"]
+      for token in removable_tokens:
+        text = text.replace(token, '')
+
+      # only take first sentence.
+      ending_chars = ['.', '!', '?']
+      for char in ending_chars:
+        if char in text:
+          text = text.split(char)[0]
+
+      # add dot to end of sentence.
+      text = text + ' .'
+
+      # mask the last word and drop very long sentences.
+      if len(text.strip()) > 0 and len(text.strip()) < 512:
+        logging.info('Preprocessing Line: %s', text)
+        text_list = text.split()
+        masked_text = ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]])
+        tokens = self._tokenizer(masked_text, return_tensors='pt')
+        tokens = {key: torch.squeeze(val) for key, val in tokens.items()}
+
+        # skip first row of csv file.
+        if "review,sentiment" not in text.strip():
+          return [(text, tokens)]
+
+  class Postprocess(beam.DoFn):
+    def __init__(self, bert_tokenizer):
+      self.bert_tokenizer = bert_tokenizer
+      logging.info('Starting Postprocess')
+
+    def process(self, element: typing.Tuple[str, PredictionResult]) \
+        -> typing.Iterable[str]:
+      text, prediction_result = element
+      inputs = prediction_result.example
+      logits = prediction_result.inference['logits']
+      mask_token_index = (
+          inputs['input_ids'] == self.bert_tokenizer.mask_token_id).nonzero(
+              as_tuple=True)[0]
+      predicted_token_id = logits[mask_token_index].argmax(axis=-1)
+      decoded_word = self.bert_tokenizer.decode(predicted_token_id)
+      text = text.replace('.', '').strip()
+      yield (
+          f"{text} \n Predicted word: {decoded_word.upper()} -- "
+          f"Actual word: {text.split()[-1].upper()}")
+
+  def __init__(self, model, model_path):
+    self._model = model
+    logging.info(f"Downloading {self._model} model from GCS.")
+    self._model_config = BertConfig.from_pretrained(self._model)
+    self._tokenizer = BertTokenizer.from_pretrained(self._model)
+    self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper(
+        state_dict_path=(model_path),
+        model_class=BertForMaskedLM,
+        model_params={'config': self._model_config},
+        device='cuda:0')
+
+  def expand(self, pcoll):
+    return (
+        pcoll
+        | 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer))
+        | 'Inference' >> RunInference(KeyedModelHandler(self._model_handler))
+        | 'Postprocess' >> beam.ParDo(self.Postprocess(
+            self._tokenizer)).with_input_types(typing.Iterable[str]))

Review Comment:
   What is the error if this is removed ? I would have expected this to get auto inferred as well (given the the predecessor is also Python).
   
   For x-lang, usually it's needed if Python type inferencing ends up choosing the default PickleCoder (for example, due to type being mapped to Any).



##########
sdks/python/apache_beam/examples/inference/multi_language_inference/expansion_service/__init__.py:
##########
@@ -0,0 +1,15 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with

Review Comment:
   Package "expansion_service" is too generic. Can you move to a package name that is specific to this example ?



##########
sdks/python/apache_beam/examples/inference/multi_language_inference/expansion_service/run_inference_expansion.py:
##########
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import logging
+import signal
+import sys
+import typing
+
+import apache_beam as beam
+from apache_beam.coders import RowCoder
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
+from apache_beam.pipeline import PipelineOptions
+from apache_beam.transforms import ptransform
+from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
+from transformers import BertConfig
+from transformers import BertForMaskedLM
+from transformers import BertTokenizer
+
+# This script provides a run inference transform with pre and post processing.
+# The model used is a BertLM, base uncased model.
+_LOGGER = logging.getLogger(__name__)
+
+# This URN will be used to register a transform that runs inference on a BERT model.
+TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert"

Review Comment:
   This should not be needed anymore.



##########
website/www/site/content/en/documentation/ml/multi-language-inference.md:
##########
@@ -0,0 +1,159 @@
+---
+title: "Cross Language RunInference  "
+---
+<!--
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+# Cross Language RunInference
+
+This Cross Language RunInference example shows how to use the [RunInference](https://beam.apache.org/documentation/ml/overview/#runinference)
+Transform in a multi-language pipeline. The pipeline is in Java and reads the input data from
+GCS. With the help of a [PythonExternalTransform](https://beam.apache.org/documentation/programming-guide/#1312-creating-cross-language-python-transforms)
+a composite python transform is called that does the preprocessing, postprocessing and inference.
+Lastly, the data is written back to GCS in the Java pipeline.
+
+## NLP model and dataset
+A `bert-base-uncased` model is used to make inference, which is an open-source model
+available on [HuggingFace](https://huggingface.co/bert-base-uncased). This BERT-model will be
+used to predict the last word of a sentence, based on the context of the sentence.
+
+Next to this we also use an [IMDB movie reviews](https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews?select=IMDB+Dataset.csv) dataset, which is  an open-source dataset that is available on Kaggle.  A sample of the data after preprocessing is shown below:
+
+| **Text** 	|   **Last Word** 	|
+|---	|:---	|
+|<img width=700/>|<img width=100/>|
+| One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be [MASK] 	| hooked 	|
+| A wonderful little [MASK] 	| production 	|
+| So im not a big fan of Boll's work but then again not many [MASK] 	| are 	|
+| This a fantastic movie of three prisoners who become [MASK] 	| famous 	|
+| Some films just simply should not be [MASK] 	| remade 	|
+| The Karen Carpenter Story shows a little more about singer Karen Carpenter's complex [MASK] 	| life 	|
+
+The full code used in this example can be found on GitHub [here](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference/multi_language_inference).
+
+
+## Multi-language RunInference pipeline
+### Cross-Language Python transform
+Next to making inference on the data, we also need to perform preprocessing and postprocessing on the data. This way the pipeline gives clean output that is easily interpreted.  In order to do these three tasks, one single composite custom Ptransform is written, with a unit DoFn or PTransform for each of the tasks as shown below:
+
+```python
+def expand(self, pcoll):
+    return (
+    pcoll
+    | 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer))
+    | 'Inference' >> RunInference(KeyedModelHandler(self._model_handler))
+    | 'Postprocess' >> beam.ParDo(self.Postprocess(
+        self._tokenizer)).with_input_types(typing.Iterable[str])
+    )
+```
+
+First, the preprocessing is done. In which the raw textual data is cleaned and tokenized for the BERT-model. All these steps are executed in the `Preprocess` DoFn. The `Preprocess` DoFn takes a single element as input and returns list with the original text and the tokenized text.
+
+The preprocessed data is then used to make inference. This is done in the [`RunInference`](https://beam.apache.org/documentation/ml/overview/#runinference) PTransform, which is already available in the Apache Beam SDK. The `RunInference` PTransform requires one parameter, a modelhandler. In this example the `KeyedModelHandler` is used, because the `Preprocess` Dofn also output the original sentence. Ofcourse, this is personal preference and can be changed to the needs of the end-user. This modelhandler is defined it this initialization function of the composite Ptransform. This section is shown below:
+
+```python
+def __init__(self, model, model_path):
+    self._model = model
+    logging.info(f"Downloading {self._model} model from GCS.")
+    self._model_config = BertConfig.from_pretrained(self._model)
+    self._tokenizer = BertTokenizer.from_pretrained(self._model)
+    self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper(
+        state_dict_path=(model_path),
+        model_class=BertForMaskedLM,
+        model_params={'config': self._model_config},
+        device='cuda:0')
+```
+We can see that the `PytorchModelHandlerKeyedTensorWrapper` is used. This is a wrapper around the `PytorchModelHandlerKeyedTensor` modelhandler. The `PytorchModelHandlerKeyedTensor` modelhandler is used to make inference on a PyTorch model. The `PytorchModelHandlerKeyedTensorWrapper` is used to limit the batch size to 1. This is done because the tokenized strings generated from BertTokenizer may have different lengths, which doesn't work with torch.stack() in current RunInference implementation since stack() requires tensors to be the same size. Restricting max_batch_size to 1 means there is only 1 example per batch in the run_inference() call. The definition of the wrapper is shown below:
+
+```python
+class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor):
+
+    def batch_elements_kwargs(self):
+      return {'max_batch_size': 1}
+```
+
+Next to the definition of the modelhandler, the ModelConfig and ModelTokenizer are loaded in the initialization function. The ModelConfig is used to define the model architecture and the ModelTokenizer is used to tokenize the input data. This is done with the following two parameters:
+- `model`: The name of the model that is used for inference. In this example it is `bert-base-uncased`.
+- `model_path`: The path to the state_dict of the model that is used for inference. In this example it is a path to a GCS bucket, where the state_dict is stored.
+
+Both these parameters specified in the Java PipelineOptions.
+
+Finally the predictions of the model are postprocessed. This is done in the `Postprocess` DoFn. The `Postprocess` DoFn returns the original text, the last word of the sentence and the predicted word.
+
+### Set up the expansion service
+Because we are using transforms from two different languages, we need an SDK for each language (in this case Python and Java). Next to this we also need to set up an expansion service. More specifically, the expansion service is used to inject the cross-language Python transform into the Java pipeline. By opting for multi-language pipelines, you have access to a much bigger pool of transforms. More detailed information can be found [here](https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines).
+
+
+Setting up the expansion service is pretty trivial. We just need to run the following command in the terminal:

Review Comment:
   You mean, run this in a virtual env setup for released Beam, correct ? We should clarify and refer to instructions for setting up a Beam Python virtual environment: https://beam.apache.org/get-started/quickstart-py/#create-and-activate-a-virtual-environment
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org