You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/07/10 18:18:42 UTC

[incubator-tvm] branch master updated: [Relay][Dyn] Dynamic TopK Op (#6008)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 474d472  [Relay][Dyn] Dynamic TopK Op (#6008)
474d472 is described below

commit 474d47234f8a2378f9135fa3200ca7ce75459889
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Jul 10 11:18:34 2020 -0700

    [Relay][Dyn] Dynamic TopK Op (#6008)
    
    * add dynamic topk op
    
    * add topk to dynamic_to_static pass
    
    * fix TF test
    
    * fix pylint
---
 python/tvm/relay/op/_algorithm.py                 | 35 ++---------
 python/tvm/relay/op/algorithm.py                  | 13 ++--
 python/tvm/relay/op/dyn/__init__.py               |  1 +
 python/tvm/relay/op/{ => dyn}/_algorithm.py       | 52 ++++------------
 python/tvm/relay/op/strategy/generic.py           |  3 +-
 src/relay/analysis/util.cc                        |  9 +--
 src/relay/op/algorithm/topk.cc                    | 24 +++----
 src/relay/op/{ => dyn}/algorithm/topk.cc          | 43 +++++++------
 src/relay/transforms/dynamic_to_static.cc         | 20 +++++-
 tests/python/relay/dyn/test_dynamic_op_level6.py  | 76 +++++++++++++++++++++++
 tests/python/relay/test_pass_dynamic_to_static.py | 53 ++++++++++++++++
 11 files changed, 211 insertions(+), 118 deletions(-)

diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py
index 5a20480..cded2e1 100644
--- a/python/tvm/relay/op/_algorithm.py
+++ b/python/tvm/relay/op/_algorithm.py
@@ -35,25 +35,6 @@ register_strategy("topk", strategy.topk_strategy)
 register_pattern("topk", OpPattern.OPAQUE)
 
 @script
-def _topk_shape_func_input_data(data, k, axis):
-    ndim = len(data.shape)
-    val_out = output_tensor((ndim,), "int64")
-    indices_out = output_tensor((ndim,), "int64")
-
-    for i in const_range(ndim):
-        if i != axis:
-            val_out[i] = int64(data.shape[i])
-            indices_out[i] = int64(data.shape[i])
-        else:
-            if k[0] < 1:
-                val_out[i] = int64(data.shape[i])
-                indices_out[i] = int64(data.shape[i])
-            else:
-                val_out[i] = int64(k[0])
-                indices_out[i] = int64(k[0])
-    return val_out, indices_out
-
-@script
 def _topk_shape_func_input_shape(data_shape, k, axis):
     ndim = data_shape.shape[0]
     val_out = output_tensor((ndim,), "int64")
@@ -72,22 +53,16 @@ def _topk_shape_func_input_shape(data_shape, k, axis):
                 indices_out[i] = int64(k)
     return val_out, indices_out
 
-@_reg.register_shape_func("topk", True)
+@_reg.register_shape_func("topk", False)
 def topk_shape_func(attrs, inputs, _):
     """
     Shape func for topk.
     """
     axis = attrs.axis
-    if attrs.k is not None:
-        if axis < 0:
-            axis += inputs[0].shape[0]
-        val_out, indices_out = \
-            _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
-    else:
-        if axis < 0:
-            axis += len(inputs[0].shape)
-        val_out, indices_out = \
-            _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
+    if axis < 0:
+        axis += inputs[0].shape[0]
+    val_out, indices_out = \
+        _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
     ret_type = attrs.ret_type
     if ret_type == "both":
         ret = [val_out, indices_out]
diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py
index d31e89a..5aeb7e6 100644
--- a/python/tvm/relay/op/algorithm.py
+++ b/python/tvm/relay/op/algorithm.py
@@ -16,8 +16,10 @@
 # under the License.
 """Classic algorithm operation"""
 from __future__ import absolute_import as _abs
+import numpy as np
 from . import _make
-from ..expr import TupleWrapper, const
+from .dyn import _make as _dyn_make
+from ..expr import TupleWrapper, Expr, Constant
 
 def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
     """Performs sorting along the given axis and returns an array of indicies
@@ -82,9 +84,12 @@ def topk(data, k=1, axis=-1, ret_type="both",
     out : relay.Expr or List[relay.Expr]
         The computed result.
     """
-    if isinstance(k, int):
-        k = const(k, "int64")
-    out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
+    if isinstance(k, Constant):
+        k = np.asscalar(k.data.asnumpy())
+    if isinstance(k, Expr):
+        out = _dyn_make.topk(data, k, axis, ret_type, is_ascend, dtype)
+    else:
+        out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
     if ret_type == "both":
         return TupleWrapper(out, 2)
     return out
diff --git a/python/tvm/relay/op/dyn/__init__.py b/python/tvm/relay/op/dyn/__init__.py
index d659203..f4d47a6 100644
--- a/python/tvm/relay/op/dyn/__init__.py
+++ b/python/tvm/relay/op/dyn/__init__.py
@@ -17,4 +17,5 @@
 # pylint: disable=wildcard-import, redefined-builtin, invalid-name
 """The Relay namespace containing dynamic ops."""
 
+from . import _algorithm
 from . import _transform
diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/dyn/_algorithm.py
similarity index 58%
copy from python/tvm/relay/op/_algorithm.py
copy to python/tvm/relay/op/dyn/_algorithm.py
index 5a20480..b98b775 100644
--- a/python/tvm/relay/op/_algorithm.py
+++ b/python/tvm/relay/op/dyn/_algorithm.py
@@ -21,18 +21,14 @@ from __future__ import absolute_import
 from tvm.te.hybrid import script
 from tvm.runtime import convert
 
-from . import strategy
-from . import op as _reg
-from .op import OpPattern, register_pattern
-from .op import register_strategy
-
-# argsort
-register_strategy("argsort", strategy.argsort_strategy)
-register_pattern("argsort", OpPattern.OPAQUE)
+from .. import strategy
+from .. import op as _reg
+from ..op import OpPattern, register_pattern
+from ..op import register_strategy
 
 # topk
-register_strategy("topk", strategy.topk_strategy)
-register_pattern("topk", OpPattern.OPAQUE)
+register_strategy("dyn.topk", strategy.topk_strategy)
+register_pattern("dyn.topk", OpPattern.OPAQUE)
 
 @script
 def _topk_shape_func_input_data(data, k, axis):
@@ -53,41 +49,17 @@ def _topk_shape_func_input_data(data, k, axis):
                 indices_out[i] = int64(k[0])
     return val_out, indices_out
 
-@script
-def _topk_shape_func_input_shape(data_shape, k, axis):
-    ndim = data_shape.shape[0]
-    val_out = output_tensor((ndim,), "int64")
-    indices_out = output_tensor((ndim,), "int64")
-
-    for i in const_range(ndim):
-        if i != axis:
-            val_out[i] = int64(data_shape[i])
-            indices_out[i] = int64(data_shape[i])
-        else:
-            if k < 1:
-                val_out[i] = int64(data_shape[i])
-                indices_out[i] = int64(data_shape[i])
-            else:
-                val_out[i] = int64(k)
-                indices_out[i] = int64(k)
-    return val_out, indices_out
-
-@_reg.register_shape_func("topk", True)
+@_reg.register_shape_func("dyn.topk", True)
 def topk_shape_func(attrs, inputs, _):
     """
     Shape func for topk.
     """
     axis = attrs.axis
-    if attrs.k is not None:
-        if axis < 0:
-            axis += inputs[0].shape[0]
-        val_out, indices_out = \
-            _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
-    else:
-        if axis < 0:
-            axis += len(inputs[0].shape)
-        val_out, indices_out = \
-            _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
+    if axis < 0:
+        axis += len(inputs[0].shape)
+    val_out, indices_out = \
+        _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
+
     ret_type = attrs.ret_type
     if ret_type == "both":
         ret = [val_out, indices_out]
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 632445b..db0577c 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -656,9 +656,10 @@ def argsort_strategy(attrs, inputs, out_type, target):
 def wrap_compute_topk(topi_compute):
     """Wrap topk compute"""
     def _compute_topk(attrs, inputs, out_type):
-        k = inputs[1]
         if attrs.k is not None:
             k = attrs.k
+        else:
+            k = inputs[1]
         axis = get_const_int(attrs.axis)
         ret_type = attrs.ret_type
         is_ascend = bool(get_const_int(attrs.is_ascend))
diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc
index 10c226e..c8dbb49 100644
--- a/src/relay/analysis/util.cc
+++ b/src/relay/analysis/util.cc
@@ -448,14 +448,7 @@ bool IsDataDependant(const CallNode* call) {
     return false;
   }
 
-  if (op->name == "topk") {
-    if (const auto* attrs = call->attrs.as<TopKAttrs>()) {
-      if (attrs->k) {
-        // If k attribute exists, it isn't data dependant.
-        return false;
-      }
-    }
-  } else if (op->name == "strided_slice") {
+  if (op->name == "strided_slice") {
     if (const auto* attrs = call->attrs.as<StridedSliceAttrs>()) {
       if (attrs->begin && attrs->end && attrs->strides) {
         // not data dependant if begin, end and strides exist
diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc
index b02fe86..14308dd 100644
--- a/src/relay/op/algorithm/topk.cc
+++ b/src/relay/op/algorithm/topk.cc
@@ -27,7 +27,6 @@
 
 namespace tvm {
 namespace relay {
-using tir::make_const;
 
 TVM_REGISTER_NODE_TYPE(TopKAttrs);
 
@@ -35,7 +34,7 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
              const TypeReporter& reporter) {
   // `types` contains: [data, result]
   const TopKAttrs* param = attrs.as<TopKAttrs>();
-  CHECK_EQ(types.size(), 3);
+  CHECK_EQ(types.size(), 2);
   const auto* data = types[0].as<TensorTypeNode>();
   CHECK(data);
   int ndim = data->shape.size();
@@ -48,42 +47,38 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   for (int i = 0; i < ndim; ++i) {
     if (i != axis) {
       out_shape.push_back(data->shape[i]);
-    } else if (param->k) {
+    } else {
       const Integer& ck = param->k.value();
       if (ck->value < 1) {
         out_shape.push_back(data->shape[i]);
       } else {
         out_shape.push_back(ck);
       }
-    } else {
-      out_shape.push_back(Any());
     }
   }
   auto values_ty = TensorType(out_shape, data->dtype);
   auto indices_ty = TensorType(out_shape, param->dtype);
   if (param->ret_type == "both") {
-    reporter->Assign(types[2], TupleType({values_ty, indices_ty}));
+    reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
   } else if (param->ret_type == "values") {
-    reporter->Assign(types[2], values_ty);
+    reporter->Assign(types[1], values_ty);
   } else if (param->ret_type == "indices") {
-    reporter->Assign(types[2], indices_ty);
+    reporter->Assign(types[1], indices_ty);
   } else {
     LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
   }
   return true;
 }
 
-Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) {
+Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype) {
   auto attrs = make_object<TopKAttrs>();
-  if (const auto& ck = k.as<ConstantNode>()) {
-    attrs->k = tvm::Integer(reinterpret_cast<int*>(ck->data->data)[0]);
-  }
+  attrs->k = Integer(k);
   attrs->axis = axis;
   attrs->ret_type = ret_type;
   attrs->is_ascend = is_ascend;
   attrs->dtype = dtype;
   static const Op& op = Op::Get("topk");
-  return Call(op, {data, k}, Attrs(attrs), {});
+  return Call(op, {data}, Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
@@ -91,10 +86,9 @@ TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
 RELAY_REGISTER_OP("topk")
     .describe(R"doc(Get the top k elements in an input tensor along the given axis.
 )doc" TVM_ADD_FILELINE)
-    .set_num_inputs(2)
+    .set_num_inputs(1)
     .set_attrs_type<TopKAttrs>()
     .add_argument("data", "Tensor", "Input data.")
-    .add_argument("k", "Tensor", "Number of top elements.")
     .set_support_level(6)
     .add_type_rel("TopK", TopKRel);
 
diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/dyn/algorithm/topk.cc
similarity index 73%
copy from src/relay/op/algorithm/topk.cc
copy to src/relay/op/dyn/algorithm/topk.cc
index b02fe86..1c88730 100644
--- a/src/relay/op/algorithm/topk.cc
+++ b/src/relay/op/dyn/algorithm/topk.cc
@@ -27,17 +27,31 @@
 
 namespace tvm {
 namespace relay {
-using tir::make_const;
-
-TVM_REGISTER_NODE_TYPE(TopKAttrs);
+namespace dyn {
 
 bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
              const TypeReporter& reporter) {
-  // `types` contains: [data, result]
+  // `types` contains: [data, k, result]
   const TopKAttrs* param = attrs.as<TopKAttrs>();
   CHECK_EQ(types.size(), 3);
   const auto* data = types[0].as<TensorTypeNode>();
-  CHECK(data);
+  const auto* k = types[1].as<TensorTypeNode>();
+  if (data == nullptr) {
+    CHECK(types[0].as<IncompleteTypeNode>())
+        << "tile: expect input type to be TensorType but get " << types[0];
+    return false;
+  }
+  if (k == nullptr) {
+    CHECK(types[1].as<IncompleteTypeNode>())
+        << "tile: expect input type to be TensorType but get " << types[1];
+    return false;
+  }
+  CHECK(k->shape.size() <= 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )";
+  if (k->shape.size() == 1) {
+    const IntImmNode* k_shape = k->shape[0].as<IntImmNode>();
+    CHECK(k_shape) << "Parameter k must have static shape";
+    CHECK_EQ(k_shape->value, 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )";
+  }
   int ndim = data->shape.size();
   int axis = param->axis;
   if (axis < 0) {
@@ -48,13 +62,6 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   for (int i = 0; i < ndim; ++i) {
     if (i != axis) {
       out_shape.push_back(data->shape[i]);
-    } else if (param->k) {
-      const Integer& ck = param->k.value();
-      if (ck->value < 1) {
-        out_shape.push_back(data->shape[i]);
-      } else {
-        out_shape.push_back(ck);
-      }
     } else {
       out_shape.push_back(Any());
     }
@@ -75,20 +82,17 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
 Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) {
   auto attrs = make_object<TopKAttrs>();
-  if (const auto& ck = k.as<ConstantNode>()) {
-    attrs->k = tvm::Integer(reinterpret_cast<int*>(ck->data->data)[0]);
-  }
   attrs->axis = axis;
   attrs->ret_type = ret_type;
   attrs->is_ascend = is_ascend;
   attrs->dtype = dtype;
-  static const Op& op = Op::Get("topk");
+  static const Op& op = Op::Get("dyn.topk");
   return Call(op, {data, k}, Attrs(attrs), {});
 }
 
-TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
+TVM_REGISTER_GLOBAL("relay.op.dyn._make.topk").set_body_typed(MakeTopK);
 
-RELAY_REGISTER_OP("topk")
+RELAY_REGISTER_OP("dyn.topk")
     .describe(R"doc(Get the top k elements in an input tensor along the given axis.
 )doc" TVM_ADD_FILELINE)
     .set_num_inputs(2)
@@ -96,7 +100,8 @@ RELAY_REGISTER_OP("topk")
     .add_argument("data", "Tensor", "Input data.")
     .add_argument("k", "Tensor", "Number of top elements.")
     .set_support_level(6)
-    .add_type_rel("TopK", TopKRel);
+    .add_type_rel("DynTopK", TopKRel);
 
+}  // namespace dyn
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc
index d09230a..dced502 100644
--- a/src/relay/transforms/dynamic_to_static.cc
+++ b/src/relay/transforms/dynamic_to_static.cc
@@ -22,6 +22,7 @@
  * \file dynamic_to_static.cc
  * \brief Rewrite Dynamic Operations to Static operations where possible
  */
+#include <tvm/relay/attrs/algorithm.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
 
@@ -33,7 +34,9 @@ namespace relay {
 class DynamicToStaticMutator : public MixedModeMutator {
  public:
   DynamicToStaticMutator()
-      : dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {}
+      : dyn_reshape_op_(Op::Get("dyn.reshape")),
+        dyn_tile_op_(Op::Get("dyn.tile")),
+        dyn_topk_op_(Op::Get("dyn.topk")) {}
 
  private:
   Expr Rewrite_(const CallNode* pre, const Expr& post) override {
@@ -55,6 +58,20 @@ class DynamicToStaticMutator : public MixedModeMutator {
         static const Op& op = Op::Get("tile");
         return Call(op, {call_node->args[0]}, Attrs(attrs), {});
       }
+    } else if (call_node->op == dyn_topk_op_) {
+      if (const ConstantNode* k = call_node->args[1].as<ConstantNode>()) {
+        const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
+        CHECK(param);
+        auto attrs = make_object<TopKAttrs>();
+        attrs->k = Integer(ToScalar(k->data, 0));
+        std::cout << attrs->k << std::endl;
+        attrs->axis = param->axis;
+        attrs->ret_type = param->ret_type;
+        attrs->is_ascend = param->is_ascend;
+        attrs->dtype = param->dtype;
+        static const Op& op = Op::Get("topk");
+        return Call(op, {call_node->args[0]}, Attrs(attrs), {});
+      }
     }
     return post;
   }
@@ -68,6 +85,7 @@ class DynamicToStaticMutator : public MixedModeMutator {
 
   const Op& dyn_reshape_op_;
   const Op& dyn_tile_op_;
+  const Op& dyn_topk_op_;
 };
 
 Expr DynamicToStatic(Function f, IRModule m) {
diff --git a/tests/python/relay/dyn/test_dynamic_op_level6.py b/tests/python/relay/dyn/test_dynamic_op_level6.py
new file mode 100644
index 0000000..60a1433
--- /dev/null
+++ b/tests/python/relay/dyn/test_dynamic_op_level6.py
@@ -0,0 +1,76 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Support level6 operator test cases.
+"""
+import numpy as np
+import tvm
+from tvm import te
+from tvm import relay
+from tvm.relay.testing import ctx_list
+
+def test_dynamic_topk():
+    def verify_topk(k, axis, ret_type, is_ascend, dtype):
+        shape = (20, 100)
+        x = relay.var("x", relay.TensorType(shape, "float32"))
+        k_var = relay.var("x", relay.TensorType((1,), "float32"))
+        out = relay.topk(x, k_var, axis, ret_type, is_ascend, dtype)
+        if isinstance(out, relay.expr.TupleWrapper):
+            out = out.astuple()
+        func = relay.Function([x, k_var], out)
+
+        np_data = np.random.uniform(size=shape).astype("float32")
+        if is_ascend:
+            np_indices = np.argsort(np_data, axis=axis)
+        else:
+            np_indices = np.argsort(-np_data, axis=axis)
+        kk = k if k >= 1 else shape[axis]
+        if axis == 0:
+            np_indices = np_indices[:kk, :]
+            np_values = np.zeros(np_indices.shape).astype("float32")
+            for i in range(shape[1]):
+                np_values[:, i] = np_data[np_indices[:, i], i]
+        else:
+            np_indices = np_indices[:, :kk]
+            np_values = np.zeros(np_indices.shape).astype("float32")
+            for i in range(shape[0]):
+                np_values[i, :] = np_data[i, np_indices[i, :]]
+        np_indices = np_indices.astype(dtype)
+
+        for target, ctx in ctx_list():
+            if "llvm" not in target: continue
+            for kind in ["vm", "debug"]:
+                mod = tvm.ir.IRModule.from_expr(func)
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(np_data, np.array([k]).astype("float32"))
+                if ret_type == "both":
+                    tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values)
+                    tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices)
+                elif ret_type == "values":
+                    tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
+                else:
+                    tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
+    np.random.seed(0)
+    for k in [0, 1, 5]:
+        for axis in [0, -1, 1]:
+            for ret_type in ["both", "values", "indices"]:
+                verify_topk(k, axis, ret_type, True, "int64")
+                verify_topk(k, axis, ret_type, False, "float32")
+
+
+if __name__ == "__main__":
+    test_topk()
diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py
index 3415ce0..bcd8a64 100644
--- a/tests/python/relay/test_pass_dynamic_to_static.py
+++ b/tests/python/relay/test_pass_dynamic_to_static.py
@@ -129,9 +129,62 @@ def test_dynamic_to_static_tile():
     verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20))
     verify_tile((4, 7), (4, 2), (16, 14))
 
+def test_dynamic_to_static_topk():
+    def verify_topk(k, axis, ret_type, is_ascend, dtype):
+        shape = (20, 100)
+        x = relay.var("x", relay.TensorType(shape, "float32"))
+        k_var = relay.const(k)
+        out = relay.topk(x, k_var, axis, ret_type, is_ascend, dtype)
+        if isinstance(out, relay.expr.TupleWrapper):
+            out = out.astuple()
+        func = relay.Function([x], out)
+
+        np_data = np.random.uniform(size=shape).astype("float32")
+        if is_ascend:
+            np_indices = np.argsort(np_data, axis=axis)
+        else:
+            np_indices = np.argsort(-np_data, axis=axis)
+        kk = k if k >= 1 else shape[axis]
+        if axis == 0:
+            np_indices = np_indices[:kk, :]
+            np_values = np.zeros(np_indices.shape).astype("float32")
+            for i in range(shape[1]):
+                np_values[:, i] = np_data[np_indices[:, i], i]
+        else:
+            np_indices = np_indices[:, :kk]
+            np_values = np.zeros(np_indices.shape).astype("float32")
+            for i in range(shape[0]):
+                np_values[i, :] = np_data[i, np_indices[i, :]]
+        np_indices = np_indices.astype(dtype)
+
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+        zz = func2.body
+        assert isinstance(zz, relay.Call)
+        assert zz.op == relay.op.get("topk")
+
+        for target, ctx in ctx_list():
+            if "llvm" not in target: continue
+            for kind in ["graph", "vm", "debug"]:
+                mod = tvm.ir.IRModule.from_expr(func2)
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(np_data)
+                if ret_type == "both":
+                    tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values)
+                    tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices)
+                elif ret_type == "values":
+                    tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
+                else:
+                    tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
+    np.random.seed(0)
+    for k in [0, 1, 5]:
+        for axis in [0, -1, 1]:
+            for ret_type in ["both", "values", "indices"]:
+                verify_topk(k, axis, ret_type, True, "int64")
+                verify_topk(k, axis, ret_type, False, "float32")
 if __name__=="__main__":
     test_dynamic_to_static_reshape()
     test_dynamic_to_static_double_reshape()
     test_dynamic_to_static_quad_reshape()
     test_dynamic_to_static_tile()
+    test_dynamic_to_static_topk()