You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/10 14:46:32 UTC
[incubator-tvm] branch master updated: [REFACTOR][IR] Move to
runtime::String (#5276)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 5da361d [REFACTOR][IR] Move to runtime::String (#5276)
5da361d is described below
commit 5da361d3adf87033b90ab5ff6f3117e8af1bee43
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Fri Apr 10 07:46:23 2020 -0700
[REFACTOR][IR] Move to runtime::String (#5276)
* Use runtime::String
* move string to tvm namespace
* add const char* constructor
* implicit cast from std::string
---
include/tvm/ir/expr.h | 7 +--
include/tvm/ir/transform.h | 11 ++--
include/tvm/node/container.h | 2 +
include/tvm/node/node.h | 2 +
include/tvm/relay/transform.h | 5 +-
include/tvm/runtime/container.h | 10 +++-
include/tvm/target/target.h | 6 +--
include/tvm/tir/stmt_functor.h | 4 +-
include/tvm/tir/transform.h | 4 +-
python/tvm/autotvm/task/task.py | 3 ++
python/tvm/relay/backend/graph_runtime_codegen.py | 3 +-
python/tvm/runtime/container.py | 63 ++++++++++++++++++----
python/tvm/runtime/object_generic.py | 4 +-
python/tvm/target/target.py | 10 ++--
src/autotvm/touch_extractor.cc | 8 +--
src/ir/attrs.cc | 2 +-
src/ir/expr.cc | 7 ++-
src/ir/op.cc | 8 +--
src/ir/transform.cc | 26 ++++-----
src/node/container.cc | 1 -
src/relay/backend/build_module.cc | 17 +++---
src/relay/backend/compile_engine.cc | 17 +++---
src/relay/backend/contrib/codegen_c/codegen_c.h | 2 +-
src/relay/backend/graph_runtime_codegen.cc | 9 ++--
src/relay/backend/vm/compiler.cc | 6 +--
src/relay/backend/vm/inline_primitives.cc | 2 +-
src/relay/backend/vm/lambda_lift.cc | 2 +-
src/relay/backend/vm/removed_unused_funcs.cc | 7 ++-
src/relay/ir/transform.cc | 4 +-
src/relay/op/tensor/transform.cc | 1 -
src/relay/transforms/alter_op_layout.cc | 3 +-
src/relay/transforms/annotate_target.cc | 11 ++--
src/relay/transforms/canonicalize_cast.cc | 3 +-
src/relay/transforms/canonicalize_ops.cc | 3 +-
src/relay/transforms/combine_parallel_conv2d.cc | 3 +-
src/relay/transforms/combine_parallel_dense.cc | 3 +-
src/relay/transforms/combine_parallel_op_batch.cc | 3 +-
src/relay/transforms/convert_layout.cc | 4 +-
src/relay/transforms/device_annotation.cc | 3 +-
src/relay/transforms/eliminate_common_subexpr.cc | 3 +-
src/relay/transforms/fast_math.cc | 3 +-
src/relay/transforms/fold_scale_axis.cc | 6 +--
src/relay/transforms/fuse_ops.cc | 3 +-
src/relay/transforms/inline.cc | 2 +-
src/relay/transforms/legalize.cc | 2 +-
src/relay/transforms/merge_composite.cc | 21 ++++----
src/relay/transforms/partition_graph.cc | 2 +-
src/relay/transforms/simplify_inference.cc | 3 +-
src/relay/transforms/to_a_normal_form.cc | 2 +-
src/runtime/container.cc | 32 ++++++++---
src/target/build_common.h | 2 +-
src/target/generic_func.cc | 5 +-
src/target/llvm/codegen_cpu.cc | 2 +-
src/target/llvm/codegen_llvm.cc | 2 +-
src/target/llvm/llvm_module.cc | 2 +-
src/target/source/codegen_c.cc | 2 +-
src/target/source/codegen_metal.cc | 2 +-
src/target/source/codegen_opengl.cc | 2 +-
src/target/source/codegen_vhls.cc | 7 ++-
src/target/spirv/build_vulkan.cc | 2 +-
src/target/spirv/codegen_spirv.cc | 2 +-
src/target/stackvm/codegen_stackvm.cc | 2 +-
src/target/target.cc | 40 +++++++-------
src/tir/ir/expr.cc | 25 +++++----
src/tir/ir/stmt_functor.cc | 6 +--
src/tir/ir/transform.cc | 2 +-
src/tir/pass/arg_binder.cc | 18 ++++---
src/tir/pass/hoist_if_then_else.cc | 11 ++--
src/tir/pass/tensor_core.cc | 2 +-
src/tir/transforms/bind_device_type.cc | 3 +-
src/tir/transforms/make_packed_api.cc | 13 +++--
src/tir/transforms/remap_thread_axis.cc | 8 ++-
src/tir/transforms/split_host_device.cc | 2 +-
tests/cpp/container_test.cc | 2 +-
tests/python/relay/test_annotate_target.py | 4 +-
tests/python/relay/test_call_graph.py | 2 +-
tests/python/relay/test_external_codegen.py | 5 +-
tests/python/relay/test_ir_nodes.py | 4 +-
.../python/relay/test_ir_structural_equal_hash.py | 6 +--
tests/python/relay/test_pass_inline.py | 28 +++++-----
tests/python/relay/test_pass_merge_composite.py | 32 +++++------
tests/python/relay/test_pass_partition_graph.py | 49 +++++++----------
tests/python/unittest/test_ir_attrs.py | 2 +-
topi/include/topi/contrib/cublas.h | 4 +-
topi/include/topi/contrib/rocblas.h | 2 +-
85 files changed, 364 insertions(+), 306 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index a683fd6..6822159 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -107,11 +107,12 @@ class PrimExpr : public BaseExpr {
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
+
/*!
- * \brief construct from string.
- * \param str The value to be constructed.
+ * \brief construct from runtime String.
+ * \param value The value to be constructed.
*/
- TVM_DLL PrimExpr(std::string str); // NOLINT(*)
+ TVM_DLL PrimExpr(runtime::String value); // NOLINT(*)
/*! \return the data type of this expression. */
DataType dtype() const {
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index ecd234a..3a9913f 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -57,6 +57,7 @@
#define TVM_IR_TRANSFORM_H_
#include <tvm/support/with.h>
+#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
@@ -95,9 +96,9 @@ class PassContextNode : public Object {
int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */
- Array<PrimExpr> required_pass;
+ Array<runtime::String> required_pass;
/*! \brief The list of disabled passes. */
- Array<PrimExpr> disabled_pass;
+ Array<runtime::String> disabled_pass;
TraceFunc trace_func;
@@ -197,7 +198,7 @@ class PassInfoNode : public Object {
std::string name;
/*! \brief The passes that are required to perform the current pass. */
- Array<PrimExpr> required;
+ Array<runtime::String> required;
PassInfoNode() = default;
@@ -226,7 +227,7 @@ class PassInfo : public ObjectRef {
*/
TVM_DLL PassInfo(int opt_level,
std::string name,
- Array<PrimExpr> required);
+ Array<runtime::String> required);
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
@@ -346,7 +347,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const Array<PrimExpr>& required);
+ const Array<runtime::String>& required);
} // namespace transform
} // namespace tvm
diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h
index 461fa11..cf2ac26 100644
--- a/include/tvm/node/container.h
+++ b/include/tvm/node/container.h
@@ -36,6 +36,8 @@
namespace tvm {
+using runtime::String;
+using runtime::StringObj;
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h
index 04f477b..b39e3b4 100644
--- a/include/tvm/node/node.h
+++ b/include/tvm/node/node.h
@@ -35,6 +35,7 @@
#define TVM_NODE_NODE_H_
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
@@ -62,6 +63,7 @@ using runtime::make_object;
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
+using runtime::String;
} // namespace tvm
#endif // TVM_NODE_NODE_H_
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index deb084c..2dcf7f3 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -24,6 +24,7 @@
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
+#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h>
#include <tvm/relay/expr.h>
@@ -59,7 +60,7 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc<
Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required);
+ const tvm::Array<runtime::String>& required);
/*! \brief Remove expressions which does not effect the program result.
*
@@ -355,7 +356,7 @@ TVM_DLL Pass Inline();
*
* \return The pass.
*/
-TVM_DLL Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
+TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
} // namespace transform
diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h
index 50b406b..083f87f 100644
--- a/include/tvm/runtime/container.h
+++ b/include/tvm/runtime/container.h
@@ -360,7 +360,15 @@ class String : public ObjectRef {
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
- explicit String(std::string other);
+ String(std::string other); // NOLINT(*)
+
+ /*!
+ * \brief Construct a new String object
+ *
+ * \param other a char array.
+ */
+ String(const char* other) // NOLINT(*)
+ : String(std::string(other)) {}
/*!
* \brief Change the value the reference object points to.
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index f6fd3c4..59aa955 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -52,11 +52,11 @@ class TargetNode : public Object {
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
- Array<PrimExpr> keys_array;
+ Array<runtime::String> keys_array;
/*! \brief Options for this target */
- Array<PrimExpr> options_array;
+ Array<runtime::String> options_array;
/*! \brief Collection of imported libs */
- Array<PrimExpr> libs_array;
+ Array<runtime::String> libs_array;
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index 6824022..ad5c5cd 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -326,7 +326,7 @@ class StmtExprMutator :
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
- * \param only_enable List of StringImm.
+ * \param only_enable List of runtime::String.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
@@ -334,7 +334,7 @@ class StmtExprMutator :
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
- const Array<PrimExpr>& only_enable = {});
+ const Array<runtime::String>& only_enable = {});
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 860014d..5ad40a3 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -56,7 +56,7 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required);
+ const tvm::Array<runtime::String>& required);
/*!
* \brief Transform the high-level PrimFunc to a low-level version
@@ -100,7 +100,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
*
* \return The pass.
*/
-TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
+TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);
/*!
diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py
index ddee149..00b6676 100644
--- a/python/tvm/autotvm/task/task.py
+++ b/python/tvm/autotvm/task/task.py
@@ -24,6 +24,7 @@ registers the standard task.
import numpy as np
from tvm import target as _target
+from tvm import runtime
from tvm.ir import container
from tvm.tir import expr
from tvm.te import tensor, placeholder
@@ -55,6 +56,8 @@ def serialize_args(args):
return x
if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
+ if isinstance(x, runtime.container.String):
+ return str(x)
if x is None:
return None
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py
index 3e5f015..8210f27 100644
--- a/python/tvm/relay/backend/graph_runtime_codegen.py
+++ b/python/tvm/relay/backend/graph_runtime_codegen.py
@@ -84,8 +84,7 @@ class GraphRuntimeCodegen(object):
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
params = {}
- for name in param_names:
- key = name.value
+ for key in param_names:
arr = self._get_param_by_name(key)
param = empty(arr.shape, dtype=arr.dtype, ctx=arr.ctx)
arr.copyto(param)
diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py
index dd59011..a719dcd 100644
--- a/python/tvm/runtime/container.py
+++ b/python/tvm/runtime/container.py
@@ -16,8 +16,9 @@
# under the License.
"""Runtime container structures."""
import tvm._ffi
-
+from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
+from tvm.runtime import _ffi_api
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
@@ -75,18 +76,19 @@ class ADT(Object):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
- self.__init_handle_by_constructor__(_ADT, tag, *fields)
+ self.__init_handle_by_constructor__(_ffi_api.ADT, tag,
+ *fields)
@property
def tag(self):
- return _GetADTTag(self)
+ return _ffi_api.GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
- self, _GetADTFields, len(self), idx)
+ self, _ffi_api.GetADTFields, len(self), idx)
def __len__(self):
- return _GetADTSize(self)
+ return _ffi_api.GetADTSize(self)
def tuple_object(fields=None):
@@ -106,7 +108,7 @@ def tuple_object(fields=None):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
- return _Tuple(*fields)
+ return _ffi_api.Tuple(*fields)
@tvm._ffi.register_object("runtime.String")
@@ -115,7 +117,7 @@ class String(Object):
Parameters
----------
- string : Str
+ string : str
The string used to construct a runtime String object
Returns
@@ -124,7 +126,50 @@ class String(Object):
The created object.
"""
def __init__(self, string):
- self.__init_handle_by_constructor__(_String, string)
+ self.__init_handle_by_constructor__(_ffi_api.String, string)
+
+ def __str__(self):
+ return _ffi_api.GetStdString(self)
+
+ def __len__(self):
+ return _ffi_api.GetStringSize(self)
+
+ def __hash__(self):
+ return _ffi_api.StringHash(self)
+
+ def __eq__(self, other):
+ if isinstance(other, string_types):
+ return self.__str__() == other
+
+ if not isinstance(other, String):
+ return False
+
+ return _ffi_api.CompareString(self, other) == 0
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __gt__(self, other):
+ return _ffi_api.CompareString(self, other) > 0
+
+ def __lt__(self, other):
+ return _ffi_api.CompareString(self, other) < 0
+
+ def __getitem__(self, key):
+ return self.__str__()[key]
+
+ def startswith(self, string):
+ """Check if the runtime string starts with a given string
+ Parameters
+ ----------
+ string : str
+ The provided string
-tvm._ffi._init_api("tvm.runtime.container")
+ Returns
+ -------
+ ret : boolean
+ Return true if the runtime string starts with the given string,
+ otherwise, false.
+ """
+ return self.__str__().startswith(string)
diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py
index 22354db..a7716df 100644
--- a/python/tvm/runtime/object_generic.py
+++ b/python/tvm/runtime/object_generic.py
@@ -19,7 +19,7 @@
from numbers import Number, Integral
from tvm._ffi.base import string_types
-from . import _ffi_node_api
+from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
@@ -56,7 +56,7 @@ def convert_to_object(value):
if isinstance(value, Number):
return const(value)
if isinstance(value, string_types):
- return _ffi_node_api.String(value)
+ return _ffi_api.String(value)
if isinstance(value, (list, tuple)):
value = [convert_to_object(x) for x in value]
return _ffi_node_api.Array(*value)
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index a83ea0c..fd15ff9 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -48,26 +48,26 @@ class Target(Object):
@property
def keys(self):
if not self._keys:
- self._keys = [k.value for k in self.keys_array]
+ self._keys = [str(k) for k in self.keys_array]
return self._keys
@property
def options(self):
if not self._options:
- self._options = [o.value for o in self.options_array]
+ self._options = [str(o) for o in self.options_array]
return self._options
@property
def libs(self):
if not self._libs:
- self._libs = [l.value for l in self.libs_array]
+ self._libs = [str(l) for l in self.libs_array]
return self._libs
@property
def model(self):
for opt in self.options_array:
- if opt.value.startswith('-model='):
- return opt.value[7:]
+ if opt.startswith('-model='):
+ return opt[7:]
return 'unknown'
@property
diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc
index b5bf2ed..fbd0829 100644
--- a/src/autotvm/touch_extractor.cc
+++ b/src/autotvm/touch_extractor.cc
@@ -252,9 +252,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
for (auto var : vars) {
Array<Array<PrimExpr> > feature_row;
ItervarFeature &fea = touch_analyzer.itervar_map[var];
- feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
+ feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_itervar_"), var});
- Array<PrimExpr> attr{std::string("_attr_"),
+ Array<PrimExpr> attr{tvm::tir::StringImmNode::make("_attr_"),
FloatImm(DataType::Float(32), trans(fea.length)),
IntImm(DataType::Int(32), fea.nest_level),
FloatImm(DataType::Float(32), trans(fea.topdown_product)),
@@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
feature_row.push_back(attr);
// arithmetic
- feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
+ feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_arith_"),
FloatImm(DataType::Float(32), trans(fea.add_ct)),
FloatImm(DataType::Float(32), trans(fea.mul_ct)),
FloatImm(DataType::Float(32), trans(fea.div_ct)),
@@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(
- Array<PrimExpr>{k,
+ Array<PrimExpr>{tvm::tir::StringImmNode::make(k),
FloatImm(DataType::Float(32), trans(v.stride)),
FloatImm(DataType::Float(32), trans(v.mod)),
FloatImm(DataType::Float(32), trans(v.count)),
diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc
index 066b8f9..bee103d 100644
--- a/src/ir/attrs.cc
+++ b/src/ir/attrs.cc
@@ -42,7 +42,7 @@ void DictAttrsNode::InitByPackedArgs(
if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kTVMStr) {
- dict.Set(key, PrimExpr(val.operator std::string()));
+ dict.Set(key, val.operator String());
} else {
dict.Set(key, val.operator PrimExpr());
}
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index b07f04a..1f0337e 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -40,8 +40,8 @@ PrimExpr::PrimExpr(int32_t value)
PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {}
-PrimExpr::PrimExpr(std::string str)
- : PrimExpr(tir::StringImmNode::make(str)) {}
+PrimExpr::PrimExpr(runtime::String value)
+ : PrimExpr(tir::StringImmNode::make(value)) {}
PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
using runtime::ObjectTypeChecker;
@@ -51,6 +51,9 @@ PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
if (ptr->IsInstance<te::TensorNode>()) {
return te::Tensor(ptr)();
}
+ if (ptr->IsInstance<runtime::StringObj>()) {
+ return tir::StringImmNode::make(runtime::String(ptr));
+ }
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
diff --git a/src/ir/op.cc b/src/ir/op.cc
index 6a50240..b024165 100644
--- a/src/ir/op.cc
+++ b/src/ir/op.cc
@@ -24,6 +24,7 @@
#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
#include <tvm/runtime/module.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
@@ -140,10 +141,9 @@ void OpRegistry::UpdateAttr(const std::string& key,
// Frontend APIs
TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
.set_body_typed([]() {
- Array<tvm::PrimExpr> ret;
- for (const std::string& name :
- dmlc::Registry<OpRegistry>::ListAllNames()) {
- ret.push_back(tvm::PrimExpr(name));
+ Array<runtime::String> ret;
+ for (const std::string& name : dmlc::Registry<OpRegistry>::ListAllNames()) {
+ ret.push_back(name);
}
return ret;
});
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 61c1fc2..6e38aac 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -23,6 +23,7 @@
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/device_api.h>
#include <tvm/node/repr_printer.h>
#include <tvm/ir/transform.h>
@@ -212,7 +213,7 @@ class SequentialNode : public PassNode {
PassInfo::PassInfo(int opt_level,
std::string name,
- tvm::Array<tvm::PrimExpr> required) {
+ tvm::Array<runtime::String> required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
@@ -274,12 +275,10 @@ void SequentialNode::ResolveDependency(const IRModule& mod) {
}
// linearly scan the pass array to match pass_name
-inline bool PassArrayContains(const Array<tvm::PrimExpr>& pass_array,
+inline bool PassArrayContains(const Array<runtime::String>& pass_array,
const std::string& pass_name) {
for (auto x : pass_array) {
- auto* str_name = x.as<tir::StringImmNode>();
- CHECK(str_name) << "pass name must be str";
- if (str_name->value == pass_name) return true;
+ if (x == pass_name) return true;
}
return false;
}
@@ -324,9 +323,7 @@ IRModule SequentialNode::operator()(const IRModule& module,
if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
- const auto* name = it.as<tvm::tir::StringImmNode>();
- CHECK(name);
- mod = GetPass(name->value)(mod, pass_ctx);
+ mod = GetPass(it)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
@@ -337,7 +334,7 @@ Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required) {
+ const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
}
@@ -345,7 +342,7 @@ Pass CreateModulePass(
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("transform.PassInfo")
-.set_body_typed([](int opt_level, std::string name, tvm::Array<PrimExpr> required) {
+.set_body_typed([](int opt_level, std::string name, tvm::Array<runtime::String> required) {
return PassInfo(opt_level, name, required);
});
@@ -363,8 +360,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "opt_level: " << node->opt_level;
p->stream << "required passes: [" << "\n";
for (const auto& it : node->required) {
- const auto* str = it.as<tvm::tir::StringImmNode>();
- p->stream << str->value << ", ";
+ p->stream << it << ", ";
}
p->stream << "]\n";
});
@@ -401,7 +397,7 @@ TVM_REGISTER_GLOBAL("transform.Sequential")
tvm::Array<Pass> passes = args[0];
int opt_level = args[1];
std::string name = args[2];
- tvm::Array<tvm::PrimExpr> required = args[3];
+ tvm::Array<runtime::String> required = args[3];
PassInfo pass_info = PassInfo(opt_level, name, required);
*ret = Sequential(passes, pass_info);
});
@@ -427,8 +423,8 @@ TVM_REGISTER_GLOBAL("transform.PassContext")
auto pctx = PassContext::Create();
int opt_level = args[0];
int fallback_device = args[1];
- tvm::Array<tvm::PrimExpr> required = args[2];
- tvm::Array<tvm::PrimExpr> disabled = args[3];
+ tvm::Array<runtime::String> required = args[2];
+ tvm::Array<runtime::String> disabled = args[3];
TraceFunc trace_func = args[4];
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
diff --git a/src/node/container.cc b/src/node/container.cc
index e7e4979..bce2eee 100644
--- a/src/node/container.cc
+++ b/src/node/container.cc
@@ -63,7 +63,6 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
static_cast<const runtime::StringObj*>(n)).operator std::string();
});
-
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index eaf78bc..e2d5e93 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -86,9 +86,10 @@ struct GraphCodegen {
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
- auto names = CallFunc<Array<tvm::PrimExpr>>("list_params_name", nullptr);
- for (auto expr : names) {
- auto key = expr.as<tir::StringImmNode>()->value;
+ auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
+ for (const auto& expr : names) {
+ // Implicit cast from runtime::String to std::string
+ std::string key = expr;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
}
return ret;
@@ -191,12 +192,12 @@ class RelayBuildModule : public runtime::ModuleNode {
/*!
* \brief List all paramter names
*
- * \return Array<StringImm> names of params
+ * \return Array<runtime::String> names of params
*/
- Array<tvm::PrimExpr> ListParamNames() {
- Array<tvm::PrimExpr> ret;
+ Array<runtime::String> ListParamNames() {
+ Array<runtime::String> ret;
for (const auto& kv : params_) {
- ret.push_back(tir::StringImmNode::make(kv.first));
+ ret.push_back(kv.first);
}
return ret;
}
@@ -272,7 +273,7 @@ class RelayBuildModule : public runtime::ModuleNode {
}
Array<Pass> pass_seqs;
- Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
+ Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc
index f75da07..9cb6b2e 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -617,17 +617,18 @@ class CompileEngineImpl : public CompileEngineNode {
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
- if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
- auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set";
- if (ext_mods.find(code_gen->value) == ext_mods.end()) {
- ext_mods[code_gen->value] = IRModule({}, {});
+ std::string code_gen_name = code_gen;
+ if (ext_mods.find(code_gen_name) == ext_mods.end()) {
+ ext_mods[code_gen_name] = IRModule({}, {});
}
- auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
auto gv = GlobalVar(std::string(symbol_name));
- ext_mods[code_gen->value]->Add(gv, src_func);
+ ext_mods[code_gen_name]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
}
@@ -691,10 +692,10 @@ class CompileEngineImpl : public CompileEngineNode {
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
- if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
- key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined())
<< "External function has not been attached a name yet.";
cache_node->func_name = std::string(name_node);
diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h
index 79d4d3f..1db3f20 100644
--- a/src/relay/backend/contrib/codegen_c/codegen_c.h
+++ b/src/relay/backend/contrib/codegen_c/codegen_c.h
@@ -70,7 +70,7 @@ class CSourceModuleCodegenBase {
*/
std::string GetExtSymbol(const Function& func) const {
const auto name_node =
- func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node);
}
diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc
index c7f1be8..4279db0 100644
--- a/src/relay/backend/graph_runtime_codegen.cc
+++ b/src/relay/backend/graph_runtime_codegen.cc
@@ -419,7 +419,7 @@ class GraphRuntimeCodegen
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;
// Handle external function
- if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
@@ -482,7 +482,7 @@ class GraphRuntimeCodegen
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
- CHECK(op->GetAttr<tir::StringImm>(attr::kCompiler).defined())
+ CHECK(op->GetAttr<String>(attr::kCompiler).defined())
<< "Only functions supported by custom codegen";
return {};
}
@@ -633,10 +633,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
});
} else if (name == "list_params_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- Array<tvm::PrimExpr> ret;
+ Array<runtime::String> ret;
for (const auto &kv : this->output_.params) {
- tvm::PrimExpr name = tir::StringImmNode::make(kv.first);
- ret.push_back(name);
+ ret.push_back(kv.first);
}
*rv = ret;
});
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index d68bff6..e2b0fff 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -475,7 +475,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
Target target;
- if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
} else {
// Next generate the invoke instruction.
@@ -493,7 +493,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto cfunc = engine_->Lower(key);
auto op_index = -1;
- if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
} else {
@@ -873,7 +873,7 @@ void VMCompiler::Lower(IRModule mod,
IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {
Array<Pass> pass_seqs;
- Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
+ Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc
index 74b2a47..12113b0 100644
--- a/src/relay/backend/vm/inline_primitives.cc
+++ b/src/relay/backend/vm/inline_primitives.cc
@@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
- if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
+ if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global
diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc
index 80745e1..59c549c 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -190,7 +190,7 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
- if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
+ if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
func = Function(func->params,
VisitExpr(func->body),
diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc
index dd11fce..c2fe37f 100644
--- a/src/relay/backend/vm/removed_unused_funcs.cc
+++ b/src/relay/backend/vm/removed_unused_funcs.cc
@@ -87,11 +87,10 @@ struct CallTracer : ExprVisitor {
* \return The module with dead functions removed.
*/
IRModule RemoveUnusedFunctions(const IRModule& module,
- Array<tvm::PrimExpr> entry_funcs) {
+ Array<runtime::String> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
- auto* str_name = entry.as<tir::StringImmNode>();
- auto funcs = CallTracer(module).Trace(str_name->value);
+ auto funcs = CallTracer(module).Trace(entry);
called_funcs.insert(funcs.cbegin(), funcs.cend());
}
auto existing_functions = module->functions;
@@ -108,7 +107,7 @@ IRModule RemoveUnusedFunctions(const IRModule& module,
namespace transform {
-Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions) {
+Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
return relay::vm::RemoveUnusedFunctions(m, entry_functions);
diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc
index a4bab36..fa709eb 100644
--- a/src/relay/ir/transform.cc
+++ b/src/relay/ir/transform.cc
@@ -145,14 +145,14 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
bool FunctionPassNode::SkipFunction(const Function& func) const {
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
- (func->GetAttr<tir::StringImm>(attr::kCompiler).defined());
+ (func->GetAttr<String>(attr::kCompiler).defined());
}
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required) {
+ const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return FunctionPass(pass_func, pass_info);
}
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 87b4602..7aa8bf1 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1177,7 +1177,6 @@ Array<te::Tensor> ArangeCompute(const Attrs& attrs,
te::Tensor start = inputs[0];
te::Tensor stop = inputs[1];
te::Tensor step = inputs[2];
- Array<tvm::PrimExpr> empty = {0};
return { DynamicArange(start, stop, step, param->dtype) };
}
diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc
index 63c1cb9..aab0b3a 100644
--- a/src/relay/transforms/alter_op_layout.cc
+++ b/src/relay/transforms/alter_op_layout.cc
@@ -125,8 +125,7 @@ Pass AlterOpLayout() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
};
- return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc
index c3d34cb..44ef35a 100644
--- a/src/relay/transforms/annotate_target.cc
+++ b/src/relay/transforms/annotate_target.cc
@@ -59,11 +59,12 @@ class AnnotateTargetWrapper : public ExprMutator {
// handle composite functions
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
- auto comp_name = func->GetAttr<tir::StringImm>(attr::kComposite);
+ auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined()) {
- size_t i = comp_name->value.find('.');
+ std::string comp_name_str = comp_name;
+ size_t i = comp_name_str.find('.');
if (i != std::string::npos) {
- std::string target = comp_name->value.substr(0, i);
+ std::string target = comp_name_str.substr(0, i);
if (target == target_) return true;
}
}
@@ -147,7 +148,7 @@ class AnnotateTargetWrapper : public ExprMutator {
Function func;
Expr new_body;
// don't step into composite functions
- if (fn->GetAttr<tir::StringImm>(attr::kComposite).defined()) {
+ if (fn->GetAttr<String>(attr::kComposite).defined()) {
func = GetRef<Function>(fn);
new_body = func->body;
} else {
@@ -225,7 +226,7 @@ Pass AnnotateTarget(const std::string& target) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
- {tir::StringImmNode::make("InferType")});
+ {"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
}
diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc
index 759a4ae..ebcbd57 100644
--- a/src/relay/transforms/canonicalize_cast.cc
+++ b/src/relay/transforms/canonicalize_cast.cc
@@ -133,8 +133,7 @@ Pass CanonicalizeCast() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f));
};
- return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast")
diff --git a/src/relay/transforms/canonicalize_ops.cc b/src/relay/transforms/canonicalize_ops.cc
index 97a128d..1d3111b 100644
--- a/src/relay/transforms/canonicalize_ops.cc
+++ b/src/relay/transforms/canonicalize_ops.cc
@@ -74,8 +74,7 @@ Pass CanonicalizeOps() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeOps(f));
};
- return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps")
diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc
index 3884dac..af6b135 100644
--- a/src/relay/transforms/combine_parallel_conv2d.cc
+++ b/src/relay/transforms/combine_parallel_conv2d.cc
@@ -220,8 +220,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
};
- return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D")
diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc
index 612dae5..1278020 100644
--- a/src/relay/transforms/combine_parallel_dense.cc
+++ b/src/relay/transforms/combine_parallel_dense.cc
@@ -80,8 +80,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
};
- return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc
index 55ca3f6..361565e 100644
--- a/src/relay/transforms/combine_parallel_op_batch.cc
+++ b/src/relay/transforms/combine_parallel_op_batch.cc
@@ -193,8 +193,7 @@ Pass CombineParallelOpBatch(const std::string& op_name,
batch_op_name,
min_num_branches));
};
- return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc
index 871969d..dbb2c38 100644
--- a/src/relay/transforms/convert_layout.cc
+++ b/src/relay/transforms/convert_layout.cc
@@ -133,9 +133,7 @@ Pass ConvertLayout(const std::string& desired_layout) {
return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
};
return CreateFunctionPass(
- pass_func, 3, "ConvertLayout",
- {tir::StringImmNode::make("InferType"),
- tir::StringImmNode::make("CanonicalizeOps")});
+ pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
}
TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc
index b4d61f1..908ba87 100644
--- a/src/relay/transforms/device_annotation.cc
+++ b/src/relay/transforms/device_annotation.cc
@@ -573,8 +573,7 @@ Pass RewriteAnnotatedOps(int fallback_device) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
};
- return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc
index f905ba5..68c59f5 100644
--- a/src/relay/transforms/eliminate_common_subexpr.cc
+++ b/src/relay/transforms/eliminate_common_subexpr.cc
@@ -91,8 +91,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
};
- return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc
index cf00a89..8234dea 100644
--- a/src/relay/transforms/fast_math.cc
+++ b/src/relay/transforms/fast_math.cc
@@ -70,8 +70,7 @@ Pass FastMath() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FastMath(f));
};
- return CreateFunctionPass(pass_func, 4, "FastMath",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.FastMath")
diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc
index 49f6e3f..cfe74bf 100644
--- a/src/relay/transforms/fold_scale_axis.cc
+++ b/src/relay/transforms/fold_scale_axis.cc
@@ -960,8 +960,7 @@ Pass ForwardFoldScaleAxis() {
return Downcast<Function>(
relay::fold_scale_axis::ForwardFoldScaleAxis(f));
};
- return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
@@ -973,8 +972,7 @@ Pass BackwardFoldScaleAxis() {
return Downcast<Function>(
relay::fold_scale_axis::BackwardFoldScaleAxis(f));
};
- return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc
index 9168898..f646042 100644
--- a/src/relay/transforms/fuse_ops.cc
+++ b/src/relay/transforms/fuse_ops.cc
@@ -980,8 +980,7 @@ Pass FuseOps(int fuse_opt_level) {
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
};
- return CreateFunctionPass(pass_func, 1, "FuseOps",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.FuseOps")
diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc
index ef3c51f..ba0f568 100644
--- a/src/relay/transforms/inline.cc
+++ b/src/relay/transforms/inline.cc
@@ -131,7 +131,7 @@ class Inliner : ExprMutator {
fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
- if (!func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (!func->GetAttr<String>(attr::kCompiler).defined()) {
CHECK_EQ(func->params.size(), args.size())
<< "Mismatch found in the number of parameters and call args";
// Bind the parameters with call args.
diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc
index 01411a6..0b5c671 100644
--- a/src/relay/transforms/legalize.cc
+++ b/src/relay/transforms/legalize.cc
@@ -101,7 +101,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
- return CreateFunctionPass(pass_func, 1, "Legalize", {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc
index 35b93dc..75d95f0 100644
--- a/src/relay/transforms/merge_composite.cc
+++ b/src/relay/transforms/merge_composite.cc
@@ -159,9 +159,9 @@ class MergeCompositeWrapper : public ExprMutator {
if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
- const auto name_node = func->GetAttr<tir::StringImm>(attr::kComposite);
+ auto name_node = func->GetAttr<String>(attr::kComposite);
// don't step into existing composite functions
- if (name_node.defined() && name_node->value != "") {
+ if (name_node.defined() && name_node != "") {
tvm::Array<tvm::relay::Expr> new_args;
for (const auto& arg : call->args) {
auto new_e = this->Mutate(arg);
@@ -185,7 +185,7 @@ class MergeCompositeWrapper : public ExprMutator {
auto free_vars = FreeVars(extract);
// make the composite function
auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
- f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_));
+ f = WithAttr(std::move(f), attr::kComposite, runtime::String(pattern_name_));
// find the expressions associated with the free vars using the args_map
// this tells us which expressions should be given as inputs to the composite function
Array<Expr> args;
@@ -207,16 +207,14 @@ class MergeCompositeWrapper : public ExprMutator {
PackedFunc check_;
};
-Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names,
+Expr MergeComposite(const Expr& expr, const Array<runtime::String>& pattern_names,
const Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
CHECK_EQ(pattern_names.size(), patterns.size());
Expr merged_expr = expr;
// merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) {
- std::string pattern_name = pattern_names[i]->value;
- Expr pattern = patterns[i];
- PackedFunc check = checks[i];
- merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr);
+ merged_expr =
+ MergeCompositeWrapper(pattern_names[i], patterns[i], checks[i]).Mutate(merged_expr);
}
return merged_expr;
}
@@ -225,7 +223,7 @@ Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names
namespace transform {
-Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
+Pass MergeComposite(const tvm::Array<runtime::String>& pattern_names,
const tvm::Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
@@ -236,8 +234,9 @@ Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
return func_pass;
}
-TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
- tvm::Array<tir::StringImm> pattern_names = args[0];
+TVM_REGISTER_GLOBAL("relay._transform.MergeComposite")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ tvm::Array<runtime::String> pattern_names = args[0];
tvm::Array<Expr> patterns = args[1];
std::vector<PackedFunc> checks;
for (int i = 2; i < args.size(); i++) {
diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc
index a4e3863..8eeac17 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -245,7 +245,7 @@ class Partitioner : public ExprMutator {
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
- tvm::tir::StringImmNode::make(target));
+ tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc
index bc7c15e..d349fdd 100644
--- a/src/relay/transforms/simplify_inference.cc
+++ b/src/relay/transforms/simplify_inference.cc
@@ -204,8 +204,7 @@ Pass SimplifyInference() {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SimplifyInference(f));
};
- return CreateFunctionPass(pass_func, 0, "SimplifyInference",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference")
diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc
index 6e35dfb..21c5162 100644
--- a/src/relay/transforms/to_a_normal_form.cc
+++ b/src/relay/transforms/to_a_normal_form.cc
@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
- if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
+ if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
}
Expr ret =
TransformF([&](const Expr& e) {
diff --git a/src/runtime/container.cc b/src/runtime/container.cc
index 400f646..81dfd3d 100644
--- a/src/runtime/container.cc
+++ b/src/runtime/container.cc
@@ -32,14 +32,14 @@ namespace runtime {
using namespace vm;
-TVM_REGISTER_GLOBAL("runtime.container._GetADTTag")
+TVM_REGISTER_GLOBAL("runtime.GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag());
});
-TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
+TVM_REGISTER_GLOBAL("runtime.GetADTSize")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
@@ -47,7 +47,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
});
-TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
+TVM_REGISTER_GLOBAL("runtime.GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
@@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
*rv = adt[idx];
});
-TVM_REGISTER_GLOBAL("runtime.container._Tuple")
+TVM_REGISTER_GLOBAL("runtime.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
@@ -65,7 +65,7 @@ TVM_REGISTER_GLOBAL("runtime.container._Tuple")
*rv = ADT::Tuple(fields);
});
-TVM_REGISTER_GLOBAL("runtime.container._ADT")
+TVM_REGISTER_GLOBAL("runtime.ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
@@ -76,11 +76,31 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT")
*rv = ADT(tag, fields);
});
-TVM_REGISTER_GLOBAL("runtime.container._String")
+TVM_REGISTER_GLOBAL("runtime.String")
.set_body_typed([](std::string str) {
return String(std::move(str));
});
+TVM_REGISTER_GLOBAL("runtime.GetStringSize")
+.set_body_typed([](String str) {
+ return static_cast<int64_t>(str.size());
+});
+
+TVM_REGISTER_GLOBAL("runtime.GetStdString")
+.set_body_typed([](String str) {
+ return std::string(str);
+});
+
+TVM_REGISTER_GLOBAL("runtime.CompareString")
+.set_body_typed([](String lhs, String rhs) {
+ return lhs.compare(rhs);
+});
+
+TVM_REGISTER_GLOBAL("runtime.StringHash")
+.set_body_typed([](String str) {
+ return static_cast<int64_t>(std::hash<String>()(str));
+});
+
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
diff --git a/src/target/build_common.h b/src/target/build_common.h
index fc45cef..5ba51da 100644
--- a/src/target/build_common.h
+++ b/src/target/build_common.h
@@ -57,7 +57,7 @@ ExtractFuncInfo(const IRModule& mod) {
info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
}
}
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol)] = info;
}
return fmap;
diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc
index 8eef4b7..44d017f 100644
--- a/src/target/generic_func.cc
+++ b/src/target/generic_func.cc
@@ -22,6 +22,7 @@
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
#include <tvm/node/node.h>
#include <tvm/node/repr_printer.h>
#include <tvm/target/target.h>
@@ -150,12 +151,12 @@ TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc")
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
- Array<PrimExpr> tags = args[2];
+ Array<runtime::String> tags = args[2];
bool allow_override = args[3];
std::vector<std::string> tags_vector;
for (auto& tag : tags) {
- tags_vector.push_back(tag.as<tvm::tir::StringImmNode>()->value);
+ tags_vector.push_back(tag);
}
generic_func
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index f0b0a4b..a863056 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -126,7 +126,7 @@ void CodeGenCPU::Init(const std::string& module_name,
void CodeGenCPU::AddFunction(const PrimFunc& f) {
CodeGenLLVM::AddFunction(f);
if (f_tvm_register_system_symbol_ != nullptr) {
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 28f4efd..bb0b7e4 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -128,7 +128,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
llvm::FunctionType* ftype = llvm::FunctionType::get(
ret_void ? t_void_ : t_int_, param_types, false);
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
CHECK(module_->getFunction(static_cast<std::string>(global_symbol)) == nullptr)
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index 9ea77ac..52dccba 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -214,7 +214,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined());
entry_func = global_symbol;
}
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 0cb4742..a0e18a6 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -78,7 +78,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
// reserve keywords
ReserveKeywordsAsUnique();
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index 2f31a3e..715c0ae 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -56,7 +56,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
GetUniqueName("_");
// add to alloc buffer type.
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc
index 4748599..13d87d2 100644
--- a/src/target/source/codegen_opengl.cc
+++ b/src/target/source/codegen_opengl.cc
@@ -156,7 +156,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) {
arg_kinds.push_back(kind);
}
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute";
diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc
index 6c1c3b9..7486164 100644
--- a/src/target/source/codegen_vhls.cc
+++ b/src/target/source/codegen_vhls.cc
@@ -147,7 +147,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
std::string whole_code = cg.Finish();
// Generate source code for compilation.
- Array<Array<PrimExpr> > kernel_info;
+ Array<Array<runtime::String> > kernel_info;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
@@ -161,11 +161,10 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
code = (*f)(code).operator std::string();
}
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
- std::string func_name = global_symbol;
- kernel_info.push_back(Array<PrimExpr>({func_name, code}));
+ kernel_info.push_back({global_symbol, code});
}
std::string xclbin;
diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc
index b6f9b86..5872141 100644
--- a/src/target/spirv/build_vulkan.cc
+++ b/src/target/spirv/build_vulkan.cc
@@ -90,7 +90,7 @@ runtime::Module BuildSPIRV(IRModule mod) {
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index 0241e22..db2a2f3 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -78,7 +78,7 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) {
builder_->MakeInst(spv::OpReturn);
builder_->MakeInst(spv::OpFunctionEnd);
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc
index af8b341..da75a70 100644
--- a/src/target/stackvm/codegen_stackvm.cc
+++ b/src/target/stackvm/codegen_stackvm.cc
@@ -536,7 +536,7 @@ runtime::Module BuildStackVM(const IRModule& mod) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenStackVM: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol;
diff --git a/src/target/target.cc b/src/target/target.cc
index 8fb9cb6..306fba4 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -62,39 +62,39 @@ Target CreateTarget(const std::string& target_name,
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
- t->options_array.push_back(tir::StringImmNode::make(item));
+ t->options_array.push_back(item);
if (item.find(libs_flag) == 0) {
std::stringstream ss(item.substr(libs_flag.length()));
std::string lib_item;
while (std::getline(ss, lib_item, ',')) {
- t->libs_array.push_back(tir::StringImmNode::make(lib_item));
+ t->libs_array.push_back(lib_item);
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
- t->keys_array.push_back(tir::StringImmNode::make(t->device_name));
+ t->keys_array.push_back(t->device_name);
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
- t->keys_array.push_back(tir::StringImmNode::make(key_item));
+ t->keys_array.push_back(key_item);
}
}
}
if (t->device_name.length() > 0) {
- t->keys_array.push_back(tir::StringImmNode::make(t->device_name));
+ t->keys_array.push_back(t->device_name);
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "c" && t->device_name == "micro_dev") {
t->device_type = kDLMicroDev;
} else if (target_name == "c" || target_name == "llvm") {
- t->keys_array.push_back(tir::StringImmNode::make("cpu"));
+ t->keys_array.push_back("cpu");
} else if (target_name == "cuda" || target_name == "nvptx") {
t->device_type = kDLGPU;
- t->keys_array.push_back(tir::StringImmNode::make("cuda"));
- t->keys_array.push_back(tir::StringImmNode::make("gpu"));
+ t->keys_array.push_back("cuda");
+ t->keys_array.push_back("gpu");
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
@@ -104,8 +104,8 @@ Target CreateTarget(const std::string& target_name,
} else {
t->device_type = kDLROCM;
}
- t->keys_array.push_back(tir::StringImmNode::make(target_name));
- t->keys_array.push_back(tir::StringImmNode::make("gpu"));
+ t->keys_array.push_back(target_name);
+ t->keys_array.push_back("gpu");
t->max_num_threads = 256;
if (t->device_name == "intel_graphics") {
t->thread_warp_size = 16;
@@ -116,20 +116,20 @@ Target CreateTarget(const std::string& target_name,
} else {
t->device_type = kDLVulkan;
}
- t->keys_array.push_back(tir::StringImmNode::make(target_name));
- t->keys_array.push_back(tir::StringImmNode::make("gpu"));
+ t->keys_array.push_back(target_name);
+ t->keys_array.push_back("gpu");
t->max_num_threads = 256;
} else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL;
- t->keys_array.push_back(tir::StringImmNode::make("sdaccel"));
- t->keys_array.push_back(tir::StringImmNode::make("hls"));
+ t->keys_array.push_back("sdaccel");
+ t->keys_array.push_back("hls");
} else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
t->device_type = kDLAOCL;
- t->keys_array.push_back(tir::StringImmNode::make("aocl"));
- t->keys_array.push_back(tir::StringImmNode::make("hls"));
+ t->keys_array.push_back("aocl");
+ t->keys_array.push_back("hls");
} else if (target_name == "opengl") {
t->device_type = kOpenGL;
- t->keys_array.push_back(tir::StringImmNode::make("opengl"));
+ t->keys_array.push_back("opengl");
} else if (target_name == "stackvm") {
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
@@ -168,7 +168,7 @@ TVM_REGISTER_GLOBAL("target.TargetFromString")
std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result;
for (auto& expr : keys_array) {
- result.push_back(expr.as<tir::StringImmNode>()->value);
+ result.push_back(expr);
}
return result;
}
@@ -176,7 +176,7 @@ std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> TargetNode::options() const {
std::vector<std::string> result;
for (auto& expr : options_array) {
- result.push_back(expr.as<tir::StringImmNode>()->value);
+ result.push_back(expr);
}
return result;
}
@@ -184,7 +184,7 @@ std::vector<std::string> TargetNode::options() const {
std::unordered_set<std::string> TargetNode::libs() const {
std::unordered_set<std::string> result;
for (auto& expr : libs_array) {
- result.insert(expr.as<tir::StringImmNode>()->value);
+ result.insert(expr);
}
return result;
}
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 891d137..0efa33a 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -47,7 +47,6 @@ Var::Var(std::string name_hint, Type type_annotation) {
data_ = std::move(n);
}
-
Var Var::copy_with_suffix(const std::string& suffix) const {
const VarNode* node = get();
ObjectPtr<VarNode> new_ptr;
@@ -826,20 +825,28 @@ TVM_REGISTER_GLOBAL("tir.Load")
}
});
-
-
TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](
DataType type, std::string name,
- Array<PrimExpr> args, int call_type,
+ Array<ObjectRef> args, int call_type,
FunctionRef func, int value_index
) {
+ Array<PrimExpr> prim_expr_args;
+ for (const auto& it : args) {
+ CHECK(it->IsInstance<runtime::StringObj>() ||
+ it->IsInstance<PrimExprNode>());
+ if (const auto* str = it.as<runtime::StringObj>()) {
+ prim_expr_args.push_back(StringImmNode::make(str->data));
+ } else {
+ prim_expr_args.push_back(Downcast<PrimExpr>(it));
+ }
+ }
return CallNode::make(type,
- name,
- args,
- static_cast<CallNode::CallType>(call_type),
- func,
- value_index);
+ name,
+ prim_expr_args,
+ static_cast<CallNode::CallType>(call_type),
+ func,
+ value_index);
});
} // namespace tir
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index ea19982..96fc435 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -120,10 +120,10 @@ class IRTransformer final :
Stmt IRTransform(Stmt ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
- const Array<PrimExpr>& only_enable) {
+ const Array<runtime::String>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
- for (PrimExpr s : only_enable) {
- only_type_index.insert(Object::TypeKey2Index(s.as<StringImmNode>()->value.c_str()));
+ for (auto s : only_enable) {
+ only_type_index.insert(Object::TypeKey2Index(s.c_str()));
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
return transform(std::move(ir_node));
diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc
index 773c67d..001c7cf 100644
--- a/src/tir/ir/transform.cc
+++ b/src/tir/ir/transform.cc
@@ -124,7 +124,7 @@ Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required) {
+ const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return PrimFuncPass(pass_func, pass_info);
}
diff --git a/src/tir/pass/arg_binder.cc b/src/tir/pass/arg_binder.cc
index 30542ea..c684b9e 100644
--- a/src/tir/pass/arg_binder.cc
+++ b/src/tir/pass/arg_binder.cc
@@ -42,7 +42,8 @@ void BinderAddAssert(PrimExpr cond,
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
- asserts->emplace_back(AssertStmtNode::make(scond, os.str(), EvaluateNode::make(0)));
+ asserts->emplace_back(AssertStmtNode::make(scond, tvm::tir::StringImmNode::make(os.str()),
+ EvaluateNode::make(0)));
}
}
@@ -173,7 +174,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
ndim_err_msg << arg_name
<< ".ndim is expected to equal "
<< buffer->shape.size();
- asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
+ auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str());
+ asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
// type checks
DataType dtype = buffer->dtype;
std::ostringstream type_err_msg;
@@ -187,7 +189,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
if (!(dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1))) {
- asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
+ auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str());
+ asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
+ asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop));
}
// data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
@@ -245,9 +249,10 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
stride_err_msg << arg_name << ".strides:"
<< " expected to be compact array";
if (conds.size() != 0) {
+ auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str());
Stmt check =
AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
- stride_err_msg.str(), EvaluateNode::make(0));
+ stride_msg, EvaluateNode::make(0));
check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
}
@@ -269,9 +274,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
} else {
std::ostringstream stride_null_err_msg;
stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
- asserts_.emplace_back(
- AssertStmtNode::make(
- NotNode::make(is_null), stride_null_err_msg.str(), nop));
+ asserts_.emplace_back(AssertStmtNode::make(
+ NotNode::make(is_null), tvm::tir::StringImmNode::make(stride_null_err_msg.str()), nop));
for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name;
diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc
index 1fd43ff..8bc4620 100644
--- a/src/tir/pass/hoist_if_then_else.cc
+++ b/src/tir/pass/hoist_if_then_else.cc
@@ -159,8 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
}
});
- return IRTransform(parent_for_stmt, nullptr, replace_target_for,
- {PrimExpr("For")});
+ return IRTransform(parent_for_stmt, nullptr, replace_target_for, {"For"});
}
// Remove IfThenElse node from a For node.
@@ -186,11 +185,9 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
}
});
- then_for = IRTransform(for_stmt, nullptr, replace_then_case,
- {PrimExpr("IfThenElse")});
+ then_for = IRTransform(for_stmt, nullptr, replace_then_case, {"IfThenElse"});
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
- else_for = IRTransform(for_stmt, nullptr, replace_else_case,
- {PrimExpr("IfThenElse")});
+ else_for = IRTransform(for_stmt, nullptr, replace_else_case, {"IfThenElse"});
}
return std::make_pair(then_for, else_for);
@@ -411,7 +408,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
*ret = new_for;
}
});
- return IRTransform(stmt, nullptr, replace_top_for, {PrimExpr("For")});
+ return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")});
}
Stmt HoistIfThenElse(Stmt stmt) {
diff --git a/src/tir/pass/tensor_core.cc b/src/tir/pass/tensor_core.cc
index 88f7496..dc2df98 100644
--- a/src/tir/pass/tensor_core.cc
+++ b/src/tir/pass/tensor_core.cc
@@ -860,7 +860,7 @@ class TensorCoreIRMutator : public StmtExprMutator {
auto it = matrix_abc_.find(simplify_name(node->name));
CHECK(it != matrix_abc_.end())
<< "Cannot find matrix info for " << node->name;
- auto matrix_abc = "wmma." + it->second;
+ auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second);
Stmt body = this->VisitStmt(op->body);
return AttrStmtNode::make(op->node,
op->attr_key,
diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc
index 486f21c..952d663 100644
--- a/src/tir/transforms/bind_device_type.cc
+++ b/src/tir/transforms/bind_device_type.cc
@@ -47,7 +47,8 @@ class DeviceTypeBinder: public StmtExprMutator {
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
- return AssertStmtNode::make(op->value == value, os.str(), body);
+ return AssertStmtNode::make(op->value == value, tvm::tir::StringImmNode::make(os.str()),
+ body);
}
}
return StmtExprMutator::VisitStmt_(op);
diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc
index c49b044..b1dd235 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -41,12 +41,13 @@ namespace tvm {
namespace tir {
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
- return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
+ return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImmNode::make(msg),
+ EvaluateNode::make(0));
}
PrimFunc MakePackedAPI(PrimFunc&& func,
int num_unpacked_args) {
- auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
std::string name_hint = global_symbol;
@@ -140,17 +141,19 @@ PrimFunc MakePackedAPI(PrimFunc&& func,
AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle ||
- tcode == kTVMNullptr, msg.str(), nop));
+ tcode == kTVMNullptr,
+ tvm::tir::StringImmNode::make(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
- seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
+ seq_check.emplace_back(
+ AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImmNode::make(msg.str()), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(
- AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
+ AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImmNode::make(msg.str()), nop));
}
} else {
args.push_back(v_arg);
diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc
index f695b3c..f366353 100644
--- a/src/tir/transforms/remap_thread_axis.cc
+++ b/src/tir/transforms/remap_thread_axis.cc
@@ -76,12 +76,10 @@ class ThreadAxisRewriter : private StmtExprMutator {
};
-PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
+PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
- const StringImmNode* str = kv.first.as<StringImmNode>();
- CHECK(str != nullptr);
- tmap[str->value] = kv.second;
+ tmap[kv.first] = kv.second;
}
auto thread_axis = f->GetAttr<Array<IterVar> >(tir::attr::kDeviceThreadAxis);
@@ -101,7 +99,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
namespace transform {
-Pass RemapThreadAxis(Map<PrimExpr, IterVar> thread_map) {
+Pass RemapThreadAxis(Map<runtime::String, IterVar> thread_map) {
auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
return RemapThreadAxis(std::move(f), thread_map);
};
diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc
index ae32bdc..5149d28 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -272,7 +272,7 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "SplitHostDevice: Require the target attribute";
- auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc
index f1198e7..063247d 100644
--- a/tests/cpp/container_test.cc
+++ b/tests/cpp/container_test.cc
@@ -261,7 +261,7 @@ TEST(String, empty) {
using namespace std;
String s{"hello"};
CHECK_EQ(s.empty(), false);
- s = "";
+ s = std::string("");
CHECK_EQ(s.empty(), true);
}
diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py
index 7301ef7..dd00d7e 100644
--- a/tests/python/relay/test_annotate_target.py
+++ b/tests/python/relay/test_annotate_target.py
@@ -231,7 +231,7 @@ def test_composite_function():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
- add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
+ add_relu = add_relu.with_attr("Composite", "test.add_relu")
# merged function
r = relay.Call(add_relu, [a, b])
@@ -249,7 +249,7 @@ def test_composite_function():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
- add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
+ add_relu = add_relu.with_attr("Composite", "test.add_relu")
# merged function
cb_1 = relay.annotation.compiler_begin(a, "test")
diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py
index 0af55d2..bae077c 100644
--- a/tests/python/relay/test_call_graph.py
+++ b/tests/python/relay/test_call_graph.py
@@ -134,7 +134,7 @@ def test_recursive_func():
func = relay.Function([i],
sb.get(),
ret_type=relay.TensorType([], 'int32'))
- func = func.with_attr("Compiler", tvm.tir.StringImm("a"))
+ func = func.with_attr("Compiler", "a")
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py
index 724e81d..b4496bb 100644
--- a/tests/python/relay/test_external_codegen.py
+++ b/tests/python/relay/test_external_codegen.py
@@ -79,9 +79,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
def set_external_func_attr(func, compiler, ext_symbol):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
- func = func.with_attr("global_symbol",
- runtime.container.String(ext_symbol))
+ func = func.with_attr("Compiler", compiler)
+ func = func.with_attr("global_symbol", ext_symbol)
return func
diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py
index dbd5934..5a71023 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -96,12 +96,14 @@ def test_function():
body = relay.Tuple(tvm.runtime.convert([]))
type_params = tvm.runtime.convert([])
fn = relay.Function(params, body, ret_type, type_params)
- fn = fn.with_attr("test_attribute", tvm.tir.StringImm("value"))
+ fn = fn.with_attr("test_attribute", "value")
+ fn = fn.with_attr("test_attribute1", "value1")
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
assert fn.attrs["test_attribute"] == "value"
+ assert fn.attrs["test_attribute1"] == "value1"
str(fn)
check_json_roundtrip(fn)
diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py
index 271960e..e1a0a01 100644
--- a/tests/python/relay/test_ir_structural_equal_hash.py
+++ b/tests/python/relay/test_ir_structural_equal_hash.py
@@ -356,7 +356,7 @@ def test_function_attr():
p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00)
- func0 = func0.with_attr("FuncName", tvm.runtime.container.String("a"))
+ func0 = func0.with_attr("FuncName", "a")
x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10))
@@ -366,7 +366,7 @@ def test_function_attr():
p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
- func1 = func1.with_attr("FuncName", tvm.runtime.container.String("b"))
+ func1 = func1.with_attr("FuncName", "b")
assert not consistent_equal(func0, func1)
@@ -698,7 +698,7 @@ def test_fn_attribute():
d = relay.var('d', shape=(10, 10))
add_1 = relay.add(c, d)
add_1_fn = relay.Function([c, d], add_1)
- add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.runtime.container.String("test"))
+ add_1_fn = add_1_fn.with_attr("TestAttribute", "test")
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
assert not consistent_equal(add_1_fn, add_fn)
diff --git a/tests/python/relay/test_pass_inline.py b/tests/python/relay/test_pass_inline.py
index 0f6d539..3b41f07 100644
--- a/tests/python/relay/test_pass_inline.py
+++ b/tests/python/relay/test_pass_inline.py
@@ -209,7 +209,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler():
g11 = relay.GlobalVar("g11")
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn11 = fn11.with_attr("Compiler", "a")
mod[g11] = fn11
x1 = relay.var("x1", shape=(3, 5))
@@ -244,7 +244,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler():
x11 = relay.var("x11", shape=(3, 5))
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn11 = fn11.with_attr("Compiler", "a")
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
@@ -367,7 +367,7 @@ def test_recursive_not_called_extern_compiler():
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
mod["main"] = relay.Function([x, y], x + y + g1(x))
@@ -380,7 +380,7 @@ def test_recursive_not_called_extern_compiler():
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
mod["main"] = relay.Function([x, y], x + y + fn1(x))
return mod
@@ -446,7 +446,7 @@ def test_globalvar_as_call_arg_extern_compiler():
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
@@ -456,7 +456,7 @@ def test_globalvar_as_call_arg_extern_compiler():
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+ fn2 = fn2.with_attr("Compiler", "b")
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
@@ -478,7 +478,7 @@ def test_globalvar_as_call_arg_extern_compiler():
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
@@ -486,7 +486,7 @@ def test_globalvar_as_call_arg_extern_compiler():
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+ fn2 = fn2.with_attr("Compiler", "b")
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
@@ -539,10 +539,10 @@ def test_inline_globalvar_without_args_extern_compiler():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+ fn2 = fn2.with_attr("Compiler", "b")
g1 = relay.GlobalVar('g1')
g2 = relay.GlobalVar('g2')
mod[g1] = fn1
@@ -555,10 +555,10 @@ def test_inline_globalvar_without_args_extern_compiler():
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+ fn2 = fn2.with_attr("Compiler", "b")
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(
relay.If(p, fn1, fn2), []))
@@ -787,7 +787,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa"))
+ fn0 = fn0.with_attr("Compiler", "aa")
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
@@ -811,7 +811,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa"))
+ fn0 = fn0.with_attr("Compiler", "aa")
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py
index 110d855..e3c8991 100644
--- a/tests/python/relay/test_pass_merge_composite.py
+++ b/tests/python/relay/test_pass_merge_composite.py
@@ -184,7 +184,7 @@ def test_simple_merge():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
- add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
+ add_relu = add_relu.with_attr("Composite", "add_relu")
# merged function
r = relay.Call(add_relu, [a, b])
@@ -249,8 +249,7 @@ def test_branch_merge():
sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node)
- add_sub_mul = add_sub_mul.with_attr("Composite",
- tir.StringImm("add_sub_mul"))
+ add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul")
# add_sub_mul1 function
in_3 = relay.var('in_3', shape=(10, 10))
@@ -259,8 +258,7 @@ def test_branch_merge():
sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
- add_sub_mul_1 = add_sub_mul_1.with_attr("Composite",
- tir.StringImm("add_sub_mul"))
+ add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul")
# merged function
m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
@@ -319,8 +317,7 @@ def test_reuse_call_merge():
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
- add_add_add = add_add_add.with_attr("Composite",
- tir.StringImm("add_add_add"))
+ add_add_add = add_add_add.with_attr("Composite", "add_add_add")
# merged function
sub_node = relay.subtract(a, b)
@@ -404,7 +401,7 @@ def test_multiple_patterns():
r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite",
- tir.StringImm("conv2d_bias_relu"))
+ "conv2d_bias_relu")
# add_relu function
in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
@@ -412,7 +409,7 @@ def test_multiple_patterns():
add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r)
- add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
+ add_relu = add_relu.with_attr("Composite", "add_relu")
# merged function
conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
@@ -481,8 +478,7 @@ def test_merge_order():
out = relay.abs(out)
out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out)
- merged_func = merged_func.with_attr('Composite',
- tir.StringImm(composite_name))
+ merged_func = merged_func.with_attr('Composite', composite_name)
ret = relay.Call(merged_func, [input_1, input_2])
return relay.Function([input_1, input_2], ret)
@@ -547,13 +543,13 @@ def test_parallel_merge():
y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1)
- func_1 = func_1.with_attr('Composite', tir.StringImm("add_sub_mul"))
+ func_1 = func_1.with_attr('Composite', "add_sub_mul")
call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1')
y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2)
- func_2 = func_2.with_attr('Composite', tir.StringImm("add_sub_mul"))
+ func_2 = func_2.with_attr('Composite', "add_sub_mul")
call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2)
return relay.Function([input_1, input_2], out)
@@ -632,14 +628,14 @@ def test_multiple_input_subgraphs():
add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1)
- add_relu_1 = add_relu_1.with_attr('Composite', tir.StringImm('add_relu'))
+ add_relu_1 = add_relu_1.with_attr('Composite', 'add_relu')
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1')
y1 = relay.var('y1')
add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2)
- add_relu_2 = add_relu_2.with_attr('Composite', tir.StringImm('add_relu'))
+ add_relu_2 = add_relu_2.with_attr('Composite', 'add_relu')
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2')
y2 = relay.var('y2')
@@ -647,7 +643,7 @@ def test_multiple_input_subgraphs():
sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul)
- add_sub_mul = add_sub_mul.with_attr('Composite', tir.StringImm('add_sub_mul'))
+ add_sub_mul = add_sub_mul.with_attr('Composite', 'add_sub_mul')
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call)
@@ -660,7 +656,7 @@ def test_multiple_input_subgraphs():
add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu)
- add_relu = add_relu.with_attr('Composite', tir.StringImm('add_relu'))
+ add_relu = add_relu.with_attr('Composite', 'add_relu')
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call)
@@ -720,7 +716,7 @@ def test_tuple_get_item_merge():
tuple_get_item_node = bn_node[0]
relu_node = relay.nn.relu(tuple_get_item_node)
bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node)
- bn_relu = bn_relu.with_attr("Composite", tir.StringImm("bn_relu"))
+ bn_relu = bn_relu.with_attr("Composite", "bn_relu")
# merged function
r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var])
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 3959613..1968f34 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -24,7 +24,6 @@ import tvm
import tvm.relay.testing
from tvm import relay
from tvm import runtime
-from tvm.runtime import container
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end
@@ -307,8 +306,8 @@ def test_extern_ccompiler_default_ops():
func = relay.Function([x0, y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
- func = func.with_attr("global_symbol", container.String("ccompiler_0"))
+ func = func.with_attr("Compiler", "ccompiler")
+ func = func.with_attr("global_symbol", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
@@ -392,8 +391,8 @@ def test_extern_dnnl():
func = relay.Function([data0, input0, input1], out)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
- func = func.with_attr("global_symbol", container.String("dnnl_0"))
+ func = func.with_attr("Compiler", "dnnl")
+ func = func.with_attr("global_symbol", "dnnl_0")
glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule()
mod[glb_var] = func
@@ -518,10 +517,8 @@ def test_function_lifting():
bn.astuple())
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func0 = func0.with_attr("Compiler",
- tvm.tir.StringImm("test_compiler"))
- func0 = func0.with_attr("global_symbol",
- container.String("test_compiler_0"))
+ func0 = func0.with_attr("Compiler", "test_compiler")
+ func0 = func0.with_attr("global_symbol", "test_compiler_0")
gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0
@@ -537,10 +534,8 @@ def test_function_lifting():
func1 = relay.Function([data1, weight1], conv)
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func1 = func1.with_attr("Compiler",
- tvm.tir.StringImm("test_compiler"))
- func1 = func1.with_attr("global_symbol",
- container.String("test_compiler_1"))
+ func1 = func1.with_attr("Compiler", "test_compiler")
+ func1 = func1.with_attr("global_symbol", "test_compiler_1")
gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1
@@ -611,10 +606,8 @@ def test_function_lifting_inline():
bn.astuple())
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func0 = func0.with_attr("Compiler",
- tvm.tir.StringImm("test_compiler"))
- func0 = func0.with_attr("global_symbol",
- container.String("test_compiler_0"))
+ func0 = func0.with_attr("Compiler", "test_compiler")
+ func0 = func0.with_attr("global_symbol", "test_compiler_0")
# main function
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
@@ -648,8 +641,8 @@ def test_constant_propagation():
func = relay.Function([y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
- func = func.with_attr("global_symbol", container.String("ccompiler_0"))
+ func = func.with_attr("Compiler", "ccompiler")
+ func = func.with_attr("global_symbol", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [y])
@@ -748,10 +741,8 @@ def test_multiple_outputs():
bn_mean, bn_var], tuple_o)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func0 = func0.with_attr("Compiler",
- tvm.tir.StringImm("test_target"))
- func0 = func0.with_attr("global_symbol",
- container.String("test_target_2"))
+ func0 = func0.with_attr("Compiler", "test_target")
+ func0 = func0.with_attr("global_symbol", "test_target_2")
gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0
@@ -816,10 +807,8 @@ def test_mixed_single_multiple_outputs():
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func1 = func1.with_attr("Compiler",
- tvm.tir.StringImm("test_target"))
- func1 = func1.with_attr("global_symbol",
- container.String("test_target_1"))
+ func1 = func1.with_attr("Compiler", "test_target")
+ func1 = func1.with_attr("global_symbol", "test_target_1")
gv1 = relay.GlobalVar("test_target_1")
mod[gv1] = func1
@@ -831,10 +820,8 @@ def test_mixed_single_multiple_outputs():
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func0 = func0.with_attr("Compiler",
- tvm.tir.StringImm("test_target"))
- func0 = func0.with_attr("global_symbol",
- container.String("test_target_0"))
+ func0 = func0.with_attr("Compiler", "test_target")
+ func0 = func0.with_attr("global_symbol", "test_target_0")
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0
diff --git a/tests/python/unittest/test_ir_attrs.py b/tests/python/unittest/test_ir_attrs.py
index 8f2e9bb..48495f4 100644
--- a/tests/python/unittest/test_ir_attrs.py
+++ b/tests/python/unittest/test_ir_attrs.py
@@ -41,7 +41,7 @@ def test_dict_attrs():
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
- assert dattr.name.value == "xyz"
+ assert dattr.name == "xyz"
assert isinstance(dattr, tvm.ir.DictAttrs)
assert "name" in dattr
assert dattr["x"].value == 1
diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h
index 66b8a10..ee18dea 100644
--- a/topi/include/topi/contrib/cublas.h
+++ b/topi/include/topi/contrib/cublas.h
@@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- PrimExpr("tvm.contrib.cublas.matmul"),
+ runtime::String("tvm.contrib.cublas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
@@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs,
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- PrimExpr("tvm.contrib.cublas.batch_matmul"),
+ runtime::String("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
diff --git a/topi/include/topi/contrib/rocblas.h b/topi/include/topi/contrib/rocblas.h
index 2fcafc7..9fe1825 100644
--- a/topi/include/topi/contrib/rocblas.h
+++ b/topi/include/topi/contrib/rocblas.h
@@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- PrimExpr("tvm.contrib.rocblas.matmul"),
+ runtime::String("tvm.contrib.rocblas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),