You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ha...@apache.org on 2021/01/25 17:28:03 UTC
[tvm] branch main updated: [FIX] Infer input shape in
sparse_dense_padded's alter_op if one does not exist (#7308)
This is an automated email from the ASF dual-hosted git repository.
haichen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f3b852d [FIX] Infer input shape in sparse_dense_padded's alter_op if one does not exist (#7308)
f3b852d is described below
commit f3b852df94398c76d9e91490c9c031be7b139584
Author: Tristan Konolige <tr...@gmail.com>
AuthorDate: Mon Jan 25 09:27:43 2021 -0800
[FIX] Infer input shape in sparse_dense_padded's alter_op if one does not exist (#7308)
* [FIX] Infer input shape in sparse_dense_padded's alter_op if one does not exist
If there are multiple alter_ops in a model, the first alteration does
not run type inference for the subsequent ones. In this case, we don't
have the shape information, so we run the inferencer manually.
* add todo
---
python/tvm/topi/cuda/sparse.py | 9 ++++++++-
src/relay/transforms/alter_op_layout.cc | 1 +
2 files changed, 9 insertions(+), 1 deletion(-)
diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py
index 0b46cf0..f68b31e 100644
--- a/python/tvm/topi/cuda/sparse.py
+++ b/python/tvm/topi/cuda/sparse.py
@@ -292,7 +292,14 @@ def is_valid_for_sparse_dense_padded(data, weight_data):
"""
# pylint:disable=invalid-name
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
- m = get_const_tuple(data.checked_type.shape)[1]
+ # If there are multiple alter_ops in a model, the first alteration does not
+ # run type inference for the subsequent ones. In this case, we don't have
+ # the shape information, so we run the inferencer manually.
+ try:
+ m = get_const_tuple(data.checked_type.shape)[1]
+ except ValueError:
+ data_infered = relay.transform.InferType()(tvm.IRModule.from_expr(data))["main"]
+ m = get_const_tuple(data_infered.ret_type.shape)[1]
if len(weight_data.shape) == 1:
bs_m = 1
else:
diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc
index 924e61a..d7ffff6 100644
--- a/src/relay/transforms/alter_op_layout.cc
+++ b/src/relay/transforms/alter_op_layout.cc
@@ -110,6 +110,7 @@ class AlterTransformMemorizer : public TransformMemorizer {
* 2. Do not support nested tuple arguments.
*/
Expr AlterOpLayout(const Expr& expr) {
+ // TODO(@icemelon9): need to rerun type inference after applying an alter op.
AlterTransformMemorizer alterMemorizer(make_object<AlterTransformMemorizerNode>());
auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; };