You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/07/28 04:31:53 UTC

[incubator-mxnet] branch master updated: [Opperf] Add optimizer update operator benchmarks to opperf (#15522)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d366a3  [Opperf] Add  optimizer update operator benchmarks to opperf (#15522)
3d366a3 is described below

commit 3d366a3aa24c2aabc8b67b82d2b834844449d1f7
Author: Chaitanya Prakash Bapat <ch...@gmail.com>
AuthorDate: Sat Jul 27 21:31:31 2019 -0700

    [Opperf] Add  optimizer update operator benchmarks to opperf (#15522)
    
    * optimizer for opperf benchmark
    
    * Trigger notification
    
    * missed function call
    
    * added params
    
    * minor typos
    
    * Trigger notification
    
    * resolve default params
    
    * temp remove multi op
    
    * take care of #15643
    
    * numbering typo
---
 benchmark/opperf/nd_operations/README.md           | 13 +----
 .../opperf/nd_operations/nn_optimizer_operators.py | 64 ++++++++++++++++++++++
 benchmark/opperf/opperf.py                         |  3 +
 benchmark/opperf/rules/default_params.py           | 54 +++++++++++++++++-
 benchmark/opperf/utils/op_registry_utils.py        | 21 +++++++
 5 files changed, 142 insertions(+), 13 deletions(-)

diff --git a/benchmark/opperf/nd_operations/README.md b/benchmark/opperf/nd_operations/README.md
index 321158c..9595866 100644
--- a/benchmark/opperf/nd_operations/README.md
+++ b/benchmark/opperf/nd_operations/README.md
@@ -28,9 +28,7 @@
 6. reshape
 7. one_hot
 8. linalg_potri
-9. mp_sgd_update
 10. multi_sgd_update
-11. signum_update
 12. Convolution_v1
 13. repeat
 14. Custom
@@ -38,7 +36,6 @@
 16. SwapAxis
 17. norm
 18. Softmax
-19. rmspropalex_update
 20. fill_element_0index
 21. cast
 22. UpSampling
@@ -52,7 +49,6 @@
 30. Activation
 31. LinearRegressionOutput
 32. Pooling_v1
-33. ftml_update
 34. Crop
 35. ElementWiseSum
 36. diag
@@ -60,24 +56,20 @@
 38. Pad
 39. linalg_gemm2
 40. crop
-41. rmsprop_update
 43. RNN
 45. SoftmaxOutput
 46. linalg_extractdiag
-47. sgd_mom_update
 48. SequenceLast
 51. SequenceReverse
 53. SVMOutput
 54. linalg_trsm
 55. where
 56. SoftmaxActivation
-57. signsgd_update
 58. slice
 59. linalg_gelqf
 60. softmin
 61. linalg_gemm
 62. BilinearSampler
-63. mp_sgd_mom_update
 64. choose_element_0index
 65. tile
 67. gather_nd
@@ -110,7 +102,6 @@
 98. linalg_syrk
 99. squeeze
 101. ROIPooling
-102. ftrl_update
 103. SliceChannel
 104. slice_like
 106. linalg_maketrian
@@ -127,6 +118,4 @@
 119. normal
 120. take
 121. MakeLoss
-122. sgd_update
-123. adam_update
-124. concat
\ No newline at end of file
+124. concat
diff --git a/benchmark/opperf/nd_operations/nn_optimizer_operators.py b/benchmark/opperf/nd_operations/nn_optimizer_operators.py
new file mode 100644
index 0000000..130ab85
--- /dev/null
+++ b/benchmark/opperf/nd_operations/nn_optimizer_operators.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
+from benchmark.opperf.utils.op_registry_utils import get_all_optimizer_operators
+
+"""Performance benchmark tests for MXNet Neural Network Optimizer Update Operators.
+
+1. Stochastic Gradient Descent (SGD)
+    1.1 mp_sgd_update
+    1.2 sgd_mom_update
+    1.3 signsgd_update
+    1.4 mp_sgd_mom_update
+    1.5 sgd_update
+2. signum_update
+3. rmspropalex_update
+4. ftml_update
+5. rmsprop_update
+6. ftrl_update
+7. adam_update
+"""
+
+
+def run_optimizer_operators_benchmarks(ctx=mx.cpu(), dtype='float32', warmup=25, runs=100):
+    """Runs benchmarks with the given context and precision (dtype) for all the neural network
+    optimizer update operators in MXNet.
+
+    Parameters
+    ----------
+    ctx: mx.ctx
+        Context to run benchmarks
+    dtype: str, default 'float32'
+        Precision to use for benchmarks
+    warmup: int, default 25
+        Number of times to run for warmup
+    runs: int, default 100
+        Number of runs to capture benchmark results
+
+    Returns
+    -------
+    Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.
+
+    """
+    # Fetch all optimizer operators
+    mx_optimizer_ops = get_all_optimizer_operators()
+
+    # Run benchmarks
+    mx_optimizer_op_results = run_op_benchmarks(mx_optimizer_ops, dtype, ctx, warmup, runs)
+    return mx_optimizer_op_results
diff --git a/benchmark/opperf/opperf.py b/benchmark/opperf/opperf.py
index 77b1667..b8055d7 100755
--- a/benchmark/opperf/opperf.py
+++ b/benchmark/opperf/opperf.py
@@ -39,6 +39,7 @@ from benchmark.opperf.nd_operations.nn_activation_operators import run_activatio
 from benchmark.opperf.nd_operations.nn_conv_operators import run_pooling_operators_benchmarks, \
     run_convolution_operators_benchmarks, run_transpose_convolution_operators_benchmarks
 from benchmark.opperf.nd_operations.nn_basic_operators import run_nn_basic_operators_benchmarks
+from benchmark.opperf.nd_operations.nn_optimizer_operators import run_optimizer_operators_benchmarks
 from benchmark.opperf.nd_operations.array_rearrange import run_rearrange_operators_benchmarks
 
 from benchmark.opperf.utils.common_utils import merge_map_list, save_to_file
@@ -96,6 +97,8 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32'):
     # Run all Convolution operations benchmarks with default input values
     mxnet_operator_benchmark_results.append(run_convolution_operators_benchmarks(ctx=ctx, dtype=dtype))
 
+    # Run all Optimizer operations benchmarks with default input values
+    mxnet_operator_benchmark_results.append(run_optimizer_operators_benchmarks(ctx=ctx, dtype=dtype))
     # Run all Transpose Convolution operations benchmarks with default input values
     mxnet_operator_benchmark_results.append(run_transpose_convolution_operators_benchmarks(ctx=ctx, dtype=dtype))
 
diff --git a/benchmark/opperf/rules/default_params.py b/benchmark/opperf/rules/default_params.py
index 322fde2..c864c7d 100644
--- a/benchmark/opperf/rules/default_params.py
+++ b/benchmark/opperf/rules/default_params.py
@@ -63,6 +63,31 @@ DEFAULT_AXIS_SHAPE = [(), 0, (0, 1)]
 # NOTE: Data used is DEFAULT_DATA
 DEFAULT_AXIS = [0]
 
+# For optimizer operators
+DEFAULT_WEIGHT = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_GRAD = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_MOM = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_MEAN = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_VAR = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_N = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_D = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_V = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_Z = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_G = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_DELTA = [(1024, 1024), (10000, 1), (10000, 100)]
+DEFAULT_LRS = [(0.1,0.1)]
+DEFAULT_LR = [0.1,0.5,0.9]
+DEFAULT_GAMMA_1 = [0.1,0.5,0.9]
+DEFAULT_GAMMA_2 = [0.1,0.5,0.9]
+DEFAULT_EPSILON = [1e-08]
+DEFAULT_BETA_1 = [0.1,0.5,0.9]
+DEFAULT_BETA_2 = [0.1,0.5,0.9]
+DEFAULT_T = [1,5]
+DEFAULT_RESCALE_GRAD = [0.4, 0.77]
+DEFAULT_CLIP_GRADIENT = [-1.0,0.8]
+DEFAULT_CLIP_WEIGHTS = [-1.0,0.8]
+DEFAULT_LAZY_UPDATE = [0,1]
+
 # For rearrange operators
 # NOTE: Data needs to be a 4D tensor for  operators like space_to_depth and depth_to_space
 # Hence below we append 4d to mark the difference.
@@ -72,6 +97,7 @@ DEFAULT_DIM_1 = [0, 1, 2, 3]
 DEFAULT_DIM_2 = [1, 2, 3, 0]
 DEFAULT_BLOCK_SIZE = [2, 5]
 
+
 # Default Inputs. MXNet Op Param Name to Default Input mapping
 DEFAULTS_INPUTS = {"data": DEFAULT_DATA,
                    "sample": DEFAULT_SAMPLE,
@@ -93,11 +119,36 @@ DEFAULTS_INPUTS = {"data": DEFAULT_DATA,
                    "p_nd": DEFAULT_P_ND,
                    "axis_shape": DEFAULT_AXIS_SHAPE,
                    "axis": DEFAULT_AXIS,
+                   "weight" : DEFAULT_WEIGHT,
+                   "weight32" : DEFAULT_WEIGHT,
+                   "grad" : DEFAULT_GRAD,
+                   "mean" : DEFAULT_MEAN,
+                   "var" : DEFAULT_VAR,
+                   "mom" : DEFAULT_MOM,
+                   "n" : DEFAULT_N,
+                   "d" : DEFAULT_D,
+                   "v" : DEFAULT_V,
+                   "z" : DEFAULT_Z,
+                   "g" : DEFAULT_G,
+                   "delta" : DEFAULT_DELTA,
+                   "lr" : DEFAULT_LR,
+                   "lrs" : DEFAULT_LRS,
+                   "wds" : DEFAULT_LRS,
+                   "gamma1" : DEFAULT_GAMMA_1,
+                   "gamma2" : DEFAULT_GAMMA_2,
+                   "epsilon" : DEFAULT_EPSILON,
+                   "beta1" : DEFAULT_BETA_1,
+                   "beta2" : DEFAULT_BETA_2,
+                   "t" : DEFAULT_T,
+                   "rescale_grad" : DEFAULT_RESCALE_GRAD,
+                   "clip_grad" : DEFAULT_CLIP_GRADIENT,
+                   "lazy_update" : DEFAULT_LAZY_UPDATE,
                    "data_4d": DEFAULT_DATA_4d,
                    "dim1": DEFAULT_DIM_1,
                    "dim2": DEFAULT_DIM_2,
                    "block_size": DEFAULT_BLOCK_SIZE}
 
+
 # These are names of MXNet operator parameters that is of type NDArray.
 # We maintain this list to automatically recognize these parameters are to be
 # given as NDArray and translate users inputs such as a shape tuple, Numpy Array or
@@ -105,4 +156,5 @@ DEFAULTS_INPUTS = {"data": DEFAULT_DATA,
 # can just say shape of the tensor, and we automatically create Tensors.
 PARAMS_OF_TYPE_NDARRAY = ["lhs", "rhs", "data", "base", "exp", "sample",
                           "mu", "sigma", "lam", "alpha", "beta", "gamma", "k", "p",
-                          "low", "high", "weight", "bias", "moving_mean", "moving_var"]
+                          "low", "high", "weight", "bias", "moving_mean", "moving_var",
+                          "weight", "weight32", "grad", "mean", "var", "mom", "n", "d", "v", "z", "g", "delta"]
diff --git a/benchmark/opperf/utils/op_registry_utils.py b/benchmark/opperf/utils/op_registry_utils.py
index f5e7506..860b83a 100644
--- a/benchmark/opperf/utils/op_registry_utils.py
+++ b/benchmark/opperf/utils/op_registry_utils.py
@@ -244,6 +244,27 @@ def get_all_reduction_operators():
     return reduction_mx_operators
 
 
+def get_all_optimizer_operators():
+    """Gets all Optimizer operators registered with MXNet.
+
+     Returns
+     -------
+     {"operator_name": {"has_backward", "nd_op_handle", "params"}}
+     """
+    optimizer_ops = ['mp_sgd_update', 'signum_update', 'rmspropalex_update', 'ftml_update', 'rmsprop_update',
+                     'sgd_mom_update', 'signsgd_update', 'mp_sgd_mom_update', 'ftrl_update', 'sgd_update',
+                     'adam_update']
+
+    # Get all mxnet operators
+    mx_operators = _get_all_mxnet_operators()
+
+    # Filter for Optimizer operators
+    optimizer_mx_operators = {}
+    for op_name, op_params in mx_operators.items():
+         if op_name in optimizer_ops and op_name not in unique_ops:
+             optimizer_mx_operators[op_name] = mx_operators[op_name]
+    return optimizer_mx_operators
+
 def get_all_sorting_searching_operators():
     """Gets all Sorting and Searching operators registered with MXNet.