You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2020/04/17 16:12:34 UTC

[incubator-mxnet] branch master updated: Add gelu fuse ops (#18082)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b01d1dc  Add gelu fuse ops (#18082)
b01d1dc is described below

commit b01d1dc197cb1de3f0102cb4a2558ef4e320768d
Author: MoisesHer <50...@users.noreply.github.com>
AuthorDate: Fri Apr 17 09:11:40 2020 -0700

    Add gelu fuse ops (#18082)
    
    * Add LeakyReLU:Gelu (fwd and bwd) to fused ops
    
    * Add test LeakyReLU:gelu
    
    * cpplint
    
    * fix lint
    
    * fix bug SQRT_2 using constant memory
    
    * add comments
---
 src/executor/pointwise_fusion_pass.cc | 14 ++++++++++++++
 src/operator/fusion/fused_op-inl.h    | 23 ++++++++++++++++++++++
 src/operator/fusion/fused_op.cu       | 36 +++++++++++++++++++++++++++++++++++
 tests/python/gpu/test_fusion.py       | 13 +++++++++++++
 4 files changed, 86 insertions(+)

diff --git a/src/executor/pointwise_fusion_pass.cc b/src/executor/pointwise_fusion_pass.cc
index 5db9706..3203f67 100644
--- a/src/executor/pointwise_fusion_pass.cc
+++ b/src/executor/pointwise_fusion_pass.cc
@@ -71,6 +71,20 @@ namespace {
                   op_name) !=
         variable_io_ops.end())
       return true;
+    if (op_name == "LeakyReLU") {
+        std::string act_type = n->attrs.dict.at("act_type");
+        if (LeakyReLU_ops.count(act_type))
+          return true;
+        else
+          return false;
+    }
+    if (op_name == "_backward_LeakyReLU") {
+        std::string act_type = n->attrs.dict.at("act_type");
+        if (LeakyReLU_bwd_ops.count(act_type))
+          return true;
+        else
+          return false;
+    }
     return false;
   }
 
diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h
index e45569f..0b10f82 100644
--- a/src/operator/fusion/fused_op-inl.h
+++ b/src/operator/fusion/fused_op-inl.h
@@ -224,6 +224,14 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
                                           {"(% * % / op::hypot(%, %))", "_0", "_2", "_1", "_2"}}}
 };
 
+// LeakyReLU ops: based on "act_type" attribute
+const std::map<std::string, std::vector<std::vector<std::string>>> LeakyReLU_ops = {
+  {"gelu"                              , {{"op::gelu(%)", "_0"}}},
+};
+const std::map<std::string, std::vector<std::vector<std::string>>> LeakyReLU_bwd_ops = {
+  {"gelu"                              , {{"op::backward_gelu(%, %)", "_1", "_0"}}},
+};
+
 const std::map<std::string, std::string> slice_ops = {
   {"slice_axis"   , ""},
   {"slice"   , ""},
@@ -543,6 +551,14 @@ __device__ inline DType relu(const DType val) {
   return val > 0 ? val : 0;
 }
 
+const float SQRT_2 = 1.4142135623730950488016887242096;
+// compatible with mshadow_op.h version
+template <typename DType>
+__device__ inline DType gelu(const DType val) {
+  return DType(0.5f * static_cast<float>(val) *
+               (1.0f + erf(static_cast<float>(val) / SQRT_2)));
+}
+
 template <typename DType>
 __device__ inline DType sigmoid(const DType val) {
   return 1.f/(1 + expf(-val));
@@ -987,6 +1003,13 @@ __device__ inline DTypeGrad backward_smooth_l1(const DType val, const DType2 sca
   }
 }
 
+// compatible with mshadow_op.h version
+template <typename DType, typename DTypeGrad>
+__device__ inline DTypeGrad backward_gelu(const DType val, const DTypeGrad grad) {
+  return grad * DType(0.5f * (1.0f + erf(static_cast<float>(val) / SQRT_2) +
+                static_cast<float>(val) * backward_erf(static_cast<float>(val) / SQRT_2, 1.0f) / SQRT_2));
+}
+
 }  // namespace op
 
 )code";
diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu
index 0088724..3d7caab 100644
--- a/src/operator/fusion/fused_op.cu
+++ b/src/operator/fusion/fused_op.cu
@@ -460,6 +460,42 @@ std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
           continue;
         }
 
+        // LeakyReLU, look for act_type
+        if (op_name == "LeakyReLU") {
+            std::string act_type = node.source->attrs.dict.at("act_type");
+            const std::vector<std::vector<std::string>>& op_descs =
+                fusion::LeakyReLU_ops.at(act_type);
+            if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_ops.end()) {
+              CHECK_EQ(outputs[i], op_descs.size());
+              size_t count = 0;
+              for (const auto& op_desc : op_descs) {
+                var_name = "temp" + std::to_string(temp_name_counter++);
+                const std::string& fmt = ParseOpDescription(op_desc, variables, node);
+                code += "const auto " + var_name + " = " + fmt + ";\n";
+                variables[{i, count}] = var_name;
+                ++count;
+              }
+              continue;
+            }
+        }
+        if (op_name == "_backward_LeakyReLU") {
+            std::string act_type = node.source->attrs.dict.at("act_type");
+            const std::vector<std::vector<std::string>>& op_descs =
+                fusion::LeakyReLU_bwd_ops.at(act_type);
+            if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_bwd_ops.end()) {
+              CHECK_EQ(outputs[i], op_descs.size());
+              size_t count = 0;
+              for (const auto& op_desc : op_descs) {
+                var_name = "temp" + std::to_string(temp_name_counter++);
+                const std::string& fmt = ParseOpDescription(op_desc, variables, node);
+                code += "const auto " + var_name + " = " + fmt + ";\n";
+                variables[{i, count}] = var_name;
+                ++count;
+              }
+              continue;
+            }
+        }
+
         LOG(FATAL) << "Unrecognized op " + op_name;
       }
     } else {
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index 8e0f063..a6be6c7 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -230,11 +230,24 @@ def check_other_ops():
     arr2 = mx.random.uniform(shape=(2,2,2,3))
     check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2)
 
+def check_leakyrelu_ops():
+    a = mx.sym.Variable('a')
+    b = mx.sym.Variable('b')
+    shape = rand_shape_2d()
+    arr1 = mx.random.uniform(shape=shape)
+    arr2 = mx.random.uniform(shape=shape)
+
+    # Testing gelu
+    print("Checking fusion of LeakyReLU:gelu")
+    check_fused_symbol(mx.sym.LeakyReLU(a+b, act_type='gelu'), a=arr1, b=arr2)
+
+
 @with_seed()
 def test_fusion():
     check_unary_ops()
     check_binary_ops()
     check_other_ops()
+    check_leakyrelu_ops()
 
 @with_seed()
 def test_fusion_compiler_cache():