You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2023/01/08 22:35:06 UTC

[GitHub] [beam] AnandInguva commented on a diff in pull request #24911: Ziqima/onnx

AnandInguva commented on code in PR #24911:
URL: https://github.com/apache/beam/pull/24911#discussion_r1064076443


##########
sdks/python/apache_beam/ml/inference/onnx_inference.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pickle
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+import pandas
+import onnx
+import onnxruntime as ort
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils.annotations import experimental
+
+try:
+  import joblib
+except ImportError:
+  # joblib is an optional dependency.
+  pass
+
+__all__ = [
+    'OnnxModelHandler'
+]
+
+NumpyInferenceFn = Callable[
+    [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]],
+    Iterable[PredictionResult]]
+
+
+def _load_model(model_uri):
+  ort_session = ort.InferenceSession(model_uri, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+  return ort_session
+
+
+def _convert_to_result(
+    batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
+) -> Iterable[PredictionResult]:
+  if isinstance(predictions, dict):
+    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
+    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
+    # length batch_size, to a list of dictionaries:
+    # [{key_type1: value_type1, key_type2: value_type2}]
+    predictions_per_tensor = [
+        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
+    ]
+    return [
+        PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
+    ]
+  return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+
+def default_numpy_inference_fn(
+    inference_session: ort.InferenceSession,
+    batch: Sequence[numpy.ndarray],
+    inference_args: Optional[Dict[str, Any]] = None) -> Any:
+  ort_inputs = {inference_session.get_inputs()[0].name: numpy.stack(batch, axis=0)}
+  ort_outs = inference_session.run(None, ort_inputs)
+  return ort_outs
+
+
+class OnnxModelHandler(ModelHandler[numpy.ndarray,

Review Comment:
   Are there any other inputs ONNX model handler could accept potentially?



##########
sdks/python/apache_beam/ml/inference/onnx_inference.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pickle
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+import pandas
+import onnx
+import onnxruntime as ort
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils.annotations import experimental

Review Comment:
   I don't think we use this anywhere. Do you want to mark this experimental? If yes, why?



##########
sdks/python/apache_beam/ml/inference/onnx_inference.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pickle
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+import pandas
+import onnx
+import onnxruntime as ort
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils.annotations import experimental
+
+try:
+  import joblib
+except ImportError:
+  # joblib is an optional dependency.
+  pass
+
+__all__ = [
+    'OnnxModelHandler'
+]
+
+NumpyInferenceFn = Callable[
+    [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]],
+    Iterable[PredictionResult]]
+
+
+def _load_model(model_uri):
+  ort_session = ort.InferenceSession(model_uri, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+  return ort_session
+
+
+def _convert_to_result(
+    batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
+) -> Iterable[PredictionResult]:
+  if isinstance(predictions, dict):
+    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
+    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
+    # length batch_size, to a list of dictionaries:
+    # [{key_type1: value_type1, key_type2: value_type2}]
+    predictions_per_tensor = [
+        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
+    ]
+    return [
+        PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
+    ]
+  return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+
+def default_numpy_inference_fn(
+    inference_session: ort.InferenceSession,
+    batch: Sequence[numpy.ndarray],
+    inference_args: Optional[Dict[str, Any]] = None) -> Any:
+  ort_inputs = {inference_session.get_inputs()[0].name: numpy.stack(batch, axis=0)}
+  ort_outs = inference_session.run(None, ort_inputs)
+  return ort_outs
+
+
+class OnnxModelHandler(ModelHandler[numpy.ndarray,
+                                    PredictionResult,
+                                    ort.InferenceSession]):
+  def __init__(
+      self,
+      model_uri: str,
+      *,
+      inference_fn: NumpyInferenceFn = default_numpy_inference_fn):
+    """ Implementation of the ModelHandler interface for onnx
+    using numpy arrays as input.
+
+    Example Usage::
+
+      pcoll | RunInference(OnnxModelHandler(model_uri="my_uri"))
+
+    Args:
+      model_uri: The URI to where the model is saved.
+      inference_fn: The inference function to use.
+        default=default_numpy_inference_fn
+    """
+    self._model_uri = model_uri
+    self._model_inference_fn = inference_fn
+
+  def load_model(self) -> ort.InferenceSession:
+    """Loads and initializes an onnx inference session for processing."""
+    return _load_model(self._model_uri)
+
+  def run_inference(
+      self,
+      batch: Sequence[numpy.ndarray],
+      inference_session: ort.InferenceSession,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Iterable[PredictionResult]:
+    """Runs inferences on a batch of numpy arrays.
+
+    Args:
+      batch: A sequence of examples as numpy arrays. They should
+        be single examples.
+      inference_session: An onnx inference session. Must be runnable with input x where x is sequence of numpy array
+      inference_args: Any additional arguments for an inference.
+
+    Returns:
+      An Iterable of type PredictionResult.
+    """
+    predictions = self._model_inference_fn(inference_session, batch, inference_args)[0]
+
+    return _convert_to_result(batch, predictions)
+
+  def get_num_bytes(self, batch: Sequence[numpy.ndarray]) -> int:
+    """
+    Returns:
+      The number of bytes of data for a batch.
+    """
+    return sum(sys.getsizeof(element) for element in batch)

Review Comment:
   Can we follow this for the num_bytes ?https://github.com/apache/beam/blob/95e53916b6c9de6052ce8fc8409a53bad3d9d517/sdks/python/apache_beam/ml/inference/tensorrt_inference.py#L307



##########
sdks/python/apache_beam/ml/inference/onnx_inference.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pickle
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+import pandas
+import onnx
+import onnxruntime as ort
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils.annotations import experimental
+
+try:
+  import joblib
+except ImportError:
+  # joblib is an optional dependency.
+  pass
+
+__all__ = [
+    'OnnxModelHandler'
+]
+
+NumpyInferenceFn = Callable[
+    [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]],
+    Iterable[PredictionResult]]
+
+
+def _load_model(model_uri):
+  ort_session = ort.InferenceSession(model_uri, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+  return ort_session
+
+
+def _convert_to_result(
+    batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
+) -> Iterable[PredictionResult]:
+  if isinstance(predictions, dict):
+    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
+    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
+    # length batch_size, to a list of dictionaries:
+    # [{key_type1: value_type1, key_type2: value_type2}]
+    predictions_per_tensor = [
+        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
+    ]
+    return [
+        PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
+    ]
+  return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+
+def default_numpy_inference_fn(
+    inference_session: ort.InferenceSession,
+    batch: Sequence[numpy.ndarray],
+    inference_args: Optional[Dict[str, Any]] = None) -> Any:
+  ort_inputs = {inference_session.get_inputs()[0].name: numpy.stack(batch, axis=0)}
+  ort_outs = inference_session.run(None, ort_inputs)

Review Comment:
   ```suggestion
     ort_outs = inference_session.run(None, ort_inputs)
   ```
   ```suggestion
     ort_outs = inference_session.run(None, {**ort_inputs, **inference_args})
   ```
   
   if the user required to pass any extra parameter other than input to the model predict call, they pass it through the `inference_args`. I think it's safe to merge `ort_inputs`, `inference_args`(even though it is empty)
   
   ```
   import onnxruntime as ort
   ort_sess = ort.InferenceSession('loop.onnx')
   outputs = ort_sess.run(None, {'input_data': dummy_input.numpy(),
                                 'loop_range': np.array(9).astype(np.int64)})
   
   loop_range could be an extra parameter passed to the predict call.
   ```



##########
sdks/python/apache_beam/ml/inference/onnx_inference.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pickle
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+import pandas
+import onnx
+import onnxruntime as ort
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils.annotations import experimental
+
+try:
+  import joblib
+except ImportError:
+  # joblib is an optional dependency.
+  pass
+
+__all__ = [
+    'OnnxModelHandler'
+]
+
+NumpyInferenceFn = Callable[
+    [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]],
+    Iterable[PredictionResult]]
+
+
+def _load_model(model_uri):
+  ort_session = ort.InferenceSession(model_uri, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+  return ort_session
+
+
+def _convert_to_result(
+    batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
+) -> Iterable[PredictionResult]:
+  if isinstance(predictions, dict):
+    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
+    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
+    # length batch_size, to a list of dictionaries:
+    # [{key_type1: value_type1, key_type2: value_type2}]
+    predictions_per_tensor = [
+        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
+    ]
+    return [
+        PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
+    ]
+  return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+
+def default_numpy_inference_fn(
+    inference_session: ort.InferenceSession,
+    batch: Sequence[numpy.ndarray],
+    inference_args: Optional[Dict[str, Any]] = None) -> Any:
+  ort_inputs = {inference_session.get_inputs()[0].name: numpy.stack(batch, axis=0)}
+  ort_outs = inference_session.run(None, ort_inputs)
+  return ort_outs
+
+
+class OnnxModelHandler(ModelHandler[numpy.ndarray,

Review Comment:
   ```suggestion
   class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
   ```



##########
sdks/python/apache_beam/ml/inference/onnx_inference.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pickle
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+import pandas
+import onnx
+import onnxruntime as ort
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils.annotations import experimental
+
+try:
+  import joblib
+except ImportError:
+  # joblib is an optional dependency.
+  pass
+
+__all__ = [
+    'OnnxModelHandler'
+]
+
+NumpyInferenceFn = Callable[
+    [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]],
+    Iterable[PredictionResult]]
+
+
+def _load_model(model_uri):
+  ort_session = ort.InferenceSession(model_uri, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+  return ort_session
+
+
+def _convert_to_result(
+    batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
+) -> Iterable[PredictionResult]:
+  if isinstance(predictions, dict):
+    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
+    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
+    # length batch_size, to a list of dictionaries:
+    # [{key_type1: value_type1, key_type2: value_type2}]
+    predictions_per_tensor = [
+        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
+    ]
+    return [
+        PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
+    ]
+  return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+
+def default_numpy_inference_fn(
+    inference_session: ort.InferenceSession,
+    batch: Sequence[numpy.ndarray],
+    inference_args: Optional[Dict[str, Any]] = None) -> Any:
+  ort_inputs = {inference_session.get_inputs()[0].name: numpy.stack(batch, axis=0)}
+  ort_outs = inference_session.run(None, ort_inputs)
+  return ort_outs
+
+
+class OnnxModelHandler(ModelHandler[numpy.ndarray,
+                                    PredictionResult,
+                                    ort.InferenceSession]):
+  def __init__(
+      self,
+      model_uri: str,
+      *,
+      inference_fn: NumpyInferenceFn = default_numpy_inference_fn):
+    """ Implementation of the ModelHandler interface for onnx
+    using numpy arrays as input.

Review Comment:
   Can we also add a note saying the inputs to ONNXModelHandler should be of the same sizes. If different size inputs are expected, user have to explicitly declare batch size as 1.



##########
sdks/python/apache_beam/ml/inference/onnx_inference.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pickle
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+import pandas
+import onnx
+import onnxruntime as ort
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils.annotations import experimental
+
+try:
+  import joblib
+except ImportError:
+  # joblib is an optional dependency.
+  pass
+
+__all__ = [
+    'OnnxModelHandler'
+]
+
+NumpyInferenceFn = Callable[
+    [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]],
+    Iterable[PredictionResult]]
+
+
+def _load_model(model_uri):

Review Comment:
   can we move this to the `load_model()`, seems unnecessary in a different method



##########
sdks/python/apache_beam/ml/inference/onnx_inference.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pickle
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+import pandas
+import onnx
+import onnxruntime as ort
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils.annotations import experimental
+
+try:
+  import joblib
+except ImportError:
+  # joblib is an optional dependency.
+  pass
+
+__all__ = [
+    'OnnxModelHandler'
+]
+
+NumpyInferenceFn = Callable[
+    [Sequence[numpy.ndarray], ort.InferenceSession, Optional[Dict[str, Any]]],
+    Iterable[PredictionResult]]
+
+
+def _load_model(model_uri):
+  ort_session = ort.InferenceSession(model_uri, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+  return ort_session
+
+
+def _convert_to_result(
+    batch: Iterable, predictions: Union[Iterable, Dict[Any, Iterable]]
+) -> Iterable[PredictionResult]:
+  if isinstance(predictions, dict):
+    # Go from one dictionary of type: {key_type1: Iterable<val_type1>,
+    # key_type2: Iterable<val_type2>, ...} where each Iterable is of
+    # length batch_size, to a list of dictionaries:
+    # [{key_type1: value_type1, key_type2: value_type2}]
+    predictions_per_tensor = [
+        dict(zip(predictions.keys(), v)) for v in zip(*predictions.values())
+    ]
+    return [
+        PredictionResult(x, y) for x, y in zip(batch, predictions_per_tensor)
+    ]
+  return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+
+def default_numpy_inference_fn(
+    inference_session: ort.InferenceSession,
+    batch: Sequence[numpy.ndarray],
+    inference_args: Optional[Dict[str, Any]] = None) -> Any:
+  ort_inputs = {inference_session.get_inputs()[0].name: numpy.stack(batch, axis=0)}
+  ort_outs = inference_session.run(None, ort_inputs)
+  return ort_outs
+
+
+class OnnxModelHandler(ModelHandler[numpy.ndarray,
+                                    PredictionResult,
+                                    ort.InferenceSession]):
+  def __init__(
+      self,
+      model_uri: str,
+      *,
+      inference_fn: NumpyInferenceFn = default_numpy_inference_fn):
+    """ Implementation of the ModelHandler interface for onnx
+    using numpy arrays as input.
+
+    Example Usage::
+
+      pcoll | RunInference(OnnxModelHandler(model_uri="my_uri"))
+
+    Args:
+      model_uri: The URI to where the model is saved.
+      inference_fn: The inference function to use.

Review Comment:
   ```suggestion
         inference_fn: The inference function to use.
   ```
   ```suggestion
         inference_fn: The inference function to use on RunInference calls..
   ```



##########
sdks/python/apache_beam/ml/inference/onnx_inference.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pickle
+import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Union
+
+import numpy
+import pandas
+import onnx
+import onnxruntime as ort
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils.annotations import experimental
+
+try:

Review Comment:
   Can we remove unused dependencies imports?



##########
sdks/python/apache_beam/ml/inference/onnx_inference_test.py:
##########
@@ -0,0 +1,477 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import os
+import shutil
+import tempfile
+import unittest
+from collections import OrderedDict
+import sys
+import numpy as np
+import pytest
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+# Protect against environments where onnx and pytorch library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
+try:
+  import onnx
+  import onnxruntime as ort
+  import torch
+  from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument
+  import tensorflow as tf
+  import tf2onnx
+  from tensorflow import keras
+  from tensorflow.keras import layers
+  from sklearn import linear_model
+  from skl2onnx import convert_sklearn
+  from skl2onnx.common.data_types import FloatTensorType
+  from apache_beam.ml.inference.base import PredictionResult
+  from apache_beam.ml.inference.base import RunInference
+  from apache_beam.ml.inference.onnx_inference import default_numpy_inference_fn
+  from apache_beam.ml.inference.onnx_inference import OnnxModelHandler
+except ImportError:
+  raise unittest.SkipTest('Onnx dependencies are not installed')
+
+try:
+  from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
+except ImportError:
+  GCSFileSystem = None  # type: ignore
+
+
+class PytorchLinearRegression(torch.nn.Module):
+  def __init__(self, input_dim, output_dim):
+    super().__init__()
+    self.linear = torch.nn.Linear(input_dim, output_dim)
+
+  def forward(self, x):
+    out = self.linear(x)
+    return out
+
+  def generate(self, x):
+    out = self.linear(x) + 0.5
+    return out
+
+
+class TestDataAndModel():
+  def get_one_feature_samples(self):
+    return [
+        np.array([1], dtype="float32"),
+        np.array([5], dtype="float32"),
+        np.array([-3], dtype="float32"),
+        np.array([10.0], dtype="float32"),
+    ]
+
+  def get_one_feature_predictions(self):
+    return [
+        PredictionResult(ex, pred) for ex,
+        pred in zip(
+            self.get_one_feature_samples(),
+            [example * 2.0 + 0.5
+                      for example in self.get_one_feature_samples()])
+    ]
+
+  def get_two_feature_examples(self):
+    return [
+      np.array([1, 5], dtype="float32"),
+      np.array([3, 10], dtype="float32"),
+      np.array([-14, 0], dtype="float32"),
+      np.array([0.5, 0.5], dtype="float32")
+    ]
+
+  def get_two_feature_predictions(self):
+    return [
+      PredictionResult(ex, pred) for ex,
+      pred in zip(
+        self.get_two_feature_examples(),
+        [f1 * 2.0 + f2 * 3 + 0.5
+        for f1, f2 in self.get_two_feature_examples()])
+        ]
+
+  def get_torch_one_feature_model(self):
+    model = PytorchLinearRegression(input_dim=1, output_dim=1)
+    model.load_state_dict(
+        OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
+                     ('linear.bias', torch.Tensor([0.5]))]))
+    return model
+  
+  def get_tf_one_feature_model(self):
+    params = [np.array([[2.0]], dtype="float32"), np.array([0.5], dtype="float32")]
+    linear_layer = layers.Dense(units=1, weights=params)
+    linear_model = tf.keras.Sequential([linear_layer])
+    return linear_model
+
+  def get_sklearn_one_feature_model(self):
+    x = [[0],[1]]
+    y = [0.5, 2.5]
+    model = linear_model.LinearRegression()
+    model.fit(x, y)
+    return model
+
+  def get_torch_two_feature_model(self):
+    model = PytorchLinearRegression(input_dim=2, output_dim=1)
+    model.load_state_dict(
+      OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
+      ('linear.bias', torch.Tensor([0.5]))]))
+    return model
+
+  def get_tf_two_feature_model(self):
+    params = [np.array([[2.0], [3]]), np.array([0.5], dtype="float32")]
+    linear_layer = layers.Dense(units=1, weights=params)
+    linear_model = tf.keras.Sequential([linear_layer])
+    return linear_model
+
+  def get_sklearn_two_feature_model(self):
+    x = [[1,5],[3,2],[1,0]]
+    y = [17.5, 12.5, 2.5]
+    model = linear_model.LinearRegression()
+    model.fit(x, y)
+    return model
+
+
+def _compare_prediction_result(a, b):
+  example_equal = np.array_equal(a.example, b.example)
+  if isinstance(a.inference, dict):
+    return all(
+        x == y for x, y in zip(a.inference.values(),
+                               b.inference.values())) and example_equal
+  return a.inference == b.inference and example_equal
+
+def _to_numpy(tensor):
+      return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+class TestOnnxModelHandler(OnnxModelHandler):
+  def __init__(self,model_uri: str,*,inference_fn = default_numpy_inference_fn):
+    self._model_uri = model_uri
+    self._model_inference_fn = inference_fn
+
+class OnnxTestBase(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+    self.test_data_and_model = TestDataAndModel()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+
+@pytest.mark.uses_pytorch

Review Comment:
   uses_pytorch markers runs Pytorch tests, which requires [https://github.com/apache/beam/blob/95e53916b6c9de6052ce8fc8409a53bad3d9d517/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt](requirements) file to be installed for dependencies. We either add onnx dependencies to that requirements file.
   
   Same goes with tensorflow. https://github.com/apache/beam/blob/95e53916b6c9de6052ce8fc8409a53bad3d9d517/sdks/python/apache_beam/examples/inference/tfx_bsl/requirements.txt



##########
start-build-env-onnx.sh:
##########
@@ -0,0 +1,146 @@
+#!/usr/bin/env bash

Review Comment:
   For my knowledge, why do we need this file?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org