You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/11 11:29:25 UTC

[tvm] branch main updated: [ETHOSN] Adding support for Leaky ReLU (#11261)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3be5622c7e [ETHOSN] Adding support for Leaky ReLU (#11261)
3be5622c7e is described below

commit 3be5622c7eb1d3f6fe76da23b02dcb59786abdcf
Author: Luke Hutton <lu...@arm.com>
AuthorDate: Wed May 11 12:29:16 2022 +0100

    [ETHOSN] Adding support for Leaky ReLU (#11261)
    
    * [ETHOSN] Adding support for Leaky ReLU
    
    Change-Id: Icad69b2ae6ed4b3f3949cf5673efe2571aa66f5f
    
    * add some missing error reporting
    
    Change-Id: I935054c4d19a939e122092fab3c6c77204d9ead8
---
 python/tvm/relay/op/contrib/ethosn.py              | 14 ++++
 src/relay/backend/contrib/ethosn/codegen.cc        | 44 ++++++++++-
 src/relay/backend/contrib/ethosn/codegen_ethosn.h  |  1 +
 src/relay/backend/contrib/ethosn/ethosn_api.cc     | 32 ++++++++
 src/relay/backend/contrib/ethosn/ethosn_api.h      |  7 ++
 .../python/contrib/test_ethosn/test_leaky_relu.py  | 86 ++++++++++++++++++++++
 6 files changed, 183 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py
index 312bc874f1..a1a3e2dccc 100644
--- a/python/tvm/relay/op/contrib/ethosn.py
+++ b/python/tvm/relay/op/contrib/ethosn.py
@@ -131,6 +131,12 @@ def pattern_table():
         pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
         return pattern
 
+    def qnn_leaky_relu_pattern():
+        pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
+        pattern = is_op("nn.leaky_relu")(pattern)
+        pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
+        return pattern
+
     def check_conv2d(extract):
         """Check if a conv2d is supported by Ethos-N."""
         if not ethosn_available():
@@ -173,6 +179,13 @@ def pattern_table():
 
         return support.tanh(extract)
 
+    def check_leaky_relu(extract):
+        """Check if Leaky ReLU is supported."""
+        if not ethosn_available():
+            return False
+
+        return support.leaky_relu(extract)
+
     return [
         ("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d),
         ("ethos-n.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_avg_pool2d),
@@ -180,6 +193,7 @@ def pattern_table():
         ("ethos-n.qnn_fc", qnn_fc_pattern(), check_fc),
         ("ethos-n.qnn_mean", qnn_mean_pattern(), check_mean),
         ("ethos-n.qnn_tanh", qnn_tanh_pattern(), check_tanh),
+        ("ethos-n.qnn_leaky_relu", qnn_leaky_relu_pattern(), check_leaky_relu),
     ]
 
 
diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc
index 674793e1bd..d9f7b84b2f 100644
--- a/src/relay/backend/contrib/ethosn/codegen.cc
+++ b/src/relay/backend/contrib/ethosn/codegen.cc
@@ -120,6 +120,10 @@ void InferTensorsVisitor::InferCall(const CallNode* cn) {
     TanhParams params;
     err += EthosnAPI::Tanh(cn->op.as<FunctionNode>()->body, &params);
     tensor_table_[cn->args[0]] = {params.input_info};
+  } else if (IsEthosnFunc(call, "ethos-n.qnn_leaky_relu")) {
+    LeakyReLUParams params;
+    err += EthosnAPI::LeakyReLU(cn->op.as<FunctionNode>()->body, &params);
+    tensor_table_[cn->args[0]] = {params.input_info};
   } else if (IsEthosnOp(call, "qnn.concatenate")) {
     ConcatenateParams params;
     err = EthosnAPI::Concatenate(call, &params);
@@ -290,6 +294,9 @@ sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) {
   } else if (IsEthosnFunc(call, "ethos-n.qnn_tanh")) {
     if ((err = MakeTanhLayer(call, &tensor))) ReportFatalError(call, err);
     return MakeOps(tensor);
+  } else if (IsEthosnFunc(call, "ethos-n.qnn_leaky_relu")) {
+    if ((err = MakeLeakyReLULayer(call, &tensor))) ReportFatalError(call, err);
+    return MakeOps(tensor);
   } else if (IsEthosnOp(call, "qnn.concatenate")) {
     if ((err = MakeConcatenateLayer(call, &tensor))) ReportFatalError(call, err);
     return MakeOps(tensor);
@@ -492,6 +499,24 @@ EthosnError ConstructNetworkVisitor::MakeTanhLayer(const Call& call,
   return EthosnError();
 }
 
+EthosnError ConstructNetworkVisitor::MakeLeakyReLULayer(const Call& call,
+                                                        sl::TensorAndId<sl::Operand>* out) {
+  LeakyReLUParams params;
+  params.input_info = GetTensorInfo(tensor_table_, call);
+  if (auto err = EthosnAPI::LeakyReLU(call->op.as<FunctionNode>()->body, &params)) {
+    return err;
+  }
+
+  auto input = operand_table_[call->args[0]][0];
+
+  try {
+    *out = AddLeakyRelu(network_, *input, params.leaky_relu_info);
+  } catch (const sl::NotSupportedException& e) {
+    return EthosnError(e.what());
+  }
+  return EthosnError();
+}
+
 EthosnError ConstructNetworkVisitor::MakeConcatenateLayer(const Call& call,
                                                           sl::TensorAndId<sl::Operand>* out) {
   ConcatenateParams params;
@@ -793,7 +818,24 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.tanh")
       TanhParams params;
       auto err = EthosnAPI::Tanh(call, &params);
       err += EthosnCompiler::SupportedSetup();
-      *rv = !err && EthosnCompiler::GetSupported()->IsTanhSupported(params.input_info);
+      char reason[kReasonMaxLength];
+      reason[0] = '\0';
+      *rv = !err && EthosnCompiler::GetSupported()->IsTanhSupported(params.input_info, nullptr,
+                                                                    reason, sizeof(reason));
+      err += EthosnError(reason);
+    });
+
+TVM_REGISTER_GLOBAL("relay.ethos-n.support.leaky_relu")
+    .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
+      Call call = args[0];
+      LeakyReLUParams params;
+      auto err = EthosnAPI::LeakyReLU(call, &params);
+      err += EthosnCompiler::SupportedSetup();
+      char reason[kReasonMaxLength];
+      reason[0] = '\0';
+      *rv = !err && EthosnCompiler::GetSupported()->IsLeakyReluSupported(
+                        params.leaky_relu_info, params.input_info, nullptr, reason, sizeof(reason));
+      err += EthosnError(reason);
     });
 
 TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate")
diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h
index b3b93ffb8b..cca96c044c 100644
--- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h
+++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h
@@ -211,6 +211,7 @@ class ConstructNetworkVisitor : public MixedModeVisitor, private ErrorReportingP
   EthosnError MakeSplitLayer(const Call& call, sl::TensorsAndId* outs);
   EthosnError MakeDepthToSpaceLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
   EthosnError MakeReluLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
+  EthosnError MakeLeakyReLULayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
 
   /*! \brief A look-up table from Expr to layers. */
   std::map<Expr, std::vector<std::shared_ptr<sl::Operand>>> operand_table_;
diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.cc b/src/relay/backend/contrib/ethosn/ethosn_api.cc
index 7a9cb37847..bf2f248b3f 100644
--- a/src/relay/backend/contrib/ethosn/ethosn_api.cc
+++ b/src/relay/backend/contrib/ethosn/ethosn_api.cc
@@ -445,6 +445,38 @@ EthosnError EthosnAPI::Tanh(const Expr& expr, TanhParams* params) {
   return err;
 }
 
+EthosnError EthosnAPI::LeakyReLU(const Expr& expr, LeakyReLUParams* params) {
+  Call quantize = Downcast<Call>(expr);
+  Call leaky_relu = Downcast<Call>(quantize->args[0]);
+  Call dequantize = Downcast<Call>(leaky_relu->args[0]);
+
+  const auto* input_dtype = quantize->checked_type().as<TensorTypeNode>();
+  sl::TensorShape input_tensor_shape = {1, 1, 1, 1};
+  sl::DataType input_tensor_dtype;
+  EthosnError err = Tvm2Npu(input_dtype->shape, &input_tensor_shape);
+  err += Tvm2Npu(input_dtype->dtype, &input_tensor_dtype);
+  float input_sc;
+  int input_zp;
+  err += AsConstant(dequantize->args[2], &input_zp);
+  err += AsConstant(dequantize->args[1], &input_sc);
+  float output_sc;
+  int output_zp;
+  err += AsConstant(quantize->args[2], &output_zp);
+  err += AsConstant(quantize->args[1], &output_sc);
+
+  const auto* attrs = leaky_relu->attrs.as<LeakyReluAttrs>();
+  double alpha = attrs->alpha;
+  if (alpha >= 1.0f || alpha <= 0.0f) {
+    err += EthosnError(
+        ErrStrm() << "leaky relu alpha must be less than 1 and greater than 0, but was " << alpha);
+    return err;
+  }
+  params->leaky_relu_info = sl::LeakyReluInfo(alpha, sl::QuantizationInfo(output_zp, output_sc));
+  params->input_info = sl::TensorInfo(input_tensor_shape, input_tensor_dtype, sl::DataFormat::NHWC,
+                                      sl::QuantizationInfo(input_zp, input_sc));
+  return err;
+}
+
 EthosnError EthosnAPI::Concatenate(const Expr& expr, ConcatenateParams* params) {
   Call call = Downcast<Call>(expr);
   const auto& attrs = call->attrs.as<ConcatenateAttrs>();
diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.h b/src/relay/backend/contrib/ethosn/ethosn_api.h
index 2d49fb2355..6ab256231f 100644
--- a/src/relay/backend/contrib/ethosn/ethosn_api.h
+++ b/src/relay/backend/contrib/ethosn/ethosn_api.h
@@ -100,6 +100,11 @@ struct TanhParams {
   sl::TensorInfo input_info;
 };
 
+struct LeakyReLUParams {
+  sl::LeakyReluInfo leaky_relu_info;
+  sl::TensorInfo input_info;
+};
+
 struct ConcatenateParams {
   sl::QuantizationInfo qInfo;
   sl::ConcatenationInfo concat_info = sl::ConcatenationInfo(1, qInfo);
@@ -204,6 +209,8 @@ class EthosnAPI {
   static EthosnError Mean(const Expr& expr, MeanParams* params);
   /*! \brief Extract the Support Library tanh params from a Relay an ethos-n tanh func */
   static EthosnError Tanh(const Expr& expr, TanhParams* params);
+  /*! \brief Extract the Support Library leaky relu params from an ethos-n leaky relu Relu call. */
+  static EthosnError LeakyReLU(const Expr& expr, LeakyReLUParams* params);
   /*! \brief Extract the Support Library concatenate params from a Relay qnn.concatenate call */
   static EthosnError Concatenate(const Expr& expr, ConcatenateParams* params);
   /*! \brief Extract the Support Library split params from a Relay split call */
diff --git a/tests/python/contrib/test_ethosn/test_leaky_relu.py b/tests/python/contrib/test_ethosn/test_leaky_relu.py
new file mode 100644
index 0000000000..cdd06f5e73
--- /dev/null
+++ b/tests/python/contrib/test_ethosn/test_leaky_relu.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Integration tests for Leaky ReLU"""
+
+import pytest
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.testing import requires_ethosn
+
+from . import infrastructure as tei
+
+
+def _get_model(shape, input_zp, input_sc, output_zp, output_sc, dtype, alpha):
+    x = relay.var("x", shape=shape, dtype=dtype)
+    x = relay.qnn.op.dequantize(
+        x,
+        input_scale=relay.const(input_sc, "float32"),
+        input_zero_point=relay.const(input_zp, "int32"),
+    )
+    x = relay.nn.leaky_relu(x, alpha=alpha)
+    return relay.qnn.op.quantize(
+        x,
+        output_scale=relay.const(output_sc, "float32"),
+        output_zero_point=relay.const(output_zp, "int32"),
+        out_dtype=dtype,
+    )
+
+
+@requires_ethosn
+@pytest.mark.parametrize("dtype", ["uint8", "int8"])
+@pytest.mark.parametrize("shape", [(1, 52, 52, 3), (1, 3, 8, 2)])
+@pytest.mark.parametrize("alpha", [0.001, 0.5678])
+def test_leaky_relu(dtype, shape, alpha):
+    """Compare Leaky ReLU output with TVM."""
+    np.random.seed(0)
+
+    iinfo = np.iinfo(dtype)
+    zp_min = iinfo.min
+    zp_max = iinfo.max
+    input_zp = zp_min + 120
+    input_sc = 0.0068132
+    output_zp = zp_min + 128
+    output_sc = 0.0078125
+
+    inputs = {"x": tvm.nd.array(np.random.randint(zp_min, high=zp_max, size=shape, dtype=dtype))}
+    outputs = []
+    for npu in [False, True]:
+        model = _get_model(shape, input_zp, input_sc, output_zp, output_sc, dtype, alpha)
+        mod = tei.make_module(model, [])
+        outputs.append(tei.build_and_run(mod, inputs, 1, {}, npu=npu))
+
+    tei.verify(outputs, dtype, 1)
+
+
+@requires_ethosn
+@pytest.mark.parametrize("dtype", ["int8"])
+@pytest.mark.parametrize("shape", [(1, 14, 14, 2)])
+@pytest.mark.parametrize("alpha", [-1.34, 2.32, 1, 0])
+def test_leaky_relu_unsupported_alpha(dtype, shape, alpha):
+    """Test unsupported values of alpha (<= 0, >= 1) in Leaky ReLU."""
+    iinfo = np.iinfo(dtype)
+    zp_min = iinfo.min
+
+    err_msg = f"leaky relu alpha must be less than 1 and greater than 0, but was {alpha}"
+
+    model = _get_model(shape, zp_min + 120, 0.0068132, zp_min + 128, 0.0078125, dtype, alpha)
+    model = tei.make_ethosn_composite(model, "ethos-n.qnn_leaky_relu")
+    mod = tei.make_ethosn_partition(model)
+    tei.test_error(mod, {}, err_msg)