You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/07/11 17:06:29 UTC

[tvm] branch main updated: [Frontend][TFLite] PreLU alpha can be an expr (#11879)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c4dc41a0dd [Frontend][TFLite] PreLU alpha can be an expr (#11879)
c4dc41a0dd is described below

commit c4dc41a0dde6ae3118823736c325811f15994615
Author: Rafael Stahl <r....@tum.de>
AuthorDate: Mon Jul 11 19:06:20 2022 +0200

    [Frontend][TFLite] PreLU alpha can be an expr (#11879)
    
    * [Frontend][TFLite] PreLU alpha can be an expr
    
    * [Frontend][TFLite] handle both cases of PreLU alpha param
---
 python/tvm/relay/frontend/tflite.py | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index d7ec441e0e..c8352a9949 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -3010,11 +3010,14 @@ class OperatorConverter(object):
 
         input_tensor = input_tensors[0]
         alpha_tensor = input_tensors[1]
-        alpha_tensor_type = alpha_tensor.tensor.Type()
-        alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
-        alpha_expr = self.exp_tab.new_const(
-            self.get_tensor_value(alpha_tensor), dtype=alpha_tensor_type_str
-        )
+        if self.has_expr(alpha_tensor.tensor_idx):
+            alpha_expr = self.get_expr(alpha_tensor.tensor_idx)
+        else:
+            alpha_tensor_type = alpha_tensor.tensor.Type()
+            alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
+            alpha_expr = self.exp_tab.new_const(
+                self.get_tensor_value(alpha_tensor), dtype=alpha_tensor_type_str
+            )
         in_expr = self.get_expr(input_tensor.tensor_idx)
         data_shape = to_int_list(self.get_tensor_shape(input_tensor))