You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/07/17 19:15:14 UTC

[incubator-mxnet] branch master updated: Add transpose_conv, sorting and searching operator benchmarks to Opperf (#15475)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 57d097b  Add transpose_conv, sorting and searching operator benchmarks to Opperf (#15475)
57d097b is described below

commit 57d097b18e0c117e2f0be8287bb94858266a9798
Author: Sandeep Krishnamurthy <sa...@gmail.com>
AuthorDate: Wed Jul 17 12:14:29 2019 -0700

    Add transpose_conv, sorting and searching operator benchmarks to Opperf (#15475)
    
    * Add Transpose Convolution op benchmarks
    
    * Add sorting and searching ops for opperf
    
    * Remove redudant logic, make it uniform
    
    * Dummy commit for CI retrigger
    
    * Address code review comments
---
 benchmark/opperf/nd_operations/README.md           |  6 ---
 .../opperf/nd_operations/nn_conv_operators.py      | 49 ++++++++++++++++++-
 .../nd_operations/sorting_searching_operators.py   | 56 ++++++++++++++++++++++
 benchmark/opperf/opperf.py                         | 12 +++--
 benchmark/opperf/rules/default_params.py           |  7 ++-
 benchmark/opperf/utils/op_registry_utils.py        | 26 +++++++++-
 6 files changed, 143 insertions(+), 13 deletions(-)

diff --git a/benchmark/opperf/nd_operations/README.md b/benchmark/opperf/nd_operations/README.md
index 7aa220c..b98a0d3 100644
--- a/benchmark/opperf/nd_operations/README.md
+++ b/benchmark/opperf/nd_operations/README.md
@@ -62,12 +62,10 @@
 40. crop
 41. rmsprop_update
 43. RNN
-44. argmin
 45. SoftmaxOutput
 46. linalg_extractdiag
 47. sgd_mom_update
 48. SequenceLast
-49. Deconvolution
 50. flip
 51. SequenceReverse
 52. swapaxes
@@ -86,18 +84,15 @@
 65. tile
 66. space_to_depth
 67. gather_nd
-68. argsort
 69. SequenceMask
 70. reshape_like
 71. slice_axis
 72. stack
-73. topk
 74. khatri_rao
 75. multi_mp_sgd_update
 76. linalg_sumlogdiag
 77. broadcast_to
 78. IdentityAttachKLSparseReg
-79. sort
 80. SpatialTransformer
 81. Concat
 82. uniform
@@ -129,7 +124,6 @@
 110. split
 111. MAERegressionOutput
 112. Correlation
-113. argmax
 114. batch_take
 115. L2Normalization
 116. broadcast_axis
diff --git a/benchmark/opperf/nd_operations/nn_conv_operators.py b/benchmark/opperf/nd_operations/nn_conv_operators.py
index e42205f..9730589 100644
--- a/benchmark/opperf/nd_operations/nn_conv_operators.py
+++ b/benchmark/opperf/nd_operations/nn_conv_operators.py
@@ -102,7 +102,7 @@ def run_convolution_operators_benchmarks(ctx=mx.cpu(), dtype='float32', warmup=2
                                                      dtype=dtype,
                                                      ctx=ctx,
                                                      inputs=[{"data": conv_data,
-                                                              "weight": (64, 3, 3,),
+                                                              "weight": (64, 3, 3),
                                                               "bias": (64,),
                                                               "kernel": (3,),
                                                               "stride": (1,),
@@ -135,3 +135,50 @@ def run_convolution_operators_benchmarks(ctx=mx.cpu(), dtype='float32', warmup=2
     # Prepare combined results
     mx_conv_op_results = merge_map_list(conv1d_benchmark_res + conv2d_benchmark_res)
     return mx_conv_op_results
+
+
+def run_transpose_convolution_operators_benchmarks(ctx=mx.cpu(), dtype='float32', warmup=10, runs=50):
+    # Conv1DTranspose Benchmarks
+    conv1d_transpose_benchmark_res = []
+    for conv_data in [(32, 3, 256), (32, 3, 64)]:
+        conv1d_transpose_benchmark_res += run_performance_test([getattr(MX_OP_MODULE, "Deconvolution")],
+                                                               run_backward=True,
+                                                               dtype=dtype,
+                                                               ctx=ctx,
+                                                               inputs=[{"data": conv_data,
+                                                                        "weight": (3, 64, 3),
+                                                                        "bias": (64,),
+                                                                        "kernel": (3,),
+                                                                        "stride": (1,),
+                                                                        "dilate": (1,),
+                                                                        "pad": (0,),
+                                                                        "adj": (0,),
+                                                                        "num_filter": 64,
+                                                                        "no_bias": False,
+                                                                        "layout": 'NCW'}
+                                                                       ],
+                                                               warmup=warmup,
+                                                               runs=runs)
+    # Conv2DTranspose Benchmarks
+    conv2d_transpose_benchmark_res = []
+    for conv_data in [(32, 3, 256, 256), (32, 3, 64, 64)]:
+        conv2d_transpose_benchmark_res += run_performance_test([getattr(MX_OP_MODULE, "Deconvolution")],
+                                                               run_backward=True,
+                                                               dtype=dtype,
+                                                               ctx=ctx,
+                                                               inputs=[{"data": conv_data,
+                                                                        "weight": (3, 64, 3, 3),
+                                                                        "bias": (64,),
+                                                                        "kernel": (3, 3),
+                                                                        "stride": (1, 1),
+                                                                        "dilate": (1, 1),
+                                                                        "pad": (0, 0),
+                                                                        "num_filter": 64,
+                                                                        "no_bias": False,
+                                                                        "layout": 'NCHW'}
+                                                                       ],
+                                                               warmup=warmup,
+                                                               runs=runs)
+    # Prepare combined results
+    mx_transpose_conv_op_results = merge_map_list(conv1d_transpose_benchmark_res + conv2d_transpose_benchmark_res)
+    return mx_transpose_conv_op_results
diff --git a/benchmark/opperf/nd_operations/sorting_searching_operators.py b/benchmark/opperf/nd_operations/sorting_searching_operators.py
new file mode 100644
index 0000000..ab98b3f
--- /dev/null
+++ b/benchmark/opperf/nd_operations/sorting_searching_operators.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
+from benchmark.opperf.utils.op_registry_utils import get_all_sorting_searching_operators
+
+
+""" Performance benchmark tests for MXNet NDArray Sorting and Searching Operations
+1. sort
+2. argsort
+3. topk
+4. argmax
+5. argmin
+"""
+
+
+def run_sorting_searching_operators_benchmarks(ctx=mx.cpu(), dtype='float32', warmup=25, runs=100):
+    """Runs benchmarks with the given context and precision (dtype)for all the sorting and searching
+    operators in MXNet.
+
+    Parameters
+    ----------
+    ctx: mx.ctx
+        Context to run benchmarks
+    dtype: str, default 'float32'
+        Precision to use for benchmarks
+    warmup: int, default 25
+        Number of times to run for warmup
+    runs: int, default 100
+        Number of runs to capture benchmark results
+
+    Returns
+    -------
+    Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.
+
+    """
+    # Fetch all Random Sampling Operators
+    mx_sort_search_ops = get_all_sorting_searching_operators()
+    # Run benchmarks
+    mx_sort_search_op_results = run_op_benchmarks(mx_sort_search_ops, dtype, ctx, warmup, runs)
+    return mx_sort_search_op_results
diff --git a/benchmark/opperf/opperf.py b/benchmark/opperf/opperf.py
index a73db4f..b2258af 100755
--- a/benchmark/opperf/opperf.py
+++ b/benchmark/opperf/opperf.py
@@ -34,13 +34,14 @@ from benchmark.opperf.nd_operations.binary_operators import run_mx_binary_broadc
 from benchmark.opperf.nd_operations.gemm_operators import run_gemm_operators_benchmarks
 from benchmark.opperf.nd_operations.random_sampling_operators import run_mx_random_sampling_operators_benchmarks
 from benchmark.opperf.nd_operations.reduction_operators import run_mx_reduction_operators_benchmarks
+from benchmark.opperf.nd_operations.sorting_searching_operators import run_sorting_searching_operators_benchmarks
 from benchmark.opperf.nd_operations.nn_activation_operators import run_activation_operators_benchmarks
 from benchmark.opperf.nd_operations.nn_conv_operators import run_pooling_operators_benchmarks, \
-    run_convolution_operators_benchmarks
+    run_convolution_operators_benchmarks, run_transpose_convolution_operators_benchmarks
 from benchmark.opperf.nd_operations.nn_basic_operators import run_nn_basic_operators_benchmarks
 
 from benchmark.opperf.utils.common_utils import merge_map_list, save_to_file
-from benchmark.opperf.utils.op_registry_utils import get_operators_with_no_benchmark,\
+from benchmark.opperf.utils.op_registry_utils import get_operators_with_no_benchmark, \
     get_current_runtime_features
 
 
@@ -74,6 +75,9 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32'):
     # Run all Reduction operations benchmarks with default input values
     mxnet_operator_benchmark_results.append(run_mx_reduction_operators_benchmarks(ctx=ctx, dtype=dtype))
 
+    # Run all Sorting and Searching operations benchmarks with default input values
+    mxnet_operator_benchmark_results.append(run_sorting_searching_operators_benchmarks(ctx=ctx, dtype=dtype))
+
     # ************************ MXNET NN OPERATOR BENCHMARKS ****************************
 
     # Run all basic NN operations benchmarks with default input values
@@ -88,6 +92,9 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32'):
     # Run all Convolution operations benchmarks with default input values
     mxnet_operator_benchmark_results.append(run_convolution_operators_benchmarks(ctx=ctx, dtype=dtype))
 
+    # Run all Transpose Convolution operations benchmarks with default input values
+    mxnet_operator_benchmark_results.append(run_transpose_convolution_operators_benchmarks(ctx=ctx, dtype=dtype))
+
     # ****************************** PREPARE FINAL RESULTS ********************************
     final_benchmark_result_map = merge_map_list(mxnet_operator_benchmark_results)
     return final_benchmark_result_map
@@ -148,4 +155,3 @@ def main():
 
 if __name__ == '__main__':
     sys.exit(main())
-
diff --git a/benchmark/opperf/rules/default_params.py b/benchmark/opperf/rules/default_params.py
index df6cdae..2c8f3d4 100644
--- a/benchmark/opperf/rules/default_params.py
+++ b/benchmark/opperf/rules/default_params.py
@@ -56,7 +56,11 @@ DEFAULT_P_ND = [[0.4, 0.77]]
 
 # For reduction operators
 # NOTE: Data used is DEFAULT_DATA
-DEFAULT_AXIS = [(), 0, (0, 1)]
+DEFAULT_AXIS_SHAPE = [(), 0, (0, 1)]
+
+# For sorting and searching operators
+# NOTE: Data used is DEFAULT_DATA
+DEFAULT_AXIS = [0]
 
 # Default Inputs. MXNet Op Param Name to Default Input mapping
 DEFAULTS_INPUTS = {"data": DEFAULT_DATA,
@@ -76,6 +80,7 @@ DEFAULTS_INPUTS = {"data": DEFAULT_DATA,
                    "p": DEFAULT_P,
                    "k_nd": DEFAULT_K_ND,
                    "p_nd": DEFAULT_P_ND,
+                   "axis_shape": DEFAULT_AXIS_SHAPE,
                    "axis": DEFAULT_AXIS}
 
 # These are names of MXNet operator parameters that is of type NDArray.
diff --git a/benchmark/opperf/utils/op_registry_utils.py b/benchmark/opperf/utils/op_registry_utils.py
index 6509be3..88ea7a1 100644
--- a/benchmark/opperf/utils/op_registry_utils.py
+++ b/benchmark/opperf/utils/op_registry_utils.py
@@ -16,10 +16,8 @@
 # under the License.
 
 """Utilities to interact with MXNet operator registry."""
-import ctypes
 from operator import itemgetter
 from mxnet import runtime
-from mxnet.base import _LIB, check_call, py_str, OpHandle, c_str, mx_uint
 import mxnet as mx
 
 from benchmark.opperf.rules.default_params import DEFAULTS_INPUTS, MX_OP_MODULE
@@ -121,6 +119,10 @@ def prepare_op_inputs(arg_params):
             arg_values[arg_name] = DEFAULTS_INPUTS[arg_name]
         elif "float" in arg_type and arg_name + "_float" in DEFAULTS_INPUTS:
             arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_float"]
+        elif "Shape" in arg_type and arg_name + "_shape" in DEFAULTS_INPUTS:
+            # This is for cases where in some ops 'axis' is Int in some ops a shape tuple.
+            # Ex: axis in sum is shape, axis in sort is int.
+            arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_shape"]
 
     # Number of different inputs we want to use to test
     # the operator
@@ -240,6 +242,26 @@ def get_all_reduction_operators():
     return reduction_mx_operators
 
 
+def get_all_sorting_searching_operators():
+    """Gets all Sorting and Searching operators registered with MXNet.
+
+    Returns
+    -------
+    {"operator_name": {"has_backward", "nd_op_handle", "params"}}
+    """
+    sort_search_ops = ['sort', 'argsort', 'argmax', 'argmin', 'topk']
+
+    # Get all mxnet operators
+    mx_operators = _get_all_mxnet_operators()
+
+    # Filter for Sort and search operators
+    sort_search_mx_operators = {}
+    for op_name, op_params in mx_operators.items():
+        if op_name in sort_search_ops and op_name not in unique_ops:
+            sort_search_mx_operators[op_name] = mx_operators[op_name]
+    return sort_search_mx_operators
+
+
 def get_operators_with_no_benchmark(operators_with_benchmark):
     """Gets all MXNet operators with not benchmark.