You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/06/13 12:58:31 UTC

[GitHub] [beam] tvalentyn commented on a diff in pull request #21818: Add Bert Language Modeling example

tvalentyn commented on code in PR #21818:
URL: https://github.com/apache/beam/pull/21818#discussion_r895686758


##########
sdks/python/apache_beam/examples/inference/pytorch_bert.py:
##########
@@ -0,0 +1,214 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+""""A pipeline that uses RunInference to perform Language Modeling with Bert.
+
+This pipeline takes sentences from the bookcorpus dataset, removes the last word
+of the sentence, and then uses the BertForMaskedLM from Hugging Face to predict
+the best word to follow or continue that sentence given all the words already in
+the sentence. The pipeline then writes the prediction to an output file in
+which users can then compare against the original sentence.
+"""
+
+import argparse
+from typing import Iterable
+from typing import Dict
+from typing import Tuple
+
+import apache_beam as beam
+import torch
+from apache_beam.ml.inference.api import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.pytorch_inference import PytorchModelHandler
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from transformers import BertTokenizer, BertForMaskedLM, BertConfig
+
+BERT_TOKENIZER = BertTokenizer.from_pretrained('bert-base-uncased')
+
+
+def add_mask_to_last_word(text: str) -> Tuple[str, str]:
+  text_list = text.split()
+  return text, ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]])
+
+
+def tokenize_sentence(
+    text_and_mask: Tuple[str, str]) -> Tuple[str, Dict[str, torch.Tensor]]:
+  text, masked_text = text_and_mask
+  tokenized_sentence = BERT_TOKENIZER.encode_plus(
+      masked_text, return_tensors="pt")
+
+  # Workaround to manually remove batch dim until we have the feature to
+  # add optional batching flag. TODO: Remove once optional batching flag added
+  return text, {
+      k: torch.squeeze(v)
+      for k, v in dict(tokenized_sentence).items()
+  }
+
+
+class PostProcessor(beam.DoFn):
+  def process(
+      self, element: Tuple[str, PredictionResult]) -> Iterable[Tuple[str, str]]:
+    text, prediction_result = element
+    inputs = prediction_result.example
+    logits = prediction_result.inference['logits']
+    mask_token_index = (
+        inputs['input_ids'] == BERT_TOKENIZER.mask_token_id).nonzero(
+            as_tuple=True)[0]
+    predicted_token_id = logits[mask_token_index].argmax(axis=-1)
+    decoded_text = BERT_TOKENIZER.decode(predicted_token_id)
+    yield (text, decoded_text)
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input',
+      dest='input',
+      default=
+      'gs://apache-beam-ml/datasets/bookcorpus/bookcorpus_subset.parquet',
+      help='Path to the text file containing image names.')
+  parser.add_argument(
+      '--output',
+      dest='output',
+      help='Path where to save output predictions.'
+      ' text file.')
+  parser.add_argument(
+      '--model_state_dict_path',
+      dest='model_state_dict_path',
+      default='/Users/yeandy/Downloads/'
+      'huggingface.BertForMaskedLM.bert-base-uncased.pth',
+      help="Path to the model's state_dict. "
+      "Default state_dict would be for the bert-base-uncased model.")
+  return parser.parse_known_args(argv)
+
+
+def run(argv=None, model_class=None, model_params=None, save_main_session=True):
+  """
+  Args:
+    argv: Command line arguments defined for this example.
+    model_class: Reference to the class definition of the model.
+                If None, BertForMaskedLM will be used as default .
+    model_params: Parameters passed to the constructor of the model_class.
+                  These will be used to instantiate the model object in the
+                  RunInference API.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  if not model_class:
+    model_config = BertConfig(is_decoder=False, return_dict=True)
+    model_class = BertForMaskedLM
+    model_params = {'config': model_config}
+
+  # TODO: Remove once optional batching flag added
+  class HuggingFaceStripBatchingWrapper(model_class):
+    """Wrapper class to convert output from dict of lists to list of dicts
+
+    The `forward()` function in Hugging Face models don't return a just a
+    standard torch.Tensor output. Instead, they can return a dictionary of
+    different outputs. To work with current RunInference implementation which
+    returns a PredictionResult object for each example, we must override the
+    `forward()` function and convert the standard Hugging Face forward output
+    into the appropriate format of List[Dict[str, torch.Tensor]].
+
+    Before:
+    output = {
+      'logit': torch.FloatTensor of shape
+        (batch_size, sequence_length, config.vocab_size),
+      'hidden_states': tuple(torch.FloatTensor) of shape
+        (batch_size, sequence_length, hidden_size)
+    }
+    After:
+    output = [
+      {
+        'logit': torch.FloatTensor of shape
+          (sequence_length, config.vocab_size),
+        'hidden_states': tuple(torch.FloatTensor) of
+          shape (sequence_length, hidden_size)
+      },
+      {
+        'logit': torch.FloatTensor of shape
+          (sequence_length, config.vocab_size),
+        'hidden_states': tuple(torch.FloatTensor) of shape
+          (sequence_length, hidden_size)
+      },
+      ...
+    ]
+    where len(output) is batch_size
+    """
+    def forward(self, **kwargs):
+      output = super().forward(**kwargs)
+      return [dict(zip(output, v)) for v in zip(*output.values())]
+
+  # TODO: Remove once nested tensors https://github.com/pytorch/nestedtensor
+  # is officially released.
+  class PytorchNoBatchModelHandler(PytorchModelHandler):
+    """Wrapper to PytorchModelHandler to limit batch size to 1.
+
+    The tokenized strings generated from BertTokenizer may have different
+    lengths, which doesn't work with torch.stack() in current RunInference
+    implementation since stack() requires tensors to be the same size.
+
+    Restricting max_batch_size to 1 means there is only 1 example per `batch`
+    in the run_inference() call.
+    """
+    def batch_elements_kwargs(self):
+      return {'max_batch_size': 1}
+
+  model_handler = PytorchNoBatchModelHandler(
+      state_dict_path=known_args.model_state_dict_path,
+      model_class=HuggingFaceStripBatchingWrapper,
+      model_params=model_params)
+
+  with beam.Pipeline(options=pipeline_options) as p:
+    text = (
+        p
+        | 'ReadSentences' >> beam.io.ReadFromParquet(known_args.input)
+        | 'ExtractTextFromDict' >> beam.Map(lambda x: x['text']))
+    text_and_masked_text_tuple = (
+        text
+        | 'AddMask' >> beam.Map(add_mask_to_last_word))
+    text_and_tokenized_text_tuple = (
+        text_and_masked_text_tuple
+        | 'TokenizeSentence' >> beam.Map(tokenize_sentence))
+    text_and_predictions = (
+        text_and_tokenized_text_tuple
+        |
+        'PyTorchRunInference' >> RunInference(model_handler).with_output_types(
+            Tuple[str, PredictionResult])

Review Comment:
   is this still necessary with the proposed changes to make keyed/non-keyed model handlers as separate classes?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org