You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/06 12:05:29 UTC

[GitHub] [tvm] ekalda commented on a change in pull request #9627: [microNPU] Add support for SIGMOID

ekalda commented on a change in pull request #9627:
URL: https://github.com/apache/tvm/pull/9627#discussion_r762947023



##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -194,6 +194,76 @@ def __call__(self, *args, **kwargs):
         pass
 
 
+def sigmoid_calc_func(x):
+    """Function to calculate the values for sigmoid"""
+    # Thse limits are inherited from TFLite
+    upper_limit = 8.0
+    lower_limit = -8.0
+
+    if x <= lower_limit:
+        y = 0.0
+    elif x >= upper_limit:
+        y = 1.0
+    else:
+        y = 1 / (1 + math.exp(-x))
+    return y
+
+
+class SigmoidRewriter(DFPatternCallback):

Review comment:
       Good idea! I made the change

##########
File path: python/tvm/relay/op/contrib/ethosu.py
##########
@@ -944,6 +944,35 @@ def tanh_pattern():
     return quant
 
 
+class SigmoidParams:
+    """
+    This class will parse a call to a ethos-u.sigmoid composite function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethos-u.sigmoid"
+
+    def __init__(self, func_body: Call):
+        self.ofm = TensorParams(func_body)
+        self.ifm = TensorParams(func_body.args[0].args[0].args[0])
+
+    def is_valid(self):
+        """
+        This function checks whether reshape has compatible attributes with the NPU

Review comment:
       Done

##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -1038,5 +1038,70 @@ def representative_dataset():
     infra.verify_source(compiled_models, accel_type)
 
 
+@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
+@pytest.mark.parametrize("ifm_shape", [[1, 115, 32, 7], [1, 4, 5, 2]])
+def test_tflite_sigmoid(accel_type, ifm_shape):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        tf.config.run_functions_eagerly(True)
+
+        class Model(tf.Module):
+            @tf.function
+            def tanh_function(self, x):

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org