You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jr...@apache.org on 2022/11/22 17:19:51 UTC

[beam] branch master updated: TensorRT Custom Inference Function Implementation (#24039)

This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b5ad9f039ca TensorRT Custom Inference Function Implementation (#24039)
b5ad9f039ca is described below

commit b5ad9f039ca4f70dc2788dad3c2df359ae0405a3
Author: Jack McCluskey <34...@users.noreply.github.com>
AuthorDate: Tue Nov 22 12:19:42 2022 -0500

    TensorRT Custom Inference Function Implementation (#24039)
    
    * Initial TensorRT implementation
    
    * Formatter
    
    * add unit test
    
    * linting
    
    * duplicate _assign_or_fail logic
---
 .../apache_beam/ml/inference/tensorrt_inference.py | 93 ++++++++++++--------
 .../ml/inference/tensorrt_inference_test.py        | 98 ++++++++++++++++++++++
 2 files changed, 156 insertions(+), 35 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
index 8ff65658c6b..5abbe50b329 100644
--- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
@@ -22,6 +22,7 @@ from __future__ import annotations
 import logging
 import threading
 from typing import Any
+from typing import Callable
 from typing import Dict
 from typing import Iterable
 from typing import Optional
@@ -164,11 +165,63 @@ class TensorRTEngine:
         self.stream)
 
 
+TensorRTInferenceFn = Callable[
+    [Sequence[np.ndarray], TensorRTEngine, Optional[Dict[str, Any]]],
+    Iterable[PredictionResult]]
+
+
+def _default_tensorRT_inference_fn(
+    batch: Sequence[np.ndarray],
+    engine: TensorRTEngine,
+    inference_args: Optional[Dict[str,
+                                  Any]] = None) -> Iterable[PredictionResult]:
+  from cuda import cuda
+  (
+      engine,
+      context,
+      context_lock,
+      inputs,
+      outputs,
+      gpu_allocations,
+      cpu_allocations,
+      stream) = engine.get_engine_attrs()
+
+  # Process I/O and execute the network
+  with context_lock:
+    _assign_or_fail(
+        cuda.cuMemcpyHtoDAsync(
+            inputs[0]['allocation'],
+            np.ascontiguousarray(batch),
+            inputs[0]['size'],
+            stream))
+    context.execute_async_v2(gpu_allocations, stream)
+    for output in range(len(cpu_allocations)):
+      _assign_or_fail(
+          cuda.cuMemcpyDtoHAsync(
+              cpu_allocations[output],
+              outputs[output]['allocation'],
+              outputs[output]['size'],
+              stream))
+    _assign_or_fail(cuda.cuStreamSynchronize(stream))
+
+    return [
+        PredictionResult(
+            x, [prediction[idx] for prediction in cpu_allocations]) for idx,
+        x in enumerate(batch)
+    ]
+
+
 @experimental(extra_message="No backwards-compatibility guarantees.")
 class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
                                               PredictionResult,
                                               TensorRTEngine]):
-  def __init__(self, min_batch_size: int, max_batch_size: int, **kwargs):
+  def __init__(
+      self,
+      min_batch_size: int,
+      max_batch_size: int,
+      *,
+      inference_fn: TensorRTInferenceFn = _default_tensorRT_inference_fn,
+      **kwargs):
     """Implementation of the ModelHandler interface for TensorRT.
 
     Example Usage::
@@ -185,6 +238,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     Args:
       min_batch_size: minimum accepted batch size.
       max_batch_size: maximum accepted batch size.
+      inference_fn: the inference function to use on RunInference calls.
+        default: _default_tensorRT_inference_fn
       kwargs: Additional arguments like 'engine_path' and 'onnx_path' are
         currently supported.
 
@@ -193,6 +248,7 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     """
     self.min_batch_size = min_batch_size
     self.max_batch_size = max_batch_size
+    self.inference_fn = inference_fn
     if 'engine_path' in kwargs:
       self.engine_path = kwargs.get('engine_path')
     elif 'onnx_path' in kwargs:
@@ -241,40 +297,7 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     Returns:
       An Iterable of type PredictionResult.
     """
-    from cuda import cuda
-    (
-        engine,
-        context,
-        context_lock,
-        inputs,
-        outputs,
-        gpu_allocations,
-        cpu_allocations,
-        stream) = engine.get_engine_attrs()
-
-    # Process I/O and execute the network
-    with context_lock:
-      _assign_or_fail(
-          cuda.cuMemcpyHtoDAsync(
-              inputs[0]['allocation'],
-              np.ascontiguousarray(batch),
-              inputs[0]['size'],
-              stream))
-      context.execute_async_v2(gpu_allocations, stream)
-      for output in range(len(cpu_allocations)):
-        _assign_or_fail(
-            cuda.cuMemcpyDtoHAsync(
-                cpu_allocations[output],
-                outputs[output]['allocation'],
-                outputs[output]['size'],
-                stream))
-      _assign_or_fail(cuda.cuStreamSynchronize(stream))
-
-      return [
-          PredictionResult(
-              x, [prediction[idx] for prediction in cpu_allocations]) for idx,
-          x in enumerate(batch)
-      ]
+    return self.inference_fn(batch, engine, inference_args)
 
   def get_num_bytes(self, batch: Sequence[np.ndarray]) -> int:
     """
diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py
index d12442ae9a0..1cbc7130200 100644
--- a/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference_test.py
@@ -59,6 +59,14 @@ SINGLE_FEATURE_PREDICTIONS = [
          for example in SINGLE_FEATURE_EXAMPLES])
 ]
 
+SINGLE_FEATURE_CUSTOM_PREDICTIONS = [
+    PredictionResult(ex, pred) for ex,
+    pred in zip(
+        SINGLE_FEATURE_EXAMPLES,
+        [[np.array([(example * 2.0 + 0.5) * 2], dtype=np.float32)]
+         for example in SINGLE_FEATURE_EXAMPLES])
+]
+
 TWO_FEATURES_EXAMPLES = [
     np.array([1, 5], dtype=np.float32),
     np.array([3, 10], dtype=np.float32),
@@ -83,6 +91,58 @@ def _compare_prediction_result(a, b):
       expected in zip(a.inference, b.inference)))
 
 
+def _assign_or_fail(args):
+  """CUDA error checking."""
+  from cuda import cuda
+  err, ret = args[0], args[1:]
+  if isinstance(err, cuda.CUresult):
+    if err != cuda.CUresult.CUDA_SUCCESS:
+      raise RuntimeError("Cuda Error: {}".format(err))
+  else:
+    raise RuntimeError("Unknown error type: {}".format(err))
+  # Special case so that no unpacking is needed at call-site.
+  if len(ret) == 1:
+    return ret[0]
+  return ret
+
+
+def _custom_tensorRT_inference_fn(batch, engine, inference_args):
+  from cuda import cuda
+  (
+      engine,
+      context,
+      context_lock,
+      inputs,
+      outputs,
+      gpu_allocations,
+      cpu_allocations,
+      stream) = engine.get_engine_attrs()
+
+  # Process I/O and execute the network
+  with context_lock:
+    _assign_or_fail(
+        cuda.cuMemcpyHtoDAsync(
+            inputs[0]['allocation'],
+            np.ascontiguousarray(batch),
+            inputs[0]['size'],
+            stream))
+    context.execute_async_v2(gpu_allocations, stream)
+    for output in range(len(cpu_allocations)):
+      _assign_or_fail(
+          cuda.cuMemcpyDtoHAsync(
+              cpu_allocations[output],
+              outputs[output]['allocation'],
+              outputs[output]['size'],
+              stream))
+    _assign_or_fail(cuda.cuStreamSynchronize(stream))
+
+    return [
+        PredictionResult(
+            x, [prediction[idx] * 2 for prediction in cpu_allocations]) for idx,
+        x in enumerate(batch)
+    ]
+
+
 @pytest.mark.uses_tensorrt
 class TensorRTRunInferenceTest(unittest.TestCase):
   @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
@@ -156,6 +216,44 @@ class TensorRTRunInferenceTest(unittest.TestCase):
     for actual, expected in zip(predictions, SINGLE_FEATURE_PREDICTIONS):
       self.assertEqual(actual, expected)
 
+  def test_inference_custom_single_tensor_feature(self):
+    """
+    This tests creating TensorRT network from scratch. Test replicates the same
+    ONNX network above but natively in TensorRT. After network creation, network
+    is used to build a TensorRT engine. Single feature tensors batched into size
+    of 4 are used as input. This routes through a custom inference function.
+    """
+    inference_runner = TensorRTEngineHandlerNumPy(
+        min_batch_size=4,
+        max_batch_size=4,
+        inference_fn=_custom_tensorRT_inference_fn)
+    builder = trt.Builder(LOGGER)
+    network = builder.create_network(
+        flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
+    input_tensor = network.add_input(
+        name="input", dtype=trt.float32, shape=(4, 1))
+    weight_const = network.add_constant(
+        (1, 1), trt.Weights((np.ascontiguousarray([2.0], dtype=np.float32))))
+    mm = network.add_matrix_multiply(
+        input_tensor,
+        trt.MatrixOperation.NONE,
+        weight_const.get_output(0),
+        trt.MatrixOperation.NONE)
+    bias_const = network.add_constant(
+        (1, 1), trt.Weights((np.ascontiguousarray([0.5], dtype=np.float32))))
+    bias_add = network.add_elementwise(
+        mm.get_output(0),
+        bias_const.get_output(0),
+        trt.ElementWiseOperation.SUM)
+    bias_add.get_output(0).name = "output"
+    network.mark_output(tensor=bias_add.get_output(0))
+
+    engine = inference_runner.build_engine(network, builder)
+    predictions = inference_runner.run_inference(
+        SINGLE_FEATURE_EXAMPLES, engine)
+    for actual, expected in zip(predictions, SINGLE_FEATURE_CUSTOM_PREDICTIONS):
+      self.assertEqual(actual, expected)
+
   def test_inference_multiple_tensor_features(self):
     """
     This tests creating TensorRT network from scratch. Test replicates the same