You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ha...@apache.org on 2021/01/25 17:28:03 UTC

[tvm] branch main updated: [FIX] Infer input shape in sparse_dense_padded's alter_op if one does not exist (#7308)

This is an automated email from the ASF dual-hosted git repository.

haichen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f3b852d  [FIX] Infer input shape in sparse_dense_padded's alter_op if one does not exist (#7308)
f3b852d is described below

commit f3b852df94398c76d9e91490c9c031be7b139584
Author: Tristan Konolige <tr...@gmail.com>
AuthorDate: Mon Jan 25 09:27:43 2021 -0800

    [FIX] Infer input shape in sparse_dense_padded's alter_op if one does not exist (#7308)
    
    * [FIX] Infer input shape in sparse_dense_padded's alter_op if one does not exist
    
    If there are multiple alter_ops in a model, the first alteration does
    not run type inference for the subsequent ones. In this case, we don't
    have the shape information, so we run the inferencer manually.
    
    * add todo
---
 python/tvm/topi/cuda/sparse.py          | 9 ++++++++-
 src/relay/transforms/alter_op_layout.cc | 1 +
 2 files changed, 9 insertions(+), 1 deletion(-)

diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py
index 0b46cf0..f68b31e 100644
--- a/python/tvm/topi/cuda/sparse.py
+++ b/python/tvm/topi/cuda/sparse.py
@@ -292,7 +292,14 @@ def is_valid_for_sparse_dense_padded(data, weight_data):
     """
     # pylint:disable=invalid-name
     warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
-    m = get_const_tuple(data.checked_type.shape)[1]
+    # If there are multiple alter_ops in a model, the first alteration does not
+    # run type inference for the subsequent ones. In this case, we don't have
+    # the shape information, so we run the inferencer manually.
+    try:
+        m = get_const_tuple(data.checked_type.shape)[1]
+    except ValueError:
+        data_infered = relay.transform.InferType()(tvm.IRModule.from_expr(data))["main"]
+        m = get_const_tuple(data_infered.ret_type.shape)[1]
     if len(weight_data.shape) == 1:
         bs_m = 1
     else:
diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc
index 924e61a..d7ffff6 100644
--- a/src/relay/transforms/alter_op_layout.cc
+++ b/src/relay/transforms/alter_op_layout.cc
@@ -110,6 +110,7 @@ class AlterTransformMemorizer : public TransformMemorizer {
  * 2. Do not support nested tuple arguments.
  */
 Expr AlterOpLayout(const Expr& expr) {
+  // TODO(@icemelon9): need to rerun type inference after applying an alter op.
   AlterTransformMemorizer alterMemorizer(make_object<AlterTransformMemorizerNode>());
   auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; };