You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/07/10 18:18:42 UTC
[incubator-tvm] branch master updated: [Relay][Dyn] Dynamic TopK Op
(#6008)
This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 474d472 [Relay][Dyn] Dynamic TopK Op (#6008)
474d472 is described below
commit 474d47234f8a2378f9135fa3200ca7ce75459889
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Jul 10 11:18:34 2020 -0700
[Relay][Dyn] Dynamic TopK Op (#6008)
* add dynamic topk op
* add topk to dynamic_to_static pass
* fix TF test
* fix pylint
---
python/tvm/relay/op/_algorithm.py | 35 ++---------
python/tvm/relay/op/algorithm.py | 13 ++--
python/tvm/relay/op/dyn/__init__.py | 1 +
python/tvm/relay/op/{ => dyn}/_algorithm.py | 52 ++++------------
python/tvm/relay/op/strategy/generic.py | 3 +-
src/relay/analysis/util.cc | 9 +--
src/relay/op/algorithm/topk.cc | 24 +++----
src/relay/op/{ => dyn}/algorithm/topk.cc | 43 +++++++------
src/relay/transforms/dynamic_to_static.cc | 20 +++++-
tests/python/relay/dyn/test_dynamic_op_level6.py | 76 +++++++++++++++++++++++
tests/python/relay/test_pass_dynamic_to_static.py | 53 ++++++++++++++++
11 files changed, 211 insertions(+), 118 deletions(-)
diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py
index 5a20480..cded2e1 100644
--- a/python/tvm/relay/op/_algorithm.py
+++ b/python/tvm/relay/op/_algorithm.py
@@ -35,25 +35,6 @@ register_strategy("topk", strategy.topk_strategy)
register_pattern("topk", OpPattern.OPAQUE)
@script
-def _topk_shape_func_input_data(data, k, axis):
- ndim = len(data.shape)
- val_out = output_tensor((ndim,), "int64")
- indices_out = output_tensor((ndim,), "int64")
-
- for i in const_range(ndim):
- if i != axis:
- val_out[i] = int64(data.shape[i])
- indices_out[i] = int64(data.shape[i])
- else:
- if k[0] < 1:
- val_out[i] = int64(data.shape[i])
- indices_out[i] = int64(data.shape[i])
- else:
- val_out[i] = int64(k[0])
- indices_out[i] = int64(k[0])
- return val_out, indices_out
-
-@script
def _topk_shape_func_input_shape(data_shape, k, axis):
ndim = data_shape.shape[0]
val_out = output_tensor((ndim,), "int64")
@@ -72,22 +53,16 @@ def _topk_shape_func_input_shape(data_shape, k, axis):
indices_out[i] = int64(k)
return val_out, indices_out
-@_reg.register_shape_func("topk", True)
+@_reg.register_shape_func("topk", False)
def topk_shape_func(attrs, inputs, _):
"""
Shape func for topk.
"""
axis = attrs.axis
- if attrs.k is not None:
- if axis < 0:
- axis += inputs[0].shape[0]
- val_out, indices_out = \
- _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
- else:
- if axis < 0:
- axis += len(inputs[0].shape)
- val_out, indices_out = \
- _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
+ if axis < 0:
+ axis += inputs[0].shape[0]
+ val_out, indices_out = \
+ _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
ret_type = attrs.ret_type
if ret_type == "both":
ret = [val_out, indices_out]
diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py
index d31e89a..5aeb7e6 100644
--- a/python/tvm/relay/op/algorithm.py
+++ b/python/tvm/relay/op/algorithm.py
@@ -16,8 +16,10 @@
# under the License.
"""Classic algorithm operation"""
from __future__ import absolute_import as _abs
+import numpy as np
from . import _make
-from ..expr import TupleWrapper, const
+from .dyn import _make as _dyn_make
+from ..expr import TupleWrapper, Expr, Constant
def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies
@@ -82,9 +84,12 @@ def topk(data, k=1, axis=-1, ret_type="both",
out : relay.Expr or List[relay.Expr]
The computed result.
"""
- if isinstance(k, int):
- k = const(k, "int64")
- out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
+ if isinstance(k, Constant):
+ k = np.asscalar(k.data.asnumpy())
+ if isinstance(k, Expr):
+ out = _dyn_make.topk(data, k, axis, ret_type, is_ascend, dtype)
+ else:
+ out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if ret_type == "both":
return TupleWrapper(out, 2)
return out
diff --git a/python/tvm/relay/op/dyn/__init__.py b/python/tvm/relay/op/dyn/__init__.py
index d659203..f4d47a6 100644
--- a/python/tvm/relay/op/dyn/__init__.py
+++ b/python/tvm/relay/op/dyn/__init__.py
@@ -17,4 +17,5 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay namespace containing dynamic ops."""
+from . import _algorithm
from . import _transform
diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/dyn/_algorithm.py
similarity index 58%
copy from python/tvm/relay/op/_algorithm.py
copy to python/tvm/relay/op/dyn/_algorithm.py
index 5a20480..b98b775 100644
--- a/python/tvm/relay/op/_algorithm.py
+++ b/python/tvm/relay/op/dyn/_algorithm.py
@@ -21,18 +21,14 @@ from __future__ import absolute_import
from tvm.te.hybrid import script
from tvm.runtime import convert
-from . import strategy
-from . import op as _reg
-from .op import OpPattern, register_pattern
-from .op import register_strategy
-
-# argsort
-register_strategy("argsort", strategy.argsort_strategy)
-register_pattern("argsort", OpPattern.OPAQUE)
+from .. import strategy
+from .. import op as _reg
+from ..op import OpPattern, register_pattern
+from ..op import register_strategy
# topk
-register_strategy("topk", strategy.topk_strategy)
-register_pattern("topk", OpPattern.OPAQUE)
+register_strategy("dyn.topk", strategy.topk_strategy)
+register_pattern("dyn.topk", OpPattern.OPAQUE)
@script
def _topk_shape_func_input_data(data, k, axis):
@@ -53,41 +49,17 @@ def _topk_shape_func_input_data(data, k, axis):
indices_out[i] = int64(k[0])
return val_out, indices_out
-@script
-def _topk_shape_func_input_shape(data_shape, k, axis):
- ndim = data_shape.shape[0]
- val_out = output_tensor((ndim,), "int64")
- indices_out = output_tensor((ndim,), "int64")
-
- for i in const_range(ndim):
- if i != axis:
- val_out[i] = int64(data_shape[i])
- indices_out[i] = int64(data_shape[i])
- else:
- if k < 1:
- val_out[i] = int64(data_shape[i])
- indices_out[i] = int64(data_shape[i])
- else:
- val_out[i] = int64(k)
- indices_out[i] = int64(k)
- return val_out, indices_out
-
-@_reg.register_shape_func("topk", True)
+@_reg.register_shape_func("dyn.topk", True)
def topk_shape_func(attrs, inputs, _):
"""
Shape func for topk.
"""
axis = attrs.axis
- if attrs.k is not None:
- if axis < 0:
- axis += inputs[0].shape[0]
- val_out, indices_out = \
- _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis))
- else:
- if axis < 0:
- axis += len(inputs[0].shape)
- val_out, indices_out = \
- _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
+ if axis < 0:
+ axis += len(inputs[0].shape)
+ val_out, indices_out = \
+ _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis))
+
ret_type = attrs.ret_type
if ret_type == "both":
ret = [val_out, indices_out]
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index 632445b..db0577c 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -656,9 +656,10 @@ def argsort_strategy(attrs, inputs, out_type, target):
def wrap_compute_topk(topi_compute):
"""Wrap topk compute"""
def _compute_topk(attrs, inputs, out_type):
- k = inputs[1]
if attrs.k is not None:
k = attrs.k
+ else:
+ k = inputs[1]
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc
index 10c226e..c8dbb49 100644
--- a/src/relay/analysis/util.cc
+++ b/src/relay/analysis/util.cc
@@ -448,14 +448,7 @@ bool IsDataDependant(const CallNode* call) {
return false;
}
- if (op->name == "topk") {
- if (const auto* attrs = call->attrs.as<TopKAttrs>()) {
- if (attrs->k) {
- // If k attribute exists, it isn't data dependant.
- return false;
- }
- }
- } else if (op->name == "strided_slice") {
+ if (op->name == "strided_slice") {
if (const auto* attrs = call->attrs.as<StridedSliceAttrs>()) {
if (attrs->begin && attrs->end && attrs->strides) {
// not data dependant if begin, end and strides exist
diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc
index b02fe86..14308dd 100644
--- a/src/relay/op/algorithm/topk.cc
+++ b/src/relay/op/algorithm/topk.cc
@@ -27,7 +27,6 @@
namespace tvm {
namespace relay {
-using tir::make_const;
TVM_REGISTER_NODE_TYPE(TopKAttrs);
@@ -35,7 +34,7 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
- CHECK_EQ(types.size(), 3);
+ CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data);
int ndim = data->shape.size();
@@ -48,42 +47,38 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
for (int i = 0; i < ndim; ++i) {
if (i != axis) {
out_shape.push_back(data->shape[i]);
- } else if (param->k) {
+ } else {
const Integer& ck = param->k.value();
if (ck->value < 1) {
out_shape.push_back(data->shape[i]);
} else {
out_shape.push_back(ck);
}
- } else {
- out_shape.push_back(Any());
}
}
auto values_ty = TensorType(out_shape, data->dtype);
auto indices_ty = TensorType(out_shape, param->dtype);
if (param->ret_type == "both") {
- reporter->Assign(types[2], TupleType({values_ty, indices_ty}));
+ reporter->Assign(types[1], TupleType({values_ty, indices_ty}));
} else if (param->ret_type == "values") {
- reporter->Assign(types[2], values_ty);
+ reporter->Assign(types[1], values_ty);
} else if (param->ret_type == "indices") {
- reporter->Assign(types[2], indices_ty);
+ reporter->Assign(types[1], indices_ty);
} else {
LOG(FATAL) << "Unsupported ret type: " << param->ret_type;
}
return true;
}
-Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) {
+Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype) {
auto attrs = make_object<TopKAttrs>();
- if (const auto& ck = k.as<ConstantNode>()) {
- attrs->k = tvm::Integer(reinterpret_cast<int*>(ck->data->data)[0]);
- }
+ attrs->k = Integer(k);
attrs->axis = axis;
attrs->ret_type = ret_type;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("topk");
- return Call(op, {data, k}, Attrs(attrs), {});
+ return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
@@ -91,10 +86,9 @@ TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
RELAY_REGISTER_OP("topk")
.describe(R"doc(Get the top k elements in an input tensor along the given axis.
)doc" TVM_ADD_FILELINE)
- .set_num_inputs(2)
+ .set_num_inputs(1)
.set_attrs_type<TopKAttrs>()
.add_argument("data", "Tensor", "Input data.")
- .add_argument("k", "Tensor", "Number of top elements.")
.set_support_level(6)
.add_type_rel("TopK", TopKRel);
diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/dyn/algorithm/topk.cc
similarity index 73%
copy from src/relay/op/algorithm/topk.cc
copy to src/relay/op/dyn/algorithm/topk.cc
index b02fe86..1c88730 100644
--- a/src/relay/op/algorithm/topk.cc
+++ b/src/relay/op/dyn/algorithm/topk.cc
@@ -27,17 +27,31 @@
namespace tvm {
namespace relay {
-using tir::make_const;
-
-TVM_REGISTER_NODE_TYPE(TopKAttrs);
+namespace dyn {
bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
- // `types` contains: [data, result]
+ // `types` contains: [data, k, result]
const TopKAttrs* param = attrs.as<TopKAttrs>();
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
- CHECK(data);
+ const auto* k = types[1].as<TensorTypeNode>();
+ if (data == nullptr) {
+ CHECK(types[0].as<IncompleteTypeNode>())
+ << "tile: expect input type to be TensorType but get " << types[0];
+ return false;
+ }
+ if (k == nullptr) {
+ CHECK(types[1].as<IncompleteTypeNode>())
+ << "tile: expect input type to be TensorType but get " << types[1];
+ return false;
+ }
+ CHECK(k->shape.size() <= 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )";
+ if (k->shape.size() == 1) {
+ const IntImmNode* k_shape = k->shape[0].as<IntImmNode>();
+ CHECK(k_shape) << "Parameter k must have static shape";
+ CHECK_EQ(k_shape->value, 1) << "Parameter k must be a Scalar or a Tensor of shape (1, )";
+ }
int ndim = data->shape.size();
int axis = param->axis;
if (axis < 0) {
@@ -48,13 +62,6 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
for (int i = 0; i < ndim; ++i) {
if (i != axis) {
out_shape.push_back(data->shape[i]);
- } else if (param->k) {
- const Integer& ck = param->k.value();
- if (ck->value < 1) {
- out_shape.push_back(data->shape[i]);
- } else {
- out_shape.push_back(ck);
- }
} else {
out_shape.push_back(Any());
}
@@ -75,20 +82,17 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) {
auto attrs = make_object<TopKAttrs>();
- if (const auto& ck = k.as<ConstantNode>()) {
- attrs->k = tvm::Integer(reinterpret_cast<int*>(ck->data->data)[0]);
- }
attrs->axis = axis;
attrs->ret_type = ret_type;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
- static const Op& op = Op::Get("topk");
+ static const Op& op = Op::Get("dyn.topk");
return Call(op, {data, k}, Attrs(attrs), {});
}
-TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK);
+TVM_REGISTER_GLOBAL("relay.op.dyn._make.topk").set_body_typed(MakeTopK);
-RELAY_REGISTER_OP("topk")
+RELAY_REGISTER_OP("dyn.topk")
.describe(R"doc(Get the top k elements in an input tensor along the given axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
@@ -96,7 +100,8 @@ RELAY_REGISTER_OP("topk")
.add_argument("data", "Tensor", "Input data.")
.add_argument("k", "Tensor", "Number of top elements.")
.set_support_level(6)
- .add_type_rel("TopK", TopKRel);
+ .add_type_rel("DynTopK", TopKRel);
+} // namespace dyn
} // namespace relay
} // namespace tvm
diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc
index d09230a..dced502 100644
--- a/src/relay/transforms/dynamic_to_static.cc
+++ b/src/relay/transforms/dynamic_to_static.cc
@@ -22,6 +22,7 @@
* \file dynamic_to_static.cc
* \brief Rewrite Dynamic Operations to Static operations where possible
*/
+#include <tvm/relay/attrs/algorithm.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
@@ -33,7 +34,9 @@ namespace relay {
class DynamicToStaticMutator : public MixedModeMutator {
public:
DynamicToStaticMutator()
- : dyn_reshape_op_(Op::Get("dyn.reshape")), dyn_tile_op_(Op::Get("dyn.tile")) {}
+ : dyn_reshape_op_(Op::Get("dyn.reshape")),
+ dyn_tile_op_(Op::Get("dyn.tile")),
+ dyn_topk_op_(Op::Get("dyn.topk")) {}
private:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
@@ -55,6 +58,20 @@ class DynamicToStaticMutator : public MixedModeMutator {
static const Op& op = Op::Get("tile");
return Call(op, {call_node->args[0]}, Attrs(attrs), {});
}
+ } else if (call_node->op == dyn_topk_op_) {
+ if (const ConstantNode* k = call_node->args[1].as<ConstantNode>()) {
+ const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
+ CHECK(param);
+ auto attrs = make_object<TopKAttrs>();
+ attrs->k = Integer(ToScalar(k->data, 0));
+ std::cout << attrs->k << std::endl;
+ attrs->axis = param->axis;
+ attrs->ret_type = param->ret_type;
+ attrs->is_ascend = param->is_ascend;
+ attrs->dtype = param->dtype;
+ static const Op& op = Op::Get("topk");
+ return Call(op, {call_node->args[0]}, Attrs(attrs), {});
+ }
}
return post;
}
@@ -68,6 +85,7 @@ class DynamicToStaticMutator : public MixedModeMutator {
const Op& dyn_reshape_op_;
const Op& dyn_tile_op_;
+ const Op& dyn_topk_op_;
};
Expr DynamicToStatic(Function f, IRModule m) {
diff --git a/tests/python/relay/dyn/test_dynamic_op_level6.py b/tests/python/relay/dyn/test_dynamic_op_level6.py
new file mode 100644
index 0000000..60a1433
--- /dev/null
+++ b/tests/python/relay/dyn/test_dynamic_op_level6.py
@@ -0,0 +1,76 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Support level6 operator test cases.
+"""
+import numpy as np
+import tvm
+from tvm import te
+from tvm import relay
+from tvm.relay.testing import ctx_list
+
+def test_dynamic_topk():
+ def verify_topk(k, axis, ret_type, is_ascend, dtype):
+ shape = (20, 100)
+ x = relay.var("x", relay.TensorType(shape, "float32"))
+ k_var = relay.var("x", relay.TensorType((1,), "float32"))
+ out = relay.topk(x, k_var, axis, ret_type, is_ascend, dtype)
+ if isinstance(out, relay.expr.TupleWrapper):
+ out = out.astuple()
+ func = relay.Function([x, k_var], out)
+
+ np_data = np.random.uniform(size=shape).astype("float32")
+ if is_ascend:
+ np_indices = np.argsort(np_data, axis=axis)
+ else:
+ np_indices = np.argsort(-np_data, axis=axis)
+ kk = k if k >= 1 else shape[axis]
+ if axis == 0:
+ np_indices = np_indices[:kk, :]
+ np_values = np.zeros(np_indices.shape).astype("float32")
+ for i in range(shape[1]):
+ np_values[:, i] = np_data[np_indices[:, i], i]
+ else:
+ np_indices = np_indices[:, :kk]
+ np_values = np.zeros(np_indices.shape).astype("float32")
+ for i in range(shape[0]):
+ np_values[i, :] = np_data[i, np_indices[i, :]]
+ np_indices = np_indices.astype(dtype)
+
+ for target, ctx in ctx_list():
+ if "llvm" not in target: continue
+ for kind in ["vm", "debug"]:
+ mod = tvm.ir.IRModule.from_expr(func)
+ intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+ op_res = intrp.evaluate()(np_data, np.array([k]).astype("float32"))
+ if ret_type == "both":
+ tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values)
+ tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices)
+ elif ret_type == "values":
+ tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
+ else:
+ tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
+ np.random.seed(0)
+ for k in [0, 1, 5]:
+ for axis in [0, -1, 1]:
+ for ret_type in ["both", "values", "indices"]:
+ verify_topk(k, axis, ret_type, True, "int64")
+ verify_topk(k, axis, ret_type, False, "float32")
+
+
+if __name__ == "__main__":
+ test_topk()
diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py
index 3415ce0..bcd8a64 100644
--- a/tests/python/relay/test_pass_dynamic_to_static.py
+++ b/tests/python/relay/test_pass_dynamic_to_static.py
@@ -129,9 +129,62 @@ def test_dynamic_to_static_tile():
verify_tile((2, 3, 4), (2, 1, 5), (4, 3, 20))
verify_tile((4, 7), (4, 2), (16, 14))
+def test_dynamic_to_static_topk():
+ def verify_topk(k, axis, ret_type, is_ascend, dtype):
+ shape = (20, 100)
+ x = relay.var("x", relay.TensorType(shape, "float32"))
+ k_var = relay.const(k)
+ out = relay.topk(x, k_var, axis, ret_type, is_ascend, dtype)
+ if isinstance(out, relay.expr.TupleWrapper):
+ out = out.astuple()
+ func = relay.Function([x], out)
+
+ np_data = np.random.uniform(size=shape).astype("float32")
+ if is_ascend:
+ np_indices = np.argsort(np_data, axis=axis)
+ else:
+ np_indices = np.argsort(-np_data, axis=axis)
+ kk = k if k >= 1 else shape[axis]
+ if axis == 0:
+ np_indices = np_indices[:kk, :]
+ np_values = np.zeros(np_indices.shape).astype("float32")
+ for i in range(shape[1]):
+ np_values[:, i] = np_data[np_indices[:, i], i]
+ else:
+ np_indices = np_indices[:, :kk]
+ np_values = np.zeros(np_indices.shape).astype("float32")
+ for i in range(shape[0]):
+ np_values[i, :] = np_data[i, np_indices[i, :]]
+ np_indices = np_indices.astype(dtype)
+
+ func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+ zz = func2.body
+ assert isinstance(zz, relay.Call)
+ assert zz.op == relay.op.get("topk")
+
+ for target, ctx in ctx_list():
+ if "llvm" not in target: continue
+ for kind in ["graph", "vm", "debug"]:
+ mod = tvm.ir.IRModule.from_expr(func2)
+ intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+ op_res = intrp.evaluate()(np_data)
+ if ret_type == "both":
+ tvm.testing.assert_allclose(op_res[0].asnumpy(), np_values)
+ tvm.testing.assert_allclose(op_res[1].asnumpy(), np_indices)
+ elif ret_type == "values":
+ tvm.testing.assert_allclose(op_res.asnumpy(), np_values)
+ else:
+ tvm.testing.assert_allclose(op_res.asnumpy(), np_indices)
+ np.random.seed(0)
+ for k in [0, 1, 5]:
+ for axis in [0, -1, 1]:
+ for ret_type in ["both", "values", "indices"]:
+ verify_topk(k, axis, ret_type, True, "int64")
+ verify_topk(k, axis, ret_type, False, "float32")
if __name__=="__main__":
test_dynamic_to_static_reshape()
test_dynamic_to_static_double_reshape()
test_dynamic_to_static_quad_reshape()
test_dynamic_to_static_tile()
+ test_dynamic_to_static_topk()